1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{self, Database, ProjectId, RoomId, ServerId, User, UserId},
6 executor::Executor,
7 AppState, Result,
8};
9use anyhow::anyhow;
10use async_tungstenite::tungstenite::{
11 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
12};
13use axum::{
14 body::Body,
15 extract::{
16 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
17 ConnectInfo, WebSocketUpgrade,
18 },
19 headers::{Header, HeaderName},
20 http::StatusCode,
21 middleware,
22 response::IntoResponse,
23 routing::get,
24 Extension, Router, TypedHeader,
25};
26use collections::{HashMap, HashSet};
27pub use connection_pool::ConnectionPool;
28use futures::{
29 channel::oneshot,
30 future::{self, BoxFuture},
31 stream::FuturesUnordered,
32 FutureExt, SinkExt, StreamExt, TryStreamExt,
33};
34use lazy_static::lazy_static;
35use prometheus::{register_int_gauge, IntGauge};
36use rpc::{
37 proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
38 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
39};
40use serde::{Serialize, Serializer};
41use std::{
42 any::TypeId,
43 fmt,
44 future::Future,
45 marker::PhantomData,
46 mem,
47 net::SocketAddr,
48 ops::{Deref, DerefMut},
49 rc::Rc,
50 sync::{
51 atomic::{AtomicBool, Ordering::SeqCst},
52 Arc,
53 },
54 time::{Duration, Instant},
55};
56use tokio::sync::{watch, Semaphore};
57use tower::ServiceBuilder;
58use tracing::{info_span, instrument, Instrument};
59
60pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
61pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
62
63lazy_static! {
64 static ref METRIC_CONNECTIONS: IntGauge =
65 register_int_gauge!("connections", "number of connections").unwrap();
66 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
67 "shared_projects",
68 "number of open projects with one or more guests"
69 )
70 .unwrap();
71}
72
73type MessageHandler =
74 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
75
76struct Response<R> {
77 peer: Arc<Peer>,
78 receipt: Receipt<R>,
79 responded: Arc<AtomicBool>,
80}
81
82impl<R: RequestMessage> Response<R> {
83 fn send(self, payload: R::Response) -> Result<()> {
84 self.responded.store(true, SeqCst);
85 self.peer.respond(self.receipt, payload)?;
86 Ok(())
87 }
88}
89
90#[derive(Clone)]
91struct Session {
92 user_id: UserId,
93 connection_id: ConnectionId,
94 db: Arc<tokio::sync::Mutex<DbHandle>>,
95 peer: Arc<Peer>,
96 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
97 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
98 executor: Executor,
99}
100
101impl Session {
102 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
103 #[cfg(test)]
104 tokio::task::yield_now().await;
105 let guard = self.db.lock().await;
106 #[cfg(test)]
107 tokio::task::yield_now().await;
108 guard
109 }
110
111 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
112 #[cfg(test)]
113 tokio::task::yield_now().await;
114 let guard = self.connection_pool.lock();
115 ConnectionPoolGuard {
116 guard,
117 _not_send: PhantomData,
118 }
119 }
120}
121
122impl fmt::Debug for Session {
123 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124 f.debug_struct("Session")
125 .field("user_id", &self.user_id)
126 .field("connection_id", &self.connection_id)
127 .finish()
128 }
129}
130
131struct DbHandle(Arc<Database>);
132
133impl Deref for DbHandle {
134 type Target = Database;
135
136 fn deref(&self) -> &Self::Target {
137 self.0.as_ref()
138 }
139}
140
141pub struct Server {
142 id: parking_lot::Mutex<ServerId>,
143 peer: Arc<Peer>,
144 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
145 app_state: Arc<AppState>,
146 executor: Executor,
147 handlers: HashMap<TypeId, MessageHandler>,
148 teardown: watch::Sender<()>,
149}
150
151pub(crate) struct ConnectionPoolGuard<'a> {
152 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
153 _not_send: PhantomData<Rc<()>>,
154}
155
156#[derive(Serialize)]
157pub struct ServerSnapshot<'a> {
158 peer: &'a Peer,
159 #[serde(serialize_with = "serialize_deref")]
160 connection_pool: ConnectionPoolGuard<'a>,
161}
162
163pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
164where
165 S: Serializer,
166 T: Deref<Target = U>,
167 U: Serialize,
168{
169 Serialize::serialize(value.deref(), serializer)
170}
171
172impl Server {
173 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
174 let mut server = Self {
175 id: parking_lot::Mutex::new(id),
176 peer: Peer::new(id.0 as u32),
177 app_state,
178 executor,
179 connection_pool: Default::default(),
180 handlers: Default::default(),
181 teardown: watch::channel(()).0,
182 };
183
184 server
185 .add_request_handler(ping)
186 .add_request_handler(create_room)
187 .add_request_handler(join_room)
188 .add_request_handler(rejoin_room)
189 .add_request_handler(leave_room)
190 .add_request_handler(call)
191 .add_request_handler(cancel_call)
192 .add_message_handler(decline_call)
193 .add_request_handler(update_participant_location)
194 .add_request_handler(share_project)
195 .add_message_handler(unshare_project)
196 .add_request_handler(join_project)
197 .add_message_handler(leave_project)
198 .add_request_handler(update_project)
199 .add_request_handler(update_worktree)
200 .add_message_handler(start_language_server)
201 .add_message_handler(update_language_server)
202 .add_message_handler(update_diagnostic_summary)
203 .add_request_handler(forward_project_request::<proto::GetHover>)
204 .add_request_handler(forward_project_request::<proto::GetDefinition>)
205 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
206 .add_request_handler(forward_project_request::<proto::GetReferences>)
207 .add_request_handler(forward_project_request::<proto::SearchProject>)
208 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
209 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
210 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
211 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
212 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
213 .add_request_handler(forward_project_request::<proto::GetCompletions>)
214 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
215 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
216 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
217 .add_request_handler(forward_project_request::<proto::PrepareRename>)
218 .add_request_handler(forward_project_request::<proto::PerformRename>)
219 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
220 .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
221 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
222 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
223 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
224 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
225 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
226 .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
227 .add_message_handler(create_buffer_for_peer)
228 .add_request_handler(update_buffer)
229 .add_message_handler(update_buffer_file)
230 .add_message_handler(buffer_reloaded)
231 .add_message_handler(buffer_saved)
232 .add_request_handler(forward_project_request::<proto::SaveBuffer>)
233 .add_request_handler(get_users)
234 .add_request_handler(fuzzy_search_users)
235 .add_request_handler(request_contact)
236 .add_request_handler(remove_contact)
237 .add_request_handler(respond_to_contact_request)
238 .add_request_handler(follow)
239 .add_message_handler(unfollow)
240 .add_message_handler(update_followers)
241 .add_message_handler(update_diff_base)
242 .add_request_handler(get_private_user_info);
243
244 Arc::new(server)
245 }
246
247 pub async fn start(&self) -> Result<()> {
248 let server_id = *self.id.lock();
249 let app_state = self.app_state.clone();
250 let peer = self.peer.clone();
251 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
252 let pool = self.connection_pool.clone();
253 let live_kit_client = self.app_state.live_kit_client.clone();
254
255 let span = info_span!("start server");
256 self.executor.spawn_detached(
257 async move {
258 tracing::info!("waiting for cleanup timeout");
259 timeout.await;
260 tracing::info!("cleanup timeout expired, retrieving stale rooms");
261 if let Some(room_ids) = app_state
262 .db
263 .stale_room_ids(&app_state.config.zed_environment, server_id)
264 .await
265 .trace_err()
266 {
267 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
268 for room_id in room_ids {
269 let mut contacts_to_update = HashSet::default();
270 let mut canceled_calls_to_user_ids = Vec::new();
271 let mut live_kit_room = String::new();
272 let mut delete_live_kit_room = false;
273
274 if let Some(mut refreshed_room) = app_state
275 .db
276 .refresh_room(room_id, server_id)
277 .await
278 .trace_err()
279 {
280 tracing::info!(
281 room_id = room_id.0,
282 new_participant_count = refreshed_room.room.participants.len(),
283 "refreshed room"
284 );
285 room_updated(&refreshed_room.room, &peer);
286 contacts_to_update
287 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
288 contacts_to_update
289 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
290 canceled_calls_to_user_ids =
291 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
292 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
293 delete_live_kit_room = refreshed_room.room.participants.is_empty();
294 }
295
296 {
297 let pool = pool.lock();
298 for canceled_user_id in canceled_calls_to_user_ids {
299 for connection_id in pool.user_connection_ids(canceled_user_id) {
300 peer.send(
301 connection_id,
302 proto::CallCanceled {
303 room_id: room_id.to_proto(),
304 },
305 )
306 .trace_err();
307 }
308 }
309 }
310
311 for user_id in contacts_to_update {
312 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
313 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
314 if let Some((busy, contacts)) = busy.zip(contacts) {
315 let pool = pool.lock();
316 let updated_contact = contact_for_user(user_id, false, busy, &pool);
317 for contact in contacts {
318 if let db::Contact::Accepted {
319 user_id: contact_user_id,
320 ..
321 } = contact
322 {
323 for contact_conn_id in
324 pool.user_connection_ids(contact_user_id)
325 {
326 peer.send(
327 contact_conn_id,
328 proto::UpdateContacts {
329 contacts: vec![updated_contact.clone()],
330 remove_contacts: Default::default(),
331 incoming_requests: Default::default(),
332 remove_incoming_requests: Default::default(),
333 outgoing_requests: Default::default(),
334 remove_outgoing_requests: Default::default(),
335 },
336 )
337 .trace_err();
338 }
339 }
340 }
341 }
342 }
343
344 if let Some(live_kit) = live_kit_client.as_ref() {
345 if delete_live_kit_room {
346 live_kit.delete_room(live_kit_room).await.trace_err();
347 }
348 }
349 }
350 }
351
352 app_state
353 .db
354 .delete_stale_servers(&app_state.config.zed_environment, server_id)
355 .await
356 .trace_err();
357 }
358 .instrument(span),
359 );
360 Ok(())
361 }
362
363 pub fn teardown(&self) {
364 self.peer.teardown();
365 self.connection_pool.lock().reset();
366 let _ = self.teardown.send(());
367 }
368
369 #[cfg(test)]
370 pub fn reset(&self, id: ServerId) {
371 self.teardown();
372 *self.id.lock() = id;
373 self.peer.reset(id.0 as u32);
374 }
375
376 #[cfg(test)]
377 pub fn id(&self) -> ServerId {
378 *self.id.lock()
379 }
380
381 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
382 where
383 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
384 Fut: 'static + Send + Future<Output = Result<()>>,
385 M: EnvelopedMessage,
386 {
387 let prev_handler = self.handlers.insert(
388 TypeId::of::<M>(),
389 Box::new(move |envelope, session| {
390 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
391 let span = info_span!(
392 "handle message",
393 payload_type = envelope.payload_type_name()
394 );
395 span.in_scope(|| {
396 tracing::info!(
397 payload_type = envelope.payload_type_name(),
398 "message received"
399 );
400 });
401 let start_time = Instant::now();
402 let future = (handler)(*envelope, session);
403 async move {
404 let result = future.await;
405 let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
406 match result {
407 Err(error) => {
408 tracing::error!(%error, ?duration_ms, "error handling message")
409 }
410 Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
411 }
412 }
413 .instrument(span)
414 .boxed()
415 }),
416 );
417 if prev_handler.is_some() {
418 panic!("registered a handler for the same message twice");
419 }
420 self
421 }
422
423 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
424 where
425 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
426 Fut: 'static + Send + Future<Output = Result<()>>,
427 M: EnvelopedMessage,
428 {
429 self.add_handler(move |envelope, session| handler(envelope.payload, session));
430 self
431 }
432
433 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
434 where
435 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
436 Fut: Send + Future<Output = Result<()>>,
437 M: RequestMessage,
438 {
439 let handler = Arc::new(handler);
440 self.add_handler(move |envelope, session| {
441 let receipt = envelope.receipt();
442 let handler = handler.clone();
443 async move {
444 let peer = session.peer.clone();
445 let responded = Arc::new(AtomicBool::default());
446 let response = Response {
447 peer: peer.clone(),
448 responded: responded.clone(),
449 receipt,
450 };
451 match (handler)(envelope.payload, response, session).await {
452 Ok(()) => {
453 if responded.load(std::sync::atomic::Ordering::SeqCst) {
454 Ok(())
455 } else {
456 Err(anyhow!("handler did not send a response"))?
457 }
458 }
459 Err(error) => {
460 peer.respond_with_error(
461 receipt,
462 proto::Error {
463 message: error.to_string(),
464 },
465 )?;
466 Err(error)
467 }
468 }
469 }
470 })
471 }
472
473 pub fn handle_connection(
474 self: &Arc<Self>,
475 connection: Connection,
476 address: String,
477 user: User,
478 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
479 executor: Executor,
480 ) -> impl Future<Output = Result<()>> {
481 let this = self.clone();
482 let user_id = user.id;
483 let login = user.github_login;
484 let span = info_span!("handle connection", %user_id, %login, %address);
485 let mut teardown = self.teardown.subscribe();
486 async move {
487 let (connection_id, handle_io, mut incoming_rx) = this
488 .peer
489 .add_connection(connection, {
490 let executor = executor.clone();
491 move |duration| executor.sleep(duration)
492 });
493
494 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
495 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
496 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
497
498 if let Some(send_connection_id) = send_connection_id.take() {
499 let _ = send_connection_id.send(connection_id);
500 }
501
502 if !user.connected_once {
503 this.peer.send(connection_id, proto::ShowContacts {})?;
504 this.app_state.db.set_user_connected_once(user_id, true).await?;
505 }
506
507 let (contacts, invite_code) = future::try_join(
508 this.app_state.db.get_contacts(user_id),
509 this.app_state.db.get_invite_code_for_user(user_id)
510 ).await?;
511
512 {
513 let mut pool = this.connection_pool.lock();
514 pool.add_connection(connection_id, user_id, user.admin);
515 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
516
517 if let Some((code, count)) = invite_code {
518 this.peer.send(connection_id, proto::UpdateInviteInfo {
519 url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
520 count: count as u32,
521 })?;
522 }
523 }
524
525 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
526 this.peer.send(connection_id, incoming_call)?;
527 }
528
529 let session = Session {
530 user_id,
531 connection_id,
532 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
533 peer: this.peer.clone(),
534 connection_pool: this.connection_pool.clone(),
535 live_kit_client: this.app_state.live_kit_client.clone(),
536 executor: executor.clone(),
537 };
538 update_user_contacts(user_id, &session).await?;
539
540 let handle_io = handle_io.fuse();
541 futures::pin_mut!(handle_io);
542
543 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
544 // This prevents deadlocks when e.g., client A performs a request to client B and
545 // client B performs a request to client A. If both clients stop processing further
546 // messages until their respective request completes, they won't have a chance to
547 // respond to the other client's request and cause a deadlock.
548 //
549 // This arrangement ensures we will attempt to process earlier messages first, but fall
550 // back to processing messages arrived later in the spirit of making progress.
551 let mut foreground_message_handlers = FuturesUnordered::new();
552 let concurrent_handlers = Arc::new(Semaphore::new(256));
553 loop {
554 let next_message = async {
555 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
556 let message = incoming_rx.next().await;
557 (permit, message)
558 }.fuse();
559 futures::pin_mut!(next_message);
560 futures::select_biased! {
561 _ = teardown.changed().fuse() => return Ok(()),
562 result = handle_io => {
563 if let Err(error) = result {
564 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
565 }
566 break;
567 }
568 _ = foreground_message_handlers.next() => {}
569 next_message = next_message => {
570 let (permit, message) = next_message;
571 if let Some(message) = message {
572 let type_name = message.payload_type_name();
573 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
574 let span_enter = span.enter();
575 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
576 let is_background = message.is_background();
577 let handle_message = (handler)(message, session.clone());
578 drop(span_enter);
579
580 let handle_message = async move {
581 handle_message.await;
582 drop(permit);
583 }.instrument(span);
584 if is_background {
585 executor.spawn_detached(handle_message);
586 } else {
587 foreground_message_handlers.push(handle_message);
588 }
589 } else {
590 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
591 }
592 } else {
593 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
594 break;
595 }
596 }
597 }
598 }
599
600 drop(foreground_message_handlers);
601 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
602 if let Err(error) = connection_lost(session, teardown, executor).await {
603 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
604 }
605
606 Ok(())
607 }.instrument(span)
608 }
609
610 pub async fn invite_code_redeemed(
611 self: &Arc<Self>,
612 inviter_id: UserId,
613 invitee_id: UserId,
614 ) -> Result<()> {
615 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
616 if let Some(code) = &user.invite_code {
617 let pool = self.connection_pool.lock();
618 let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
619 for connection_id in pool.user_connection_ids(inviter_id) {
620 self.peer.send(
621 connection_id,
622 proto::UpdateContacts {
623 contacts: vec![invitee_contact.clone()],
624 ..Default::default()
625 },
626 )?;
627 self.peer.send(
628 connection_id,
629 proto::UpdateInviteInfo {
630 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
631 count: user.invite_count as u32,
632 },
633 )?;
634 }
635 }
636 }
637 Ok(())
638 }
639
640 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
641 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
642 if let Some(invite_code) = &user.invite_code {
643 let pool = self.connection_pool.lock();
644 for connection_id in pool.user_connection_ids(user_id) {
645 self.peer.send(
646 connection_id,
647 proto::UpdateInviteInfo {
648 url: format!(
649 "{}{}",
650 self.app_state.config.invite_link_prefix, invite_code
651 ),
652 count: user.invite_count as u32,
653 },
654 )?;
655 }
656 }
657 }
658 Ok(())
659 }
660
661 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
662 ServerSnapshot {
663 connection_pool: ConnectionPoolGuard {
664 guard: self.connection_pool.lock(),
665 _not_send: PhantomData,
666 },
667 peer: &self.peer,
668 }
669 }
670}
671
672impl<'a> Deref for ConnectionPoolGuard<'a> {
673 type Target = ConnectionPool;
674
675 fn deref(&self) -> &Self::Target {
676 &*self.guard
677 }
678}
679
680impl<'a> DerefMut for ConnectionPoolGuard<'a> {
681 fn deref_mut(&mut self) -> &mut Self::Target {
682 &mut *self.guard
683 }
684}
685
686impl<'a> Drop for ConnectionPoolGuard<'a> {
687 fn drop(&mut self) {
688 #[cfg(test)]
689 self.check_invariants();
690 }
691}
692
693fn broadcast<F>(
694 sender_id: Option<ConnectionId>,
695 receiver_ids: impl IntoIterator<Item = ConnectionId>,
696 mut f: F,
697) where
698 F: FnMut(ConnectionId) -> anyhow::Result<()>,
699{
700 for receiver_id in receiver_ids {
701 if Some(receiver_id) != sender_id {
702 if let Err(error) = f(receiver_id) {
703 tracing::error!("failed to send to {:?} {}", receiver_id, error);
704 }
705 }
706 }
707}
708
709lazy_static! {
710 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
711}
712
713pub struct ProtocolVersion(u32);
714
715impl Header for ProtocolVersion {
716 fn name() -> &'static HeaderName {
717 &ZED_PROTOCOL_VERSION
718 }
719
720 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
721 where
722 Self: Sized,
723 I: Iterator<Item = &'i axum::http::HeaderValue>,
724 {
725 let version = values
726 .next()
727 .ok_or_else(axum::headers::Error::invalid)?
728 .to_str()
729 .map_err(|_| axum::headers::Error::invalid())?
730 .parse()
731 .map_err(|_| axum::headers::Error::invalid())?;
732 Ok(Self(version))
733 }
734
735 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
736 values.extend([self.0.to_string().parse().unwrap()]);
737 }
738}
739
740pub fn routes(server: Arc<Server>) -> Router<Body> {
741 Router::new()
742 .route("/rpc", get(handle_websocket_request))
743 .layer(
744 ServiceBuilder::new()
745 .layer(Extension(server.app_state.clone()))
746 .layer(middleware::from_fn(auth::validate_header)),
747 )
748 .route("/metrics", get(handle_metrics))
749 .layer(Extension(server))
750}
751
752pub async fn handle_websocket_request(
753 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
754 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
755 Extension(server): Extension<Arc<Server>>,
756 Extension(user): Extension<User>,
757 ws: WebSocketUpgrade,
758) -> axum::response::Response {
759 if protocol_version != rpc::PROTOCOL_VERSION {
760 return (
761 StatusCode::UPGRADE_REQUIRED,
762 "client must be upgraded".to_string(),
763 )
764 .into_response();
765 }
766 let socket_address = socket_address.to_string();
767 ws.on_upgrade(move |socket| {
768 use util::ResultExt;
769 let socket = socket
770 .map_ok(to_tungstenite_message)
771 .err_into()
772 .with(|message| async move { Ok(to_axum_message(message)) });
773 let connection = Connection::new(Box::pin(socket));
774 async move {
775 server
776 .handle_connection(connection, socket_address, user, None, Executor::Production)
777 .await
778 .log_err();
779 }
780 })
781}
782
783pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
784 let connections = server
785 .connection_pool
786 .lock()
787 .connections()
788 .filter(|connection| !connection.admin)
789 .count();
790
791 METRIC_CONNECTIONS.set(connections as _);
792
793 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
794 METRIC_SHARED_PROJECTS.set(shared_projects as _);
795
796 let encoder = prometheus::TextEncoder::new();
797 let metric_families = prometheus::gather();
798 let encoded_metrics = encoder
799 .encode_to_string(&metric_families)
800 .map_err(|err| anyhow!("{}", err))?;
801 Ok(encoded_metrics)
802}
803
804#[instrument(err, skip(executor))]
805async fn connection_lost(
806 session: Session,
807 mut teardown: watch::Receiver<()>,
808 executor: Executor,
809) -> Result<()> {
810 session.peer.disconnect(session.connection_id);
811 session
812 .connection_pool()
813 .await
814 .remove_connection(session.connection_id)?;
815
816 session
817 .db()
818 .await
819 .connection_lost(session.connection_id)
820 .await
821 .trace_err();
822
823 futures::select_biased! {
824 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
825 leave_room_for_session(&session).await.trace_err();
826
827 if !session
828 .connection_pool()
829 .await
830 .is_user_online(session.user_id)
831 {
832 let db = session.db().await;
833 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
834 room_updated(&room, &session.peer);
835 }
836 }
837 update_user_contacts(session.user_id, &session).await?;
838 }
839 _ = teardown.changed().fuse() => {}
840 }
841
842 Ok(())
843}
844
845async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
846 response.send(proto::Ack {})?;
847 Ok(())
848}
849
850async fn create_room(
851 _request: proto::CreateRoom,
852 response: Response<proto::CreateRoom>,
853 session: Session,
854) -> Result<()> {
855 let live_kit_room = nanoid::nanoid!(30);
856 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
857 if let Some(_) = live_kit
858 .create_room(live_kit_room.clone())
859 .await
860 .trace_err()
861 {
862 if let Some(token) = live_kit
863 .room_token(&live_kit_room, &session.user_id.to_string())
864 .trace_err()
865 {
866 Some(proto::LiveKitConnectionInfo {
867 server_url: live_kit.url().into(),
868 token,
869 })
870 } else {
871 None
872 }
873 } else {
874 None
875 }
876 } else {
877 None
878 };
879
880 {
881 let room = session
882 .db()
883 .await
884 .create_room(session.user_id, session.connection_id, &live_kit_room)
885 .await?;
886
887 response.send(proto::CreateRoomResponse {
888 room: Some(room.clone()),
889 live_kit_connection_info,
890 })?;
891 }
892
893 update_user_contacts(session.user_id, &session).await?;
894 Ok(())
895}
896
897async fn join_room(
898 request: proto::JoinRoom,
899 response: Response<proto::JoinRoom>,
900 session: Session,
901) -> Result<()> {
902 let room_id = RoomId::from_proto(request.id);
903 let room = {
904 let room = session
905 .db()
906 .await
907 .join_room(room_id, session.user_id, session.connection_id)
908 .await?;
909 room_updated(&room, &session.peer);
910 room.clone()
911 };
912
913 for connection_id in session
914 .connection_pool()
915 .await
916 .user_connection_ids(session.user_id)
917 {
918 session
919 .peer
920 .send(
921 connection_id,
922 proto::CallCanceled {
923 room_id: room_id.to_proto(),
924 },
925 )
926 .trace_err();
927 }
928
929 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
930 if let Some(token) = live_kit
931 .room_token(&room.live_kit_room, &session.user_id.to_string())
932 .trace_err()
933 {
934 Some(proto::LiveKitConnectionInfo {
935 server_url: live_kit.url().into(),
936 token,
937 })
938 } else {
939 None
940 }
941 } else {
942 None
943 };
944
945 response.send(proto::JoinRoomResponse {
946 room: Some(room),
947 live_kit_connection_info,
948 })?;
949
950 update_user_contacts(session.user_id, &session).await?;
951 Ok(())
952}
953
954async fn rejoin_room(
955 request: proto::RejoinRoom,
956 response: Response<proto::RejoinRoom>,
957 session: Session,
958) -> Result<()> {
959 {
960 let mut rejoined_room = session
961 .db()
962 .await
963 .rejoin_room(request, session.user_id, session.connection_id)
964 .await?;
965
966 response.send(proto::RejoinRoomResponse {
967 room: Some(rejoined_room.room.clone()),
968 reshared_projects: rejoined_room
969 .reshared_projects
970 .iter()
971 .map(|project| proto::ResharedProject {
972 id: project.id.to_proto(),
973 collaborators: project
974 .collaborators
975 .iter()
976 .map(|collaborator| collaborator.to_proto())
977 .collect(),
978 })
979 .collect(),
980 rejoined_projects: rejoined_room
981 .rejoined_projects
982 .iter()
983 .map(|rejoined_project| proto::RejoinedProject {
984 id: rejoined_project.id.to_proto(),
985 worktrees: rejoined_project
986 .worktrees
987 .iter()
988 .map(|worktree| proto::WorktreeMetadata {
989 id: worktree.id,
990 root_name: worktree.root_name.clone(),
991 visible: worktree.visible,
992 abs_path: worktree.abs_path.clone(),
993 })
994 .collect(),
995 collaborators: rejoined_project
996 .collaborators
997 .iter()
998 .map(|collaborator| collaborator.to_proto())
999 .collect(),
1000 language_servers: rejoined_project.language_servers.clone(),
1001 })
1002 .collect(),
1003 })?;
1004 room_updated(&rejoined_room.room, &session.peer);
1005
1006 for project in &rejoined_room.reshared_projects {
1007 for collaborator in &project.collaborators {
1008 session
1009 .peer
1010 .send(
1011 collaborator.connection_id,
1012 proto::UpdateProjectCollaborator {
1013 project_id: project.id.to_proto(),
1014 old_peer_id: Some(project.old_connection_id.into()),
1015 new_peer_id: Some(session.connection_id.into()),
1016 },
1017 )
1018 .trace_err();
1019 }
1020
1021 broadcast(
1022 Some(session.connection_id),
1023 project
1024 .collaborators
1025 .iter()
1026 .map(|collaborator| collaborator.connection_id),
1027 |connection_id| {
1028 session.peer.forward_send(
1029 session.connection_id,
1030 connection_id,
1031 proto::UpdateProject {
1032 project_id: project.id.to_proto(),
1033 worktrees: project.worktrees.clone(),
1034 },
1035 )
1036 },
1037 );
1038 }
1039
1040 for project in &rejoined_room.rejoined_projects {
1041 for collaborator in &project.collaborators {
1042 session
1043 .peer
1044 .send(
1045 collaborator.connection_id,
1046 proto::UpdateProjectCollaborator {
1047 project_id: project.id.to_proto(),
1048 old_peer_id: Some(project.old_connection_id.into()),
1049 new_peer_id: Some(session.connection_id.into()),
1050 },
1051 )
1052 .trace_err();
1053 }
1054 }
1055
1056 for project in &mut rejoined_room.rejoined_projects {
1057 for worktree in mem::take(&mut project.worktrees) {
1058 #[cfg(any(test, feature = "test-support"))]
1059 const MAX_CHUNK_SIZE: usize = 2;
1060 #[cfg(not(any(test, feature = "test-support")))]
1061 const MAX_CHUNK_SIZE: usize = 256;
1062
1063 // Stream this worktree's entries.
1064 let message = proto::UpdateWorktree {
1065 project_id: project.id.to_proto(),
1066 worktree_id: worktree.id,
1067 abs_path: worktree.abs_path.clone(),
1068 root_name: worktree.root_name,
1069 updated_entries: worktree.updated_entries,
1070 removed_entries: worktree.removed_entries,
1071 scan_id: worktree.scan_id,
1072 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1073 updated_repositories: worktree.updated_repositories,
1074 removed_repositories: worktree.removed_repositories,
1075 };
1076 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1077 session.peer.send(session.connection_id, update.clone())?;
1078 }
1079
1080 // Stream this worktree's diagnostics.
1081 for summary in worktree.diagnostic_summaries {
1082 session.peer.send(
1083 session.connection_id,
1084 proto::UpdateDiagnosticSummary {
1085 project_id: project.id.to_proto(),
1086 worktree_id: worktree.id,
1087 summary: Some(summary),
1088 },
1089 )?;
1090 }
1091 }
1092
1093 for language_server in &project.language_servers {
1094 session.peer.send(
1095 session.connection_id,
1096 proto::UpdateLanguageServer {
1097 project_id: project.id.to_proto(),
1098 language_server_id: language_server.id,
1099 variant: Some(
1100 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1101 proto::LspDiskBasedDiagnosticsUpdated {},
1102 ),
1103 ),
1104 },
1105 )?;
1106 }
1107 }
1108 }
1109
1110 update_user_contacts(session.user_id, &session).await?;
1111 Ok(())
1112}
1113
1114async fn leave_room(
1115 _: proto::LeaveRoom,
1116 response: Response<proto::LeaveRoom>,
1117 session: Session,
1118) -> Result<()> {
1119 leave_room_for_session(&session).await?;
1120 response.send(proto::Ack {})?;
1121 Ok(())
1122}
1123
1124async fn call(
1125 request: proto::Call,
1126 response: Response<proto::Call>,
1127 session: Session,
1128) -> Result<()> {
1129 let room_id = RoomId::from_proto(request.room_id);
1130 let calling_user_id = session.user_id;
1131 let calling_connection_id = session.connection_id;
1132 let called_user_id = UserId::from_proto(request.called_user_id);
1133 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1134 if !session
1135 .db()
1136 .await
1137 .has_contact(calling_user_id, called_user_id)
1138 .await?
1139 {
1140 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1141 }
1142
1143 let incoming_call = {
1144 let (room, incoming_call) = &mut *session
1145 .db()
1146 .await
1147 .call(
1148 room_id,
1149 calling_user_id,
1150 calling_connection_id,
1151 called_user_id,
1152 initial_project_id,
1153 )
1154 .await?;
1155 room_updated(&room, &session.peer);
1156 mem::take(incoming_call)
1157 };
1158 update_user_contacts(called_user_id, &session).await?;
1159
1160 let mut calls = session
1161 .connection_pool()
1162 .await
1163 .user_connection_ids(called_user_id)
1164 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1165 .collect::<FuturesUnordered<_>>();
1166
1167 while let Some(call_response) = calls.next().await {
1168 match call_response.as_ref() {
1169 Ok(_) => {
1170 response.send(proto::Ack {})?;
1171 return Ok(());
1172 }
1173 Err(_) => {
1174 call_response.trace_err();
1175 }
1176 }
1177 }
1178
1179 {
1180 let room = session
1181 .db()
1182 .await
1183 .call_failed(room_id, called_user_id)
1184 .await?;
1185 room_updated(&room, &session.peer);
1186 }
1187 update_user_contacts(called_user_id, &session).await?;
1188
1189 Err(anyhow!("failed to ring user"))?
1190}
1191
1192async fn cancel_call(
1193 request: proto::CancelCall,
1194 response: Response<proto::CancelCall>,
1195 session: Session,
1196) -> Result<()> {
1197 let called_user_id = UserId::from_proto(request.called_user_id);
1198 let room_id = RoomId::from_proto(request.room_id);
1199 {
1200 let room = session
1201 .db()
1202 .await
1203 .cancel_call(room_id, session.connection_id, called_user_id)
1204 .await?;
1205 room_updated(&room, &session.peer);
1206 }
1207
1208 for connection_id in session
1209 .connection_pool()
1210 .await
1211 .user_connection_ids(called_user_id)
1212 {
1213 session
1214 .peer
1215 .send(
1216 connection_id,
1217 proto::CallCanceled {
1218 room_id: room_id.to_proto(),
1219 },
1220 )
1221 .trace_err();
1222 }
1223 response.send(proto::Ack {})?;
1224
1225 update_user_contacts(called_user_id, &session).await?;
1226 Ok(())
1227}
1228
1229async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1230 let room_id = RoomId::from_proto(message.room_id);
1231 {
1232 let room = session
1233 .db()
1234 .await
1235 .decline_call(Some(room_id), session.user_id)
1236 .await?
1237 .ok_or_else(|| anyhow!("failed to decline call"))?;
1238 room_updated(&room, &session.peer);
1239 }
1240
1241 for connection_id in session
1242 .connection_pool()
1243 .await
1244 .user_connection_ids(session.user_id)
1245 {
1246 session
1247 .peer
1248 .send(
1249 connection_id,
1250 proto::CallCanceled {
1251 room_id: room_id.to_proto(),
1252 },
1253 )
1254 .trace_err();
1255 }
1256 update_user_contacts(session.user_id, &session).await?;
1257 Ok(())
1258}
1259
1260async fn update_participant_location(
1261 request: proto::UpdateParticipantLocation,
1262 response: Response<proto::UpdateParticipantLocation>,
1263 session: Session,
1264) -> Result<()> {
1265 let room_id = RoomId::from_proto(request.room_id);
1266 let location = request
1267 .location
1268 .ok_or_else(|| anyhow!("invalid location"))?;
1269 let room = session
1270 .db()
1271 .await
1272 .update_room_participant_location(room_id, session.connection_id, location)
1273 .await?;
1274 room_updated(&room, &session.peer);
1275 response.send(proto::Ack {})?;
1276 Ok(())
1277}
1278
1279async fn share_project(
1280 request: proto::ShareProject,
1281 response: Response<proto::ShareProject>,
1282 session: Session,
1283) -> Result<()> {
1284 let (project_id, room) = &*session
1285 .db()
1286 .await
1287 .share_project(
1288 RoomId::from_proto(request.room_id),
1289 session.connection_id,
1290 &request.worktrees,
1291 )
1292 .await?;
1293 response.send(proto::ShareProjectResponse {
1294 project_id: project_id.to_proto(),
1295 })?;
1296 room_updated(&room, &session.peer);
1297
1298 Ok(())
1299}
1300
1301async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1302 let project_id = ProjectId::from_proto(message.project_id);
1303
1304 let (room, guest_connection_ids) = &*session
1305 .db()
1306 .await
1307 .unshare_project(project_id, session.connection_id)
1308 .await?;
1309
1310 broadcast(
1311 Some(session.connection_id),
1312 guest_connection_ids.iter().copied(),
1313 |conn_id| session.peer.send(conn_id, message.clone()),
1314 );
1315 room_updated(&room, &session.peer);
1316
1317 Ok(())
1318}
1319
1320async fn join_project(
1321 request: proto::JoinProject,
1322 response: Response<proto::JoinProject>,
1323 session: Session,
1324) -> Result<()> {
1325 let project_id = ProjectId::from_proto(request.project_id);
1326 let guest_user_id = session.user_id;
1327
1328 tracing::info!(%project_id, "join project");
1329
1330 let (project, replica_id) = &mut *session
1331 .db()
1332 .await
1333 .join_project(project_id, session.connection_id)
1334 .await?;
1335
1336 let collaborators = project
1337 .collaborators
1338 .iter()
1339 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1340 .map(|collaborator| collaborator.to_proto())
1341 .collect::<Vec<_>>();
1342
1343 let worktrees = project
1344 .worktrees
1345 .iter()
1346 .map(|(id, worktree)| proto::WorktreeMetadata {
1347 id: *id,
1348 root_name: worktree.root_name.clone(),
1349 visible: worktree.visible,
1350 abs_path: worktree.abs_path.clone(),
1351 })
1352 .collect::<Vec<_>>();
1353
1354 for collaborator in &collaborators {
1355 session
1356 .peer
1357 .send(
1358 collaborator.peer_id.unwrap().into(),
1359 proto::AddProjectCollaborator {
1360 project_id: project_id.to_proto(),
1361 collaborator: Some(proto::Collaborator {
1362 peer_id: Some(session.connection_id.into()),
1363 replica_id: replica_id.0 as u32,
1364 user_id: guest_user_id.to_proto(),
1365 }),
1366 },
1367 )
1368 .trace_err();
1369 }
1370
1371 // First, we send the metadata associated with each worktree.
1372 response.send(proto::JoinProjectResponse {
1373 worktrees: worktrees.clone(),
1374 replica_id: replica_id.0 as u32,
1375 collaborators: collaborators.clone(),
1376 language_servers: project.language_servers.clone(),
1377 })?;
1378
1379 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1380 #[cfg(any(test, feature = "test-support"))]
1381 const MAX_CHUNK_SIZE: usize = 2;
1382 #[cfg(not(any(test, feature = "test-support")))]
1383 const MAX_CHUNK_SIZE: usize = 256;
1384
1385 // Stream this worktree's entries.
1386 let message = proto::UpdateWorktree {
1387 project_id: project_id.to_proto(),
1388 worktree_id,
1389 abs_path: worktree.abs_path.clone(),
1390 root_name: worktree.root_name,
1391 updated_entries: worktree.entries,
1392 removed_entries: Default::default(),
1393 scan_id: worktree.scan_id,
1394 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1395 updated_repositories: worktree.repository_entries.into_values().collect(),
1396 removed_repositories: Default::default(),
1397 };
1398 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1399 session.peer.send(session.connection_id, update.clone())?;
1400 }
1401
1402 // Stream this worktree's diagnostics.
1403 for summary in worktree.diagnostic_summaries {
1404 session.peer.send(
1405 session.connection_id,
1406 proto::UpdateDiagnosticSummary {
1407 project_id: project_id.to_proto(),
1408 worktree_id: worktree.id,
1409 summary: Some(summary),
1410 },
1411 )?;
1412 }
1413 }
1414
1415 for language_server in &project.language_servers {
1416 session.peer.send(
1417 session.connection_id,
1418 proto::UpdateLanguageServer {
1419 project_id: project_id.to_proto(),
1420 language_server_id: language_server.id,
1421 variant: Some(
1422 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1423 proto::LspDiskBasedDiagnosticsUpdated {},
1424 ),
1425 ),
1426 },
1427 )?;
1428 }
1429
1430 Ok(())
1431}
1432
1433async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1434 let sender_id = session.connection_id;
1435 let project_id = ProjectId::from_proto(request.project_id);
1436
1437 let (room, project) = &*session
1438 .db()
1439 .await
1440 .leave_project(project_id, sender_id)
1441 .await?;
1442 tracing::info!(
1443 %project_id,
1444 host_user_id = %project.host_user_id,
1445 host_connection_id = %project.host_connection_id,
1446 "leave project"
1447 );
1448
1449 project_left(&project, &session);
1450 room_updated(&room, &session.peer);
1451
1452 Ok(())
1453}
1454
1455async fn update_project(
1456 request: proto::UpdateProject,
1457 response: Response<proto::UpdateProject>,
1458 session: Session,
1459) -> Result<()> {
1460 let project_id = ProjectId::from_proto(request.project_id);
1461 let (room, guest_connection_ids) = &*session
1462 .db()
1463 .await
1464 .update_project(project_id, session.connection_id, &request.worktrees)
1465 .await?;
1466 broadcast(
1467 Some(session.connection_id),
1468 guest_connection_ids.iter().copied(),
1469 |connection_id| {
1470 session
1471 .peer
1472 .forward_send(session.connection_id, connection_id, request.clone())
1473 },
1474 );
1475 room_updated(&room, &session.peer);
1476 response.send(proto::Ack {})?;
1477
1478 Ok(())
1479}
1480
1481async fn update_worktree(
1482 request: proto::UpdateWorktree,
1483 response: Response<proto::UpdateWorktree>,
1484 session: Session,
1485) -> Result<()> {
1486 let guest_connection_ids = session
1487 .db()
1488 .await
1489 .update_worktree(&request, session.connection_id)
1490 .await?;
1491
1492 broadcast(
1493 Some(session.connection_id),
1494 guest_connection_ids.iter().copied(),
1495 |connection_id| {
1496 session
1497 .peer
1498 .forward_send(session.connection_id, connection_id, request.clone())
1499 },
1500 );
1501 response.send(proto::Ack {})?;
1502 Ok(())
1503}
1504
1505async fn update_diagnostic_summary(
1506 message: proto::UpdateDiagnosticSummary,
1507 session: Session,
1508) -> Result<()> {
1509 let guest_connection_ids = session
1510 .db()
1511 .await
1512 .update_diagnostic_summary(&message, session.connection_id)
1513 .await?;
1514
1515 broadcast(
1516 Some(session.connection_id),
1517 guest_connection_ids.iter().copied(),
1518 |connection_id| {
1519 session
1520 .peer
1521 .forward_send(session.connection_id, connection_id, message.clone())
1522 },
1523 );
1524
1525 Ok(())
1526}
1527
1528async fn start_language_server(
1529 request: proto::StartLanguageServer,
1530 session: Session,
1531) -> Result<()> {
1532 let guest_connection_ids = session
1533 .db()
1534 .await
1535 .start_language_server(&request, session.connection_id)
1536 .await?;
1537
1538 broadcast(
1539 Some(session.connection_id),
1540 guest_connection_ids.iter().copied(),
1541 |connection_id| {
1542 session
1543 .peer
1544 .forward_send(session.connection_id, connection_id, request.clone())
1545 },
1546 );
1547 Ok(())
1548}
1549
1550async fn update_language_server(
1551 request: proto::UpdateLanguageServer,
1552 session: Session,
1553) -> Result<()> {
1554 session.executor.record_backtrace();
1555 let project_id = ProjectId::from_proto(request.project_id);
1556 let project_connection_ids = session
1557 .db()
1558 .await
1559 .project_connection_ids(project_id, session.connection_id)
1560 .await?;
1561 broadcast(
1562 Some(session.connection_id),
1563 project_connection_ids.iter().copied(),
1564 |connection_id| {
1565 session
1566 .peer
1567 .forward_send(session.connection_id, connection_id, request.clone())
1568 },
1569 );
1570 Ok(())
1571}
1572
1573async fn forward_project_request<T>(
1574 request: T,
1575 response: Response<T>,
1576 session: Session,
1577) -> Result<()>
1578where
1579 T: EntityMessage + RequestMessage,
1580{
1581 session.executor.record_backtrace();
1582 let project_id = ProjectId::from_proto(request.remote_entity_id());
1583 let host_connection_id = {
1584 let collaborators = session
1585 .db()
1586 .await
1587 .project_collaborators(project_id, session.connection_id)
1588 .await?;
1589 collaborators
1590 .iter()
1591 .find(|collaborator| collaborator.is_host)
1592 .ok_or_else(|| anyhow!("host not found"))?
1593 .connection_id
1594 };
1595
1596 let payload = session
1597 .peer
1598 .forward_request(session.connection_id, host_connection_id, request)
1599 .await?;
1600
1601 response.send(payload)?;
1602 Ok(())
1603}
1604
1605async fn create_buffer_for_peer(
1606 request: proto::CreateBufferForPeer,
1607 session: Session,
1608) -> Result<()> {
1609 session.executor.record_backtrace();
1610 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1611 session
1612 .peer
1613 .forward_send(session.connection_id, peer_id.into(), request)?;
1614 Ok(())
1615}
1616
1617async fn update_buffer(
1618 request: proto::UpdateBuffer,
1619 response: Response<proto::UpdateBuffer>,
1620 session: Session,
1621) -> Result<()> {
1622 session.executor.record_backtrace();
1623 let project_id = ProjectId::from_proto(request.project_id);
1624 let mut guest_connection_ids;
1625 let mut host_connection_id = None;
1626 {
1627 let collaborators = session
1628 .db()
1629 .await
1630 .project_collaborators(project_id, session.connection_id)
1631 .await?;
1632 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1633 for collaborator in collaborators.iter() {
1634 if collaborator.is_host {
1635 host_connection_id = Some(collaborator.connection_id);
1636 } else {
1637 guest_connection_ids.push(collaborator.connection_id);
1638 }
1639 }
1640 }
1641 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1642
1643 session.executor.record_backtrace();
1644 broadcast(
1645 Some(session.connection_id),
1646 guest_connection_ids,
1647 |connection_id| {
1648 session
1649 .peer
1650 .forward_send(session.connection_id, connection_id, request.clone())
1651 },
1652 );
1653 if host_connection_id != session.connection_id {
1654 session
1655 .peer
1656 .forward_request(session.connection_id, host_connection_id, request.clone())
1657 .await?;
1658 }
1659
1660 response.send(proto::Ack {})?;
1661 Ok(())
1662}
1663
1664async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1665 let project_id = ProjectId::from_proto(request.project_id);
1666 let project_connection_ids = session
1667 .db()
1668 .await
1669 .project_connection_ids(project_id, session.connection_id)
1670 .await?;
1671
1672 broadcast(
1673 Some(session.connection_id),
1674 project_connection_ids.iter().copied(),
1675 |connection_id| {
1676 session
1677 .peer
1678 .forward_send(session.connection_id, connection_id, request.clone())
1679 },
1680 );
1681 Ok(())
1682}
1683
1684async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1685 let project_id = ProjectId::from_proto(request.project_id);
1686 let project_connection_ids = session
1687 .db()
1688 .await
1689 .project_connection_ids(project_id, session.connection_id)
1690 .await?;
1691 broadcast(
1692 Some(session.connection_id),
1693 project_connection_ids.iter().copied(),
1694 |connection_id| {
1695 session
1696 .peer
1697 .forward_send(session.connection_id, connection_id, request.clone())
1698 },
1699 );
1700 Ok(())
1701}
1702
1703async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1704 let project_id = ProjectId::from_proto(request.project_id);
1705 let project_connection_ids = session
1706 .db()
1707 .await
1708 .project_connection_ids(project_id, session.connection_id)
1709 .await?;
1710 broadcast(
1711 Some(session.connection_id),
1712 project_connection_ids.iter().copied(),
1713 |connection_id| {
1714 session
1715 .peer
1716 .forward_send(session.connection_id, connection_id, request.clone())
1717 },
1718 );
1719 Ok(())
1720}
1721
1722async fn follow(
1723 request: proto::Follow,
1724 response: Response<proto::Follow>,
1725 session: Session,
1726) -> Result<()> {
1727 let project_id = ProjectId::from_proto(request.project_id);
1728 let leader_id = request
1729 .leader_id
1730 .ok_or_else(|| anyhow!("invalid leader id"))?
1731 .into();
1732 let follower_id = session.connection_id;
1733
1734 {
1735 let project_connection_ids = session
1736 .db()
1737 .await
1738 .project_connection_ids(project_id, session.connection_id)
1739 .await?;
1740
1741 if !project_connection_ids.contains(&leader_id) {
1742 Err(anyhow!("no such peer"))?;
1743 }
1744 }
1745
1746 let mut response_payload = session
1747 .peer
1748 .forward_request(session.connection_id, leader_id, request)
1749 .await?;
1750 response_payload
1751 .views
1752 .retain(|view| view.leader_id != Some(follower_id.into()));
1753 response.send(response_payload)?;
1754
1755 let room = session
1756 .db()
1757 .await
1758 .follow(project_id, leader_id, follower_id)
1759 .await?;
1760 room_updated(&room, &session.peer);
1761
1762 Ok(())
1763}
1764
1765async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1766 let project_id = ProjectId::from_proto(request.project_id);
1767 let leader_id = request
1768 .leader_id
1769 .ok_or_else(|| anyhow!("invalid leader id"))?
1770 .into();
1771 let follower_id = session.connection_id;
1772
1773 if !session
1774 .db()
1775 .await
1776 .project_connection_ids(project_id, session.connection_id)
1777 .await?
1778 .contains(&leader_id)
1779 {
1780 Err(anyhow!("no such peer"))?;
1781 }
1782
1783 session
1784 .peer
1785 .forward_send(session.connection_id, leader_id, request)?;
1786
1787 let room = session
1788 .db()
1789 .await
1790 .unfollow(project_id, leader_id, follower_id)
1791 .await?;
1792 room_updated(&room, &session.peer);
1793
1794 Ok(())
1795}
1796
1797async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1798 let project_id = ProjectId::from_proto(request.project_id);
1799 let project_connection_ids = session
1800 .db
1801 .lock()
1802 .await
1803 .project_connection_ids(project_id, session.connection_id)
1804 .await?;
1805
1806 let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1807 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1808 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1809 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1810 });
1811 for follower_peer_id in request.follower_ids.iter().copied() {
1812 let follower_connection_id = follower_peer_id.into();
1813 if project_connection_ids.contains(&follower_connection_id)
1814 && Some(follower_peer_id) != leader_id
1815 {
1816 session.peer.forward_send(
1817 session.connection_id,
1818 follower_connection_id,
1819 request.clone(),
1820 )?;
1821 }
1822 }
1823 Ok(())
1824}
1825
1826async fn get_users(
1827 request: proto::GetUsers,
1828 response: Response<proto::GetUsers>,
1829 session: Session,
1830) -> Result<()> {
1831 let user_ids = request
1832 .user_ids
1833 .into_iter()
1834 .map(UserId::from_proto)
1835 .collect();
1836 let users = session
1837 .db()
1838 .await
1839 .get_users_by_ids(user_ids)
1840 .await?
1841 .into_iter()
1842 .map(|user| proto::User {
1843 id: user.id.to_proto(),
1844 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1845 github_login: user.github_login,
1846 })
1847 .collect();
1848 response.send(proto::UsersResponse { users })?;
1849 Ok(())
1850}
1851
1852async fn fuzzy_search_users(
1853 request: proto::FuzzySearchUsers,
1854 response: Response<proto::FuzzySearchUsers>,
1855 session: Session,
1856) -> Result<()> {
1857 let query = request.query;
1858 let users = match query.len() {
1859 0 => vec![],
1860 1 | 2 => session
1861 .db()
1862 .await
1863 .get_user_by_github_login(&query)
1864 .await?
1865 .into_iter()
1866 .collect(),
1867 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1868 };
1869 let users = users
1870 .into_iter()
1871 .filter(|user| user.id != session.user_id)
1872 .map(|user| proto::User {
1873 id: user.id.to_proto(),
1874 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1875 github_login: user.github_login,
1876 })
1877 .collect();
1878 response.send(proto::UsersResponse { users })?;
1879 Ok(())
1880}
1881
1882async fn request_contact(
1883 request: proto::RequestContact,
1884 response: Response<proto::RequestContact>,
1885 session: Session,
1886) -> Result<()> {
1887 let requester_id = session.user_id;
1888 let responder_id = UserId::from_proto(request.responder_id);
1889 if requester_id == responder_id {
1890 return Err(anyhow!("cannot add yourself as a contact"))?;
1891 }
1892
1893 session
1894 .db()
1895 .await
1896 .send_contact_request(requester_id, responder_id)
1897 .await?;
1898
1899 // Update outgoing contact requests of requester
1900 let mut update = proto::UpdateContacts::default();
1901 update.outgoing_requests.push(responder_id.to_proto());
1902 for connection_id in session
1903 .connection_pool()
1904 .await
1905 .user_connection_ids(requester_id)
1906 {
1907 session.peer.send(connection_id, update.clone())?;
1908 }
1909
1910 // Update incoming contact requests of responder
1911 let mut update = proto::UpdateContacts::default();
1912 update
1913 .incoming_requests
1914 .push(proto::IncomingContactRequest {
1915 requester_id: requester_id.to_proto(),
1916 should_notify: true,
1917 });
1918 for connection_id in session
1919 .connection_pool()
1920 .await
1921 .user_connection_ids(responder_id)
1922 {
1923 session.peer.send(connection_id, update.clone())?;
1924 }
1925
1926 response.send(proto::Ack {})?;
1927 Ok(())
1928}
1929
1930async fn respond_to_contact_request(
1931 request: proto::RespondToContactRequest,
1932 response: Response<proto::RespondToContactRequest>,
1933 session: Session,
1934) -> Result<()> {
1935 let responder_id = session.user_id;
1936 let requester_id = UserId::from_proto(request.requester_id);
1937 let db = session.db().await;
1938 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
1939 db.dismiss_contact_notification(responder_id, requester_id)
1940 .await?;
1941 } else {
1942 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
1943
1944 db.respond_to_contact_request(responder_id, requester_id, accept)
1945 .await?;
1946 let requester_busy = db.is_user_busy(requester_id).await?;
1947 let responder_busy = db.is_user_busy(responder_id).await?;
1948
1949 let pool = session.connection_pool().await;
1950 // Update responder with new contact
1951 let mut update = proto::UpdateContacts::default();
1952 if accept {
1953 update
1954 .contacts
1955 .push(contact_for_user(requester_id, false, requester_busy, &pool));
1956 }
1957 update
1958 .remove_incoming_requests
1959 .push(requester_id.to_proto());
1960 for connection_id in pool.user_connection_ids(responder_id) {
1961 session.peer.send(connection_id, update.clone())?;
1962 }
1963
1964 // Update requester with new contact
1965 let mut update = proto::UpdateContacts::default();
1966 if accept {
1967 update
1968 .contacts
1969 .push(contact_for_user(responder_id, true, responder_busy, &pool));
1970 }
1971 update
1972 .remove_outgoing_requests
1973 .push(responder_id.to_proto());
1974 for connection_id in pool.user_connection_ids(requester_id) {
1975 session.peer.send(connection_id, update.clone())?;
1976 }
1977 }
1978
1979 response.send(proto::Ack {})?;
1980 Ok(())
1981}
1982
1983async fn remove_contact(
1984 request: proto::RemoveContact,
1985 response: Response<proto::RemoveContact>,
1986 session: Session,
1987) -> Result<()> {
1988 let requester_id = session.user_id;
1989 let responder_id = UserId::from_proto(request.user_id);
1990 let db = session.db().await;
1991 let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
1992
1993 let pool = session.connection_pool().await;
1994 // Update outgoing contact requests of requester
1995 let mut update = proto::UpdateContacts::default();
1996 if contact_accepted {
1997 update.remove_contacts.push(responder_id.to_proto());
1998 } else {
1999 update
2000 .remove_outgoing_requests
2001 .push(responder_id.to_proto());
2002 }
2003 for connection_id in pool.user_connection_ids(requester_id) {
2004 session.peer.send(connection_id, update.clone())?;
2005 }
2006
2007 // Update incoming contact requests of responder
2008 let mut update = proto::UpdateContacts::default();
2009 if contact_accepted {
2010 update.remove_contacts.push(requester_id.to_proto());
2011 } else {
2012 update
2013 .remove_incoming_requests
2014 .push(requester_id.to_proto());
2015 }
2016 for connection_id in pool.user_connection_ids(responder_id) {
2017 session.peer.send(connection_id, update.clone())?;
2018 }
2019
2020 response.send(proto::Ack {})?;
2021 Ok(())
2022}
2023
2024async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
2025 let project_id = ProjectId::from_proto(request.project_id);
2026 let project_connection_ids = session
2027 .db()
2028 .await
2029 .project_connection_ids(project_id, session.connection_id)
2030 .await?;
2031 broadcast(
2032 Some(session.connection_id),
2033 project_connection_ids.iter().copied(),
2034 |connection_id| {
2035 session
2036 .peer
2037 .forward_send(session.connection_id, connection_id, request.clone())
2038 },
2039 );
2040 Ok(())
2041}
2042
2043async fn get_private_user_info(
2044 _request: proto::GetPrivateUserInfo,
2045 response: Response<proto::GetPrivateUserInfo>,
2046 session: Session,
2047) -> Result<()> {
2048 let metrics_id = session
2049 .db()
2050 .await
2051 .get_user_metrics_id(session.user_id)
2052 .await?;
2053 let user = session
2054 .db()
2055 .await
2056 .get_user_by_id(session.user_id)
2057 .await?
2058 .ok_or_else(|| anyhow!("user not found"))?;
2059 response.send(proto::GetPrivateUserInfoResponse {
2060 metrics_id,
2061 staff: user.admin,
2062 })?;
2063 Ok(())
2064}
2065
2066fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
2067 match message {
2068 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
2069 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
2070 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
2071 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
2072 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
2073 code: frame.code.into(),
2074 reason: frame.reason,
2075 })),
2076 }
2077}
2078
2079fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
2080 match message {
2081 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
2082 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
2083 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
2084 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
2085 AxumMessage::Close(frame) => {
2086 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
2087 code: frame.code.into(),
2088 reason: frame.reason,
2089 }))
2090 }
2091 }
2092}
2093
2094fn build_initial_contacts_update(
2095 contacts: Vec<db::Contact>,
2096 pool: &ConnectionPool,
2097) -> proto::UpdateContacts {
2098 let mut update = proto::UpdateContacts::default();
2099
2100 for contact in contacts {
2101 match contact {
2102 db::Contact::Accepted {
2103 user_id,
2104 should_notify,
2105 busy,
2106 } => {
2107 update
2108 .contacts
2109 .push(contact_for_user(user_id, should_notify, busy, &pool));
2110 }
2111 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
2112 db::Contact::Incoming {
2113 user_id,
2114 should_notify,
2115 } => update
2116 .incoming_requests
2117 .push(proto::IncomingContactRequest {
2118 requester_id: user_id.to_proto(),
2119 should_notify,
2120 }),
2121 }
2122 }
2123
2124 update
2125}
2126
2127fn contact_for_user(
2128 user_id: UserId,
2129 should_notify: bool,
2130 busy: bool,
2131 pool: &ConnectionPool,
2132) -> proto::Contact {
2133 proto::Contact {
2134 user_id: user_id.to_proto(),
2135 online: pool.is_user_online(user_id),
2136 busy,
2137 should_notify,
2138 }
2139}
2140
2141fn room_updated(room: &proto::Room, peer: &Peer) {
2142 broadcast(
2143 None,
2144 room.participants
2145 .iter()
2146 .filter_map(|participant| Some(participant.peer_id?.into())),
2147 |peer_id| {
2148 peer.send(
2149 peer_id.into(),
2150 proto::RoomUpdated {
2151 room: Some(room.clone()),
2152 },
2153 )
2154 },
2155 );
2156}
2157
2158async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
2159 let db = session.db().await;
2160 let contacts = db.get_contacts(user_id).await?;
2161 let busy = db.is_user_busy(user_id).await?;
2162
2163 let pool = session.connection_pool().await;
2164 let updated_contact = contact_for_user(user_id, false, busy, &pool);
2165 for contact in contacts {
2166 if let db::Contact::Accepted {
2167 user_id: contact_user_id,
2168 ..
2169 } = contact
2170 {
2171 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
2172 session
2173 .peer
2174 .send(
2175 contact_conn_id,
2176 proto::UpdateContacts {
2177 contacts: vec![updated_contact.clone()],
2178 remove_contacts: Default::default(),
2179 incoming_requests: Default::default(),
2180 remove_incoming_requests: Default::default(),
2181 outgoing_requests: Default::default(),
2182 remove_outgoing_requests: Default::default(),
2183 },
2184 )
2185 .trace_err();
2186 }
2187 }
2188 }
2189 Ok(())
2190}
2191
2192async fn leave_room_for_session(session: &Session) -> Result<()> {
2193 let mut contacts_to_update = HashSet::default();
2194
2195 let room_id;
2196 let canceled_calls_to_user_ids;
2197 let live_kit_room;
2198 let delete_live_kit_room;
2199 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
2200 contacts_to_update.insert(session.user_id);
2201
2202 for project in left_room.left_projects.values() {
2203 project_left(project, session);
2204 }
2205
2206 room_updated(&left_room.room, &session.peer);
2207 room_id = RoomId::from_proto(left_room.room.id);
2208 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
2209 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
2210 delete_live_kit_room = left_room.room.participants.is_empty();
2211 } else {
2212 return Ok(());
2213 }
2214
2215 {
2216 let pool = session.connection_pool().await;
2217 for canceled_user_id in canceled_calls_to_user_ids {
2218 for connection_id in pool.user_connection_ids(canceled_user_id) {
2219 session
2220 .peer
2221 .send(
2222 connection_id,
2223 proto::CallCanceled {
2224 room_id: room_id.to_proto(),
2225 },
2226 )
2227 .trace_err();
2228 }
2229 contacts_to_update.insert(canceled_user_id);
2230 }
2231 }
2232
2233 for contact_user_id in contacts_to_update {
2234 update_user_contacts(contact_user_id, &session).await?;
2235 }
2236
2237 if let Some(live_kit) = session.live_kit_client.as_ref() {
2238 live_kit
2239 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
2240 .await
2241 .trace_err();
2242
2243 if delete_live_kit_room {
2244 live_kit.delete_room(live_kit_room).await.trace_err();
2245 }
2246 }
2247
2248 Ok(())
2249}
2250
2251fn project_left(project: &db::LeftProject, session: &Session) {
2252 for connection_id in &project.connection_ids {
2253 if project.host_user_id == session.user_id {
2254 session
2255 .peer
2256 .send(
2257 *connection_id,
2258 proto::UnshareProject {
2259 project_id: project.id.to_proto(),
2260 },
2261 )
2262 .trace_err();
2263 } else {
2264 session
2265 .peer
2266 .send(
2267 *connection_id,
2268 proto::RemoveProjectCollaborator {
2269 project_id: project.id.to_proto(),
2270 peer_id: Some(session.connection_id.into()),
2271 },
2272 )
2273 .trace_err();
2274 }
2275 }
2276}
2277
2278pub trait ResultExt {
2279 type Ok;
2280
2281 fn trace_err(self) -> Option<Self::Ok>;
2282}
2283
2284impl<T, E> ResultExt for Result<T, E>
2285where
2286 E: std::fmt::Debug,
2287{
2288 type Ok = T;
2289
2290 fn trace_err(self) -> Option<T> {
2291 match self {
2292 Ok(value) => Some(value),
2293 Err(error) => {
2294 tracing::error!("{:?}", error);
2295 None
2296 }
2297 }
2298 }
2299}