1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{self, Database, ProjectId, RoomId, ServerId, User, UserId},
6 executor::Executor,
7 AppState, Result,
8};
9use anyhow::anyhow;
10use async_tungstenite::tungstenite::{
11 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
12};
13use axum::{
14 body::Body,
15 extract::{
16 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
17 ConnectInfo, WebSocketUpgrade,
18 },
19 headers::{Header, HeaderName},
20 http::StatusCode,
21 middleware,
22 response::IntoResponse,
23 routing::get,
24 Extension, Router, TypedHeader,
25};
26use collections::{HashMap, HashSet};
27pub use connection_pool::ConnectionPool;
28use futures::{
29 channel::oneshot,
30 future::{self, BoxFuture},
31 stream::FuturesUnordered,
32 FutureExt, SinkExt, StreamExt, TryStreamExt,
33};
34use lazy_static::lazy_static;
35use prometheus::{register_int_gauge, IntGauge};
36use rpc::{
37 proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
38 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
39};
40use serde::{Serialize, Serializer};
41use std::{
42 any::TypeId,
43 fmt,
44 future::Future,
45 marker::PhantomData,
46 mem,
47 net::SocketAddr,
48 ops::{Deref, DerefMut},
49 rc::Rc,
50 sync::{
51 atomic::{AtomicBool, Ordering::SeqCst},
52 Arc,
53 },
54 time::{Duration, Instant},
55};
56use tokio::sync::{watch, Semaphore};
57use tower::ServiceBuilder;
58use tracing::{info_span, instrument, Instrument};
59
60pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
61pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
62
63lazy_static! {
64 static ref METRIC_CONNECTIONS: IntGauge =
65 register_int_gauge!("connections", "number of connections").unwrap();
66 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
67 "shared_projects",
68 "number of open projects with one or more guests"
69 )
70 .unwrap();
71}
72
73type MessageHandler =
74 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
75
76struct Response<R> {
77 peer: Arc<Peer>,
78 receipt: Receipt<R>,
79 responded: Arc<AtomicBool>,
80}
81
82impl<R: RequestMessage> Response<R> {
83 fn send(self, payload: R::Response) -> Result<()> {
84 self.responded.store(true, SeqCst);
85 self.peer.respond(self.receipt, payload)?;
86 Ok(())
87 }
88}
89
90#[derive(Clone)]
91struct Session {
92 user_id: UserId,
93 connection_id: ConnectionId,
94 db: Arc<tokio::sync::Mutex<DbHandle>>,
95 peer: Arc<Peer>,
96 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
97 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
98 executor: Executor,
99}
100
101impl Session {
102 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
103 #[cfg(test)]
104 tokio::task::yield_now().await;
105 let guard = self.db.lock().await;
106 #[cfg(test)]
107 tokio::task::yield_now().await;
108 guard
109 }
110
111 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
112 #[cfg(test)]
113 tokio::task::yield_now().await;
114 let guard = self.connection_pool.lock();
115 ConnectionPoolGuard {
116 guard,
117 _not_send: PhantomData,
118 }
119 }
120}
121
122impl fmt::Debug for Session {
123 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124 f.debug_struct("Session")
125 .field("user_id", &self.user_id)
126 .field("connection_id", &self.connection_id)
127 .finish()
128 }
129}
130
131struct DbHandle(Arc<Database>);
132
133impl Deref for DbHandle {
134 type Target = Database;
135
136 fn deref(&self) -> &Self::Target {
137 self.0.as_ref()
138 }
139}
140
141pub struct Server {
142 id: parking_lot::Mutex<ServerId>,
143 peer: Arc<Peer>,
144 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
145 app_state: Arc<AppState>,
146 executor: Executor,
147 handlers: HashMap<TypeId, MessageHandler>,
148 teardown: watch::Sender<()>,
149}
150
151pub(crate) struct ConnectionPoolGuard<'a> {
152 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
153 _not_send: PhantomData<Rc<()>>,
154}
155
156#[derive(Serialize)]
157pub struct ServerSnapshot<'a> {
158 peer: &'a Peer,
159 #[serde(serialize_with = "serialize_deref")]
160 connection_pool: ConnectionPoolGuard<'a>,
161}
162
163pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
164where
165 S: Serializer,
166 T: Deref<Target = U>,
167 U: Serialize,
168{
169 Serialize::serialize(value.deref(), serializer)
170}
171
172impl Server {
173 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
174 let mut server = Self {
175 id: parking_lot::Mutex::new(id),
176 peer: Peer::new(id.0 as u32),
177 app_state,
178 executor,
179 connection_pool: Default::default(),
180 handlers: Default::default(),
181 teardown: watch::channel(()).0,
182 };
183
184 server
185 .add_request_handler(ping)
186 .add_request_handler(create_room)
187 .add_request_handler(join_room)
188 .add_request_handler(rejoin_room)
189 .add_request_handler(leave_room)
190 .add_request_handler(call)
191 .add_request_handler(cancel_call)
192 .add_message_handler(decline_call)
193 .add_request_handler(update_participant_location)
194 .add_request_handler(share_project)
195 .add_message_handler(unshare_project)
196 .add_request_handler(join_project)
197 .add_message_handler(leave_project)
198 .add_request_handler(update_project)
199 .add_request_handler(update_worktree)
200 .add_message_handler(start_language_server)
201 .add_message_handler(update_language_server)
202 .add_message_handler(update_diagnostic_summary)
203 .add_message_handler(update_worktree_settings)
204 .add_request_handler(forward_project_request::<proto::GetHover>)
205 .add_request_handler(forward_project_request::<proto::GetDefinition>)
206 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
207 .add_request_handler(forward_project_request::<proto::GetReferences>)
208 .add_request_handler(forward_project_request::<proto::SearchProject>)
209 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
210 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
211 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
212 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
213 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
214 .add_request_handler(forward_project_request::<proto::GetCompletions>)
215 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
216 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
217 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
218 .add_request_handler(forward_project_request::<proto::PrepareRename>)
219 .add_request_handler(forward_project_request::<proto::PerformRename>)
220 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
221 .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
222 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
223 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
224 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
225 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
226 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
227 .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
228 .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
229 .add_message_handler(create_buffer_for_peer)
230 .add_request_handler(update_buffer)
231 .add_message_handler(update_buffer_file)
232 .add_message_handler(buffer_reloaded)
233 .add_message_handler(buffer_saved)
234 .add_request_handler(forward_project_request::<proto::SaveBuffer>)
235 .add_request_handler(get_users)
236 .add_request_handler(fuzzy_search_users)
237 .add_request_handler(request_contact)
238 .add_request_handler(remove_contact)
239 .add_request_handler(respond_to_contact_request)
240 .add_request_handler(follow)
241 .add_message_handler(unfollow)
242 .add_message_handler(update_followers)
243 .add_message_handler(update_diff_base)
244 .add_request_handler(get_private_user_info);
245
246 Arc::new(server)
247 }
248
249 pub async fn start(&self) -> Result<()> {
250 let server_id = *self.id.lock();
251 let app_state = self.app_state.clone();
252 let peer = self.peer.clone();
253 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
254 let pool = self.connection_pool.clone();
255 let live_kit_client = self.app_state.live_kit_client.clone();
256
257 let span = info_span!("start server");
258 self.executor.spawn_detached(
259 async move {
260 tracing::info!("waiting for cleanup timeout");
261 timeout.await;
262 tracing::info!("cleanup timeout expired, retrieving stale rooms");
263 if let Some(room_ids) = app_state
264 .db
265 .stale_room_ids(&app_state.config.zed_environment, server_id)
266 .await
267 .trace_err()
268 {
269 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
270 for room_id in room_ids {
271 let mut contacts_to_update = HashSet::default();
272 let mut canceled_calls_to_user_ids = Vec::new();
273 let mut live_kit_room = String::new();
274 let mut delete_live_kit_room = false;
275
276 if let Some(mut refreshed_room) = app_state
277 .db
278 .refresh_room(room_id, server_id)
279 .await
280 .trace_err()
281 {
282 tracing::info!(
283 room_id = room_id.0,
284 new_participant_count = refreshed_room.room.participants.len(),
285 "refreshed room"
286 );
287 room_updated(&refreshed_room.room, &peer);
288 contacts_to_update
289 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
290 contacts_to_update
291 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
292 canceled_calls_to_user_ids =
293 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
294 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
295 delete_live_kit_room = refreshed_room.room.participants.is_empty();
296 }
297
298 {
299 let pool = pool.lock();
300 for canceled_user_id in canceled_calls_to_user_ids {
301 for connection_id in pool.user_connection_ids(canceled_user_id) {
302 peer.send(
303 connection_id,
304 proto::CallCanceled {
305 room_id: room_id.to_proto(),
306 },
307 )
308 .trace_err();
309 }
310 }
311 }
312
313 for user_id in contacts_to_update {
314 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
315 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
316 if let Some((busy, contacts)) = busy.zip(contacts) {
317 let pool = pool.lock();
318 let updated_contact = contact_for_user(user_id, false, busy, &pool);
319 for contact in contacts {
320 if let db::Contact::Accepted {
321 user_id: contact_user_id,
322 ..
323 } = contact
324 {
325 for contact_conn_id in
326 pool.user_connection_ids(contact_user_id)
327 {
328 peer.send(
329 contact_conn_id,
330 proto::UpdateContacts {
331 contacts: vec![updated_contact.clone()],
332 remove_contacts: Default::default(),
333 incoming_requests: Default::default(),
334 remove_incoming_requests: Default::default(),
335 outgoing_requests: Default::default(),
336 remove_outgoing_requests: Default::default(),
337 },
338 )
339 .trace_err();
340 }
341 }
342 }
343 }
344 }
345
346 if let Some(live_kit) = live_kit_client.as_ref() {
347 if delete_live_kit_room {
348 live_kit.delete_room(live_kit_room).await.trace_err();
349 }
350 }
351 }
352 }
353
354 app_state
355 .db
356 .delete_stale_servers(&app_state.config.zed_environment, server_id)
357 .await
358 .trace_err();
359 }
360 .instrument(span),
361 );
362 Ok(())
363 }
364
365 pub fn teardown(&self) {
366 self.peer.teardown();
367 self.connection_pool.lock().reset();
368 let _ = self.teardown.send(());
369 }
370
371 #[cfg(test)]
372 pub fn reset(&self, id: ServerId) {
373 self.teardown();
374 *self.id.lock() = id;
375 self.peer.reset(id.0 as u32);
376 }
377
378 #[cfg(test)]
379 pub fn id(&self) -> ServerId {
380 *self.id.lock()
381 }
382
383 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
384 where
385 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
386 Fut: 'static + Send + Future<Output = Result<()>>,
387 M: EnvelopedMessage,
388 {
389 let prev_handler = self.handlers.insert(
390 TypeId::of::<M>(),
391 Box::new(move |envelope, session| {
392 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
393 let span = info_span!(
394 "handle message",
395 payload_type = envelope.payload_type_name()
396 );
397 span.in_scope(|| {
398 tracing::info!(
399 payload_type = envelope.payload_type_name(),
400 "message received"
401 );
402 });
403 let start_time = Instant::now();
404 let future = (handler)(*envelope, session);
405 async move {
406 let result = future.await;
407 let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
408 match result {
409 Err(error) => {
410 tracing::error!(%error, ?duration_ms, "error handling message")
411 }
412 Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
413 }
414 }
415 .instrument(span)
416 .boxed()
417 }),
418 );
419 if prev_handler.is_some() {
420 panic!("registered a handler for the same message twice");
421 }
422 self
423 }
424
425 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
426 where
427 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
428 Fut: 'static + Send + Future<Output = Result<()>>,
429 M: EnvelopedMessage,
430 {
431 self.add_handler(move |envelope, session| handler(envelope.payload, session));
432 self
433 }
434
435 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
436 where
437 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
438 Fut: Send + Future<Output = Result<()>>,
439 M: RequestMessage,
440 {
441 let handler = Arc::new(handler);
442 self.add_handler(move |envelope, session| {
443 let receipt = envelope.receipt();
444 let handler = handler.clone();
445 async move {
446 let peer = session.peer.clone();
447 let responded = Arc::new(AtomicBool::default());
448 let response = Response {
449 peer: peer.clone(),
450 responded: responded.clone(),
451 receipt,
452 };
453 match (handler)(envelope.payload, response, session).await {
454 Ok(()) => {
455 if responded.load(std::sync::atomic::Ordering::SeqCst) {
456 Ok(())
457 } else {
458 Err(anyhow!("handler did not send a response"))?
459 }
460 }
461 Err(error) => {
462 peer.respond_with_error(
463 receipt,
464 proto::Error {
465 message: error.to_string(),
466 },
467 )?;
468 Err(error)
469 }
470 }
471 }
472 })
473 }
474
475 pub fn handle_connection(
476 self: &Arc<Self>,
477 connection: Connection,
478 address: String,
479 user: User,
480 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
481 executor: Executor,
482 ) -> impl Future<Output = Result<()>> {
483 let this = self.clone();
484 let user_id = user.id;
485 let login = user.github_login;
486 let span = info_span!("handle connection", %user_id, %login, %address);
487 let mut teardown = self.teardown.subscribe();
488 async move {
489 let (connection_id, handle_io, mut incoming_rx) = this
490 .peer
491 .add_connection(connection, {
492 let executor = executor.clone();
493 move |duration| executor.sleep(duration)
494 });
495
496 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
497 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
498 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
499
500 if let Some(send_connection_id) = send_connection_id.take() {
501 let _ = send_connection_id.send(connection_id);
502 }
503
504 if !user.connected_once {
505 this.peer.send(connection_id, proto::ShowContacts {})?;
506 this.app_state.db.set_user_connected_once(user_id, true).await?;
507 }
508
509 let (contacts, invite_code) = future::try_join(
510 this.app_state.db.get_contacts(user_id),
511 this.app_state.db.get_invite_code_for_user(user_id)
512 ).await?;
513
514 {
515 let mut pool = this.connection_pool.lock();
516 pool.add_connection(connection_id, user_id, user.admin);
517 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
518
519 if let Some((code, count)) = invite_code {
520 this.peer.send(connection_id, proto::UpdateInviteInfo {
521 url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
522 count: count as u32,
523 })?;
524 }
525 }
526
527 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
528 this.peer.send(connection_id, incoming_call)?;
529 }
530
531 let session = Session {
532 user_id,
533 connection_id,
534 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
535 peer: this.peer.clone(),
536 connection_pool: this.connection_pool.clone(),
537 live_kit_client: this.app_state.live_kit_client.clone(),
538 executor: executor.clone(),
539 };
540 update_user_contacts(user_id, &session).await?;
541
542 let handle_io = handle_io.fuse();
543 futures::pin_mut!(handle_io);
544
545 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
546 // This prevents deadlocks when e.g., client A performs a request to client B and
547 // client B performs a request to client A. If both clients stop processing further
548 // messages until their respective request completes, they won't have a chance to
549 // respond to the other client's request and cause a deadlock.
550 //
551 // This arrangement ensures we will attempt to process earlier messages first, but fall
552 // back to processing messages arrived later in the spirit of making progress.
553 let mut foreground_message_handlers = FuturesUnordered::new();
554 let concurrent_handlers = Arc::new(Semaphore::new(256));
555 loop {
556 let next_message = async {
557 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
558 let message = incoming_rx.next().await;
559 (permit, message)
560 }.fuse();
561 futures::pin_mut!(next_message);
562 futures::select_biased! {
563 _ = teardown.changed().fuse() => return Ok(()),
564 result = handle_io => {
565 if let Err(error) = result {
566 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
567 }
568 break;
569 }
570 _ = foreground_message_handlers.next() => {}
571 next_message = next_message => {
572 let (permit, message) = next_message;
573 if let Some(message) = message {
574 let type_name = message.payload_type_name();
575 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
576 let span_enter = span.enter();
577 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
578 let is_background = message.is_background();
579 let handle_message = (handler)(message, session.clone());
580 drop(span_enter);
581
582 let handle_message = async move {
583 handle_message.await;
584 drop(permit);
585 }.instrument(span);
586 if is_background {
587 executor.spawn_detached(handle_message);
588 } else {
589 foreground_message_handlers.push(handle_message);
590 }
591 } else {
592 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
593 }
594 } else {
595 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
596 break;
597 }
598 }
599 }
600 }
601
602 drop(foreground_message_handlers);
603 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
604 if let Err(error) = connection_lost(session, teardown, executor).await {
605 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
606 }
607
608 Ok(())
609 }.instrument(span)
610 }
611
612 pub async fn invite_code_redeemed(
613 self: &Arc<Self>,
614 inviter_id: UserId,
615 invitee_id: UserId,
616 ) -> Result<()> {
617 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
618 if let Some(code) = &user.invite_code {
619 let pool = self.connection_pool.lock();
620 let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
621 for connection_id in pool.user_connection_ids(inviter_id) {
622 self.peer.send(
623 connection_id,
624 proto::UpdateContacts {
625 contacts: vec![invitee_contact.clone()],
626 ..Default::default()
627 },
628 )?;
629 self.peer.send(
630 connection_id,
631 proto::UpdateInviteInfo {
632 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
633 count: user.invite_count as u32,
634 },
635 )?;
636 }
637 }
638 }
639 Ok(())
640 }
641
642 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
643 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
644 if let Some(invite_code) = &user.invite_code {
645 let pool = self.connection_pool.lock();
646 for connection_id in pool.user_connection_ids(user_id) {
647 self.peer.send(
648 connection_id,
649 proto::UpdateInviteInfo {
650 url: format!(
651 "{}{}",
652 self.app_state.config.invite_link_prefix, invite_code
653 ),
654 count: user.invite_count as u32,
655 },
656 )?;
657 }
658 }
659 }
660 Ok(())
661 }
662
663 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
664 ServerSnapshot {
665 connection_pool: ConnectionPoolGuard {
666 guard: self.connection_pool.lock(),
667 _not_send: PhantomData,
668 },
669 peer: &self.peer,
670 }
671 }
672}
673
674impl<'a> Deref for ConnectionPoolGuard<'a> {
675 type Target = ConnectionPool;
676
677 fn deref(&self) -> &Self::Target {
678 &*self.guard
679 }
680}
681
682impl<'a> DerefMut for ConnectionPoolGuard<'a> {
683 fn deref_mut(&mut self) -> &mut Self::Target {
684 &mut *self.guard
685 }
686}
687
688impl<'a> Drop for ConnectionPoolGuard<'a> {
689 fn drop(&mut self) {
690 #[cfg(test)]
691 self.check_invariants();
692 }
693}
694
695fn broadcast<F>(
696 sender_id: Option<ConnectionId>,
697 receiver_ids: impl IntoIterator<Item = ConnectionId>,
698 mut f: F,
699) where
700 F: FnMut(ConnectionId) -> anyhow::Result<()>,
701{
702 for receiver_id in receiver_ids {
703 if Some(receiver_id) != sender_id {
704 if let Err(error) = f(receiver_id) {
705 tracing::error!("failed to send to {:?} {}", receiver_id, error);
706 }
707 }
708 }
709}
710
711lazy_static! {
712 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
713}
714
715pub struct ProtocolVersion(u32);
716
717impl Header for ProtocolVersion {
718 fn name() -> &'static HeaderName {
719 &ZED_PROTOCOL_VERSION
720 }
721
722 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
723 where
724 Self: Sized,
725 I: Iterator<Item = &'i axum::http::HeaderValue>,
726 {
727 let version = values
728 .next()
729 .ok_or_else(axum::headers::Error::invalid)?
730 .to_str()
731 .map_err(|_| axum::headers::Error::invalid())?
732 .parse()
733 .map_err(|_| axum::headers::Error::invalid())?;
734 Ok(Self(version))
735 }
736
737 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
738 values.extend([self.0.to_string().parse().unwrap()]);
739 }
740}
741
742pub fn routes(server: Arc<Server>) -> Router<Body> {
743 Router::new()
744 .route("/rpc", get(handle_websocket_request))
745 .layer(
746 ServiceBuilder::new()
747 .layer(Extension(server.app_state.clone()))
748 .layer(middleware::from_fn(auth::validate_header)),
749 )
750 .route("/metrics", get(handle_metrics))
751 .layer(Extension(server))
752}
753
754pub async fn handle_websocket_request(
755 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
756 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
757 Extension(server): Extension<Arc<Server>>,
758 Extension(user): Extension<User>,
759 ws: WebSocketUpgrade,
760) -> axum::response::Response {
761 if protocol_version != rpc::PROTOCOL_VERSION {
762 return (
763 StatusCode::UPGRADE_REQUIRED,
764 "client must be upgraded".to_string(),
765 )
766 .into_response();
767 }
768 let socket_address = socket_address.to_string();
769 ws.on_upgrade(move |socket| {
770 use util::ResultExt;
771 let socket = socket
772 .map_ok(to_tungstenite_message)
773 .err_into()
774 .with(|message| async move { Ok(to_axum_message(message)) });
775 let connection = Connection::new(Box::pin(socket));
776 async move {
777 server
778 .handle_connection(connection, socket_address, user, None, Executor::Production)
779 .await
780 .log_err();
781 }
782 })
783}
784
785pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
786 let connections = server
787 .connection_pool
788 .lock()
789 .connections()
790 .filter(|connection| !connection.admin)
791 .count();
792
793 METRIC_CONNECTIONS.set(connections as _);
794
795 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
796 METRIC_SHARED_PROJECTS.set(shared_projects as _);
797
798 let encoder = prometheus::TextEncoder::new();
799 let metric_families = prometheus::gather();
800 let encoded_metrics = encoder
801 .encode_to_string(&metric_families)
802 .map_err(|err| anyhow!("{}", err))?;
803 Ok(encoded_metrics)
804}
805
806#[instrument(err, skip(executor))]
807async fn connection_lost(
808 session: Session,
809 mut teardown: watch::Receiver<()>,
810 executor: Executor,
811) -> Result<()> {
812 session.peer.disconnect(session.connection_id);
813 session
814 .connection_pool()
815 .await
816 .remove_connection(session.connection_id)?;
817
818 session
819 .db()
820 .await
821 .connection_lost(session.connection_id)
822 .await
823 .trace_err();
824
825 futures::select_biased! {
826 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
827 leave_room_for_session(&session).await.trace_err();
828
829 if !session
830 .connection_pool()
831 .await
832 .is_user_online(session.user_id)
833 {
834 let db = session.db().await;
835 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
836 room_updated(&room, &session.peer);
837 }
838 }
839 update_user_contacts(session.user_id, &session).await?;
840 }
841 _ = teardown.changed().fuse() => {}
842 }
843
844 Ok(())
845}
846
847async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
848 response.send(proto::Ack {})?;
849 Ok(())
850}
851
852async fn create_room(
853 _request: proto::CreateRoom,
854 response: Response<proto::CreateRoom>,
855 session: Session,
856) -> Result<()> {
857 let live_kit_room = nanoid::nanoid!(30);
858 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
859 if let Some(_) = live_kit
860 .create_room(live_kit_room.clone())
861 .await
862 .trace_err()
863 {
864 if let Some(token) = live_kit
865 .room_token(&live_kit_room, &session.user_id.to_string())
866 .trace_err()
867 {
868 Some(proto::LiveKitConnectionInfo {
869 server_url: live_kit.url().into(),
870 token,
871 })
872 } else {
873 None
874 }
875 } else {
876 None
877 }
878 } else {
879 None
880 };
881
882 {
883 let room = session
884 .db()
885 .await
886 .create_room(session.user_id, session.connection_id, &live_kit_room)
887 .await?;
888
889 response.send(proto::CreateRoomResponse {
890 room: Some(room.clone()),
891 live_kit_connection_info,
892 })?;
893 }
894
895 update_user_contacts(session.user_id, &session).await?;
896 Ok(())
897}
898
899async fn join_room(
900 request: proto::JoinRoom,
901 response: Response<proto::JoinRoom>,
902 session: Session,
903) -> Result<()> {
904 let room_id = RoomId::from_proto(request.id);
905 let room = {
906 let room = session
907 .db()
908 .await
909 .join_room(room_id, session.user_id, session.connection_id)
910 .await?;
911 room_updated(&room, &session.peer);
912 room.clone()
913 };
914
915 for connection_id in session
916 .connection_pool()
917 .await
918 .user_connection_ids(session.user_id)
919 {
920 session
921 .peer
922 .send(
923 connection_id,
924 proto::CallCanceled {
925 room_id: room_id.to_proto(),
926 },
927 )
928 .trace_err();
929 }
930
931 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
932 if let Some(token) = live_kit
933 .room_token(&room.live_kit_room, &session.user_id.to_string())
934 .trace_err()
935 {
936 Some(proto::LiveKitConnectionInfo {
937 server_url: live_kit.url().into(),
938 token,
939 })
940 } else {
941 None
942 }
943 } else {
944 None
945 };
946
947 response.send(proto::JoinRoomResponse {
948 room: Some(room),
949 live_kit_connection_info,
950 })?;
951
952 update_user_contacts(session.user_id, &session).await?;
953 Ok(())
954}
955
956async fn rejoin_room(
957 request: proto::RejoinRoom,
958 response: Response<proto::RejoinRoom>,
959 session: Session,
960) -> Result<()> {
961 {
962 let mut rejoined_room = session
963 .db()
964 .await
965 .rejoin_room(request, session.user_id, session.connection_id)
966 .await?;
967
968 response.send(proto::RejoinRoomResponse {
969 room: Some(rejoined_room.room.clone()),
970 reshared_projects: rejoined_room
971 .reshared_projects
972 .iter()
973 .map(|project| proto::ResharedProject {
974 id: project.id.to_proto(),
975 collaborators: project
976 .collaborators
977 .iter()
978 .map(|collaborator| collaborator.to_proto())
979 .collect(),
980 })
981 .collect(),
982 rejoined_projects: rejoined_room
983 .rejoined_projects
984 .iter()
985 .map(|rejoined_project| proto::RejoinedProject {
986 id: rejoined_project.id.to_proto(),
987 worktrees: rejoined_project
988 .worktrees
989 .iter()
990 .map(|worktree| proto::WorktreeMetadata {
991 id: worktree.id,
992 root_name: worktree.root_name.clone(),
993 visible: worktree.visible,
994 abs_path: worktree.abs_path.clone(),
995 })
996 .collect(),
997 collaborators: rejoined_project
998 .collaborators
999 .iter()
1000 .map(|collaborator| collaborator.to_proto())
1001 .collect(),
1002 language_servers: rejoined_project.language_servers.clone(),
1003 })
1004 .collect(),
1005 })?;
1006 room_updated(&rejoined_room.room, &session.peer);
1007
1008 for project in &rejoined_room.reshared_projects {
1009 for collaborator in &project.collaborators {
1010 session
1011 .peer
1012 .send(
1013 collaborator.connection_id,
1014 proto::UpdateProjectCollaborator {
1015 project_id: project.id.to_proto(),
1016 old_peer_id: Some(project.old_connection_id.into()),
1017 new_peer_id: Some(session.connection_id.into()),
1018 },
1019 )
1020 .trace_err();
1021 }
1022
1023 broadcast(
1024 Some(session.connection_id),
1025 project
1026 .collaborators
1027 .iter()
1028 .map(|collaborator| collaborator.connection_id),
1029 |connection_id| {
1030 session.peer.forward_send(
1031 session.connection_id,
1032 connection_id,
1033 proto::UpdateProject {
1034 project_id: project.id.to_proto(),
1035 worktrees: project.worktrees.clone(),
1036 },
1037 )
1038 },
1039 );
1040 }
1041
1042 for project in &rejoined_room.rejoined_projects {
1043 for collaborator in &project.collaborators {
1044 session
1045 .peer
1046 .send(
1047 collaborator.connection_id,
1048 proto::UpdateProjectCollaborator {
1049 project_id: project.id.to_proto(),
1050 old_peer_id: Some(project.old_connection_id.into()),
1051 new_peer_id: Some(session.connection_id.into()),
1052 },
1053 )
1054 .trace_err();
1055 }
1056 }
1057
1058 for project in &mut rejoined_room.rejoined_projects {
1059 for worktree in mem::take(&mut project.worktrees) {
1060 #[cfg(any(test, feature = "test-support"))]
1061 const MAX_CHUNK_SIZE: usize = 2;
1062 #[cfg(not(any(test, feature = "test-support")))]
1063 const MAX_CHUNK_SIZE: usize = 256;
1064
1065 // Stream this worktree's entries.
1066 let message = proto::UpdateWorktree {
1067 project_id: project.id.to_proto(),
1068 worktree_id: worktree.id,
1069 abs_path: worktree.abs_path.clone(),
1070 root_name: worktree.root_name,
1071 updated_entries: worktree.updated_entries,
1072 removed_entries: worktree.removed_entries,
1073 scan_id: worktree.scan_id,
1074 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1075 updated_repositories: worktree.updated_repositories,
1076 removed_repositories: worktree.removed_repositories,
1077 };
1078 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1079 session.peer.send(session.connection_id, update.clone())?;
1080 }
1081
1082 // Stream this worktree's diagnostics.
1083 for summary in worktree.diagnostic_summaries {
1084 session.peer.send(
1085 session.connection_id,
1086 proto::UpdateDiagnosticSummary {
1087 project_id: project.id.to_proto(),
1088 worktree_id: worktree.id,
1089 summary: Some(summary),
1090 },
1091 )?;
1092 }
1093
1094 for settings_file in worktree.settings_files {
1095 session.peer.send(
1096 session.connection_id,
1097 proto::UpdateWorktreeSettings {
1098 project_id: project.id.to_proto(),
1099 worktree_id: worktree.id,
1100 path: settings_file.path,
1101 content: Some(settings_file.content),
1102 },
1103 )?;
1104 }
1105 }
1106
1107 for language_server in &project.language_servers {
1108 session.peer.send(
1109 session.connection_id,
1110 proto::UpdateLanguageServer {
1111 project_id: project.id.to_proto(),
1112 language_server_id: language_server.id,
1113 variant: Some(
1114 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1115 proto::LspDiskBasedDiagnosticsUpdated {},
1116 ),
1117 ),
1118 },
1119 )?;
1120 }
1121 }
1122 }
1123
1124 update_user_contacts(session.user_id, &session).await?;
1125 Ok(())
1126}
1127
1128async fn leave_room(
1129 _: proto::LeaveRoom,
1130 response: Response<proto::LeaveRoom>,
1131 session: Session,
1132) -> Result<()> {
1133 leave_room_for_session(&session).await?;
1134 response.send(proto::Ack {})?;
1135 Ok(())
1136}
1137
1138async fn call(
1139 request: proto::Call,
1140 response: Response<proto::Call>,
1141 session: Session,
1142) -> Result<()> {
1143 let room_id = RoomId::from_proto(request.room_id);
1144 let calling_user_id = session.user_id;
1145 let calling_connection_id = session.connection_id;
1146 let called_user_id = UserId::from_proto(request.called_user_id);
1147 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1148 if !session
1149 .db()
1150 .await
1151 .has_contact(calling_user_id, called_user_id)
1152 .await?
1153 {
1154 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1155 }
1156
1157 let incoming_call = {
1158 let (room, incoming_call) = &mut *session
1159 .db()
1160 .await
1161 .call(
1162 room_id,
1163 calling_user_id,
1164 calling_connection_id,
1165 called_user_id,
1166 initial_project_id,
1167 )
1168 .await?;
1169 room_updated(&room, &session.peer);
1170 mem::take(incoming_call)
1171 };
1172 update_user_contacts(called_user_id, &session).await?;
1173
1174 let mut calls = session
1175 .connection_pool()
1176 .await
1177 .user_connection_ids(called_user_id)
1178 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1179 .collect::<FuturesUnordered<_>>();
1180
1181 while let Some(call_response) = calls.next().await {
1182 match call_response.as_ref() {
1183 Ok(_) => {
1184 response.send(proto::Ack {})?;
1185 return Ok(());
1186 }
1187 Err(_) => {
1188 call_response.trace_err();
1189 }
1190 }
1191 }
1192
1193 {
1194 let room = session
1195 .db()
1196 .await
1197 .call_failed(room_id, called_user_id)
1198 .await?;
1199 room_updated(&room, &session.peer);
1200 }
1201 update_user_contacts(called_user_id, &session).await?;
1202
1203 Err(anyhow!("failed to ring user"))?
1204}
1205
1206async fn cancel_call(
1207 request: proto::CancelCall,
1208 response: Response<proto::CancelCall>,
1209 session: Session,
1210) -> Result<()> {
1211 let called_user_id = UserId::from_proto(request.called_user_id);
1212 let room_id = RoomId::from_proto(request.room_id);
1213 {
1214 let room = session
1215 .db()
1216 .await
1217 .cancel_call(room_id, session.connection_id, called_user_id)
1218 .await?;
1219 room_updated(&room, &session.peer);
1220 }
1221
1222 for connection_id in session
1223 .connection_pool()
1224 .await
1225 .user_connection_ids(called_user_id)
1226 {
1227 session
1228 .peer
1229 .send(
1230 connection_id,
1231 proto::CallCanceled {
1232 room_id: room_id.to_proto(),
1233 },
1234 )
1235 .trace_err();
1236 }
1237 response.send(proto::Ack {})?;
1238
1239 update_user_contacts(called_user_id, &session).await?;
1240 Ok(())
1241}
1242
1243async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1244 let room_id = RoomId::from_proto(message.room_id);
1245 {
1246 let room = session
1247 .db()
1248 .await
1249 .decline_call(Some(room_id), session.user_id)
1250 .await?
1251 .ok_or_else(|| anyhow!("failed to decline call"))?;
1252 room_updated(&room, &session.peer);
1253 }
1254
1255 for connection_id in session
1256 .connection_pool()
1257 .await
1258 .user_connection_ids(session.user_id)
1259 {
1260 session
1261 .peer
1262 .send(
1263 connection_id,
1264 proto::CallCanceled {
1265 room_id: room_id.to_proto(),
1266 },
1267 )
1268 .trace_err();
1269 }
1270 update_user_contacts(session.user_id, &session).await?;
1271 Ok(())
1272}
1273
1274async fn update_participant_location(
1275 request: proto::UpdateParticipantLocation,
1276 response: Response<proto::UpdateParticipantLocation>,
1277 session: Session,
1278) -> Result<()> {
1279 let room_id = RoomId::from_proto(request.room_id);
1280 let location = request
1281 .location
1282 .ok_or_else(|| anyhow!("invalid location"))?;
1283 let room = session
1284 .db()
1285 .await
1286 .update_room_participant_location(room_id, session.connection_id, location)
1287 .await?;
1288 room_updated(&room, &session.peer);
1289 response.send(proto::Ack {})?;
1290 Ok(())
1291}
1292
1293async fn share_project(
1294 request: proto::ShareProject,
1295 response: Response<proto::ShareProject>,
1296 session: Session,
1297) -> Result<()> {
1298 let (project_id, room) = &*session
1299 .db()
1300 .await
1301 .share_project(
1302 RoomId::from_proto(request.room_id),
1303 session.connection_id,
1304 &request.worktrees,
1305 )
1306 .await?;
1307 response.send(proto::ShareProjectResponse {
1308 project_id: project_id.to_proto(),
1309 })?;
1310 room_updated(&room, &session.peer);
1311
1312 Ok(())
1313}
1314
1315async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1316 let project_id = ProjectId::from_proto(message.project_id);
1317
1318 let (room, guest_connection_ids) = &*session
1319 .db()
1320 .await
1321 .unshare_project(project_id, session.connection_id)
1322 .await?;
1323
1324 broadcast(
1325 Some(session.connection_id),
1326 guest_connection_ids.iter().copied(),
1327 |conn_id| session.peer.send(conn_id, message.clone()),
1328 );
1329 room_updated(&room, &session.peer);
1330
1331 Ok(())
1332}
1333
1334async fn join_project(
1335 request: proto::JoinProject,
1336 response: Response<proto::JoinProject>,
1337 session: Session,
1338) -> Result<()> {
1339 let project_id = ProjectId::from_proto(request.project_id);
1340 let guest_user_id = session.user_id;
1341
1342 tracing::info!(%project_id, "join project");
1343
1344 let (project, replica_id) = &mut *session
1345 .db()
1346 .await
1347 .join_project(project_id, session.connection_id)
1348 .await?;
1349
1350 let collaborators = project
1351 .collaborators
1352 .iter()
1353 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1354 .map(|collaborator| collaborator.to_proto())
1355 .collect::<Vec<_>>();
1356
1357 let worktrees = project
1358 .worktrees
1359 .iter()
1360 .map(|(id, worktree)| proto::WorktreeMetadata {
1361 id: *id,
1362 root_name: worktree.root_name.clone(),
1363 visible: worktree.visible,
1364 abs_path: worktree.abs_path.clone(),
1365 })
1366 .collect::<Vec<_>>();
1367
1368 for collaborator in &collaborators {
1369 session
1370 .peer
1371 .send(
1372 collaborator.peer_id.unwrap().into(),
1373 proto::AddProjectCollaborator {
1374 project_id: project_id.to_proto(),
1375 collaborator: Some(proto::Collaborator {
1376 peer_id: Some(session.connection_id.into()),
1377 replica_id: replica_id.0 as u32,
1378 user_id: guest_user_id.to_proto(),
1379 }),
1380 },
1381 )
1382 .trace_err();
1383 }
1384
1385 // First, we send the metadata associated with each worktree.
1386 response.send(proto::JoinProjectResponse {
1387 worktrees: worktrees.clone(),
1388 replica_id: replica_id.0 as u32,
1389 collaborators: collaborators.clone(),
1390 language_servers: project.language_servers.clone(),
1391 })?;
1392
1393 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1394 #[cfg(any(test, feature = "test-support"))]
1395 const MAX_CHUNK_SIZE: usize = 2;
1396 #[cfg(not(any(test, feature = "test-support")))]
1397 const MAX_CHUNK_SIZE: usize = 256;
1398
1399 // Stream this worktree's entries.
1400 let message = proto::UpdateWorktree {
1401 project_id: project_id.to_proto(),
1402 worktree_id,
1403 abs_path: worktree.abs_path.clone(),
1404 root_name: worktree.root_name,
1405 updated_entries: worktree.entries,
1406 removed_entries: Default::default(),
1407 scan_id: worktree.scan_id,
1408 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1409 updated_repositories: worktree.repository_entries.into_values().collect(),
1410 removed_repositories: Default::default(),
1411 };
1412 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1413 session.peer.send(session.connection_id, update.clone())?;
1414 }
1415
1416 // Stream this worktree's diagnostics.
1417 for summary in worktree.diagnostic_summaries {
1418 session.peer.send(
1419 session.connection_id,
1420 proto::UpdateDiagnosticSummary {
1421 project_id: project_id.to_proto(),
1422 worktree_id: worktree.id,
1423 summary: Some(summary),
1424 },
1425 )?;
1426 }
1427
1428 for settings_file in worktree.settings_files {
1429 session.peer.send(
1430 session.connection_id,
1431 proto::UpdateWorktreeSettings {
1432 project_id: project_id.to_proto(),
1433 worktree_id: worktree.id,
1434 path: settings_file.path,
1435 content: Some(settings_file.content),
1436 },
1437 )?;
1438 }
1439 }
1440
1441 for language_server in &project.language_servers {
1442 session.peer.send(
1443 session.connection_id,
1444 proto::UpdateLanguageServer {
1445 project_id: project_id.to_proto(),
1446 language_server_id: language_server.id,
1447 variant: Some(
1448 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1449 proto::LspDiskBasedDiagnosticsUpdated {},
1450 ),
1451 ),
1452 },
1453 )?;
1454 }
1455
1456 Ok(())
1457}
1458
1459async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1460 let sender_id = session.connection_id;
1461 let project_id = ProjectId::from_proto(request.project_id);
1462
1463 let (room, project) = &*session
1464 .db()
1465 .await
1466 .leave_project(project_id, sender_id)
1467 .await?;
1468 tracing::info!(
1469 %project_id,
1470 host_user_id = %project.host_user_id,
1471 host_connection_id = %project.host_connection_id,
1472 "leave project"
1473 );
1474
1475 project_left(&project, &session);
1476 room_updated(&room, &session.peer);
1477
1478 Ok(())
1479}
1480
1481async fn update_project(
1482 request: proto::UpdateProject,
1483 response: Response<proto::UpdateProject>,
1484 session: Session,
1485) -> Result<()> {
1486 let project_id = ProjectId::from_proto(request.project_id);
1487 let (room, guest_connection_ids) = &*session
1488 .db()
1489 .await
1490 .update_project(project_id, session.connection_id, &request.worktrees)
1491 .await?;
1492 broadcast(
1493 Some(session.connection_id),
1494 guest_connection_ids.iter().copied(),
1495 |connection_id| {
1496 session
1497 .peer
1498 .forward_send(session.connection_id, connection_id, request.clone())
1499 },
1500 );
1501 room_updated(&room, &session.peer);
1502 response.send(proto::Ack {})?;
1503
1504 Ok(())
1505}
1506
1507async fn update_worktree(
1508 request: proto::UpdateWorktree,
1509 response: Response<proto::UpdateWorktree>,
1510 session: Session,
1511) -> Result<()> {
1512 let guest_connection_ids = session
1513 .db()
1514 .await
1515 .update_worktree(&request, session.connection_id)
1516 .await?;
1517
1518 broadcast(
1519 Some(session.connection_id),
1520 guest_connection_ids.iter().copied(),
1521 |connection_id| {
1522 session
1523 .peer
1524 .forward_send(session.connection_id, connection_id, request.clone())
1525 },
1526 );
1527 response.send(proto::Ack {})?;
1528 Ok(())
1529}
1530
1531async fn update_diagnostic_summary(
1532 message: proto::UpdateDiagnosticSummary,
1533 session: Session,
1534) -> Result<()> {
1535 let guest_connection_ids = session
1536 .db()
1537 .await
1538 .update_diagnostic_summary(&message, session.connection_id)
1539 .await?;
1540
1541 broadcast(
1542 Some(session.connection_id),
1543 guest_connection_ids.iter().copied(),
1544 |connection_id| {
1545 session
1546 .peer
1547 .forward_send(session.connection_id, connection_id, message.clone())
1548 },
1549 );
1550
1551 Ok(())
1552}
1553
1554async fn update_worktree_settings(
1555 message: proto::UpdateWorktreeSettings,
1556 session: Session,
1557) -> Result<()> {
1558 let guest_connection_ids = session
1559 .db()
1560 .await
1561 .update_worktree_settings(&message, session.connection_id)
1562 .await?;
1563
1564 broadcast(
1565 Some(session.connection_id),
1566 guest_connection_ids.iter().copied(),
1567 |connection_id| {
1568 session
1569 .peer
1570 .forward_send(session.connection_id, connection_id, message.clone())
1571 },
1572 );
1573
1574 Ok(())
1575}
1576
1577async fn start_language_server(
1578 request: proto::StartLanguageServer,
1579 session: Session,
1580) -> Result<()> {
1581 let guest_connection_ids = session
1582 .db()
1583 .await
1584 .start_language_server(&request, session.connection_id)
1585 .await?;
1586
1587 broadcast(
1588 Some(session.connection_id),
1589 guest_connection_ids.iter().copied(),
1590 |connection_id| {
1591 session
1592 .peer
1593 .forward_send(session.connection_id, connection_id, request.clone())
1594 },
1595 );
1596 Ok(())
1597}
1598
1599async fn update_language_server(
1600 request: proto::UpdateLanguageServer,
1601 session: Session,
1602) -> Result<()> {
1603 session.executor.record_backtrace();
1604 let project_id = ProjectId::from_proto(request.project_id);
1605 let project_connection_ids = session
1606 .db()
1607 .await
1608 .project_connection_ids(project_id, session.connection_id)
1609 .await?;
1610 broadcast(
1611 Some(session.connection_id),
1612 project_connection_ids.iter().copied(),
1613 |connection_id| {
1614 session
1615 .peer
1616 .forward_send(session.connection_id, connection_id, request.clone())
1617 },
1618 );
1619 Ok(())
1620}
1621
1622async fn forward_project_request<T>(
1623 request: T,
1624 response: Response<T>,
1625 session: Session,
1626) -> Result<()>
1627where
1628 T: EntityMessage + RequestMessage,
1629{
1630 session.executor.record_backtrace();
1631 let project_id = ProjectId::from_proto(request.remote_entity_id());
1632 let host_connection_id = {
1633 let collaborators = session
1634 .db()
1635 .await
1636 .project_collaborators(project_id, session.connection_id)
1637 .await?;
1638 collaborators
1639 .iter()
1640 .find(|collaborator| collaborator.is_host)
1641 .ok_or_else(|| anyhow!("host not found"))?
1642 .connection_id
1643 };
1644
1645 let payload = session
1646 .peer
1647 .forward_request(session.connection_id, host_connection_id, request)
1648 .await?;
1649
1650 response.send(payload)?;
1651 Ok(())
1652}
1653
1654async fn create_buffer_for_peer(
1655 request: proto::CreateBufferForPeer,
1656 session: Session,
1657) -> Result<()> {
1658 session.executor.record_backtrace();
1659 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1660 session
1661 .peer
1662 .forward_send(session.connection_id, peer_id.into(), request)?;
1663 Ok(())
1664}
1665
1666async fn update_buffer(
1667 request: proto::UpdateBuffer,
1668 response: Response<proto::UpdateBuffer>,
1669 session: Session,
1670) -> Result<()> {
1671 session.executor.record_backtrace();
1672 let project_id = ProjectId::from_proto(request.project_id);
1673 let mut guest_connection_ids;
1674 let mut host_connection_id = None;
1675 {
1676 let collaborators = session
1677 .db()
1678 .await
1679 .project_collaborators(project_id, session.connection_id)
1680 .await?;
1681 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1682 for collaborator in collaborators.iter() {
1683 if collaborator.is_host {
1684 host_connection_id = Some(collaborator.connection_id);
1685 } else {
1686 guest_connection_ids.push(collaborator.connection_id);
1687 }
1688 }
1689 }
1690 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1691
1692 session.executor.record_backtrace();
1693 broadcast(
1694 Some(session.connection_id),
1695 guest_connection_ids,
1696 |connection_id| {
1697 session
1698 .peer
1699 .forward_send(session.connection_id, connection_id, request.clone())
1700 },
1701 );
1702 if host_connection_id != session.connection_id {
1703 session
1704 .peer
1705 .forward_request(session.connection_id, host_connection_id, request.clone())
1706 .await?;
1707 }
1708
1709 response.send(proto::Ack {})?;
1710 Ok(())
1711}
1712
1713async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1714 let project_id = ProjectId::from_proto(request.project_id);
1715 let project_connection_ids = session
1716 .db()
1717 .await
1718 .project_connection_ids(project_id, session.connection_id)
1719 .await?;
1720
1721 broadcast(
1722 Some(session.connection_id),
1723 project_connection_ids.iter().copied(),
1724 |connection_id| {
1725 session
1726 .peer
1727 .forward_send(session.connection_id, connection_id, request.clone())
1728 },
1729 );
1730 Ok(())
1731}
1732
1733async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1734 let project_id = ProjectId::from_proto(request.project_id);
1735 let project_connection_ids = session
1736 .db()
1737 .await
1738 .project_connection_ids(project_id, session.connection_id)
1739 .await?;
1740 broadcast(
1741 Some(session.connection_id),
1742 project_connection_ids.iter().copied(),
1743 |connection_id| {
1744 session
1745 .peer
1746 .forward_send(session.connection_id, connection_id, request.clone())
1747 },
1748 );
1749 Ok(())
1750}
1751
1752async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1753 let project_id = ProjectId::from_proto(request.project_id);
1754 let project_connection_ids = session
1755 .db()
1756 .await
1757 .project_connection_ids(project_id, session.connection_id)
1758 .await?;
1759 broadcast(
1760 Some(session.connection_id),
1761 project_connection_ids.iter().copied(),
1762 |connection_id| {
1763 session
1764 .peer
1765 .forward_send(session.connection_id, connection_id, request.clone())
1766 },
1767 );
1768 Ok(())
1769}
1770
1771async fn follow(
1772 request: proto::Follow,
1773 response: Response<proto::Follow>,
1774 session: Session,
1775) -> Result<()> {
1776 let project_id = ProjectId::from_proto(request.project_id);
1777 let leader_id = request
1778 .leader_id
1779 .ok_or_else(|| anyhow!("invalid leader id"))?
1780 .into();
1781 let follower_id = session.connection_id;
1782
1783 {
1784 let project_connection_ids = session
1785 .db()
1786 .await
1787 .project_connection_ids(project_id, session.connection_id)
1788 .await?;
1789
1790 if !project_connection_ids.contains(&leader_id) {
1791 Err(anyhow!("no such peer"))?;
1792 }
1793 }
1794
1795 let mut response_payload = session
1796 .peer
1797 .forward_request(session.connection_id, leader_id, request)
1798 .await?;
1799 response_payload
1800 .views
1801 .retain(|view| view.leader_id != Some(follower_id.into()));
1802 response.send(response_payload)?;
1803
1804 let room = session
1805 .db()
1806 .await
1807 .follow(project_id, leader_id, follower_id)
1808 .await?;
1809 room_updated(&room, &session.peer);
1810
1811 Ok(())
1812}
1813
1814async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1815 let project_id = ProjectId::from_proto(request.project_id);
1816 let leader_id = request
1817 .leader_id
1818 .ok_or_else(|| anyhow!("invalid leader id"))?
1819 .into();
1820 let follower_id = session.connection_id;
1821
1822 if !session
1823 .db()
1824 .await
1825 .project_connection_ids(project_id, session.connection_id)
1826 .await?
1827 .contains(&leader_id)
1828 {
1829 Err(anyhow!("no such peer"))?;
1830 }
1831
1832 session
1833 .peer
1834 .forward_send(session.connection_id, leader_id, request)?;
1835
1836 let room = session
1837 .db()
1838 .await
1839 .unfollow(project_id, leader_id, follower_id)
1840 .await?;
1841 room_updated(&room, &session.peer);
1842
1843 Ok(())
1844}
1845
1846async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1847 let project_id = ProjectId::from_proto(request.project_id);
1848 let project_connection_ids = session
1849 .db
1850 .lock()
1851 .await
1852 .project_connection_ids(project_id, session.connection_id)
1853 .await?;
1854
1855 let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1856 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1857 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1858 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1859 });
1860 for follower_peer_id in request.follower_ids.iter().copied() {
1861 let follower_connection_id = follower_peer_id.into();
1862 if project_connection_ids.contains(&follower_connection_id)
1863 && Some(follower_peer_id) != leader_id
1864 {
1865 session.peer.forward_send(
1866 session.connection_id,
1867 follower_connection_id,
1868 request.clone(),
1869 )?;
1870 }
1871 }
1872 Ok(())
1873}
1874
1875async fn get_users(
1876 request: proto::GetUsers,
1877 response: Response<proto::GetUsers>,
1878 session: Session,
1879) -> Result<()> {
1880 let user_ids = request
1881 .user_ids
1882 .into_iter()
1883 .map(UserId::from_proto)
1884 .collect();
1885 let users = session
1886 .db()
1887 .await
1888 .get_users_by_ids(user_ids)
1889 .await?
1890 .into_iter()
1891 .map(|user| proto::User {
1892 id: user.id.to_proto(),
1893 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1894 github_login: user.github_login,
1895 })
1896 .collect();
1897 response.send(proto::UsersResponse { users })?;
1898 Ok(())
1899}
1900
1901async fn fuzzy_search_users(
1902 request: proto::FuzzySearchUsers,
1903 response: Response<proto::FuzzySearchUsers>,
1904 session: Session,
1905) -> Result<()> {
1906 let query = request.query;
1907 let users = match query.len() {
1908 0 => vec![],
1909 1 | 2 => session
1910 .db()
1911 .await
1912 .get_user_by_github_login(&query)
1913 .await?
1914 .into_iter()
1915 .collect(),
1916 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1917 };
1918 let users = users
1919 .into_iter()
1920 .filter(|user| user.id != session.user_id)
1921 .map(|user| proto::User {
1922 id: user.id.to_proto(),
1923 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1924 github_login: user.github_login,
1925 })
1926 .collect();
1927 response.send(proto::UsersResponse { users })?;
1928 Ok(())
1929}
1930
1931async fn request_contact(
1932 request: proto::RequestContact,
1933 response: Response<proto::RequestContact>,
1934 session: Session,
1935) -> Result<()> {
1936 let requester_id = session.user_id;
1937 let responder_id = UserId::from_proto(request.responder_id);
1938 if requester_id == responder_id {
1939 return Err(anyhow!("cannot add yourself as a contact"))?;
1940 }
1941
1942 session
1943 .db()
1944 .await
1945 .send_contact_request(requester_id, responder_id)
1946 .await?;
1947
1948 // Update outgoing contact requests of requester
1949 let mut update = proto::UpdateContacts::default();
1950 update.outgoing_requests.push(responder_id.to_proto());
1951 for connection_id in session
1952 .connection_pool()
1953 .await
1954 .user_connection_ids(requester_id)
1955 {
1956 session.peer.send(connection_id, update.clone())?;
1957 }
1958
1959 // Update incoming contact requests of responder
1960 let mut update = proto::UpdateContacts::default();
1961 update
1962 .incoming_requests
1963 .push(proto::IncomingContactRequest {
1964 requester_id: requester_id.to_proto(),
1965 should_notify: true,
1966 });
1967 for connection_id in session
1968 .connection_pool()
1969 .await
1970 .user_connection_ids(responder_id)
1971 {
1972 session.peer.send(connection_id, update.clone())?;
1973 }
1974
1975 response.send(proto::Ack {})?;
1976 Ok(())
1977}
1978
1979async fn respond_to_contact_request(
1980 request: proto::RespondToContactRequest,
1981 response: Response<proto::RespondToContactRequest>,
1982 session: Session,
1983) -> Result<()> {
1984 let responder_id = session.user_id;
1985 let requester_id = UserId::from_proto(request.requester_id);
1986 let db = session.db().await;
1987 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
1988 db.dismiss_contact_notification(responder_id, requester_id)
1989 .await?;
1990 } else {
1991 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
1992
1993 db.respond_to_contact_request(responder_id, requester_id, accept)
1994 .await?;
1995 let requester_busy = db.is_user_busy(requester_id).await?;
1996 let responder_busy = db.is_user_busy(responder_id).await?;
1997
1998 let pool = session.connection_pool().await;
1999 // Update responder with new contact
2000 let mut update = proto::UpdateContacts::default();
2001 if accept {
2002 update
2003 .contacts
2004 .push(contact_for_user(requester_id, false, requester_busy, &pool));
2005 }
2006 update
2007 .remove_incoming_requests
2008 .push(requester_id.to_proto());
2009 for connection_id in pool.user_connection_ids(responder_id) {
2010 session.peer.send(connection_id, update.clone())?;
2011 }
2012
2013 // Update requester with new contact
2014 let mut update = proto::UpdateContacts::default();
2015 if accept {
2016 update
2017 .contacts
2018 .push(contact_for_user(responder_id, true, responder_busy, &pool));
2019 }
2020 update
2021 .remove_outgoing_requests
2022 .push(responder_id.to_proto());
2023 for connection_id in pool.user_connection_ids(requester_id) {
2024 session.peer.send(connection_id, update.clone())?;
2025 }
2026 }
2027
2028 response.send(proto::Ack {})?;
2029 Ok(())
2030}
2031
2032async fn remove_contact(
2033 request: proto::RemoveContact,
2034 response: Response<proto::RemoveContact>,
2035 session: Session,
2036) -> Result<()> {
2037 let requester_id = session.user_id;
2038 let responder_id = UserId::from_proto(request.user_id);
2039 let db = session.db().await;
2040 let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
2041
2042 let pool = session.connection_pool().await;
2043 // Update outgoing contact requests of requester
2044 let mut update = proto::UpdateContacts::default();
2045 if contact_accepted {
2046 update.remove_contacts.push(responder_id.to_proto());
2047 } else {
2048 update
2049 .remove_outgoing_requests
2050 .push(responder_id.to_proto());
2051 }
2052 for connection_id in pool.user_connection_ids(requester_id) {
2053 session.peer.send(connection_id, update.clone())?;
2054 }
2055
2056 // Update incoming contact requests of responder
2057 let mut update = proto::UpdateContacts::default();
2058 if contact_accepted {
2059 update.remove_contacts.push(requester_id.to_proto());
2060 } else {
2061 update
2062 .remove_incoming_requests
2063 .push(requester_id.to_proto());
2064 }
2065 for connection_id in pool.user_connection_ids(responder_id) {
2066 session.peer.send(connection_id, update.clone())?;
2067 }
2068
2069 response.send(proto::Ack {})?;
2070 Ok(())
2071}
2072
2073async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
2074 let project_id = ProjectId::from_proto(request.project_id);
2075 let project_connection_ids = session
2076 .db()
2077 .await
2078 .project_connection_ids(project_id, session.connection_id)
2079 .await?;
2080 broadcast(
2081 Some(session.connection_id),
2082 project_connection_ids.iter().copied(),
2083 |connection_id| {
2084 session
2085 .peer
2086 .forward_send(session.connection_id, connection_id, request.clone())
2087 },
2088 );
2089 Ok(())
2090}
2091
2092async fn get_private_user_info(
2093 _request: proto::GetPrivateUserInfo,
2094 response: Response<proto::GetPrivateUserInfo>,
2095 session: Session,
2096) -> Result<()> {
2097 let metrics_id = session
2098 .db()
2099 .await
2100 .get_user_metrics_id(session.user_id)
2101 .await?;
2102 let user = session
2103 .db()
2104 .await
2105 .get_user_by_id(session.user_id)
2106 .await?
2107 .ok_or_else(|| anyhow!("user not found"))?;
2108 response.send(proto::GetPrivateUserInfoResponse {
2109 metrics_id,
2110 staff: user.admin,
2111 })?;
2112 Ok(())
2113}
2114
2115fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
2116 match message {
2117 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
2118 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
2119 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
2120 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
2121 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
2122 code: frame.code.into(),
2123 reason: frame.reason,
2124 })),
2125 }
2126}
2127
2128fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
2129 match message {
2130 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
2131 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
2132 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
2133 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
2134 AxumMessage::Close(frame) => {
2135 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
2136 code: frame.code.into(),
2137 reason: frame.reason,
2138 }))
2139 }
2140 }
2141}
2142
2143fn build_initial_contacts_update(
2144 contacts: Vec<db::Contact>,
2145 pool: &ConnectionPool,
2146) -> proto::UpdateContacts {
2147 let mut update = proto::UpdateContacts::default();
2148
2149 for contact in contacts {
2150 match contact {
2151 db::Contact::Accepted {
2152 user_id,
2153 should_notify,
2154 busy,
2155 } => {
2156 update
2157 .contacts
2158 .push(contact_for_user(user_id, should_notify, busy, &pool));
2159 }
2160 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
2161 db::Contact::Incoming {
2162 user_id,
2163 should_notify,
2164 } => update
2165 .incoming_requests
2166 .push(proto::IncomingContactRequest {
2167 requester_id: user_id.to_proto(),
2168 should_notify,
2169 }),
2170 }
2171 }
2172
2173 update
2174}
2175
2176fn contact_for_user(
2177 user_id: UserId,
2178 should_notify: bool,
2179 busy: bool,
2180 pool: &ConnectionPool,
2181) -> proto::Contact {
2182 proto::Contact {
2183 user_id: user_id.to_proto(),
2184 online: pool.is_user_online(user_id),
2185 busy,
2186 should_notify,
2187 }
2188}
2189
2190fn room_updated(room: &proto::Room, peer: &Peer) {
2191 broadcast(
2192 None,
2193 room.participants
2194 .iter()
2195 .filter_map(|participant| Some(participant.peer_id?.into())),
2196 |peer_id| {
2197 peer.send(
2198 peer_id.into(),
2199 proto::RoomUpdated {
2200 room: Some(room.clone()),
2201 },
2202 )
2203 },
2204 );
2205}
2206
2207async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
2208 let db = session.db().await;
2209 let contacts = db.get_contacts(user_id).await?;
2210 let busy = db.is_user_busy(user_id).await?;
2211
2212 let pool = session.connection_pool().await;
2213 let updated_contact = contact_for_user(user_id, false, busy, &pool);
2214 for contact in contacts {
2215 if let db::Contact::Accepted {
2216 user_id: contact_user_id,
2217 ..
2218 } = contact
2219 {
2220 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
2221 session
2222 .peer
2223 .send(
2224 contact_conn_id,
2225 proto::UpdateContacts {
2226 contacts: vec![updated_contact.clone()],
2227 remove_contacts: Default::default(),
2228 incoming_requests: Default::default(),
2229 remove_incoming_requests: Default::default(),
2230 outgoing_requests: Default::default(),
2231 remove_outgoing_requests: Default::default(),
2232 },
2233 )
2234 .trace_err();
2235 }
2236 }
2237 }
2238 Ok(())
2239}
2240
2241async fn leave_room_for_session(session: &Session) -> Result<()> {
2242 let mut contacts_to_update = HashSet::default();
2243
2244 let room_id;
2245 let canceled_calls_to_user_ids;
2246 let live_kit_room;
2247 let delete_live_kit_room;
2248 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
2249 contacts_to_update.insert(session.user_id);
2250
2251 for project in left_room.left_projects.values() {
2252 project_left(project, session);
2253 }
2254
2255 room_updated(&left_room.room, &session.peer);
2256 room_id = RoomId::from_proto(left_room.room.id);
2257 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
2258 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
2259 delete_live_kit_room = left_room.room.participants.is_empty();
2260 } else {
2261 return Ok(());
2262 }
2263
2264 {
2265 let pool = session.connection_pool().await;
2266 for canceled_user_id in canceled_calls_to_user_ids {
2267 for connection_id in pool.user_connection_ids(canceled_user_id) {
2268 session
2269 .peer
2270 .send(
2271 connection_id,
2272 proto::CallCanceled {
2273 room_id: room_id.to_proto(),
2274 },
2275 )
2276 .trace_err();
2277 }
2278 contacts_to_update.insert(canceled_user_id);
2279 }
2280 }
2281
2282 for contact_user_id in contacts_to_update {
2283 update_user_contacts(contact_user_id, &session).await?;
2284 }
2285
2286 if let Some(live_kit) = session.live_kit_client.as_ref() {
2287 live_kit
2288 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
2289 .await
2290 .trace_err();
2291
2292 if delete_live_kit_room {
2293 live_kit.delete_room(live_kit_room).await.trace_err();
2294 }
2295 }
2296
2297 Ok(())
2298}
2299
2300fn project_left(project: &db::LeftProject, session: &Session) {
2301 for connection_id in &project.connection_ids {
2302 if project.host_user_id == session.user_id {
2303 session
2304 .peer
2305 .send(
2306 *connection_id,
2307 proto::UnshareProject {
2308 project_id: project.id.to_proto(),
2309 },
2310 )
2311 .trace_err();
2312 } else {
2313 session
2314 .peer
2315 .send(
2316 *connection_id,
2317 proto::RemoveProjectCollaborator {
2318 project_id: project.id.to_proto(),
2319 peer_id: Some(session.connection_id.into()),
2320 },
2321 )
2322 .trace_err();
2323 }
2324 }
2325}
2326
2327pub trait ResultExt {
2328 type Ok;
2329
2330 fn trace_err(self) -> Option<Self::Ok>;
2331}
2332
2333impl<T, E> ResultExt for Result<T, E>
2334where
2335 E: std::fmt::Debug,
2336{
2337 type Ok = T;
2338
2339 fn trace_err(self) -> Option<T> {
2340 match self {
2341 Ok(value) => Some(value),
2342 Err(error) => {
2343 tracing::error!("{:?}", error);
2344 None
2345 }
2346 }
2347 }
2348}