1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{self, Database, ProjectId, RoomId, ServerId, User, UserId},
6 executor::Executor,
7 AppState, Result,
8};
9use anyhow::anyhow;
10use async_tungstenite::tungstenite::{
11 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
12};
13use axum::{
14 body::Body,
15 extract::{
16 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
17 ConnectInfo, WebSocketUpgrade,
18 },
19 headers::{Header, HeaderName},
20 http::StatusCode,
21 middleware,
22 response::IntoResponse,
23 routing::get,
24 Extension, Router, TypedHeader,
25};
26use collections::{HashMap, HashSet};
27pub use connection_pool::ConnectionPool;
28use futures::{
29 channel::oneshot,
30 future::{self, BoxFuture},
31 stream::FuturesUnordered,
32 FutureExt, SinkExt, StreamExt, TryStreamExt,
33};
34use lazy_static::lazy_static;
35use prometheus::{register_int_gauge, IntGauge};
36use rpc::{
37 proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
38 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
39};
40use serde::{Serialize, Serializer};
41use std::{
42 any::TypeId,
43 fmt,
44 future::Future,
45 marker::PhantomData,
46 mem,
47 net::SocketAddr,
48 ops::{Deref, DerefMut},
49 rc::Rc,
50 sync::{
51 atomic::{AtomicBool, Ordering::SeqCst},
52 Arc,
53 },
54 time::{Duration, Instant},
55};
56use tokio::sync::{watch, Semaphore};
57use tower::ServiceBuilder;
58use tracing::{info_span, instrument, Instrument};
59
60pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
61pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
62
63lazy_static! {
64 static ref METRIC_CONNECTIONS: IntGauge =
65 register_int_gauge!("connections", "number of connections").unwrap();
66 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
67 "shared_projects",
68 "number of open projects with one or more guests"
69 )
70 .unwrap();
71}
72
73type MessageHandler =
74 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
75
76struct Response<R> {
77 peer: Arc<Peer>,
78 receipt: Receipt<R>,
79 responded: Arc<AtomicBool>,
80}
81
82impl<R: RequestMessage> Response<R> {
83 fn send(self, payload: R::Response) -> Result<()> {
84 self.responded.store(true, SeqCst);
85 self.peer.respond(self.receipt, payload)?;
86 Ok(())
87 }
88}
89
90#[derive(Clone)]
91struct Session {
92 user_id: UserId,
93 connection_id: ConnectionId,
94 db: Arc<tokio::sync::Mutex<DbHandle>>,
95 peer: Arc<Peer>,
96 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
97 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
98 executor: Executor,
99}
100
101impl Session {
102 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
103 #[cfg(test)]
104 tokio::task::yield_now().await;
105 let guard = self.db.lock().await;
106 #[cfg(test)]
107 tokio::task::yield_now().await;
108 guard
109 }
110
111 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
112 #[cfg(test)]
113 tokio::task::yield_now().await;
114 let guard = self.connection_pool.lock();
115 ConnectionPoolGuard {
116 guard,
117 _not_send: PhantomData,
118 }
119 }
120}
121
122impl fmt::Debug for Session {
123 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124 f.debug_struct("Session")
125 .field("user_id", &self.user_id)
126 .field("connection_id", &self.connection_id)
127 .finish()
128 }
129}
130
131struct DbHandle(Arc<Database>);
132
133impl Deref for DbHandle {
134 type Target = Database;
135
136 fn deref(&self) -> &Self::Target {
137 self.0.as_ref()
138 }
139}
140
141pub struct Server {
142 id: parking_lot::Mutex<ServerId>,
143 peer: Arc<Peer>,
144 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
145 app_state: Arc<AppState>,
146 executor: Executor,
147 handlers: HashMap<TypeId, MessageHandler>,
148 teardown: watch::Sender<()>,
149}
150
151pub(crate) struct ConnectionPoolGuard<'a> {
152 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
153 _not_send: PhantomData<Rc<()>>,
154}
155
156#[derive(Serialize)]
157pub struct ServerSnapshot<'a> {
158 peer: &'a Peer,
159 #[serde(serialize_with = "serialize_deref")]
160 connection_pool: ConnectionPoolGuard<'a>,
161}
162
163pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
164where
165 S: Serializer,
166 T: Deref<Target = U>,
167 U: Serialize,
168{
169 Serialize::serialize(value.deref(), serializer)
170}
171
172impl Server {
173 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
174 let mut server = Self {
175 id: parking_lot::Mutex::new(id),
176 peer: Peer::new(id.0 as u32),
177 app_state,
178 executor,
179 connection_pool: Default::default(),
180 handlers: Default::default(),
181 teardown: watch::channel(()).0,
182 };
183
184 server
185 .add_request_handler(ping)
186 .add_request_handler(create_room)
187 .add_request_handler(join_room)
188 .add_request_handler(rejoin_room)
189 .add_request_handler(leave_room)
190 .add_request_handler(call)
191 .add_request_handler(cancel_call)
192 .add_message_handler(decline_call)
193 .add_request_handler(update_participant_location)
194 .add_request_handler(share_project)
195 .add_message_handler(unshare_project)
196 .add_request_handler(join_project)
197 .add_message_handler(leave_project)
198 .add_request_handler(update_project)
199 .add_request_handler(update_worktree)
200 .add_message_handler(start_language_server)
201 .add_message_handler(update_language_server)
202 .add_message_handler(update_diagnostic_summary)
203 .add_message_handler(update_worktree_settings)
204 .add_request_handler(forward_project_request::<proto::GetHover>)
205 .add_request_handler(forward_project_request::<proto::GetDefinition>)
206 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
207 .add_request_handler(forward_project_request::<proto::GetReferences>)
208 .add_request_handler(forward_project_request::<proto::SearchProject>)
209 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
210 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
211 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
212 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
213 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
214 .add_request_handler(forward_project_request::<proto::GetCompletions>)
215 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
216 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
217 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
218 .add_request_handler(forward_project_request::<proto::PrepareRename>)
219 .add_request_handler(forward_project_request::<proto::PerformRename>)
220 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
221 .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
222 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
223 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
224 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
225 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
226 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
227 .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
228 .add_message_handler(create_buffer_for_peer)
229 .add_request_handler(update_buffer)
230 .add_message_handler(update_buffer_file)
231 .add_message_handler(buffer_reloaded)
232 .add_message_handler(buffer_saved)
233 .add_request_handler(forward_project_request::<proto::SaveBuffer>)
234 .add_request_handler(get_users)
235 .add_request_handler(fuzzy_search_users)
236 .add_request_handler(request_contact)
237 .add_request_handler(remove_contact)
238 .add_request_handler(respond_to_contact_request)
239 .add_request_handler(follow)
240 .add_message_handler(unfollow)
241 .add_message_handler(update_followers)
242 .add_message_handler(update_diff_base)
243 .add_request_handler(get_private_user_info);
244
245 Arc::new(server)
246 }
247
248 pub async fn start(&self) -> Result<()> {
249 let server_id = *self.id.lock();
250 let app_state = self.app_state.clone();
251 let peer = self.peer.clone();
252 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
253 let pool = self.connection_pool.clone();
254 let live_kit_client = self.app_state.live_kit_client.clone();
255
256 let span = info_span!("start server");
257 self.executor.spawn_detached(
258 async move {
259 tracing::info!("waiting for cleanup timeout");
260 timeout.await;
261 tracing::info!("cleanup timeout expired, retrieving stale rooms");
262 if let Some(room_ids) = app_state
263 .db
264 .stale_room_ids(&app_state.config.zed_environment, server_id)
265 .await
266 .trace_err()
267 {
268 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
269 for room_id in room_ids {
270 let mut contacts_to_update = HashSet::default();
271 let mut canceled_calls_to_user_ids = Vec::new();
272 let mut live_kit_room = String::new();
273 let mut delete_live_kit_room = false;
274
275 if let Some(mut refreshed_room) = app_state
276 .db
277 .refresh_room(room_id, server_id)
278 .await
279 .trace_err()
280 {
281 tracing::info!(
282 room_id = room_id.0,
283 new_participant_count = refreshed_room.room.participants.len(),
284 "refreshed room"
285 );
286 room_updated(&refreshed_room.room, &peer);
287 contacts_to_update
288 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
289 contacts_to_update
290 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
291 canceled_calls_to_user_ids =
292 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
293 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
294 delete_live_kit_room = refreshed_room.room.participants.is_empty();
295 }
296
297 {
298 let pool = pool.lock();
299 for canceled_user_id in canceled_calls_to_user_ids {
300 for connection_id in pool.user_connection_ids(canceled_user_id) {
301 peer.send(
302 connection_id,
303 proto::CallCanceled {
304 room_id: room_id.to_proto(),
305 },
306 )
307 .trace_err();
308 }
309 }
310 }
311
312 for user_id in contacts_to_update {
313 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
314 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
315 if let Some((busy, contacts)) = busy.zip(contacts) {
316 let pool = pool.lock();
317 let updated_contact = contact_for_user(user_id, false, busy, &pool);
318 for contact in contacts {
319 if let db::Contact::Accepted {
320 user_id: contact_user_id,
321 ..
322 } = contact
323 {
324 for contact_conn_id in
325 pool.user_connection_ids(contact_user_id)
326 {
327 peer.send(
328 contact_conn_id,
329 proto::UpdateContacts {
330 contacts: vec![updated_contact.clone()],
331 remove_contacts: Default::default(),
332 incoming_requests: Default::default(),
333 remove_incoming_requests: Default::default(),
334 outgoing_requests: Default::default(),
335 remove_outgoing_requests: Default::default(),
336 },
337 )
338 .trace_err();
339 }
340 }
341 }
342 }
343 }
344
345 if let Some(live_kit) = live_kit_client.as_ref() {
346 if delete_live_kit_room {
347 live_kit.delete_room(live_kit_room).await.trace_err();
348 }
349 }
350 }
351 }
352
353 app_state
354 .db
355 .delete_stale_servers(&app_state.config.zed_environment, server_id)
356 .await
357 .trace_err();
358 }
359 .instrument(span),
360 );
361 Ok(())
362 }
363
364 pub fn teardown(&self) {
365 self.peer.teardown();
366 self.connection_pool.lock().reset();
367 let _ = self.teardown.send(());
368 }
369
370 #[cfg(test)]
371 pub fn reset(&self, id: ServerId) {
372 self.teardown();
373 *self.id.lock() = id;
374 self.peer.reset(id.0 as u32);
375 }
376
377 #[cfg(test)]
378 pub fn id(&self) -> ServerId {
379 *self.id.lock()
380 }
381
382 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
383 where
384 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
385 Fut: 'static + Send + Future<Output = Result<()>>,
386 M: EnvelopedMessage,
387 {
388 let prev_handler = self.handlers.insert(
389 TypeId::of::<M>(),
390 Box::new(move |envelope, session| {
391 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
392 let span = info_span!(
393 "handle message",
394 payload_type = envelope.payload_type_name()
395 );
396 span.in_scope(|| {
397 tracing::info!(
398 payload_type = envelope.payload_type_name(),
399 "message received"
400 );
401 });
402 let start_time = Instant::now();
403 let future = (handler)(*envelope, session);
404 async move {
405 let result = future.await;
406 let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
407 match result {
408 Err(error) => {
409 tracing::error!(%error, ?duration_ms, "error handling message")
410 }
411 Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
412 }
413 }
414 .instrument(span)
415 .boxed()
416 }),
417 );
418 if prev_handler.is_some() {
419 panic!("registered a handler for the same message twice");
420 }
421 self
422 }
423
424 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
425 where
426 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
427 Fut: 'static + Send + Future<Output = Result<()>>,
428 M: EnvelopedMessage,
429 {
430 self.add_handler(move |envelope, session| handler(envelope.payload, session));
431 self
432 }
433
434 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
435 where
436 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
437 Fut: Send + Future<Output = Result<()>>,
438 M: RequestMessage,
439 {
440 let handler = Arc::new(handler);
441 self.add_handler(move |envelope, session| {
442 let receipt = envelope.receipt();
443 let handler = handler.clone();
444 async move {
445 let peer = session.peer.clone();
446 let responded = Arc::new(AtomicBool::default());
447 let response = Response {
448 peer: peer.clone(),
449 responded: responded.clone(),
450 receipt,
451 };
452 match (handler)(envelope.payload, response, session).await {
453 Ok(()) => {
454 if responded.load(std::sync::atomic::Ordering::SeqCst) {
455 Ok(())
456 } else {
457 Err(anyhow!("handler did not send a response"))?
458 }
459 }
460 Err(error) => {
461 peer.respond_with_error(
462 receipt,
463 proto::Error {
464 message: error.to_string(),
465 },
466 )?;
467 Err(error)
468 }
469 }
470 }
471 })
472 }
473
474 pub fn handle_connection(
475 self: &Arc<Self>,
476 connection: Connection,
477 address: String,
478 user: User,
479 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
480 executor: Executor,
481 ) -> impl Future<Output = Result<()>> {
482 let this = self.clone();
483 let user_id = user.id;
484 let login = user.github_login;
485 let span = info_span!("handle connection", %user_id, %login, %address);
486 let mut teardown = self.teardown.subscribe();
487 async move {
488 let (connection_id, handle_io, mut incoming_rx) = this
489 .peer
490 .add_connection(connection, {
491 let executor = executor.clone();
492 move |duration| executor.sleep(duration)
493 });
494
495 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
496 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
497 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
498
499 if let Some(send_connection_id) = send_connection_id.take() {
500 let _ = send_connection_id.send(connection_id);
501 }
502
503 if !user.connected_once {
504 this.peer.send(connection_id, proto::ShowContacts {})?;
505 this.app_state.db.set_user_connected_once(user_id, true).await?;
506 }
507
508 let (contacts, invite_code) = future::try_join(
509 this.app_state.db.get_contacts(user_id),
510 this.app_state.db.get_invite_code_for_user(user_id)
511 ).await?;
512
513 {
514 let mut pool = this.connection_pool.lock();
515 pool.add_connection(connection_id, user_id, user.admin);
516 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
517
518 if let Some((code, count)) = invite_code {
519 this.peer.send(connection_id, proto::UpdateInviteInfo {
520 url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
521 count: count as u32,
522 })?;
523 }
524 }
525
526 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
527 this.peer.send(connection_id, incoming_call)?;
528 }
529
530 let session = Session {
531 user_id,
532 connection_id,
533 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
534 peer: this.peer.clone(),
535 connection_pool: this.connection_pool.clone(),
536 live_kit_client: this.app_state.live_kit_client.clone(),
537 executor: executor.clone(),
538 };
539 update_user_contacts(user_id, &session).await?;
540
541 let handle_io = handle_io.fuse();
542 futures::pin_mut!(handle_io);
543
544 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
545 // This prevents deadlocks when e.g., client A performs a request to client B and
546 // client B performs a request to client A. If both clients stop processing further
547 // messages until their respective request completes, they won't have a chance to
548 // respond to the other client's request and cause a deadlock.
549 //
550 // This arrangement ensures we will attempt to process earlier messages first, but fall
551 // back to processing messages arrived later in the spirit of making progress.
552 let mut foreground_message_handlers = FuturesUnordered::new();
553 let concurrent_handlers = Arc::new(Semaphore::new(256));
554 loop {
555 let next_message = async {
556 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
557 let message = incoming_rx.next().await;
558 (permit, message)
559 }.fuse();
560 futures::pin_mut!(next_message);
561 futures::select_biased! {
562 _ = teardown.changed().fuse() => return Ok(()),
563 result = handle_io => {
564 if let Err(error) = result {
565 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
566 }
567 break;
568 }
569 _ = foreground_message_handlers.next() => {}
570 next_message = next_message => {
571 let (permit, message) = next_message;
572 if let Some(message) = message {
573 let type_name = message.payload_type_name();
574 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
575 let span_enter = span.enter();
576 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
577 let is_background = message.is_background();
578 let handle_message = (handler)(message, session.clone());
579 drop(span_enter);
580
581 let handle_message = async move {
582 handle_message.await;
583 drop(permit);
584 }.instrument(span);
585 if is_background {
586 executor.spawn_detached(handle_message);
587 } else {
588 foreground_message_handlers.push(handle_message);
589 }
590 } else {
591 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
592 }
593 } else {
594 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
595 break;
596 }
597 }
598 }
599 }
600
601 drop(foreground_message_handlers);
602 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
603 if let Err(error) = connection_lost(session, teardown, executor).await {
604 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
605 }
606
607 Ok(())
608 }.instrument(span)
609 }
610
611 pub async fn invite_code_redeemed(
612 self: &Arc<Self>,
613 inviter_id: UserId,
614 invitee_id: UserId,
615 ) -> Result<()> {
616 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
617 if let Some(code) = &user.invite_code {
618 let pool = self.connection_pool.lock();
619 let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
620 for connection_id in pool.user_connection_ids(inviter_id) {
621 self.peer.send(
622 connection_id,
623 proto::UpdateContacts {
624 contacts: vec![invitee_contact.clone()],
625 ..Default::default()
626 },
627 )?;
628 self.peer.send(
629 connection_id,
630 proto::UpdateInviteInfo {
631 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
632 count: user.invite_count as u32,
633 },
634 )?;
635 }
636 }
637 }
638 Ok(())
639 }
640
641 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
642 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
643 if let Some(invite_code) = &user.invite_code {
644 let pool = self.connection_pool.lock();
645 for connection_id in pool.user_connection_ids(user_id) {
646 self.peer.send(
647 connection_id,
648 proto::UpdateInviteInfo {
649 url: format!(
650 "{}{}",
651 self.app_state.config.invite_link_prefix, invite_code
652 ),
653 count: user.invite_count as u32,
654 },
655 )?;
656 }
657 }
658 }
659 Ok(())
660 }
661
662 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
663 ServerSnapshot {
664 connection_pool: ConnectionPoolGuard {
665 guard: self.connection_pool.lock(),
666 _not_send: PhantomData,
667 },
668 peer: &self.peer,
669 }
670 }
671}
672
673impl<'a> Deref for ConnectionPoolGuard<'a> {
674 type Target = ConnectionPool;
675
676 fn deref(&self) -> &Self::Target {
677 &*self.guard
678 }
679}
680
681impl<'a> DerefMut for ConnectionPoolGuard<'a> {
682 fn deref_mut(&mut self) -> &mut Self::Target {
683 &mut *self.guard
684 }
685}
686
687impl<'a> Drop for ConnectionPoolGuard<'a> {
688 fn drop(&mut self) {
689 #[cfg(test)]
690 self.check_invariants();
691 }
692}
693
694fn broadcast<F>(
695 sender_id: Option<ConnectionId>,
696 receiver_ids: impl IntoIterator<Item = ConnectionId>,
697 mut f: F,
698) where
699 F: FnMut(ConnectionId) -> anyhow::Result<()>,
700{
701 for receiver_id in receiver_ids {
702 if Some(receiver_id) != sender_id {
703 if let Err(error) = f(receiver_id) {
704 tracing::error!("failed to send to {:?} {}", receiver_id, error);
705 }
706 }
707 }
708}
709
710lazy_static! {
711 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
712}
713
714pub struct ProtocolVersion(u32);
715
716impl Header for ProtocolVersion {
717 fn name() -> &'static HeaderName {
718 &ZED_PROTOCOL_VERSION
719 }
720
721 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
722 where
723 Self: Sized,
724 I: Iterator<Item = &'i axum::http::HeaderValue>,
725 {
726 let version = values
727 .next()
728 .ok_or_else(axum::headers::Error::invalid)?
729 .to_str()
730 .map_err(|_| axum::headers::Error::invalid())?
731 .parse()
732 .map_err(|_| axum::headers::Error::invalid())?;
733 Ok(Self(version))
734 }
735
736 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
737 values.extend([self.0.to_string().parse().unwrap()]);
738 }
739}
740
741pub fn routes(server: Arc<Server>) -> Router<Body> {
742 Router::new()
743 .route("/rpc", get(handle_websocket_request))
744 .layer(
745 ServiceBuilder::new()
746 .layer(Extension(server.app_state.clone()))
747 .layer(middleware::from_fn(auth::validate_header)),
748 )
749 .route("/metrics", get(handle_metrics))
750 .layer(Extension(server))
751}
752
753pub async fn handle_websocket_request(
754 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
755 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
756 Extension(server): Extension<Arc<Server>>,
757 Extension(user): Extension<User>,
758 ws: WebSocketUpgrade,
759) -> axum::response::Response {
760 if protocol_version != rpc::PROTOCOL_VERSION {
761 return (
762 StatusCode::UPGRADE_REQUIRED,
763 "client must be upgraded".to_string(),
764 )
765 .into_response();
766 }
767 let socket_address = socket_address.to_string();
768 ws.on_upgrade(move |socket| {
769 use util::ResultExt;
770 let socket = socket
771 .map_ok(to_tungstenite_message)
772 .err_into()
773 .with(|message| async move { Ok(to_axum_message(message)) });
774 let connection = Connection::new(Box::pin(socket));
775 async move {
776 server
777 .handle_connection(connection, socket_address, user, None, Executor::Production)
778 .await
779 .log_err();
780 }
781 })
782}
783
784pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
785 let connections = server
786 .connection_pool
787 .lock()
788 .connections()
789 .filter(|connection| !connection.admin)
790 .count();
791
792 METRIC_CONNECTIONS.set(connections as _);
793
794 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
795 METRIC_SHARED_PROJECTS.set(shared_projects as _);
796
797 let encoder = prometheus::TextEncoder::new();
798 let metric_families = prometheus::gather();
799 let encoded_metrics = encoder
800 .encode_to_string(&metric_families)
801 .map_err(|err| anyhow!("{}", err))?;
802 Ok(encoded_metrics)
803}
804
805#[instrument(err, skip(executor))]
806async fn connection_lost(
807 session: Session,
808 mut teardown: watch::Receiver<()>,
809 executor: Executor,
810) -> Result<()> {
811 session.peer.disconnect(session.connection_id);
812 session
813 .connection_pool()
814 .await
815 .remove_connection(session.connection_id)?;
816
817 session
818 .db()
819 .await
820 .connection_lost(session.connection_id)
821 .await
822 .trace_err();
823
824 futures::select_biased! {
825 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
826 leave_room_for_session(&session).await.trace_err();
827
828 if !session
829 .connection_pool()
830 .await
831 .is_user_online(session.user_id)
832 {
833 let db = session.db().await;
834 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
835 room_updated(&room, &session.peer);
836 }
837 }
838 update_user_contacts(session.user_id, &session).await?;
839 }
840 _ = teardown.changed().fuse() => {}
841 }
842
843 Ok(())
844}
845
846async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
847 response.send(proto::Ack {})?;
848 Ok(())
849}
850
851async fn create_room(
852 _request: proto::CreateRoom,
853 response: Response<proto::CreateRoom>,
854 session: Session,
855) -> Result<()> {
856 let live_kit_room = nanoid::nanoid!(30);
857 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
858 if let Some(_) = live_kit
859 .create_room(live_kit_room.clone())
860 .await
861 .trace_err()
862 {
863 if let Some(token) = live_kit
864 .room_token(&live_kit_room, &session.user_id.to_string())
865 .trace_err()
866 {
867 Some(proto::LiveKitConnectionInfo {
868 server_url: live_kit.url().into(),
869 token,
870 })
871 } else {
872 None
873 }
874 } else {
875 None
876 }
877 } else {
878 None
879 };
880
881 {
882 let room = session
883 .db()
884 .await
885 .create_room(session.user_id, session.connection_id, &live_kit_room)
886 .await?;
887
888 response.send(proto::CreateRoomResponse {
889 room: Some(room.clone()),
890 live_kit_connection_info,
891 })?;
892 }
893
894 update_user_contacts(session.user_id, &session).await?;
895 Ok(())
896}
897
898async fn join_room(
899 request: proto::JoinRoom,
900 response: Response<proto::JoinRoom>,
901 session: Session,
902) -> Result<()> {
903 let room_id = RoomId::from_proto(request.id);
904 let room = {
905 let room = session
906 .db()
907 .await
908 .join_room(room_id, session.user_id, session.connection_id)
909 .await?;
910 room_updated(&room, &session.peer);
911 room.clone()
912 };
913
914 for connection_id in session
915 .connection_pool()
916 .await
917 .user_connection_ids(session.user_id)
918 {
919 session
920 .peer
921 .send(
922 connection_id,
923 proto::CallCanceled {
924 room_id: room_id.to_proto(),
925 },
926 )
927 .trace_err();
928 }
929
930 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
931 if let Some(token) = live_kit
932 .room_token(&room.live_kit_room, &session.user_id.to_string())
933 .trace_err()
934 {
935 Some(proto::LiveKitConnectionInfo {
936 server_url: live_kit.url().into(),
937 token,
938 })
939 } else {
940 None
941 }
942 } else {
943 None
944 };
945
946 response.send(proto::JoinRoomResponse {
947 room: Some(room),
948 live_kit_connection_info,
949 })?;
950
951 update_user_contacts(session.user_id, &session).await?;
952 Ok(())
953}
954
955async fn rejoin_room(
956 request: proto::RejoinRoom,
957 response: Response<proto::RejoinRoom>,
958 session: Session,
959) -> Result<()> {
960 {
961 let mut rejoined_room = session
962 .db()
963 .await
964 .rejoin_room(request, session.user_id, session.connection_id)
965 .await?;
966
967 response.send(proto::RejoinRoomResponse {
968 room: Some(rejoined_room.room.clone()),
969 reshared_projects: rejoined_room
970 .reshared_projects
971 .iter()
972 .map(|project| proto::ResharedProject {
973 id: project.id.to_proto(),
974 collaborators: project
975 .collaborators
976 .iter()
977 .map(|collaborator| collaborator.to_proto())
978 .collect(),
979 })
980 .collect(),
981 rejoined_projects: rejoined_room
982 .rejoined_projects
983 .iter()
984 .map(|rejoined_project| proto::RejoinedProject {
985 id: rejoined_project.id.to_proto(),
986 worktrees: rejoined_project
987 .worktrees
988 .iter()
989 .map(|worktree| proto::WorktreeMetadata {
990 id: worktree.id,
991 root_name: worktree.root_name.clone(),
992 visible: worktree.visible,
993 abs_path: worktree.abs_path.clone(),
994 })
995 .collect(),
996 collaborators: rejoined_project
997 .collaborators
998 .iter()
999 .map(|collaborator| collaborator.to_proto())
1000 .collect(),
1001 language_servers: rejoined_project.language_servers.clone(),
1002 })
1003 .collect(),
1004 })?;
1005 room_updated(&rejoined_room.room, &session.peer);
1006
1007 for project in &rejoined_room.reshared_projects {
1008 for collaborator in &project.collaborators {
1009 session
1010 .peer
1011 .send(
1012 collaborator.connection_id,
1013 proto::UpdateProjectCollaborator {
1014 project_id: project.id.to_proto(),
1015 old_peer_id: Some(project.old_connection_id.into()),
1016 new_peer_id: Some(session.connection_id.into()),
1017 },
1018 )
1019 .trace_err();
1020 }
1021
1022 broadcast(
1023 Some(session.connection_id),
1024 project
1025 .collaborators
1026 .iter()
1027 .map(|collaborator| collaborator.connection_id),
1028 |connection_id| {
1029 session.peer.forward_send(
1030 session.connection_id,
1031 connection_id,
1032 proto::UpdateProject {
1033 project_id: project.id.to_proto(),
1034 worktrees: project.worktrees.clone(),
1035 },
1036 )
1037 },
1038 );
1039 }
1040
1041 for project in &rejoined_room.rejoined_projects {
1042 for collaborator in &project.collaborators {
1043 session
1044 .peer
1045 .send(
1046 collaborator.connection_id,
1047 proto::UpdateProjectCollaborator {
1048 project_id: project.id.to_proto(),
1049 old_peer_id: Some(project.old_connection_id.into()),
1050 new_peer_id: Some(session.connection_id.into()),
1051 },
1052 )
1053 .trace_err();
1054 }
1055 }
1056
1057 for project in &mut rejoined_room.rejoined_projects {
1058 for worktree in mem::take(&mut project.worktrees) {
1059 #[cfg(any(test, feature = "test-support"))]
1060 const MAX_CHUNK_SIZE: usize = 2;
1061 #[cfg(not(any(test, feature = "test-support")))]
1062 const MAX_CHUNK_SIZE: usize = 256;
1063
1064 // Stream this worktree's entries.
1065 let message = proto::UpdateWorktree {
1066 project_id: project.id.to_proto(),
1067 worktree_id: worktree.id,
1068 abs_path: worktree.abs_path.clone(),
1069 root_name: worktree.root_name,
1070 updated_entries: worktree.updated_entries,
1071 removed_entries: worktree.removed_entries,
1072 scan_id: worktree.scan_id,
1073 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1074 updated_repositories: worktree.updated_repositories,
1075 removed_repositories: worktree.removed_repositories,
1076 };
1077 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1078 session.peer.send(session.connection_id, update.clone())?;
1079 }
1080
1081 // Stream this worktree's diagnostics.
1082 for summary in worktree.diagnostic_summaries {
1083 session.peer.send(
1084 session.connection_id,
1085 proto::UpdateDiagnosticSummary {
1086 project_id: project.id.to_proto(),
1087 worktree_id: worktree.id,
1088 summary: Some(summary),
1089 },
1090 )?;
1091 }
1092
1093 for settings_file in worktree.settings_files {
1094 session.peer.send(
1095 session.connection_id,
1096 proto::UpdateWorktreeSettings {
1097 project_id: project.id.to_proto(),
1098 worktree_id: worktree.id,
1099 path: settings_file.path,
1100 content: Some(settings_file.content),
1101 },
1102 )?;
1103 }
1104 }
1105
1106 for language_server in &project.language_servers {
1107 session.peer.send(
1108 session.connection_id,
1109 proto::UpdateLanguageServer {
1110 project_id: project.id.to_proto(),
1111 language_server_id: language_server.id,
1112 variant: Some(
1113 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1114 proto::LspDiskBasedDiagnosticsUpdated {},
1115 ),
1116 ),
1117 },
1118 )?;
1119 }
1120 }
1121 }
1122
1123 update_user_contacts(session.user_id, &session).await?;
1124 Ok(())
1125}
1126
1127async fn leave_room(
1128 _: proto::LeaveRoom,
1129 response: Response<proto::LeaveRoom>,
1130 session: Session,
1131) -> Result<()> {
1132 leave_room_for_session(&session).await?;
1133 response.send(proto::Ack {})?;
1134 Ok(())
1135}
1136
1137async fn call(
1138 request: proto::Call,
1139 response: Response<proto::Call>,
1140 session: Session,
1141) -> Result<()> {
1142 let room_id = RoomId::from_proto(request.room_id);
1143 let calling_user_id = session.user_id;
1144 let calling_connection_id = session.connection_id;
1145 let called_user_id = UserId::from_proto(request.called_user_id);
1146 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1147 if !session
1148 .db()
1149 .await
1150 .has_contact(calling_user_id, called_user_id)
1151 .await?
1152 {
1153 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1154 }
1155
1156 let incoming_call = {
1157 let (room, incoming_call) = &mut *session
1158 .db()
1159 .await
1160 .call(
1161 room_id,
1162 calling_user_id,
1163 calling_connection_id,
1164 called_user_id,
1165 initial_project_id,
1166 )
1167 .await?;
1168 room_updated(&room, &session.peer);
1169 mem::take(incoming_call)
1170 };
1171 update_user_contacts(called_user_id, &session).await?;
1172
1173 let mut calls = session
1174 .connection_pool()
1175 .await
1176 .user_connection_ids(called_user_id)
1177 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1178 .collect::<FuturesUnordered<_>>();
1179
1180 while let Some(call_response) = calls.next().await {
1181 match call_response.as_ref() {
1182 Ok(_) => {
1183 response.send(proto::Ack {})?;
1184 return Ok(());
1185 }
1186 Err(_) => {
1187 call_response.trace_err();
1188 }
1189 }
1190 }
1191
1192 {
1193 let room = session
1194 .db()
1195 .await
1196 .call_failed(room_id, called_user_id)
1197 .await?;
1198 room_updated(&room, &session.peer);
1199 }
1200 update_user_contacts(called_user_id, &session).await?;
1201
1202 Err(anyhow!("failed to ring user"))?
1203}
1204
1205async fn cancel_call(
1206 request: proto::CancelCall,
1207 response: Response<proto::CancelCall>,
1208 session: Session,
1209) -> Result<()> {
1210 let called_user_id = UserId::from_proto(request.called_user_id);
1211 let room_id = RoomId::from_proto(request.room_id);
1212 {
1213 let room = session
1214 .db()
1215 .await
1216 .cancel_call(room_id, session.connection_id, called_user_id)
1217 .await?;
1218 room_updated(&room, &session.peer);
1219 }
1220
1221 for connection_id in session
1222 .connection_pool()
1223 .await
1224 .user_connection_ids(called_user_id)
1225 {
1226 session
1227 .peer
1228 .send(
1229 connection_id,
1230 proto::CallCanceled {
1231 room_id: room_id.to_proto(),
1232 },
1233 )
1234 .trace_err();
1235 }
1236 response.send(proto::Ack {})?;
1237
1238 update_user_contacts(called_user_id, &session).await?;
1239 Ok(())
1240}
1241
1242async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1243 let room_id = RoomId::from_proto(message.room_id);
1244 {
1245 let room = session
1246 .db()
1247 .await
1248 .decline_call(Some(room_id), session.user_id)
1249 .await?
1250 .ok_or_else(|| anyhow!("failed to decline call"))?;
1251 room_updated(&room, &session.peer);
1252 }
1253
1254 for connection_id in session
1255 .connection_pool()
1256 .await
1257 .user_connection_ids(session.user_id)
1258 {
1259 session
1260 .peer
1261 .send(
1262 connection_id,
1263 proto::CallCanceled {
1264 room_id: room_id.to_proto(),
1265 },
1266 )
1267 .trace_err();
1268 }
1269 update_user_contacts(session.user_id, &session).await?;
1270 Ok(())
1271}
1272
1273async fn update_participant_location(
1274 request: proto::UpdateParticipantLocation,
1275 response: Response<proto::UpdateParticipantLocation>,
1276 session: Session,
1277) -> Result<()> {
1278 let room_id = RoomId::from_proto(request.room_id);
1279 let location = request
1280 .location
1281 .ok_or_else(|| anyhow!("invalid location"))?;
1282 let room = session
1283 .db()
1284 .await
1285 .update_room_participant_location(room_id, session.connection_id, location)
1286 .await?;
1287 room_updated(&room, &session.peer);
1288 response.send(proto::Ack {})?;
1289 Ok(())
1290}
1291
1292async fn share_project(
1293 request: proto::ShareProject,
1294 response: Response<proto::ShareProject>,
1295 session: Session,
1296) -> Result<()> {
1297 let (project_id, room) = &*session
1298 .db()
1299 .await
1300 .share_project(
1301 RoomId::from_proto(request.room_id),
1302 session.connection_id,
1303 &request.worktrees,
1304 )
1305 .await?;
1306 response.send(proto::ShareProjectResponse {
1307 project_id: project_id.to_proto(),
1308 })?;
1309 room_updated(&room, &session.peer);
1310
1311 Ok(())
1312}
1313
1314async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1315 let project_id = ProjectId::from_proto(message.project_id);
1316
1317 let (room, guest_connection_ids) = &*session
1318 .db()
1319 .await
1320 .unshare_project(project_id, session.connection_id)
1321 .await?;
1322
1323 broadcast(
1324 Some(session.connection_id),
1325 guest_connection_ids.iter().copied(),
1326 |conn_id| session.peer.send(conn_id, message.clone()),
1327 );
1328 room_updated(&room, &session.peer);
1329
1330 Ok(())
1331}
1332
1333async fn join_project(
1334 request: proto::JoinProject,
1335 response: Response<proto::JoinProject>,
1336 session: Session,
1337) -> Result<()> {
1338 let project_id = ProjectId::from_proto(request.project_id);
1339 let guest_user_id = session.user_id;
1340
1341 tracing::info!(%project_id, "join project");
1342
1343 let (project, replica_id) = &mut *session
1344 .db()
1345 .await
1346 .join_project(project_id, session.connection_id)
1347 .await?;
1348
1349 let collaborators = project
1350 .collaborators
1351 .iter()
1352 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1353 .map(|collaborator| collaborator.to_proto())
1354 .collect::<Vec<_>>();
1355
1356 let worktrees = project
1357 .worktrees
1358 .iter()
1359 .map(|(id, worktree)| proto::WorktreeMetadata {
1360 id: *id,
1361 root_name: worktree.root_name.clone(),
1362 visible: worktree.visible,
1363 abs_path: worktree.abs_path.clone(),
1364 })
1365 .collect::<Vec<_>>();
1366
1367 for collaborator in &collaborators {
1368 session
1369 .peer
1370 .send(
1371 collaborator.peer_id.unwrap().into(),
1372 proto::AddProjectCollaborator {
1373 project_id: project_id.to_proto(),
1374 collaborator: Some(proto::Collaborator {
1375 peer_id: Some(session.connection_id.into()),
1376 replica_id: replica_id.0 as u32,
1377 user_id: guest_user_id.to_proto(),
1378 }),
1379 },
1380 )
1381 .trace_err();
1382 }
1383
1384 // First, we send the metadata associated with each worktree.
1385 response.send(proto::JoinProjectResponse {
1386 worktrees: worktrees.clone(),
1387 replica_id: replica_id.0 as u32,
1388 collaborators: collaborators.clone(),
1389 language_servers: project.language_servers.clone(),
1390 })?;
1391
1392 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1393 #[cfg(any(test, feature = "test-support"))]
1394 const MAX_CHUNK_SIZE: usize = 2;
1395 #[cfg(not(any(test, feature = "test-support")))]
1396 const MAX_CHUNK_SIZE: usize = 256;
1397
1398 // Stream this worktree's entries.
1399 let message = proto::UpdateWorktree {
1400 project_id: project_id.to_proto(),
1401 worktree_id,
1402 abs_path: worktree.abs_path.clone(),
1403 root_name: worktree.root_name,
1404 updated_entries: worktree.entries,
1405 removed_entries: Default::default(),
1406 scan_id: worktree.scan_id,
1407 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1408 updated_repositories: worktree.repository_entries.into_values().collect(),
1409 removed_repositories: Default::default(),
1410 };
1411 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1412 session.peer.send(session.connection_id, update.clone())?;
1413 }
1414
1415 // Stream this worktree's diagnostics.
1416 for summary in worktree.diagnostic_summaries {
1417 session.peer.send(
1418 session.connection_id,
1419 proto::UpdateDiagnosticSummary {
1420 project_id: project_id.to_proto(),
1421 worktree_id: worktree.id,
1422 summary: Some(summary),
1423 },
1424 )?;
1425 }
1426
1427 for settings_file in worktree.settings_files {
1428 session.peer.send(
1429 session.connection_id,
1430 proto::UpdateWorktreeSettings {
1431 project_id: project_id.to_proto(),
1432 worktree_id: worktree.id,
1433 path: settings_file.path,
1434 content: Some(settings_file.content),
1435 },
1436 )?;
1437 }
1438 }
1439
1440 for language_server in &project.language_servers {
1441 session.peer.send(
1442 session.connection_id,
1443 proto::UpdateLanguageServer {
1444 project_id: project_id.to_proto(),
1445 language_server_id: language_server.id,
1446 variant: Some(
1447 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1448 proto::LspDiskBasedDiagnosticsUpdated {},
1449 ),
1450 ),
1451 },
1452 )?;
1453 }
1454
1455 Ok(())
1456}
1457
1458async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1459 let sender_id = session.connection_id;
1460 let project_id = ProjectId::from_proto(request.project_id);
1461
1462 let (room, project) = &*session
1463 .db()
1464 .await
1465 .leave_project(project_id, sender_id)
1466 .await?;
1467 tracing::info!(
1468 %project_id,
1469 host_user_id = %project.host_user_id,
1470 host_connection_id = %project.host_connection_id,
1471 "leave project"
1472 );
1473
1474 project_left(&project, &session);
1475 room_updated(&room, &session.peer);
1476
1477 Ok(())
1478}
1479
1480async fn update_project(
1481 request: proto::UpdateProject,
1482 response: Response<proto::UpdateProject>,
1483 session: Session,
1484) -> Result<()> {
1485 let project_id = ProjectId::from_proto(request.project_id);
1486 let (room, guest_connection_ids) = &*session
1487 .db()
1488 .await
1489 .update_project(project_id, session.connection_id, &request.worktrees)
1490 .await?;
1491 broadcast(
1492 Some(session.connection_id),
1493 guest_connection_ids.iter().copied(),
1494 |connection_id| {
1495 session
1496 .peer
1497 .forward_send(session.connection_id, connection_id, request.clone())
1498 },
1499 );
1500 room_updated(&room, &session.peer);
1501 response.send(proto::Ack {})?;
1502
1503 Ok(())
1504}
1505
1506async fn update_worktree(
1507 request: proto::UpdateWorktree,
1508 response: Response<proto::UpdateWorktree>,
1509 session: Session,
1510) -> Result<()> {
1511 let guest_connection_ids = session
1512 .db()
1513 .await
1514 .update_worktree(&request, session.connection_id)
1515 .await?;
1516
1517 broadcast(
1518 Some(session.connection_id),
1519 guest_connection_ids.iter().copied(),
1520 |connection_id| {
1521 session
1522 .peer
1523 .forward_send(session.connection_id, connection_id, request.clone())
1524 },
1525 );
1526 response.send(proto::Ack {})?;
1527 Ok(())
1528}
1529
1530async fn update_diagnostic_summary(
1531 message: proto::UpdateDiagnosticSummary,
1532 session: Session,
1533) -> Result<()> {
1534 let guest_connection_ids = session
1535 .db()
1536 .await
1537 .update_diagnostic_summary(&message, session.connection_id)
1538 .await?;
1539
1540 broadcast(
1541 Some(session.connection_id),
1542 guest_connection_ids.iter().copied(),
1543 |connection_id| {
1544 session
1545 .peer
1546 .forward_send(session.connection_id, connection_id, message.clone())
1547 },
1548 );
1549
1550 Ok(())
1551}
1552
1553async fn update_worktree_settings(
1554 message: proto::UpdateWorktreeSettings,
1555 session: Session,
1556) -> Result<()> {
1557 let guest_connection_ids = session
1558 .db()
1559 .await
1560 .update_worktree_settings(&message, session.connection_id)
1561 .await?;
1562
1563 broadcast(
1564 Some(session.connection_id),
1565 guest_connection_ids.iter().copied(),
1566 |connection_id| {
1567 session
1568 .peer
1569 .forward_send(session.connection_id, connection_id, message.clone())
1570 },
1571 );
1572
1573 Ok(())
1574}
1575
1576async fn start_language_server(
1577 request: proto::StartLanguageServer,
1578 session: Session,
1579) -> Result<()> {
1580 let guest_connection_ids = session
1581 .db()
1582 .await
1583 .start_language_server(&request, session.connection_id)
1584 .await?;
1585
1586 broadcast(
1587 Some(session.connection_id),
1588 guest_connection_ids.iter().copied(),
1589 |connection_id| {
1590 session
1591 .peer
1592 .forward_send(session.connection_id, connection_id, request.clone())
1593 },
1594 );
1595 Ok(())
1596}
1597
1598async fn update_language_server(
1599 request: proto::UpdateLanguageServer,
1600 session: Session,
1601) -> Result<()> {
1602 session.executor.record_backtrace();
1603 let project_id = ProjectId::from_proto(request.project_id);
1604 let project_connection_ids = session
1605 .db()
1606 .await
1607 .project_connection_ids(project_id, session.connection_id)
1608 .await?;
1609 broadcast(
1610 Some(session.connection_id),
1611 project_connection_ids.iter().copied(),
1612 |connection_id| {
1613 session
1614 .peer
1615 .forward_send(session.connection_id, connection_id, request.clone())
1616 },
1617 );
1618 Ok(())
1619}
1620
1621async fn forward_project_request<T>(
1622 request: T,
1623 response: Response<T>,
1624 session: Session,
1625) -> Result<()>
1626where
1627 T: EntityMessage + RequestMessage,
1628{
1629 session.executor.record_backtrace();
1630 let project_id = ProjectId::from_proto(request.remote_entity_id());
1631 let host_connection_id = {
1632 let collaborators = session
1633 .db()
1634 .await
1635 .project_collaborators(project_id, session.connection_id)
1636 .await?;
1637 collaborators
1638 .iter()
1639 .find(|collaborator| collaborator.is_host)
1640 .ok_or_else(|| anyhow!("host not found"))?
1641 .connection_id
1642 };
1643
1644 let payload = session
1645 .peer
1646 .forward_request(session.connection_id, host_connection_id, request)
1647 .await?;
1648
1649 response.send(payload)?;
1650 Ok(())
1651}
1652
1653async fn create_buffer_for_peer(
1654 request: proto::CreateBufferForPeer,
1655 session: Session,
1656) -> Result<()> {
1657 session.executor.record_backtrace();
1658 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1659 session
1660 .peer
1661 .forward_send(session.connection_id, peer_id.into(), request)?;
1662 Ok(())
1663}
1664
1665async fn update_buffer(
1666 request: proto::UpdateBuffer,
1667 response: Response<proto::UpdateBuffer>,
1668 session: Session,
1669) -> Result<()> {
1670 session.executor.record_backtrace();
1671 let project_id = ProjectId::from_proto(request.project_id);
1672 let mut guest_connection_ids;
1673 let mut host_connection_id = None;
1674 {
1675 let collaborators = session
1676 .db()
1677 .await
1678 .project_collaborators(project_id, session.connection_id)
1679 .await?;
1680 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1681 for collaborator in collaborators.iter() {
1682 if collaborator.is_host {
1683 host_connection_id = Some(collaborator.connection_id);
1684 } else {
1685 guest_connection_ids.push(collaborator.connection_id);
1686 }
1687 }
1688 }
1689 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1690
1691 session.executor.record_backtrace();
1692 broadcast(
1693 Some(session.connection_id),
1694 guest_connection_ids,
1695 |connection_id| {
1696 session
1697 .peer
1698 .forward_send(session.connection_id, connection_id, request.clone())
1699 },
1700 );
1701 if host_connection_id != session.connection_id {
1702 session
1703 .peer
1704 .forward_request(session.connection_id, host_connection_id, request.clone())
1705 .await?;
1706 }
1707
1708 response.send(proto::Ack {})?;
1709 Ok(())
1710}
1711
1712async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1713 let project_id = ProjectId::from_proto(request.project_id);
1714 let project_connection_ids = session
1715 .db()
1716 .await
1717 .project_connection_ids(project_id, session.connection_id)
1718 .await?;
1719
1720 broadcast(
1721 Some(session.connection_id),
1722 project_connection_ids.iter().copied(),
1723 |connection_id| {
1724 session
1725 .peer
1726 .forward_send(session.connection_id, connection_id, request.clone())
1727 },
1728 );
1729 Ok(())
1730}
1731
1732async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1733 let project_id = ProjectId::from_proto(request.project_id);
1734 let project_connection_ids = session
1735 .db()
1736 .await
1737 .project_connection_ids(project_id, session.connection_id)
1738 .await?;
1739 broadcast(
1740 Some(session.connection_id),
1741 project_connection_ids.iter().copied(),
1742 |connection_id| {
1743 session
1744 .peer
1745 .forward_send(session.connection_id, connection_id, request.clone())
1746 },
1747 );
1748 Ok(())
1749}
1750
1751async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1752 let project_id = ProjectId::from_proto(request.project_id);
1753 let project_connection_ids = session
1754 .db()
1755 .await
1756 .project_connection_ids(project_id, session.connection_id)
1757 .await?;
1758 broadcast(
1759 Some(session.connection_id),
1760 project_connection_ids.iter().copied(),
1761 |connection_id| {
1762 session
1763 .peer
1764 .forward_send(session.connection_id, connection_id, request.clone())
1765 },
1766 );
1767 Ok(())
1768}
1769
1770async fn follow(
1771 request: proto::Follow,
1772 response: Response<proto::Follow>,
1773 session: Session,
1774) -> Result<()> {
1775 let project_id = ProjectId::from_proto(request.project_id);
1776 let leader_id = request
1777 .leader_id
1778 .ok_or_else(|| anyhow!("invalid leader id"))?
1779 .into();
1780 let follower_id = session.connection_id;
1781
1782 {
1783 let project_connection_ids = session
1784 .db()
1785 .await
1786 .project_connection_ids(project_id, session.connection_id)
1787 .await?;
1788
1789 if !project_connection_ids.contains(&leader_id) {
1790 Err(anyhow!("no such peer"))?;
1791 }
1792 }
1793
1794 let mut response_payload = session
1795 .peer
1796 .forward_request(session.connection_id, leader_id, request)
1797 .await?;
1798 response_payload
1799 .views
1800 .retain(|view| view.leader_id != Some(follower_id.into()));
1801 response.send(response_payload)?;
1802
1803 let room = session
1804 .db()
1805 .await
1806 .follow(project_id, leader_id, follower_id)
1807 .await?;
1808 room_updated(&room, &session.peer);
1809
1810 Ok(())
1811}
1812
1813async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1814 let project_id = ProjectId::from_proto(request.project_id);
1815 let leader_id = request
1816 .leader_id
1817 .ok_or_else(|| anyhow!("invalid leader id"))?
1818 .into();
1819 let follower_id = session.connection_id;
1820
1821 if !session
1822 .db()
1823 .await
1824 .project_connection_ids(project_id, session.connection_id)
1825 .await?
1826 .contains(&leader_id)
1827 {
1828 Err(anyhow!("no such peer"))?;
1829 }
1830
1831 session
1832 .peer
1833 .forward_send(session.connection_id, leader_id, request)?;
1834
1835 let room = session
1836 .db()
1837 .await
1838 .unfollow(project_id, leader_id, follower_id)
1839 .await?;
1840 room_updated(&room, &session.peer);
1841
1842 Ok(())
1843}
1844
1845async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1846 let project_id = ProjectId::from_proto(request.project_id);
1847 let project_connection_ids = session
1848 .db
1849 .lock()
1850 .await
1851 .project_connection_ids(project_id, session.connection_id)
1852 .await?;
1853
1854 let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1855 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1856 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1857 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1858 });
1859 for follower_peer_id in request.follower_ids.iter().copied() {
1860 let follower_connection_id = follower_peer_id.into();
1861 if project_connection_ids.contains(&follower_connection_id)
1862 && Some(follower_peer_id) != leader_id
1863 {
1864 session.peer.forward_send(
1865 session.connection_id,
1866 follower_connection_id,
1867 request.clone(),
1868 )?;
1869 }
1870 }
1871 Ok(())
1872}
1873
1874async fn get_users(
1875 request: proto::GetUsers,
1876 response: Response<proto::GetUsers>,
1877 session: Session,
1878) -> Result<()> {
1879 let user_ids = request
1880 .user_ids
1881 .into_iter()
1882 .map(UserId::from_proto)
1883 .collect();
1884 let users = session
1885 .db()
1886 .await
1887 .get_users_by_ids(user_ids)
1888 .await?
1889 .into_iter()
1890 .map(|user| proto::User {
1891 id: user.id.to_proto(),
1892 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1893 github_login: user.github_login,
1894 })
1895 .collect();
1896 response.send(proto::UsersResponse { users })?;
1897 Ok(())
1898}
1899
1900async fn fuzzy_search_users(
1901 request: proto::FuzzySearchUsers,
1902 response: Response<proto::FuzzySearchUsers>,
1903 session: Session,
1904) -> Result<()> {
1905 let query = request.query;
1906 let users = match query.len() {
1907 0 => vec![],
1908 1 | 2 => session
1909 .db()
1910 .await
1911 .get_user_by_github_login(&query)
1912 .await?
1913 .into_iter()
1914 .collect(),
1915 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1916 };
1917 let users = users
1918 .into_iter()
1919 .filter(|user| user.id != session.user_id)
1920 .map(|user| proto::User {
1921 id: user.id.to_proto(),
1922 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1923 github_login: user.github_login,
1924 })
1925 .collect();
1926 response.send(proto::UsersResponse { users })?;
1927 Ok(())
1928}
1929
1930async fn request_contact(
1931 request: proto::RequestContact,
1932 response: Response<proto::RequestContact>,
1933 session: Session,
1934) -> Result<()> {
1935 let requester_id = session.user_id;
1936 let responder_id = UserId::from_proto(request.responder_id);
1937 if requester_id == responder_id {
1938 return Err(anyhow!("cannot add yourself as a contact"))?;
1939 }
1940
1941 session
1942 .db()
1943 .await
1944 .send_contact_request(requester_id, responder_id)
1945 .await?;
1946
1947 // Update outgoing contact requests of requester
1948 let mut update = proto::UpdateContacts::default();
1949 update.outgoing_requests.push(responder_id.to_proto());
1950 for connection_id in session
1951 .connection_pool()
1952 .await
1953 .user_connection_ids(requester_id)
1954 {
1955 session.peer.send(connection_id, update.clone())?;
1956 }
1957
1958 // Update incoming contact requests of responder
1959 let mut update = proto::UpdateContacts::default();
1960 update
1961 .incoming_requests
1962 .push(proto::IncomingContactRequest {
1963 requester_id: requester_id.to_proto(),
1964 should_notify: true,
1965 });
1966 for connection_id in session
1967 .connection_pool()
1968 .await
1969 .user_connection_ids(responder_id)
1970 {
1971 session.peer.send(connection_id, update.clone())?;
1972 }
1973
1974 response.send(proto::Ack {})?;
1975 Ok(())
1976}
1977
1978async fn respond_to_contact_request(
1979 request: proto::RespondToContactRequest,
1980 response: Response<proto::RespondToContactRequest>,
1981 session: Session,
1982) -> Result<()> {
1983 let responder_id = session.user_id;
1984 let requester_id = UserId::from_proto(request.requester_id);
1985 let db = session.db().await;
1986 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
1987 db.dismiss_contact_notification(responder_id, requester_id)
1988 .await?;
1989 } else {
1990 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
1991
1992 db.respond_to_contact_request(responder_id, requester_id, accept)
1993 .await?;
1994 let requester_busy = db.is_user_busy(requester_id).await?;
1995 let responder_busy = db.is_user_busy(responder_id).await?;
1996
1997 let pool = session.connection_pool().await;
1998 // Update responder with new contact
1999 let mut update = proto::UpdateContacts::default();
2000 if accept {
2001 update
2002 .contacts
2003 .push(contact_for_user(requester_id, false, requester_busy, &pool));
2004 }
2005 update
2006 .remove_incoming_requests
2007 .push(requester_id.to_proto());
2008 for connection_id in pool.user_connection_ids(responder_id) {
2009 session.peer.send(connection_id, update.clone())?;
2010 }
2011
2012 // Update requester with new contact
2013 let mut update = proto::UpdateContacts::default();
2014 if accept {
2015 update
2016 .contacts
2017 .push(contact_for_user(responder_id, true, responder_busy, &pool));
2018 }
2019 update
2020 .remove_outgoing_requests
2021 .push(responder_id.to_proto());
2022 for connection_id in pool.user_connection_ids(requester_id) {
2023 session.peer.send(connection_id, update.clone())?;
2024 }
2025 }
2026
2027 response.send(proto::Ack {})?;
2028 Ok(())
2029}
2030
2031async fn remove_contact(
2032 request: proto::RemoveContact,
2033 response: Response<proto::RemoveContact>,
2034 session: Session,
2035) -> Result<()> {
2036 let requester_id = session.user_id;
2037 let responder_id = UserId::from_proto(request.user_id);
2038 let db = session.db().await;
2039 let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
2040
2041 let pool = session.connection_pool().await;
2042 // Update outgoing contact requests of requester
2043 let mut update = proto::UpdateContacts::default();
2044 if contact_accepted {
2045 update.remove_contacts.push(responder_id.to_proto());
2046 } else {
2047 update
2048 .remove_outgoing_requests
2049 .push(responder_id.to_proto());
2050 }
2051 for connection_id in pool.user_connection_ids(requester_id) {
2052 session.peer.send(connection_id, update.clone())?;
2053 }
2054
2055 // Update incoming contact requests of responder
2056 let mut update = proto::UpdateContacts::default();
2057 if contact_accepted {
2058 update.remove_contacts.push(requester_id.to_proto());
2059 } else {
2060 update
2061 .remove_incoming_requests
2062 .push(requester_id.to_proto());
2063 }
2064 for connection_id in pool.user_connection_ids(responder_id) {
2065 session.peer.send(connection_id, update.clone())?;
2066 }
2067
2068 response.send(proto::Ack {})?;
2069 Ok(())
2070}
2071
2072async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
2073 let project_id = ProjectId::from_proto(request.project_id);
2074 let project_connection_ids = session
2075 .db()
2076 .await
2077 .project_connection_ids(project_id, session.connection_id)
2078 .await?;
2079 broadcast(
2080 Some(session.connection_id),
2081 project_connection_ids.iter().copied(),
2082 |connection_id| {
2083 session
2084 .peer
2085 .forward_send(session.connection_id, connection_id, request.clone())
2086 },
2087 );
2088 Ok(())
2089}
2090
2091async fn get_private_user_info(
2092 _request: proto::GetPrivateUserInfo,
2093 response: Response<proto::GetPrivateUserInfo>,
2094 session: Session,
2095) -> Result<()> {
2096 let metrics_id = session
2097 .db()
2098 .await
2099 .get_user_metrics_id(session.user_id)
2100 .await?;
2101 let user = session
2102 .db()
2103 .await
2104 .get_user_by_id(session.user_id)
2105 .await?
2106 .ok_or_else(|| anyhow!("user not found"))?;
2107 response.send(proto::GetPrivateUserInfoResponse {
2108 metrics_id,
2109 staff: user.admin,
2110 })?;
2111 Ok(())
2112}
2113
2114fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
2115 match message {
2116 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
2117 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
2118 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
2119 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
2120 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
2121 code: frame.code.into(),
2122 reason: frame.reason,
2123 })),
2124 }
2125}
2126
2127fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
2128 match message {
2129 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
2130 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
2131 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
2132 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
2133 AxumMessage::Close(frame) => {
2134 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
2135 code: frame.code.into(),
2136 reason: frame.reason,
2137 }))
2138 }
2139 }
2140}
2141
2142fn build_initial_contacts_update(
2143 contacts: Vec<db::Contact>,
2144 pool: &ConnectionPool,
2145) -> proto::UpdateContacts {
2146 let mut update = proto::UpdateContacts::default();
2147
2148 for contact in contacts {
2149 match contact {
2150 db::Contact::Accepted {
2151 user_id,
2152 should_notify,
2153 busy,
2154 } => {
2155 update
2156 .contacts
2157 .push(contact_for_user(user_id, should_notify, busy, &pool));
2158 }
2159 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
2160 db::Contact::Incoming {
2161 user_id,
2162 should_notify,
2163 } => update
2164 .incoming_requests
2165 .push(proto::IncomingContactRequest {
2166 requester_id: user_id.to_proto(),
2167 should_notify,
2168 }),
2169 }
2170 }
2171
2172 update
2173}
2174
2175fn contact_for_user(
2176 user_id: UserId,
2177 should_notify: bool,
2178 busy: bool,
2179 pool: &ConnectionPool,
2180) -> proto::Contact {
2181 proto::Contact {
2182 user_id: user_id.to_proto(),
2183 online: pool.is_user_online(user_id),
2184 busy,
2185 should_notify,
2186 }
2187}
2188
2189fn room_updated(room: &proto::Room, peer: &Peer) {
2190 broadcast(
2191 None,
2192 room.participants
2193 .iter()
2194 .filter_map(|participant| Some(participant.peer_id?.into())),
2195 |peer_id| {
2196 peer.send(
2197 peer_id.into(),
2198 proto::RoomUpdated {
2199 room: Some(room.clone()),
2200 },
2201 )
2202 },
2203 );
2204}
2205
2206async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
2207 let db = session.db().await;
2208 let contacts = db.get_contacts(user_id).await?;
2209 let busy = db.is_user_busy(user_id).await?;
2210
2211 let pool = session.connection_pool().await;
2212 let updated_contact = contact_for_user(user_id, false, busy, &pool);
2213 for contact in contacts {
2214 if let db::Contact::Accepted {
2215 user_id: contact_user_id,
2216 ..
2217 } = contact
2218 {
2219 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
2220 session
2221 .peer
2222 .send(
2223 contact_conn_id,
2224 proto::UpdateContacts {
2225 contacts: vec![updated_contact.clone()],
2226 remove_contacts: Default::default(),
2227 incoming_requests: Default::default(),
2228 remove_incoming_requests: Default::default(),
2229 outgoing_requests: Default::default(),
2230 remove_outgoing_requests: Default::default(),
2231 },
2232 )
2233 .trace_err();
2234 }
2235 }
2236 }
2237 Ok(())
2238}
2239
2240async fn leave_room_for_session(session: &Session) -> Result<()> {
2241 let mut contacts_to_update = HashSet::default();
2242
2243 let room_id;
2244 let canceled_calls_to_user_ids;
2245 let live_kit_room;
2246 let delete_live_kit_room;
2247 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
2248 contacts_to_update.insert(session.user_id);
2249
2250 for project in left_room.left_projects.values() {
2251 project_left(project, session);
2252 }
2253
2254 room_updated(&left_room.room, &session.peer);
2255 room_id = RoomId::from_proto(left_room.room.id);
2256 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
2257 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
2258 delete_live_kit_room = left_room.room.participants.is_empty();
2259 } else {
2260 return Ok(());
2261 }
2262
2263 {
2264 let pool = session.connection_pool().await;
2265 for canceled_user_id in canceled_calls_to_user_ids {
2266 for connection_id in pool.user_connection_ids(canceled_user_id) {
2267 session
2268 .peer
2269 .send(
2270 connection_id,
2271 proto::CallCanceled {
2272 room_id: room_id.to_proto(),
2273 },
2274 )
2275 .trace_err();
2276 }
2277 contacts_to_update.insert(canceled_user_id);
2278 }
2279 }
2280
2281 for contact_user_id in contacts_to_update {
2282 update_user_contacts(contact_user_id, &session).await?;
2283 }
2284
2285 if let Some(live_kit) = session.live_kit_client.as_ref() {
2286 live_kit
2287 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
2288 .await
2289 .trace_err();
2290
2291 if delete_live_kit_room {
2292 live_kit.delete_room(live_kit_room).await.trace_err();
2293 }
2294 }
2295
2296 Ok(())
2297}
2298
2299fn project_left(project: &db::LeftProject, session: &Session) {
2300 for connection_id in &project.connection_ids {
2301 if project.host_user_id == session.user_id {
2302 session
2303 .peer
2304 .send(
2305 *connection_id,
2306 proto::UnshareProject {
2307 project_id: project.id.to_proto(),
2308 },
2309 )
2310 .trace_err();
2311 } else {
2312 session
2313 .peer
2314 .send(
2315 *connection_id,
2316 proto::RemoveProjectCollaborator {
2317 project_id: project.id.to_proto(),
2318 peer_id: Some(session.connection_id.into()),
2319 },
2320 )
2321 .trace_err();
2322 }
2323 }
2324}
2325
2326pub trait ResultExt {
2327 type Ok;
2328
2329 fn trace_err(self) -> Option<Self::Ok>;
2330}
2331
2332impl<T, E> ResultExt for Result<T, E>
2333where
2334 E: std::fmt::Debug,
2335{
2336 type Ok = T;
2337
2338 fn trace_err(self) -> Option<T> {
2339 match self {
2340 Ok(value) => Some(value),
2341 Err(error) => {
2342 tracing::error!("{:?}", error);
2343 None
2344 }
2345 }
2346 }
2347}