1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{self, Database, ProjectId, RoomId, User, UserId},
6 executor::Executor,
7 AppState, Result,
8};
9use anyhow::anyhow;
10use async_tungstenite::tungstenite::{
11 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
12};
13use axum::{
14 body::Body,
15 extract::{
16 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
17 ConnectInfo, WebSocketUpgrade,
18 },
19 headers::{Header, HeaderName},
20 http::StatusCode,
21 middleware,
22 response::IntoResponse,
23 routing::get,
24 Extension, Router, TypedHeader,
25};
26use collections::{HashMap, HashSet};
27pub use connection_pool::ConnectionPool;
28use futures::{
29 channel::oneshot,
30 future::{self, BoxFuture},
31 stream::FuturesUnordered,
32 FutureExt, SinkExt, StreamExt, TryStreamExt,
33};
34use lazy_static::lazy_static;
35use prometheus::{register_int_gauge, IntGauge};
36use rpc::{
37 proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
38 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
39};
40use serde::{Serialize, Serializer};
41use std::{
42 any::TypeId,
43 fmt,
44 future::Future,
45 marker::PhantomData,
46 mem,
47 net::SocketAddr,
48 ops::{Deref, DerefMut},
49 rc::Rc,
50 sync::{
51 atomic::{AtomicBool, Ordering::SeqCst},
52 Arc,
53 },
54 time::Duration,
55};
56use tokio::sync::{watch, Mutex, MutexGuard};
57use tower::ServiceBuilder;
58use tracing::{info_span, instrument, Instrument};
59
60pub const RECONNECT_TIMEOUT: Duration = rpc::RECEIVE_TIMEOUT;
61
62lazy_static! {
63 static ref METRIC_CONNECTIONS: IntGauge =
64 register_int_gauge!("connections", "number of connections").unwrap();
65 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
66 "shared_projects",
67 "number of open projects with one or more guests"
68 )
69 .unwrap();
70}
71
72type MessageHandler =
73 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
74
75struct Response<R> {
76 peer: Arc<Peer>,
77 receipt: Receipt<R>,
78 responded: Arc<AtomicBool>,
79}
80
81impl<R: RequestMessage> Response<R> {
82 fn send(self, payload: R::Response) -> Result<()> {
83 self.responded.store(true, SeqCst);
84 self.peer.respond(self.receipt, payload)?;
85 Ok(())
86 }
87}
88
89#[derive(Clone)]
90struct Session {
91 user_id: UserId,
92 connection_id: ConnectionId,
93 db: Arc<Mutex<DbHandle>>,
94 peer: Arc<Peer>,
95 connection_pool: Arc<Mutex<ConnectionPool>>,
96 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
97}
98
99impl Session {
100 async fn db(&self) -> MutexGuard<DbHandle> {
101 #[cfg(test)]
102 tokio::task::yield_now().await;
103 let guard = self.db.lock().await;
104 #[cfg(test)]
105 tokio::task::yield_now().await;
106 guard
107 }
108
109 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
110 #[cfg(test)]
111 tokio::task::yield_now().await;
112 let guard = self.connection_pool.lock().await;
113 #[cfg(test)]
114 tokio::task::yield_now().await;
115 ConnectionPoolGuard {
116 guard,
117 _not_send: PhantomData,
118 }
119 }
120}
121
122impl fmt::Debug for Session {
123 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124 f.debug_struct("Session")
125 .field("user_id", &self.user_id)
126 .field("connection_id", &self.connection_id)
127 .finish()
128 }
129}
130
131struct DbHandle(Arc<Database>);
132
133impl Deref for DbHandle {
134 type Target = Database;
135
136 fn deref(&self) -> &Self::Target {
137 self.0.as_ref()
138 }
139}
140
141pub struct Server {
142 peer: Arc<Peer>,
143 pub(crate) connection_pool: Arc<Mutex<ConnectionPool>>,
144 app_state: Arc<AppState>,
145 executor: Executor,
146 handlers: HashMap<TypeId, MessageHandler>,
147 teardown: watch::Sender<()>,
148}
149
150pub(crate) struct ConnectionPoolGuard<'a> {
151 guard: MutexGuard<'a, ConnectionPool>,
152 _not_send: PhantomData<Rc<()>>,
153}
154
155#[derive(Serialize)]
156pub struct ServerSnapshot<'a> {
157 peer: &'a Peer,
158 #[serde(serialize_with = "serialize_deref")]
159 connection_pool: ConnectionPoolGuard<'a>,
160}
161
162pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
163where
164 S: Serializer,
165 T: Deref<Target = U>,
166 U: Serialize,
167{
168 Serialize::serialize(value.deref(), serializer)
169}
170
171impl Server {
172 pub fn new(app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
173 let mut server = Self {
174 peer: Peer::new(),
175 app_state,
176 executor,
177 connection_pool: Default::default(),
178 handlers: Default::default(),
179 teardown: watch::channel(()).0,
180 };
181
182 server
183 .add_request_handler(ping)
184 .add_request_handler(create_room)
185 .add_request_handler(join_room)
186 .add_message_handler(leave_room)
187 .add_request_handler(call)
188 .add_request_handler(cancel_call)
189 .add_message_handler(decline_call)
190 .add_request_handler(update_participant_location)
191 .add_request_handler(share_project)
192 .add_message_handler(unshare_project)
193 .add_request_handler(join_project)
194 .add_message_handler(leave_project)
195 .add_request_handler(update_project)
196 .add_request_handler(update_worktree)
197 .add_message_handler(start_language_server)
198 .add_message_handler(update_language_server)
199 .add_message_handler(update_diagnostic_summary)
200 .add_request_handler(forward_project_request::<proto::GetHover>)
201 .add_request_handler(forward_project_request::<proto::GetDefinition>)
202 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
203 .add_request_handler(forward_project_request::<proto::GetReferences>)
204 .add_request_handler(forward_project_request::<proto::SearchProject>)
205 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
206 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
207 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
208 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
209 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
210 .add_request_handler(forward_project_request::<proto::GetCompletions>)
211 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
212 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
213 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
214 .add_request_handler(forward_project_request::<proto::PrepareRename>)
215 .add_request_handler(forward_project_request::<proto::PerformRename>)
216 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
217 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
218 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
219 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
220 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
221 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
222 .add_message_handler(create_buffer_for_peer)
223 .add_request_handler(update_buffer)
224 .add_message_handler(update_buffer_file)
225 .add_message_handler(buffer_reloaded)
226 .add_message_handler(buffer_saved)
227 .add_request_handler(save_buffer)
228 .add_request_handler(get_users)
229 .add_request_handler(fuzzy_search_users)
230 .add_request_handler(request_contact)
231 .add_request_handler(remove_contact)
232 .add_request_handler(respond_to_contact_request)
233 .add_request_handler(follow)
234 .add_message_handler(unfollow)
235 .add_message_handler(update_followers)
236 .add_message_handler(update_diff_base)
237 .add_request_handler(get_private_user_info);
238
239 Arc::new(server)
240 }
241
242 pub async fn start(&self) -> Result<()> {
243 self.app_state.db.delete_stale_projects().await?;
244 let db = self.app_state.db.clone();
245 let peer = self.peer.clone();
246 let timeout = self.executor.sleep(RECONNECT_TIMEOUT);
247 let pool = self.connection_pool.clone();
248 let live_kit_client = self.app_state.live_kit_client.clone();
249 self.executor.spawn_detached(async move {
250 timeout.await;
251 if let Some(room_ids) = db.outdated_room_ids().await.trace_err() {
252 for room_id in room_ids {
253 let mut contacts_to_update = HashSet::default();
254 let mut canceled_calls_to_user_ids = Vec::new();
255 let mut live_kit_room = String::new();
256 let mut delete_live_kit_room = false;
257
258 if let Ok(mut refreshed_room) = db.refresh_room(room_id).await {
259 room_updated(&refreshed_room.room, &peer);
260 contacts_to_update
261 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
262 contacts_to_update
263 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
264 canceled_calls_to_user_ids =
265 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
266 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
267 delete_live_kit_room = refreshed_room.room.participants.is_empty();
268 }
269
270 {
271 let pool = pool.lock().await;
272 for canceled_user_id in canceled_calls_to_user_ids {
273 for connection_id in pool.user_connection_ids(canceled_user_id) {
274 peer.send(connection_id, proto::CallCanceled {}).trace_err();
275 }
276 }
277 }
278
279 for user_id in contacts_to_update {
280 if let Some((busy, contacts)) = db
281 .is_user_busy(user_id)
282 .await
283 .trace_err()
284 .zip(db.get_contacts(user_id).await.trace_err())
285 {
286 let pool = pool.lock().await;
287 let updated_contact = contact_for_user(user_id, false, busy, &pool);
288 for contact in contacts {
289 if let db::Contact::Accepted {
290 user_id: contact_user_id,
291 ..
292 } = contact
293 {
294 for contact_conn_id in pool.user_connection_ids(contact_user_id)
295 {
296 peer.send(
297 contact_conn_id,
298 proto::UpdateContacts {
299 contacts: vec![updated_contact.clone()],
300 remove_contacts: Default::default(),
301 incoming_requests: Default::default(),
302 remove_incoming_requests: Default::default(),
303 outgoing_requests: Default::default(),
304 remove_outgoing_requests: Default::default(),
305 },
306 )
307 .trace_err();
308 }
309 }
310 }
311 }
312 }
313
314 if let Some(live_kit) = live_kit_client.as_ref() {
315 if delete_live_kit_room {
316 live_kit.delete_room(live_kit_room).await.trace_err();
317 }
318 }
319 }
320 }
321 });
322 Ok(())
323 }
324
325 pub fn teardown(&self) {
326 self.peer.reset();
327 let _ = self.teardown.send(());
328 }
329
330 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
331 where
332 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
333 Fut: 'static + Send + Future<Output = Result<()>>,
334 M: EnvelopedMessage,
335 {
336 let prev_handler = self.handlers.insert(
337 TypeId::of::<M>(),
338 Box::new(move |envelope, session| {
339 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
340 let span = info_span!(
341 "handle message",
342 payload_type = envelope.payload_type_name()
343 );
344 span.in_scope(|| {
345 tracing::info!(
346 payload_type = envelope.payload_type_name(),
347 "message received"
348 );
349 });
350 let future = (handler)(*envelope, session);
351 async move {
352 if let Err(error) = future.await {
353 tracing::error!(%error, "error handling message");
354 }
355 }
356 .instrument(span)
357 .boxed()
358 }),
359 );
360 if prev_handler.is_some() {
361 panic!("registered a handler for the same message twice");
362 }
363 self
364 }
365
366 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
367 where
368 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
369 Fut: 'static + Send + Future<Output = Result<()>>,
370 M: EnvelopedMessage,
371 {
372 self.add_handler(move |envelope, session| handler(envelope.payload, session));
373 self
374 }
375
376 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
377 where
378 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
379 Fut: Send + Future<Output = Result<()>>,
380 M: RequestMessage,
381 {
382 let handler = Arc::new(handler);
383 self.add_handler(move |envelope, session| {
384 let receipt = envelope.receipt();
385 let handler = handler.clone();
386 async move {
387 let peer = session.peer.clone();
388 let responded = Arc::new(AtomicBool::default());
389 let response = Response {
390 peer: peer.clone(),
391 responded: responded.clone(),
392 receipt,
393 };
394 match (handler)(envelope.payload, response, session).await {
395 Ok(()) => {
396 if responded.load(std::sync::atomic::Ordering::SeqCst) {
397 Ok(())
398 } else {
399 Err(anyhow!("handler did not send a response"))?
400 }
401 }
402 Err(error) => {
403 peer.respond_with_error(
404 receipt,
405 proto::Error {
406 message: error.to_string(),
407 },
408 )?;
409 Err(error)
410 }
411 }
412 }
413 })
414 }
415
416 pub fn handle_connection(
417 self: &Arc<Self>,
418 connection: Connection,
419 address: String,
420 user: User,
421 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
422 executor: Executor,
423 ) -> impl Future<Output = Result<()>> {
424 let this = self.clone();
425 let user_id = user.id;
426 let login = user.github_login;
427 let span = info_span!("handle connection", %user_id, %login, %address);
428 let mut teardown = self.teardown.subscribe();
429 async move {
430 let (connection_id, handle_io, mut incoming_rx) = this
431 .peer
432 .add_connection(connection, {
433 let executor = executor.clone();
434 move |duration| executor.sleep(duration)
435 });
436
437 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
438 this.peer.send(connection_id, proto::Hello { peer_id: connection_id.0 })?;
439 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
440
441 if let Some(send_connection_id) = send_connection_id.take() {
442 let _ = send_connection_id.send(connection_id);
443 }
444
445 if !user.connected_once {
446 this.peer.send(connection_id, proto::ShowContacts {})?;
447 this.app_state.db.set_user_connected_once(user_id, true).await?;
448 }
449
450 let (contacts, invite_code) = future::try_join(
451 this.app_state.db.get_contacts(user_id),
452 this.app_state.db.get_invite_code_for_user(user_id)
453 ).await?;
454
455 {
456 let mut pool = this.connection_pool.lock().await;
457 pool.add_connection(connection_id, user_id, user.admin);
458 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
459
460 if let Some((code, count)) = invite_code {
461 this.peer.send(connection_id, proto::UpdateInviteInfo {
462 url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
463 count: count as u32,
464 })?;
465 }
466 }
467
468 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
469 this.peer.send(connection_id, incoming_call)?;
470 }
471
472 let session = Session {
473 user_id,
474 connection_id,
475 db: Arc::new(Mutex::new(DbHandle(this.app_state.db.clone()))),
476 peer: this.peer.clone(),
477 connection_pool: this.connection_pool.clone(),
478 live_kit_client: this.app_state.live_kit_client.clone()
479 };
480 update_user_contacts(user_id, &session).await?;
481
482 let handle_io = handle_io.fuse();
483 futures::pin_mut!(handle_io);
484
485 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
486 // This prevents deadlocks when e.g., client A performs a request to client B and
487 // client B performs a request to client A. If both clients stop processing further
488 // messages until their respective request completes, they won't have a chance to
489 // respond to the other client's request and cause a deadlock.
490 //
491 // This arrangement ensures we will attempt to process earlier messages first, but fall
492 // back to processing messages arrived later in the spirit of making progress.
493 let mut foreground_message_handlers = FuturesUnordered::new();
494 loop {
495 let next_message = incoming_rx.next().fuse();
496 futures::pin_mut!(next_message);
497 futures::select_biased! {
498 _ = teardown.changed().fuse() => return Ok(()),
499 result = handle_io => {
500 if let Err(error) = result {
501 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
502 }
503 break;
504 }
505 _ = foreground_message_handlers.next() => {}
506 message = next_message => {
507 if let Some(message) = message {
508 let type_name = message.payload_type_name();
509 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
510 let span_enter = span.enter();
511 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
512 let is_background = message.is_background();
513 let handle_message = (handler)(message, session.clone());
514 drop(span_enter);
515
516 let handle_message = handle_message.instrument(span);
517 if is_background {
518 executor.spawn_detached(handle_message);
519 } else {
520 foreground_message_handlers.push(handle_message);
521 }
522 } else {
523 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
524 }
525 } else {
526 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
527 break;
528 }
529 }
530 }
531 }
532
533 drop(foreground_message_handlers);
534 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
535 if let Err(error) = sign_out(session, teardown, executor).await {
536 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
537 }
538
539 Ok(())
540 }.instrument(span)
541 }
542
543 pub async fn invite_code_redeemed(
544 self: &Arc<Self>,
545 inviter_id: UserId,
546 invitee_id: UserId,
547 ) -> Result<()> {
548 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
549 if let Some(code) = &user.invite_code {
550 let pool = self.connection_pool.lock().await;
551 let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
552 for connection_id in pool.user_connection_ids(inviter_id) {
553 self.peer.send(
554 connection_id,
555 proto::UpdateContacts {
556 contacts: vec![invitee_contact.clone()],
557 ..Default::default()
558 },
559 )?;
560 self.peer.send(
561 connection_id,
562 proto::UpdateInviteInfo {
563 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
564 count: user.invite_count as u32,
565 },
566 )?;
567 }
568 }
569 }
570 Ok(())
571 }
572
573 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
574 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
575 if let Some(invite_code) = &user.invite_code {
576 let pool = self.connection_pool.lock().await;
577 for connection_id in pool.user_connection_ids(user_id) {
578 self.peer.send(
579 connection_id,
580 proto::UpdateInviteInfo {
581 url: format!(
582 "{}{}",
583 self.app_state.config.invite_link_prefix, invite_code
584 ),
585 count: user.invite_count as u32,
586 },
587 )?;
588 }
589 }
590 }
591 Ok(())
592 }
593
594 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
595 ServerSnapshot {
596 connection_pool: ConnectionPoolGuard {
597 guard: self.connection_pool.lock().await,
598 _not_send: PhantomData,
599 },
600 peer: &self.peer,
601 }
602 }
603}
604
605impl<'a> Deref for ConnectionPoolGuard<'a> {
606 type Target = ConnectionPool;
607
608 fn deref(&self) -> &Self::Target {
609 &*self.guard
610 }
611}
612
613impl<'a> DerefMut for ConnectionPoolGuard<'a> {
614 fn deref_mut(&mut self) -> &mut Self::Target {
615 &mut *self.guard
616 }
617}
618
619impl<'a> Drop for ConnectionPoolGuard<'a> {
620 fn drop(&mut self) {
621 #[cfg(test)]
622 self.check_invariants();
623 }
624}
625
626fn broadcast<F>(
627 sender_id: ConnectionId,
628 receiver_ids: impl IntoIterator<Item = ConnectionId>,
629 mut f: F,
630) where
631 F: FnMut(ConnectionId) -> anyhow::Result<()>,
632{
633 for receiver_id in receiver_ids {
634 if receiver_id != sender_id {
635 f(receiver_id).trace_err();
636 }
637 }
638}
639
640lazy_static! {
641 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
642}
643
644pub struct ProtocolVersion(u32);
645
646impl Header for ProtocolVersion {
647 fn name() -> &'static HeaderName {
648 &ZED_PROTOCOL_VERSION
649 }
650
651 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
652 where
653 Self: Sized,
654 I: Iterator<Item = &'i axum::http::HeaderValue>,
655 {
656 let version = values
657 .next()
658 .ok_or_else(axum::headers::Error::invalid)?
659 .to_str()
660 .map_err(|_| axum::headers::Error::invalid())?
661 .parse()
662 .map_err(|_| axum::headers::Error::invalid())?;
663 Ok(Self(version))
664 }
665
666 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
667 values.extend([self.0.to_string().parse().unwrap()]);
668 }
669}
670
671pub fn routes(server: Arc<Server>) -> Router<Body> {
672 Router::new()
673 .route("/rpc", get(handle_websocket_request))
674 .layer(
675 ServiceBuilder::new()
676 .layer(Extension(server.app_state.clone()))
677 .layer(middleware::from_fn(auth::validate_header)),
678 )
679 .route("/metrics", get(handle_metrics))
680 .layer(Extension(server))
681}
682
683pub async fn handle_websocket_request(
684 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
685 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
686 Extension(server): Extension<Arc<Server>>,
687 Extension(user): Extension<User>,
688 ws: WebSocketUpgrade,
689) -> axum::response::Response {
690 if protocol_version != rpc::PROTOCOL_VERSION {
691 return (
692 StatusCode::UPGRADE_REQUIRED,
693 "client must be upgraded".to_string(),
694 )
695 .into_response();
696 }
697 let socket_address = socket_address.to_string();
698 ws.on_upgrade(move |socket| {
699 use util::ResultExt;
700 let socket = socket
701 .map_ok(to_tungstenite_message)
702 .err_into()
703 .with(|message| async move { Ok(to_axum_message(message)) });
704 let connection = Connection::new(Box::pin(socket));
705 async move {
706 server
707 .handle_connection(connection, socket_address, user, None, Executor::Production)
708 .await
709 .log_err();
710 }
711 })
712}
713
714pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
715 let connections = server
716 .connection_pool
717 .lock()
718 .await
719 .connections()
720 .filter(|connection| !connection.admin)
721 .count();
722
723 METRIC_CONNECTIONS.set(connections as _);
724
725 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
726 METRIC_SHARED_PROJECTS.set(shared_projects as _);
727
728 let encoder = prometheus::TextEncoder::new();
729 let metric_families = prometheus::gather();
730 let encoded_metrics = encoder
731 .encode_to_string(&metric_families)
732 .map_err(|err| anyhow!("{}", err))?;
733 Ok(encoded_metrics)
734}
735
736#[instrument(err, skip(executor))]
737async fn sign_out(
738 session: Session,
739 mut teardown: watch::Receiver<()>,
740 executor: Executor,
741) -> Result<()> {
742 session.peer.disconnect(session.connection_id);
743 session
744 .connection_pool()
745 .await
746 .remove_connection(session.connection_id)?;
747
748 if let Some(mut left_projects) = session
749 .db()
750 .await
751 .connection_lost(session.connection_id)
752 .await
753 .trace_err()
754 {
755 for left_project in mem::take(&mut *left_projects) {
756 project_left(&left_project, &session);
757 }
758 }
759
760 futures::select_biased! {
761 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
762 leave_room_for_session(&session).await.trace_err();
763
764 if !session
765 .connection_pool()
766 .await
767 .is_user_online(session.user_id)
768 {
769 let db = session.db().await;
770 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err() {
771 room_updated(&room, &session.peer);
772 }
773 }
774 update_user_contacts(session.user_id, &session).await?;
775 }
776 _ = teardown.changed().fuse() => {}
777 }
778
779 Ok(())
780}
781
782async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
783 response.send(proto::Ack {})?;
784 Ok(())
785}
786
787async fn create_room(
788 _request: proto::CreateRoom,
789 response: Response<proto::CreateRoom>,
790 session: Session,
791) -> Result<()> {
792 let live_kit_room = nanoid::nanoid!(30);
793 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
794 if let Some(_) = live_kit
795 .create_room(live_kit_room.clone())
796 .await
797 .trace_err()
798 {
799 if let Some(token) = live_kit
800 .room_token(&live_kit_room, &session.connection_id.to_string())
801 .trace_err()
802 {
803 Some(proto::LiveKitConnectionInfo {
804 server_url: live_kit.url().into(),
805 token,
806 })
807 } else {
808 None
809 }
810 } else {
811 None
812 }
813 } else {
814 None
815 };
816
817 {
818 let room = session
819 .db()
820 .await
821 .create_room(session.user_id, session.connection_id, &live_kit_room)
822 .await?;
823
824 response.send(proto::CreateRoomResponse {
825 room: Some(room.clone()),
826 live_kit_connection_info,
827 })?;
828 }
829
830 update_user_contacts(session.user_id, &session).await?;
831 Ok(())
832}
833
834async fn join_room(
835 request: proto::JoinRoom,
836 response: Response<proto::JoinRoom>,
837 session: Session,
838) -> Result<()> {
839 let room = {
840 let room = session
841 .db()
842 .await
843 .join_room(
844 RoomId::from_proto(request.id),
845 session.user_id,
846 session.connection_id,
847 )
848 .await?;
849 room_updated(&room, &session.peer);
850 room.clone()
851 };
852
853 for connection_id in session
854 .connection_pool()
855 .await
856 .user_connection_ids(session.user_id)
857 {
858 session
859 .peer
860 .send(connection_id, proto::CallCanceled {})
861 .trace_err();
862 }
863
864 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
865 if let Some(token) = live_kit
866 .room_token(&room.live_kit_room, &session.connection_id.to_string())
867 .trace_err()
868 {
869 Some(proto::LiveKitConnectionInfo {
870 server_url: live_kit.url().into(),
871 token,
872 })
873 } else {
874 None
875 }
876 } else {
877 None
878 };
879
880 response.send(proto::JoinRoomResponse {
881 room: Some(room),
882 live_kit_connection_info,
883 })?;
884
885 update_user_contacts(session.user_id, &session).await?;
886 Ok(())
887}
888
889async fn leave_room(_message: proto::LeaveRoom, session: Session) -> Result<()> {
890 leave_room_for_session(&session).await
891}
892
893async fn call(
894 request: proto::Call,
895 response: Response<proto::Call>,
896 session: Session,
897) -> Result<()> {
898 let room_id = RoomId::from_proto(request.room_id);
899 let calling_user_id = session.user_id;
900 let calling_connection_id = session.connection_id;
901 let called_user_id = UserId::from_proto(request.called_user_id);
902 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
903 if !session
904 .db()
905 .await
906 .has_contact(calling_user_id, called_user_id)
907 .await?
908 {
909 return Err(anyhow!("cannot call a user who isn't a contact"))?;
910 }
911
912 let incoming_call = {
913 let (room, incoming_call) = &mut *session
914 .db()
915 .await
916 .call(
917 room_id,
918 calling_user_id,
919 calling_connection_id,
920 called_user_id,
921 initial_project_id,
922 )
923 .await?;
924 room_updated(&room, &session.peer);
925 mem::take(incoming_call)
926 };
927 update_user_contacts(called_user_id, &session).await?;
928
929 let mut calls = session
930 .connection_pool()
931 .await
932 .user_connection_ids(called_user_id)
933 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
934 .collect::<FuturesUnordered<_>>();
935
936 while let Some(call_response) = calls.next().await {
937 match call_response.as_ref() {
938 Ok(_) => {
939 response.send(proto::Ack {})?;
940 return Ok(());
941 }
942 Err(_) => {
943 call_response.trace_err();
944 }
945 }
946 }
947
948 {
949 let room = session
950 .db()
951 .await
952 .call_failed(room_id, called_user_id)
953 .await?;
954 room_updated(&room, &session.peer);
955 }
956 update_user_contacts(called_user_id, &session).await?;
957
958 Err(anyhow!("failed to ring user"))?
959}
960
961async fn cancel_call(
962 request: proto::CancelCall,
963 response: Response<proto::CancelCall>,
964 session: Session,
965) -> Result<()> {
966 let called_user_id = UserId::from_proto(request.called_user_id);
967 let room_id = RoomId::from_proto(request.room_id);
968 {
969 let room = session
970 .db()
971 .await
972 .cancel_call(Some(room_id), session.connection_id, called_user_id)
973 .await?;
974 room_updated(&room, &session.peer);
975 }
976
977 for connection_id in session
978 .connection_pool()
979 .await
980 .user_connection_ids(called_user_id)
981 {
982 session
983 .peer
984 .send(connection_id, proto::CallCanceled {})
985 .trace_err();
986 }
987 response.send(proto::Ack {})?;
988
989 update_user_contacts(called_user_id, &session).await?;
990 Ok(())
991}
992
993async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
994 let room_id = RoomId::from_proto(message.room_id);
995 {
996 let room = session
997 .db()
998 .await
999 .decline_call(Some(room_id), session.user_id)
1000 .await?;
1001 room_updated(&room, &session.peer);
1002 }
1003
1004 for connection_id in session
1005 .connection_pool()
1006 .await
1007 .user_connection_ids(session.user_id)
1008 {
1009 session
1010 .peer
1011 .send(connection_id, proto::CallCanceled {})
1012 .trace_err();
1013 }
1014 update_user_contacts(session.user_id, &session).await?;
1015 Ok(())
1016}
1017
1018async fn update_participant_location(
1019 request: proto::UpdateParticipantLocation,
1020 response: Response<proto::UpdateParticipantLocation>,
1021 session: Session,
1022) -> Result<()> {
1023 let room_id = RoomId::from_proto(request.room_id);
1024 let location = request
1025 .location
1026 .ok_or_else(|| anyhow!("invalid location"))?;
1027 let room = session
1028 .db()
1029 .await
1030 .update_room_participant_location(room_id, session.connection_id, location)
1031 .await?;
1032 room_updated(&room, &session.peer);
1033 response.send(proto::Ack {})?;
1034 Ok(())
1035}
1036
1037async fn share_project(
1038 request: proto::ShareProject,
1039 response: Response<proto::ShareProject>,
1040 session: Session,
1041) -> Result<()> {
1042 let (project_id, room) = &*session
1043 .db()
1044 .await
1045 .share_project(
1046 RoomId::from_proto(request.room_id),
1047 session.connection_id,
1048 &request.worktrees,
1049 )
1050 .await?;
1051 response.send(proto::ShareProjectResponse {
1052 project_id: project_id.to_proto(),
1053 })?;
1054 room_updated(&room, &session.peer);
1055
1056 Ok(())
1057}
1058
1059async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1060 let project_id = ProjectId::from_proto(message.project_id);
1061
1062 let (room, guest_connection_ids) = &*session
1063 .db()
1064 .await
1065 .unshare_project(project_id, session.connection_id)
1066 .await?;
1067
1068 broadcast(
1069 session.connection_id,
1070 guest_connection_ids.iter().copied(),
1071 |conn_id| session.peer.send(conn_id, message.clone()),
1072 );
1073 room_updated(&room, &session.peer);
1074
1075 Ok(())
1076}
1077
1078async fn join_project(
1079 request: proto::JoinProject,
1080 response: Response<proto::JoinProject>,
1081 session: Session,
1082) -> Result<()> {
1083 let project_id = ProjectId::from_proto(request.project_id);
1084 let guest_user_id = session.user_id;
1085
1086 tracing::info!(%project_id, "join project");
1087
1088 let (project, replica_id) = &mut *session
1089 .db()
1090 .await
1091 .join_project(project_id, session.connection_id)
1092 .await?;
1093
1094 let collaborators = project
1095 .collaborators
1096 .iter()
1097 .filter(|collaborator| collaborator.connection_id != session.connection_id.0 as i32)
1098 .map(|collaborator| proto::Collaborator {
1099 peer_id: collaborator.connection_id as u32,
1100 replica_id: collaborator.replica_id.0 as u32,
1101 user_id: collaborator.user_id.to_proto(),
1102 })
1103 .collect::<Vec<_>>();
1104 let worktrees = project
1105 .worktrees
1106 .iter()
1107 .map(|(id, worktree)| proto::WorktreeMetadata {
1108 id: *id,
1109 root_name: worktree.root_name.clone(),
1110 visible: worktree.visible,
1111 abs_path: worktree.abs_path.clone(),
1112 })
1113 .collect::<Vec<_>>();
1114
1115 for collaborator in &collaborators {
1116 session
1117 .peer
1118 .send(
1119 ConnectionId(collaborator.peer_id),
1120 proto::AddProjectCollaborator {
1121 project_id: project_id.to_proto(),
1122 collaborator: Some(proto::Collaborator {
1123 peer_id: session.connection_id.0,
1124 replica_id: replica_id.0 as u32,
1125 user_id: guest_user_id.to_proto(),
1126 }),
1127 },
1128 )
1129 .trace_err();
1130 }
1131
1132 // First, we send the metadata associated with each worktree.
1133 response.send(proto::JoinProjectResponse {
1134 worktrees: worktrees.clone(),
1135 replica_id: replica_id.0 as u32,
1136 collaborators: collaborators.clone(),
1137 language_servers: project.language_servers.clone(),
1138 })?;
1139
1140 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1141 #[cfg(any(test, feature = "test-support"))]
1142 const MAX_CHUNK_SIZE: usize = 2;
1143 #[cfg(not(any(test, feature = "test-support")))]
1144 const MAX_CHUNK_SIZE: usize = 256;
1145
1146 // Stream this worktree's entries.
1147 let message = proto::UpdateWorktree {
1148 project_id: project_id.to_proto(),
1149 worktree_id,
1150 abs_path: worktree.abs_path.clone(),
1151 root_name: worktree.root_name,
1152 updated_entries: worktree.entries,
1153 removed_entries: Default::default(),
1154 scan_id: worktree.scan_id,
1155 is_last_update: worktree.is_complete,
1156 };
1157 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1158 session.peer.send(session.connection_id, update.clone())?;
1159 }
1160
1161 // Stream this worktree's diagnostics.
1162 for summary in worktree.diagnostic_summaries {
1163 session.peer.send(
1164 session.connection_id,
1165 proto::UpdateDiagnosticSummary {
1166 project_id: project_id.to_proto(),
1167 worktree_id: worktree.id,
1168 summary: Some(summary),
1169 },
1170 )?;
1171 }
1172 }
1173
1174 for language_server in &project.language_servers {
1175 session.peer.send(
1176 session.connection_id,
1177 proto::UpdateLanguageServer {
1178 project_id: project_id.to_proto(),
1179 language_server_id: language_server.id,
1180 variant: Some(
1181 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1182 proto::LspDiskBasedDiagnosticsUpdated {},
1183 ),
1184 ),
1185 },
1186 )?;
1187 }
1188
1189 Ok(())
1190}
1191
1192async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1193 let sender_id = session.connection_id;
1194 let project_id = ProjectId::from_proto(request.project_id);
1195
1196 let project = session
1197 .db()
1198 .await
1199 .leave_project(project_id, sender_id)
1200 .await?;
1201 tracing::info!(
1202 %project_id,
1203 host_user_id = %project.host_user_id,
1204 host_connection_id = %project.host_connection_id,
1205 "leave project"
1206 );
1207 project_left(&project, &session);
1208
1209 Ok(())
1210}
1211
1212async fn update_project(
1213 request: proto::UpdateProject,
1214 response: Response<proto::UpdateProject>,
1215 session: Session,
1216) -> Result<()> {
1217 let project_id = ProjectId::from_proto(request.project_id);
1218 let (room, guest_connection_ids) = &*session
1219 .db()
1220 .await
1221 .update_project(project_id, session.connection_id, &request.worktrees)
1222 .await?;
1223 broadcast(
1224 session.connection_id,
1225 guest_connection_ids.iter().copied(),
1226 |connection_id| {
1227 session
1228 .peer
1229 .forward_send(session.connection_id, connection_id, request.clone())
1230 },
1231 );
1232 room_updated(&room, &session.peer);
1233 response.send(proto::Ack {})?;
1234
1235 Ok(())
1236}
1237
1238async fn update_worktree(
1239 request: proto::UpdateWorktree,
1240 response: Response<proto::UpdateWorktree>,
1241 session: Session,
1242) -> Result<()> {
1243 let guest_connection_ids = session
1244 .db()
1245 .await
1246 .update_worktree(&request, session.connection_id)
1247 .await?;
1248
1249 broadcast(
1250 session.connection_id,
1251 guest_connection_ids.iter().copied(),
1252 |connection_id| {
1253 session
1254 .peer
1255 .forward_send(session.connection_id, connection_id, request.clone())
1256 },
1257 );
1258 response.send(proto::Ack {})?;
1259 Ok(())
1260}
1261
1262async fn update_diagnostic_summary(
1263 message: proto::UpdateDiagnosticSummary,
1264 session: Session,
1265) -> Result<()> {
1266 let guest_connection_ids = session
1267 .db()
1268 .await
1269 .update_diagnostic_summary(&message, session.connection_id)
1270 .await?;
1271
1272 broadcast(
1273 session.connection_id,
1274 guest_connection_ids.iter().copied(),
1275 |connection_id| {
1276 session
1277 .peer
1278 .forward_send(session.connection_id, connection_id, message.clone())
1279 },
1280 );
1281
1282 Ok(())
1283}
1284
1285async fn start_language_server(
1286 request: proto::StartLanguageServer,
1287 session: Session,
1288) -> Result<()> {
1289 let guest_connection_ids = session
1290 .db()
1291 .await
1292 .start_language_server(&request, session.connection_id)
1293 .await?;
1294
1295 broadcast(
1296 session.connection_id,
1297 guest_connection_ids.iter().copied(),
1298 |connection_id| {
1299 session
1300 .peer
1301 .forward_send(session.connection_id, connection_id, request.clone())
1302 },
1303 );
1304 Ok(())
1305}
1306
1307async fn update_language_server(
1308 request: proto::UpdateLanguageServer,
1309 session: Session,
1310) -> Result<()> {
1311 let project_id = ProjectId::from_proto(request.project_id);
1312 let project_connection_ids = session
1313 .db()
1314 .await
1315 .project_connection_ids(project_id, session.connection_id)
1316 .await?;
1317 broadcast(
1318 session.connection_id,
1319 project_connection_ids.iter().copied(),
1320 |connection_id| {
1321 session
1322 .peer
1323 .forward_send(session.connection_id, connection_id, request.clone())
1324 },
1325 );
1326 Ok(())
1327}
1328
1329async fn forward_project_request<T>(
1330 request: T,
1331 response: Response<T>,
1332 session: Session,
1333) -> Result<()>
1334where
1335 T: EntityMessage + RequestMessage,
1336{
1337 let project_id = ProjectId::from_proto(request.remote_entity_id());
1338 let host_connection_id = {
1339 let collaborators = session
1340 .db()
1341 .await
1342 .project_collaborators(project_id, session.connection_id)
1343 .await?;
1344 ConnectionId(
1345 collaborators
1346 .iter()
1347 .find(|collaborator| collaborator.is_host)
1348 .ok_or_else(|| anyhow!("host not found"))?
1349 .connection_id as u32,
1350 )
1351 };
1352
1353 let payload = session
1354 .peer
1355 .forward_request(session.connection_id, host_connection_id, request)
1356 .await?;
1357
1358 response.send(payload)?;
1359 Ok(())
1360}
1361
1362async fn save_buffer(
1363 request: proto::SaveBuffer,
1364 response: Response<proto::SaveBuffer>,
1365 session: Session,
1366) -> Result<()> {
1367 let project_id = ProjectId::from_proto(request.project_id);
1368 let host_connection_id = {
1369 let collaborators = session
1370 .db()
1371 .await
1372 .project_collaborators(project_id, session.connection_id)
1373 .await?;
1374 let host = collaborators
1375 .iter()
1376 .find(|collaborator| collaborator.is_host)
1377 .ok_or_else(|| anyhow!("host not found"))?;
1378 ConnectionId(host.connection_id as u32)
1379 };
1380 let response_payload = session
1381 .peer
1382 .forward_request(session.connection_id, host_connection_id, request.clone())
1383 .await?;
1384
1385 let mut collaborators = session
1386 .db()
1387 .await
1388 .project_collaborators(project_id, session.connection_id)
1389 .await?;
1390 collaborators
1391 .retain(|collaborator| collaborator.connection_id != session.connection_id.0 as i32);
1392 let project_connection_ids = collaborators
1393 .iter()
1394 .map(|collaborator| ConnectionId(collaborator.connection_id as u32));
1395 broadcast(host_connection_id, project_connection_ids, |conn_id| {
1396 session
1397 .peer
1398 .forward_send(host_connection_id, conn_id, response_payload.clone())
1399 });
1400 response.send(response_payload)?;
1401 Ok(())
1402}
1403
1404async fn create_buffer_for_peer(
1405 request: proto::CreateBufferForPeer,
1406 session: Session,
1407) -> Result<()> {
1408 session.peer.forward_send(
1409 session.connection_id,
1410 ConnectionId(request.peer_id),
1411 request,
1412 )?;
1413 Ok(())
1414}
1415
1416async fn update_buffer(
1417 request: proto::UpdateBuffer,
1418 response: Response<proto::UpdateBuffer>,
1419 session: Session,
1420) -> Result<()> {
1421 let project_id = ProjectId::from_proto(request.project_id);
1422 let project_connection_ids = session
1423 .db()
1424 .await
1425 .project_connection_ids(project_id, session.connection_id)
1426 .await?;
1427
1428 broadcast(
1429 session.connection_id,
1430 project_connection_ids.iter().copied(),
1431 |connection_id| {
1432 session
1433 .peer
1434 .forward_send(session.connection_id, connection_id, request.clone())
1435 },
1436 );
1437 response.send(proto::Ack {})?;
1438 Ok(())
1439}
1440
1441async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1442 let project_id = ProjectId::from_proto(request.project_id);
1443 let project_connection_ids = session
1444 .db()
1445 .await
1446 .project_connection_ids(project_id, session.connection_id)
1447 .await?;
1448
1449 broadcast(
1450 session.connection_id,
1451 project_connection_ids.iter().copied(),
1452 |connection_id| {
1453 session
1454 .peer
1455 .forward_send(session.connection_id, connection_id, request.clone())
1456 },
1457 );
1458 Ok(())
1459}
1460
1461async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1462 let project_id = ProjectId::from_proto(request.project_id);
1463 let project_connection_ids = session
1464 .db()
1465 .await
1466 .project_connection_ids(project_id, session.connection_id)
1467 .await?;
1468 broadcast(
1469 session.connection_id,
1470 project_connection_ids.iter().copied(),
1471 |connection_id| {
1472 session
1473 .peer
1474 .forward_send(session.connection_id, connection_id, request.clone())
1475 },
1476 );
1477 Ok(())
1478}
1479
1480async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1481 let project_id = ProjectId::from_proto(request.project_id);
1482 let project_connection_ids = session
1483 .db()
1484 .await
1485 .project_connection_ids(project_id, session.connection_id)
1486 .await?;
1487 broadcast(
1488 session.connection_id,
1489 project_connection_ids.iter().copied(),
1490 |connection_id| {
1491 session
1492 .peer
1493 .forward_send(session.connection_id, connection_id, request.clone())
1494 },
1495 );
1496 Ok(())
1497}
1498
1499async fn follow(
1500 request: proto::Follow,
1501 response: Response<proto::Follow>,
1502 session: Session,
1503) -> Result<()> {
1504 let project_id = ProjectId::from_proto(request.project_id);
1505 let leader_id = ConnectionId(request.leader_id);
1506 let follower_id = session.connection_id;
1507 {
1508 let project_connection_ids = session
1509 .db()
1510 .await
1511 .project_connection_ids(project_id, session.connection_id)
1512 .await?;
1513
1514 if !project_connection_ids.contains(&leader_id) {
1515 Err(anyhow!("no such peer"))?;
1516 }
1517 }
1518
1519 let mut response_payload = session
1520 .peer
1521 .forward_request(session.connection_id, leader_id, request)
1522 .await?;
1523 response_payload
1524 .views
1525 .retain(|view| view.leader_id != Some(follower_id.0));
1526 response.send(response_payload)?;
1527 Ok(())
1528}
1529
1530async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1531 let project_id = ProjectId::from_proto(request.project_id);
1532 let leader_id = ConnectionId(request.leader_id);
1533 let project_connection_ids = session
1534 .db()
1535 .await
1536 .project_connection_ids(project_id, session.connection_id)
1537 .await?;
1538 if !project_connection_ids.contains(&leader_id) {
1539 Err(anyhow!("no such peer"))?;
1540 }
1541 session
1542 .peer
1543 .forward_send(session.connection_id, leader_id, request)?;
1544 Ok(())
1545}
1546
1547async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1548 let project_id = ProjectId::from_proto(request.project_id);
1549 let project_connection_ids = session
1550 .db
1551 .lock()
1552 .await
1553 .project_connection_ids(project_id, session.connection_id)
1554 .await?;
1555
1556 let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1557 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1558 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1559 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1560 });
1561 for follower_id in &request.follower_ids {
1562 let follower_id = ConnectionId(*follower_id);
1563 if project_connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
1564 session
1565 .peer
1566 .forward_send(session.connection_id, follower_id, request.clone())?;
1567 }
1568 }
1569 Ok(())
1570}
1571
1572async fn get_users(
1573 request: proto::GetUsers,
1574 response: Response<proto::GetUsers>,
1575 session: Session,
1576) -> Result<()> {
1577 let user_ids = request
1578 .user_ids
1579 .into_iter()
1580 .map(UserId::from_proto)
1581 .collect();
1582 let users = session
1583 .db()
1584 .await
1585 .get_users_by_ids(user_ids)
1586 .await?
1587 .into_iter()
1588 .map(|user| proto::User {
1589 id: user.id.to_proto(),
1590 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1591 github_login: user.github_login,
1592 })
1593 .collect();
1594 response.send(proto::UsersResponse { users })?;
1595 Ok(())
1596}
1597
1598async fn fuzzy_search_users(
1599 request: proto::FuzzySearchUsers,
1600 response: Response<proto::FuzzySearchUsers>,
1601 session: Session,
1602) -> Result<()> {
1603 let query = request.query;
1604 let users = match query.len() {
1605 0 => vec![],
1606 1 | 2 => session
1607 .db()
1608 .await
1609 .get_user_by_github_account(&query, None)
1610 .await?
1611 .into_iter()
1612 .collect(),
1613 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1614 };
1615 let users = users
1616 .into_iter()
1617 .filter(|user| user.id != session.user_id)
1618 .map(|user| proto::User {
1619 id: user.id.to_proto(),
1620 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1621 github_login: user.github_login,
1622 })
1623 .collect();
1624 response.send(proto::UsersResponse { users })?;
1625 Ok(())
1626}
1627
1628async fn request_contact(
1629 request: proto::RequestContact,
1630 response: Response<proto::RequestContact>,
1631 session: Session,
1632) -> Result<()> {
1633 let requester_id = session.user_id;
1634 let responder_id = UserId::from_proto(request.responder_id);
1635 if requester_id == responder_id {
1636 return Err(anyhow!("cannot add yourself as a contact"))?;
1637 }
1638
1639 session
1640 .db()
1641 .await
1642 .send_contact_request(requester_id, responder_id)
1643 .await?;
1644
1645 // Update outgoing contact requests of requester
1646 let mut update = proto::UpdateContacts::default();
1647 update.outgoing_requests.push(responder_id.to_proto());
1648 for connection_id in session
1649 .connection_pool()
1650 .await
1651 .user_connection_ids(requester_id)
1652 {
1653 session.peer.send(connection_id, update.clone())?;
1654 }
1655
1656 // Update incoming contact requests of responder
1657 let mut update = proto::UpdateContacts::default();
1658 update
1659 .incoming_requests
1660 .push(proto::IncomingContactRequest {
1661 requester_id: requester_id.to_proto(),
1662 should_notify: true,
1663 });
1664 for connection_id in session
1665 .connection_pool()
1666 .await
1667 .user_connection_ids(responder_id)
1668 {
1669 session.peer.send(connection_id, update.clone())?;
1670 }
1671
1672 response.send(proto::Ack {})?;
1673 Ok(())
1674}
1675
1676async fn respond_to_contact_request(
1677 request: proto::RespondToContactRequest,
1678 response: Response<proto::RespondToContactRequest>,
1679 session: Session,
1680) -> Result<()> {
1681 let responder_id = session.user_id;
1682 let requester_id = UserId::from_proto(request.requester_id);
1683 let db = session.db().await;
1684 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
1685 db.dismiss_contact_notification(responder_id, requester_id)
1686 .await?;
1687 } else {
1688 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
1689
1690 db.respond_to_contact_request(responder_id, requester_id, accept)
1691 .await?;
1692 let requester_busy = db.is_user_busy(requester_id).await?;
1693 let responder_busy = db.is_user_busy(responder_id).await?;
1694
1695 let pool = session.connection_pool().await;
1696 // Update responder with new contact
1697 let mut update = proto::UpdateContacts::default();
1698 if accept {
1699 update
1700 .contacts
1701 .push(contact_for_user(requester_id, false, requester_busy, &pool));
1702 }
1703 update
1704 .remove_incoming_requests
1705 .push(requester_id.to_proto());
1706 for connection_id in pool.user_connection_ids(responder_id) {
1707 session.peer.send(connection_id, update.clone())?;
1708 }
1709
1710 // Update requester with new contact
1711 let mut update = proto::UpdateContacts::default();
1712 if accept {
1713 update
1714 .contacts
1715 .push(contact_for_user(responder_id, true, responder_busy, &pool));
1716 }
1717 update
1718 .remove_outgoing_requests
1719 .push(responder_id.to_proto());
1720 for connection_id in pool.user_connection_ids(requester_id) {
1721 session.peer.send(connection_id, update.clone())?;
1722 }
1723 }
1724
1725 response.send(proto::Ack {})?;
1726 Ok(())
1727}
1728
1729async fn remove_contact(
1730 request: proto::RemoveContact,
1731 response: Response<proto::RemoveContact>,
1732 session: Session,
1733) -> Result<()> {
1734 let requester_id = session.user_id;
1735 let responder_id = UserId::from_proto(request.user_id);
1736 let db = session.db().await;
1737 db.remove_contact(requester_id, responder_id).await?;
1738
1739 let pool = session.connection_pool().await;
1740 // Update outgoing contact requests of requester
1741 let mut update = proto::UpdateContacts::default();
1742 update
1743 .remove_outgoing_requests
1744 .push(responder_id.to_proto());
1745 for connection_id in pool.user_connection_ids(requester_id) {
1746 session.peer.send(connection_id, update.clone())?;
1747 }
1748
1749 // Update incoming contact requests of responder
1750 let mut update = proto::UpdateContacts::default();
1751 update
1752 .remove_incoming_requests
1753 .push(requester_id.to_proto());
1754 for connection_id in pool.user_connection_ids(responder_id) {
1755 session.peer.send(connection_id, update.clone())?;
1756 }
1757
1758 response.send(proto::Ack {})?;
1759 Ok(())
1760}
1761
1762async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
1763 let project_id = ProjectId::from_proto(request.project_id);
1764 let project_connection_ids = session
1765 .db()
1766 .await
1767 .project_connection_ids(project_id, session.connection_id)
1768 .await?;
1769 broadcast(
1770 session.connection_id,
1771 project_connection_ids.iter().copied(),
1772 |connection_id| {
1773 session
1774 .peer
1775 .forward_send(session.connection_id, connection_id, request.clone())
1776 },
1777 );
1778 Ok(())
1779}
1780
1781async fn get_private_user_info(
1782 _request: proto::GetPrivateUserInfo,
1783 response: Response<proto::GetPrivateUserInfo>,
1784 session: Session,
1785) -> Result<()> {
1786 let metrics_id = session
1787 .db()
1788 .await
1789 .get_user_metrics_id(session.user_id)
1790 .await?;
1791 let user = session
1792 .db()
1793 .await
1794 .get_user_by_id(session.user_id)
1795 .await?
1796 .ok_or_else(|| anyhow!("user not found"))?;
1797 response.send(proto::GetPrivateUserInfoResponse {
1798 metrics_id,
1799 staff: user.admin,
1800 })?;
1801 Ok(())
1802}
1803
1804fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1805 match message {
1806 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1807 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1808 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1809 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1810 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1811 code: frame.code.into(),
1812 reason: frame.reason,
1813 })),
1814 }
1815}
1816
1817fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1818 match message {
1819 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1820 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1821 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1822 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1823 AxumMessage::Close(frame) => {
1824 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1825 code: frame.code.into(),
1826 reason: frame.reason,
1827 }))
1828 }
1829 }
1830}
1831
1832fn build_initial_contacts_update(
1833 contacts: Vec<db::Contact>,
1834 pool: &ConnectionPool,
1835) -> proto::UpdateContacts {
1836 let mut update = proto::UpdateContacts::default();
1837
1838 for contact in contacts {
1839 match contact {
1840 db::Contact::Accepted {
1841 user_id,
1842 should_notify,
1843 busy,
1844 } => {
1845 update
1846 .contacts
1847 .push(contact_for_user(user_id, should_notify, busy, &pool));
1848 }
1849 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
1850 db::Contact::Incoming {
1851 user_id,
1852 should_notify,
1853 } => update
1854 .incoming_requests
1855 .push(proto::IncomingContactRequest {
1856 requester_id: user_id.to_proto(),
1857 should_notify,
1858 }),
1859 }
1860 }
1861
1862 update
1863}
1864
1865fn contact_for_user(
1866 user_id: UserId,
1867 should_notify: bool,
1868 busy: bool,
1869 pool: &ConnectionPool,
1870) -> proto::Contact {
1871 proto::Contact {
1872 user_id: user_id.to_proto(),
1873 online: pool.is_user_online(user_id),
1874 busy,
1875 should_notify,
1876 }
1877}
1878
1879fn room_updated(room: &proto::Room, peer: &Peer) {
1880 for participant in &room.participants {
1881 peer.send(
1882 ConnectionId(participant.peer_id),
1883 proto::RoomUpdated {
1884 room: Some(room.clone()),
1885 },
1886 )
1887 .trace_err();
1888 }
1889}
1890
1891async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
1892 let db = session.db().await;
1893 let contacts = db.get_contacts(user_id).await?;
1894 let busy = db.is_user_busy(user_id).await?;
1895
1896 let pool = session.connection_pool().await;
1897 let updated_contact = contact_for_user(user_id, false, busy, &pool);
1898 for contact in contacts {
1899 if let db::Contact::Accepted {
1900 user_id: contact_user_id,
1901 ..
1902 } = contact
1903 {
1904 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
1905 session
1906 .peer
1907 .send(
1908 contact_conn_id,
1909 proto::UpdateContacts {
1910 contacts: vec![updated_contact.clone()],
1911 remove_contacts: Default::default(),
1912 incoming_requests: Default::default(),
1913 remove_incoming_requests: Default::default(),
1914 outgoing_requests: Default::default(),
1915 remove_outgoing_requests: Default::default(),
1916 },
1917 )
1918 .trace_err();
1919 }
1920 }
1921 }
1922 Ok(())
1923}
1924
1925async fn leave_room_for_session(session: &Session) -> Result<()> {
1926 let mut contacts_to_update = HashSet::default();
1927
1928 let canceled_calls_to_user_ids;
1929 let live_kit_room;
1930 let delete_live_kit_room;
1931 {
1932 let mut left_room = session.db().await.leave_room(session.connection_id).await?;
1933 contacts_to_update.insert(session.user_id);
1934
1935 for project in left_room.left_projects.values() {
1936 project_left(project, session);
1937 }
1938
1939 room_updated(&left_room.room, &session.peer);
1940 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
1941 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
1942 delete_live_kit_room = left_room.room.participants.is_empty();
1943 }
1944
1945 {
1946 let pool = session.connection_pool().await;
1947 for canceled_user_id in canceled_calls_to_user_ids {
1948 for connection_id in pool.user_connection_ids(canceled_user_id) {
1949 session
1950 .peer
1951 .send(connection_id, proto::CallCanceled {})
1952 .trace_err();
1953 }
1954 contacts_to_update.insert(canceled_user_id);
1955 }
1956 }
1957
1958 for contact_user_id in contacts_to_update {
1959 update_user_contacts(contact_user_id, &session).await?;
1960 }
1961
1962 if let Some(live_kit) = session.live_kit_client.as_ref() {
1963 live_kit
1964 .remove_participant(live_kit_room.clone(), session.connection_id.to_string())
1965 .await
1966 .trace_err();
1967
1968 if delete_live_kit_room {
1969 live_kit.delete_room(live_kit_room).await.trace_err();
1970 }
1971 }
1972
1973 Ok(())
1974}
1975
1976fn project_left(project: &db::LeftProject, session: &Session) {
1977 for connection_id in &project.connection_ids {
1978 if project.host_user_id == session.user_id {
1979 session
1980 .peer
1981 .send(
1982 *connection_id,
1983 proto::UnshareProject {
1984 project_id: project.id.to_proto(),
1985 },
1986 )
1987 .trace_err();
1988 } else {
1989 session
1990 .peer
1991 .send(
1992 *connection_id,
1993 proto::RemoveProjectCollaborator {
1994 project_id: project.id.to_proto(),
1995 peer_id: session.connection_id.0,
1996 },
1997 )
1998 .trace_err();
1999 }
2000 }
2001
2002 session
2003 .peer
2004 .send(
2005 session.connection_id,
2006 proto::UnshareProject {
2007 project_id: project.id.to_proto(),
2008 },
2009 )
2010 .trace_err();
2011}
2012
2013pub trait ResultExt {
2014 type Ok;
2015
2016 fn trace_err(self) -> Option<Self::Ok>;
2017}
2018
2019impl<T, E> ResultExt for Result<T, E>
2020where
2021 E: std::fmt::Debug,
2022{
2023 type Ok = T;
2024
2025 fn trace_err(self) -> Option<T> {
2026 match self {
2027 Ok(value) => Some(value),
2028 Err(error) => {
2029 tracing::error!("{:?}", error);
2030 None
2031 }
2032 }
2033 }
2034}