rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth::{self, Impersonator},
   5    db::{
   6        self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreateChannelResult,
   7        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
   8        MoveChannelResult, NotificationId, ProjectId, RemoveChannelMemberResult,
   9        RenameChannelResult, RespondToChannelInvite, RoomId, ServerId, SetChannelVisibilityResult,
  10        User, UserId,
  11    },
  12    executor::Executor,
  13    AppState, Error, Result,
  14};
  15use anyhow::anyhow;
  16use async_tungstenite::tungstenite::{
  17    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  18};
  19use axum::{
  20    body::Body,
  21    extract::{
  22        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  23        ConnectInfo, WebSocketUpgrade,
  24    },
  25    headers::{Header, HeaderName},
  26    http::StatusCode,
  27    middleware,
  28    response::IntoResponse,
  29    routing::get,
  30    Extension, Router, TypedHeader,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::ConnectionPool;
  34use futures::{
  35    channel::oneshot,
  36    future::{self, BoxFuture},
  37    stream::FuturesUnordered,
  38    FutureExt, SinkExt, StreamExt, TryStreamExt,
  39};
  40use lazy_static::lazy_static;
  41use prometheus::{register_int_gauge, IntGauge};
  42use rpc::{
  43    proto::{
  44        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  45        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  46    },
  47    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  48};
  49use serde::{Serialize, Serializer};
  50use std::{
  51    any::TypeId,
  52    fmt,
  53    future::Future,
  54    marker::PhantomData,
  55    mem,
  56    net::SocketAddr,
  57    ops::{Deref, DerefMut},
  58    rc::Rc,
  59    sync::{
  60        atomic::{AtomicBool, Ordering::SeqCst},
  61        Arc,
  62    },
  63    time::{Duration, Instant},
  64};
  65use time::OffsetDateTime;
  66use tokio::sync::{watch, Semaphore};
  67use tower::ServiceBuilder;
  68use tracing::{field, info_span, instrument, Instrument};
  69
  70pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  71pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
  72
  73const MESSAGE_COUNT_PER_PAGE: usize = 100;
  74const MAX_MESSAGE_LEN: usize = 1024;
  75const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  76
  77lazy_static! {
  78    static ref METRIC_CONNECTIONS: IntGauge =
  79        register_int_gauge!("connections", "number of connections").unwrap();
  80    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  81        "shared_projects",
  82        "number of open projects with one or more guests"
  83    )
  84    .unwrap();
  85}
  86
  87type MessageHandler =
  88    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  89
  90struct Response<R> {
  91    peer: Arc<Peer>,
  92    receipt: Receipt<R>,
  93    responded: Arc<AtomicBool>,
  94}
  95
  96impl<R: RequestMessage> Response<R> {
  97    fn send(self, payload: R::Response) -> Result<()> {
  98        self.responded.store(true, SeqCst);
  99        self.peer.respond(self.receipt, payload)?;
 100        Ok(())
 101    }
 102}
 103
 104#[derive(Clone)]
 105struct Session {
 106    zed_environment: Arc<str>,
 107    user_id: UserId,
 108    connection_id: ConnectionId,
 109    db: Arc<tokio::sync::Mutex<DbHandle>>,
 110    peer: Arc<Peer>,
 111    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 112    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 113    _executor: Executor,
 114}
 115
 116impl Session {
 117    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 118        #[cfg(test)]
 119        tokio::task::yield_now().await;
 120        let guard = self.db.lock().await;
 121        #[cfg(test)]
 122        tokio::task::yield_now().await;
 123        guard
 124    }
 125
 126    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 127        #[cfg(test)]
 128        tokio::task::yield_now().await;
 129        let guard = self.connection_pool.lock();
 130        ConnectionPoolGuard {
 131            guard,
 132            _not_send: PhantomData,
 133        }
 134    }
 135}
 136
 137impl fmt::Debug for Session {
 138    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 139        f.debug_struct("Session")
 140            .field("user_id", &self.user_id)
 141            .field("connection_id", &self.connection_id)
 142            .finish()
 143    }
 144}
 145
 146struct DbHandle(Arc<Database>);
 147
 148impl Deref for DbHandle {
 149    type Target = Database;
 150
 151    fn deref(&self) -> &Self::Target {
 152        self.0.as_ref()
 153    }
 154}
 155
 156pub struct Server {
 157    id: parking_lot::Mutex<ServerId>,
 158    peer: Arc<Peer>,
 159    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 160    app_state: Arc<AppState>,
 161    executor: Executor,
 162    handlers: HashMap<TypeId, MessageHandler>,
 163    teardown: watch::Sender<()>,
 164}
 165
 166pub(crate) struct ConnectionPoolGuard<'a> {
 167    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 168    _not_send: PhantomData<Rc<()>>,
 169}
 170
 171#[derive(Serialize)]
 172pub struct ServerSnapshot<'a> {
 173    peer: &'a Peer,
 174    #[serde(serialize_with = "serialize_deref")]
 175    connection_pool: ConnectionPoolGuard<'a>,
 176}
 177
 178pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 179where
 180    S: Serializer,
 181    T: Deref<Target = U>,
 182    U: Serialize,
 183{
 184    Serialize::serialize(value.deref(), serializer)
 185}
 186
 187impl Server {
 188    pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 189        let mut server = Self {
 190            id: parking_lot::Mutex::new(id),
 191            peer: Peer::new(id.0 as u32),
 192            app_state,
 193            executor,
 194            connection_pool: Default::default(),
 195            handlers: Default::default(),
 196            teardown: watch::channel(()).0,
 197        };
 198
 199        server
 200            .add_request_handler(ping)
 201            .add_request_handler(create_room)
 202            .add_request_handler(join_room)
 203            .add_request_handler(rejoin_room)
 204            .add_request_handler(leave_room)
 205            .add_request_handler(set_room_participant_role)
 206            .add_request_handler(call)
 207            .add_request_handler(cancel_call)
 208            .add_message_handler(decline_call)
 209            .add_request_handler(update_participant_location)
 210            .add_request_handler(share_project)
 211            .add_message_handler(unshare_project)
 212            .add_request_handler(join_project)
 213            .add_message_handler(leave_project)
 214            .add_request_handler(update_project)
 215            .add_request_handler(update_worktree)
 216            .add_message_handler(start_language_server)
 217            .add_message_handler(update_language_server)
 218            .add_message_handler(update_diagnostic_summary)
 219            .add_message_handler(update_worktree_settings)
 220            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 221            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 222            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 223            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 224            .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
 225            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 226            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 227            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 228            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 229            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 230            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 231            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 232            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 233            .add_request_handler(
 234                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 235            )
 236            .add_request_handler(
 237                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 238            )
 239            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 240            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 241            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 242            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 243            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 244            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 245            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 246            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 247            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 248            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 249            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 250            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 251            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 252            .add_message_handler(create_buffer_for_peer)
 253            .add_request_handler(update_buffer)
 254            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 255            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 256            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 257            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 258            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 259            .add_request_handler(get_users)
 260            .add_request_handler(fuzzy_search_users)
 261            .add_request_handler(request_contact)
 262            .add_request_handler(remove_contact)
 263            .add_request_handler(respond_to_contact_request)
 264            .add_request_handler(create_channel)
 265            .add_request_handler(delete_channel)
 266            .add_request_handler(invite_channel_member)
 267            .add_request_handler(remove_channel_member)
 268            .add_request_handler(set_channel_member_role)
 269            .add_request_handler(set_channel_visibility)
 270            .add_request_handler(rename_channel)
 271            .add_request_handler(join_channel_buffer)
 272            .add_request_handler(leave_channel_buffer)
 273            .add_message_handler(update_channel_buffer)
 274            .add_request_handler(rejoin_channel_buffers)
 275            .add_request_handler(get_channel_members)
 276            .add_request_handler(respond_to_channel_invite)
 277            .add_request_handler(join_channel)
 278            .add_request_handler(join_channel_chat)
 279            .add_message_handler(leave_channel_chat)
 280            .add_request_handler(send_channel_message)
 281            .add_request_handler(remove_channel_message)
 282            .add_request_handler(get_channel_messages)
 283            .add_request_handler(get_channel_messages_by_id)
 284            .add_request_handler(get_notifications)
 285            .add_request_handler(mark_notification_as_read)
 286            .add_request_handler(move_channel)
 287            .add_request_handler(follow)
 288            .add_message_handler(unfollow)
 289            .add_message_handler(update_followers)
 290            .add_request_handler(get_private_user_info)
 291            .add_message_handler(acknowledge_channel_message)
 292            .add_message_handler(acknowledge_buffer_version);
 293
 294        Arc::new(server)
 295    }
 296
 297    pub async fn start(&self) -> Result<()> {
 298        let server_id = *self.id.lock();
 299        let app_state = self.app_state.clone();
 300        let peer = self.peer.clone();
 301        let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
 302        let pool = self.connection_pool.clone();
 303        let live_kit_client = self.app_state.live_kit_client.clone();
 304
 305        let span = info_span!("start server");
 306        self.executor.spawn_detached(
 307            async move {
 308                tracing::info!("waiting for cleanup timeout");
 309                timeout.await;
 310                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 311                if let Some((room_ids, channel_ids)) = app_state
 312                    .db
 313                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 314                    .await
 315                    .trace_err()
 316                {
 317                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 318                    tracing::info!(
 319                        stale_channel_buffer_count = channel_ids.len(),
 320                        "retrieved stale channel buffers"
 321                    );
 322
 323                    for channel_id in channel_ids {
 324                        if let Some(refreshed_channel_buffer) = app_state
 325                            .db
 326                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 327                            .await
 328                            .trace_err()
 329                        {
 330                            for connection_id in refreshed_channel_buffer.connection_ids {
 331                                peer.send(
 332                                    connection_id,
 333                                    proto::UpdateChannelBufferCollaborators {
 334                                        channel_id: channel_id.to_proto(),
 335                                        collaborators: refreshed_channel_buffer
 336                                            .collaborators
 337                                            .clone(),
 338                                    },
 339                                )
 340                                .trace_err();
 341                            }
 342                        }
 343                    }
 344
 345                    for room_id in room_ids {
 346                        let mut contacts_to_update = HashSet::default();
 347                        let mut canceled_calls_to_user_ids = Vec::new();
 348                        let mut live_kit_room = String::new();
 349                        let mut delete_live_kit_room = false;
 350
 351                        if let Some(mut refreshed_room) = app_state
 352                            .db
 353                            .clear_stale_room_participants(room_id, server_id)
 354                            .await
 355                            .trace_err()
 356                        {
 357                            tracing::info!(
 358                                room_id = room_id.0,
 359                                new_participant_count = refreshed_room.room.participants.len(),
 360                                "refreshed room"
 361                            );
 362                            room_updated(&refreshed_room.room, &peer);
 363                            if let Some(channel_id) = refreshed_room.channel_id {
 364                                channel_updated(
 365                                    channel_id,
 366                                    &refreshed_room.room,
 367                                    &refreshed_room.channel_members,
 368                                    &peer,
 369                                    &*pool.lock(),
 370                                );
 371                            }
 372                            contacts_to_update
 373                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 374                            contacts_to_update
 375                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 376                            canceled_calls_to_user_ids =
 377                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 378                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 379                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 380                        }
 381
 382                        {
 383                            let pool = pool.lock();
 384                            for canceled_user_id in canceled_calls_to_user_ids {
 385                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 386                                    peer.send(
 387                                        connection_id,
 388                                        proto::CallCanceled {
 389                                            room_id: room_id.to_proto(),
 390                                        },
 391                                    )
 392                                    .trace_err();
 393                                }
 394                            }
 395                        }
 396
 397                        for user_id in contacts_to_update {
 398                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 399                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 400                            if let Some((busy, contacts)) = busy.zip(contacts) {
 401                                let pool = pool.lock();
 402                                let updated_contact = contact_for_user(user_id, busy, &pool);
 403                                for contact in contacts {
 404                                    if let db::Contact::Accepted {
 405                                        user_id: contact_user_id,
 406                                        ..
 407                                    } = contact
 408                                    {
 409                                        for contact_conn_id in
 410                                            pool.user_connection_ids(contact_user_id)
 411                                        {
 412                                            peer.send(
 413                                                contact_conn_id,
 414                                                proto::UpdateContacts {
 415                                                    contacts: vec![updated_contact.clone()],
 416                                                    remove_contacts: Default::default(),
 417                                                    incoming_requests: Default::default(),
 418                                                    remove_incoming_requests: Default::default(),
 419                                                    outgoing_requests: Default::default(),
 420                                                    remove_outgoing_requests: Default::default(),
 421                                                },
 422                                            )
 423                                            .trace_err();
 424                                        }
 425                                    }
 426                                }
 427                            }
 428                        }
 429
 430                        if let Some(live_kit) = live_kit_client.as_ref() {
 431                            if delete_live_kit_room {
 432                                live_kit.delete_room(live_kit_room).await.trace_err();
 433                            }
 434                        }
 435                    }
 436                }
 437
 438                app_state
 439                    .db
 440                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 441                    .await
 442                    .trace_err();
 443            }
 444            .instrument(span),
 445        );
 446        Ok(())
 447    }
 448
 449    pub fn teardown(&self) {
 450        self.peer.teardown();
 451        self.connection_pool.lock().reset();
 452        let _ = self.teardown.send(());
 453    }
 454
 455    #[cfg(test)]
 456    pub fn reset(&self, id: ServerId) {
 457        self.teardown();
 458        *self.id.lock() = id;
 459        self.peer.reset(id.0 as u32);
 460    }
 461
 462    #[cfg(test)]
 463    pub fn id(&self) -> ServerId {
 464        *self.id.lock()
 465    }
 466
 467    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 468    where
 469        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 470        Fut: 'static + Send + Future<Output = Result<()>>,
 471        M: EnvelopedMessage,
 472    {
 473        let prev_handler = self.handlers.insert(
 474            TypeId::of::<M>(),
 475            Box::new(move |envelope, session| {
 476                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 477                let span = info_span!(
 478                    "handle message",
 479                    payload_type = envelope.payload_type_name()
 480                );
 481                span.in_scope(|| {
 482                    tracing::info!(
 483                        payload_type = envelope.payload_type_name(),
 484                        "message received"
 485                    );
 486                });
 487                let start_time = Instant::now();
 488                let future = (handler)(*envelope, session);
 489                async move {
 490                    let result = future.await;
 491                    let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 492                    match result {
 493                        Err(error) => {
 494                            tracing::error!(%error, ?duration_ms, "error handling message")
 495                        }
 496                        Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
 497                    }
 498                }
 499                .instrument(span)
 500                .boxed()
 501            }),
 502        );
 503        if prev_handler.is_some() {
 504            panic!("registered a handler for the same message twice");
 505        }
 506        self
 507    }
 508
 509    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 510    where
 511        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 512        Fut: 'static + Send + Future<Output = Result<()>>,
 513        M: EnvelopedMessage,
 514    {
 515        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 516        self
 517    }
 518
 519    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 520    where
 521        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 522        Fut: Send + Future<Output = Result<()>>,
 523        M: RequestMessage,
 524    {
 525        let handler = Arc::new(handler);
 526        self.add_handler(move |envelope, session| {
 527            let receipt = envelope.receipt();
 528            let handler = handler.clone();
 529            async move {
 530                let peer = session.peer.clone();
 531                let responded = Arc::new(AtomicBool::default());
 532                let response = Response {
 533                    peer: peer.clone(),
 534                    responded: responded.clone(),
 535                    receipt,
 536                };
 537                match (handler)(envelope.payload, response, session).await {
 538                    Ok(()) => {
 539                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 540                            Ok(())
 541                        } else {
 542                            Err(anyhow!("handler did not send a response"))?
 543                        }
 544                    }
 545                    Err(error) => {
 546                        let proto_err = match &error {
 547                            Error::Internal(err) => err.to_proto(),
 548                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 549                        };
 550                        peer.respond_with_error(receipt, proto_err)?;
 551                        Err(error)
 552                    }
 553                }
 554            }
 555        })
 556    }
 557
 558    pub fn handle_connection(
 559        self: &Arc<Self>,
 560        connection: Connection,
 561        address: String,
 562        user: User,
 563        impersonator: Option<User>,
 564        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 565        executor: Executor,
 566    ) -> impl Future<Output = Result<()>> {
 567        let this = self.clone();
 568        let user_id = user.id;
 569        let login = user.github_login;
 570        let span = info_span!("handle connection", %user_id, %login, %address, impersonator = field::Empty);
 571        if let Some(impersonator) = impersonator {
 572            span.record("impersonator", &impersonator.github_login);
 573        }
 574        let mut teardown = self.teardown.subscribe();
 575        async move {
 576            let (connection_id, handle_io, mut incoming_rx) = this
 577                .peer
 578                .add_connection(connection, {
 579                    let executor = executor.clone();
 580                    move |duration| executor.sleep(duration)
 581                });
 582
 583            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 584            this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
 585            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 586
 587            if let Some(send_connection_id) = send_connection_id.take() {
 588                let _ = send_connection_id.send(connection_id);
 589            }
 590
 591            if !user.connected_once {
 592                this.peer.send(connection_id, proto::ShowContacts {})?;
 593                this.app_state.db.set_user_connected_once(user_id, true).await?;
 594            }
 595
 596            let (contacts, channels_for_user, channel_invites) = future::try_join3(
 597                this.app_state.db.get_contacts(user_id),
 598                this.app_state.db.get_channels_for_user(user_id),
 599                this.app_state.db.get_channel_invites_for_user(user_id),
 600            ).await?;
 601
 602            {
 603                let mut pool = this.connection_pool.lock();
 604                pool.add_connection(connection_id, user_id, user.admin);
 605                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 606                this.peer.send(connection_id, build_channels_update(
 607                    channels_for_user,
 608                    channel_invites
 609                ))?;
 610            }
 611
 612            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 613                this.peer.send(connection_id, incoming_call)?;
 614            }
 615
 616            let session = Session {
 617                user_id,
 618                connection_id,
 619                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 620                zed_environment: this.app_state.config.zed_environment.clone(),
 621                peer: this.peer.clone(),
 622                connection_pool: this.connection_pool.clone(),
 623                live_kit_client: this.app_state.live_kit_client.clone(),
 624                _executor: executor.clone()
 625            };
 626            update_user_contacts(user_id, &session).await?;
 627
 628            let handle_io = handle_io.fuse();
 629            futures::pin_mut!(handle_io);
 630
 631            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 632            // This prevents deadlocks when e.g., client A performs a request to client B and
 633            // client B performs a request to client A. If both clients stop processing further
 634            // messages until their respective request completes, they won't have a chance to
 635            // respond to the other client's request and cause a deadlock.
 636            //
 637            // This arrangement ensures we will attempt to process earlier messages first, but fall
 638            // back to processing messages arrived later in the spirit of making progress.
 639            let mut foreground_message_handlers = FuturesUnordered::new();
 640            let concurrent_handlers = Arc::new(Semaphore::new(256));
 641            loop {
 642                let next_message = async {
 643                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 644                    let message = incoming_rx.next().await;
 645                    (permit, message)
 646                }.fuse();
 647                futures::pin_mut!(next_message);
 648                futures::select_biased! {
 649                    _ = teardown.changed().fuse() => return Ok(()),
 650                    result = handle_io => {
 651                        if let Err(error) = result {
 652                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 653                        }
 654                        break;
 655                    }
 656                    _ = foreground_message_handlers.next() => {}
 657                    next_message = next_message => {
 658                        let (permit, message) = next_message;
 659                        if let Some(message) = message {
 660                            let type_name = message.payload_type_name();
 661                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 662                            let span_enter = span.enter();
 663                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 664                                let is_background = message.is_background();
 665                                let handle_message = (handler)(message, session.clone());
 666                                drop(span_enter);
 667
 668                                let handle_message = async move {
 669                                    handle_message.await;
 670                                    drop(permit);
 671                                }.instrument(span);
 672                                if is_background {
 673                                    executor.spawn_detached(handle_message);
 674                                } else {
 675                                    foreground_message_handlers.push(handle_message);
 676                                }
 677                            } else {
 678                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 679                            }
 680                        } else {
 681                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 682                            break;
 683                        }
 684                    }
 685                }
 686            }
 687
 688            drop(foreground_message_handlers);
 689            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 690            if let Err(error) = connection_lost(session, teardown, executor).await {
 691                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 692            }
 693
 694            Ok(())
 695        }.instrument(span)
 696    }
 697
 698    pub async fn invite_code_redeemed(
 699        self: &Arc<Self>,
 700        inviter_id: UserId,
 701        invitee_id: UserId,
 702    ) -> Result<()> {
 703        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 704            if let Some(code) = &user.invite_code {
 705                let pool = self.connection_pool.lock();
 706                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 707                for connection_id in pool.user_connection_ids(inviter_id) {
 708                    self.peer.send(
 709                        connection_id,
 710                        proto::UpdateContacts {
 711                            contacts: vec![invitee_contact.clone()],
 712                            ..Default::default()
 713                        },
 714                    )?;
 715                    self.peer.send(
 716                        connection_id,
 717                        proto::UpdateInviteInfo {
 718                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 719                            count: user.invite_count as u32,
 720                        },
 721                    )?;
 722                }
 723            }
 724        }
 725        Ok(())
 726    }
 727
 728    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 729        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 730            if let Some(invite_code) = &user.invite_code {
 731                let pool = self.connection_pool.lock();
 732                for connection_id in pool.user_connection_ids(user_id) {
 733                    self.peer.send(
 734                        connection_id,
 735                        proto::UpdateInviteInfo {
 736                            url: format!(
 737                                "{}{}",
 738                                self.app_state.config.invite_link_prefix, invite_code
 739                            ),
 740                            count: user.invite_count as u32,
 741                        },
 742                    )?;
 743                }
 744            }
 745        }
 746        Ok(())
 747    }
 748
 749    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 750        ServerSnapshot {
 751            connection_pool: ConnectionPoolGuard {
 752                guard: self.connection_pool.lock(),
 753                _not_send: PhantomData,
 754            },
 755            peer: &self.peer,
 756        }
 757    }
 758}
 759
 760impl<'a> Deref for ConnectionPoolGuard<'a> {
 761    type Target = ConnectionPool;
 762
 763    fn deref(&self) -> &Self::Target {
 764        &*self.guard
 765    }
 766}
 767
 768impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 769    fn deref_mut(&mut self) -> &mut Self::Target {
 770        &mut *self.guard
 771    }
 772}
 773
 774impl<'a> Drop for ConnectionPoolGuard<'a> {
 775    fn drop(&mut self) {
 776        #[cfg(test)]
 777        self.check_invariants();
 778    }
 779}
 780
 781fn broadcast<F>(
 782    sender_id: Option<ConnectionId>,
 783    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 784    mut f: F,
 785) where
 786    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 787{
 788    for receiver_id in receiver_ids {
 789        if Some(receiver_id) != sender_id {
 790            if let Err(error) = f(receiver_id) {
 791                tracing::error!("failed to send to {:?} {}", receiver_id, error);
 792            }
 793        }
 794    }
 795}
 796
 797lazy_static! {
 798    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 799}
 800
 801pub struct ProtocolVersion(u32);
 802
 803impl Header for ProtocolVersion {
 804    fn name() -> &'static HeaderName {
 805        &ZED_PROTOCOL_VERSION
 806    }
 807
 808    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 809    where
 810        Self: Sized,
 811        I: Iterator<Item = &'i axum::http::HeaderValue>,
 812    {
 813        let version = values
 814            .next()
 815            .ok_or_else(axum::headers::Error::invalid)?
 816            .to_str()
 817            .map_err(|_| axum::headers::Error::invalid())?
 818            .parse()
 819            .map_err(|_| axum::headers::Error::invalid())?;
 820        Ok(Self(version))
 821    }
 822
 823    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 824        values.extend([self.0.to_string().parse().unwrap()]);
 825    }
 826}
 827
 828pub fn routes(server: Arc<Server>) -> Router<Body> {
 829    Router::new()
 830        .route("/rpc", get(handle_websocket_request))
 831        .layer(
 832            ServiceBuilder::new()
 833                .layer(Extension(server.app_state.clone()))
 834                .layer(middleware::from_fn(auth::validate_header)),
 835        )
 836        .route("/metrics", get(handle_metrics))
 837        .layer(Extension(server))
 838}
 839
 840pub async fn handle_websocket_request(
 841    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 842    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 843    Extension(server): Extension<Arc<Server>>,
 844    Extension(user): Extension<User>,
 845    Extension(impersonator): Extension<Impersonator>,
 846    ws: WebSocketUpgrade,
 847) -> axum::response::Response {
 848    if protocol_version != rpc::PROTOCOL_VERSION {
 849        return (
 850            StatusCode::UPGRADE_REQUIRED,
 851            "client must be upgraded".to_string(),
 852        )
 853            .into_response();
 854    }
 855    let socket_address = socket_address.to_string();
 856    ws.on_upgrade(move |socket| {
 857        use util::ResultExt;
 858        let socket = socket
 859            .map_ok(to_tungstenite_message)
 860            .err_into()
 861            .with(|message| async move { Ok(to_axum_message(message)) });
 862        let connection = Connection::new(Box::pin(socket));
 863        async move {
 864            server
 865                .handle_connection(
 866                    connection,
 867                    socket_address,
 868                    user,
 869                    impersonator.0,
 870                    None,
 871                    Executor::Production,
 872                )
 873                .await
 874                .log_err();
 875        }
 876    })
 877}
 878
 879pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 880    let connections = server
 881        .connection_pool
 882        .lock()
 883        .connections()
 884        .filter(|connection| !connection.admin)
 885        .count();
 886
 887    METRIC_CONNECTIONS.set(connections as _);
 888
 889    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 890    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 891
 892    let encoder = prometheus::TextEncoder::new();
 893    let metric_families = prometheus::gather();
 894    let encoded_metrics = encoder
 895        .encode_to_string(&metric_families)
 896        .map_err(|err| anyhow!("{}", err))?;
 897    Ok(encoded_metrics)
 898}
 899
 900#[instrument(err, skip(executor))]
 901async fn connection_lost(
 902    session: Session,
 903    mut teardown: watch::Receiver<()>,
 904    executor: Executor,
 905) -> Result<()> {
 906    session.peer.disconnect(session.connection_id);
 907    session
 908        .connection_pool()
 909        .await
 910        .remove_connection(session.connection_id)?;
 911
 912    session
 913        .db()
 914        .await
 915        .connection_lost(session.connection_id)
 916        .await
 917        .trace_err();
 918
 919    futures::select_biased! {
 920        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 921            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
 922            leave_room_for_session(&session).await.trace_err();
 923            leave_channel_buffers_for_session(&session)
 924                .await
 925                .trace_err();
 926
 927            if !session
 928                .connection_pool()
 929                .await
 930                .is_user_online(session.user_id)
 931            {
 932                let db = session.db().await;
 933                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
 934                    room_updated(&room, &session.peer);
 935                }
 936            }
 937
 938            update_user_contacts(session.user_id, &session).await?;
 939        }
 940        _ = teardown.changed().fuse() => {}
 941    }
 942
 943    Ok(())
 944}
 945
 946/// Acknowledges a ping from a client, used to keep the connection alive.
 947async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 948    response.send(proto::Ack {})?;
 949    Ok(())
 950}
 951
 952/// Creates a new room for calling (outside of channels)
 953async fn create_room(
 954    _request: proto::CreateRoom,
 955    response: Response<proto::CreateRoom>,
 956    session: Session,
 957) -> Result<()> {
 958    let live_kit_room = nanoid::nanoid!(30);
 959
 960    let live_kit_connection_info = {
 961        let live_kit_room = live_kit_room.clone();
 962        let live_kit = session.live_kit_client.as_ref();
 963
 964        util::async_maybe!({
 965            let live_kit = live_kit?;
 966
 967            let token = live_kit
 968                .room_token(&live_kit_room, &session.user_id.to_string())
 969                .trace_err()?;
 970
 971            Some(proto::LiveKitConnectionInfo {
 972                server_url: live_kit.url().into(),
 973                token,
 974                can_publish: true,
 975            })
 976        })
 977    }
 978    .await;
 979
 980    let room = session
 981        .db()
 982        .await
 983        .create_room(
 984            session.user_id,
 985            session.connection_id,
 986            &live_kit_room,
 987            &session.zed_environment,
 988        )
 989        .await?;
 990
 991    response.send(proto::CreateRoomResponse {
 992        room: Some(room.clone()),
 993        live_kit_connection_info,
 994    })?;
 995
 996    update_user_contacts(session.user_id, &session).await?;
 997    Ok(())
 998}
 999
