1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{self, Database, ProjectId, RoomId, User, UserId},
6 executor::Executor,
7 AppState, Result,
8};
9use anyhow::anyhow;
10use async_tungstenite::tungstenite::{
11 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
12};
13use axum::{
14 body::Body,
15 extract::{
16 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
17 ConnectInfo, WebSocketUpgrade,
18 },
19 headers::{Header, HeaderName},
20 http::StatusCode,
21 middleware,
22 response::IntoResponse,
23 routing::get,
24 Extension, Router, TypedHeader,
25};
26use collections::{HashMap, HashSet};
27pub use connection_pool::ConnectionPool;
28use futures::{
29 channel::oneshot,
30 future::{self, BoxFuture},
31 stream::FuturesUnordered,
32 FutureExt, SinkExt, StreamExt, TryStreamExt,
33};
34use lazy_static::lazy_static;
35use prometheus::{register_int_gauge, IntGauge};
36use rpc::{
37 proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
38 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
39};
40use serde::{Serialize, Serializer};
41use std::{
42 any::TypeId,
43 fmt,
44 future::Future,
45 marker::PhantomData,
46 mem,
47 net::SocketAddr,
48 ops::{Deref, DerefMut},
49 rc::Rc,
50 sync::{
51 atomic::{AtomicBool, Ordering::SeqCst},
52 Arc,
53 },
54 time::Duration,
55};
56use tokio::sync::{watch, Mutex, MutexGuard};
57use tower::ServiceBuilder;
58use tracing::{info_span, instrument, Instrument};
59
60pub const RECONNECT_TIMEOUT: Duration = rpc::RECEIVE_TIMEOUT;
61
62lazy_static! {
63 static ref METRIC_CONNECTIONS: IntGauge =
64 register_int_gauge!("connections", "number of connections").unwrap();
65 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
66 "shared_projects",
67 "number of open projects with one or more guests"
68 )
69 .unwrap();
70}
71
72type MessageHandler =
73 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
74
75struct Response<R> {
76 peer: Arc<Peer>,
77 receipt: Receipt<R>,
78 responded: Arc<AtomicBool>,
79}
80
81impl<R: RequestMessage> Response<R> {
82 fn send(self, payload: R::Response) -> Result<()> {
83 self.responded.store(true, SeqCst);
84 self.peer.respond(self.receipt, payload)?;
85 Ok(())
86 }
87}
88
89#[derive(Clone)]
90struct Session {
91 user_id: UserId,
92 connection_id: ConnectionId,
93 db: Arc<Mutex<DbHandle>>,
94 peer: Arc<Peer>,
95 connection_pool: Arc<Mutex<ConnectionPool>>,
96 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
97}
98
99impl Session {
100 async fn db(&self) -> MutexGuard<DbHandle> {
101 #[cfg(test)]
102 tokio::task::yield_now().await;
103 let guard = self.db.lock().await;
104 #[cfg(test)]
105 tokio::task::yield_now().await;
106 guard
107 }
108
109 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
110 #[cfg(test)]
111 tokio::task::yield_now().await;
112 let guard = self.connection_pool.lock().await;
113 #[cfg(test)]
114 tokio::task::yield_now().await;
115 ConnectionPoolGuard {
116 guard,
117 _not_send: PhantomData,
118 }
119 }
120}
121
122impl fmt::Debug for Session {
123 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124 f.debug_struct("Session")
125 .field("user_id", &self.user_id)
126 .field("connection_id", &self.connection_id)
127 .finish()
128 }
129}
130
131struct DbHandle(Arc<Database>);
132
133impl Deref for DbHandle {
134 type Target = Database;
135
136 fn deref(&self) -> &Self::Target {
137 self.0.as_ref()
138 }
139}
140
141pub struct Server {
142 peer: Arc<Peer>,
143 pub(crate) connection_pool: Arc<Mutex<ConnectionPool>>,
144 app_state: Arc<AppState>,
145 executor: Executor,
146 handlers: HashMap<TypeId, MessageHandler>,
147 teardown: watch::Sender<()>,
148}
149
150pub(crate) struct ConnectionPoolGuard<'a> {
151 guard: MutexGuard<'a, ConnectionPool>,
152 _not_send: PhantomData<Rc<()>>,
153}
154
155#[derive(Serialize)]
156pub struct ServerSnapshot<'a> {
157 peer: &'a Peer,
158 #[serde(serialize_with = "serialize_deref")]
159 connection_pool: ConnectionPoolGuard<'a>,
160}
161
162pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
163where
164 S: Serializer,
165 T: Deref<Target = U>,
166 U: Serialize,
167{
168 Serialize::serialize(value.deref(), serializer)
169}
170
171impl Server {
172 pub fn new(app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
173 let mut server = Self {
174 peer: Peer::new(),
175 app_state,
176 executor,
177 connection_pool: Default::default(),
178 handlers: Default::default(),
179 teardown: watch::channel(()).0,
180 };
181
182 server
183 .add_request_handler(ping)
184 .add_request_handler(create_room)
185 .add_request_handler(join_room)
186 .add_message_handler(leave_room)
187 .add_request_handler(call)
188 .add_request_handler(cancel_call)
189 .add_message_handler(decline_call)
190 .add_request_handler(update_participant_location)
191 .add_request_handler(share_project)
192 .add_message_handler(unshare_project)
193 .add_request_handler(join_project)
194 .add_message_handler(leave_project)
195 .add_request_handler(update_project)
196 .add_request_handler(update_worktree)
197 .add_message_handler(start_language_server)
198 .add_message_handler(update_language_server)
199 .add_message_handler(update_diagnostic_summary)
200 .add_request_handler(forward_project_request::<proto::GetHover>)
201 .add_request_handler(forward_project_request::<proto::GetDefinition>)
202 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
203 .add_request_handler(forward_project_request::<proto::GetReferences>)
204 .add_request_handler(forward_project_request::<proto::SearchProject>)
205 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
206 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
207 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
208 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
209 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
210 .add_request_handler(forward_project_request::<proto::GetCompletions>)
211 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
212 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
213 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
214 .add_request_handler(forward_project_request::<proto::PrepareRename>)
215 .add_request_handler(forward_project_request::<proto::PerformRename>)
216 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
217 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
218 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
219 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
220 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
221 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
222 .add_message_handler(create_buffer_for_peer)
223 .add_request_handler(update_buffer)
224 .add_message_handler(update_buffer_file)
225 .add_message_handler(buffer_reloaded)
226 .add_message_handler(buffer_saved)
227 .add_request_handler(save_buffer)
228 .add_request_handler(get_users)
229 .add_request_handler(fuzzy_search_users)
230 .add_request_handler(request_contact)
231 .add_request_handler(remove_contact)
232 .add_request_handler(respond_to_contact_request)
233 .add_request_handler(follow)
234 .add_message_handler(unfollow)
235 .add_message_handler(update_followers)
236 .add_message_handler(update_diff_base)
237 .add_request_handler(get_private_user_info);
238
239 Arc::new(server)
240 }
241
242 pub async fn start(&self) -> Result<()> {
243 self.app_state.db.delete_stale_projects().await?;
244 let db = self.app_state.db.clone();
245 let peer = self.peer.clone();
246 let timeout = self.executor.sleep(RECONNECT_TIMEOUT);
247 let pool = self.connection_pool.clone();
248 let live_kit_client = self.app_state.live_kit_client.clone();
249 self.executor.spawn_detached(async move {
250 timeout.await;
251 if let Some(room_ids) = db.outdated_room_ids().await.trace_err() {
252 for room_id in room_ids {
253 let mut contacts_to_update = HashSet::default();
254 let mut canceled_calls_to_user_ids = Vec::new();
255 let mut live_kit_room = String::new();
256 let mut delete_live_kit_room = false;
257
258 if let Ok(mut refreshed_room) = db.refresh_room(room_id).await {
259 room_updated(&refreshed_room.room, &peer);
260 contacts_to_update
261 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
262 contacts_to_update
263 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
264 canceled_calls_to_user_ids =
265 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
266 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
267 delete_live_kit_room = refreshed_room.room.participants.is_empty();
268 }
269
270 {
271 let pool = pool.lock().await;
272 for canceled_user_id in canceled_calls_to_user_ids {
273 for connection_id in pool.user_connection_ids(canceled_user_id) {
274 peer.send(
275 connection_id,
276 proto::CallCanceled {
277 room_id: room_id.to_proto(),
278 },
279 )
280 .trace_err();
281 }
282 }
283 }
284
285 for user_id in contacts_to_update {
286 let busy = db.is_user_busy(user_id).await.trace_err();
287 let contacts = db.get_contacts(user_id).await.trace_err();
288 if let Some((busy, contacts)) = busy.zip(contacts) {
289 let pool = pool.lock().await;
290 let updated_contact = contact_for_user(user_id, false, busy, &pool);
291 for contact in contacts {
292 if let db::Contact::Accepted {
293 user_id: contact_user_id,
294 ..
295 } = contact
296 {
297 for contact_conn_id in pool.user_connection_ids(contact_user_id)
298 {
299 peer.send(
300 contact_conn_id,
301 proto::UpdateContacts {
302 contacts: vec![updated_contact.clone()],
303 remove_contacts: Default::default(),
304 incoming_requests: Default::default(),
305 remove_incoming_requests: Default::default(),
306 outgoing_requests: Default::default(),
307 remove_outgoing_requests: Default::default(),
308 },
309 )
310 .trace_err();
311 }
312 }
313 }
314 }
315 }
316
317 if let Some(live_kit) = live_kit_client.as_ref() {
318 if delete_live_kit_room {
319 live_kit.delete_room(live_kit_room).await.trace_err();
320 }
321 }
322 }
323 }
324 });
325 Ok(())
326 }
327
328 pub fn teardown(&self) {
329 self.peer.reset();
330 let _ = self.teardown.send(());
331 }
332
333 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
334 where
335 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
336 Fut: 'static + Send + Future<Output = Result<()>>,
337 M: EnvelopedMessage,
338 {
339 let prev_handler = self.handlers.insert(
340 TypeId::of::<M>(),
341 Box::new(move |envelope, session| {
342 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
343 let span = info_span!(
344 "handle message",
345 payload_type = envelope.payload_type_name()
346 );
347 span.in_scope(|| {
348 tracing::info!(
349 payload_type = envelope.payload_type_name(),
350 "message received"
351 );
352 });
353 let future = (handler)(*envelope, session);
354 async move {
355 if let Err(error) = future.await {
356 tracing::error!(%error, "error handling message");
357 }
358 }
359 .instrument(span)
360 .boxed()
361 }),
362 );
363 if prev_handler.is_some() {
364 panic!("registered a handler for the same message twice");
365 }
366 self
367 }
368
369 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
370 where
371 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
372 Fut: 'static + Send + Future<Output = Result<()>>,
373 M: EnvelopedMessage,
374 {
375 self.add_handler(move |envelope, session| handler(envelope.payload, session));
376 self
377 }
378
379 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
380 where
381 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
382 Fut: Send + Future<Output = Result<()>>,
383 M: RequestMessage,
384 {
385 let handler = Arc::new(handler);
386 self.add_handler(move |envelope, session| {
387 let receipt = envelope.receipt();
388 let handler = handler.clone();
389 async move {
390 let peer = session.peer.clone();
391 let responded = Arc::new(AtomicBool::default());
392 let response = Response {
393 peer: peer.clone(),
394 responded: responded.clone(),
395 receipt,
396 };
397 match (handler)(envelope.payload, response, session).await {
398 Ok(()) => {
399 if responded.load(std::sync::atomic::Ordering::SeqCst) {
400 Ok(())
401 } else {
402 Err(anyhow!("handler did not send a response"))?
403 }
404 }
405 Err(error) => {
406 peer.respond_with_error(
407 receipt,
408 proto::Error {
409 message: error.to_string(),
410 },
411 )?;
412 Err(error)
413 }
414 }
415 }
416 })
417 }
418
419 pub fn handle_connection(
420 self: &Arc<Self>,
421 connection: Connection,
422 address: String,
423 user: User,
424 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
425 executor: Executor,
426 ) -> impl Future<Output = Result<()>> {
427 let this = self.clone();
428 let user_id = user.id;
429 let login = user.github_login;
430 let span = info_span!("handle connection", %user_id, %login, %address);
431 let mut teardown = self.teardown.subscribe();
432 async move {
433 let (connection_id, handle_io, mut incoming_rx) = this
434 .peer
435 .add_connection(connection, {
436 let executor = executor.clone();
437 move |duration| executor.sleep(duration)
438 });
439
440 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
441 this.peer.send(connection_id, proto::Hello { peer_id: connection_id.0 })?;
442 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
443
444 if let Some(send_connection_id) = send_connection_id.take() {
445 let _ = send_connection_id.send(connection_id);
446 }
447
448 if !user.connected_once {
449 this.peer.send(connection_id, proto::ShowContacts {})?;
450 this.app_state.db.set_user_connected_once(user_id, true).await?;
451 }
452
453 let (contacts, invite_code) = future::try_join(
454 this.app_state.db.get_contacts(user_id),
455 this.app_state.db.get_invite_code_for_user(user_id)
456 ).await?;
457
458 {
459 let mut pool = this.connection_pool.lock().await;
460 pool.add_connection(connection_id, user_id, user.admin);
461 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
462
463 if let Some((code, count)) = invite_code {
464 this.peer.send(connection_id, proto::UpdateInviteInfo {
465 url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
466 count: count as u32,
467 })?;
468 }
469 }
470
471 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
472 this.peer.send(connection_id, incoming_call)?;
473 }
474
475 let session = Session {
476 user_id,
477 connection_id,
478 db: Arc::new(Mutex::new(DbHandle(this.app_state.db.clone()))),
479 peer: this.peer.clone(),
480 connection_pool: this.connection_pool.clone(),
481 live_kit_client: this.app_state.live_kit_client.clone()
482 };
483 update_user_contacts(user_id, &session).await?;
484
485 let handle_io = handle_io.fuse();
486 futures::pin_mut!(handle_io);
487
488 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
489 // This prevents deadlocks when e.g., client A performs a request to client B and
490 // client B performs a request to client A. If both clients stop processing further
491 // messages until their respective request completes, they won't have a chance to
492 // respond to the other client's request and cause a deadlock.
493 //
494 // This arrangement ensures we will attempt to process earlier messages first, but fall
495 // back to processing messages arrived later in the spirit of making progress.
496 let mut foreground_message_handlers = FuturesUnordered::new();
497 loop {
498 let next_message = incoming_rx.next().fuse();
499 futures::pin_mut!(next_message);
500 futures::select_biased! {
501 _ = teardown.changed().fuse() => return Ok(()),
502 result = handle_io => {
503 if let Err(error) = result {
504 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
505 }
506 break;
507 }
508 _ = foreground_message_handlers.next() => {}
509 message = next_message => {
510 if let Some(message) = message {
511 let type_name = message.payload_type_name();
512 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
513 let span_enter = span.enter();
514 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
515 let is_background = message.is_background();
516 let handle_message = (handler)(message, session.clone());
517 drop(span_enter);
518
519 let handle_message = handle_message.instrument(span);
520 if is_background {
521 executor.spawn_detached(handle_message);
522 } else {
523 foreground_message_handlers.push(handle_message);
524 }
525 } else {
526 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
527 }
528 } else {
529 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
530 break;
531 }
532 }
533 }
534 }
535
536 drop(foreground_message_handlers);
537 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
538 if let Err(error) = sign_out(session, teardown, executor).await {
539 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
540 }
541
542 Ok(())
543 }.instrument(span)
544 }
545
546 pub async fn invite_code_redeemed(
547 self: &Arc<Self>,
548 inviter_id: UserId,
549 invitee_id: UserId,
550 ) -> Result<()> {
551 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
552 if let Some(code) = &user.invite_code {
553 let pool = self.connection_pool.lock().await;
554 let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
555 for connection_id in pool.user_connection_ids(inviter_id) {
556 self.peer.send(
557 connection_id,
558 proto::UpdateContacts {
559 contacts: vec![invitee_contact.clone()],
560 ..Default::default()
561 },
562 )?;
563 self.peer.send(
564 connection_id,
565 proto::UpdateInviteInfo {
566 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
567 count: user.invite_count as u32,
568 },
569 )?;
570 }
571 }
572 }
573 Ok(())
574 }
575
576 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
577 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
578 if let Some(invite_code) = &user.invite_code {
579 let pool = self.connection_pool.lock().await;
580 for connection_id in pool.user_connection_ids(user_id) {
581 self.peer.send(
582 connection_id,
583 proto::UpdateInviteInfo {
584 url: format!(
585 "{}{}",
586 self.app_state.config.invite_link_prefix, invite_code
587 ),
588 count: user.invite_count as u32,
589 },
590 )?;
591 }
592 }
593 }
594 Ok(())
595 }
596
597 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
598 ServerSnapshot {
599 connection_pool: ConnectionPoolGuard {
600 guard: self.connection_pool.lock().await,
601 _not_send: PhantomData,
602 },
603 peer: &self.peer,
604 }
605 }
606}
607
608impl<'a> Deref for ConnectionPoolGuard<'a> {
609 type Target = ConnectionPool;
610
611 fn deref(&self) -> &Self::Target {
612 &*self.guard
613 }
614}
615
616impl<'a> DerefMut for ConnectionPoolGuard<'a> {
617 fn deref_mut(&mut self) -> &mut Self::Target {
618 &mut *self.guard
619 }
620}
621
622impl<'a> Drop for ConnectionPoolGuard<'a> {
623 fn drop(&mut self) {
624 #[cfg(test)]
625 self.check_invariants();
626 }
627}
628
629fn broadcast<F>(
630 sender_id: ConnectionId,
631 receiver_ids: impl IntoIterator<Item = ConnectionId>,
632 mut f: F,
633) where
634 F: FnMut(ConnectionId) -> anyhow::Result<()>,
635{
636 for receiver_id in receiver_ids {
637 if receiver_id != sender_id {
638 f(receiver_id).trace_err();
639 }
640 }
641}
642
643lazy_static! {
644 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
645}
646
647pub struct ProtocolVersion(u32);
648
649impl Header for ProtocolVersion {
650 fn name() -> &'static HeaderName {
651 &ZED_PROTOCOL_VERSION
652 }
653
654 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
655 where
656 Self: Sized,
657 I: Iterator<Item = &'i axum::http::HeaderValue>,
658 {
659 let version = values
660 .next()
661 .ok_or_else(axum::headers::Error::invalid)?
662 .to_str()
663 .map_err(|_| axum::headers::Error::invalid())?
664 .parse()
665 .map_err(|_| axum::headers::Error::invalid())?;
666 Ok(Self(version))
667 }
668
669 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
670 values.extend([self.0.to_string().parse().unwrap()]);
671 }
672}
673
674pub fn routes(server: Arc<Server>) -> Router<Body> {
675 Router::new()
676 .route("/rpc", get(handle_websocket_request))
677 .layer(
678 ServiceBuilder::new()
679 .layer(Extension(server.app_state.clone()))
680 .layer(middleware::from_fn(auth::validate_header)),
681 )
682 .route("/metrics", get(handle_metrics))
683 .layer(Extension(server))
684}
685
686pub async fn handle_websocket_request(
687 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
688 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
689 Extension(server): Extension<Arc<Server>>,
690 Extension(user): Extension<User>,
691 ws: WebSocketUpgrade,
692) -> axum::response::Response {
693 if protocol_version != rpc::PROTOCOL_VERSION {
694 return (
695 StatusCode::UPGRADE_REQUIRED,
696 "client must be upgraded".to_string(),
697 )
698 .into_response();
699 }
700 let socket_address = socket_address.to_string();
701 ws.on_upgrade(move |socket| {
702 use util::ResultExt;
703 let socket = socket
704 .map_ok(to_tungstenite_message)
705 .err_into()
706 .with(|message| async move { Ok(to_axum_message(message)) });
707 let connection = Connection::new(Box::pin(socket));
708 async move {
709 server
710 .handle_connection(connection, socket_address, user, None, Executor::Production)
711 .await
712 .log_err();
713 }
714 })
715}
716
717pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
718 let connections = server
719 .connection_pool
720 .lock()
721 .await
722 .connections()
723 .filter(|connection| !connection.admin)
724 .count();
725
726 METRIC_CONNECTIONS.set(connections as _);
727
728 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
729 METRIC_SHARED_PROJECTS.set(shared_projects as _);
730
731 let encoder = prometheus::TextEncoder::new();
732 let metric_families = prometheus::gather();
733 let encoded_metrics = encoder
734 .encode_to_string(&metric_families)
735 .map_err(|err| anyhow!("{}", err))?;
736 Ok(encoded_metrics)
737}
738
739#[instrument(err, skip(executor))]
740async fn sign_out(
741 session: Session,
742 mut teardown: watch::Receiver<()>,
743 executor: Executor,
744) -> Result<()> {
745 session.peer.disconnect(session.connection_id);
746 session
747 .connection_pool()
748 .await
749 .remove_connection(session.connection_id)?;
750
751 if let Some(mut left_projects) = session
752 .db()
753 .await
754 .connection_lost(session.connection_id)
755 .await
756 .trace_err()
757 {
758 for left_project in mem::take(&mut *left_projects) {
759 project_left(&left_project, &session);
760 }
761 }
762
763 futures::select_biased! {
764 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
765 leave_room_for_session(&session).await.trace_err();
766
767 if !session
768 .connection_pool()
769 .await
770 .is_user_online(session.user_id)
771 {
772 let db = session.db().await;
773 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err() {
774 room_updated(&room, &session.peer);
775 }
776 }
777 update_user_contacts(session.user_id, &session).await?;
778 }
779 _ = teardown.changed().fuse() => {}
780 }
781
782 Ok(())
783}
784
785async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
786 response.send(proto::Ack {})?;
787 Ok(())
788}
789
790async fn create_room(
791 _request: proto::CreateRoom,
792 response: Response<proto::CreateRoom>,
793 session: Session,
794) -> Result<()> {
795 let live_kit_room = nanoid::nanoid!(30);
796 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
797 if let Some(_) = live_kit
798 .create_room(live_kit_room.clone())
799 .await
800 .trace_err()
801 {
802 if let Some(token) = live_kit
803 .room_token(&live_kit_room, &session.connection_id.to_string())
804 .trace_err()
805 {
806 Some(proto::LiveKitConnectionInfo {
807 server_url: live_kit.url().into(),
808 token,
809 })
810 } else {
811 None
812 }
813 } else {
814 None
815 }
816 } else {
817 None
818 };
819
820 {
821 let room = session
822 .db()
823 .await
824 .create_room(session.user_id, session.connection_id, &live_kit_room)
825 .await?;
826
827 response.send(proto::CreateRoomResponse {
828 room: Some(room.clone()),
829 live_kit_connection_info,
830 })?;
831 }
832
833 update_user_contacts(session.user_id, &session).await?;
834 Ok(())
835}
836
837async fn join_room(
838 request: proto::JoinRoom,
839 response: Response<proto::JoinRoom>,
840 session: Session,
841) -> Result<()> {
842 let room_id = RoomId::from_proto(request.id);
843 let room = {
844 let room = session
845 .db()
846 .await
847 .join_room(room_id, session.user_id, session.connection_id)
848 .await?;
849 room_updated(&room, &session.peer);
850 room.clone()
851 };
852
853 for connection_id in session
854 .connection_pool()
855 .await
856 .user_connection_ids(session.user_id)
857 {
858 session
859 .peer
860 .send(
861 connection_id,
862 proto::CallCanceled {
863 room_id: room_id.to_proto(),
864 },
865 )
866 .trace_err();
867 }
868
869 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
870 if let Some(token) = live_kit
871 .room_token(&room.live_kit_room, &session.connection_id.to_string())
872 .trace_err()
873 {
874 Some(proto::LiveKitConnectionInfo {
875 server_url: live_kit.url().into(),
876 token,
877 })
878 } else {
879 None
880 }
881 } else {
882 None
883 };
884
885 response.send(proto::JoinRoomResponse {
886 room: Some(room),
887 live_kit_connection_info,
888 })?;
889
890 update_user_contacts(session.user_id, &session).await?;
891 Ok(())
892}
893
894async fn leave_room(_message: proto::LeaveRoom, session: Session) -> Result<()> {
895 leave_room_for_session(&session).await
896}
897
898async fn call(
899 request: proto::Call,
900 response: Response<proto::Call>,
901 session: Session,
902) -> Result<()> {
903 let room_id = RoomId::from_proto(request.room_id);
904 let calling_user_id = session.user_id;
905 let calling_connection_id = session.connection_id;
906 let called_user_id = UserId::from_proto(request.called_user_id);
907 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
908 if !session
909 .db()
910 .await
911 .has_contact(calling_user_id, called_user_id)
912 .await?
913 {
914 return Err(anyhow!("cannot call a user who isn't a contact"))?;
915 }
916
917 let incoming_call = {
918 let (room, incoming_call) = &mut *session
919 .db()
920 .await
921 .call(
922 room_id,
923 calling_user_id,
924 calling_connection_id,
925 called_user_id,
926 initial_project_id,
927 )
928 .await?;
929 room_updated(&room, &session.peer);
930 mem::take(incoming_call)
931 };
932 update_user_contacts(called_user_id, &session).await?;
933
934 let mut calls = session
935 .connection_pool()
936 .await
937 .user_connection_ids(called_user_id)
938 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
939 .collect::<FuturesUnordered<_>>();
940
941 while let Some(call_response) = calls.next().await {
942 match call_response.as_ref() {
943 Ok(_) => {
944 response.send(proto::Ack {})?;
945 return Ok(());
946 }
947 Err(_) => {
948 call_response.trace_err();
949 }
950 }
951 }
952
953 {
954 let room = session
955 .db()
956 .await
957 .call_failed(room_id, called_user_id)
958 .await?;
959 room_updated(&room, &session.peer);
960 }
961 update_user_contacts(called_user_id, &session).await?;
962
963 Err(anyhow!("failed to ring user"))?
964}
965
966async fn cancel_call(
967 request: proto::CancelCall,
968 response: Response<proto::CancelCall>,
969 session: Session,
970) -> Result<()> {
971 let called_user_id = UserId::from_proto(request.called_user_id);
972 let room_id = RoomId::from_proto(request.room_id);
973 {
974 let room = session
975 .db()
976 .await
977 .cancel_call(Some(room_id), session.connection_id, called_user_id)
978 .await?;
979 room_updated(&room, &session.peer);
980 }
981
982 for connection_id in session
983 .connection_pool()
984 .await
985 .user_connection_ids(called_user_id)
986 {
987 session
988 .peer
989 .send(
990 connection_id,
991 proto::CallCanceled {
992 room_id: room_id.to_proto(),
993 },
994 )
995 .trace_err();
996 }
997 response.send(proto::Ack {})?;
998
999 update_user_contacts(called_user_id, &session).await?;
1000 Ok(())
1001}
1002
1003async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1004 let room_id = RoomId::from_proto(message.room_id);
1005 {
1006 let room = session
1007 .db()
1008 .await
1009 .decline_call(Some(room_id), session.user_id)
1010 .await?;
1011 room_updated(&room, &session.peer);
1012 }
1013
1014 for connection_id in session
1015 .connection_pool()
1016 .await
1017 .user_connection_ids(session.user_id)
1018 {
1019 session
1020 .peer
1021 .send(
1022 connection_id,
1023 proto::CallCanceled {
1024 room_id: room_id.to_proto(),
1025 },
1026 )
1027 .trace_err();
1028 }
1029 update_user_contacts(session.user_id, &session).await?;
1030 Ok(())
1031}
1032
1033async fn update_participant_location(
1034 request: proto::UpdateParticipantLocation,
1035 response: Response<proto::UpdateParticipantLocation>,
1036 session: Session,
1037) -> Result<()> {
1038 let room_id = RoomId::from_proto(request.room_id);
1039 let location = request
1040 .location
1041 .ok_or_else(|| anyhow!("invalid location"))?;
1042 let room = session
1043 .db()
1044 .await
1045 .update_room_participant_location(room_id, session.connection_id, location)
1046 .await?;
1047 room_updated(&room, &session.peer);
1048 response.send(proto::Ack {})?;
1049 Ok(())
1050}
1051
1052async fn share_project(
1053 request: proto::ShareProject,
1054 response: Response<proto::ShareProject>,
1055 session: Session,
1056) -> Result<()> {
1057 let (project_id, room) = &*session
1058 .db()
1059 .await
1060 .share_project(
1061 RoomId::from_proto(request.room_id),
1062 session.connection_id,
1063 &request.worktrees,
1064 )
1065 .await?;
1066 response.send(proto::ShareProjectResponse {
1067 project_id: project_id.to_proto(),
1068 })?;
1069 room_updated(&room, &session.peer);
1070
1071 Ok(())
1072}
1073
1074async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1075 let project_id = ProjectId::from_proto(message.project_id);
1076
1077 let (room, guest_connection_ids) = &*session
1078 .db()
1079 .await
1080 .unshare_project(project_id, session.connection_id)
1081 .await?;
1082
1083 broadcast(
1084 session.connection_id,
1085 guest_connection_ids.iter().copied(),
1086 |conn_id| session.peer.send(conn_id, message.clone()),
1087 );
1088 room_updated(&room, &session.peer);
1089
1090 Ok(())
1091}
1092
1093async fn join_project(
1094 request: proto::JoinProject,
1095 response: Response<proto::JoinProject>,
1096 session: Session,
1097) -> Result<()> {
1098 let project_id = ProjectId::from_proto(request.project_id);
1099 let guest_user_id = session.user_id;
1100
1101 tracing::info!(%project_id, "join project");
1102
1103 let (project, replica_id) = &mut *session
1104 .db()
1105 .await
1106 .join_project(project_id, session.connection_id)
1107 .await?;
1108
1109 let collaborators = project
1110 .collaborators
1111 .iter()
1112 .filter(|collaborator| collaborator.connection_id != session.connection_id.0 as i32)
1113 .map(|collaborator| proto::Collaborator {
1114 peer_id: collaborator.connection_id as u32,
1115 replica_id: collaborator.replica_id.0 as u32,
1116 user_id: collaborator.user_id.to_proto(),
1117 })
1118 .collect::<Vec<_>>();
1119 let worktrees = project
1120 .worktrees
1121 .iter()
1122 .map(|(id, worktree)| proto::WorktreeMetadata {
1123 id: *id,
1124 root_name: worktree.root_name.clone(),
1125 visible: worktree.visible,
1126 abs_path: worktree.abs_path.clone(),
1127 })
1128 .collect::<Vec<_>>();
1129
1130 for collaborator in &collaborators {
1131 session
1132 .peer
1133 .send(
1134 ConnectionId(collaborator.peer_id),
1135 proto::AddProjectCollaborator {
1136 project_id: project_id.to_proto(),
1137 collaborator: Some(proto::Collaborator {
1138 peer_id: session.connection_id.0,
1139 replica_id: replica_id.0 as u32,
1140 user_id: guest_user_id.to_proto(),
1141 }),
1142 },
1143 )
1144 .trace_err();
1145 }
1146
1147 // First, we send the metadata associated with each worktree.
1148 response.send(proto::JoinProjectResponse {
1149 worktrees: worktrees.clone(),
1150 replica_id: replica_id.0 as u32,
1151 collaborators: collaborators.clone(),
1152 language_servers: project.language_servers.clone(),
1153 })?;
1154
1155 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1156 #[cfg(any(test, feature = "test-support"))]
1157 const MAX_CHUNK_SIZE: usize = 2;
1158 #[cfg(not(any(test, feature = "test-support")))]
1159 const MAX_CHUNK_SIZE: usize = 256;
1160
1161 // Stream this worktree's entries.
1162 let message = proto::UpdateWorktree {
1163 project_id: project_id.to_proto(),
1164 worktree_id,
1165 abs_path: worktree.abs_path.clone(),
1166 root_name: worktree.root_name,
1167 updated_entries: worktree.entries,
1168 removed_entries: Default::default(),
1169 scan_id: worktree.scan_id,
1170 is_last_update: worktree.is_complete,
1171 };
1172 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1173 session.peer.send(session.connection_id, update.clone())?;
1174 }
1175
1176 // Stream this worktree's diagnostics.
1177 for summary in worktree.diagnostic_summaries {
1178 session.peer.send(
1179 session.connection_id,
1180 proto::UpdateDiagnosticSummary {
1181 project_id: project_id.to_proto(),
1182 worktree_id: worktree.id,
1183 summary: Some(summary),
1184 },
1185 )?;
1186 }
1187 }
1188
1189 for language_server in &project.language_servers {
1190 session.peer.send(
1191 session.connection_id,
1192 proto::UpdateLanguageServer {
1193 project_id: project_id.to_proto(),
1194 language_server_id: language_server.id,
1195 variant: Some(
1196 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1197 proto::LspDiskBasedDiagnosticsUpdated {},
1198 ),
1199 ),
1200 },
1201 )?;
1202 }
1203
1204 Ok(())
1205}
1206
1207async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1208 let sender_id = session.connection_id;
1209 let project_id = ProjectId::from_proto(request.project_id);
1210
1211 let project = session
1212 .db()
1213 .await
1214 .leave_project(project_id, sender_id)
1215 .await?;
1216 tracing::info!(
1217 %project_id,
1218 host_user_id = %project.host_user_id,
1219 host_connection_id = %project.host_connection_id,
1220 "leave project"
1221 );
1222 project_left(&project, &session);
1223
1224 Ok(())
1225}
1226
1227async fn update_project(
1228 request: proto::UpdateProject,
1229 response: Response<proto::UpdateProject>,
1230 session: Session,
1231) -> Result<()> {
1232 let project_id = ProjectId::from_proto(request.project_id);
1233 let (room, guest_connection_ids) = &*session
1234 .db()
1235 .await
1236 .update_project(project_id, session.connection_id, &request.worktrees)
1237 .await?;
1238 broadcast(
1239 session.connection_id,
1240 guest_connection_ids.iter().copied(),
1241 |connection_id| {
1242 session
1243 .peer
1244 .forward_send(session.connection_id, connection_id, request.clone())
1245 },
1246 );
1247 room_updated(&room, &session.peer);
1248 response.send(proto::Ack {})?;
1249
1250 Ok(())
1251}
1252
1253async fn update_worktree(
1254 request: proto::UpdateWorktree,
1255 response: Response<proto::UpdateWorktree>,
1256 session: Session,
1257) -> Result<()> {
1258 let guest_connection_ids = session
1259 .db()
1260 .await
1261 .update_worktree(&request, session.connection_id)
1262 .await?;
1263
1264 broadcast(
1265 session.connection_id,
1266 guest_connection_ids.iter().copied(),
1267 |connection_id| {
1268 session
1269 .peer
1270 .forward_send(session.connection_id, connection_id, request.clone())
1271 },
1272 );
1273 response.send(proto::Ack {})?;
1274 Ok(())
1275}
1276
1277async fn update_diagnostic_summary(
1278 message: proto::UpdateDiagnosticSummary,
1279 session: Session,
1280) -> Result<()> {
1281 let guest_connection_ids = session
1282 .db()
1283 .await
1284 .update_diagnostic_summary(&message, session.connection_id)
1285 .await?;
1286
1287 broadcast(
1288 session.connection_id,
1289 guest_connection_ids.iter().copied(),
1290 |connection_id| {
1291 session
1292 .peer
1293 .forward_send(session.connection_id, connection_id, message.clone())
1294 },
1295 );
1296
1297 Ok(())
1298}
1299
1300async fn start_language_server(
1301 request: proto::StartLanguageServer,
1302 session: Session,
1303) -> Result<()> {
1304 let guest_connection_ids = session
1305 .db()
1306 .await
1307 .start_language_server(&request, session.connection_id)
1308 .await?;
1309
1310 broadcast(
1311 session.connection_id,
1312 guest_connection_ids.iter().copied(),
1313 |connection_id| {
1314 session
1315 .peer
1316 .forward_send(session.connection_id, connection_id, request.clone())
1317 },
1318 );
1319 Ok(())
1320}
1321
1322async fn update_language_server(
1323 request: proto::UpdateLanguageServer,
1324 session: Session,
1325) -> Result<()> {
1326 let project_id = ProjectId::from_proto(request.project_id);
1327 let project_connection_ids = session
1328 .db()
1329 .await
1330 .project_connection_ids(project_id, session.connection_id)
1331 .await?;
1332 broadcast(
1333 session.connection_id,
1334 project_connection_ids.iter().copied(),
1335 |connection_id| {
1336 session
1337 .peer
1338 .forward_send(session.connection_id, connection_id, request.clone())
1339 },
1340 );
1341 Ok(())
1342}
1343
1344async fn forward_project_request<T>(
1345 request: T,
1346 response: Response<T>,
1347 session: Session,
1348) -> Result<()>
1349where
1350 T: EntityMessage + RequestMessage,
1351{
1352 let project_id = ProjectId::from_proto(request.remote_entity_id());
1353 let host_connection_id = {
1354 let collaborators = session
1355 .db()
1356 .await
1357 .project_collaborators(project_id, session.connection_id)
1358 .await?;
1359 ConnectionId(
1360 collaborators
1361 .iter()
1362 .find(|collaborator| collaborator.is_host)
1363 .ok_or_else(|| anyhow!("host not found"))?
1364 .connection_id as u32,
1365 )
1366 };
1367
1368 let payload = session
1369 .peer
1370 .forward_request(session.connection_id, host_connection_id, request)
1371 .await?;
1372
1373 response.send(payload)?;
1374 Ok(())
1375}
1376
1377async fn save_buffer(
1378 request: proto::SaveBuffer,
1379 response: Response<proto::SaveBuffer>,
1380 session: Session,
1381) -> Result<()> {
1382 let project_id = ProjectId::from_proto(request.project_id);
1383 let host_connection_id = {
1384 let collaborators = session
1385 .db()
1386 .await
1387 .project_collaborators(project_id, session.connection_id)
1388 .await?;
1389 let host = collaborators
1390 .iter()
1391 .find(|collaborator| collaborator.is_host)
1392 .ok_or_else(|| anyhow!("host not found"))?;
1393 ConnectionId(host.connection_id as u32)
1394 };
1395 let response_payload = session
1396 .peer
1397 .forward_request(session.connection_id, host_connection_id, request.clone())
1398 .await?;
1399
1400 let mut collaborators = session
1401 .db()
1402 .await
1403 .project_collaborators(project_id, session.connection_id)
1404 .await?;
1405 collaborators
1406 .retain(|collaborator| collaborator.connection_id != session.connection_id.0 as i32);
1407 let project_connection_ids = collaborators
1408 .iter()
1409 .map(|collaborator| ConnectionId(collaborator.connection_id as u32));
1410 broadcast(host_connection_id, project_connection_ids, |conn_id| {
1411 session
1412 .peer
1413 .forward_send(host_connection_id, conn_id, response_payload.clone())
1414 });
1415 response.send(response_payload)?;
1416 Ok(())
1417}
1418
1419async fn create_buffer_for_peer(
1420 request: proto::CreateBufferForPeer,
1421 session: Session,
1422) -> Result<()> {
1423 session.peer.forward_send(
1424 session.connection_id,
1425 ConnectionId(request.peer_id),
1426 request,
1427 )?;
1428 Ok(())
1429}
1430
1431async fn update_buffer(
1432 request: proto::UpdateBuffer,
1433 response: Response<proto::UpdateBuffer>,
1434 session: Session,
1435) -> Result<()> {
1436 let project_id = ProjectId::from_proto(request.project_id);
1437 let project_connection_ids = session
1438 .db()
1439 .await
1440 .project_connection_ids(project_id, session.connection_id)
1441 .await?;
1442
1443 broadcast(
1444 session.connection_id,
1445 project_connection_ids.iter().copied(),
1446 |connection_id| {
1447 session
1448 .peer
1449 .forward_send(session.connection_id, connection_id, request.clone())
1450 },
1451 );
1452 response.send(proto::Ack {})?;
1453 Ok(())
1454}
1455
1456async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1457 let project_id = ProjectId::from_proto(request.project_id);
1458 let project_connection_ids = session
1459 .db()
1460 .await
1461 .project_connection_ids(project_id, session.connection_id)
1462 .await?;
1463
1464 broadcast(
1465 session.connection_id,
1466 project_connection_ids.iter().copied(),
1467 |connection_id| {
1468 session
1469 .peer
1470 .forward_send(session.connection_id, connection_id, request.clone())
1471 },
1472 );
1473 Ok(())
1474}
1475
1476async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1477 let project_id = ProjectId::from_proto(request.project_id);
1478 let project_connection_ids = session
1479 .db()
1480 .await
1481 .project_connection_ids(project_id, session.connection_id)
1482 .await?;
1483 broadcast(
1484 session.connection_id,
1485 project_connection_ids.iter().copied(),
1486 |connection_id| {
1487 session
1488 .peer
1489 .forward_send(session.connection_id, connection_id, request.clone())
1490 },
1491 );
1492 Ok(())
1493}
1494
1495async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1496 let project_id = ProjectId::from_proto(request.project_id);
1497 let project_connection_ids = session
1498 .db()
1499 .await
1500 .project_connection_ids(project_id, session.connection_id)
1501 .await?;
1502 broadcast(
1503 session.connection_id,
1504 project_connection_ids.iter().copied(),
1505 |connection_id| {
1506 session
1507 .peer
1508 .forward_send(session.connection_id, connection_id, request.clone())
1509 },
1510 );
1511 Ok(())
1512}
1513
1514async fn follow(
1515 request: proto::Follow,
1516 response: Response<proto::Follow>,
1517 session: Session,
1518) -> Result<()> {
1519 let project_id = ProjectId::from_proto(request.project_id);
1520 let leader_id = ConnectionId(request.leader_id);
1521 let follower_id = session.connection_id;
1522 {
1523 let project_connection_ids = session
1524 .db()
1525 .await
1526 .project_connection_ids(project_id, session.connection_id)
1527 .await?;
1528
1529 if !project_connection_ids.contains(&leader_id) {
1530 Err(anyhow!("no such peer"))?;
1531 }
1532 }
1533
1534 let mut response_payload = session
1535 .peer
1536 .forward_request(session.connection_id, leader_id, request)
1537 .await?;
1538 response_payload
1539 .views
1540 .retain(|view| view.leader_id != Some(follower_id.0));
1541 response.send(response_payload)?;
1542 Ok(())
1543}
1544
1545async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1546 let project_id = ProjectId::from_proto(request.project_id);
1547 let leader_id = ConnectionId(request.leader_id);
1548 let project_connection_ids = session
1549 .db()
1550 .await
1551 .project_connection_ids(project_id, session.connection_id)
1552 .await?;
1553 if !project_connection_ids.contains(&leader_id) {
1554 Err(anyhow!("no such peer"))?;
1555 }
1556 session
1557 .peer
1558 .forward_send(session.connection_id, leader_id, request)?;
1559 Ok(())
1560}
1561
1562async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1563 let project_id = ProjectId::from_proto(request.project_id);
1564 let project_connection_ids = session
1565 .db
1566 .lock()
1567 .await
1568 .project_connection_ids(project_id, session.connection_id)
1569 .await?;
1570
1571 let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1572 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1573 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1574 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1575 });
1576 for follower_id in &request.follower_ids {
1577 let follower_id = ConnectionId(*follower_id);
1578 if project_connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
1579 session
1580 .peer
1581 .forward_send(session.connection_id, follower_id, request.clone())?;
1582 }
1583 }
1584 Ok(())
1585}
1586
1587async fn get_users(
1588 request: proto::GetUsers,
1589 response: Response<proto::GetUsers>,
1590 session: Session,
1591) -> Result<()> {
1592 let user_ids = request
1593 .user_ids
1594 .into_iter()
1595 .map(UserId::from_proto)
1596 .collect();
1597 let users = session
1598 .db()
1599 .await
1600 .get_users_by_ids(user_ids)
1601 .await?
1602 .into_iter()
1603 .map(|user| proto::User {
1604 id: user.id.to_proto(),
1605 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1606 github_login: user.github_login,
1607 })
1608 .collect();
1609 response.send(proto::UsersResponse { users })?;
1610 Ok(())
1611}
1612
1613async fn fuzzy_search_users(
1614 request: proto::FuzzySearchUsers,
1615 response: Response<proto::FuzzySearchUsers>,
1616 session: Session,
1617) -> Result<()> {
1618 let query = request.query;
1619 let users = match query.len() {
1620 0 => vec![],
1621 1 | 2 => session
1622 .db()
1623 .await
1624 .get_user_by_github_account(&query, None)
1625 .await?
1626 .into_iter()
1627 .collect(),
1628 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1629 };
1630 let users = users
1631 .into_iter()
1632 .filter(|user| user.id != session.user_id)
1633 .map(|user| proto::User {
1634 id: user.id.to_proto(),
1635 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1636 github_login: user.github_login,
1637 })
1638 .collect();
1639 response.send(proto::UsersResponse { users })?;
1640 Ok(())
1641}
1642
1643async fn request_contact(
1644 request: proto::RequestContact,
1645 response: Response<proto::RequestContact>,
1646 session: Session,
1647) -> Result<()> {
1648 let requester_id = session.user_id;
1649 let responder_id = UserId::from_proto(request.responder_id);
1650 if requester_id == responder_id {
1651 return Err(anyhow!("cannot add yourself as a contact"))?;
1652 }
1653
1654 session
1655 .db()
1656 .await
1657 .send_contact_request(requester_id, responder_id)
1658 .await?;
1659
1660 // Update outgoing contact requests of requester
1661 let mut update = proto::UpdateContacts::default();
1662 update.outgoing_requests.push(responder_id.to_proto());
1663 for connection_id in session
1664 .connection_pool()
1665 .await
1666 .user_connection_ids(requester_id)
1667 {
1668 session.peer.send(connection_id, update.clone())?;
1669 }
1670
1671 // Update incoming contact requests of responder
1672 let mut update = proto::UpdateContacts::default();
1673 update
1674 .incoming_requests
1675 .push(proto::IncomingContactRequest {
1676 requester_id: requester_id.to_proto(),
1677 should_notify: true,
1678 });
1679 for connection_id in session
1680 .connection_pool()
1681 .await
1682 .user_connection_ids(responder_id)
1683 {
1684 session.peer.send(connection_id, update.clone())?;
1685 }
1686
1687 response.send(proto::Ack {})?;
1688 Ok(())
1689}
1690
1691async fn respond_to_contact_request(
1692 request: proto::RespondToContactRequest,
1693 response: Response<proto::RespondToContactRequest>,
1694 session: Session,
1695) -> Result<()> {
1696 let responder_id = session.user_id;
1697 let requester_id = UserId::from_proto(request.requester_id);
1698 let db = session.db().await;
1699 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
1700 db.dismiss_contact_notification(responder_id, requester_id)
1701 .await?;
1702 } else {
1703 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
1704
1705 db.respond_to_contact_request(responder_id, requester_id, accept)
1706 .await?;
1707 let requester_busy = db.is_user_busy(requester_id).await?;
1708 let responder_busy = db.is_user_busy(responder_id).await?;
1709
1710 let pool = session.connection_pool().await;
1711 // Update responder with new contact
1712 let mut update = proto::UpdateContacts::default();
1713 if accept {
1714 update
1715 .contacts
1716 .push(contact_for_user(requester_id, false, requester_busy, &pool));
1717 }
1718 update
1719 .remove_incoming_requests
1720 .push(requester_id.to_proto());
1721 for connection_id in pool.user_connection_ids(responder_id) {
1722 session.peer.send(connection_id, update.clone())?;
1723 }
1724
1725 // Update requester with new contact
1726 let mut update = proto::UpdateContacts::default();
1727 if accept {
1728 update
1729 .contacts
1730 .push(contact_for_user(responder_id, true, responder_busy, &pool));
1731 }
1732 update
1733 .remove_outgoing_requests
1734 .push(responder_id.to_proto());
1735 for connection_id in pool.user_connection_ids(requester_id) {
1736 session.peer.send(connection_id, update.clone())?;
1737 }
1738 }
1739
1740 response.send(proto::Ack {})?;
1741 Ok(())
1742}
1743
1744async fn remove_contact(
1745 request: proto::RemoveContact,
1746 response: Response<proto::RemoveContact>,
1747 session: Session,
1748) -> Result<()> {
1749 let requester_id = session.user_id;
1750 let responder_id = UserId::from_proto(request.user_id);
1751 let db = session.db().await;
1752 db.remove_contact(requester_id, responder_id).await?;
1753
1754 let pool = session.connection_pool().await;
1755 // Update outgoing contact requests of requester
1756 let mut update = proto::UpdateContacts::default();
1757 update
1758 .remove_outgoing_requests
1759 .push(responder_id.to_proto());
1760 for connection_id in pool.user_connection_ids(requester_id) {
1761 session.peer.send(connection_id, update.clone())?;
1762 }
1763
1764 // Update incoming contact requests of responder
1765 let mut update = proto::UpdateContacts::default();
1766 update
1767 .remove_incoming_requests
1768 .push(requester_id.to_proto());
1769 for connection_id in pool.user_connection_ids(responder_id) {
1770 session.peer.send(connection_id, update.clone())?;
1771 }
1772
1773 response.send(proto::Ack {})?;
1774 Ok(())
1775}
1776
1777async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
1778 let project_id = ProjectId::from_proto(request.project_id);
1779 let project_connection_ids = session
1780 .db()
1781 .await
1782 .project_connection_ids(project_id, session.connection_id)
1783 .await?;
1784 broadcast(
1785 session.connection_id,
1786 project_connection_ids.iter().copied(),
1787 |connection_id| {
1788 session
1789 .peer
1790 .forward_send(session.connection_id, connection_id, request.clone())
1791 },
1792 );
1793 Ok(())
1794}
1795
1796async fn get_private_user_info(
1797 _request: proto::GetPrivateUserInfo,
1798 response: Response<proto::GetPrivateUserInfo>,
1799 session: Session,
1800) -> Result<()> {
1801 let metrics_id = session
1802 .db()
1803 .await
1804 .get_user_metrics_id(session.user_id)
1805 .await?;
1806 let user = session
1807 .db()
1808 .await
1809 .get_user_by_id(session.user_id)
1810 .await?
1811 .ok_or_else(|| anyhow!("user not found"))?;
1812 response.send(proto::GetPrivateUserInfoResponse {
1813 metrics_id,
1814 staff: user.admin,
1815 })?;
1816 Ok(())
1817}
1818
1819fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1820 match message {
1821 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1822 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1823 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1824 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1825 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1826 code: frame.code.into(),
1827 reason: frame.reason,
1828 })),
1829 }
1830}
1831
1832fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1833 match message {
1834 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1835 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1836 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1837 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1838 AxumMessage::Close(frame) => {
1839 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1840 code: frame.code.into(),
1841 reason: frame.reason,
1842 }))
1843 }
1844 }
1845}
1846
1847fn build_initial_contacts_update(
1848 contacts: Vec<db::Contact>,
1849 pool: &ConnectionPool,
1850) -> proto::UpdateContacts {
1851 let mut update = proto::UpdateContacts::default();
1852
1853 for contact in contacts {
1854 match contact {
1855 db::Contact::Accepted {
1856 user_id,
1857 should_notify,
1858 busy,
1859 } => {
1860 update
1861 .contacts
1862 .push(contact_for_user(user_id, should_notify, busy, &pool));
1863 }
1864 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
1865 db::Contact::Incoming {
1866 user_id,
1867 should_notify,
1868 } => update
1869 .incoming_requests
1870 .push(proto::IncomingContactRequest {
1871 requester_id: user_id.to_proto(),
1872 should_notify,
1873 }),
1874 }
1875 }
1876
1877 update
1878}
1879
1880fn contact_for_user(
1881 user_id: UserId,
1882 should_notify: bool,
1883 busy: bool,
1884 pool: &ConnectionPool,
1885) -> proto::Contact {
1886 proto::Contact {
1887 user_id: user_id.to_proto(),
1888 online: pool.is_user_online(user_id),
1889 busy,
1890 should_notify,
1891 }
1892}
1893
1894fn room_updated(room: &proto::Room, peer: &Peer) {
1895 for participant in &room.participants {
1896 peer.send(
1897 ConnectionId(participant.peer_id),
1898 proto::RoomUpdated {
1899 room: Some(room.clone()),
1900 },
1901 )
1902 .trace_err();
1903 }
1904}
1905
1906async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
1907 let db = session.db().await;
1908 let contacts = db.get_contacts(user_id).await?;
1909 let busy = db.is_user_busy(user_id).await?;
1910
1911 let pool = session.connection_pool().await;
1912 let updated_contact = contact_for_user(user_id, false, busy, &pool);
1913 for contact in contacts {
1914 if let db::Contact::Accepted {
1915 user_id: contact_user_id,
1916 ..
1917 } = contact
1918 {
1919 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
1920 session
1921 .peer
1922 .send(
1923 contact_conn_id,
1924 proto::UpdateContacts {
1925 contacts: vec![updated_contact.clone()],
1926 remove_contacts: Default::default(),
1927 incoming_requests: Default::default(),
1928 remove_incoming_requests: Default::default(),
1929 outgoing_requests: Default::default(),
1930 remove_outgoing_requests: Default::default(),
1931 },
1932 )
1933 .trace_err();
1934 }
1935 }
1936 }
1937 Ok(())
1938}
1939
1940async fn leave_room_for_session(session: &Session) -> Result<()> {
1941 let mut contacts_to_update = HashSet::default();
1942
1943 let room_id;
1944 let canceled_calls_to_user_ids;
1945 let live_kit_room;
1946 let delete_live_kit_room;
1947 {
1948 let mut left_room = session.db().await.leave_room(session.connection_id).await?;
1949 contacts_to_update.insert(session.user_id);
1950
1951 for project in left_room.left_projects.values() {
1952 project_left(project, session);
1953 }
1954
1955 room_updated(&left_room.room, &session.peer);
1956 room_id = RoomId::from_proto(left_room.room.id);
1957 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
1958 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
1959 delete_live_kit_room = left_room.room.participants.is_empty();
1960 }
1961
1962 {
1963 let pool = session.connection_pool().await;
1964 for canceled_user_id in canceled_calls_to_user_ids {
1965 for connection_id in pool.user_connection_ids(canceled_user_id) {
1966 session
1967 .peer
1968 .send(
1969 connection_id,
1970 proto::CallCanceled {
1971 room_id: room_id.to_proto(),
1972 },
1973 )
1974 .trace_err();
1975 }
1976 contacts_to_update.insert(canceled_user_id);
1977 }
1978 }
1979
1980 for contact_user_id in contacts_to_update {
1981 update_user_contacts(contact_user_id, &session).await?;
1982 }
1983
1984 if let Some(live_kit) = session.live_kit_client.as_ref() {
1985 live_kit
1986 .remove_participant(live_kit_room.clone(), session.connection_id.to_string())
1987 .await
1988 .trace_err();
1989
1990 if delete_live_kit_room {
1991 live_kit.delete_room(live_kit_room).await.trace_err();
1992 }
1993 }
1994
1995 Ok(())
1996}
1997
1998fn project_left(project: &db::LeftProject, session: &Session) {
1999 for connection_id in &project.connection_ids {
2000 if project.host_user_id == session.user_id {
2001 session
2002 .peer
2003 .send(
2004 *connection_id,
2005 proto::UnshareProject {
2006 project_id: project.id.to_proto(),
2007 },
2008 )
2009 .trace_err();
2010 } else {
2011 session
2012 .peer
2013 .send(
2014 *connection_id,
2015 proto::RemoveProjectCollaborator {
2016 project_id: project.id.to_proto(),
2017 peer_id: session.connection_id.0,
2018 },
2019 )
2020 .trace_err();
2021 }
2022 }
2023
2024 session
2025 .peer
2026 .send(
2027 session.connection_id,
2028 proto::UnshareProject {
2029 project_id: project.id.to_proto(),
2030 },
2031 )
2032 .trace_err();
2033}
2034
2035pub trait ResultExt {
2036 type Ok;
2037
2038 fn trace_err(self) -> Option<Self::Ok>;
2039}
2040
2041impl<T, E> ResultExt for Result<T, E>
2042where
2043 E: std::fmt::Debug,
2044{
2045 type Ok = T;
2046
2047 fn trace_err(self) -> Option<T> {
2048 match self {
2049 Ok(value) => Some(value),
2050 Err(error) => {
2051 tracing::error!("{:?}", error);
2052 None
2053 }
2054 }
2055 }
2056}