1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{self, Database, ProjectId, RoomId, User, UserId},
6 executor::Executor,
7 AppState, Result,
8};
9use anyhow::anyhow;
10use async_tungstenite::tungstenite::{
11 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
12};
13use axum::{
14 body::Body,
15 extract::{
16 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
17 ConnectInfo, WebSocketUpgrade,
18 },
19 headers::{Header, HeaderName},
20 http::StatusCode,
21 middleware,
22 response::IntoResponse,
23 routing::get,
24 Extension, Router, TypedHeader,
25};
26use collections::{HashMap, HashSet};
27pub use connection_pool::ConnectionPool;
28use futures::{
29 channel::oneshot,
30 future::{self, BoxFuture},
31 stream::FuturesUnordered,
32 FutureExt, SinkExt, StreamExt, TryStreamExt,
33};
34use lazy_static::lazy_static;
35use prometheus::{register_int_gauge, IntGauge};
36use rpc::{
37 proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
38 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
39};
40use serde::{Serialize, Serializer};
41use std::{
42 any::TypeId,
43 fmt,
44 future::Future,
45 marker::PhantomData,
46 mem,
47 net::SocketAddr,
48 ops::{Deref, DerefMut},
49 rc::Rc,
50 sync::{
51 atomic::{AtomicBool, Ordering::SeqCst},
52 Arc,
53 },
54 time::Duration,
55};
56use tokio::sync::{watch, Mutex, MutexGuard};
57use tower::ServiceBuilder;
58use tracing::{info_span, instrument, Instrument};
59
60pub const RECONNECT_TIMEOUT: Duration = rpc::RECEIVE_TIMEOUT;
61
62lazy_static! {
63 static ref METRIC_CONNECTIONS: IntGauge =
64 register_int_gauge!("connections", "number of connections").unwrap();
65 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
66 "shared_projects",
67 "number of open projects with one or more guests"
68 )
69 .unwrap();
70}
71
72type MessageHandler =
73 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
74
75struct Response<R> {
76 peer: Arc<Peer>,
77 receipt: Receipt<R>,
78 responded: Arc<AtomicBool>,
79}
80
81impl<R: RequestMessage> Response<R> {
82 fn send(self, payload: R::Response) -> Result<()> {
83 self.responded.store(true, SeqCst);
84 self.peer.respond(self.receipt, payload)?;
85 Ok(())
86 }
87}
88
89#[derive(Clone)]
90struct Session {
91 user_id: UserId,
92 connection_id: ConnectionId,
93 db: Arc<Mutex<DbHandle>>,
94 peer: Arc<Peer>,
95 connection_pool: Arc<Mutex<ConnectionPool>>,
96 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
97}
98
99impl Session {
100 async fn db(&self) -> MutexGuard<DbHandle> {
101 #[cfg(test)]
102 tokio::task::yield_now().await;
103 let guard = self.db.lock().await;
104 #[cfg(test)]
105 tokio::task::yield_now().await;
106 guard
107 }
108
109 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
110 #[cfg(test)]
111 tokio::task::yield_now().await;
112 let guard = self.connection_pool.lock().await;
113 #[cfg(test)]
114 tokio::task::yield_now().await;
115 ConnectionPoolGuard {
116 guard,
117 _not_send: PhantomData,
118 }
119 }
120}
121
122impl fmt::Debug for Session {
123 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124 f.debug_struct("Session")
125 .field("user_id", &self.user_id)
126 .field("connection_id", &self.connection_id)
127 .finish()
128 }
129}
130
131struct DbHandle(Arc<Database>);
132
133impl Deref for DbHandle {
134 type Target = Database;
135
136 fn deref(&self) -> &Self::Target {
137 self.0.as_ref()
138 }
139}
140
141pub struct Server {
142 peer: Arc<Peer>,
143 pub(crate) connection_pool: Arc<Mutex<ConnectionPool>>,
144 app_state: Arc<AppState>,
145 handlers: HashMap<TypeId, MessageHandler>,
146 teardown: watch::Sender<()>,
147}
148
149pub(crate) struct ConnectionPoolGuard<'a> {
150 guard: MutexGuard<'a, ConnectionPool>,
151 _not_send: PhantomData<Rc<()>>,
152}
153
154#[derive(Serialize)]
155pub struct ServerSnapshot<'a> {
156 peer: &'a Peer,
157 #[serde(serialize_with = "serialize_deref")]
158 connection_pool: ConnectionPoolGuard<'a>,
159}
160
161pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
162where
163 S: Serializer,
164 T: Deref<Target = U>,
165 U: Serialize,
166{
167 Serialize::serialize(value.deref(), serializer)
168}
169
170impl Server {
171 pub fn new(app_state: Arc<AppState>) -> Arc<Self> {
172 let mut server = Self {
173 peer: Peer::new(),
174 app_state,
175 connection_pool: Default::default(),
176 handlers: Default::default(),
177 teardown: watch::channel(()).0,
178 };
179
180 server
181 .add_request_handler(ping)
182 .add_request_handler(create_room)
183 .add_request_handler(join_room)
184 .add_message_handler(leave_room)
185 .add_request_handler(call)
186 .add_request_handler(cancel_call)
187 .add_message_handler(decline_call)
188 .add_request_handler(update_participant_location)
189 .add_request_handler(share_project)
190 .add_message_handler(unshare_project)
191 .add_request_handler(join_project)
192 .add_message_handler(leave_project)
193 .add_request_handler(update_project)
194 .add_request_handler(update_worktree)
195 .add_message_handler(start_language_server)
196 .add_message_handler(update_language_server)
197 .add_message_handler(update_diagnostic_summary)
198 .add_request_handler(forward_project_request::<proto::GetHover>)
199 .add_request_handler(forward_project_request::<proto::GetDefinition>)
200 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
201 .add_request_handler(forward_project_request::<proto::GetReferences>)
202 .add_request_handler(forward_project_request::<proto::SearchProject>)
203 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
204 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
205 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
206 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
207 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
208 .add_request_handler(forward_project_request::<proto::GetCompletions>)
209 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
210 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
211 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
212 .add_request_handler(forward_project_request::<proto::PrepareRename>)
213 .add_request_handler(forward_project_request::<proto::PerformRename>)
214 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
215 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
216 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
217 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
218 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
219 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
220 .add_message_handler(create_buffer_for_peer)
221 .add_request_handler(update_buffer)
222 .add_message_handler(update_buffer_file)
223 .add_message_handler(buffer_reloaded)
224 .add_message_handler(buffer_saved)
225 .add_request_handler(save_buffer)
226 .add_request_handler(get_users)
227 .add_request_handler(fuzzy_search_users)
228 .add_request_handler(request_contact)
229 .add_request_handler(remove_contact)
230 .add_request_handler(respond_to_contact_request)
231 .add_request_handler(follow)
232 .add_message_handler(unfollow)
233 .add_message_handler(update_followers)
234 .add_message_handler(update_diff_base)
235 .add_request_handler(get_private_user_info);
236
237 Arc::new(server)
238 }
239
240 pub fn teardown(&self) {
241 let _ = self.teardown.send(());
242 }
243
244 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
245 where
246 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
247 Fut: 'static + Send + Future<Output = Result<()>>,
248 M: EnvelopedMessage,
249 {
250 let prev_handler = self.handlers.insert(
251 TypeId::of::<M>(),
252 Box::new(move |envelope, session| {
253 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
254 let span = info_span!(
255 "handle message",
256 payload_type = envelope.payload_type_name()
257 );
258 span.in_scope(|| {
259 tracing::info!(
260 payload_type = envelope.payload_type_name(),
261 "message received"
262 );
263 });
264 let future = (handler)(*envelope, session);
265 async move {
266 if let Err(error) = future.await {
267 tracing::error!(%error, "error handling message");
268 }
269 }
270 .instrument(span)
271 .boxed()
272 }),
273 );
274 if prev_handler.is_some() {
275 panic!("registered a handler for the same message twice");
276 }
277 self
278 }
279
280 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
281 where
282 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
283 Fut: 'static + Send + Future<Output = Result<()>>,
284 M: EnvelopedMessage,
285 {
286 self.add_handler(move |envelope, session| handler(envelope.payload, session));
287 self
288 }
289
290 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
291 where
292 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
293 Fut: Send + Future<Output = Result<()>>,
294 M: RequestMessage,
295 {
296 let handler = Arc::new(handler);
297 self.add_handler(move |envelope, session| {
298 let receipt = envelope.receipt();
299 let handler = handler.clone();
300 async move {
301 let peer = session.peer.clone();
302 let responded = Arc::new(AtomicBool::default());
303 let response = Response {
304 peer: peer.clone(),
305 responded: responded.clone(),
306 receipt,
307 };
308 match (handler)(envelope.payload, response, session).await {
309 Ok(()) => {
310 if responded.load(std::sync::atomic::Ordering::SeqCst) {
311 Ok(())
312 } else {
313 Err(anyhow!("handler did not send a response"))?
314 }
315 }
316 Err(error) => {
317 peer.respond_with_error(
318 receipt,
319 proto::Error {
320 message: error.to_string(),
321 },
322 )?;
323 Err(error)
324 }
325 }
326 }
327 })
328 }
329
330 pub fn handle_connection(
331 self: &Arc<Self>,
332 connection: Connection,
333 address: String,
334 user: User,
335 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
336 executor: Executor,
337 ) -> impl Future<Output = Result<()>> {
338 let this = self.clone();
339 let user_id = user.id;
340 let login = user.github_login;
341 let span = info_span!("handle connection", %user_id, %login, %address);
342 let teardown = self.teardown.subscribe();
343 async move {
344 let (connection_id, handle_io, mut incoming_rx) = this
345 .peer
346 .add_connection(connection, {
347 let executor = executor.clone();
348 move |duration| executor.sleep(duration)
349 });
350
351 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
352 this.peer.send(connection_id, proto::Hello { peer_id: connection_id.0 })?;
353 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
354
355 if let Some(send_connection_id) = send_connection_id.take() {
356 let _ = send_connection_id.send(connection_id);
357 }
358
359 if !user.connected_once {
360 this.peer.send(connection_id, proto::ShowContacts {})?;
361 this.app_state.db.set_user_connected_once(user_id, true).await?;
362 }
363
364 let (contacts, invite_code) = future::try_join(
365 this.app_state.db.get_contacts(user_id),
366 this.app_state.db.get_invite_code_for_user(user_id)
367 ).await?;
368
369 {
370 let mut pool = this.connection_pool.lock().await;
371 pool.add_connection(connection_id, user_id, user.admin);
372 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
373
374 if let Some((code, count)) = invite_code {
375 this.peer.send(connection_id, proto::UpdateInviteInfo {
376 url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
377 count: count as u32,
378 })?;
379 }
380 }
381
382 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
383 this.peer.send(connection_id, incoming_call)?;
384 }
385
386 let session = Session {
387 user_id,
388 connection_id,
389 db: Arc::new(Mutex::new(DbHandle(this.app_state.db.clone()))),
390 peer: this.peer.clone(),
391 connection_pool: this.connection_pool.clone(),
392 live_kit_client: this.app_state.live_kit_client.clone()
393 };
394 update_user_contacts(user_id, &session).await?;
395
396 let handle_io = handle_io.fuse();
397 futures::pin_mut!(handle_io);
398
399 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
400 // This prevents deadlocks when e.g., client A performs a request to client B and
401 // client B performs a request to client A. If both clients stop processing further
402 // messages until their respective request completes, they won't have a chance to
403 // respond to the other client's request and cause a deadlock.
404 //
405 // This arrangement ensures we will attempt to process earlier messages first, but fall
406 // back to processing messages arrived later in the spirit of making progress.
407 let mut foreground_message_handlers = FuturesUnordered::new();
408 loop {
409 let next_message = incoming_rx.next().fuse();
410 futures::pin_mut!(next_message);
411 futures::select_biased! {
412 result = handle_io => {
413 if let Err(error) = result {
414 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
415 }
416 break;
417 }
418 _ = foreground_message_handlers.next() => {}
419 message = next_message => {
420 if let Some(message) = message {
421 let type_name = message.payload_type_name();
422 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
423 let span_enter = span.enter();
424 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
425 let is_background = message.is_background();
426 let handle_message = (handler)(message, session.clone());
427 drop(span_enter);
428
429 let handle_message = handle_message.instrument(span);
430 if is_background {
431 executor.spawn_detached(handle_message);
432 } else {
433 foreground_message_handlers.push(handle_message);
434 }
435 } else {
436 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
437 }
438 } else {
439 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
440 break;
441 }
442 }
443 }
444 }
445
446 drop(foreground_message_handlers);
447 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
448 if let Err(error) = sign_out(session, teardown, executor).await {
449 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
450 }
451
452 Ok(())
453 }.instrument(span)
454 }
455
456 pub async fn invite_code_redeemed(
457 self: &Arc<Self>,
458 inviter_id: UserId,
459 invitee_id: UserId,
460 ) -> Result<()> {
461 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
462 if let Some(code) = &user.invite_code {
463 let pool = self.connection_pool.lock().await;
464 let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
465 for connection_id in pool.user_connection_ids(inviter_id) {
466 self.peer.send(
467 connection_id,
468 proto::UpdateContacts {
469 contacts: vec![invitee_contact.clone()],
470 ..Default::default()
471 },
472 )?;
473 self.peer.send(
474 connection_id,
475 proto::UpdateInviteInfo {
476 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
477 count: user.invite_count as u32,
478 },
479 )?;
480 }
481 }
482 }
483 Ok(())
484 }
485
486 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
487 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
488 if let Some(invite_code) = &user.invite_code {
489 let pool = self.connection_pool.lock().await;
490 for connection_id in pool.user_connection_ids(user_id) {
491 self.peer.send(
492 connection_id,
493 proto::UpdateInviteInfo {
494 url: format!(
495 "{}{}",
496 self.app_state.config.invite_link_prefix, invite_code
497 ),
498 count: user.invite_count as u32,
499 },
500 )?;
501 }
502 }
503 }
504 Ok(())
505 }
506
507 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
508 ServerSnapshot {
509 connection_pool: ConnectionPoolGuard {
510 guard: self.connection_pool.lock().await,
511 _not_send: PhantomData,
512 },
513 peer: &self.peer,
514 }
515 }
516}
517
518impl<'a> Deref for ConnectionPoolGuard<'a> {
519 type Target = ConnectionPool;
520
521 fn deref(&self) -> &Self::Target {
522 &*self.guard
523 }
524}
525
526impl<'a> DerefMut for ConnectionPoolGuard<'a> {
527 fn deref_mut(&mut self) -> &mut Self::Target {
528 &mut *self.guard
529 }
530}
531
532impl<'a> Drop for ConnectionPoolGuard<'a> {
533 fn drop(&mut self) {
534 #[cfg(test)]
535 self.check_invariants();
536 }
537}
538
539fn broadcast<F>(
540 sender_id: ConnectionId,
541 receiver_ids: impl IntoIterator<Item = ConnectionId>,
542 mut f: F,
543) where
544 F: FnMut(ConnectionId) -> anyhow::Result<()>,
545{
546 for receiver_id in receiver_ids {
547 if receiver_id != sender_id {
548 f(receiver_id).trace_err();
549 }
550 }
551}
552
553lazy_static! {
554 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
555}
556
557pub struct ProtocolVersion(u32);
558
559impl Header for ProtocolVersion {
560 fn name() -> &'static HeaderName {
561 &ZED_PROTOCOL_VERSION
562 }
563
564 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
565 where
566 Self: Sized,
567 I: Iterator<Item = &'i axum::http::HeaderValue>,
568 {
569 let version = values
570 .next()
571 .ok_or_else(axum::headers::Error::invalid)?
572 .to_str()
573 .map_err(|_| axum::headers::Error::invalid())?
574 .parse()
575 .map_err(|_| axum::headers::Error::invalid())?;
576 Ok(Self(version))
577 }
578
579 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
580 values.extend([self.0.to_string().parse().unwrap()]);
581 }
582}
583
584pub fn routes(server: Arc<Server>) -> Router<Body> {
585 Router::new()
586 .route("/rpc", get(handle_websocket_request))
587 .layer(
588 ServiceBuilder::new()
589 .layer(Extension(server.app_state.clone()))
590 .layer(middleware::from_fn(auth::validate_header)),
591 )
592 .route("/metrics", get(handle_metrics))
593 .layer(Extension(server))
594}
595
596pub async fn handle_websocket_request(
597 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
598 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
599 Extension(server): Extension<Arc<Server>>,
600 Extension(user): Extension<User>,
601 ws: WebSocketUpgrade,
602) -> axum::response::Response {
603 if protocol_version != rpc::PROTOCOL_VERSION {
604 return (
605 StatusCode::UPGRADE_REQUIRED,
606 "client must be upgraded".to_string(),
607 )
608 .into_response();
609 }
610 let socket_address = socket_address.to_string();
611 ws.on_upgrade(move |socket| {
612 use util::ResultExt;
613 let socket = socket
614 .map_ok(to_tungstenite_message)
615 .err_into()
616 .with(|message| async move { Ok(to_axum_message(message)) });
617 let connection = Connection::new(Box::pin(socket));
618 async move {
619 server
620 .handle_connection(connection, socket_address, user, None, Executor::Production)
621 .await
622 .log_err();
623 }
624 })
625}
626
627pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
628 let connections = server
629 .connection_pool
630 .lock()
631 .await
632 .connections()
633 .filter(|connection| !connection.admin)
634 .count();
635
636 METRIC_CONNECTIONS.set(connections as _);
637
638 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
639 METRIC_SHARED_PROJECTS.set(shared_projects as _);
640
641 let encoder = prometheus::TextEncoder::new();
642 let metric_families = prometheus::gather();
643 let encoded_metrics = encoder
644 .encode_to_string(&metric_families)
645 .map_err(|err| anyhow!("{}", err))?;
646 Ok(encoded_metrics)
647}
648
649#[instrument(err, skip(executor))]
650async fn sign_out(
651 session: Session,
652 mut teardown: watch::Receiver<()>,
653 executor: Executor,
654) -> Result<()> {
655 session.peer.disconnect(session.connection_id);
656 session
657 .connection_pool()
658 .await
659 .remove_connection(session.connection_id)?;
660
661 if let Some(mut left_projects) = session
662 .db()
663 .await
664 .connection_lost(session.connection_id)
665 .await
666 .trace_err()
667 {
668 for left_project in mem::take(&mut *left_projects) {
669 project_left(&left_project, &session);
670 }
671 }
672
673 futures::select_biased! {
674 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
675 leave_room_for_session(&session).await.trace_err();
676
677 if !session
678 .connection_pool()
679 .await
680 .is_user_online(session.user_id)
681 {
682 let db = session.db().await;
683 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err() {
684 room_updated(&room, &session);
685 }
686 }
687 update_user_contacts(session.user_id, &session).await?;
688 }
689 _ = teardown.changed().fuse() => {}
690 }
691
692 Ok(())
693}
694
695async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
696 response.send(proto::Ack {})?;
697 Ok(())
698}
699
700async fn create_room(
701 _request: proto::CreateRoom,
702 response: Response<proto::CreateRoom>,
703 session: Session,
704) -> Result<()> {
705 let live_kit_room = nanoid::nanoid!(30);
706 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
707 if let Some(_) = live_kit
708 .create_room(live_kit_room.clone())
709 .await
710 .trace_err()
711 {
712 if let Some(token) = live_kit
713 .room_token(&live_kit_room, &session.connection_id.to_string())
714 .trace_err()
715 {
716 Some(proto::LiveKitConnectionInfo {
717 server_url: live_kit.url().into(),
718 token,
719 })
720 } else {
721 None
722 }
723 } else {
724 None
725 }
726 } else {
727 None
728 };
729
730 {
731 let room = session
732 .db()
733 .await
734 .create_room(session.user_id, session.connection_id, &live_kit_room)
735 .await?;
736
737 response.send(proto::CreateRoomResponse {
738 room: Some(room.clone()),
739 live_kit_connection_info,
740 })?;
741 }
742
743 update_user_contacts(session.user_id, &session).await?;
744 Ok(())
745}
746
747async fn join_room(
748 request: proto::JoinRoom,
749 response: Response<proto::JoinRoom>,
750 session: Session,
751) -> Result<()> {
752 let room = {
753 let room = session
754 .db()
755 .await
756 .join_room(
757 RoomId::from_proto(request.id),
758 session.user_id,
759 session.connection_id,
760 )
761 .await?;
762 room_updated(&room, &session);
763 room.clone()
764 };
765
766 for connection_id in session
767 .connection_pool()
768 .await
769 .user_connection_ids(session.user_id)
770 {
771 session
772 .peer
773 .send(connection_id, proto::CallCanceled {})
774 .trace_err();
775 }
776
777 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
778 if let Some(token) = live_kit
779 .room_token(&room.live_kit_room, &session.connection_id.to_string())
780 .trace_err()
781 {
782 Some(proto::LiveKitConnectionInfo {
783 server_url: live_kit.url().into(),
784 token,
785 })
786 } else {
787 None
788 }
789 } else {
790 None
791 };
792
793 response.send(proto::JoinRoomResponse {
794 room: Some(room),
795 live_kit_connection_info,
796 })?;
797
798 update_user_contacts(session.user_id, &session).await?;
799 Ok(())
800}
801
802async fn leave_room(_message: proto::LeaveRoom, session: Session) -> Result<()> {
803 leave_room_for_session(&session).await
804}
805
806async fn call(
807 request: proto::Call,
808 response: Response<proto::Call>,
809 session: Session,
810) -> Result<()> {
811 let room_id = RoomId::from_proto(request.room_id);
812 let calling_user_id = session.user_id;
813 let calling_connection_id = session.connection_id;
814 let called_user_id = UserId::from_proto(request.called_user_id);
815 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
816 if !session
817 .db()
818 .await
819 .has_contact(calling_user_id, called_user_id)
820 .await?
821 {
822 return Err(anyhow!("cannot call a user who isn't a contact"))?;
823 }
824
825 let incoming_call = {
826 let (room, incoming_call) = &mut *session
827 .db()
828 .await
829 .call(
830 room_id,
831 calling_user_id,
832 calling_connection_id,
833 called_user_id,
834 initial_project_id,
835 )
836 .await?;
837 room_updated(&room, &session);
838 mem::take(incoming_call)
839 };
840 update_user_contacts(called_user_id, &session).await?;
841
842 let mut calls = session
843 .connection_pool()
844 .await
845 .user_connection_ids(called_user_id)
846 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
847 .collect::<FuturesUnordered<_>>();
848
849 while let Some(call_response) = calls.next().await {
850 match call_response.as_ref() {
851 Ok(_) => {
852 response.send(proto::Ack {})?;
853 return Ok(());
854 }
855 Err(_) => {
856 call_response.trace_err();
857 }
858 }
859 }
860
861 {
862 let room = session
863 .db()
864 .await
865 .call_failed(room_id, called_user_id)
866 .await?;
867 room_updated(&room, &session);
868 }
869 update_user_contacts(called_user_id, &session).await?;
870
871 Err(anyhow!("failed to ring user"))?
872}
873
874async fn cancel_call(
875 request: proto::CancelCall,
876 response: Response<proto::CancelCall>,
877 session: Session,
878) -> Result<()> {
879 let called_user_id = UserId::from_proto(request.called_user_id);
880 let room_id = RoomId::from_proto(request.room_id);
881 {
882 let room = session
883 .db()
884 .await
885 .cancel_call(Some(room_id), session.connection_id, called_user_id)
886 .await?;
887 room_updated(&room, &session);
888 }
889
890 for connection_id in session
891 .connection_pool()
892 .await
893 .user_connection_ids(called_user_id)
894 {
895 session
896 .peer
897 .send(connection_id, proto::CallCanceled {})
898 .trace_err();
899 }
900 response.send(proto::Ack {})?;
901
902 update_user_contacts(called_user_id, &session).await?;
903 Ok(())
904}
905
906async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
907 let room_id = RoomId::from_proto(message.room_id);
908 {
909 let room = session
910 .db()
911 .await
912 .decline_call(Some(room_id), session.user_id)
913 .await?;
914 room_updated(&room, &session);
915 }
916
917 for connection_id in session
918 .connection_pool()
919 .await
920 .user_connection_ids(session.user_id)
921 {
922 session
923 .peer
924 .send(connection_id, proto::CallCanceled {})
925 .trace_err();
926 }
927 update_user_contacts(session.user_id, &session).await?;
928 Ok(())
929}
930
931async fn update_participant_location(
932 request: proto::UpdateParticipantLocation,
933 response: Response<proto::UpdateParticipantLocation>,
934 session: Session,
935) -> Result<()> {
936 let room_id = RoomId::from_proto(request.room_id);
937 let location = request
938 .location
939 .ok_or_else(|| anyhow!("invalid location"))?;
940 let room = session
941 .db()
942 .await
943 .update_room_participant_location(room_id, session.connection_id, location)
944 .await?;
945 room_updated(&room, &session);
946 response.send(proto::Ack {})?;
947 Ok(())
948}
949
950async fn share_project(
951 request: proto::ShareProject,
952 response: Response<proto::ShareProject>,
953 session: Session,
954) -> Result<()> {
955 let (project_id, room) = &*session
956 .db()
957 .await
958 .share_project(
959 RoomId::from_proto(request.room_id),
960 session.connection_id,
961 &request.worktrees,
962 )
963 .await?;
964 response.send(proto::ShareProjectResponse {
965 project_id: project_id.to_proto(),
966 })?;
967 room_updated(&room, &session);
968
969 Ok(())
970}
971
972async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
973 let project_id = ProjectId::from_proto(message.project_id);
974
975 let (room, guest_connection_ids) = &*session
976 .db()
977 .await
978 .unshare_project(project_id, session.connection_id)
979 .await?;
980
981 broadcast(
982 session.connection_id,
983 guest_connection_ids.iter().copied(),
984 |conn_id| session.peer.send(conn_id, message.clone()),
985 );
986 room_updated(&room, &session);
987
988 Ok(())
989}
990
991async fn join_project(
992 request: proto::JoinProject,
993 response: Response<proto::JoinProject>,
994 session: Session,
995) -> Result<()> {
996 let project_id = ProjectId::from_proto(request.project_id);
997 let guest_user_id = session.user_id;
998
999 tracing::info!(%project_id, "join project");
1000
1001 let (project, replica_id) = &mut *session
1002 .db()
1003 .await
1004 .join_project(project_id, session.connection_id)
1005 .await?;
1006
1007 let collaborators = project
1008 .collaborators
1009 .iter()
1010 .filter(|collaborator| collaborator.connection_id != session.connection_id.0 as i32)
1011 .map(|collaborator| proto::Collaborator {
1012 peer_id: collaborator.connection_id as u32,
1013 replica_id: collaborator.replica_id.0 as u32,
1014 user_id: collaborator.user_id.to_proto(),
1015 })
1016 .collect::<Vec<_>>();
1017 let worktrees = project
1018 .worktrees
1019 .iter()
1020 .map(|(id, worktree)| proto::WorktreeMetadata {
1021 id: *id,
1022 root_name: worktree.root_name.clone(),
1023 visible: worktree.visible,
1024 abs_path: worktree.abs_path.clone(),
1025 })
1026 .collect::<Vec<_>>();
1027
1028 for collaborator in &collaborators {
1029 session
1030 .peer
1031 .send(
1032 ConnectionId(collaborator.peer_id),
1033 proto::AddProjectCollaborator {
1034 project_id: project_id.to_proto(),
1035 collaborator: Some(proto::Collaborator {
1036 peer_id: session.connection_id.0,
1037 replica_id: replica_id.0 as u32,
1038 user_id: guest_user_id.to_proto(),
1039 }),
1040 },
1041 )
1042 .trace_err();
1043 }
1044
1045 // First, we send the metadata associated with each worktree.
1046 response.send(proto::JoinProjectResponse {
1047 worktrees: worktrees.clone(),
1048 replica_id: replica_id.0 as u32,
1049 collaborators: collaborators.clone(),
1050 language_servers: project.language_servers.clone(),
1051 })?;
1052
1053 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1054 #[cfg(any(test, feature = "test-support"))]
1055 const MAX_CHUNK_SIZE: usize = 2;
1056 #[cfg(not(any(test, feature = "test-support")))]
1057 const MAX_CHUNK_SIZE: usize = 256;
1058
1059 // Stream this worktree's entries.
1060 let message = proto::UpdateWorktree {
1061 project_id: project_id.to_proto(),
1062 worktree_id,
1063 abs_path: worktree.abs_path.clone(),
1064 root_name: worktree.root_name,
1065 updated_entries: worktree.entries,
1066 removed_entries: Default::default(),
1067 scan_id: worktree.scan_id,
1068 is_last_update: worktree.is_complete,
1069 };
1070 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1071 session.peer.send(session.connection_id, update.clone())?;
1072 }
1073
1074 // Stream this worktree's diagnostics.
1075 for summary in worktree.diagnostic_summaries {
1076 session.peer.send(
1077 session.connection_id,
1078 proto::UpdateDiagnosticSummary {
1079 project_id: project_id.to_proto(),
1080 worktree_id: worktree.id,
1081 summary: Some(summary),
1082 },
1083 )?;
1084 }
1085 }
1086
1087 for language_server in &project.language_servers {
1088 session.peer.send(
1089 session.connection_id,
1090 proto::UpdateLanguageServer {
1091 project_id: project_id.to_proto(),
1092 language_server_id: language_server.id,
1093 variant: Some(
1094 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1095 proto::LspDiskBasedDiagnosticsUpdated {},
1096 ),
1097 ),
1098 },
1099 )?;
1100 }
1101
1102 Ok(())
1103}
1104
1105async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1106 let sender_id = session.connection_id;
1107 let project_id = ProjectId::from_proto(request.project_id);
1108
1109 let project = session
1110 .db()
1111 .await
1112 .leave_project(project_id, sender_id)
1113 .await?;
1114 tracing::info!(
1115 %project_id,
1116 host_user_id = %project.host_user_id,
1117 host_connection_id = %project.host_connection_id,
1118 "leave project"
1119 );
1120 project_left(&project, &session);
1121
1122 Ok(())
1123}
1124
1125async fn update_project(
1126 request: proto::UpdateProject,
1127 response: Response<proto::UpdateProject>,
1128 session: Session,
1129) -> Result<()> {
1130 let project_id = ProjectId::from_proto(request.project_id);
1131 let (room, guest_connection_ids) = &*session
1132 .db()
1133 .await
1134 .update_project(project_id, session.connection_id, &request.worktrees)
1135 .await?;
1136 broadcast(
1137 session.connection_id,
1138 guest_connection_ids.iter().copied(),
1139 |connection_id| {
1140 session
1141 .peer
1142 .forward_send(session.connection_id, connection_id, request.clone())
1143 },
1144 );
1145 room_updated(&room, &session);
1146 response.send(proto::Ack {})?;
1147
1148 Ok(())
1149}
1150
1151async fn update_worktree(
1152 request: proto::UpdateWorktree,
1153 response: Response<proto::UpdateWorktree>,
1154 session: Session,
1155) -> Result<()> {
1156 let guest_connection_ids = session
1157 .db()
1158 .await
1159 .update_worktree(&request, session.connection_id)
1160 .await?;
1161
1162 broadcast(
1163 session.connection_id,
1164 guest_connection_ids.iter().copied(),
1165 |connection_id| {
1166 session
1167 .peer
1168 .forward_send(session.connection_id, connection_id, request.clone())
1169 },
1170 );
1171 response.send(proto::Ack {})?;
1172 Ok(())
1173}
1174
1175async fn update_diagnostic_summary(
1176 message: proto::UpdateDiagnosticSummary,
1177 session: Session,
1178) -> Result<()> {
1179 let guest_connection_ids = session
1180 .db()
1181 .await
1182 .update_diagnostic_summary(&message, session.connection_id)
1183 .await?;
1184
1185 broadcast(
1186 session.connection_id,
1187 guest_connection_ids.iter().copied(),
1188 |connection_id| {
1189 session
1190 .peer
1191 .forward_send(session.connection_id, connection_id, message.clone())
1192 },
1193 );
1194
1195 Ok(())
1196}
1197
1198async fn start_language_server(
1199 request: proto::StartLanguageServer,
1200 session: Session,
1201) -> Result<()> {
1202 let guest_connection_ids = session
1203 .db()
1204 .await
1205 .start_language_server(&request, session.connection_id)
1206 .await?;
1207
1208 broadcast(
1209 session.connection_id,
1210 guest_connection_ids.iter().copied(),
1211 |connection_id| {
1212 session
1213 .peer
1214 .forward_send(session.connection_id, connection_id, request.clone())
1215 },
1216 );
1217 Ok(())
1218}
1219
1220async fn update_language_server(
1221 request: proto::UpdateLanguageServer,
1222 session: Session,
1223) -> Result<()> {
1224 let project_id = ProjectId::from_proto(request.project_id);
1225 let project_connection_ids = session
1226 .db()
1227 .await
1228 .project_connection_ids(project_id, session.connection_id)
1229 .await?;
1230 broadcast(
1231 session.connection_id,
1232 project_connection_ids.iter().copied(),
1233 |connection_id| {
1234 session
1235 .peer
1236 .forward_send(session.connection_id, connection_id, request.clone())
1237 },
1238 );
1239 Ok(())
1240}
1241
1242async fn forward_project_request<T>(
1243 request: T,
1244 response: Response<T>,
1245 session: Session,
1246) -> Result<()>
1247where
1248 T: EntityMessage + RequestMessage,
1249{
1250 let project_id = ProjectId::from_proto(request.remote_entity_id());
1251 let host_connection_id = {
1252 let collaborators = session
1253 .db()
1254 .await
1255 .project_collaborators(project_id, session.connection_id)
1256 .await?;
1257 ConnectionId(
1258 collaborators
1259 .iter()
1260 .find(|collaborator| collaborator.is_host)
1261 .ok_or_else(|| anyhow!("host not found"))?
1262 .connection_id as u32,
1263 )
1264 };
1265
1266 let payload = session
1267 .peer
1268 .forward_request(session.connection_id, host_connection_id, request)
1269 .await?;
1270
1271 response.send(payload)?;
1272 Ok(())
1273}
1274
1275async fn save_buffer(
1276 request: proto::SaveBuffer,
1277 response: Response<proto::SaveBuffer>,
1278 session: Session,
1279) -> Result<()> {
1280 let project_id = ProjectId::from_proto(request.project_id);
1281 let host_connection_id = {
1282 let collaborators = session
1283 .db()
1284 .await
1285 .project_collaborators(project_id, session.connection_id)
1286 .await?;
1287 let host = collaborators
1288 .iter()
1289 .find(|collaborator| collaborator.is_host)
1290 .ok_or_else(|| anyhow!("host not found"))?;
1291 ConnectionId(host.connection_id as u32)
1292 };
1293 let response_payload = session
1294 .peer
1295 .forward_request(session.connection_id, host_connection_id, request.clone())
1296 .await?;
1297
1298 let mut collaborators = session
1299 .db()
1300 .await
1301 .project_collaborators(project_id, session.connection_id)
1302 .await?;
1303 collaborators
1304 .retain(|collaborator| collaborator.connection_id != session.connection_id.0 as i32);
1305 let project_connection_ids = collaborators
1306 .iter()
1307 .map(|collaborator| ConnectionId(collaborator.connection_id as u32));
1308 broadcast(host_connection_id, project_connection_ids, |conn_id| {
1309 session
1310 .peer
1311 .forward_send(host_connection_id, conn_id, response_payload.clone())
1312 });
1313 response.send(response_payload)?;
1314 Ok(())
1315}
1316
1317async fn create_buffer_for_peer(
1318 request: proto::CreateBufferForPeer,
1319 session: Session,
1320) -> Result<()> {
1321 session.peer.forward_send(
1322 session.connection_id,
1323 ConnectionId(request.peer_id),
1324 request,
1325 )?;
1326 Ok(())
1327}
1328
1329async fn update_buffer(
1330 request: proto::UpdateBuffer,
1331 response: Response<proto::UpdateBuffer>,
1332 session: Session,
1333) -> Result<()> {
1334 let project_id = ProjectId::from_proto(request.project_id);
1335 let project_connection_ids = session
1336 .db()
1337 .await
1338 .project_connection_ids(project_id, session.connection_id)
1339 .await?;
1340
1341 broadcast(
1342 session.connection_id,
1343 project_connection_ids.iter().copied(),
1344 |connection_id| {
1345 session
1346 .peer
1347 .forward_send(session.connection_id, connection_id, request.clone())
1348 },
1349 );
1350 response.send(proto::Ack {})?;
1351 Ok(())
1352}
1353
1354async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1355 let project_id = ProjectId::from_proto(request.project_id);
1356 let project_connection_ids = session
1357 .db()
1358 .await
1359 .project_connection_ids(project_id, session.connection_id)
1360 .await?;
1361
1362 broadcast(
1363 session.connection_id,
1364 project_connection_ids.iter().copied(),
1365 |connection_id| {
1366 session
1367 .peer
1368 .forward_send(session.connection_id, connection_id, request.clone())
1369 },
1370 );
1371 Ok(())
1372}
1373
1374async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1375 let project_id = ProjectId::from_proto(request.project_id);
1376 let project_connection_ids = session
1377 .db()
1378 .await
1379 .project_connection_ids(project_id, session.connection_id)
1380 .await?;
1381 broadcast(
1382 session.connection_id,
1383 project_connection_ids.iter().copied(),
1384 |connection_id| {
1385 session
1386 .peer
1387 .forward_send(session.connection_id, connection_id, request.clone())
1388 },
1389 );
1390 Ok(())
1391}
1392
1393async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1394 let project_id = ProjectId::from_proto(request.project_id);
1395 let project_connection_ids = session
1396 .db()
1397 .await
1398 .project_connection_ids(project_id, session.connection_id)
1399 .await?;
1400 broadcast(
1401 session.connection_id,
1402 project_connection_ids.iter().copied(),
1403 |connection_id| {
1404 session
1405 .peer
1406 .forward_send(session.connection_id, connection_id, request.clone())
1407 },
1408 );
1409 Ok(())
1410}
1411
1412async fn follow(
1413 request: proto::Follow,
1414 response: Response<proto::Follow>,
1415 session: Session,
1416) -> Result<()> {
1417 let project_id = ProjectId::from_proto(request.project_id);
1418 let leader_id = ConnectionId(request.leader_id);
1419 let follower_id = session.connection_id;
1420 {
1421 let project_connection_ids = session
1422 .db()
1423 .await
1424 .project_connection_ids(project_id, session.connection_id)
1425 .await?;
1426
1427 if !project_connection_ids.contains(&leader_id) {
1428 Err(anyhow!("no such peer"))?;
1429 }
1430 }
1431
1432 let mut response_payload = session
1433 .peer
1434 .forward_request(session.connection_id, leader_id, request)
1435 .await?;
1436 response_payload
1437 .views
1438 .retain(|view| view.leader_id != Some(follower_id.0));
1439 response.send(response_payload)?;
1440 Ok(())
1441}
1442
1443async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1444 let project_id = ProjectId::from_proto(request.project_id);
1445 let leader_id = ConnectionId(request.leader_id);
1446 let project_connection_ids = session
1447 .db()
1448 .await
1449 .project_connection_ids(project_id, session.connection_id)
1450 .await?;
1451 if !project_connection_ids.contains(&leader_id) {
1452 Err(anyhow!("no such peer"))?;
1453 }
1454 session
1455 .peer
1456 .forward_send(session.connection_id, leader_id, request)?;
1457 Ok(())
1458}
1459
1460async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1461 let project_id = ProjectId::from_proto(request.project_id);
1462 let project_connection_ids = session
1463 .db
1464 .lock()
1465 .await
1466 .project_connection_ids(project_id, session.connection_id)
1467 .await?;
1468
1469 let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1470 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1471 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1472 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1473 });
1474 for follower_id in &request.follower_ids {
1475 let follower_id = ConnectionId(*follower_id);
1476 if project_connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
1477 session
1478 .peer
1479 .forward_send(session.connection_id, follower_id, request.clone())?;
1480 }
1481 }
1482 Ok(())
1483}
1484
1485async fn get_users(
1486 request: proto::GetUsers,
1487 response: Response<proto::GetUsers>,
1488 session: Session,
1489) -> Result<()> {
1490 let user_ids = request
1491 .user_ids
1492 .into_iter()
1493 .map(UserId::from_proto)
1494 .collect();
1495 let users = session
1496 .db()
1497 .await
1498 .get_users_by_ids(user_ids)
1499 .await?
1500 .into_iter()
1501 .map(|user| proto::User {
1502 id: user.id.to_proto(),
1503 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1504 github_login: user.github_login,
1505 })
1506 .collect();
1507 response.send(proto::UsersResponse { users })?;
1508 Ok(())
1509}
1510
1511async fn fuzzy_search_users(
1512 request: proto::FuzzySearchUsers,
1513 response: Response<proto::FuzzySearchUsers>,
1514 session: Session,
1515) -> Result<()> {
1516 let query = request.query;
1517 let users = match query.len() {
1518 0 => vec![],
1519 1 | 2 => session
1520 .db()
1521 .await
1522 .get_user_by_github_account(&query, None)
1523 .await?
1524 .into_iter()
1525 .collect(),
1526 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1527 };
1528 let users = users
1529 .into_iter()
1530 .filter(|user| user.id != session.user_id)
1531 .map(|user| proto::User {
1532 id: user.id.to_proto(),
1533 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1534 github_login: user.github_login,
1535 })
1536 .collect();
1537 response.send(proto::UsersResponse { users })?;
1538 Ok(())
1539}
1540
1541async fn request_contact(
1542 request: proto::RequestContact,
1543 response: Response<proto::RequestContact>,
1544 session: Session,
1545) -> Result<()> {
1546 let requester_id = session.user_id;
1547 let responder_id = UserId::from_proto(request.responder_id);
1548 if requester_id == responder_id {
1549 return Err(anyhow!("cannot add yourself as a contact"))?;
1550 }
1551
1552 session
1553 .db()
1554 .await
1555 .send_contact_request(requester_id, responder_id)
1556 .await?;
1557
1558 // Update outgoing contact requests of requester
1559 let mut update = proto::UpdateContacts::default();
1560 update.outgoing_requests.push(responder_id.to_proto());
1561 for connection_id in session
1562 .connection_pool()
1563 .await
1564 .user_connection_ids(requester_id)
1565 {
1566 session.peer.send(connection_id, update.clone())?;
1567 }
1568
1569 // Update incoming contact requests of responder
1570 let mut update = proto::UpdateContacts::default();
1571 update
1572 .incoming_requests
1573 .push(proto::IncomingContactRequest {
1574 requester_id: requester_id.to_proto(),
1575 should_notify: true,
1576 });
1577 for connection_id in session
1578 .connection_pool()
1579 .await
1580 .user_connection_ids(responder_id)
1581 {
1582 session.peer.send(connection_id, update.clone())?;
1583 }
1584
1585 response.send(proto::Ack {})?;
1586 Ok(())
1587}
1588
1589async fn respond_to_contact_request(
1590 request: proto::RespondToContactRequest,
1591 response: Response<proto::RespondToContactRequest>,
1592 session: Session,
1593) -> Result<()> {
1594 let responder_id = session.user_id;
1595 let requester_id = UserId::from_proto(request.requester_id);
1596 let db = session.db().await;
1597 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
1598 db.dismiss_contact_notification(responder_id, requester_id)
1599 .await?;
1600 } else {
1601 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
1602
1603 db.respond_to_contact_request(responder_id, requester_id, accept)
1604 .await?;
1605 let requester_busy = db.is_user_busy(requester_id).await?;
1606 let responder_busy = db.is_user_busy(responder_id).await?;
1607
1608 let pool = session.connection_pool().await;
1609 // Update responder with new contact
1610 let mut update = proto::UpdateContacts::default();
1611 if accept {
1612 update
1613 .contacts
1614 .push(contact_for_user(requester_id, false, requester_busy, &pool));
1615 }
1616 update
1617 .remove_incoming_requests
1618 .push(requester_id.to_proto());
1619 for connection_id in pool.user_connection_ids(responder_id) {
1620 session.peer.send(connection_id, update.clone())?;
1621 }
1622
1623 // Update requester with new contact
1624 let mut update = proto::UpdateContacts::default();
1625 if accept {
1626 update
1627 .contacts
1628 .push(contact_for_user(responder_id, true, responder_busy, &pool));
1629 }
1630 update
1631 .remove_outgoing_requests
1632 .push(responder_id.to_proto());
1633 for connection_id in pool.user_connection_ids(requester_id) {
1634 session.peer.send(connection_id, update.clone())?;
1635 }
1636 }
1637
1638 response.send(proto::Ack {})?;
1639 Ok(())
1640}
1641
1642async fn remove_contact(
1643 request: proto::RemoveContact,
1644 response: Response<proto::RemoveContact>,
1645 session: Session,
1646) -> Result<()> {
1647 let requester_id = session.user_id;
1648 let responder_id = UserId::from_proto(request.user_id);
1649 let db = session.db().await;
1650 db.remove_contact(requester_id, responder_id).await?;
1651
1652 let pool = session.connection_pool().await;
1653 // Update outgoing contact requests of requester
1654 let mut update = proto::UpdateContacts::default();
1655 update
1656 .remove_outgoing_requests
1657 .push(responder_id.to_proto());
1658 for connection_id in pool.user_connection_ids(requester_id) {
1659 session.peer.send(connection_id, update.clone())?;
1660 }
1661
1662 // Update incoming contact requests of responder
1663 let mut update = proto::UpdateContacts::default();
1664 update
1665 .remove_incoming_requests
1666 .push(requester_id.to_proto());
1667 for connection_id in pool.user_connection_ids(responder_id) {
1668 session.peer.send(connection_id, update.clone())?;
1669 }
1670
1671 response.send(proto::Ack {})?;
1672 Ok(())
1673}
1674
1675async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
1676 let project_id = ProjectId::from_proto(request.project_id);
1677 let project_connection_ids = session
1678 .db()
1679 .await
1680 .project_connection_ids(project_id, session.connection_id)
1681 .await?;
1682 broadcast(
1683 session.connection_id,
1684 project_connection_ids.iter().copied(),
1685 |connection_id| {
1686 session
1687 .peer
1688 .forward_send(session.connection_id, connection_id, request.clone())
1689 },
1690 );
1691 Ok(())
1692}
1693
1694async fn get_private_user_info(
1695 _request: proto::GetPrivateUserInfo,
1696 response: Response<proto::GetPrivateUserInfo>,
1697 session: Session,
1698) -> Result<()> {
1699 let metrics_id = session
1700 .db()
1701 .await
1702 .get_user_metrics_id(session.user_id)
1703 .await?;
1704 let user = session
1705 .db()
1706 .await
1707 .get_user_by_id(session.user_id)
1708 .await?
1709 .ok_or_else(|| anyhow!("user not found"))?;
1710 response.send(proto::GetPrivateUserInfoResponse {
1711 metrics_id,
1712 staff: user.admin,
1713 })?;
1714 Ok(())
1715}
1716
1717fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1718 match message {
1719 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1720 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1721 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1722 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1723 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1724 code: frame.code.into(),
1725 reason: frame.reason,
1726 })),
1727 }
1728}
1729
1730fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1731 match message {
1732 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1733 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1734 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1735 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1736 AxumMessage::Close(frame) => {
1737 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1738 code: frame.code.into(),
1739 reason: frame.reason,
1740 }))
1741 }
1742 }
1743}
1744
1745fn build_initial_contacts_update(
1746 contacts: Vec<db::Contact>,
1747 pool: &ConnectionPool,
1748) -> proto::UpdateContacts {
1749 let mut update = proto::UpdateContacts::default();
1750
1751 for contact in contacts {
1752 match contact {
1753 db::Contact::Accepted {
1754 user_id,
1755 should_notify,
1756 busy,
1757 } => {
1758 update
1759 .contacts
1760 .push(contact_for_user(user_id, should_notify, busy, &pool));
1761 }
1762 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
1763 db::Contact::Incoming {
1764 user_id,
1765 should_notify,
1766 } => update
1767 .incoming_requests
1768 .push(proto::IncomingContactRequest {
1769 requester_id: user_id.to_proto(),
1770 should_notify,
1771 }),
1772 }
1773 }
1774
1775 update
1776}
1777
1778fn contact_for_user(
1779 user_id: UserId,
1780 should_notify: bool,
1781 busy: bool,
1782 pool: &ConnectionPool,
1783) -> proto::Contact {
1784 proto::Contact {
1785 user_id: user_id.to_proto(),
1786 online: pool.is_user_online(user_id),
1787 busy,
1788 should_notify,
1789 }
1790}
1791
1792fn room_updated(room: &proto::Room, session: &Session) {
1793 for participant in &room.participants {
1794 session
1795 .peer
1796 .send(
1797 ConnectionId(participant.peer_id),
1798 proto::RoomUpdated {
1799 room: Some(room.clone()),
1800 },
1801 )
1802 .trace_err();
1803 }
1804}
1805
1806async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
1807 let db = session.db().await;
1808 let contacts = db.get_contacts(user_id).await?;
1809 let busy = db.is_user_busy(user_id).await?;
1810
1811 let pool = session.connection_pool().await;
1812 let updated_contact = contact_for_user(user_id, false, busy, &pool);
1813 for contact in contacts {
1814 if let db::Contact::Accepted {
1815 user_id: contact_user_id,
1816 ..
1817 } = contact
1818 {
1819 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
1820 session
1821 .peer
1822 .send(
1823 contact_conn_id,
1824 proto::UpdateContacts {
1825 contacts: vec![updated_contact.clone()],
1826 remove_contacts: Default::default(),
1827 incoming_requests: Default::default(),
1828 remove_incoming_requests: Default::default(),
1829 outgoing_requests: Default::default(),
1830 remove_outgoing_requests: Default::default(),
1831 },
1832 )
1833 .trace_err();
1834 }
1835 }
1836 }
1837 Ok(())
1838}
1839
1840async fn leave_room_for_session(session: &Session) -> Result<()> {
1841 let mut contacts_to_update = HashSet::default();
1842
1843 let canceled_calls_to_user_ids;
1844 let live_kit_room;
1845 let delete_live_kit_room;
1846 {
1847 let mut left_room = session.db().await.leave_room(session.connection_id).await?;
1848 contacts_to_update.insert(session.user_id);
1849
1850 for project in left_room.left_projects.values() {
1851 project_left(project, session);
1852 }
1853
1854 room_updated(&left_room.room, &session);
1855 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
1856 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
1857 delete_live_kit_room = left_room.room.participants.is_empty();
1858 }
1859
1860 {
1861 let pool = session.connection_pool().await;
1862 for canceled_user_id in canceled_calls_to_user_ids {
1863 for connection_id in pool.user_connection_ids(canceled_user_id) {
1864 session
1865 .peer
1866 .send(connection_id, proto::CallCanceled {})
1867 .trace_err();
1868 }
1869 contacts_to_update.insert(canceled_user_id);
1870 }
1871 }
1872
1873 for contact_user_id in contacts_to_update {
1874 update_user_contacts(contact_user_id, &session).await?;
1875 }
1876
1877 if let Some(live_kit) = session.live_kit_client.as_ref() {
1878 live_kit
1879 .remove_participant(live_kit_room.clone(), session.connection_id.to_string())
1880 .await
1881 .trace_err();
1882
1883 if delete_live_kit_room {
1884 live_kit.delete_room(live_kit_room).await.trace_err();
1885 }
1886 }
1887
1888 Ok(())
1889}
1890
1891fn project_left(project: &db::LeftProject, session: &Session) {
1892 for connection_id in &project.connection_ids {
1893 if project.host_user_id == session.user_id {
1894 session
1895 .peer
1896 .send(
1897 *connection_id,
1898 proto::UnshareProject {
1899 project_id: project.id.to_proto(),
1900 },
1901 )
1902 .trace_err();
1903 } else {
1904 session
1905 .peer
1906 .send(
1907 *connection_id,
1908 proto::RemoveProjectCollaborator {
1909 project_id: project.id.to_proto(),
1910 peer_id: session.connection_id.0,
1911 },
1912 )
1913 .trace_err();
1914 }
1915 }
1916
1917 session
1918 .peer
1919 .send(
1920 session.connection_id,
1921 proto::UnshareProject {
1922 project_id: project.id.to_proto(),
1923 },
1924 )
1925 .trace_err();
1926}
1927
1928pub trait ResultExt {
1929 type Ok;
1930
1931 fn trace_err(self) -> Option<Self::Ok>;
1932}
1933
1934impl<T, E> ResultExt for Result<T, E>
1935where
1936 E: std::fmt::Debug,
1937{
1938 type Ok = T;
1939
1940 fn trace_err(self) -> Option<T> {
1941 match self {
1942 Ok(value) => Some(value),
1943 Err(error) => {
1944 tracing::error!("{:?}", error);
1945 None
1946 }
1947 }
1948 }
1949}