1000/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1001async fn join_room(
1002    request: proto::JoinRoom,
1003    response: Response<proto::JoinRoom>,
1004    session: Session,
1005) -> Result<()> {
1006    let room_id = RoomId::from_proto(request.id);
1007
1008    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1009
1010    if let Some(channel_id) = channel_id {
1011        return join_channel_internal(channel_id, Box::new(response), session).await;
1012    }
1013
1014    let joined_room = {
1015        let room = session
1016            .db()
1017            .await
1018            .join_room(
1019                room_id,
1020                session.user_id,
1021                session.connection_id,
1022                session.zed_environment.as_ref(),
1023            )
1024            .await?;
1025        room_updated(&room.room, &session.peer);
1026        room.into_inner()
1027    };
1028
1029    for connection_id in session
1030        .connection_pool()
1031        .await
1032        .user_connection_ids(session.user_id)
1033    {
1034        session
1035            .peer
1036            .send(
1037                connection_id,
1038                proto::CallCanceled {
1039                    room_id: room_id.to_proto(),
1040                },
1041            )
1042            .trace_err();
1043    }
1044
1045    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1046        if let Some(token) = live_kit
1047            .room_token(
1048                &joined_room.room.live_kit_room,
1049                &session.user_id.to_string(),
1050            )
1051            .trace_err()
1052        {
1053            Some(proto::LiveKitConnectionInfo {
1054                server_url: live_kit.url().into(),
1055                token,
1056                can_publish: true,
1057            })
1058        } else {
1059            None
1060        }
1061    } else {
1062        None
1063    };
1064
1065    response.send(proto::JoinRoomResponse {
1066        room: Some(joined_room.room),
1067        channel_id: None,
1068        live_kit_connection_info,
1069    })?;
1070
1071    update_user_contacts(session.user_id, &session).await?;
1072    Ok(())
1073}
1074
1075/// Rejoin room is used to reconnect to a room after connection errors.
1076async fn rejoin_room(
1077    request: proto::RejoinRoom,
1078    response: Response<proto::RejoinRoom>,
1079    session: Session,
1080) -> Result<()> {
1081    let room;
1082    let channel_id;
1083    let channel_members;
1084    {
1085        let mut rejoined_room = session
1086            .db()
1087            .await
1088            .rejoin_room(request, session.user_id, session.connection_id)
1089            .await?;
1090
1091        response.send(proto::RejoinRoomResponse {
1092            room: Some(rejoined_room.room.clone()),
1093            reshared_projects: rejoined_room
1094                .reshared_projects
1095                .iter()
1096                .map(|project| proto::ResharedProject {
1097                    id: project.id.to_proto(),
1098                    collaborators: project
1099                        .collaborators
1100                        .iter()
1101                        .map(|collaborator| collaborator.to_proto())
1102                        .collect(),
1103                })
1104                .collect(),
1105            rejoined_projects: rejoined_room
1106                .rejoined_projects
1107                .iter()
1108                .map(|rejoined_project| proto::RejoinedProject {
1109                    id: rejoined_project.id.to_proto(),
1110                    worktrees: rejoined_project
1111                        .worktrees
1112                        .iter()
1113                        .map(|worktree| proto::WorktreeMetadata {
1114                            id: worktree.id,
1115                            root_name: worktree.root_name.clone(),
1116                            visible: worktree.visible,
1117                            abs_path: worktree.abs_path.clone(),
1118                        })
1119                        .collect(),
1120                    collaborators: rejoined_project
1121                        .collaborators
1122                        .iter()
1123                        .map(|collaborator| collaborator.to_proto())
1124                        .collect(),
1125                    language_servers: rejoined_project.language_servers.clone(),
1126                })
1127                .collect(),
1128        })?;
1129        room_updated(&rejoined_room.room, &session.peer);
1130
1131        for project in &rejoined_room.reshared_projects {
1132            for collaborator in &project.collaborators {
1133                session
1134                    .peer
1135                    .send(
1136                        collaborator.connection_id,
1137                        proto::UpdateProjectCollaborator {
1138                            project_id: project.id.to_proto(),
1139                            old_peer_id: Some(project.old_connection_id.into()),
1140                            new_peer_id: Some(session.connection_id.into()),
1141                        },
1142                    )
1143                    .trace_err();
1144            }
1145
1146            broadcast(
1147                Some(session.connection_id),
1148                project
1149                    .collaborators
1150                    .iter()
1151                    .map(|collaborator| collaborator.connection_id),
1152                |connection_id| {
1153                    session.peer.forward_send(
1154                        session.connection_id,
1155                        connection_id,
1156                        proto::UpdateProject {
1157                            project_id: project.id.to_proto(),
1158                            worktrees: project.worktrees.clone(),
1159                        },
1160                    )
1161                },
1162            );
1163        }
1164
1165        for project in &rejoined_room.rejoined_projects {
1166            for collaborator in &project.collaborators {
1167                session
1168                    .peer
1169                    .send(
1170                        collaborator.connection_id,
1171                        proto::UpdateProjectCollaborator {
1172                            project_id: project.id.to_proto(),
1173                            old_peer_id: Some(project.old_connection_id.into()),
1174                            new_peer_id: Some(session.connection_id.into()),
1175                        },
1176                    )
1177                    .trace_err();
1178            }
1179        }
1180
1181        for project in &mut rejoined_room.rejoined_projects {
1182            for worktree in mem::take(&mut project.worktrees) {
1183                #[cfg(any(test, feature = "test-support"))]
1184                const MAX_CHUNK_SIZE: usize = 2;
1185                #[cfg(not(any(test, feature = "test-support")))]
1186                const MAX_CHUNK_SIZE: usize = 256;
1187
1188                // Stream this worktree's entries.
1189                let message = proto::UpdateWorktree {
1190                    project_id: project.id.to_proto(),
1191                    worktree_id: worktree.id,
1192                    abs_path: worktree.abs_path.clone(),
1193                    root_name: worktree.root_name,
1194                    updated_entries: worktree.updated_entries,
1195                    removed_entries: worktree.removed_entries,
1196                    scan_id: worktree.scan_id,
1197                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1198                    updated_repositories: worktree.updated_repositories,
1199                    removed_repositories: worktree.removed_repositories,
1200                };
1201                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1202                    session.peer.send(session.connection_id, update.clone())?;
1203                }
1204
1205                // Stream this worktree's diagnostics.
1206                for summary in worktree.diagnostic_summaries {
1207                    session.peer.send(
1208                        session.connection_id,
1209                        proto::UpdateDiagnosticSummary {
1210                            project_id: project.id.to_proto(),
1211                            worktree_id: worktree.id,
1212                            summary: Some(summary),
1213                        },
1214                    )?;
1215                }
1216
1217                for settings_file in worktree.settings_files {
1218                    session.peer.send(
1219                        session.connection_id,
1220                        proto::UpdateWorktreeSettings {
1221                            project_id: project.id.to_proto(),
1222                            worktree_id: worktree.id,
1223                            path: settings_file.path,
1224                            content: Some(settings_file.content),
1225                        },
1226                    )?;
1227                }
1228            }
1229
1230            for language_server in &project.language_servers {
1231                session.peer.send(
1232                    session.connection_id,
1233                    proto::UpdateLanguageServer {
1234                        project_id: project.id.to_proto(),
1235                        language_server_id: language_server.id,
1236                        variant: Some(
1237                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1238                                proto::LspDiskBasedDiagnosticsUpdated {},
1239                            ),
1240                        ),
1241                    },
1242                )?;
1243            }
1244        }
1245
1246        let rejoined_room = rejoined_room.into_inner();
1247
1248        room = rejoined_room.room;
1249        channel_id = rejoined_room.channel_id;
1250        channel_members = rejoined_room.channel_members;
1251    }
1252
1253    if let Some(channel_id) = channel_id {
1254        channel_updated(
1255            channel_id,
1256            &room,
1257            &channel_members,
1258            &session.peer,
1259            &*session.connection_pool().await,
1260        );
1261    }
1262
1263    update_user_contacts(session.user_id, &session).await?;
1264    Ok(())
1265}
1266
1267/// leave room disonnects from the room.
1268async fn leave_room(
1269    _: proto::LeaveRoom,
1270    response: Response<proto::LeaveRoom>,
1271    session: Session,
1272) -> Result<()> {
1273    leave_room_for_session(&session).await?;
1274    response.send(proto::Ack {})?;
1275    Ok(())
1276}
1277
1278/// Updates the permissions of someone else in the room.
1279async fn set_room_participant_role(
1280    request: proto::SetRoomParticipantRole,
1281    response: Response<proto::SetRoomParticipantRole>,
1282    session: Session,
1283) -> Result<()> {
1284    let (live_kit_room, can_publish) = {
1285        let room = session
1286            .db()
1287            .await
1288            .set_room_participant_role(
1289                session.user_id,
1290                RoomId::from_proto(request.room_id),
1291                UserId::from_proto(request.user_id),
1292                ChannelRole::from(request.role()),
1293            )
1294            .await?;
1295
1296        let live_kit_room = room.live_kit_room.clone();
1297        let can_publish = ChannelRole::from(request.role()).can_publish_to_rooms();
1298        room_updated(&room, &session.peer);
1299        (live_kit_room, can_publish)
1300    };
1301
1302    if let Some(live_kit) = session.live_kit_client.as_ref() {
1303        live_kit
1304            .update_participant(
1305                live_kit_room.clone(),
1306                request.user_id.to_string(),
1307                live_kit_server::proto::ParticipantPermission {
1308                    can_subscribe: true,
1309                    can_publish,
1310                    can_publish_data: can_publish,
1311                    hidden: false,
1312                    recorder: false,
1313                },
1314            )
1315            .await
1316            .trace_err();
1317    }
1318
1319    response.send(proto::Ack {})?;
1320    Ok(())
1321}
1322
1323/// Call someone else into the current room
1324async fn call(
1325    request: proto::Call,
1326    response: Response<proto::Call>,
1327    session: Session,
1328) -> Result<()> {
1329    let room_id = RoomId::from_proto(request.room_id);
1330    let calling_user_id = session.user_id;
1331    let calling_connection_id = session.connection_id;
1332    let called_user_id = UserId::from_proto(request.called_user_id);
1333    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1334    if !session
1335        .db()
1336        .await
1337        .has_contact(calling_user_id, called_user_id)
1338        .await?
1339    {
1340        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1341    }
1342
1343    let incoming_call = {
1344        let (room, incoming_call) = &mut *session
1345            .db()
1346            .await
1347            .call(
1348                room_id,
1349                calling_user_id,
1350                calling_connection_id,
1351                called_user_id,
1352                initial_project_id,
1353            )
1354            .await?;
1355        room_updated(&room, &session.peer);
1356        mem::take(incoming_call)
1357    };
1358    update_user_contacts(called_user_id, &session).await?;
1359
1360    let mut calls = session
1361        .connection_pool()
1362        .await
1363        .user_connection_ids(called_user_id)
1364        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1365        .collect::<FuturesUnordered<_>>();
1366
1367    while let Some(call_response) = calls.next().await {
1368        match call_response.as_ref() {
1369            Ok(_) => {
1370                response.send(proto::Ack {})?;
1371                return Ok(());
1372            }
1373            Err(_) => {
1374                call_response.trace_err();
1375            }
1376        }
1377    }
1378
1379    {
1380        let room = session
1381            .db()
1382            .await
1383            .call_failed(room_id, called_user_id)
1384            .await?;
1385        room_updated(&room, &session.peer);
1386    }
1387    update_user_contacts(called_user_id, &session).await?;
1388
1389    Err(anyhow!("failed to ring user"))?
1390}
1391
1392/// Cancel an outgoing call.
1393async fn cancel_call(
1394    request: proto::CancelCall,
1395    response: Response<proto::CancelCall>,
1396    session: Session,
1397) -> Result<()> {
1398    let called_user_id = UserId::from_proto(request.called_user_id);
1399    let room_id = RoomId::from_proto(request.room_id);
1400    {
1401        let room = session
1402            .db()
1403            .await
1404            .cancel_call(room_id, session.connection_id, called_user_id)
1405            .await?;
1406        room_updated(&room, &session.peer);
1407    }
1408
1409    for connection_id in session
1410        .connection_pool()
1411        .await
1412        .user_connection_ids(called_user_id)
1413    {
1414        session
1415            .peer
1416            .send(
1417                connection_id,
1418                proto::CallCanceled {
1419                    room_id: room_id.to_proto(),
1420                },
1421            )
1422            .trace_err();
1423    }
1424    response.send(proto::Ack {})?;
1425
1426    update_user_contacts(called_user_id, &session).await?;
1427    Ok(())
1428}
1429
1430/// Decline an incoming call.
1431async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1432    let room_id = RoomId::from_proto(message.room_id);
1433    {
1434        let room = session
1435            .db()
1436            .await
1437            .decline_call(Some(room_id), session.user_id)
1438            .await?
1439            .ok_or_else(|| anyhow!("failed to decline call"))?;
1440        room_updated(&room, &session.peer);
1441    }
1442
1443    for connection_id in session
1444        .connection_pool()
1445        .await
1446        .user_connection_ids(session.user_id)
1447    {
1448        session
1449            .peer
1450            .send(
1451                connection_id,
1452                proto::CallCanceled {
1453                    room_id: room_id.to_proto(),
1454                },
1455            )
1456            .trace_err();
1457    }
1458    update_user_contacts(session.user_id, &session).await?;
1459    Ok(())
1460}
1461
1462/// Updates other participants in the room with your current location.
1463async fn update_participant_location(
1464    request: proto::UpdateParticipantLocation,
1465    response: Response<proto::UpdateParticipantLocation>,
1466    session: Session,
1467) -> Result<()> {
1468    let room_id = RoomId::from_proto(request.room_id);
1469    let location = request
1470        .location
1471        .ok_or_else(|| anyhow!("invalid location"))?;
1472
1473    let db = session.db().await;
1474    let room = db
1475        .update_room_participant_location(room_id, session.connection_id, location)
1476        .await?;
1477
1478    room_updated(&room, &session.peer);
1479    response.send(proto::Ack {})?;
1480    Ok(())
1481}
1482
1483/// Share a project into the room.
1484async fn share_project(
1485    request: proto::ShareProject,
1486    response: Response<proto::ShareProject>,
1487    session: Session,
1488) -> Result<()> {
1489    let (project_id, room) = &*session
1490        .db()
1491        .await
1492        .share_project(
1493            RoomId::from_proto(request.room_id),
1494            session.connection_id,
1495            &request.worktrees,
1496        )
1497        .await?;
1498    response.send(proto::ShareProjectResponse {
1499        project_id: project_id.to_proto(),
1500    })?;
1501    room_updated(&room, &session.peer);
1502
1503    Ok(())
1504}
1505
1506/// Unshare a project from the room.
1507async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1508    let project_id = ProjectId::from_proto(message.project_id);
1509
1510    let (room, guest_connection_ids) = &*session
1511        .db()
1512        .await
1513        .unshare_project(project_id, session.connection_id)
1514        .await?;
1515
1516    broadcast(
1517        Some(session.connection_id),
1518        guest_connection_ids.iter().copied(),
1519        |conn_id| session.peer.send(conn_id, message.clone()),
1520    );
1521    room_updated(&room, &session.peer);
1522
1523    Ok(())
1524}
1525
1526/// Join someone elses shared project.
1527async fn join_project(
1528    request: proto::JoinProject,
1529    response: Response<proto::JoinProject>,
1530    session: Session,
1531) -> Result<()> {
1532    let project_id = ProjectId::from_proto(request.project_id);
1533    let guest_user_id = session.user_id;
1534
1535    tracing::info!(%project_id, "join project");
1536
1537    let (project, replica_id) = &mut *session
1538        .db()
1539        .await
1540        .join_project(project_id, session.connection_id)
1541        .await?;
1542
1543    let collaborators = project
1544        .collaborators
1545        .iter()
1546        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1547        .map(|collaborator| collaborator.to_proto())
1548        .collect::<Vec<_>>();
1549
1550    let worktrees = project
1551        .worktrees
1552        .iter()
1553        .map(|(id, worktree)| proto::WorktreeMetadata {
1554            id: *id,
1555            root_name: worktree.root_name.clone(),
1556            visible: worktree.visible,
1557            abs_path: worktree.abs_path.clone(),
1558        })
1559        .collect::<Vec<_>>();
1560
1561    for collaborator in &collaborators {
1562        session
1563            .peer
1564            .send(
1565                collaborator.peer_id.unwrap().into(),
1566                proto::AddProjectCollaborator {
1567                    project_id: project_id.to_proto(),
1568                    collaborator: Some(proto::Collaborator {
1569                        peer_id: Some(session.connection_id.into()),
1570                        replica_id: replica_id.0 as u32,
1571                        user_id: guest_user_id.to_proto(),
1572                    }),
1573                },
1574            )
1575            .trace_err();
1576    }
1577
1578    // First, we send the metadata associated with each worktree.
1579    response.send(proto::JoinProjectResponse {
1580        worktrees: worktrees.clone(),
1581        replica_id: replica_id.0 as u32,
1582        collaborators: collaborators.clone(),
1583        language_servers: project.language_servers.clone(),
1584    })?;
1585
1586    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1587        #[cfg(any(test, feature = "test-support"))]
1588        const MAX_CHUNK_SIZE: usize = 2;
1589        #[cfg(not(any(test, feature = "test-support")))]
1590        const MAX_CHUNK_SIZE: usize = 256;
1591
1592        // Stream this worktree's entries.
1593        let message = proto::UpdateWorktree {
1594            project_id: project_id.to_proto(),
1595            worktree_id,
1596            abs_path: worktree.abs_path.clone(),
1597            root_name: worktree.root_name,
1598            updated_entries: worktree.entries,
1599            removed_entries: Default::default(),
1600            scan_id: worktree.scan_id,
1601            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1602            updated_repositories: worktree.repository_entries.into_values().collect(),
1603            removed_repositories: Default::default(),
1604        };
1605        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1606            session.peer.send(session.connection_id, update.clone())?;
1607        }
1608
1609        // Stream this worktree's diagnostics.
1610        for summary in worktree.diagnostic_summaries {
1611            session.peer.send(
1612                session.connection_id,
1613                proto::UpdateDiagnosticSummary {
1614                    project_id: project_id.to_proto(),
1615                    worktree_id: worktree.id,
1616                    summary: Some(summary),
1617                },
1618            )?;
1619        }
1620
1621        for settings_file in worktree.settings_files {
1622            session.peer.send(
1623                session.connection_id,
1624                proto::UpdateWorktreeSettings {
1625                    project_id: project_id.to_proto(),
1626                    worktree_id: worktree.id,
1627                    path: settings_file.path,
1628                    content: Some(settings_file.content),
1629                },
1630            )?;
1631        }
1632    }
1633
1634    for language_server in &project.language_servers {
1635        session.peer.send(
1636            session.connection_id,
1637            proto::UpdateLanguageServer {
1638                project_id: project_id.to_proto(),
1639                language_server_id: language_server.id,
1640                variant: Some(
1641                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1642                        proto::LspDiskBasedDiagnosticsUpdated {},
1643                    ),
1644                ),
1645            },
1646        )?;
1647    }
1648
1649    Ok(())
1650}
1651
1652/// Leave someone elses shared project.
1653async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1654    let sender_id = session.connection_id;
1655    let project_id = ProjectId::from_proto(request.project_id);
1656
1657    let (room, project) = &*session
1658        .db()
1659        .await
1660        .leave_project(project_id, sender_id)
1661        .await?;
1662    tracing::info!(
1663        %project_id,
1664        host_user_id = %project.host_user_id,
1665        host_connection_id = %project.host_connection_id,
1666        "leave project"
1667    );
1668
1669    project_left(&project, &session);
1670    room_updated(&room, &session.peer);
1671
1672    Ok(())
1673}
1674
1675/// Updates other participants with changes to the project
1676async fn update_project(
1677    request: proto::UpdateProject,
1678    response: Response<proto::UpdateProject>,
1679    session: Session,
1680) -> Result<()> {
1681    let project_id = ProjectId::from_proto(request.project_id);
1682    let (room, guest_connection_ids) = &*session
1683        .db()
1684        .await
1685        .update_project(project_id, session.connection_id, &request.worktrees)
1686        .await?;
1687    broadcast(
1688        Some(session.connection_id),
1689        guest_connection_ids.iter().copied(),
1690        |connection_id| {
1691            session
1692                .peer
1693                .forward_send(session.connection_id, connection_id, request.clone())
1694        },
1695    );
1696    room_updated(&room, &session.peer);
1697    response.send(proto::Ack {})?;
1698
1699    Ok(())
1700}
1701
1702/// Updates other participants with changes to the worktree
1703async fn update_worktree(
1704    request: proto::UpdateWorktree,
1705    response: Response<proto::UpdateWorktree>,
1706    session: Session,
1707) -> Result<()> {
1708    let guest_connection_ids = session
1709        .db()
1710        .await
1711        .update_worktree(&request, session.connection_id)
1712        .await?;
1713
1714    broadcast(
1715        Some(session.connection_id),
1716        guest_connection_ids.iter().copied(),
1717        |connection_id| {
1718            session
1719                .peer
1720                .forward_send(session.connection_id, connection_id, request.clone())
1721        },
1722    );
1723    response.send(proto::Ack {})?;
1724    Ok(())
1725}
1726
1727/// Updates other participants with changes to the diagnostics
1728async fn update_diagnostic_summary(
1729    message: proto::UpdateDiagnosticSummary,
1730    session: Session,
1731) -> Result<()> {
1732    let guest_connection_ids = session
1733        .db()
1734        .await
1735        .update_diagnostic_summary(&message, session.connection_id)
1736        .await?;
1737
1738    broadcast(
1739        Some(session.connection_id),
1740        guest_connection_ids.iter().copied(),
1741        |connection_id| {
1742            session
1743                .peer
1744                .forward_send(session.connection_id, connection_id, message.clone())
1745        },
1746    );
1747
1748    Ok(())
1749}
1750
1751/// Updates other participants with changes to the worktree settings
1752async fn update_worktree_settings(
1753    message: proto::UpdateWorktreeSettings,
1754    session: Session,
1755) -> Result<()> {
1756    let guest_connection_ids = session
1757        .db()
1758        .await
1759        .update_worktree_settings(&message, session.connection_id)
1760        .await?;
1761
1762    broadcast(
1763        Some(session.connection_id),
1764        guest_connection_ids.iter().copied(),
1765        |connection_id| {
1766            session
1767                .peer
1768                .forward_send(session.connection_id, connection_id, message.clone())
1769        },
1770    );
1771
1772    Ok(())
1773}
1774
1775/// Notify other participants that a  language server has started.
1776async fn start_language_server(
1777    request: proto::StartLanguageServer,
1778    session: Session,
1779) -> Result<()> {
1780    let guest_connection_ids = session
1781        .db()
1782        .await
1783        .start_language_server(&request, session.connection_id)
1784        .await?;
1785
1786    broadcast(
1787        Some(session.connection_id),
1788        guest_connection_ids.iter().copied(),
1789        |connection_id| {
1790            session
1791                .peer
1792                .forward_send(session.connection_id, connection_id, request.clone())
1793        },
1794    );
1795    Ok(())
1796}
1797
1798/// Notify other participants that a language server has changed.
1799async fn update_language_server(
1800    request: proto::UpdateLanguageServer,
1801    session: Session,
1802) -> Result<()> {
1803    let project_id = ProjectId::from_proto(request.project_id);
1804    let project_connection_ids = session
1805        .db()
1806        .await
1807        .project_connection_ids(project_id, session.connection_id)
1808        .await?;
1809    broadcast(
1810        Some(session.connection_id),
1811        project_connection_ids.iter().copied(),
1812        |connection_id| {
1813            session
1814                .peer
1815                .forward_send(session.connection_id, connection_id, request.clone())
1816        },
1817    );
1818    Ok(())
1819}
1820
1821/// forward a project request to the host. These requests should be read only
1822/// as guests are allowed to send them.
1823async fn forward_read_only_project_request<T>(
1824    request: T,
1825    response: Response<T>,
1826    session: Session,
1827) -> Result<()>
1828where
1829    T: EntityMessage + RequestMessage,
1830{
1831    let project_id = ProjectId::from_proto(request.remote_entity_id());
1832    let host_connection_id = session
1833        .db()
1834        .await
1835        .host_for_read_only_project_request(project_id, session.connection_id)
1836        .await?;
1837    let payload = session
1838        .peer
1839        .forward_request(session.connection_id, host_connection_id, request)
1840        .await?;
1841    response.send(payload)?;
1842    Ok(())
1843}
1844
1845/// forward a project request to the host. These requests are disallowed
1846/// for guests.
1847async fn forward_mutating_project_request<T>(
1848    request: T,
1849    response: Response<T>,
1850    session: Session,
1851) -> Result<()>
1852where
1853    T: EntityMessage + RequestMessage,
1854{
1855    let project_id = ProjectId::from_proto(request.remote_entity_id());
1856    let host_connection_id = session
1857        .db()
1858        .await
1859        .host_for_mutating_project_request(project_id, session.connection_id)
1860        .await?;
1861    let payload = session
1862        .peer
1863        .forward_request(session.connection_id, host_connection_id, request)
1864        .await?;
1865    response.send(payload)?;
1866    Ok(())
1867}
1868
1869/// Notify other participants that a new buffer has been created
1870async fn create_buffer_for_peer(
1871    request: proto::CreateBufferForPeer,
1872    session: Session,
1873) -> Result<()> {
1874    session
1875        .db()
1876        .await
1877        .check_user_is_project_host(
1878            ProjectId::from_proto(request.project_id),
1879            session.connection_id,
1880        )
1881        .await?;
1882    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1883    session
1884        .peer
1885        .forward_send(session.connection_id, peer_id.into(), request)?;
1886    Ok(())
1887}
1888
1889/// Notify other participants that a buffer has been updated. This is
1890/// allowed for guests as long as the update is limited to selections.
1891async fn update_buffer(
1892    request: proto::UpdateBuffer,
1893    response: Response<proto::UpdateBuffer>,
1894    session: Session,
1895) -> Result<()> {
1896    let project_id = ProjectId::from_proto(request.project_id);
1897    let mut guest_connection_ids;
1898    let mut host_connection_id = None;
1899
1900    let mut requires_write_permission = false;
1901
1902    for op in request.operations.iter() {
1903        match op.variant {
1904            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
1905            Some(_) => requires_write_permission = true,
1906        }
1907    }
1908
1909    {
1910        let collaborators = session
1911            .db()
1912            .await
1913            .project_collaborators_for_buffer_update(
1914                project_id,
1915                session.connection_id,
1916                requires_write_permission,
1917            )
1918            .await?;
1919        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1920        for collaborator in collaborators.iter() {
1921            if collaborator.is_host {
1922                host_connection_id = Some(collaborator.connection_id);
1923            } else {
1924                guest_connection_ids.push(collaborator.connection_id);
1925            }
1926        }
1927    }
1928    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1929
1930    broadcast(
1931        Some(session.connection_id),
1932        guest_connection_ids,
1933        |connection_id| {
1934            session
1935                .peer
1936                .forward_send(session.connection_id, connection_id, request.clone())
1937        },
1938    );
1939    if host_connection_id != session.connection_id {
1940        session
1941            .peer
1942            .forward_request(session.connection_id, host_connection_id, request.clone())
1943            .await?;
1944    }
1945
1946    response.send(proto::Ack {})?;
1947    Ok(())
1948}
1949
1950/// Notify other participants that a project has been updated.
1951async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
1952    request: T,
1953    session: Session,
1954) -> Result<()> {
1955    let project_id = ProjectId::from_proto(request.remote_entity_id());
1956    let project_connection_ids = session
1957        .db()
1958        .await
1959        .project_connection_ids(project_id, session.connection_id)
1960        .await?;
1961
1962    broadcast(
1963        Some(session.connection_id),
1964        project_connection_ids.iter().copied(),
1965        |connection_id| {
1966            session
1967                .peer
1968                .forward_send(session.connection_id, connection_id, request.clone())
1969        },
1970    );
1971    Ok(())
1972}
1973
1974/// Start following another user in a call.
1975async fn follow(
1976    request: proto::Follow,
1977    response: Response<proto::Follow>,
1978    session: Session,
1979) -> Result<()> {
1980    let room_id = RoomId::from_proto(request.room_id);
1981    let project_id = request.project_id.map(ProjectId::from_proto);
1982    let leader_id = request
1983        .leader_id
1984        .ok_or_else(|| anyhow!("invalid leader id"))?
1985        .into();
1986    let follower_id = session.connection_id;
1987
1988    session
1989        .db()
1990        .await
1991        .check_room_participants(room_id, leader_id, session.connection_id)
1992        .await?;
1993
1994    let response_payload = session
1995        .peer
1996        .forward_request(session.connection_id, leader_id, request)
1997        .await?;
1998    response.send(response_payload)?;
1999
2000    if let Some(project_id) = project_id {
2001        let room = session
2002            .db()
2003            .await
2004            .follow(room_id, project_id, leader_id, follower_id)
2005            .await?;
2006        room_updated(&room, &session.peer);
2007    }
2008
2009    Ok(())
2010}
2011
2012/// Stop following another user in a call.
2013async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2014    let room_id = RoomId::from_proto(request.room_id);
2015    let project_id = request.project_id.map(ProjectId::from_proto);
2016    let leader_id = request
2017        .leader_id
2018        .ok_or_else(|| anyhow!("invalid leader id"))?
2019        .into();
2020    let follower_id = session.connection_id;
2021
2022    session
2023        .db()
2024        .await
2025        .check_room_participants(room_id, leader_id, session.connection_id)
2026        .await?;
2027
2028    session
2029        .peer
2030        .forward_send(session.connection_id, leader_id, request)?;
2031
2032    if let Some(project_id) = project_id {
2033        let room = session
2034            .db()
2035            .await
2036            .unfollow(room_id, project_id, leader_id, follower_id)
2037            .await?;
2038        room_updated(&room, &session.peer);
2039    }
2040
2041    Ok(())
2042}
2043
2044/// Notify everyone following you of your current location.
2045async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2046    let room_id = RoomId::from_proto(request.room_id);
2047    let database = session.db.lock().await;
2048
2049    let connection_ids = if let Some(project_id) = request.project_id {
2050        let project_id = ProjectId::from_proto(project_id);
2051        database
2052            .project_connection_ids(project_id, session.connection_id)
2053            .await?
2054    } else {
2055        database
2056            .room_connection_ids(room_id, session.connection_id)
2057            .await?
2058    };
2059
2060    // For now, don't send view update messages back to that view's current leader.
2061    let connection_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2062        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2063        _ => None,
2064    });
2065
2066    for follower_peer_id in request.follower_ids.iter().copied() {
2067        let follower_connection_id = follower_peer_id.into();
2068        if Some(follower_peer_id) != connection_id_to_omit
2069            && connection_ids.contains(&follower_connection_id)
2070        {
2071            session.peer.forward_send(
2072                session.connection_id,
2073                follower_connection_id,
2074                request.clone(),
2075            )?;
2076        }
2077    }
2078    Ok(())
2079}
2080
2081/// Get public data about users.
2082async fn get_users(
2083    request: proto::GetUsers,
2084    response: Response<proto::GetUsers>,
2085    session: Session,
2086) -> Result<()> {
2087    let user_ids = request
2088        .user_ids
2089        .into_iter()
2090        .map(UserId::from_proto)
2091        .collect();
2092    let users = session
2093        .db()
2094        .await
2095        .get_users_by_ids(user_ids)
2096        .await?
2097        .into_iter()
2098        .map(|user| proto::User {
2099            id: user.id.to_proto(),
2100            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2101            github_login: user.github_login,
2102        })
2103        .collect();
2104    response.send(proto::UsersResponse { users })?;
2105    Ok(())
2106}
2107
2108/// Search for users (to invite) buy Github login
2109async fn fuzzy_search_users(
2110    request: proto::FuzzySearchUsers,
2111    response: Response<proto::FuzzySearchUsers>,
2112    session: Session,
2113) -> Result<()> {
2114    let query = request.query;
2115    let users = match query.len() {
2116        0 => vec![],
2117        1 | 2 => session
2118            .db()
2119            .await
2120            .get_user_by_github_login(&query)
2121            .await?
2122            .into_iter()
2123            .collect(),
2124        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2125    };
2126    let users = users
2127        .into_iter()
2128        .filter(|user| user.id != session.user_id)
2129        .map(|user| proto::User {
2130            id: user.id.to_proto(),
2131            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2132            github_login: user.github_login,
2133        })
2134        .collect();
2135    response.send(proto::UsersResponse { users })?;
2136    Ok(())
2137}
2138
2139/// Send a contact request to another user.
2140async fn request_contact(
2141    request: proto::RequestContact,
2142    response: Response<proto::RequestContact>,
2143    session: Session,
2144) -> Result<()> {
2145    let requester_id = session.user_id;
2146    let responder_id = UserId::from_proto(request.responder_id);
2147    if requester_id == responder_id {
2148        return Err(anyhow!("cannot add yourself as a contact"))?;
2149    }
2150
2151    let notifications = session
2152        .db()
2153        .await
2154        .send_contact_request(requester_id, responder_id)
2155        .await?;
2156
2157    // Update outgoing contact requests of requester
2158    let mut update = proto::UpdateContacts::default();
2159    update.outgoing_requests.push(responder_id.to_proto());
2160    for connection_id in session
2161        .connection_pool()
2162        .await
2163        .user_connection_ids(requester_id)
2164    {
2165        session.peer.send(connection_id, update.clone())?;
2166    }
2167
2168    // Update incoming contact requests of responder
2169    let mut update = proto::UpdateContacts::default();
2170    update
2171        .incoming_requests
2172        .push(proto::IncomingContactRequest {
2173            requester_id: requester_id.to_proto(),
2174        });
2175    let connection_pool = session.connection_pool().await;
2176    for connection_id in connection_pool.user_connection_ids(responder_id) {
2177        session.peer.send(connection_id, update.clone())?;
2178    }
2179
2180    send_notifications(&*connection_pool, &session.peer, notifications);
2181
2182    response.send(proto::Ack {})?;
2183    Ok(())
2184}
2185
2186/// Accept or decline a contact request
2187async fn respond_to_contact_request(
2188    request: proto::RespondToContactRequest,
2189    response: Response<proto::RespondToContactRequest>,
2190    session: Session,
2191) -> Result<()> {
2192    let responder_id = session.user_id;
2193    let requester_id = UserId::from_proto(request.requester_id);
2194    let db = session.db().await;
2195    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2196        db.dismiss_contact_notification(responder_id, requester_id)
2197            .await?;
2198    } else {
2199        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2200
2201        let notifications = db
2202            .respond_to_contact_request(responder_id, requester_id, accept)
2203            .await?;
2204        let requester_busy = db.is_user_busy(requester_id).await?;
2205        let responder_busy = db.is_user_busy(responder_id).await?;
2206
2207        let pool = session.connection_pool().await;
2208        // Update responder with new contact
2209        let mut update = proto::UpdateContacts::default();
2210        if accept {
2211            update
2212                .contacts
2213                .push(contact_for_user(requester_id, requester_busy, &pool));
2214        }
2215        update
2216            .remove_incoming_requests
2217            .push(requester_id.to_proto());
2218        for connection_id in pool.user_connection_ids(responder_id) {
2219            session.peer.send(connection_id, update.clone())?;
2220        }
2221
2222        // Update requester with new contact
2223        let mut update = proto::UpdateContacts::default();
2224        if accept {
2225            update
2226                .contacts
2227                .push(contact_for_user(responder_id, responder_busy, &pool));
2228        }
2229        update
2230            .remove_outgoing_requests
2231            .push(responder_id.to_proto());
2232
2233        for connection_id in pool.user_connection_ids(requester_id) {
2234            session.peer.send(connection_id, update.clone())?;
2235        }
2236
2237        send_notifications(&*pool, &session.peer, notifications);
2238    }
2239
2240    response.send(proto::Ack {})?;
2241    Ok(())
2242}
2243
2244/// Remove a contact.
2245async fn remove_contact(
2246    request: proto::RemoveContact,
2247    response: Response<proto::RemoveContact>,
2248    session: Session,
2249) -> Result<()> {
2250    let requester_id = session.user_id;
2251    let responder_id = UserId::from_proto(request.user_id);
2252    let db = session.db().await;
2253    let (contact_accepted, deleted_notification_id) =
2254        db.remove_contact(requester_id, responder_id).await?;
2255
2256    let pool = session.connection_pool().await;
2257    // Update outgoing contact requests of requester
2258    let mut update = proto::UpdateContacts::default();
2259    if contact_accepted {
2260        update.remove_contacts.push(responder_id.to_proto());
2261    } else {
2262        update
2263            .remove_outgoing_requests
2264            .push(responder_id.to_proto());
2265    }
2266    for connection_id in pool.user_connection_ids(requester_id) {
2267        session.peer.send(connection_id, update.clone())?;
2268    }
2269
2270    // Update incoming contact requests of responder
2271    let mut update = proto::UpdateContacts::default();
2272    if contact_accepted {
2273        update.remove_contacts.push(requester_id.to_proto());
2274    } else {
2275        update
2276            .remove_incoming_requests
2277            .push(requester_id.to_proto());
2278    }
2279    for connection_id in pool.user_connection_ids(responder_id) {
2280        session.peer.send(connection_id, update.clone())?;
2281        if let Some(notification_id) = deleted_notification_id {
2282            session.peer.send(
2283                connection_id,
2284                proto::DeleteNotification {
2285                    notification_id: notification_id.to_proto(),
2286                },
2287            )?;
2288        }
2289    }
2290
2291    response.send(proto::Ack {})?;
2292    Ok(())
2293}
2294
2295/// Creates a new channel.
2296async fn create_channel(
2297    request: proto::CreateChannel,
2298    response: Response<proto::CreateChannel>,
2299    session: Session,
2300) -> Result<()> {
2301    let db = session.db().await;
2302
2303    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2304    let CreateChannelResult {
2305        channel,
2306        participants_to_update,
2307    } = db
2308        .create_channel(&request.name, parent_id, session.user_id)
2309        .await?;
2310
2311    response.send(proto::CreateChannelResponse {
2312        channel: Some(channel.to_proto()),
2313        parent_id: request.parent_id,
2314    })?;
2315
2316    let connection_pool = session.connection_pool().await;
2317    for (user_id, channels) in participants_to_update {
2318        let update = build_channels_update(channels, vec![]);
2319        for connection_id in connection_pool.user_connection_ids(user_id) {
2320            if user_id == session.user_id {
2321                continue;
2322            }
2323            session.peer.send(connection_id, update.clone())?;
2324        }
2325    }
2326
2327    Ok(())
2328}
2329
2330/// Delete a channel
2331async fn delete_channel(
2332    request: proto::DeleteChannel,
2333    response: Response<proto::DeleteChannel>,
2334    session: Session,
2335) -> Result<()> {
2336    let db = session.db().await;
2337
2338    let channel_id = request.channel_id;
2339    let (removed_channels, member_ids) = db
2340        .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2341        .await?;
2342    response.send(proto::Ack {})?;
2343
2344    // Notify members of removed channels
2345    let mut update = proto::UpdateChannels::default();
2346    update
2347        .delete_channels
2348        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2349
2350    let connection_pool = session.connection_pool().await;
2351    for member_id in member_ids {
2352        for connection_id in connection_pool.user_connection_ids(member_id) {
2353            session.peer.send(connection_id, update.clone())?;
2354        }
2355    }
2356
2357    Ok(())
2358}
2359
2360/// Invite someone to join a channel.
2361async fn invite_channel_member(
2362    request: proto::InviteChannelMember,
2363    response: Response<proto::InviteChannelMember>,
2364    session: Session,
2365) -> Result<()> {
2366    let db = session.db().await;
2367    let channel_id = ChannelId::from_proto(request.channel_id);
2368    let invitee_id = UserId::from_proto(request.user_id);
2369    let InviteMemberResult {
2370        channel,
2371        notifications,
2372    } = db
2373        .invite_channel_member(
2374            channel_id,
2375            invitee_id,
2376            session.user_id,
2377            request.role().into(),
2378        )
2379        .await?;
2380
2381    let update = proto::UpdateChannels {
2382        channel_invitations: vec![channel.to_proto()],
2383        ..Default::default()
2384    };
2385
2386    let connection_pool = session.connection_pool().await;
2387    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2388        session.peer.send(connection_id, update.clone())?;
2389    }
2390
2391    send_notifications(&*connection_pool, &session.peer, notifications);
2392
2393    response.send(proto::Ack {})?;
2394    Ok(())
2395}
2396
2397/// remove someone from a channel
2398async fn remove_channel_member(
2399    request: proto::RemoveChannelMember,
2400    response: Response<proto::RemoveChannelMember>,
2401    session: Session,
2402) -> Result<()> {
2403    let db = session.db().await;
2404    let channel_id = ChannelId::from_proto(request.channel_id);
2405    let member_id = UserId::from_proto(request.user_id);
2406
2407    let RemoveChannelMemberResult {
2408        membership_update,
2409        notification_id,
2410    } = db
2411        .remove_channel_member(channel_id, member_id, session.user_id)
2412        .await?;
2413
2414    let connection_pool = &session.connection_pool().await;
2415    notify_membership_updated(
2416        &connection_pool,
2417        membership_update,
2418        member_id,
2419        &session.peer,
2420    );
2421    for connection_id in connection_pool.user_connection_ids(member_id) {
2422        if let Some(notification_id) = notification_id {
2423            session
2424                .peer
2425                .send(
2426                    connection_id,
2427                    proto::DeleteNotification {
2428                        notification_id: notification_id.to_proto(),
2429                    },
2430                )
2431                .trace_err();
2432        }
2433    }
2434
2435    response.send(proto::Ack {})?;
2436    Ok(())
2437}
2438
2439/// Toggle the channel between public and private
2440async fn set_channel_visibility(
2441    request: proto::SetChannelVisibility,
2442    response: Response<proto::SetChannelVisibility>,
2443    session: Session,
2444) -> Result<()> {
2445    let db = session.db().await;
2446    let channel_id = ChannelId::from_proto(request.channel_id);
2447    let visibility = request.visibility().into();
2448
2449    let SetChannelVisibilityResult {
2450        participants_to_update,
2451        participants_to_remove,
2452        channels_to_remove,
2453    } = db
2454        .set_channel_visibility(channel_id, visibility, session.user_id)
2455        .await?;
2456
2457    let connection_pool = session.connection_pool().await;
2458    for (user_id, channels) in participants_to_update {
2459        let update = build_channels_update(channels, vec![]);
2460        for connection_id in connection_pool.user_connection_ids(user_id) {
2461            session.peer.send(connection_id, update.clone())?;
2462        }
2463    }
2464    for user_id in participants_to_remove {
2465        let update = proto::UpdateChannels {
2466            delete_channels: channels_to_remove.iter().map(|id| id.to_proto()).collect(),
2467            ..Default::default()
2468        };
2469        for connection_id in connection_pool.user_connection_ids(user_id) {
2470            session.peer.send(connection_id, update.clone())?;
2471        }
2472    }
2473
2474    response.send(proto::Ack {})?;
2475    Ok(())
2476}
2477
2478/// Alter the role for a user in the channel
2479async fn set_channel_member_role(
2480    request: proto::SetChannelMemberRole,
2481    response: Response<proto::SetChannelMemberRole>,
2482    session: Session,
2483) -> Result<()> {
2484    let db = session.db().await;
2485    let channel_id = ChannelId::from_proto(request.channel_id);
2486    let member_id = UserId::from_proto(request.user_id);
2487    let result = db
2488        .set_channel_member_role(
2489            channel_id,
2490            session.user_id,
2491            member_id,
2492            request.role().into(),
2493        )
2494        .await?;
2495
2496    match result {
2497        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2498            let connection_pool = session.connection_pool().await;
2499            notify_membership_updated(
2500                &connection_pool,
2501                membership_update,
2502                member_id,
2503                &session.peer,
2504            )
2505        }
2506        db::SetMemberRoleResult::InviteUpdated(channel) => {
2507            let update = proto::UpdateChannels {
2508                channel_invitations: vec![channel.to_proto()],
2509                ..Default::default()
2510            };
2511
2512            for connection_id in session
2513                .connection_pool()
2514                .await
2515                .user_connection_ids(member_id)
2516            {
2517                session.peer.send(connection_id, update.clone())?;
2518            }
2519        }
2520    }
2521
2522    response.send(proto::Ack {})?;
2523    Ok(())
2524}
2525
2526/// Change the name of a channel
2527async fn rename_channel(
2528    request: proto::RenameChannel,
2529    response: Response<proto::RenameChannel>,
2530    session: Session,
2531) -> Result<()> {
2532    let db = session.db().await;
2533    let channel_id = ChannelId::from_proto(request.channel_id);
2534    let RenameChannelResult {
2535        channel,
2536        participants_to_update,
2537    } = db
2538        .rename_channel(channel_id, session.user_id, &request.name)
2539        .await?;
2540
2541    response.send(proto::RenameChannelResponse {
2542        channel: Some(channel.to_proto()),
2543    })?;
2544
2545    let connection_pool = session.connection_pool().await;
2546    for (user_id, channel) in participants_to_update {
2547        for connection_id in connection_pool.user_connection_ids(user_id) {
2548            let update = proto::UpdateChannels {
2549                channels: vec![channel.to_proto()],
2550                ..Default::default()
2551            };
2552
2553            session.peer.send(connection_id, update.clone())?;
2554        }
2555    }
2556
2557    Ok(())
2558}
2559
2560/// Move a channel to a new parent.
2561async fn move_channel(
2562    request: proto::MoveChannel,
2563    response: Response<proto::MoveChannel>,
2564    session: Session,
2565) -> Result<()> {
2566    let channel_id = ChannelId::from_proto(request.channel_id);
2567    let to = request.to.map(ChannelId::from_proto);
2568
2569    let result = session
2570        .db()
2571        .await
2572        .move_channel(channel_id, to, session.user_id)
2573        .await?;
2574
2575    notify_channel_moved(result, session).await?;
2576
2577    response.send(Ack {})?;
2578    Ok(())
2579}
2580
2581async fn notify_channel_moved(result: Option<MoveChannelResult>, session: Session) -> Result<()> {
2582    let Some(MoveChannelResult {
2583        participants_to_remove,
2584        participants_to_update,
2585        moved_channels,
2586    }) = result
2587    else {
2588        return Ok(());
2589    };
2590    let moved_channels: Vec<u64> = moved_channels.iter().map(|id| id.to_proto()).collect();
2591
2592    let connection_pool = session.connection_pool().await;
2593    for (user_id, channels) in participants_to_update {
2594        let mut update = build_channels_update(channels, vec![]);
2595        update.delete_channels = moved_channels.clone();
2596        for connection_id in connection_pool.user_connection_ids(user_id) {
2597            session.peer.send(connection_id, update.clone())?;
2598        }
2599    }
2600
2601    for user_id in participants_to_remove {
2602        let update = proto::UpdateChannels {
2603            delete_channels: moved_channels.clone(),
2604            ..Default::default()
2605        };
2606        for connection_id in connection_pool.user_connection_ids(user_id) {
2607            session.peer.send(connection_id, update.clone())?;
2608        }
2609    }
2610    Ok(())
2611}
2612
2613/// Get the list of channel members
2614async fn get_channel_members(
2615    request: proto::GetChannelMembers,
2616    response: Response<proto::GetChannelMembers>,
2617    session: Session,
2618) -> Result<()> {
2619    let db = session.db().await;
2620    let channel_id = ChannelId::from_proto(request.channel_id);
2621    let members = db
2622        .get_channel_participant_details(channel_id, session.user_id)
2623        .await?;
2624    response.send(proto::GetChannelMembersResponse { members })?;
2625    Ok(())
2626}
2627
2628/// Accept or decline a channel invitation.
2629async fn respond_to_channel_invite(
2630    request: proto::RespondToChannelInvite,
2631    response: Response<proto::RespondToChannelInvite>,
2632    session: Session,
2633) -> Result<()> {
2634    let db = session.db().await;
2635    let channel_id = ChannelId::from_proto(request.channel_id);
2636    let RespondToChannelInvite {
2637        membership_update,
2638        notifications,
2639    } = db
2640        .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2641        .await?;
2642
2643    let connection_pool = session.connection_pool().await;
2644    if let Some(membership_update) = membership_update {
2645        notify_membership_updated(
2646            &connection_pool,
2647            membership_update,
2648            session.user_id,
2649            &session.peer,
2650        );
2651    } else {
2652        let update = proto::UpdateChannels {
2653            remove_channel_invitations: vec![channel_id.to_proto()],
2654            ..Default::default()
2655        };
2656
2657        for connection_id in connection_pool.user_connection_ids(session.user_id) {
2658            session.peer.send(connection_id, update.clone())?;
2659        }
2660    };
2661
2662    send_notifications(&*connection_pool, &session.peer, notifications);
2663
2664    response.send(proto::Ack {})?;
2665
2666    Ok(())
2667}
2668
2669/// Join the channels' room
2670async fn join_channel(
2671    request: proto::JoinChannel,
2672    response: Response<proto::JoinChannel>,
2673    session: Session,
2674) -> Result<()> {
2675    let channel_id = ChannelId::from_proto(request.channel_id);
2676    join_channel_internal(channel_id, Box::new(response), session).await
2677}
2678
2679trait JoinChannelInternalResponse {
2680    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2681}
2682impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2683    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2684        Response::<proto::JoinChannel>::send(self, result)
2685    }
2686}
2687impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2688    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2689        Response::<proto::JoinRoom>::send(self, result)
2690    }
2691}
2692
2693async fn join_channel_internal(
2694    channel_id: ChannelId,
2695    response: Box<impl JoinChannelInternalResponse>,
2696    session: Session,
2697) -> Result<()> {
2698    let joined_room = {
2699        leave_room_for_session(&session).await?;
2700        let db = session.db().await;
2701
2702        let (joined_room, membership_updated, role) = db
2703            .join_channel(
2704                channel_id,
2705                session.user_id,
2706                session.connection_id,
2707                session.zed_environment.as_ref(),
2708            )
2709            .await?;
2710
2711        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2712            let (can_publish, token) = if role == ChannelRole::Guest {
2713                (
2714                    false,
2715                    live_kit
2716                        .guest_token(
2717                            &joined_room.room.live_kit_room,
2718                            &session.user_id.to_string(),
2719                        )
2720                        .trace_err()?,
2721                )
2722            } else {
2723                (
2724                    true,
2725                    live_kit
2726                        .room_token(
2727                            &joined_room.room.live_kit_room,
2728                            &session.user_id.to_string(),
2729                        )
2730                        .trace_err()?,
2731                )
2732            };
2733
2734            Some(LiveKitConnectionInfo {
2735                server_url: live_kit.url().into(),
2736                token,
2737                can_publish,
2738            })
2739        });
2740
2741        response.send(proto::JoinRoomResponse {
2742            room: Some(joined_room.room.clone()),
2743            channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2744            live_kit_connection_info,
2745        })?;
2746
2747        let connection_pool = session.connection_pool().await;
2748        if let Some(membership_updated) = membership_updated {
2749            notify_membership_updated(
2750                &connection_pool,
2751                membership_updated,
2752                session.user_id,
2753                &session.peer,
2754            );
2755        }
2756
2757        room_updated(&joined_room.room, &session.peer);
2758
2759        joined_room
2760    };
2761
2762    channel_updated(
2763        channel_id,
2764        &joined_room.room,
2765        &joined_room.channel_members,
2766        &session.peer,
2767        &*session.connection_pool().await,
2768    );
2769
2770    update_user_contacts(session.user_id, &session).await?;
2771    Ok(())
2772}
2773
2774/// Start editing the channel notes
2775async fn join_channel_buffer(
2776    request: proto::JoinChannelBuffer,
2777    response: Response<proto::JoinChannelBuffer>,
2778    session: Session,
2779) -> Result<()> {
2780    let db = session.db().await;
2781    let channel_id = ChannelId::from_proto(request.channel_id);
2782
2783    let open_response = db
2784        .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2785        .await?;
2786
2787    let collaborators = open_response.collaborators.clone();
2788    response.send(open_response)?;
2789
2790    let update = UpdateChannelBufferCollaborators {
2791        channel_id: channel_id.to_proto(),
2792        collaborators: collaborators.clone(),
2793    };
2794    channel_buffer_updated(
2795        session.connection_id,
2796        collaborators
2797            .iter()
2798            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2799        &update,
2800        &session.peer,
2801    );
2802
2803    Ok(())
2804}
2805
2806/// Edit the channel notes
2807async fn update_channel_buffer(
2808    request: proto::UpdateChannelBuffer,
2809    session: Session,
2810) -> Result<()> {
2811    let db = session.db().await;
2812    let channel_id = ChannelId::from_proto(request.channel_id);
2813
2814    let (collaborators, non_collaborators, epoch, version) = db
2815        .update_channel_buffer(channel_id, session.user_id, &request.operations)
2816        .await?;
2817
2818    channel_buffer_updated(
2819        session.connection_id,
2820        collaborators,
2821        &proto::UpdateChannelBuffer {
2822            channel_id: channel_id.to_proto(),
2823            operations: request.operations,
2824        },
2825        &session.peer,
2826    );
2827
2828    let pool = &*session.connection_pool().await;
2829
2830    broadcast(
2831        None,
2832        non_collaborators
2833            .iter()
2834            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2835        |peer_id| {
2836            session.peer.send(
2837                peer_id.into(),
2838                proto::UpdateChannels {
2839                    unseen_channel_buffer_changes: vec![proto::UnseenChannelBufferChange {
2840                        channel_id: channel_id.to_proto(),
2841                        epoch: epoch as u64,
2842                        version: version.clone(),
2843                    }],
2844                    ..Default::default()
2845                },
2846            )
2847        },
2848    );
2849
2850    Ok(())
2851}
2852
2853/// Rejoin the channel notes after a connection blip
2854async fn rejoin_channel_buffers(
2855    request: proto::RejoinChannelBuffers,
2856    response: Response<proto::RejoinChannelBuffers>,
2857    session: Session,
2858) -> Result<()> {
2859    let db = session.db().await;
2860    let buffers = db
2861        .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2862        .await?;
2863
2864    for rejoined_buffer in &buffers {
2865        let collaborators_to_notify = rejoined_buffer
2866            .buffer
2867            .collaborators
2868            .iter()
2869            .filter_map(|c| Some(c.peer_id?.into()));
2870        channel_buffer_updated(
2871            session.connection_id,
2872            collaborators_to_notify,
2873            &proto::UpdateChannelBufferCollaborators {
2874                channel_id: rejoined_buffer.buffer.channel_id,
2875                collaborators: rejoined_buffer.buffer.collaborators.clone(),
2876            },
2877            &session.peer,
2878        );
2879    }
2880
2881    response.send(proto::RejoinChannelBuffersResponse {
2882        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2883    })?;
2884
2885    Ok(())
2886}
2887
2888/// Stop editing the channel notes
2889async fn leave_channel_buffer(
2890    request: proto::LeaveChannelBuffer,
2891    response: Response<proto::LeaveChannelBuffer>,
2892    session: Session,
2893) -> Result<()> {
2894    let db = session.db().await;
2895    let channel_id = ChannelId::from_proto(request.channel_id);
2896
2897    let left_buffer = db
2898        .leave_channel_buffer(channel_id, session.connection_id)
2899        .await?;
2900
2901    response.send(Ack {})?;
2902
2903    channel_buffer_updated(
2904        session.connection_id,
2905        left_buffer.connections,
2906        &proto::UpdateChannelBufferCollaborators {
2907            channel_id: channel_id.to_proto(),
2908            collaborators: left_buffer.collaborators,
2909        },
2910        &session.peer,
2911    );
2912
2913    Ok(())
2914}
2915
2916fn channel_buffer_updated<T: EnvelopedMessage>(
2917    sender_id: ConnectionId,
2918    collaborators: impl IntoIterator<Item = ConnectionId>,
2919    message: &T,
2920    peer: &Peer,
2921) {
2922    broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2923        peer.send(peer_id.into(), message.clone())
2924    });
2925}
2926
2927fn send_notifications(
2928    connection_pool: &ConnectionPool,
2929    peer: &Peer,
2930    notifications: db::NotificationBatch,
2931) {
2932    for (user_id, notification) in notifications {
2933        for connection_id in connection_pool.user_connection_ids(user_id) {
2934            if let Err(error) = peer.send(
2935                connection_id,
2936                proto::AddNotification {
2937                    notification: Some(notification.clone()),
2938                },
2939            ) {
2940                tracing::error!(
2941                    "failed to send notification to {:?} {}",
2942                    connection_id,
2943                    error
2944                );
2945            }
2946        }
2947    }
2948}
2949
2950/// Send a message to the channel
2951async fn send_channel_message(
2952    request: proto::SendChannelMessage,
2953    response: Response<proto::SendChannelMessage>,
2954    session: Session,
2955) -> Result<()> {
2956    // Validate the message body.
2957    let body = request.body.trim().to_string();
2958    if body.len() > MAX_MESSAGE_LEN {
2959        return Err(anyhow!("message is too long"))?;
2960    }
2961    if body.is_empty() {
2962        return Err(anyhow!("message can't be blank"))?;
2963    }
2964
2965    // TODO: adjust mentions if body is trimmed
2966
2967    let timestamp = OffsetDateTime::now_utc();
2968    let nonce = request
2969        .nonce
2970        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
2971
2972    let channel_id = ChannelId::from_proto(request.channel_id);
2973    let CreatedChannelMessage {
2974        message_id,
2975        participant_connection_ids,
2976        channel_members,
2977        notifications,
2978    } = session
2979        .db()
2980        .await
2981        .create_channel_message(
2982            channel_id,
2983            session.user_id,
2984            &body,
2985            &request.mentions,
2986            timestamp,
2987            nonce.clone().into(),
2988        )
2989        .await?;
2990    let message = proto::ChannelMessage {
2991        sender_id: session.user_id.to_proto(),
2992        id: message_id.to_proto(),
2993        body,
2994        mentions: request.mentions,
2995        timestamp: timestamp.unix_timestamp() as u64,
2996        nonce: Some(nonce),
2997    };
2998    broadcast(
2999        Some(session.connection_id),
3000        participant_connection_ids,
3001        |connection| {
3002            session.peer.send(
3003                connection,
3004                proto::ChannelMessageSent {
3005                    channel_id: channel_id.to_proto(),
3006                    message: Some(message.clone()),
3007                },
3008            )
3009        },
3010    );
3011    response.send(proto::SendChannelMessageResponse {
3012        message: Some(message),
3013    })?;
3014
3015    let pool = &*session.connection_pool().await;
3016    broadcast(
3017        None,
3018        channel_members
3019            .iter()
3020            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3021        |peer_id| {
3022            session.peer.send(
3023                peer_id.into(),
3024                proto::UpdateChannels {
3025                    unseen_channel_messages: vec![proto::UnseenChannelMessage {
3026                        channel_id: channel_id.to_proto(),
3027                        message_id: message_id.to_proto(),
3028                    }],
3029                    ..Default::default()
3030                },
3031            )
3032        },
3033    );
3034    send_notifications(pool, &session.peer, notifications);
3035
3036    Ok(())
3037}
3038
3039/// Delete a channel message
3040async fn remove_channel_message(
3041    request: proto::RemoveChannelMessage,
3042    response: Response<proto::RemoveChannelMessage>,
3043    session: Session,
3044) -> Result<()> {
3045    let channel_id = ChannelId::from_proto(request.channel_id);
3046    let message_id = MessageId::from_proto(request.message_id);
3047    let connection_ids = session
3048        .db()
3049        .await
3050        .remove_channel_message(channel_id, message_id, session.user_id)
3051        .await?;
3052    broadcast(Some(session.connection_id), connection_ids, |connection| {
3053        session.peer.send(connection, request.clone())
3054    });
3055    response.send(proto::Ack {})?;
3056    Ok(())
3057}
3058
3059/// Mark a channel message as read
3060async fn acknowledge_channel_message(
3061    request: proto::AckChannelMessage,
3062    session: Session,
3063) -> Result<()> {
3064    let channel_id = ChannelId::from_proto(request.channel_id);
3065    let message_id = MessageId::from_proto(request.message_id);
3066    let notifications = session
3067        .db()
3068        .await
3069        .observe_channel_message(channel_id, session.user_id, message_id)
3070        .await?;
3071    send_notifications(
3072        &*session.connection_pool().await,
3073        &session.peer,
3074        notifications,
3075    );
3076    Ok(())
3077}
3078
3079/// Mark a buffer version as synced
3080async fn acknowledge_buffer_version(
3081    request: proto::AckBufferOperation,
3082    session: Session,
3083) -> Result<()> {
3084    let buffer_id = BufferId::from_proto(request.buffer_id);
3085    session
3086        .db()
3087        .await
3088        .observe_buffer_version(
3089            buffer_id,
3090            session.user_id,
3091            request.epoch as i32,
3092            &request.version,
3093        )
3094        .await?;
3095    Ok(())
3096}
3097
3098/// Start receiving chat updates for a channel
3099async fn join_channel_chat(
3100    request: proto::JoinChannelChat,
3101    response: Response<proto::JoinChannelChat>,
3102    session: Session,
3103) -> Result<()> {
3104    let channel_id = ChannelId::from_proto(request.channel_id);
3105
3106    let db = session.db().await;
3107    db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3108        .await?;
3109    let messages = db
3110        .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3111        .await?;
3112    response.send(proto::JoinChannelChatResponse {
3113        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3114        messages,
3115    })?;
3116    Ok(())
3117}
3118
3119/// Stop receiving chat updates for a channel
3120async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3121    let channel_id = ChannelId::from_proto(request.channel_id);
3122    session
3123        .db()
3124        .await
3125        .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3126        .await?;
3127    Ok(())
3128}
3129
3130/// Retrieve the chat history for a channel
3131async fn get_channel_messages(
3132    request: proto::GetChannelMessages,
3133    response: Response<proto::GetChannelMessages>,
3134    session: Session,
3135) -> Result<()> {
3136    let channel_id = ChannelId::from_proto(request.channel_id);
3137    let messages = session
3138        .db()
3139        .await
3140        .get_channel_messages(
3141            channel_id,
3142            session.user_id,
3143            MESSAGE_COUNT_PER_PAGE,
3144            Some(MessageId::from_proto(request.before_message_id)),
3145        )
3146        .await?;
3147    response.send(proto::GetChannelMessagesResponse {
3148        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3149        messages,
3150    })?;
3151    Ok(())
3152}
3153
3154/// Retrieve specific chat messages
3155async fn get_channel_messages_by_id(
3156    request: proto::GetChannelMessagesById,
3157    response: Response<proto::GetChannelMessagesById>,
3158    session: Session,
3159) -> Result<()> {
3160    let message_ids = request
3161        .message_ids
3162        .iter()
3163        .map(|id| MessageId::from_proto(*id))
3164        .collect::<Vec<_>>();
3165    let messages = session
3166        .db()
3167        .await
3168        .get_channel_messages_by_id(session.user_id, &message_ids)
3169        .await?;
3170    response.send(proto::GetChannelMessagesResponse {
3171        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3172        messages,
3173    })?;
3174    Ok(())
3175}
3176
3177/// Retrieve the current users notifications
3178async fn get_notifications(
3179    request: proto::GetNotifications,
3180    response: Response<proto::GetNotifications>,
3181    session: Session,
3182) -> Result<()> {
3183    let notifications = session
3184        .db()
3185        .await
3186        .get_notifications(
3187            session.user_id,
3188            NOTIFICATION_COUNT_PER_PAGE,
3189            request
3190                .before_id
3191                .map(|id| db::NotificationId::from_proto(id)),
3192        )
3193        .await?;
3194    response.send(proto::GetNotificationsResponse {
3195        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3196        notifications,
3197    })?;
3198    Ok(())
3199}
3200
3201/// Mark notifications as read
3202async fn mark_notification_as_read(
3203    request: proto::MarkNotificationRead,
3204    response: Response<proto::MarkNotificationRead>,
3205    session: Session,
3206) -> Result<()> {
3207    let database = &session.db().await;
3208    let notifications = database
3209        .mark_notification_as_read_by_id(
3210            session.user_id,
3211            NotificationId::from_proto(request.notification_id),
3212        )
3213        .await?;
3214    send_notifications(
3215        &*session.connection_pool().await,
3216        &session.peer,
3217        notifications,
3218    );
3219    response.send(proto::Ack {})?;
3220    Ok(())
3221}
3222
3223/// Get the current users information
3224async fn get_private_user_info(
3225    _request: proto::GetPrivateUserInfo,
3226    response: Response<proto::GetPrivateUserInfo>,
3227    session: Session,
3228) -> Result<()> {
3229    let db = session.db().await;
3230
3231    let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3232    let user = db
3233        .get_user_by_id(session.user_id)
3234        .await?
3235        .ok_or_else(|| anyhow!("user not found"))?;
3236    let flags = db.get_user_flags(session.user_id).await?;
3237
3238    response.send(proto::GetPrivateUserInfoResponse {
3239        metrics_id,
3240        staff: user.admin,
3241        flags,
3242    })?;
3243    Ok(())
3244}
3245
3246fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3247    match message {
3248        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3249        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3250        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3251        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3252        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3253            code: frame.code.into(),
3254            reason: frame.reason,
3255        })),
3256    }
3257}
3258
3259fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3260    match message {
3261        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3262        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3263        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3264        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3265        AxumMessage::Close(frame) => {
3266            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3267                code: frame.code.into(),
3268                reason: frame.reason,
3269            }))
3270        }
3271    }
3272}
3273
3274fn notify_membership_updated(
3275    connection_pool: &ConnectionPool,
3276    result: MembershipUpdated,
3277    user_id: UserId,
3278    peer: &Peer,
3279) {
3280    let mut update = build_channels_update(result.new_channels, vec![]);
3281    update.delete_channels = result
3282        .removed_channels
3283        .into_iter()
3284        .map(|id| id.to_proto())
3285        .collect();
3286    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3287
3288    for connection_id in connection_pool.user_connection_ids(user_id) {
3289        peer.send(connection_id, update.clone()).trace_err();
3290    }
3291}
3292
3293fn build_channels_update(
3294    channels: ChannelsForUser,
3295    channel_invites: Vec<db::Channel>,
3296) -> proto::UpdateChannels {
3297    let mut update = proto::UpdateChannels::default();
3298
3299    for channel in channels.channels {
3300        update.channels.push(channel.to_proto());
3301    }
3302
3303    update.unseen_channel_buffer_changes = channels.unseen_buffer_changes;
3304    update.unseen_channel_messages = channels.channel_messages;
3305
3306    for (channel_id, participants) in channels.channel_participants {
3307        update
3308            .channel_participants
3309            .push(proto::ChannelParticipants {
3310                channel_id: channel_id.to_proto(),
3311                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3312            });
3313    }
3314
3315    for channel in channel_invites {
3316        update.channel_invitations.push(channel.to_proto());
3317    }
3318
3319    update
3320}
3321
3322fn build_initial_contacts_update(
3323    contacts: Vec<db::Contact>,
3324    pool: &ConnectionPool,
3325) -> proto::UpdateContacts {
3326    let mut update = proto::UpdateContacts::default();
3327
3328    for contact in contacts {
3329        match contact {
3330            db::Contact::Accepted { user_id, busy } => {
3331                update.contacts.push(contact_for_user(user_id, busy, &pool));
3332            }
3333            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3334            db::Contact::Incoming { user_id } => {
3335                update
3336                    .incoming_requests
3337                    .push(proto::IncomingContactRequest {
3338                        requester_id: user_id.to_proto(),
3339                    })
3340            }
3341        }
3342    }
3343
3344    update
3345}
3346
3347fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3348    proto::Contact {
3349        user_id: user_id.to_proto(),
3350        online: pool.is_user_online(user_id),
3351        busy,
3352    }
3353}
3354
3355fn room_updated(room: &proto::Room, peer: &Peer) {
3356    broadcast(
3357        None,
3358        room.participants
3359            .iter()
3360            .filter_map(|participant| Some(participant.peer_id?.into())),
3361        |peer_id| {
3362            peer.send(
3363                peer_id.into(),
3364                proto::RoomUpdated {
3365                    room: Some(room.clone()),
3366                },
3367            )
3368        },
3369    );
3370}
3371
3372fn channel_updated(
3373    channel_id: ChannelId,
3374    room: &proto::Room,
3375    channel_members: &[UserId],
3376    peer: &Peer,
3377    pool: &ConnectionPool,
3378) {
3379    let participants = room
3380        .participants
3381        .iter()
3382        .map(|p| p.user_id)
3383        .collect::<Vec<_>>();
3384
3385    broadcast(
3386        None,
3387        channel_members
3388            .iter()
3389            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3390        |peer_id| {
3391            peer.send(
3392                peer_id.into(),
3393                proto::UpdateChannels {
3394                    channel_participants: vec![proto::ChannelParticipants {
3395                        channel_id: channel_id.to_proto(),
3396                        participant_user_ids: participants.clone(),
3397                    }],
3398                    ..Default::default()
3399                },
3400            )
3401        },
3402    );
3403}
3404
3405async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3406    let db = session.db().await;
3407
3408    let contacts = db.get_contacts(user_id).await?;
3409    let busy = db.is_user_busy(user_id).await?;
3410
3411    let pool = session.connection_pool().await;
3412    let updated_contact = contact_for_user(user_id, busy, &pool);
3413    for contact in contacts {
3414        if let db::Contact::Accepted {
3415            user_id: contact_user_id,
3416            ..
3417        } = contact
3418        {
3419            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3420                session
3421                    .peer
3422                    .send(
3423                        contact_conn_id,
3424                        proto::UpdateContacts {
3425                            contacts: vec![updated_contact.clone()],
3426                            remove_contacts: Default::default(),
3427                            incoming_requests: Default::default(),
3428                            remove_incoming_requests: Default::default(),
3429                            outgoing_requests: Default::default(),
3430                            remove_outgoing_requests: Default::default(),
3431                        },
3432                    )
3433                    .trace_err();
3434            }
3435        }
3436    }
3437    Ok(())
3438}
3439
3440async fn leave_room_for_session(session: &Session) -> Result<()> {
3441    let mut contacts_to_update = HashSet::default();
3442
3443    let room_id;
3444    let canceled_calls_to_user_ids;
3445    let live_kit_room;
3446    let delete_live_kit_room;
3447    let room;
3448    let channel_members;
3449    let channel_id;
3450
3451    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3452        contacts_to_update.insert(session.user_id);
3453
3454        for project in left_room.left_projects.values() {
3455            project_left(project, session);
3456        }
3457
3458        room_id = RoomId::from_proto(left_room.room.id);
3459        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3460        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3461        delete_live_kit_room = left_room.deleted;
3462        room = mem::take(&mut left_room.room);
3463        channel_members = mem::take(&mut left_room.channel_members);
3464        channel_id = left_room.channel_id;
3465
3466        room_updated(&room, &session.peer);
3467    } else {
3468        return Ok(());
3469    }
3470
3471    if let Some(channel_id) = channel_id {
3472        channel_updated(
3473            channel_id,
3474            &room,
3475            &channel_members,
3476            &session.peer,
3477            &*session.connection_pool().await,
3478        );
3479    }
3480
3481    {
3482        let pool = session.connection_pool().await;
3483        for canceled_user_id in canceled_calls_to_user_ids {
3484            for connection_id in pool.user_connection_ids(canceled_user_id) {
3485                session
3486                    .peer
3487                    .send(
3488                        connection_id,
3489                        proto::CallCanceled {
3490                            room_id: room_id.to_proto(),
3491                        },
3492                    )
3493                    .trace_err();
3494            }
3495            contacts_to_update.insert(canceled_user_id);
3496        }
3497    }
3498
3499    for contact_user_id in contacts_to_update {
3500        update_user_contacts(contact_user_id, &session).await?;
3501    }
3502
3503    if let Some(live_kit) = session.live_kit_client.as_ref() {
3504        live_kit
3505            .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3506            .await
3507            .trace_err();
3508
3509        if delete_live_kit_room {
3510            live_kit.delete_room(live_kit_room).await.trace_err();
3511        }
3512    }
3513
3514    Ok(())
3515}
3516
3517async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3518    let left_channel_buffers = session
3519        .db()
3520        .await
3521        .leave_channel_buffers(session.connection_id)
3522        .await?;
3523
3524    for left_buffer in left_channel_buffers {
3525        channel_buffer_updated(
3526            session.connection_id,
3527            left_buffer.connections,
3528            &proto::UpdateChannelBufferCollaborators {
3529                channel_id: left_buffer.channel_id.to_proto(),
3530                collaborators: left_buffer.collaborators,
3531            },
3532            &session.peer,
3533        );
3534    }
3535
3536    Ok(())
3537}
3538
3539fn project_left(project: &db::LeftProject, session: &Session) {
3540    for connection_id in &project.connection_ids {
3541        if project.host_user_id == session.user_id {
3542            session
3543                .peer
3544                .send(
3545                    *connection_id,
3546                    proto::UnshareProject {
3547                        project_id: project.id.to_proto(),
3548                    },
3549                )
3550                .trace_err();
3551        } else {
3552            session
3553                .peer
3554                .send(
3555                    *connection_id,
3556                    proto::RemoveProjectCollaborator {
3557                        project_id: project.id.to_proto(),
3558                        peer_id: Some(session.connection_id.into()),
3559                    },
3560                )
3561                .trace_err();
3562        }
3563    }
3564}
3565
3566pub trait ResultExt {
3567    type Ok;
3568
3569    fn trace_err(self) -> Option<Self::Ok>;
3570}
3571
3572impl<T, E> ResultExt for Result<T, E>
3573where
3574    E: std::fmt::Debug,
3575{
3576    type Ok = T;
3577
3578    fn trace_err(self) -> Option<T> {
3579        match self {
3580            Ok(value) => Some(value),
3581            Err(error) => {
3582                tracing::error!("{:?}", error);
3583                None
3584            }
3585        }
3586    }
3587}