1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{self, ChannelId, Database, ProjectId, RoomId, ServerId, User, UserId},
6 executor::Executor,
7 AppState, Result,
8};
9use anyhow::anyhow;
10use async_tungstenite::tungstenite::{
11 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
12};
13use axum::{
14 body::Body,
15 extract::{
16 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
17 ConnectInfo, WebSocketUpgrade,
18 },
19 headers::{Header, HeaderName},
20 http::StatusCode,
21 middleware,
22 response::IntoResponse,
23 routing::get,
24 Extension, Router, TypedHeader,
25};
26use collections::{HashMap, HashSet};
27pub use connection_pool::ConnectionPool;
28use futures::{
29 channel::oneshot,
30 future::{self, BoxFuture},
31 stream::FuturesUnordered,
32 FutureExt, SinkExt, StreamExt, TryStreamExt,
33};
34use lazy_static::lazy_static;
35use prometheus::{register_int_gauge, IntGauge};
36use rpc::{
37 proto::{
38 self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
39 RequestMessage,
40 },
41 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
42};
43use serde::{Serialize, Serializer};
44use std::{
45 any::TypeId,
46 fmt,
47 future::Future,
48 marker::PhantomData,
49 mem,
50 net::SocketAddr,
51 ops::{Deref, DerefMut},
52 rc::Rc,
53 sync::{
54 atomic::{AtomicBool, Ordering::SeqCst},
55 Arc,
56 },
57 time::{Duration, Instant},
58};
59use tokio::sync::{watch, Semaphore};
60use tower::ServiceBuilder;
61use tracing::{info_span, instrument, Instrument};
62
63pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
64pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
65
66lazy_static! {
67 static ref METRIC_CONNECTIONS: IntGauge =
68 register_int_gauge!("connections", "number of connections").unwrap();
69 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
70 "shared_projects",
71 "number of open projects with one or more guests"
72 )
73 .unwrap();
74}
75
76type MessageHandler =
77 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
78
79struct Response<R> {
80 peer: Arc<Peer>,
81 receipt: Receipt<R>,
82 responded: Arc<AtomicBool>,
83}
84
85impl<R: RequestMessage> Response<R> {
86 fn send(self, payload: R::Response) -> Result<()> {
87 self.responded.store(true, SeqCst);
88 self.peer.respond(self.receipt, payload)?;
89 Ok(())
90 }
91}
92
93#[derive(Clone)]
94struct Session {
95 user_id: UserId,
96 connection_id: ConnectionId,
97 db: Arc<tokio::sync::Mutex<DbHandle>>,
98 peer: Arc<Peer>,
99 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
100 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
101 executor: Executor,
102}
103
104impl Session {
105 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
106 #[cfg(test)]
107 tokio::task::yield_now().await;
108 let guard = self.db.lock().await;
109 #[cfg(test)]
110 tokio::task::yield_now().await;
111 guard
112 }
113
114 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
115 #[cfg(test)]
116 tokio::task::yield_now().await;
117 let guard = self.connection_pool.lock();
118 ConnectionPoolGuard {
119 guard,
120 _not_send: PhantomData,
121 }
122 }
123}
124
125impl fmt::Debug for Session {
126 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127 f.debug_struct("Session")
128 .field("user_id", &self.user_id)
129 .field("connection_id", &self.connection_id)
130 .finish()
131 }
132}
133
134struct DbHandle(Arc<Database>);
135
136impl Deref for DbHandle {
137 type Target = Database;
138
139 fn deref(&self) -> &Self::Target {
140 self.0.as_ref()
141 }
142}
143
144pub struct Server {
145 id: parking_lot::Mutex<ServerId>,
146 peer: Arc<Peer>,
147 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
148 app_state: Arc<AppState>,
149 executor: Executor,
150 handlers: HashMap<TypeId, MessageHandler>,
151 teardown: watch::Sender<()>,
152}
153
154pub(crate) struct ConnectionPoolGuard<'a> {
155 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
156 _not_send: PhantomData<Rc<()>>,
157}
158
159#[derive(Serialize)]
160pub struct ServerSnapshot<'a> {
161 peer: &'a Peer,
162 #[serde(serialize_with = "serialize_deref")]
163 connection_pool: ConnectionPoolGuard<'a>,
164}
165
166pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
167where
168 S: Serializer,
169 T: Deref<Target = U>,
170 U: Serialize,
171{
172 Serialize::serialize(value.deref(), serializer)
173}
174
175impl Server {
176 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
177 let mut server = Self {
178 id: parking_lot::Mutex::new(id),
179 peer: Peer::new(id.0 as u32),
180 app_state,
181 executor,
182 connection_pool: Default::default(),
183 handlers: Default::default(),
184 teardown: watch::channel(()).0,
185 };
186
187 server
188 .add_request_handler(ping)
189 .add_request_handler(create_room)
190 .add_request_handler(join_room)
191 .add_request_handler(rejoin_room)
192 .add_request_handler(leave_room)
193 .add_request_handler(call)
194 .add_request_handler(cancel_call)
195 .add_message_handler(decline_call)
196 .add_request_handler(update_participant_location)
197 .add_request_handler(share_project)
198 .add_message_handler(unshare_project)
199 .add_request_handler(join_project)
200 .add_message_handler(leave_project)
201 .add_request_handler(update_project)
202 .add_request_handler(update_worktree)
203 .add_message_handler(start_language_server)
204 .add_message_handler(update_language_server)
205 .add_message_handler(update_diagnostic_summary)
206 .add_message_handler(update_worktree_settings)
207 .add_message_handler(refresh_inlay_hints)
208 .add_request_handler(forward_project_request::<proto::GetHover>)
209 .add_request_handler(forward_project_request::<proto::GetDefinition>)
210 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
211 .add_request_handler(forward_project_request::<proto::GetReferences>)
212 .add_request_handler(forward_project_request::<proto::SearchProject>)
213 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
214 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
215 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
216 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
217 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
218 .add_request_handler(forward_project_request::<proto::GetCompletions>)
219 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
220 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
221 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
222 .add_request_handler(forward_project_request::<proto::PrepareRename>)
223 .add_request_handler(forward_project_request::<proto::PerformRename>)
224 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
225 .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
226 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
227 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
228 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
229 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
230 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
231 .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
232 .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
233 .add_request_handler(forward_project_request::<proto::InlayHints>)
234 .add_message_handler(create_buffer_for_peer)
235 .add_request_handler(update_buffer)
236 .add_message_handler(update_buffer_file)
237 .add_message_handler(buffer_reloaded)
238 .add_message_handler(buffer_saved)
239 .add_request_handler(forward_project_request::<proto::SaveBuffer>)
240 .add_request_handler(get_users)
241 .add_request_handler(fuzzy_search_users)
242 .add_request_handler(request_contact)
243 .add_request_handler(remove_contact)
244 .add_request_handler(respond_to_contact_request)
245 .add_request_handler(create_channel)
246 .add_request_handler(remove_channel)
247 .add_request_handler(invite_channel_member)
248 .add_request_handler(remove_channel_member)
249 .add_request_handler(respond_to_channel_invite)
250 .add_request_handler(join_channel)
251 .add_request_handler(follow)
252 .add_message_handler(unfollow)
253 .add_message_handler(update_followers)
254 .add_message_handler(update_diff_base)
255 .add_request_handler(get_private_user_info);
256
257 Arc::new(server)
258 }
259
260 pub async fn start(&self) -> Result<()> {
261 let server_id = *self.id.lock();
262 let app_state = self.app_state.clone();
263 let peer = self.peer.clone();
264 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
265 let pool = self.connection_pool.clone();
266 let live_kit_client = self.app_state.live_kit_client.clone();
267
268 let span = info_span!("start server");
269 self.executor.spawn_detached(
270 async move {
271 tracing::info!("waiting for cleanup timeout");
272 timeout.await;
273 tracing::info!("cleanup timeout expired, retrieving stale rooms");
274 if let Some(room_ids) = app_state
275 .db
276 .stale_room_ids(&app_state.config.zed_environment, server_id)
277 .await
278 .trace_err()
279 {
280 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
281 for room_id in room_ids {
282 let mut contacts_to_update = HashSet::default();
283 let mut canceled_calls_to_user_ids = Vec::new();
284 let mut live_kit_room = String::new();
285 let mut delete_live_kit_room = false;
286
287 if let Some(mut refreshed_room) = app_state
288 .db
289 .refresh_room(room_id, server_id)
290 .await
291 .trace_err()
292 {
293 tracing::info!(
294 room_id = room_id.0,
295 new_participant_count = refreshed_room.room.participants.len(),
296 "refreshed room"
297 );
298 room_updated(&refreshed_room.room, &peer);
299 contacts_to_update
300 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
301 contacts_to_update
302 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
303 canceled_calls_to_user_ids =
304 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
305 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
306 delete_live_kit_room = refreshed_room.room.participants.is_empty();
307 }
308
309 {
310 let pool = pool.lock();
311 for canceled_user_id in canceled_calls_to_user_ids {
312 for connection_id in pool.user_connection_ids(canceled_user_id) {
313 peer.send(
314 connection_id,
315 proto::CallCanceled {
316 room_id: room_id.to_proto(),
317 },
318 )
319 .trace_err();
320 }
321 }
322 }
323
324 for user_id in contacts_to_update {
325 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
326 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
327 if let Some((busy, contacts)) = busy.zip(contacts) {
328 let pool = pool.lock();
329 let updated_contact = contact_for_user(user_id, false, busy, &pool);
330 for contact in contacts {
331 if let db::Contact::Accepted {
332 user_id: contact_user_id,
333 ..
334 } = contact
335 {
336 for contact_conn_id in
337 pool.user_connection_ids(contact_user_id)
338 {
339 peer.send(
340 contact_conn_id,
341 proto::UpdateContacts {
342 contacts: vec![updated_contact.clone()],
343 remove_contacts: Default::default(),
344 incoming_requests: Default::default(),
345 remove_incoming_requests: Default::default(),
346 outgoing_requests: Default::default(),
347 remove_outgoing_requests: Default::default(),
348 },
349 )
350 .trace_err();
351 }
352 }
353 }
354 }
355 }
356
357 if let Some(live_kit) = live_kit_client.as_ref() {
358 if delete_live_kit_room {
359 live_kit.delete_room(live_kit_room).await.trace_err();
360 }
361 }
362 }
363 }
364
365 app_state
366 .db
367 .delete_stale_servers(&app_state.config.zed_environment, server_id)
368 .await
369 .trace_err();
370 }
371 .instrument(span),
372 );
373 Ok(())
374 }
375
376 pub fn teardown(&self) {
377 self.peer.teardown();
378 self.connection_pool.lock().reset();
379 let _ = self.teardown.send(());
380 }
381
382 #[cfg(test)]
383 pub fn reset(&self, id: ServerId) {
384 self.teardown();
385 *self.id.lock() = id;
386 self.peer.reset(id.0 as u32);
387 }
388
389 #[cfg(test)]
390 pub fn id(&self) -> ServerId {
391 *self.id.lock()
392 }
393
394 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
395 where
396 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
397 Fut: 'static + Send + Future<Output = Result<()>>,
398 M: EnvelopedMessage,
399 {
400 let prev_handler = self.handlers.insert(
401 TypeId::of::<M>(),
402 Box::new(move |envelope, session| {
403 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
404 let span = info_span!(
405 "handle message",
406 payload_type = envelope.payload_type_name()
407 );
408 span.in_scope(|| {
409 tracing::info!(
410 payload_type = envelope.payload_type_name(),
411 "message received"
412 );
413 });
414 let start_time = Instant::now();
415 let future = (handler)(*envelope, session);
416 async move {
417 let result = future.await;
418 let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
419 match result {
420 Err(error) => {
421 tracing::error!(%error, ?duration_ms, "error handling message")
422 }
423 Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
424 }
425 }
426 .instrument(span)
427 .boxed()
428 }),
429 );
430 if prev_handler.is_some() {
431 panic!("registered a handler for the same message twice");
432 }
433 self
434 }
435
436 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
437 where
438 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
439 Fut: 'static + Send + Future<Output = Result<()>>,
440 M: EnvelopedMessage,
441 {
442 self.add_handler(move |envelope, session| handler(envelope.payload, session));
443 self
444 }
445
446 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
447 where
448 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
449 Fut: Send + Future<Output = Result<()>>,
450 M: RequestMessage,
451 {
452 let handler = Arc::new(handler);
453 self.add_handler(move |envelope, session| {
454 let receipt = envelope.receipt();
455 let handler = handler.clone();
456 async move {
457 let peer = session.peer.clone();
458 let responded = Arc::new(AtomicBool::default());
459 let response = Response {
460 peer: peer.clone(),
461 responded: responded.clone(),
462 receipt,
463 };
464 match (handler)(envelope.payload, response, session).await {
465 Ok(()) => {
466 if responded.load(std::sync::atomic::Ordering::SeqCst) {
467 Ok(())
468 } else {
469 Err(anyhow!("handler did not send a response"))?
470 }
471 }
472 Err(error) => {
473 peer.respond_with_error(
474 receipt,
475 proto::Error {
476 message: error.to_string(),
477 },
478 )?;
479 Err(error)
480 }
481 }
482 }
483 })
484 }
485
486 pub fn handle_connection(
487 self: &Arc<Self>,
488 connection: Connection,
489 address: String,
490 user: User,
491 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
492 executor: Executor,
493 ) -> impl Future<Output = Result<()>> {
494 let this = self.clone();
495 let user_id = user.id;
496 let login = user.github_login;
497 let span = info_span!("handle connection", %user_id, %login, %address);
498 let mut teardown = self.teardown.subscribe();
499 async move {
500 let (connection_id, handle_io, mut incoming_rx) = this
501 .peer
502 .add_connection(connection, {
503 let executor = executor.clone();
504 move |duration| executor.sleep(duration)
505 });
506
507 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
508 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
509 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
510
511 if let Some(send_connection_id) = send_connection_id.take() {
512 let _ = send_connection_id.send(connection_id);
513 }
514
515 if !user.connected_once {
516 this.peer.send(connection_id, proto::ShowContacts {})?;
517 this.app_state.db.set_user_connected_once(user_id, true).await?;
518 }
519
520 let (contacts, invite_code, channels, channel_invites) = future::try_join4(
521 this.app_state.db.get_contacts(user_id),
522 this.app_state.db.get_invite_code_for_user(user_id),
523 this.app_state.db.get_channels(user_id),
524 this.app_state.db.get_channel_invites(user_id)
525 ).await?;
526
527 {
528 let mut pool = this.connection_pool.lock();
529 pool.add_connection(connection_id, user_id, user.admin);
530 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
531 this.peer.send(connection_id, build_initial_channels_update(channels, channel_invites))?;
532
533 if let Some((code, count)) = invite_code {
534 this.peer.send(connection_id, proto::UpdateInviteInfo {
535 url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
536 count: count as u32,
537 })?;
538 }
539 }
540
541 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
542 this.peer.send(connection_id, incoming_call)?;
543 }
544
545 let session = Session {
546 user_id,
547 connection_id,
548 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
549 peer: this.peer.clone(),
550 connection_pool: this.connection_pool.clone(),
551 live_kit_client: this.app_state.live_kit_client.clone(),
552 executor: executor.clone(),
553 };
554 update_user_contacts(user_id, &session).await?;
555
556 let handle_io = handle_io.fuse();
557 futures::pin_mut!(handle_io);
558
559 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
560 // This prevents deadlocks when e.g., client A performs a request to client B and
561 // client B performs a request to client A. If both clients stop processing further
562 // messages until their respective request completes, they won't have a chance to
563 // respond to the other client's request and cause a deadlock.
564 //
565 // This arrangement ensures we will attempt to process earlier messages first, but fall
566 // back to processing messages arrived later in the spirit of making progress.
567 let mut foreground_message_handlers = FuturesUnordered::new();
568 let concurrent_handlers = Arc::new(Semaphore::new(256));
569 loop {
570 let next_message = async {
571 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
572 let message = incoming_rx.next().await;
573 (permit, message)
574 }.fuse();
575 futures::pin_mut!(next_message);
576 futures::select_biased! {
577 _ = teardown.changed().fuse() => return Ok(()),
578 result = handle_io => {
579 if let Err(error) = result {
580 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
581 }
582 break;
583 }
584 _ = foreground_message_handlers.next() => {}
585 next_message = next_message => {
586 let (permit, message) = next_message;
587 if let Some(message) = message {
588 let type_name = message.payload_type_name();
589 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
590 let span_enter = span.enter();
591 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
592 let is_background = message.is_background();
593 let handle_message = (handler)(message, session.clone());
594 drop(span_enter);
595
596 let handle_message = async move {
597 handle_message.await;
598 drop(permit);
599 }.instrument(span);
600 if is_background {
601 executor.spawn_detached(handle_message);
602 } else {
603 foreground_message_handlers.push(handle_message);
604 }
605 } else {
606 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
607 }
608 } else {
609 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
610 break;
611 }
612 }
613 }
614 }
615
616 drop(foreground_message_handlers);
617 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
618 if let Err(error) = connection_lost(session, teardown, executor).await {
619 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
620 }
621
622 Ok(())
623 }.instrument(span)
624 }
625
626 pub async fn invite_code_redeemed(
627 self: &Arc<Self>,
628 inviter_id: UserId,
629 invitee_id: UserId,
630 ) -> Result<()> {
631 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
632 if let Some(code) = &user.invite_code {
633 let pool = self.connection_pool.lock();
634 let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
635 for connection_id in pool.user_connection_ids(inviter_id) {
636 self.peer.send(
637 connection_id,
638 proto::UpdateContacts {
639 contacts: vec![invitee_contact.clone()],
640 ..Default::default()
641 },
642 )?;
643 self.peer.send(
644 connection_id,
645 proto::UpdateInviteInfo {
646 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
647 count: user.invite_count as u32,
648 },
649 )?;
650 }
651 }
652 }
653 Ok(())
654 }
655
656 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
657 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
658 if let Some(invite_code) = &user.invite_code {
659 let pool = self.connection_pool.lock();
660 for connection_id in pool.user_connection_ids(user_id) {
661 self.peer.send(
662 connection_id,
663 proto::UpdateInviteInfo {
664 url: format!(
665 "{}{}",
666 self.app_state.config.invite_link_prefix, invite_code
667 ),
668 count: user.invite_count as u32,
669 },
670 )?;
671 }
672 }
673 }
674 Ok(())
675 }
676
677 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
678 ServerSnapshot {
679 connection_pool: ConnectionPoolGuard {
680 guard: self.connection_pool.lock(),
681 _not_send: PhantomData,
682 },
683 peer: &self.peer,
684 }
685 }
686}
687
688impl<'a> Deref for ConnectionPoolGuard<'a> {
689 type Target = ConnectionPool;
690
691 fn deref(&self) -> &Self::Target {
692 &*self.guard
693 }
694}
695
696impl<'a> DerefMut for ConnectionPoolGuard<'a> {
697 fn deref_mut(&mut self) -> &mut Self::Target {
698 &mut *self.guard
699 }
700}
701
702impl<'a> Drop for ConnectionPoolGuard<'a> {
703 fn drop(&mut self) {
704 #[cfg(test)]
705 self.check_invariants();
706 }
707}
708
709fn broadcast<F>(
710 sender_id: Option<ConnectionId>,
711 receiver_ids: impl IntoIterator<Item = ConnectionId>,
712 mut f: F,
713) where
714 F: FnMut(ConnectionId) -> anyhow::Result<()>,
715{
716 for receiver_id in receiver_ids {
717 if Some(receiver_id) != sender_id {
718 if let Err(error) = f(receiver_id) {
719 tracing::error!("failed to send to {:?} {}", receiver_id, error);
720 }
721 }
722 }
723}
724
725lazy_static! {
726 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
727}
728
729pub struct ProtocolVersion(u32);
730
731impl Header for ProtocolVersion {
732 fn name() -> &'static HeaderName {
733 &ZED_PROTOCOL_VERSION
734 }
735
736 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
737 where
738 Self: Sized,
739 I: Iterator<Item = &'i axum::http::HeaderValue>,
740 {
741 let version = values
742 .next()
743 .ok_or_else(axum::headers::Error::invalid)?
744 .to_str()
745 .map_err(|_| axum::headers::Error::invalid())?
746 .parse()
747 .map_err(|_| axum::headers::Error::invalid())?;
748 Ok(Self(version))
749 }
750
751 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
752 values.extend([self.0.to_string().parse().unwrap()]);
753 }
754}
755
756pub fn routes(server: Arc<Server>) -> Router<Body> {
757 Router::new()
758 .route("/rpc", get(handle_websocket_request))
759 .layer(
760 ServiceBuilder::new()
761 .layer(Extension(server.app_state.clone()))
762 .layer(middleware::from_fn(auth::validate_header)),
763 )
764 .route("/metrics", get(handle_metrics))
765 .layer(Extension(server))
766}
767
768pub async fn handle_websocket_request(
769 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
770 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
771 Extension(server): Extension<Arc<Server>>,
772 Extension(user): Extension<User>,
773 ws: WebSocketUpgrade,
774) -> axum::response::Response {
775 if protocol_version != rpc::PROTOCOL_VERSION {
776 return (
777 StatusCode::UPGRADE_REQUIRED,
778 "client must be upgraded".to_string(),
779 )
780 .into_response();
781 }
782 let socket_address = socket_address.to_string();
783 ws.on_upgrade(move |socket| {
784 use util::ResultExt;
785 let socket = socket
786 .map_ok(to_tungstenite_message)
787 .err_into()
788 .with(|message| async move { Ok(to_axum_message(message)) });
789 let connection = Connection::new(Box::pin(socket));
790 async move {
791 server
792 .handle_connection(connection, socket_address, user, None, Executor::Production)
793 .await
794 .log_err();
795 }
796 })
797}
798
799pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
800 let connections = server
801 .connection_pool
802 .lock()
803 .connections()
804 .filter(|connection| !connection.admin)
805 .count();
806
807 METRIC_CONNECTIONS.set(connections as _);
808
809 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
810 METRIC_SHARED_PROJECTS.set(shared_projects as _);
811
812 let encoder = prometheus::TextEncoder::new();
813 let metric_families = prometheus::gather();
814 let encoded_metrics = encoder
815 .encode_to_string(&metric_families)
816 .map_err(|err| anyhow!("{}", err))?;
817 Ok(encoded_metrics)
818}
819
820#[instrument(err, skip(executor))]
821async fn connection_lost(
822 session: Session,
823 mut teardown: watch::Receiver<()>,
824 executor: Executor,
825) -> Result<()> {
826 session.peer.disconnect(session.connection_id);
827 session
828 .connection_pool()
829 .await
830 .remove_connection(session.connection_id)?;
831
832 session
833 .db()
834 .await
835 .connection_lost(session.connection_id)
836 .await
837 .trace_err();
838
839 futures::select_biased! {
840 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
841 leave_room_for_session(&session).await.trace_err();
842
843 if !session
844 .connection_pool()
845 .await
846 .is_user_online(session.user_id)
847 {
848 let db = session.db().await;
849 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
850 room_updated(&room, &session.peer);
851 }
852 }
853 update_user_contacts(session.user_id, &session).await?;
854 }
855 _ = teardown.changed().fuse() => {}
856 }
857
858 Ok(())
859}
860
861async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
862 response.send(proto::Ack {})?;
863 Ok(())
864}
865
866async fn create_room(
867 _request: proto::CreateRoom,
868 response: Response<proto::CreateRoom>,
869 session: Session,
870) -> Result<()> {
871 let live_kit_room = nanoid::nanoid!(30);
872
873 let live_kit_connection_info = {
874 let live_kit_room = live_kit_room.clone();
875 let live_kit = session.live_kit_client.as_ref();
876
877 util::async_iife!({
878 let live_kit = live_kit?;
879
880 live_kit
881 .create_room(live_kit_room.clone())
882 .await
883 .trace_err()?;
884
885 let token = live_kit
886 .room_token(&live_kit_room, &session.user_id.to_string())
887 .trace_err()?;
888
889 Some(proto::LiveKitConnectionInfo {
890 server_url: live_kit.url().into(),
891 token,
892 })
893 })
894 }
895 .await;
896
897 let room = session
898 .db()
899 .await
900 .create_room(session.user_id, session.connection_id, &live_kit_room)
901 .await?;
902
903 response.send(proto::CreateRoomResponse {
904 room: Some(room.clone()),
905 live_kit_connection_info,
906 })?;
907
908 update_user_contacts(session.user_id, &session).await?;
909 Ok(())
910}
911
912async fn join_room(
913 request: proto::JoinRoom,
914 response: Response<proto::JoinRoom>,
915 session: Session,
916) -> Result<()> {
917 let room_id = RoomId::from_proto(request.id);
918 let room = {
919 let room = session
920 .db()
921 .await
922 .join_room(room_id, session.user_id, None, session.connection_id)
923 .await?;
924 room_updated(&room, &session.peer);
925 room.clone()
926 };
927
928 for connection_id in session
929 .connection_pool()
930 .await
931 .user_connection_ids(session.user_id)
932 {
933 session
934 .peer
935 .send(
936 connection_id,
937 proto::CallCanceled {
938 room_id: room_id.to_proto(),
939 },
940 )
941 .trace_err();
942 }
943
944 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
945 if let Some(token) = live_kit
946 .room_token(&room.live_kit_room, &session.user_id.to_string())
947 .trace_err()
948 {
949 Some(proto::LiveKitConnectionInfo {
950 server_url: live_kit.url().into(),
951 token,
952 })
953 } else {
954 None
955 }
956 } else {
957 None
958 };
959
960 response.send(proto::JoinRoomResponse {
961 room: Some(room),
962 live_kit_connection_info,
963 })?;
964
965 update_user_contacts(session.user_id, &session).await?;
966 Ok(())
967}
968
969async fn rejoin_room(
970 request: proto::RejoinRoom,
971 response: Response<proto::RejoinRoom>,
972 session: Session,
973) -> Result<()> {
974 {
975 let mut rejoined_room = session
976 .db()
977 .await
978 .rejoin_room(request, session.user_id, session.connection_id)
979 .await?;
980
981 response.send(proto::RejoinRoomResponse {
982 room: Some(rejoined_room.room.clone()),
983 reshared_projects: rejoined_room
984 .reshared_projects
985 .iter()
986 .map(|project| proto::ResharedProject {
987 id: project.id.to_proto(),
988 collaborators: project
989 .collaborators
990 .iter()
991 .map(|collaborator| collaborator.to_proto())
992 .collect(),
993 })
994 .collect(),
995 rejoined_projects: rejoined_room
996 .rejoined_projects
997 .iter()
998 .map(|rejoined_project| proto::RejoinedProject {
999 id: rejoined_project.id.to_proto(),
1000 worktrees: rejoined_project
1001 .worktrees
1002 .iter()
1003 .map(|worktree| proto::WorktreeMetadata {
1004 id: worktree.id,
1005 root_name: worktree.root_name.clone(),
1006 visible: worktree.visible,
1007 abs_path: worktree.abs_path.clone(),
1008 })
1009 .collect(),
1010 collaborators: rejoined_project
1011 .collaborators
1012 .iter()
1013 .map(|collaborator| collaborator.to_proto())
1014 .collect(),
1015 language_servers: rejoined_project.language_servers.clone(),
1016 })
1017 .collect(),
1018 })?;
1019 room_updated(&rejoined_room.room, &session.peer);
1020
1021 for project in &rejoined_room.reshared_projects {
1022 for collaborator in &project.collaborators {
1023 session
1024 .peer
1025 .send(
1026 collaborator.connection_id,
1027 proto::UpdateProjectCollaborator {
1028 project_id: project.id.to_proto(),
1029 old_peer_id: Some(project.old_connection_id.into()),
1030 new_peer_id: Some(session.connection_id.into()),
1031 },
1032 )
1033 .trace_err();
1034 }
1035
1036 broadcast(
1037 Some(session.connection_id),
1038 project
1039 .collaborators
1040 .iter()
1041 .map(|collaborator| collaborator.connection_id),
1042 |connection_id| {
1043 session.peer.forward_send(
1044 session.connection_id,
1045 connection_id,
1046 proto::UpdateProject {
1047 project_id: project.id.to_proto(),
1048 worktrees: project.worktrees.clone(),
1049 },
1050 )
1051 },
1052 );
1053 }
1054
1055 for project in &rejoined_room.rejoined_projects {
1056 for collaborator in &project.collaborators {
1057 session
1058 .peer
1059 .send(
1060 collaborator.connection_id,
1061 proto::UpdateProjectCollaborator {
1062 project_id: project.id.to_proto(),
1063 old_peer_id: Some(project.old_connection_id.into()),
1064 new_peer_id: Some(session.connection_id.into()),
1065 },
1066 )
1067 .trace_err();
1068 }
1069 }
1070
1071 for project in &mut rejoined_room.rejoined_projects {
1072 for worktree in mem::take(&mut project.worktrees) {
1073 #[cfg(any(test, feature = "test-support"))]
1074 const MAX_CHUNK_SIZE: usize = 2;
1075 #[cfg(not(any(test, feature = "test-support")))]
1076 const MAX_CHUNK_SIZE: usize = 256;
1077
1078 // Stream this worktree's entries.
1079 let message = proto::UpdateWorktree {
1080 project_id: project.id.to_proto(),
1081 worktree_id: worktree.id,
1082 abs_path: worktree.abs_path.clone(),
1083 root_name: worktree.root_name,
1084 updated_entries: worktree.updated_entries,
1085 removed_entries: worktree.removed_entries,
1086 scan_id: worktree.scan_id,
1087 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1088 updated_repositories: worktree.updated_repositories,
1089 removed_repositories: worktree.removed_repositories,
1090 };
1091 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1092 session.peer.send(session.connection_id, update.clone())?;
1093 }
1094
1095 // Stream this worktree's diagnostics.
1096 for summary in worktree.diagnostic_summaries {
1097 session.peer.send(
1098 session.connection_id,
1099 proto::UpdateDiagnosticSummary {
1100 project_id: project.id.to_proto(),
1101 worktree_id: worktree.id,
1102 summary: Some(summary),
1103 },
1104 )?;
1105 }
1106
1107 for settings_file in worktree.settings_files {
1108 session.peer.send(
1109 session.connection_id,
1110 proto::UpdateWorktreeSettings {
1111 project_id: project.id.to_proto(),
1112 worktree_id: worktree.id,
1113 path: settings_file.path,
1114 content: Some(settings_file.content),
1115 },
1116 )?;
1117 }
1118 }
1119
1120 for language_server in &project.language_servers {
1121 session.peer.send(
1122 session.connection_id,
1123 proto::UpdateLanguageServer {
1124 project_id: project.id.to_proto(),
1125 language_server_id: language_server.id,
1126 variant: Some(
1127 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1128 proto::LspDiskBasedDiagnosticsUpdated {},
1129 ),
1130 ),
1131 },
1132 )?;
1133 }
1134 }
1135 }
1136
1137 update_user_contacts(session.user_id, &session).await?;
1138 Ok(())
1139}
1140
1141async fn leave_room(
1142 _: proto::LeaveRoom,
1143 response: Response<proto::LeaveRoom>,
1144 session: Session,
1145) -> Result<()> {
1146 leave_room_for_session(&session).await?;
1147 response.send(proto::Ack {})?;
1148 Ok(())
1149}
1150
1151async fn call(
1152 request: proto::Call,
1153 response: Response<proto::Call>,
1154 session: Session,
1155) -> Result<()> {
1156 let room_id = RoomId::from_proto(request.room_id);
1157 let calling_user_id = session.user_id;
1158 let calling_connection_id = session.connection_id;
1159 let called_user_id = UserId::from_proto(request.called_user_id);
1160 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1161 if !session
1162 .db()
1163 .await
1164 .has_contact(calling_user_id, called_user_id)
1165 .await?
1166 {
1167 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1168 }
1169
1170 let incoming_call = {
1171 let (room, incoming_call) = &mut *session
1172 .db()
1173 .await
1174 .call(
1175 room_id,
1176 calling_user_id,
1177 calling_connection_id,
1178 called_user_id,
1179 initial_project_id,
1180 )
1181 .await?;
1182 room_updated(&room, &session.peer);
1183 mem::take(incoming_call)
1184 };
1185 update_user_contacts(called_user_id, &session).await?;
1186
1187 let mut calls = session
1188 .connection_pool()
1189 .await
1190 .user_connection_ids(called_user_id)
1191 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1192 .collect::<FuturesUnordered<_>>();
1193
1194 while let Some(call_response) = calls.next().await {
1195 match call_response.as_ref() {
1196 Ok(_) => {
1197 response.send(proto::Ack {})?;
1198 return Ok(());
1199 }
1200 Err(_) => {
1201 call_response.trace_err();
1202 }
1203 }
1204 }
1205
1206 {
1207 let room = session
1208 .db()
1209 .await
1210 .call_failed(room_id, called_user_id)
1211 .await?;
1212 room_updated(&room, &session.peer);
1213 }
1214 update_user_contacts(called_user_id, &session).await?;
1215
1216 Err(anyhow!("failed to ring user"))?
1217}
1218
1219async fn cancel_call(
1220 request: proto::CancelCall,
1221 response: Response<proto::CancelCall>,
1222 session: Session,
1223) -> Result<()> {
1224 let called_user_id = UserId::from_proto(request.called_user_id);
1225 let room_id = RoomId::from_proto(request.room_id);
1226 {
1227 let room = session
1228 .db()
1229 .await
1230 .cancel_call(room_id, session.connection_id, called_user_id)
1231 .await?;
1232 room_updated(&room, &session.peer);
1233 }
1234
1235 for connection_id in session
1236 .connection_pool()
1237 .await
1238 .user_connection_ids(called_user_id)
1239 {
1240 session
1241 .peer
1242 .send(
1243 connection_id,
1244 proto::CallCanceled {
1245 room_id: room_id.to_proto(),
1246 },
1247 )
1248 .trace_err();
1249 }
1250 response.send(proto::Ack {})?;
1251
1252 update_user_contacts(called_user_id, &session).await?;
1253 Ok(())
1254}
1255
1256async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1257 let room_id = RoomId::from_proto(message.room_id);
1258 {
1259 let room = session
1260 .db()
1261 .await
1262 .decline_call(Some(room_id), session.user_id)
1263 .await?
1264 .ok_or_else(|| anyhow!("failed to decline call"))?;
1265 room_updated(&room, &session.peer);
1266 }
1267
1268 for connection_id in session
1269 .connection_pool()
1270 .await
1271 .user_connection_ids(session.user_id)
1272 {
1273 session
1274 .peer
1275 .send(
1276 connection_id,
1277 proto::CallCanceled {
1278 room_id: room_id.to_proto(),
1279 },
1280 )
1281 .trace_err();
1282 }
1283 update_user_contacts(session.user_id, &session).await?;
1284 Ok(())
1285}
1286
1287async fn update_participant_location(
1288 request: proto::UpdateParticipantLocation,
1289 response: Response<proto::UpdateParticipantLocation>,
1290 session: Session,
1291) -> Result<()> {
1292 let room_id = RoomId::from_proto(request.room_id);
1293 let location = request
1294 .location
1295 .ok_or_else(|| anyhow!("invalid location"))?;
1296
1297 let db = session.db().await;
1298 let room = db
1299 .update_room_participant_location(room_id, session.connection_id, location)
1300 .await?;
1301
1302 room_updated(&room, &session.peer);
1303 response.send(proto::Ack {})?;
1304 Ok(())
1305}
1306
1307async fn share_project(
1308 request: proto::ShareProject,
1309 response: Response<proto::ShareProject>,
1310 session: Session,
1311) -> Result<()> {
1312 let (project_id, room) = &*session
1313 .db()
1314 .await
1315 .share_project(
1316 RoomId::from_proto(request.room_id),
1317 session.connection_id,
1318 &request.worktrees,
1319 )
1320 .await?;
1321 response.send(proto::ShareProjectResponse {
1322 project_id: project_id.to_proto(),
1323 })?;
1324 room_updated(&room, &session.peer);
1325
1326 Ok(())
1327}
1328
1329async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1330 let project_id = ProjectId::from_proto(message.project_id);
1331
1332 let (room, guest_connection_ids) = &*session
1333 .db()
1334 .await
1335 .unshare_project(project_id, session.connection_id)
1336 .await?;
1337
1338 broadcast(
1339 Some(session.connection_id),
1340 guest_connection_ids.iter().copied(),
1341 |conn_id| session.peer.send(conn_id, message.clone()),
1342 );
1343 room_updated(&room, &session.peer);
1344
1345 Ok(())
1346}
1347
1348async fn join_project(
1349 request: proto::JoinProject,
1350 response: Response<proto::JoinProject>,
1351 session: Session,
1352) -> Result<()> {
1353 let project_id = ProjectId::from_proto(request.project_id);
1354 let guest_user_id = session.user_id;
1355
1356 tracing::info!(%project_id, "join project");
1357
1358 let (project, replica_id) = &mut *session
1359 .db()
1360 .await
1361 .join_project(project_id, session.connection_id)
1362 .await?;
1363
1364 let collaborators = project
1365 .collaborators
1366 .iter()
1367 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1368 .map(|collaborator| collaborator.to_proto())
1369 .collect::<Vec<_>>();
1370
1371 let worktrees = project
1372 .worktrees
1373 .iter()
1374 .map(|(id, worktree)| proto::WorktreeMetadata {
1375 id: *id,
1376 root_name: worktree.root_name.clone(),
1377 visible: worktree.visible,
1378 abs_path: worktree.abs_path.clone(),
1379 })
1380 .collect::<Vec<_>>();
1381
1382 for collaborator in &collaborators {
1383 session
1384 .peer
1385 .send(
1386 collaborator.peer_id.unwrap().into(),
1387 proto::AddProjectCollaborator {
1388 project_id: project_id.to_proto(),
1389 collaborator: Some(proto::Collaborator {
1390 peer_id: Some(session.connection_id.into()),
1391 replica_id: replica_id.0 as u32,
1392 user_id: guest_user_id.to_proto(),
1393 }),
1394 },
1395 )
1396 .trace_err();
1397 }
1398
1399 // First, we send the metadata associated with each worktree.
1400 response.send(proto::JoinProjectResponse {
1401 worktrees: worktrees.clone(),
1402 replica_id: replica_id.0 as u32,
1403 collaborators: collaborators.clone(),
1404 language_servers: project.language_servers.clone(),
1405 })?;
1406
1407 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1408 #[cfg(any(test, feature = "test-support"))]
1409 const MAX_CHUNK_SIZE: usize = 2;
1410 #[cfg(not(any(test, feature = "test-support")))]
1411 const MAX_CHUNK_SIZE: usize = 256;
1412
1413 // Stream this worktree's entries.
1414 let message = proto::UpdateWorktree {
1415 project_id: project_id.to_proto(),
1416 worktree_id,
1417 abs_path: worktree.abs_path.clone(),
1418 root_name: worktree.root_name,
1419 updated_entries: worktree.entries,
1420 removed_entries: Default::default(),
1421 scan_id: worktree.scan_id,
1422 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1423 updated_repositories: worktree.repository_entries.into_values().collect(),
1424 removed_repositories: Default::default(),
1425 };
1426 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1427 session.peer.send(session.connection_id, update.clone())?;
1428 }
1429
1430 // Stream this worktree's diagnostics.
1431 for summary in worktree.diagnostic_summaries {
1432 session.peer.send(
1433 session.connection_id,
1434 proto::UpdateDiagnosticSummary {
1435 project_id: project_id.to_proto(),
1436 worktree_id: worktree.id,
1437 summary: Some(summary),
1438 },
1439 )?;
1440 }
1441
1442 for settings_file in worktree.settings_files {
1443 session.peer.send(
1444 session.connection_id,
1445 proto::UpdateWorktreeSettings {
1446 project_id: project_id.to_proto(),
1447 worktree_id: worktree.id,
1448 path: settings_file.path,
1449 content: Some(settings_file.content),
1450 },
1451 )?;
1452 }
1453 }
1454
1455 for language_server in &project.language_servers {
1456 session.peer.send(
1457 session.connection_id,
1458 proto::UpdateLanguageServer {
1459 project_id: project_id.to_proto(),
1460 language_server_id: language_server.id,
1461 variant: Some(
1462 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1463 proto::LspDiskBasedDiagnosticsUpdated {},
1464 ),
1465 ),
1466 },
1467 )?;
1468 }
1469
1470 Ok(())
1471}
1472
1473async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1474 let sender_id = session.connection_id;
1475 let project_id = ProjectId::from_proto(request.project_id);
1476
1477 let (room, project) = &*session
1478 .db()
1479 .await
1480 .leave_project(project_id, sender_id)
1481 .await?;
1482 tracing::info!(
1483 %project_id,
1484 host_user_id = %project.host_user_id,
1485 host_connection_id = %project.host_connection_id,
1486 "leave project"
1487 );
1488
1489 project_left(&project, &session);
1490 room_updated(&room, &session.peer);
1491
1492 Ok(())
1493}
1494
1495async fn update_project(
1496 request: proto::UpdateProject,
1497 response: Response<proto::UpdateProject>,
1498 session: Session,
1499) -> Result<()> {
1500 let project_id = ProjectId::from_proto(request.project_id);
1501 let (room, guest_connection_ids) = &*session
1502 .db()
1503 .await
1504 .update_project(project_id, session.connection_id, &request.worktrees)
1505 .await?;
1506 broadcast(
1507 Some(session.connection_id),
1508 guest_connection_ids.iter().copied(),
1509 |connection_id| {
1510 session
1511 .peer
1512 .forward_send(session.connection_id, connection_id, request.clone())
1513 },
1514 );
1515 room_updated(&room, &session.peer);
1516 response.send(proto::Ack {})?;
1517
1518 Ok(())
1519}
1520
1521async fn update_worktree(
1522 request: proto::UpdateWorktree,
1523 response: Response<proto::UpdateWorktree>,
1524 session: Session,
1525) -> Result<()> {
1526 let guest_connection_ids = session
1527 .db()
1528 .await
1529 .update_worktree(&request, session.connection_id)
1530 .await?;
1531
1532 broadcast(
1533 Some(session.connection_id),
1534 guest_connection_ids.iter().copied(),
1535 |connection_id| {
1536 session
1537 .peer
1538 .forward_send(session.connection_id, connection_id, request.clone())
1539 },
1540 );
1541 response.send(proto::Ack {})?;
1542 Ok(())
1543}
1544
1545async fn update_diagnostic_summary(
1546 message: proto::UpdateDiagnosticSummary,
1547 session: Session,
1548) -> Result<()> {
1549 let guest_connection_ids = session
1550 .db()
1551 .await
1552 .update_diagnostic_summary(&message, session.connection_id)
1553 .await?;
1554
1555 broadcast(
1556 Some(session.connection_id),
1557 guest_connection_ids.iter().copied(),
1558 |connection_id| {
1559 session
1560 .peer
1561 .forward_send(session.connection_id, connection_id, message.clone())
1562 },
1563 );
1564
1565 Ok(())
1566}
1567
1568async fn update_worktree_settings(
1569 message: proto::UpdateWorktreeSettings,
1570 session: Session,
1571) -> Result<()> {
1572 let guest_connection_ids = session
1573 .db()
1574 .await
1575 .update_worktree_settings(&message, session.connection_id)
1576 .await?;
1577
1578 broadcast(
1579 Some(session.connection_id),
1580 guest_connection_ids.iter().copied(),
1581 |connection_id| {
1582 session
1583 .peer
1584 .forward_send(session.connection_id, connection_id, message.clone())
1585 },
1586 );
1587
1588 Ok(())
1589}
1590
1591async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1592 broadcast_project_message(request.project_id, request, session).await
1593}
1594
1595async fn start_language_server(
1596 request: proto::StartLanguageServer,
1597 session: Session,
1598) -> Result<()> {
1599 let guest_connection_ids = session
1600 .db()
1601 .await
1602 .start_language_server(&request, session.connection_id)
1603 .await?;
1604
1605 broadcast(
1606 Some(session.connection_id),
1607 guest_connection_ids.iter().copied(),
1608 |connection_id| {
1609 session
1610 .peer
1611 .forward_send(session.connection_id, connection_id, request.clone())
1612 },
1613 );
1614 Ok(())
1615}
1616
1617async fn update_language_server(
1618 request: proto::UpdateLanguageServer,
1619 session: Session,
1620) -> Result<()> {
1621 session.executor.record_backtrace();
1622 let project_id = ProjectId::from_proto(request.project_id);
1623 let project_connection_ids = session
1624 .db()
1625 .await
1626 .project_connection_ids(project_id, session.connection_id)
1627 .await?;
1628 broadcast(
1629 Some(session.connection_id),
1630 project_connection_ids.iter().copied(),
1631 |connection_id| {
1632 session
1633 .peer
1634 .forward_send(session.connection_id, connection_id, request.clone())
1635 },
1636 );
1637 Ok(())
1638}
1639
1640async fn forward_project_request<T>(
1641 request: T,
1642 response: Response<T>,
1643 session: Session,
1644) -> Result<()>
1645where
1646 T: EntityMessage + RequestMessage,
1647{
1648 session.executor.record_backtrace();
1649 let project_id = ProjectId::from_proto(request.remote_entity_id());
1650 let host_connection_id = {
1651 let collaborators = session
1652 .db()
1653 .await
1654 .project_collaborators(project_id, session.connection_id)
1655 .await?;
1656 collaborators
1657 .iter()
1658 .find(|collaborator| collaborator.is_host)
1659 .ok_or_else(|| anyhow!("host not found"))?
1660 .connection_id
1661 };
1662
1663 let payload = session
1664 .peer
1665 .forward_request(session.connection_id, host_connection_id, request)
1666 .await?;
1667
1668 response.send(payload)?;
1669 Ok(())
1670}
1671
1672async fn create_buffer_for_peer(
1673 request: proto::CreateBufferForPeer,
1674 session: Session,
1675) -> Result<()> {
1676 session.executor.record_backtrace();
1677 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1678 session
1679 .peer
1680 .forward_send(session.connection_id, peer_id.into(), request)?;
1681 Ok(())
1682}
1683
1684async fn update_buffer(
1685 request: proto::UpdateBuffer,
1686 response: Response<proto::UpdateBuffer>,
1687 session: Session,
1688) -> Result<()> {
1689 session.executor.record_backtrace();
1690 let project_id = ProjectId::from_proto(request.project_id);
1691 let mut guest_connection_ids;
1692 let mut host_connection_id = None;
1693 {
1694 let collaborators = session
1695 .db()
1696 .await
1697 .project_collaborators(project_id, session.connection_id)
1698 .await?;
1699 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1700 for collaborator in collaborators.iter() {
1701 if collaborator.is_host {
1702 host_connection_id = Some(collaborator.connection_id);
1703 } else {
1704 guest_connection_ids.push(collaborator.connection_id);
1705 }
1706 }
1707 }
1708 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1709
1710 session.executor.record_backtrace();
1711 broadcast(
1712 Some(session.connection_id),
1713 guest_connection_ids,
1714 |connection_id| {
1715 session
1716 .peer
1717 .forward_send(session.connection_id, connection_id, request.clone())
1718 },
1719 );
1720 if host_connection_id != session.connection_id {
1721 session
1722 .peer
1723 .forward_request(session.connection_id, host_connection_id, request.clone())
1724 .await?;
1725 }
1726
1727 response.send(proto::Ack {})?;
1728 Ok(())
1729}
1730
1731async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1732 let project_id = ProjectId::from_proto(request.project_id);
1733 let project_connection_ids = session
1734 .db()
1735 .await
1736 .project_connection_ids(project_id, session.connection_id)
1737 .await?;
1738
1739 broadcast(
1740 Some(session.connection_id),
1741 project_connection_ids.iter().copied(),
1742 |connection_id| {
1743 session
1744 .peer
1745 .forward_send(session.connection_id, connection_id, request.clone())
1746 },
1747 );
1748 Ok(())
1749}
1750
1751async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1752 let project_id = ProjectId::from_proto(request.project_id);
1753 let project_connection_ids = session
1754 .db()
1755 .await
1756 .project_connection_ids(project_id, session.connection_id)
1757 .await?;
1758 broadcast(
1759 Some(session.connection_id),
1760 project_connection_ids.iter().copied(),
1761 |connection_id| {
1762 session
1763 .peer
1764 .forward_send(session.connection_id, connection_id, request.clone())
1765 },
1766 );
1767 Ok(())
1768}
1769
1770async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1771 broadcast_project_message(request.project_id, request, session).await
1772}
1773
1774async fn broadcast_project_message<T: EnvelopedMessage>(
1775 project_id: u64,
1776 request: T,
1777 session: Session,
1778) -> Result<()> {
1779 let project_id = ProjectId::from_proto(project_id);
1780 let project_connection_ids = session
1781 .db()
1782 .await
1783 .project_connection_ids(project_id, session.connection_id)
1784 .await?;
1785 broadcast(
1786 Some(session.connection_id),
1787 project_connection_ids.iter().copied(),
1788 |connection_id| {
1789 session
1790 .peer
1791 .forward_send(session.connection_id, connection_id, request.clone())
1792 },
1793 );
1794 Ok(())
1795}
1796
1797async fn follow(
1798 request: proto::Follow,
1799 response: Response<proto::Follow>,
1800 session: Session,
1801) -> Result<()> {
1802 let project_id = ProjectId::from_proto(request.project_id);
1803 let leader_id = request
1804 .leader_id
1805 .ok_or_else(|| anyhow!("invalid leader id"))?
1806 .into();
1807 let follower_id = session.connection_id;
1808
1809 {
1810 let project_connection_ids = session
1811 .db()
1812 .await
1813 .project_connection_ids(project_id, session.connection_id)
1814 .await?;
1815
1816 if !project_connection_ids.contains(&leader_id) {
1817 Err(anyhow!("no such peer"))?;
1818 }
1819 }
1820
1821 let mut response_payload = session
1822 .peer
1823 .forward_request(session.connection_id, leader_id, request)
1824 .await?;
1825 response_payload
1826 .views
1827 .retain(|view| view.leader_id != Some(follower_id.into()));
1828 response.send(response_payload)?;
1829
1830 let room = session
1831 .db()
1832 .await
1833 .follow(project_id, leader_id, follower_id)
1834 .await?;
1835 room_updated(&room, &session.peer);
1836
1837 Ok(())
1838}
1839
1840async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1841 let project_id = ProjectId::from_proto(request.project_id);
1842 let leader_id = request
1843 .leader_id
1844 .ok_or_else(|| anyhow!("invalid leader id"))?
1845 .into();
1846 let follower_id = session.connection_id;
1847
1848 if !session
1849 .db()
1850 .await
1851 .project_connection_ids(project_id, session.connection_id)
1852 .await?
1853 .contains(&leader_id)
1854 {
1855 Err(anyhow!("no such peer"))?;
1856 }
1857
1858 session
1859 .peer
1860 .forward_send(session.connection_id, leader_id, request)?;
1861
1862 let room = session
1863 .db()
1864 .await
1865 .unfollow(project_id, leader_id, follower_id)
1866 .await?;
1867 room_updated(&room, &session.peer);
1868
1869 Ok(())
1870}
1871
1872async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1873 let project_id = ProjectId::from_proto(request.project_id);
1874 let project_connection_ids = session
1875 .db
1876 .lock()
1877 .await
1878 .project_connection_ids(project_id, session.connection_id)
1879 .await?;
1880
1881 let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1882 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1883 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1884 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1885 });
1886 for follower_peer_id in request.follower_ids.iter().copied() {
1887 let follower_connection_id = follower_peer_id.into();
1888 if project_connection_ids.contains(&follower_connection_id)
1889 && Some(follower_peer_id) != leader_id
1890 {
1891 session.peer.forward_send(
1892 session.connection_id,
1893 follower_connection_id,
1894 request.clone(),
1895 )?;
1896 }
1897 }
1898 Ok(())
1899}
1900
1901async fn get_users(
1902 request: proto::GetUsers,
1903 response: Response<proto::GetUsers>,
1904 session: Session,
1905) -> Result<()> {
1906 let user_ids = request
1907 .user_ids
1908 .into_iter()
1909 .map(UserId::from_proto)
1910 .collect();
1911 let users = session
1912 .db()
1913 .await
1914 .get_users_by_ids(user_ids)
1915 .await?
1916 .into_iter()
1917 .map(|user| proto::User {
1918 id: user.id.to_proto(),
1919 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1920 github_login: user.github_login,
1921 })
1922 .collect();
1923 response.send(proto::UsersResponse { users })?;
1924 Ok(())
1925}
1926
1927async fn fuzzy_search_users(
1928 request: proto::FuzzySearchUsers,
1929 response: Response<proto::FuzzySearchUsers>,
1930 session: Session,
1931) -> Result<()> {
1932 let query = request.query;
1933 let users = match query.len() {
1934 0 => vec![],
1935 1 | 2 => session
1936 .db()
1937 .await
1938 .get_user_by_github_login(&query)
1939 .await?
1940 .into_iter()
1941 .collect(),
1942 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1943 };
1944 let users = users
1945 .into_iter()
1946 .filter(|user| user.id != session.user_id)
1947 .map(|user| proto::User {
1948 id: user.id.to_proto(),
1949 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1950 github_login: user.github_login,
1951 })
1952 .collect();
1953 response.send(proto::UsersResponse { users })?;
1954 Ok(())
1955}
1956
1957async fn request_contact(
1958 request: proto::RequestContact,
1959 response: Response<proto::RequestContact>,
1960 session: Session,
1961) -> Result<()> {
1962 let requester_id = session.user_id;
1963 let responder_id = UserId::from_proto(request.responder_id);
1964 if requester_id == responder_id {
1965 return Err(anyhow!("cannot add yourself as a contact"))?;
1966 }
1967
1968 session
1969 .db()
1970 .await
1971 .send_contact_request(requester_id, responder_id)
1972 .await?;
1973
1974 // Update outgoing contact requests of requester
1975 let mut update = proto::UpdateContacts::default();
1976 update.outgoing_requests.push(responder_id.to_proto());
1977 for connection_id in session
1978 .connection_pool()
1979 .await
1980 .user_connection_ids(requester_id)
1981 {
1982 session.peer.send(connection_id, update.clone())?;
1983 }
1984
1985 // Update incoming contact requests of responder
1986 let mut update = proto::UpdateContacts::default();
1987 update
1988 .incoming_requests
1989 .push(proto::IncomingContactRequest {
1990 requester_id: requester_id.to_proto(),
1991 should_notify: true,
1992 });
1993 for connection_id in session
1994 .connection_pool()
1995 .await
1996 .user_connection_ids(responder_id)
1997 {
1998 session.peer.send(connection_id, update.clone())?;
1999 }
2000
2001 response.send(proto::Ack {})?;
2002 Ok(())
2003}
2004
2005async fn respond_to_contact_request(
2006 request: proto::RespondToContactRequest,
2007 response: Response<proto::RespondToContactRequest>,
2008 session: Session,
2009) -> Result<()> {
2010 let responder_id = session.user_id;
2011 let requester_id = UserId::from_proto(request.requester_id);
2012 let db = session.db().await;
2013 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2014 db.dismiss_contact_notification(responder_id, requester_id)
2015 .await?;
2016 } else {
2017 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2018
2019 db.respond_to_contact_request(responder_id, requester_id, accept)
2020 .await?;
2021 let requester_busy = db.is_user_busy(requester_id).await?;
2022 let responder_busy = db.is_user_busy(responder_id).await?;
2023
2024 let pool = session.connection_pool().await;
2025 // Update responder with new contact
2026 let mut update = proto::UpdateContacts::default();
2027 if accept {
2028 update
2029 .contacts
2030 .push(contact_for_user(requester_id, false, requester_busy, &pool));
2031 }
2032 update
2033 .remove_incoming_requests
2034 .push(requester_id.to_proto());
2035 for connection_id in pool.user_connection_ids(responder_id) {
2036 session.peer.send(connection_id, update.clone())?;
2037 }
2038
2039 // Update requester with new contact
2040 let mut update = proto::UpdateContacts::default();
2041 if accept {
2042 update
2043 .contacts
2044 .push(contact_for_user(responder_id, true, responder_busy, &pool));
2045 }
2046 update
2047 .remove_outgoing_requests
2048 .push(responder_id.to_proto());
2049 for connection_id in pool.user_connection_ids(requester_id) {
2050 session.peer.send(connection_id, update.clone())?;
2051 }
2052 }
2053
2054 response.send(proto::Ack {})?;
2055 Ok(())
2056}
2057
2058async fn remove_contact(
2059 request: proto::RemoveContact,
2060 response: Response<proto::RemoveContact>,
2061 session: Session,
2062) -> Result<()> {
2063 let requester_id = session.user_id;
2064 let responder_id = UserId::from_proto(request.user_id);
2065 let db = session.db().await;
2066 let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
2067
2068 let pool = session.connection_pool().await;
2069 // Update outgoing contact requests of requester
2070 let mut update = proto::UpdateContacts::default();
2071 if contact_accepted {
2072 update.remove_contacts.push(responder_id.to_proto());
2073 } else {
2074 update
2075 .remove_outgoing_requests
2076 .push(responder_id.to_proto());
2077 }
2078 for connection_id in pool.user_connection_ids(requester_id) {
2079 session.peer.send(connection_id, update.clone())?;
2080 }
2081
2082 // Update incoming contact requests of responder
2083 let mut update = proto::UpdateContacts::default();
2084 if contact_accepted {
2085 update.remove_contacts.push(requester_id.to_proto());
2086 } else {
2087 update
2088 .remove_incoming_requests
2089 .push(requester_id.to_proto());
2090 }
2091 for connection_id in pool.user_connection_ids(responder_id) {
2092 session.peer.send(connection_id, update.clone())?;
2093 }
2094
2095 response.send(proto::Ack {})?;
2096 Ok(())
2097}
2098
2099async fn create_channel(
2100 request: proto::CreateChannel,
2101 response: Response<proto::CreateChannel>,
2102 session: Session,
2103) -> Result<()> {
2104 let db = session.db().await;
2105 let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
2106
2107 if let Some(live_kit) = session.live_kit_client.as_ref() {
2108 live_kit.create_room(live_kit_room.clone()).await?;
2109 }
2110
2111 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2112 let id = db
2113 .create_channel(&request.name, parent_id, &live_kit_room, session.user_id)
2114 .await?;
2115
2116 response.send(proto::CreateChannelResponse {
2117 channel_id: id.to_proto(),
2118 })?;
2119
2120 let mut update = proto::UpdateChannels::default();
2121 update.channels.push(proto::Channel {
2122 id: id.to_proto(),
2123 name: request.name,
2124 parent_id: request.parent_id,
2125 });
2126
2127 if let Some(parent_id) = parent_id {
2128 let member_ids = db.get_channel_members(parent_id).await?;
2129 let connection_pool = session.connection_pool().await;
2130 for member_id in member_ids {
2131 for connection_id in connection_pool.user_connection_ids(member_id) {
2132 session.peer.send(connection_id, update.clone())?;
2133 }
2134 }
2135 } else {
2136 session.peer.send(session.connection_id, update)?;
2137 }
2138
2139 Ok(())
2140}
2141
2142async fn remove_channel(
2143 request: proto::RemoveChannel,
2144 response: Response<proto::RemoveChannel>,
2145 session: Session,
2146) -> Result<()> {
2147 let db = session.db().await;
2148
2149 let channel_id = request.channel_id;
2150 let (removed_channels, member_ids) = db
2151 .remove_channel(ChannelId::from_proto(channel_id), session.user_id)
2152 .await?;
2153 response.send(proto::Ack {})?;
2154
2155 // Notify members of removed channels
2156 let mut update = proto::UpdateChannels::default();
2157 update
2158 .remove_channels
2159 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2160
2161 let connection_pool = session.connection_pool().await;
2162 for member_id in member_ids {
2163 for connection_id in connection_pool.user_connection_ids(member_id) {
2164 session.peer.send(connection_id, update.clone())?;
2165 }
2166 }
2167
2168 Ok(())
2169}
2170
2171async fn invite_channel_member(
2172 request: proto::InviteChannelMember,
2173 response: Response<proto::InviteChannelMember>,
2174 session: Session,
2175) -> Result<()> {
2176 let db = session.db().await;
2177 let channel_id = ChannelId::from_proto(request.channel_id);
2178 let channel = db
2179 .get_channel(channel_id)
2180 .await?
2181 .ok_or_else(|| anyhow!("channel not found"))?;
2182 let invitee_id = UserId::from_proto(request.user_id);
2183 db.invite_channel_member(channel_id, invitee_id, session.user_id, false)
2184 .await?;
2185
2186 let mut update = proto::UpdateChannels::default();
2187 update.channel_invitations.push(proto::Channel {
2188 id: channel.id.to_proto(),
2189 name: channel.name,
2190 parent_id: None,
2191 });
2192 for connection_id in session
2193 .connection_pool()
2194 .await
2195 .user_connection_ids(invitee_id)
2196 {
2197 session.peer.send(connection_id, update.clone())?;
2198 }
2199
2200 response.send(proto::Ack {})?;
2201 Ok(())
2202}
2203
2204async fn remove_channel_member(
2205 request: proto::RemoveChannelMember,
2206 response: Response<proto::RemoveChannelMember>,
2207 session: Session,
2208) -> Result<()> {
2209 Ok(())
2210}
2211
2212async fn respond_to_channel_invite(
2213 request: proto::RespondToChannelInvite,
2214 response: Response<proto::RespondToChannelInvite>,
2215 session: Session,
2216) -> Result<()> {
2217 let db = session.db().await;
2218 let channel_id = ChannelId::from_proto(request.channel_id);
2219 let channel = db
2220 .get_channel(channel_id)
2221 .await?
2222 .ok_or_else(|| anyhow!("no such channel"))?;
2223 db.respond_to_channel_invite(channel_id, session.user_id, request.accept)
2224 .await?;
2225
2226 let mut update = proto::UpdateChannels::default();
2227 update
2228 .remove_channel_invitations
2229 .push(channel_id.to_proto());
2230 if request.accept {
2231 update.channels.push(proto::Channel {
2232 id: channel.id.to_proto(),
2233 name: channel.name,
2234 parent_id: None,
2235 });
2236 }
2237 session.peer.send(session.connection_id, update)?;
2238 response.send(proto::Ack {})?;
2239
2240 Ok(())
2241}
2242
2243async fn join_channel(
2244 request: proto::JoinChannel,
2245 response: Response<proto::JoinChannel>,
2246 session: Session,
2247) -> Result<()> {
2248 let channel_id = ChannelId::from_proto(request.channel_id);
2249
2250 {
2251 let db = session.db().await;
2252 let room_id = db.get_channel_room(channel_id).await?;
2253
2254 let room = db
2255 .join_room(
2256 room_id,
2257 session.user_id,
2258 Some(channel_id),
2259 session.connection_id,
2260 )
2261 .await?;
2262
2263 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2264 let token = live_kit
2265 .room_token(&room.live_kit_room, &session.user_id.to_string())
2266 .trace_err()?;
2267
2268 Some(LiveKitConnectionInfo {
2269 server_url: live_kit.url().into(),
2270 token,
2271 })
2272 });
2273
2274 response.send(proto::JoinRoomResponse {
2275 room: Some(room.clone()),
2276 live_kit_connection_info,
2277 })?;
2278
2279 room_updated(&room, &session.peer);
2280 }
2281
2282 update_user_contacts(session.user_id, &session).await?;
2283
2284 Ok(())
2285}
2286
2287async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
2288 let project_id = ProjectId::from_proto(request.project_id);
2289 let project_connection_ids = session
2290 .db()
2291 .await
2292 .project_connection_ids(project_id, session.connection_id)
2293 .await?;
2294 broadcast(
2295 Some(session.connection_id),
2296 project_connection_ids.iter().copied(),
2297 |connection_id| {
2298 session
2299 .peer
2300 .forward_send(session.connection_id, connection_id, request.clone())
2301 },
2302 );
2303 Ok(())
2304}
2305
2306async fn get_private_user_info(
2307 _request: proto::GetPrivateUserInfo,
2308 response: Response<proto::GetPrivateUserInfo>,
2309 session: Session,
2310) -> Result<()> {
2311 let metrics_id = session
2312 .db()
2313 .await
2314 .get_user_metrics_id(session.user_id)
2315 .await?;
2316 let user = session
2317 .db()
2318 .await
2319 .get_user_by_id(session.user_id)
2320 .await?
2321 .ok_or_else(|| anyhow!("user not found"))?;
2322 response.send(proto::GetPrivateUserInfoResponse {
2323 metrics_id,
2324 staff: user.admin,
2325 })?;
2326 Ok(())
2327}
2328
2329fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
2330 match message {
2331 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
2332 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
2333 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
2334 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
2335 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
2336 code: frame.code.into(),
2337 reason: frame.reason,
2338 })),
2339 }
2340}
2341
2342fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
2343 match message {
2344 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
2345 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
2346 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
2347 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
2348 AxumMessage::Close(frame) => {
2349 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
2350 code: frame.code.into(),
2351 reason: frame.reason,
2352 }))
2353 }
2354 }
2355}
2356
2357fn build_initial_channels_update(
2358 channels: Vec<db::Channel>,
2359 channel_invites: Vec<db::Channel>,
2360) -> proto::UpdateChannels {
2361 let mut update = proto::UpdateChannels::default();
2362
2363 for channel in channels {
2364 update.channels.push(proto::Channel {
2365 id: channel.id.to_proto(),
2366 name: channel.name,
2367 parent_id: channel.parent_id.map(|id| id.to_proto()),
2368 });
2369 }
2370
2371 for channel in channel_invites {
2372 update.channel_invitations.push(proto::Channel {
2373 id: channel.id.to_proto(),
2374 name: channel.name,
2375 parent_id: None,
2376 });
2377 }
2378
2379 update
2380}
2381
2382fn build_initial_contacts_update(
2383 contacts: Vec<db::Contact>,
2384 pool: &ConnectionPool,
2385) -> proto::UpdateContacts {
2386 let mut update = proto::UpdateContacts::default();
2387
2388 for contact in contacts {
2389 match contact {
2390 db::Contact::Accepted {
2391 user_id,
2392 should_notify,
2393 busy,
2394 } => {
2395 update
2396 .contacts
2397 .push(contact_for_user(user_id, should_notify, busy, &pool));
2398 }
2399 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
2400 db::Contact::Incoming {
2401 user_id,
2402 should_notify,
2403 } => update
2404 .incoming_requests
2405 .push(proto::IncomingContactRequest {
2406 requester_id: user_id.to_proto(),
2407 should_notify,
2408 }),
2409 }
2410 }
2411
2412 update
2413}
2414
2415fn contact_for_user(
2416 user_id: UserId,
2417 should_notify: bool,
2418 busy: bool,
2419 pool: &ConnectionPool,
2420) -> proto::Contact {
2421 proto::Contact {
2422 user_id: user_id.to_proto(),
2423 online: pool.is_user_online(user_id),
2424 busy,
2425 should_notify,
2426 }
2427}
2428
2429fn room_updated(room: &proto::Room, peer: &Peer) {
2430 broadcast(
2431 None,
2432 room.participants
2433 .iter()
2434 .filter_map(|participant| Some(participant.peer_id?.into())),
2435 |peer_id| {
2436 peer.send(
2437 peer_id.into(),
2438 proto::RoomUpdated {
2439 room: Some(room.clone()),
2440 },
2441 )
2442 },
2443 );
2444}
2445
2446async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
2447 let db = session.db().await;
2448 let contacts = db.get_contacts(user_id).await?;
2449 let busy = db.is_user_busy(user_id).await?;
2450
2451 let pool = session.connection_pool().await;
2452 let updated_contact = contact_for_user(user_id, false, busy, &pool);
2453 for contact in contacts {
2454 if let db::Contact::Accepted {
2455 user_id: contact_user_id,
2456 ..
2457 } = contact
2458 {
2459 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
2460 session
2461 .peer
2462 .send(
2463 contact_conn_id,
2464 proto::UpdateContacts {
2465 contacts: vec![updated_contact.clone()],
2466 remove_contacts: Default::default(),
2467 incoming_requests: Default::default(),
2468 remove_incoming_requests: Default::default(),
2469 outgoing_requests: Default::default(),
2470 remove_outgoing_requests: Default::default(),
2471 },
2472 )
2473 .trace_err();
2474 }
2475 }
2476 }
2477 Ok(())
2478}
2479
2480async fn leave_room_for_session(session: &Session) -> Result<()> {
2481 let mut contacts_to_update = HashSet::default();
2482
2483 let room_id;
2484 let canceled_calls_to_user_ids;
2485 let live_kit_room;
2486 let delete_live_kit_room;
2487 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
2488 contacts_to_update.insert(session.user_id);
2489
2490 for project in left_room.left_projects.values() {
2491 project_left(project, session);
2492 }
2493
2494 room_updated(&left_room.room, &session.peer);
2495 room_id = RoomId::from_proto(left_room.room.id);
2496 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
2497 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
2498 delete_live_kit_room = left_room.deleted;
2499 } else {
2500 return Ok(());
2501 }
2502
2503 {
2504 let pool = session.connection_pool().await;
2505 for canceled_user_id in canceled_calls_to_user_ids {
2506 for connection_id in pool.user_connection_ids(canceled_user_id) {
2507 session
2508 .peer
2509 .send(
2510 connection_id,
2511 proto::CallCanceled {
2512 room_id: room_id.to_proto(),
2513 },
2514 )
2515 .trace_err();
2516 }
2517 contacts_to_update.insert(canceled_user_id);
2518 }
2519 }
2520
2521 for contact_user_id in contacts_to_update {
2522 update_user_contacts(contact_user_id, &session).await?;
2523 }
2524
2525 if let Some(live_kit) = session.live_kit_client.as_ref() {
2526 live_kit
2527 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
2528 .await
2529 .trace_err();
2530
2531 if delete_live_kit_room {
2532 live_kit.delete_room(live_kit_room).await.trace_err();
2533 }
2534 }
2535
2536 Ok(())
2537}
2538
2539fn project_left(project: &db::LeftProject, session: &Session) {
2540 for connection_id in &project.connection_ids {
2541 if project.host_user_id == session.user_id {
2542 session
2543 .peer
2544 .send(
2545 *connection_id,
2546 proto::UnshareProject {
2547 project_id: project.id.to_proto(),
2548 },
2549 )
2550 .trace_err();
2551 } else {
2552 session
2553 .peer
2554 .send(
2555 *connection_id,
2556 proto::RemoveProjectCollaborator {
2557 project_id: project.id.to_proto(),
2558 peer_id: Some(session.connection_id.into()),
2559 },
2560 )
2561 .trace_err();
2562 }
2563 }
2564}
2565
2566pub trait ResultExt {
2567 type Ok;
2568
2569 fn trace_err(self) -> Option<Self::Ok>;
2570}
2571
2572impl<T, E> ResultExt for Result<T, E>
2573where
2574 E: std::fmt::Debug,
2575{
2576 type Ok = T;
2577
2578 fn trace_err(self) -> Option<T> {
2579 match self {
2580 Ok(value) => Some(value),
2581 Err(error) => {
2582 tracing::error!("{:?}", error);
2583 None
2584 }
2585 }
2586 }
2587}