1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{self, ChannelId, Database, ProjectId, RoomId, ServerId, User, UserId},
6 executor::Executor,
7 AppState, Result,
8};
9use anyhow::anyhow;
10use async_tungstenite::tungstenite::{
11 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
12};
13use axum::{
14 body::Body,
15 extract::{
16 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
17 ConnectInfo, WebSocketUpgrade,
18 },
19 headers::{Header, HeaderName},
20 http::StatusCode,
21 middleware,
22 response::IntoResponse,
23 routing::get,
24 Extension, Router, TypedHeader,
25};
26use collections::{HashMap, HashSet};
27pub use connection_pool::ConnectionPool;
28use futures::{
29 channel::oneshot,
30 future::{self, BoxFuture},
31 stream::FuturesUnordered,
32 FutureExt, SinkExt, StreamExt, TryStreamExt,
33};
34use lazy_static::lazy_static;
35use prometheus::{register_int_gauge, IntGauge};
36use rpc::{
37 proto::{
38 self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
39 RequestMessage,
40 },
41 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
42};
43use serde::{Serialize, Serializer};
44use std::{
45 any::TypeId,
46 fmt,
47 future::Future,
48 marker::PhantomData,
49 mem,
50 net::SocketAddr,
51 ops::{Deref, DerefMut},
52 rc::Rc,
53 sync::{
54 atomic::{AtomicBool, Ordering::SeqCst},
55 Arc,
56 },
57 time::{Duration, Instant},
58};
59use tokio::sync::{watch, Semaphore};
60use tower::ServiceBuilder;
61use tracing::{info_span, instrument, Instrument};
62
63pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
64pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
65
66lazy_static! {
67 static ref METRIC_CONNECTIONS: IntGauge =
68 register_int_gauge!("connections", "number of connections").unwrap();
69 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
70 "shared_projects",
71 "number of open projects with one or more guests"
72 )
73 .unwrap();
74}
75
76type MessageHandler =
77 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
78
79struct Response<R> {
80 peer: Arc<Peer>,
81 receipt: Receipt<R>,
82 responded: Arc<AtomicBool>,
83}
84
85impl<R: RequestMessage> Response<R> {
86 fn send(self, payload: R::Response) -> Result<()> {
87 self.responded.store(true, SeqCst);
88 self.peer.respond(self.receipt, payload)?;
89 Ok(())
90 }
91}
92
93#[derive(Clone)]
94struct Session {
95 user_id: UserId,
96 connection_id: ConnectionId,
97 db: Arc<tokio::sync::Mutex<DbHandle>>,
98 peer: Arc<Peer>,
99 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
100 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
101 executor: Executor,
102}
103
104impl Session {
105 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
106 #[cfg(test)]
107 tokio::task::yield_now().await;
108 let guard = self.db.lock().await;
109 #[cfg(test)]
110 tokio::task::yield_now().await;
111 guard
112 }
113
114 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
115 #[cfg(test)]
116 tokio::task::yield_now().await;
117 let guard = self.connection_pool.lock();
118 ConnectionPoolGuard {
119 guard,
120 _not_send: PhantomData,
121 }
122 }
123}
124
125impl fmt::Debug for Session {
126 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127 f.debug_struct("Session")
128 .field("user_id", &self.user_id)
129 .field("connection_id", &self.connection_id)
130 .finish()
131 }
132}
133
134struct DbHandle(Arc<Database>);
135
136impl Deref for DbHandle {
137 type Target = Database;
138
139 fn deref(&self) -> &Self::Target {
140 self.0.as_ref()
141 }
142}
143
144pub struct Server {
145 id: parking_lot::Mutex<ServerId>,
146 peer: Arc<Peer>,
147 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
148 app_state: Arc<AppState>,
149 executor: Executor,
150 handlers: HashMap<TypeId, MessageHandler>,
151 teardown: watch::Sender<()>,
152}
153
154pub(crate) struct ConnectionPoolGuard<'a> {
155 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
156 _not_send: PhantomData<Rc<()>>,
157}
158
159#[derive(Serialize)]
160pub struct ServerSnapshot<'a> {
161 peer: &'a Peer,
162 #[serde(serialize_with = "serialize_deref")]
163 connection_pool: ConnectionPoolGuard<'a>,
164}
165
166pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
167where
168 S: Serializer,
169 T: Deref<Target = U>,
170 U: Serialize,
171{
172 Serialize::serialize(value.deref(), serializer)
173}
174
175impl Server {
176 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
177 let mut server = Self {
178 id: parking_lot::Mutex::new(id),
179 peer: Peer::new(id.0 as u32),
180 app_state,
181 executor,
182 connection_pool: Default::default(),
183 handlers: Default::default(),
184 teardown: watch::channel(()).0,
185 };
186
187 server
188 .add_request_handler(ping)
189 .add_request_handler(create_room)
190 .add_request_handler(join_room)
191 .add_request_handler(rejoin_room)
192 .add_request_handler(leave_room)
193 .add_request_handler(call)
194 .add_request_handler(cancel_call)
195 .add_message_handler(decline_call)
196 .add_request_handler(update_participant_location)
197 .add_request_handler(share_project)
198 .add_message_handler(unshare_project)
199 .add_request_handler(join_project)
200 .add_message_handler(leave_project)
201 .add_request_handler(update_project)
202 .add_request_handler(update_worktree)
203 .add_message_handler(start_language_server)
204 .add_message_handler(update_language_server)
205 .add_message_handler(update_diagnostic_summary)
206 .add_message_handler(update_worktree_settings)
207 .add_message_handler(refresh_inlay_hints)
208 .add_request_handler(forward_project_request::<proto::GetHover>)
209 .add_request_handler(forward_project_request::<proto::GetDefinition>)
210 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
211 .add_request_handler(forward_project_request::<proto::GetReferences>)
212 .add_request_handler(forward_project_request::<proto::SearchProject>)
213 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
214 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
215 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
216 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
217 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
218 .add_request_handler(forward_project_request::<proto::GetCompletions>)
219 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
220 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
221 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
222 .add_request_handler(forward_project_request::<proto::PrepareRename>)
223 .add_request_handler(forward_project_request::<proto::PerformRename>)
224 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
225 .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
226 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
227 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
228 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
229 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
230 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
231 .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
232 .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
233 .add_request_handler(forward_project_request::<proto::InlayHints>)
234 .add_message_handler(create_buffer_for_peer)
235 .add_request_handler(update_buffer)
236 .add_message_handler(update_buffer_file)
237 .add_message_handler(buffer_reloaded)
238 .add_message_handler(buffer_saved)
239 .add_request_handler(forward_project_request::<proto::SaveBuffer>)
240 .add_request_handler(get_users)
241 .add_request_handler(fuzzy_search_users)
242 .add_request_handler(request_contact)
243 .add_request_handler(remove_contact)
244 .add_request_handler(respond_to_contact_request)
245 .add_request_handler(create_channel)
246 .add_request_handler(remove_channel)
247 .add_request_handler(invite_channel_member)
248 .add_request_handler(remove_channel_member)
249 .add_request_handler(respond_to_channel_invite)
250 .add_request_handler(join_channel)
251 .add_request_handler(follow)
252 .add_message_handler(unfollow)
253 .add_message_handler(update_followers)
254 .add_message_handler(update_diff_base)
255 .add_request_handler(get_private_user_info);
256
257 Arc::new(server)
258 }
259
260 pub async fn start(&self) -> Result<()> {
261 let server_id = *self.id.lock();
262 let app_state = self.app_state.clone();
263 let peer = self.peer.clone();
264 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
265 let pool = self.connection_pool.clone();
266 let live_kit_client = self.app_state.live_kit_client.clone();
267
268 let span = info_span!("start server");
269 self.executor.spawn_detached(
270 async move {
271 tracing::info!("waiting for cleanup timeout");
272 timeout.await;
273 tracing::info!("cleanup timeout expired, retrieving stale rooms");
274 if let Some(room_ids) = app_state
275 .db
276 .stale_room_ids(&app_state.config.zed_environment, server_id)
277 .await
278 .trace_err()
279 {
280 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
281 for room_id in room_ids {
282 let mut contacts_to_update = HashSet::default();
283 let mut canceled_calls_to_user_ids = Vec::new();
284 let mut live_kit_room = String::new();
285 let mut delete_live_kit_room = false;
286
287 if let Some(mut refreshed_room) = app_state
288 .db
289 .refresh_room(room_id, server_id)
290 .await
291 .trace_err()
292 {
293 tracing::info!(
294 room_id = room_id.0,
295 new_participant_count = refreshed_room.room.participants.len(),
296 "refreshed room"
297 );
298 room_updated(&refreshed_room.room, &peer);
299 if let Some(channel_id) = refreshed_room.channel_id {
300 channel_updated(
301 channel_id,
302 &refreshed_room.room,
303 &refreshed_room.channel_members,
304 &peer,
305 &*pool.lock(),
306 );
307 }
308 contacts_to_update
309 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
310 contacts_to_update
311 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
312 canceled_calls_to_user_ids =
313 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
314 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
315 delete_live_kit_room = refreshed_room.room.participants.is_empty();
316 }
317
318 {
319 let pool = pool.lock();
320 for canceled_user_id in canceled_calls_to_user_ids {
321 for connection_id in pool.user_connection_ids(canceled_user_id) {
322 peer.send(
323 connection_id,
324 proto::CallCanceled {
325 room_id: room_id.to_proto(),
326 },
327 )
328 .trace_err();
329 }
330 }
331 }
332
333 for user_id in contacts_to_update {
334 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
335 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
336 if let Some((busy, contacts)) = busy.zip(contacts) {
337 let pool = pool.lock();
338 let updated_contact = contact_for_user(user_id, false, busy, &pool);
339 for contact in contacts {
340 if let db::Contact::Accepted {
341 user_id: contact_user_id,
342 ..
343 } = contact
344 {
345 for contact_conn_id in
346 pool.user_connection_ids(contact_user_id)
347 {
348 peer.send(
349 contact_conn_id,
350 proto::UpdateContacts {
351 contacts: vec![updated_contact.clone()],
352 remove_contacts: Default::default(),
353 incoming_requests: Default::default(),
354 remove_incoming_requests: Default::default(),
355 outgoing_requests: Default::default(),
356 remove_outgoing_requests: Default::default(),
357 },
358 )
359 .trace_err();
360 }
361 }
362 }
363 }
364 }
365
366 if let Some(live_kit) = live_kit_client.as_ref() {
367 if delete_live_kit_room {
368 live_kit.delete_room(live_kit_room).await.trace_err();
369 }
370 }
371 }
372 }
373
374 app_state
375 .db
376 .delete_stale_servers(&app_state.config.zed_environment, server_id)
377 .await
378 .trace_err();
379 }
380 .instrument(span),
381 );
382 Ok(())
383 }
384
385 pub fn teardown(&self) {
386 self.peer.teardown();
387 self.connection_pool.lock().reset();
388 let _ = self.teardown.send(());
389 }
390
391 #[cfg(test)]
392 pub fn reset(&self, id: ServerId) {
393 self.teardown();
394 *self.id.lock() = id;
395 self.peer.reset(id.0 as u32);
396 }
397
398 #[cfg(test)]
399 pub fn id(&self) -> ServerId {
400 *self.id.lock()
401 }
402
403 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
404 where
405 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
406 Fut: 'static + Send + Future<Output = Result<()>>,
407 M: EnvelopedMessage,
408 {
409 let prev_handler = self.handlers.insert(
410 TypeId::of::<M>(),
411 Box::new(move |envelope, session| {
412 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
413 let span = info_span!(
414 "handle message",
415 payload_type = envelope.payload_type_name()
416 );
417 span.in_scope(|| {
418 tracing::info!(
419 payload_type = envelope.payload_type_name(),
420 "message received"
421 );
422 });
423 let start_time = Instant::now();
424 let future = (handler)(*envelope, session);
425 async move {
426 let result = future.await;
427 let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
428 match result {
429 Err(error) => {
430 tracing::error!(%error, ?duration_ms, "error handling message")
431 }
432 Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
433 }
434 }
435 .instrument(span)
436 .boxed()
437 }),
438 );
439 if prev_handler.is_some() {
440 panic!("registered a handler for the same message twice");
441 }
442 self
443 }
444
445 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
446 where
447 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
448 Fut: 'static + Send + Future<Output = Result<()>>,
449 M: EnvelopedMessage,
450 {
451 self.add_handler(move |envelope, session| handler(envelope.payload, session));
452 self
453 }
454
455 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
456 where
457 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
458 Fut: Send + Future<Output = Result<()>>,
459 M: RequestMessage,
460 {
461 let handler = Arc::new(handler);
462 self.add_handler(move |envelope, session| {
463 let receipt = envelope.receipt();
464 let handler = handler.clone();
465 async move {
466 let peer = session.peer.clone();
467 let responded = Arc::new(AtomicBool::default());
468 let response = Response {
469 peer: peer.clone(),
470 responded: responded.clone(),
471 receipt,
472 };
473 match (handler)(envelope.payload, response, session).await {
474 Ok(()) => {
475 if responded.load(std::sync::atomic::Ordering::SeqCst) {
476 Ok(())
477 } else {
478 Err(anyhow!("handler did not send a response"))?
479 }
480 }
481 Err(error) => {
482 peer.respond_with_error(
483 receipt,
484 proto::Error {
485 message: error.to_string(),
486 },
487 )?;
488 Err(error)
489 }
490 }
491 }
492 })
493 }
494
495 pub fn handle_connection(
496 self: &Arc<Self>,
497 connection: Connection,
498 address: String,
499 user: User,
500 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
501 executor: Executor,
502 ) -> impl Future<Output = Result<()>> {
503 let this = self.clone();
504 let user_id = user.id;
505 let login = user.github_login;
506 let span = info_span!("handle connection", %user_id, %login, %address);
507 let mut teardown = self.teardown.subscribe();
508 async move {
509 let (connection_id, handle_io, mut incoming_rx) = this
510 .peer
511 .add_connection(connection, {
512 let executor = executor.clone();
513 move |duration| executor.sleep(duration)
514 });
515
516 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
517 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
518 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
519
520 if let Some(send_connection_id) = send_connection_id.take() {
521 let _ = send_connection_id.send(connection_id);
522 }
523
524 if !user.connected_once {
525 this.peer.send(connection_id, proto::ShowContacts {})?;
526 this.app_state.db.set_user_connected_once(user_id, true).await?;
527 }
528
529 let (contacts, invite_code, (channels, channel_participants), channel_invites) = future::try_join4(
530 this.app_state.db.get_contacts(user_id),
531 this.app_state.db.get_invite_code_for_user(user_id),
532 this.app_state.db.get_channels(user_id),
533 this.app_state.db.get_channel_invites(user_id)
534 ).await?;
535
536 {
537 let mut pool = this.connection_pool.lock();
538 pool.add_connection(connection_id, user_id, user.admin);
539 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
540 this.peer.send(connection_id, build_initial_channels_update(channels, channel_participants, channel_invites))?;
541
542 if let Some((code, count)) = invite_code {
543 this.peer.send(connection_id, proto::UpdateInviteInfo {
544 url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
545 count: count as u32,
546 })?;
547 }
548 }
549
550 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
551 this.peer.send(connection_id, incoming_call)?;
552 }
553
554 let session = Session {
555 user_id,
556 connection_id,
557 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
558 peer: this.peer.clone(),
559 connection_pool: this.connection_pool.clone(),
560 live_kit_client: this.app_state.live_kit_client.clone(),
561 executor: executor.clone(),
562 };
563 update_user_contacts(user_id, &session).await?;
564
565 let handle_io = handle_io.fuse();
566 futures::pin_mut!(handle_io);
567
568 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
569 // This prevents deadlocks when e.g., client A performs a request to client B and
570 // client B performs a request to client A. If both clients stop processing further
571 // messages until their respective request completes, they won't have a chance to
572 // respond to the other client's request and cause a deadlock.
573 //
574 // This arrangement ensures we will attempt to process earlier messages first, but fall
575 // back to processing messages arrived later in the spirit of making progress.
576 let mut foreground_message_handlers = FuturesUnordered::new();
577 let concurrent_handlers = Arc::new(Semaphore::new(256));
578 loop {
579 let next_message = async {
580 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
581 let message = incoming_rx.next().await;
582 (permit, message)
583 }.fuse();
584 futures::pin_mut!(next_message);
585 futures::select_biased! {
586 _ = teardown.changed().fuse() => return Ok(()),
587 result = handle_io => {
588 if let Err(error) = result {
589 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
590 }
591 break;
592 }
593 _ = foreground_message_handlers.next() => {}
594 next_message = next_message => {
595 let (permit, message) = next_message;
596 if let Some(message) = message {
597 let type_name = message.payload_type_name();
598 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
599 let span_enter = span.enter();
600 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
601 let is_background = message.is_background();
602 let handle_message = (handler)(message, session.clone());
603 drop(span_enter);
604
605 let handle_message = async move {
606 handle_message.await;
607 drop(permit);
608 }.instrument(span);
609 if is_background {
610 executor.spawn_detached(handle_message);
611 } else {
612 foreground_message_handlers.push(handle_message);
613 }
614 } else {
615 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
616 }
617 } else {
618 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
619 break;
620 }
621 }
622 }
623 }
624
625 drop(foreground_message_handlers);
626 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
627 if let Err(error) = connection_lost(session, teardown, executor).await {
628 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
629 }
630
631 Ok(())
632 }.instrument(span)
633 }
634
635 pub async fn invite_code_redeemed(
636 self: &Arc<Self>,
637 inviter_id: UserId,
638 invitee_id: UserId,
639 ) -> Result<()> {
640 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
641 if let Some(code) = &user.invite_code {
642 let pool = self.connection_pool.lock();
643 let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
644 for connection_id in pool.user_connection_ids(inviter_id) {
645 self.peer.send(
646 connection_id,
647 proto::UpdateContacts {
648 contacts: vec![invitee_contact.clone()],
649 ..Default::default()
650 },
651 )?;
652 self.peer.send(
653 connection_id,
654 proto::UpdateInviteInfo {
655 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
656 count: user.invite_count as u32,
657 },
658 )?;
659 }
660 }
661 }
662 Ok(())
663 }
664
665 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
666 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
667 if let Some(invite_code) = &user.invite_code {
668 let pool = self.connection_pool.lock();
669 for connection_id in pool.user_connection_ids(user_id) {
670 self.peer.send(
671 connection_id,
672 proto::UpdateInviteInfo {
673 url: format!(
674 "{}{}",
675 self.app_state.config.invite_link_prefix, invite_code
676 ),
677 count: user.invite_count as u32,
678 },
679 )?;
680 }
681 }
682 }
683 Ok(())
684 }
685
686 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
687 ServerSnapshot {
688 connection_pool: ConnectionPoolGuard {
689 guard: self.connection_pool.lock(),
690 _not_send: PhantomData,
691 },
692 peer: &self.peer,
693 }
694 }
695}
696
697impl<'a> Deref for ConnectionPoolGuard<'a> {
698 type Target = ConnectionPool;
699
700 fn deref(&self) -> &Self::Target {
701 &*self.guard
702 }
703}
704
705impl<'a> DerefMut for ConnectionPoolGuard<'a> {
706 fn deref_mut(&mut self) -> &mut Self::Target {
707 &mut *self.guard
708 }
709}
710
711impl<'a> Drop for ConnectionPoolGuard<'a> {
712 fn drop(&mut self) {
713 #[cfg(test)]
714 self.check_invariants();
715 }
716}
717
718fn broadcast<F>(
719 sender_id: Option<ConnectionId>,
720 receiver_ids: impl IntoIterator<Item = ConnectionId>,
721 mut f: F,
722) where
723 F: FnMut(ConnectionId) -> anyhow::Result<()>,
724{
725 for receiver_id in receiver_ids {
726 if Some(receiver_id) != sender_id {
727 if let Err(error) = f(receiver_id) {
728 tracing::error!("failed to send to {:?} {}", receiver_id, error);
729 }
730 }
731 }
732}
733
734lazy_static! {
735 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
736}
737
738pub struct ProtocolVersion(u32);
739
740impl Header for ProtocolVersion {
741 fn name() -> &'static HeaderName {
742 &ZED_PROTOCOL_VERSION
743 }
744
745 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
746 where
747 Self: Sized,
748 I: Iterator<Item = &'i axum::http::HeaderValue>,
749 {
750 let version = values
751 .next()
752 .ok_or_else(axum::headers::Error::invalid)?
753 .to_str()
754 .map_err(|_| axum::headers::Error::invalid())?
755 .parse()
756 .map_err(|_| axum::headers::Error::invalid())?;
757 Ok(Self(version))
758 }
759
760 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
761 values.extend([self.0.to_string().parse().unwrap()]);
762 }
763}
764
765pub fn routes(server: Arc<Server>) -> Router<Body> {
766 Router::new()
767 .route("/rpc", get(handle_websocket_request))
768 .layer(
769 ServiceBuilder::new()
770 .layer(Extension(server.app_state.clone()))
771 .layer(middleware::from_fn(auth::validate_header)),
772 )
773 .route("/metrics", get(handle_metrics))
774 .layer(Extension(server))
775}
776
777pub async fn handle_websocket_request(
778 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
779 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
780 Extension(server): Extension<Arc<Server>>,
781 Extension(user): Extension<User>,
782 ws: WebSocketUpgrade,
783) -> axum::response::Response {
784 if protocol_version != rpc::PROTOCOL_VERSION {
785 return (
786 StatusCode::UPGRADE_REQUIRED,
787 "client must be upgraded".to_string(),
788 )
789 .into_response();
790 }
791 let socket_address = socket_address.to_string();
792 ws.on_upgrade(move |socket| {
793 use util::ResultExt;
794 let socket = socket
795 .map_ok(to_tungstenite_message)
796 .err_into()
797 .with(|message| async move { Ok(to_axum_message(message)) });
798 let connection = Connection::new(Box::pin(socket));
799 async move {
800 server
801 .handle_connection(connection, socket_address, user, None, Executor::Production)
802 .await
803 .log_err();
804 }
805 })
806}
807
808pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
809 let connections = server
810 .connection_pool
811 .lock()
812 .connections()
813 .filter(|connection| !connection.admin)
814 .count();
815
816 METRIC_CONNECTIONS.set(connections as _);
817
818 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
819 METRIC_SHARED_PROJECTS.set(shared_projects as _);
820
821 let encoder = prometheus::TextEncoder::new();
822 let metric_families = prometheus::gather();
823 let encoded_metrics = encoder
824 .encode_to_string(&metric_families)
825 .map_err(|err| anyhow!("{}", err))?;
826 Ok(encoded_metrics)
827}
828
829#[instrument(err, skip(executor))]
830async fn connection_lost(
831 session: Session,
832 mut teardown: watch::Receiver<()>,
833 executor: Executor,
834) -> Result<()> {
835 session.peer.disconnect(session.connection_id);
836 session
837 .connection_pool()
838 .await
839 .remove_connection(session.connection_id)?;
840
841 session
842 .db()
843 .await
844 .connection_lost(session.connection_id)
845 .await
846 .trace_err();
847
848 futures::select_biased! {
849 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
850 leave_room_for_session(&session).await.trace_err();
851
852 if !session
853 .connection_pool()
854 .await
855 .is_user_online(session.user_id)
856 {
857 let db = session.db().await;
858 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
859 room_updated(&room, &session.peer);
860 }
861 }
862 update_user_contacts(session.user_id, &session).await?;
863 }
864 _ = teardown.changed().fuse() => {}
865 }
866
867 Ok(())
868}
869
870async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
871 response.send(proto::Ack {})?;
872 Ok(())
873}
874
875async fn create_room(
876 _request: proto::CreateRoom,
877 response: Response<proto::CreateRoom>,
878 session: Session,
879) -> Result<()> {
880 let live_kit_room = nanoid::nanoid!(30);
881
882 let live_kit_connection_info = {
883 let live_kit_room = live_kit_room.clone();
884 let live_kit = session.live_kit_client.as_ref();
885
886 util::async_iife!({
887 let live_kit = live_kit?;
888
889 live_kit
890 .create_room(live_kit_room.clone())
891 .await
892 .trace_err()?;
893
894 let token = live_kit
895 .room_token(&live_kit_room, &session.user_id.to_string())
896 .trace_err()?;
897
898 Some(proto::LiveKitConnectionInfo {
899 server_url: live_kit.url().into(),
900 token,
901 })
902 })
903 }
904 .await;
905
906 let room = session
907 .db()
908 .await
909 .create_room(session.user_id, session.connection_id, &live_kit_room)
910 .await?;
911
912 response.send(proto::CreateRoomResponse {
913 room: Some(room.clone()),
914 live_kit_connection_info,
915 })?;
916
917 update_user_contacts(session.user_id, &session).await?;
918 Ok(())
919}
920
921async fn join_room(
922 request: proto::JoinRoom,
923 response: Response<proto::JoinRoom>,
924 session: Session,
925) -> Result<()> {
926 let room_id = RoomId::from_proto(request.id);
927 let room = {
928 let room = session
929 .db()
930 .await
931 .join_room(room_id, session.user_id, None, session.connection_id)
932 .await?;
933 room_updated(&room.room, &session.peer);
934 room.room.clone()
935 };
936
937 for connection_id in session
938 .connection_pool()
939 .await
940 .user_connection_ids(session.user_id)
941 {
942 session
943 .peer
944 .send(
945 connection_id,
946 proto::CallCanceled {
947 room_id: room_id.to_proto(),
948 },
949 )
950 .trace_err();
951 }
952
953 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
954 if let Some(token) = live_kit
955 .room_token(&room.live_kit_room, &session.user_id.to_string())
956 .trace_err()
957 {
958 Some(proto::LiveKitConnectionInfo {
959 server_url: live_kit.url().into(),
960 token,
961 })
962 } else {
963 None
964 }
965 } else {
966 None
967 };
968
969 response.send(proto::JoinRoomResponse {
970 room: Some(room),
971 live_kit_connection_info,
972 })?;
973
974 update_user_contacts(session.user_id, &session).await?;
975 Ok(())
976}
977
978async fn rejoin_room(
979 request: proto::RejoinRoom,
980 response: Response<proto::RejoinRoom>,
981 session: Session,
982) -> Result<()> {
983 let room;
984 let channel_id;
985 let channel_members;
986 {
987 let mut rejoined_room = session
988 .db()
989 .await
990 .rejoin_room(request, session.user_id, session.connection_id)
991 .await?;
992
993 response.send(proto::RejoinRoomResponse {
994 room: Some(rejoined_room.room.clone()),
995 reshared_projects: rejoined_room
996 .reshared_projects
997 .iter()
998 .map(|project| proto::ResharedProject {
999 id: project.id.to_proto(),
1000 collaborators: project
1001 .collaborators
1002 .iter()
1003 .map(|collaborator| collaborator.to_proto())
1004 .collect(),
1005 })
1006 .collect(),
1007 rejoined_projects: rejoined_room
1008 .rejoined_projects
1009 .iter()
1010 .map(|rejoined_project| proto::RejoinedProject {
1011 id: rejoined_project.id.to_proto(),
1012 worktrees: rejoined_project
1013 .worktrees
1014 .iter()
1015 .map(|worktree| proto::WorktreeMetadata {
1016 id: worktree.id,
1017 root_name: worktree.root_name.clone(),
1018 visible: worktree.visible,
1019 abs_path: worktree.abs_path.clone(),
1020 })
1021 .collect(),
1022 collaborators: rejoined_project
1023 .collaborators
1024 .iter()
1025 .map(|collaborator| collaborator.to_proto())
1026 .collect(),
1027 language_servers: rejoined_project.language_servers.clone(),
1028 })
1029 .collect(),
1030 })?;
1031 room_updated(&rejoined_room.room, &session.peer);
1032
1033 for project in &rejoined_room.reshared_projects {
1034 for collaborator in &project.collaborators {
1035 session
1036 .peer
1037 .send(
1038 collaborator.connection_id,
1039 proto::UpdateProjectCollaborator {
1040 project_id: project.id.to_proto(),
1041 old_peer_id: Some(project.old_connection_id.into()),
1042 new_peer_id: Some(session.connection_id.into()),
1043 },
1044 )
1045 .trace_err();
1046 }
1047
1048 broadcast(
1049 Some(session.connection_id),
1050 project
1051 .collaborators
1052 .iter()
1053 .map(|collaborator| collaborator.connection_id),
1054 |connection_id| {
1055 session.peer.forward_send(
1056 session.connection_id,
1057 connection_id,
1058 proto::UpdateProject {
1059 project_id: project.id.to_proto(),
1060 worktrees: project.worktrees.clone(),
1061 },
1062 )
1063 },
1064 );
1065 }
1066
1067 for project in &rejoined_room.rejoined_projects {
1068 for collaborator in &project.collaborators {
1069 session
1070 .peer
1071 .send(
1072 collaborator.connection_id,
1073 proto::UpdateProjectCollaborator {
1074 project_id: project.id.to_proto(),
1075 old_peer_id: Some(project.old_connection_id.into()),
1076 new_peer_id: Some(session.connection_id.into()),
1077 },
1078 )
1079 .trace_err();
1080 }
1081 }
1082
1083 for project in &mut rejoined_room.rejoined_projects {
1084 for worktree in mem::take(&mut project.worktrees) {
1085 #[cfg(any(test, feature = "test-support"))]
1086 const MAX_CHUNK_SIZE: usize = 2;
1087 #[cfg(not(any(test, feature = "test-support")))]
1088 const MAX_CHUNK_SIZE: usize = 256;
1089
1090 // Stream this worktree's entries.
1091 let message = proto::UpdateWorktree {
1092 project_id: project.id.to_proto(),
1093 worktree_id: worktree.id,
1094 abs_path: worktree.abs_path.clone(),
1095 root_name: worktree.root_name,
1096 updated_entries: worktree.updated_entries,
1097 removed_entries: worktree.removed_entries,
1098 scan_id: worktree.scan_id,
1099 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1100 updated_repositories: worktree.updated_repositories,
1101 removed_repositories: worktree.removed_repositories,
1102 };
1103 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1104 session.peer.send(session.connection_id, update.clone())?;
1105 }
1106
1107 // Stream this worktree's diagnostics.
1108 for summary in worktree.diagnostic_summaries {
1109 session.peer.send(
1110 session.connection_id,
1111 proto::UpdateDiagnosticSummary {
1112 project_id: project.id.to_proto(),
1113 worktree_id: worktree.id,
1114 summary: Some(summary),
1115 },
1116 )?;
1117 }
1118
1119 for settings_file in worktree.settings_files {
1120 session.peer.send(
1121 session.connection_id,
1122 proto::UpdateWorktreeSettings {
1123 project_id: project.id.to_proto(),
1124 worktree_id: worktree.id,
1125 path: settings_file.path,
1126 content: Some(settings_file.content),
1127 },
1128 )?;
1129 }
1130 }
1131
1132 for language_server in &project.language_servers {
1133 session.peer.send(
1134 session.connection_id,
1135 proto::UpdateLanguageServer {
1136 project_id: project.id.to_proto(),
1137 language_server_id: language_server.id,
1138 variant: Some(
1139 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1140 proto::LspDiskBasedDiagnosticsUpdated {},
1141 ),
1142 ),
1143 },
1144 )?;
1145 }
1146 }
1147
1148 room = mem::take(&mut rejoined_room.room);
1149 channel_id = rejoined_room.channel_id;
1150 channel_members = mem::take(&mut rejoined_room.channel_members);
1151 }
1152
1153 //TODO: move this into the room guard
1154 if let Some(channel_id) = channel_id {
1155 channel_updated(
1156 channel_id,
1157 &room,
1158 &channel_members,
1159 &session.peer,
1160 &*session.connection_pool().await,
1161 );
1162 }
1163
1164 update_user_contacts(session.user_id, &session).await?;
1165 Ok(())
1166}
1167
1168async fn leave_room(
1169 _: proto::LeaveRoom,
1170 response: Response<proto::LeaveRoom>,
1171 session: Session,
1172) -> Result<()> {
1173 leave_room_for_session(&session).await?;
1174 response.send(proto::Ack {})?;
1175 Ok(())
1176}
1177
1178async fn call(
1179 request: proto::Call,
1180 response: Response<proto::Call>,
1181 session: Session,
1182) -> Result<()> {
1183 let room_id = RoomId::from_proto(request.room_id);
1184 let calling_user_id = session.user_id;
1185 let calling_connection_id = session.connection_id;
1186 let called_user_id = UserId::from_proto(request.called_user_id);
1187 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1188 if !session
1189 .db()
1190 .await
1191 .has_contact(calling_user_id, called_user_id)
1192 .await?
1193 {
1194 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1195 }
1196
1197 let incoming_call = {
1198 let (room, incoming_call) = &mut *session
1199 .db()
1200 .await
1201 .call(
1202 room_id,
1203 calling_user_id,
1204 calling_connection_id,
1205 called_user_id,
1206 initial_project_id,
1207 )
1208 .await?;
1209 room_updated(&room, &session.peer);
1210 mem::take(incoming_call)
1211 };
1212 update_user_contacts(called_user_id, &session).await?;
1213
1214 let mut calls = session
1215 .connection_pool()
1216 .await
1217 .user_connection_ids(called_user_id)
1218 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1219 .collect::<FuturesUnordered<_>>();
1220
1221 while let Some(call_response) = calls.next().await {
1222 match call_response.as_ref() {
1223 Ok(_) => {
1224 response.send(proto::Ack {})?;
1225 return Ok(());
1226 }
1227 Err(_) => {
1228 call_response.trace_err();
1229 }
1230 }
1231 }
1232
1233 {
1234 let room = session
1235 .db()
1236 .await
1237 .call_failed(room_id, called_user_id)
1238 .await?;
1239 room_updated(&room, &session.peer);
1240 }
1241 update_user_contacts(called_user_id, &session).await?;
1242
1243 Err(anyhow!("failed to ring user"))?
1244}
1245
1246async fn cancel_call(
1247 request: proto::CancelCall,
1248 response: Response<proto::CancelCall>,
1249 session: Session,
1250) -> Result<()> {
1251 let called_user_id = UserId::from_proto(request.called_user_id);
1252 let room_id = RoomId::from_proto(request.room_id);
1253 {
1254 let room = session
1255 .db()
1256 .await
1257 .cancel_call(room_id, session.connection_id, called_user_id)
1258 .await?;
1259 room_updated(&room, &session.peer);
1260 }
1261
1262 for connection_id in session
1263 .connection_pool()
1264 .await
1265 .user_connection_ids(called_user_id)
1266 {
1267 session
1268 .peer
1269 .send(
1270 connection_id,
1271 proto::CallCanceled {
1272 room_id: room_id.to_proto(),
1273 },
1274 )
1275 .trace_err();
1276 }
1277 response.send(proto::Ack {})?;
1278
1279 update_user_contacts(called_user_id, &session).await?;
1280 Ok(())
1281}
1282
1283async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1284 let room_id = RoomId::from_proto(message.room_id);
1285 {
1286 let room = session
1287 .db()
1288 .await
1289 .decline_call(Some(room_id), session.user_id)
1290 .await?
1291 .ok_or_else(|| anyhow!("failed to decline call"))?;
1292 room_updated(&room, &session.peer);
1293 }
1294
1295 for connection_id in session
1296 .connection_pool()
1297 .await
1298 .user_connection_ids(session.user_id)
1299 {
1300 session
1301 .peer
1302 .send(
1303 connection_id,
1304 proto::CallCanceled {
1305 room_id: room_id.to_proto(),
1306 },
1307 )
1308 .trace_err();
1309 }
1310 update_user_contacts(session.user_id, &session).await?;
1311 Ok(())
1312}
1313
1314async fn update_participant_location(
1315 request: proto::UpdateParticipantLocation,
1316 response: Response<proto::UpdateParticipantLocation>,
1317 session: Session,
1318) -> Result<()> {
1319 let room_id = RoomId::from_proto(request.room_id);
1320 let location = request
1321 .location
1322 .ok_or_else(|| anyhow!("invalid location"))?;
1323
1324 let db = session.db().await;
1325 let room = db
1326 .update_room_participant_location(room_id, session.connection_id, location)
1327 .await?;
1328
1329 room_updated(&room, &session.peer);
1330 response.send(proto::Ack {})?;
1331 Ok(())
1332}
1333
1334async fn share_project(
1335 request: proto::ShareProject,
1336 response: Response<proto::ShareProject>,
1337 session: Session,
1338) -> Result<()> {
1339 let (project_id, room) = &*session
1340 .db()
1341 .await
1342 .share_project(
1343 RoomId::from_proto(request.room_id),
1344 session.connection_id,
1345 &request.worktrees,
1346 )
1347 .await?;
1348 response.send(proto::ShareProjectResponse {
1349 project_id: project_id.to_proto(),
1350 })?;
1351 room_updated(&room, &session.peer);
1352
1353 Ok(())
1354}
1355
1356async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1357 let project_id = ProjectId::from_proto(message.project_id);
1358
1359 let (room, guest_connection_ids) = &*session
1360 .db()
1361 .await
1362 .unshare_project(project_id, session.connection_id)
1363 .await?;
1364
1365 broadcast(
1366 Some(session.connection_id),
1367 guest_connection_ids.iter().copied(),
1368 |conn_id| session.peer.send(conn_id, message.clone()),
1369 );
1370 room_updated(&room, &session.peer);
1371
1372 Ok(())
1373}
1374
1375async fn join_project(
1376 request: proto::JoinProject,
1377 response: Response<proto::JoinProject>,
1378 session: Session,
1379) -> Result<()> {
1380 let project_id = ProjectId::from_proto(request.project_id);
1381 let guest_user_id = session.user_id;
1382
1383 tracing::info!(%project_id, "join project");
1384
1385 let (project, replica_id) = &mut *session
1386 .db()
1387 .await
1388 .join_project(project_id, session.connection_id)
1389 .await?;
1390
1391 let collaborators = project
1392 .collaborators
1393 .iter()
1394 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1395 .map(|collaborator| collaborator.to_proto())
1396 .collect::<Vec<_>>();
1397
1398 let worktrees = project
1399 .worktrees
1400 .iter()
1401 .map(|(id, worktree)| proto::WorktreeMetadata {
1402 id: *id,
1403 root_name: worktree.root_name.clone(),
1404 visible: worktree.visible,
1405 abs_path: worktree.abs_path.clone(),
1406 })
1407 .collect::<Vec<_>>();
1408
1409 for collaborator in &collaborators {
1410 session
1411 .peer
1412 .send(
1413 collaborator.peer_id.unwrap().into(),
1414 proto::AddProjectCollaborator {
1415 project_id: project_id.to_proto(),
1416 collaborator: Some(proto::Collaborator {
1417 peer_id: Some(session.connection_id.into()),
1418 replica_id: replica_id.0 as u32,
1419 user_id: guest_user_id.to_proto(),
1420 }),
1421 },
1422 )
1423 .trace_err();
1424 }
1425
1426 // First, we send the metadata associated with each worktree.
1427 response.send(proto::JoinProjectResponse {
1428 worktrees: worktrees.clone(),
1429 replica_id: replica_id.0 as u32,
1430 collaborators: collaborators.clone(),
1431 language_servers: project.language_servers.clone(),
1432 })?;
1433
1434 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1435 #[cfg(any(test, feature = "test-support"))]
1436 const MAX_CHUNK_SIZE: usize = 2;
1437 #[cfg(not(any(test, feature = "test-support")))]
1438 const MAX_CHUNK_SIZE: usize = 256;
1439
1440 // Stream this worktree's entries.
1441 let message = proto::UpdateWorktree {
1442 project_id: project_id.to_proto(),
1443 worktree_id,
1444 abs_path: worktree.abs_path.clone(),
1445 root_name: worktree.root_name,
1446 updated_entries: worktree.entries,
1447 removed_entries: Default::default(),
1448 scan_id: worktree.scan_id,
1449 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1450 updated_repositories: worktree.repository_entries.into_values().collect(),
1451 removed_repositories: Default::default(),
1452 };
1453 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1454 session.peer.send(session.connection_id, update.clone())?;
1455 }
1456
1457 // Stream this worktree's diagnostics.
1458 for summary in worktree.diagnostic_summaries {
1459 session.peer.send(
1460 session.connection_id,
1461 proto::UpdateDiagnosticSummary {
1462 project_id: project_id.to_proto(),
1463 worktree_id: worktree.id,
1464 summary: Some(summary),
1465 },
1466 )?;
1467 }
1468
1469 for settings_file in worktree.settings_files {
1470 session.peer.send(
1471 session.connection_id,
1472 proto::UpdateWorktreeSettings {
1473 project_id: project_id.to_proto(),
1474 worktree_id: worktree.id,
1475 path: settings_file.path,
1476 content: Some(settings_file.content),
1477 },
1478 )?;
1479 }
1480 }
1481
1482 for language_server in &project.language_servers {
1483 session.peer.send(
1484 session.connection_id,
1485 proto::UpdateLanguageServer {
1486 project_id: project_id.to_proto(),
1487 language_server_id: language_server.id,
1488 variant: Some(
1489 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1490 proto::LspDiskBasedDiagnosticsUpdated {},
1491 ),
1492 ),
1493 },
1494 )?;
1495 }
1496
1497 Ok(())
1498}
1499
1500async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1501 let sender_id = session.connection_id;
1502 let project_id = ProjectId::from_proto(request.project_id);
1503
1504 let (room, project) = &*session
1505 .db()
1506 .await
1507 .leave_project(project_id, sender_id)
1508 .await?;
1509 tracing::info!(
1510 %project_id,
1511 host_user_id = %project.host_user_id,
1512 host_connection_id = %project.host_connection_id,
1513 "leave project"
1514 );
1515
1516 project_left(&project, &session);
1517 room_updated(&room, &session.peer);
1518
1519 Ok(())
1520}
1521
1522async fn update_project(
1523 request: proto::UpdateProject,
1524 response: Response<proto::UpdateProject>,
1525 session: Session,
1526) -> Result<()> {
1527 let project_id = ProjectId::from_proto(request.project_id);
1528 let (room, guest_connection_ids) = &*session
1529 .db()
1530 .await
1531 .update_project(project_id, session.connection_id, &request.worktrees)
1532 .await?;
1533 broadcast(
1534 Some(session.connection_id),
1535 guest_connection_ids.iter().copied(),
1536 |connection_id| {
1537 session
1538 .peer
1539 .forward_send(session.connection_id, connection_id, request.clone())
1540 },
1541 );
1542 room_updated(&room, &session.peer);
1543 response.send(proto::Ack {})?;
1544
1545 Ok(())
1546}
1547
1548async fn update_worktree(
1549 request: proto::UpdateWorktree,
1550 response: Response<proto::UpdateWorktree>,
1551 session: Session,
1552) -> Result<()> {
1553 let guest_connection_ids = session
1554 .db()
1555 .await
1556 .update_worktree(&request, session.connection_id)
1557 .await?;
1558
1559 broadcast(
1560 Some(session.connection_id),
1561 guest_connection_ids.iter().copied(),
1562 |connection_id| {
1563 session
1564 .peer
1565 .forward_send(session.connection_id, connection_id, request.clone())
1566 },
1567 );
1568 response.send(proto::Ack {})?;
1569 Ok(())
1570}
1571
1572async fn update_diagnostic_summary(
1573 message: proto::UpdateDiagnosticSummary,
1574 session: Session,
1575) -> Result<()> {
1576 let guest_connection_ids = session
1577 .db()
1578 .await
1579 .update_diagnostic_summary(&message, session.connection_id)
1580 .await?;
1581
1582 broadcast(
1583 Some(session.connection_id),
1584 guest_connection_ids.iter().copied(),
1585 |connection_id| {
1586 session
1587 .peer
1588 .forward_send(session.connection_id, connection_id, message.clone())
1589 },
1590 );
1591
1592 Ok(())
1593}
1594
1595async fn update_worktree_settings(
1596 message: proto::UpdateWorktreeSettings,
1597 session: Session,
1598) -> Result<()> {
1599 let guest_connection_ids = session
1600 .db()
1601 .await
1602 .update_worktree_settings(&message, session.connection_id)
1603 .await?;
1604
1605 broadcast(
1606 Some(session.connection_id),
1607 guest_connection_ids.iter().copied(),
1608 |connection_id| {
1609 session
1610 .peer
1611 .forward_send(session.connection_id, connection_id, message.clone())
1612 },
1613 );
1614
1615 Ok(())
1616}
1617
1618async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1619 broadcast_project_message(request.project_id, request, session).await
1620}
1621
1622async fn start_language_server(
1623 request: proto::StartLanguageServer,
1624 session: Session,
1625) -> Result<()> {
1626 let guest_connection_ids = session
1627 .db()
1628 .await
1629 .start_language_server(&request, session.connection_id)
1630 .await?;
1631
1632 broadcast(
1633 Some(session.connection_id),
1634 guest_connection_ids.iter().copied(),
1635 |connection_id| {
1636 session
1637 .peer
1638 .forward_send(session.connection_id, connection_id, request.clone())
1639 },
1640 );
1641 Ok(())
1642}
1643
1644async fn update_language_server(
1645 request: proto::UpdateLanguageServer,
1646 session: Session,
1647) -> Result<()> {
1648 session.executor.record_backtrace();
1649 let project_id = ProjectId::from_proto(request.project_id);
1650 let project_connection_ids = session
1651 .db()
1652 .await
1653 .project_connection_ids(project_id, session.connection_id)
1654 .await?;
1655 broadcast(
1656 Some(session.connection_id),
1657 project_connection_ids.iter().copied(),
1658 |connection_id| {
1659 session
1660 .peer
1661 .forward_send(session.connection_id, connection_id, request.clone())
1662 },
1663 );
1664 Ok(())
1665}
1666
1667async fn forward_project_request<T>(
1668 request: T,
1669 response: Response<T>,
1670 session: Session,
1671) -> Result<()>
1672where
1673 T: EntityMessage + RequestMessage,
1674{
1675 session.executor.record_backtrace();
1676 let project_id = ProjectId::from_proto(request.remote_entity_id());
1677 let host_connection_id = {
1678 let collaborators = session
1679 .db()
1680 .await
1681 .project_collaborators(project_id, session.connection_id)
1682 .await?;
1683 collaborators
1684 .iter()
1685 .find(|collaborator| collaborator.is_host)
1686 .ok_or_else(|| anyhow!("host not found"))?
1687 .connection_id
1688 };
1689
1690 let payload = session
1691 .peer
1692 .forward_request(session.connection_id, host_connection_id, request)
1693 .await?;
1694
1695 response.send(payload)?;
1696 Ok(())
1697}
1698
1699async fn create_buffer_for_peer(
1700 request: proto::CreateBufferForPeer,
1701 session: Session,
1702) -> Result<()> {
1703 session.executor.record_backtrace();
1704 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1705 session
1706 .peer
1707 .forward_send(session.connection_id, peer_id.into(), request)?;
1708 Ok(())
1709}
1710
1711async fn update_buffer(
1712 request: proto::UpdateBuffer,
1713 response: Response<proto::UpdateBuffer>,
1714 session: Session,
1715) -> Result<()> {
1716 session.executor.record_backtrace();
1717 let project_id = ProjectId::from_proto(request.project_id);
1718 let mut guest_connection_ids;
1719 let mut host_connection_id = None;
1720 {
1721 let collaborators = session
1722 .db()
1723 .await
1724 .project_collaborators(project_id, session.connection_id)
1725 .await?;
1726 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1727 for collaborator in collaborators.iter() {
1728 if collaborator.is_host {
1729 host_connection_id = Some(collaborator.connection_id);
1730 } else {
1731 guest_connection_ids.push(collaborator.connection_id);
1732 }
1733 }
1734 }
1735 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1736
1737 session.executor.record_backtrace();
1738 broadcast(
1739 Some(session.connection_id),
1740 guest_connection_ids,
1741 |connection_id| {
1742 session
1743 .peer
1744 .forward_send(session.connection_id, connection_id, request.clone())
1745 },
1746 );
1747 if host_connection_id != session.connection_id {
1748 session
1749 .peer
1750 .forward_request(session.connection_id, host_connection_id, request.clone())
1751 .await?;
1752 }
1753
1754 response.send(proto::Ack {})?;
1755 Ok(())
1756}
1757
1758async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1759 let project_id = ProjectId::from_proto(request.project_id);
1760 let project_connection_ids = session
1761 .db()
1762 .await
1763 .project_connection_ids(project_id, session.connection_id)
1764 .await?;
1765
1766 broadcast(
1767 Some(session.connection_id),
1768 project_connection_ids.iter().copied(),
1769 |connection_id| {
1770 session
1771 .peer
1772 .forward_send(session.connection_id, connection_id, request.clone())
1773 },
1774 );
1775 Ok(())
1776}
1777
1778async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1779 let project_id = ProjectId::from_proto(request.project_id);
1780 let project_connection_ids = session
1781 .db()
1782 .await
1783 .project_connection_ids(project_id, session.connection_id)
1784 .await?;
1785 broadcast(
1786 Some(session.connection_id),
1787 project_connection_ids.iter().copied(),
1788 |connection_id| {
1789 session
1790 .peer
1791 .forward_send(session.connection_id, connection_id, request.clone())
1792 },
1793 );
1794 Ok(())
1795}
1796
1797async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1798 broadcast_project_message(request.project_id, request, session).await
1799}
1800
1801async fn broadcast_project_message<T: EnvelopedMessage>(
1802 project_id: u64,
1803 request: T,
1804 session: Session,
1805) -> Result<()> {
1806 let project_id = ProjectId::from_proto(project_id);
1807 let project_connection_ids = session
1808 .db()
1809 .await
1810 .project_connection_ids(project_id, session.connection_id)
1811 .await?;
1812 broadcast(
1813 Some(session.connection_id),
1814 project_connection_ids.iter().copied(),
1815 |connection_id| {
1816 session
1817 .peer
1818 .forward_send(session.connection_id, connection_id, request.clone())
1819 },
1820 );
1821 Ok(())
1822}
1823
1824async fn follow(
1825 request: proto::Follow,
1826 response: Response<proto::Follow>,
1827 session: Session,
1828) -> Result<()> {
1829 let project_id = ProjectId::from_proto(request.project_id);
1830 let leader_id = request
1831 .leader_id
1832 .ok_or_else(|| anyhow!("invalid leader id"))?
1833 .into();
1834 let follower_id = session.connection_id;
1835
1836 {
1837 let project_connection_ids = session
1838 .db()
1839 .await
1840 .project_connection_ids(project_id, session.connection_id)
1841 .await?;
1842
1843 if !project_connection_ids.contains(&leader_id) {
1844 Err(anyhow!("no such peer"))?;
1845 }
1846 }
1847
1848 let mut response_payload = session
1849 .peer
1850 .forward_request(session.connection_id, leader_id, request)
1851 .await?;
1852 response_payload
1853 .views
1854 .retain(|view| view.leader_id != Some(follower_id.into()));
1855 response.send(response_payload)?;
1856
1857 let room = session
1858 .db()
1859 .await
1860 .follow(project_id, leader_id, follower_id)
1861 .await?;
1862 room_updated(&room, &session.peer);
1863
1864 Ok(())
1865}
1866
1867async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1868 let project_id = ProjectId::from_proto(request.project_id);
1869 let leader_id = request
1870 .leader_id
1871 .ok_or_else(|| anyhow!("invalid leader id"))?
1872 .into();
1873 let follower_id = session.connection_id;
1874
1875 if !session
1876 .db()
1877 .await
1878 .project_connection_ids(project_id, session.connection_id)
1879 .await?
1880 .contains(&leader_id)
1881 {
1882 Err(anyhow!("no such peer"))?;
1883 }
1884
1885 session
1886 .peer
1887 .forward_send(session.connection_id, leader_id, request)?;
1888
1889 let room = session
1890 .db()
1891 .await
1892 .unfollow(project_id, leader_id, follower_id)
1893 .await?;
1894 room_updated(&room, &session.peer);
1895
1896 Ok(())
1897}
1898
1899async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1900 let project_id = ProjectId::from_proto(request.project_id);
1901 let project_connection_ids = session
1902 .db
1903 .lock()
1904 .await
1905 .project_connection_ids(project_id, session.connection_id)
1906 .await?;
1907
1908 let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1909 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1910 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1911 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1912 });
1913 for follower_peer_id in request.follower_ids.iter().copied() {
1914 let follower_connection_id = follower_peer_id.into();
1915 if project_connection_ids.contains(&follower_connection_id)
1916 && Some(follower_peer_id) != leader_id
1917 {
1918 session.peer.forward_send(
1919 session.connection_id,
1920 follower_connection_id,
1921 request.clone(),
1922 )?;
1923 }
1924 }
1925 Ok(())
1926}
1927
1928async fn get_users(
1929 request: proto::GetUsers,
1930 response: Response<proto::GetUsers>,
1931 session: Session,
1932) -> Result<()> {
1933 let user_ids = request
1934 .user_ids
1935 .into_iter()
1936 .map(UserId::from_proto)
1937 .collect();
1938 let users = session
1939 .db()
1940 .await
1941 .get_users_by_ids(user_ids)
1942 .await?
1943 .into_iter()
1944 .map(|user| proto::User {
1945 id: user.id.to_proto(),
1946 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1947 github_login: user.github_login,
1948 })
1949 .collect();
1950 response.send(proto::UsersResponse { users })?;
1951 Ok(())
1952}
1953
1954async fn fuzzy_search_users(
1955 request: proto::FuzzySearchUsers,
1956 response: Response<proto::FuzzySearchUsers>,
1957 session: Session,
1958) -> Result<()> {
1959 let query = request.query;
1960 let users = match query.len() {
1961 0 => vec![],
1962 1 | 2 => session
1963 .db()
1964 .await
1965 .get_user_by_github_login(&query)
1966 .await?
1967 .into_iter()
1968 .collect(),
1969 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1970 };
1971 let users = users
1972 .into_iter()
1973 .filter(|user| user.id != session.user_id)
1974 .map(|user| proto::User {
1975 id: user.id.to_proto(),
1976 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1977 github_login: user.github_login,
1978 })
1979 .collect();
1980 response.send(proto::UsersResponse { users })?;
1981 Ok(())
1982}
1983
1984async fn request_contact(
1985 request: proto::RequestContact,
1986 response: Response<proto::RequestContact>,
1987 session: Session,
1988) -> Result<()> {
1989 let requester_id = session.user_id;
1990 let responder_id = UserId::from_proto(request.responder_id);
1991 if requester_id == responder_id {
1992 return Err(anyhow!("cannot add yourself as a contact"))?;
1993 }
1994
1995 session
1996 .db()
1997 .await
1998 .send_contact_request(requester_id, responder_id)
1999 .await?;
2000
2001 // Update outgoing contact requests of requester
2002 let mut update = proto::UpdateContacts::default();
2003 update.outgoing_requests.push(responder_id.to_proto());
2004 for connection_id in session
2005 .connection_pool()
2006 .await
2007 .user_connection_ids(requester_id)
2008 {
2009 session.peer.send(connection_id, update.clone())?;
2010 }
2011
2012 // Update incoming contact requests of responder
2013 let mut update = proto::UpdateContacts::default();
2014 update
2015 .incoming_requests
2016 .push(proto::IncomingContactRequest {
2017 requester_id: requester_id.to_proto(),
2018 should_notify: true,
2019 });
2020 for connection_id in session
2021 .connection_pool()
2022 .await
2023 .user_connection_ids(responder_id)
2024 {
2025 session.peer.send(connection_id, update.clone())?;
2026 }
2027
2028 response.send(proto::Ack {})?;
2029 Ok(())
2030}
2031
2032async fn respond_to_contact_request(
2033 request: proto::RespondToContactRequest,
2034 response: Response<proto::RespondToContactRequest>,
2035 session: Session,
2036) -> Result<()> {
2037 let responder_id = session.user_id;
2038 let requester_id = UserId::from_proto(request.requester_id);
2039 let db = session.db().await;
2040 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2041 db.dismiss_contact_notification(responder_id, requester_id)
2042 .await?;
2043 } else {
2044 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2045
2046 db.respond_to_contact_request(responder_id, requester_id, accept)
2047 .await?;
2048 let requester_busy = db.is_user_busy(requester_id).await?;
2049 let responder_busy = db.is_user_busy(responder_id).await?;
2050
2051 let pool = session.connection_pool().await;
2052 // Update responder with new contact
2053 let mut update = proto::UpdateContacts::default();
2054 if accept {
2055 update
2056 .contacts
2057 .push(contact_for_user(requester_id, false, requester_busy, &pool));
2058 }
2059 update
2060 .remove_incoming_requests
2061 .push(requester_id.to_proto());
2062 for connection_id in pool.user_connection_ids(responder_id) {
2063 session.peer.send(connection_id, update.clone())?;
2064 }
2065
2066 // Update requester with new contact
2067 let mut update = proto::UpdateContacts::default();
2068 if accept {
2069 update
2070 .contacts
2071 .push(contact_for_user(responder_id, true, responder_busy, &pool));
2072 }
2073 update
2074 .remove_outgoing_requests
2075 .push(responder_id.to_proto());
2076 for connection_id in pool.user_connection_ids(requester_id) {
2077 session.peer.send(connection_id, update.clone())?;
2078 }
2079 }
2080
2081 response.send(proto::Ack {})?;
2082 Ok(())
2083}
2084
2085async fn remove_contact(
2086 request: proto::RemoveContact,
2087 response: Response<proto::RemoveContact>,
2088 session: Session,
2089) -> Result<()> {
2090 let requester_id = session.user_id;
2091 let responder_id = UserId::from_proto(request.user_id);
2092 let db = session.db().await;
2093 let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
2094
2095 let pool = session.connection_pool().await;
2096 // Update outgoing contact requests of requester
2097 let mut update = proto::UpdateContacts::default();
2098 if contact_accepted {
2099 update.remove_contacts.push(responder_id.to_proto());
2100 } else {
2101 update
2102 .remove_outgoing_requests
2103 .push(responder_id.to_proto());
2104 }
2105 for connection_id in pool.user_connection_ids(requester_id) {
2106 session.peer.send(connection_id, update.clone())?;
2107 }
2108
2109 // Update incoming contact requests of responder
2110 let mut update = proto::UpdateContacts::default();
2111 if contact_accepted {
2112 update.remove_contacts.push(requester_id.to_proto());
2113 } else {
2114 update
2115 .remove_incoming_requests
2116 .push(requester_id.to_proto());
2117 }
2118 for connection_id in pool.user_connection_ids(responder_id) {
2119 session.peer.send(connection_id, update.clone())?;
2120 }
2121
2122 response.send(proto::Ack {})?;
2123 Ok(())
2124}
2125
2126async fn create_channel(
2127 request: proto::CreateChannel,
2128 response: Response<proto::CreateChannel>,
2129 session: Session,
2130) -> Result<()> {
2131 let db = session.db().await;
2132 let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
2133
2134 if let Some(live_kit) = session.live_kit_client.as_ref() {
2135 live_kit.create_room(live_kit_room.clone()).await?;
2136 }
2137
2138 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2139 let id = db
2140 .create_channel(&request.name, parent_id, &live_kit_room, session.user_id)
2141 .await?;
2142
2143 response.send(proto::CreateChannelResponse {
2144 channel_id: id.to_proto(),
2145 })?;
2146
2147 let mut update = proto::UpdateChannels::default();
2148 update.channels.push(proto::Channel {
2149 id: id.to_proto(),
2150 name: request.name,
2151 parent_id: request.parent_id,
2152 });
2153
2154 if let Some(parent_id) = parent_id {
2155 let member_ids = db.get_channel_members(parent_id).await?;
2156 let connection_pool = session.connection_pool().await;
2157 for member_id in member_ids {
2158 for connection_id in connection_pool.user_connection_ids(member_id) {
2159 session.peer.send(connection_id, update.clone())?;
2160 }
2161 }
2162 } else {
2163 session.peer.send(session.connection_id, update)?;
2164 }
2165
2166 Ok(())
2167}
2168
2169async fn remove_channel(
2170 request: proto::RemoveChannel,
2171 response: Response<proto::RemoveChannel>,
2172 session: Session,
2173) -> Result<()> {
2174 let db = session.db().await;
2175
2176 let channel_id = request.channel_id;
2177 let (removed_channels, member_ids) = db
2178 .remove_channel(ChannelId::from_proto(channel_id), session.user_id)
2179 .await?;
2180 response.send(proto::Ack {})?;
2181
2182 // Notify members of removed channels
2183 let mut update = proto::UpdateChannels::default();
2184 update
2185 .remove_channels
2186 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2187
2188 let connection_pool = session.connection_pool().await;
2189 for member_id in member_ids {
2190 for connection_id in connection_pool.user_connection_ids(member_id) {
2191 session.peer.send(connection_id, update.clone())?;
2192 }
2193 }
2194
2195 Ok(())
2196}
2197
2198async fn invite_channel_member(
2199 request: proto::InviteChannelMember,
2200 response: Response<proto::InviteChannelMember>,
2201 session: Session,
2202) -> Result<()> {
2203 let db = session.db().await;
2204 let channel_id = ChannelId::from_proto(request.channel_id);
2205 let channel = db
2206 .get_channel(channel_id)
2207 .await?
2208 .ok_or_else(|| anyhow!("channel not found"))?;
2209 let invitee_id = UserId::from_proto(request.user_id);
2210 db.invite_channel_member(channel_id, invitee_id, session.user_id, false)
2211 .await?;
2212
2213 let mut update = proto::UpdateChannels::default();
2214 update.channel_invitations.push(proto::Channel {
2215 id: channel.id.to_proto(),
2216 name: channel.name,
2217 parent_id: None,
2218 });
2219 for connection_id in session
2220 .connection_pool()
2221 .await
2222 .user_connection_ids(invitee_id)
2223 {
2224 session.peer.send(connection_id, update.clone())?;
2225 }
2226
2227 response.send(proto::Ack {})?;
2228 Ok(())
2229}
2230
2231async fn remove_channel_member(
2232 _request: proto::RemoveChannelMember,
2233 _response: Response<proto::RemoveChannelMember>,
2234 _session: Session,
2235) -> Result<()> {
2236 Ok(())
2237}
2238
2239async fn respond_to_channel_invite(
2240 request: proto::RespondToChannelInvite,
2241 response: Response<proto::RespondToChannelInvite>,
2242 session: Session,
2243) -> Result<()> {
2244 let db = session.db().await;
2245 let channel_id = ChannelId::from_proto(request.channel_id);
2246 let channel = db
2247 .get_channel(channel_id)
2248 .await?
2249 .ok_or_else(|| anyhow!("no such channel"))?;
2250 db.respond_to_channel_invite(channel_id, session.user_id, request.accept)
2251 .await?;
2252
2253 let mut update = proto::UpdateChannels::default();
2254 update
2255 .remove_channel_invitations
2256 .push(channel_id.to_proto());
2257 if request.accept {
2258 update.channels.push(proto::Channel {
2259 id: channel.id.to_proto(),
2260 name: channel.name,
2261 parent_id: None,
2262 });
2263 }
2264 session.peer.send(session.connection_id, update)?;
2265 response.send(proto::Ack {})?;
2266
2267 Ok(())
2268}
2269
2270async fn join_channel(
2271 request: proto::JoinChannel,
2272 response: Response<proto::JoinChannel>,
2273 session: Session,
2274) -> Result<()> {
2275 let channel_id = ChannelId::from_proto(request.channel_id);
2276
2277 let joined_room = {
2278 let db = session.db().await;
2279 let room_id = db.room_id_for_channel(channel_id).await?;
2280
2281 let joined_room = db
2282 .join_room(
2283 room_id,
2284 session.user_id,
2285 Some(channel_id),
2286 session.connection_id,
2287 )
2288 .await?;
2289
2290 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2291 let token = live_kit
2292 .room_token(
2293 &joined_room.room.live_kit_room,
2294 &session.user_id.to_string(),
2295 )
2296 .trace_err()?;
2297
2298 Some(LiveKitConnectionInfo {
2299 server_url: live_kit.url().into(),
2300 token,
2301 })
2302 });
2303
2304 response.send(proto::JoinRoomResponse {
2305 room: Some(joined_room.room.clone()),
2306 live_kit_connection_info,
2307 })?;
2308
2309 room_updated(&joined_room.room, &session.peer);
2310
2311 joined_room.clone()
2312 };
2313
2314 // TODO - do this while still holding the room guard,
2315 // currently there's a possible race condition if someone joins the channel
2316 // after we've dropped the lock but before we finish sending these updates
2317 channel_updated(
2318 channel_id,
2319 &joined_room.room,
2320 &joined_room.channel_members,
2321 &session.peer,
2322 &*session.connection_pool().await,
2323 );
2324
2325 update_user_contacts(session.user_id, &session).await?;
2326
2327 Ok(())
2328}
2329
2330async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
2331 let project_id = ProjectId::from_proto(request.project_id);
2332 let project_connection_ids = session
2333 .db()
2334 .await
2335 .project_connection_ids(project_id, session.connection_id)
2336 .await?;
2337 broadcast(
2338 Some(session.connection_id),
2339 project_connection_ids.iter().copied(),
2340 |connection_id| {
2341 session
2342 .peer
2343 .forward_send(session.connection_id, connection_id, request.clone())
2344 },
2345 );
2346 Ok(())
2347}
2348
2349async fn get_private_user_info(
2350 _request: proto::GetPrivateUserInfo,
2351 response: Response<proto::GetPrivateUserInfo>,
2352 session: Session,
2353) -> Result<()> {
2354 let metrics_id = session
2355 .db()
2356 .await
2357 .get_user_metrics_id(session.user_id)
2358 .await?;
2359 let user = session
2360 .db()
2361 .await
2362 .get_user_by_id(session.user_id)
2363 .await?
2364 .ok_or_else(|| anyhow!("user not found"))?;
2365 response.send(proto::GetPrivateUserInfoResponse {
2366 metrics_id,
2367 staff: user.admin,
2368 })?;
2369 Ok(())
2370}
2371
2372fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
2373 match message {
2374 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
2375 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
2376 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
2377 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
2378 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
2379 code: frame.code.into(),
2380 reason: frame.reason,
2381 })),
2382 }
2383}
2384
2385fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
2386 match message {
2387 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
2388 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
2389 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
2390 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
2391 AxumMessage::Close(frame) => {
2392 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
2393 code: frame.code.into(),
2394 reason: frame.reason,
2395 }))
2396 }
2397 }
2398}
2399
2400fn build_initial_channels_update(
2401 channels: Vec<db::Channel>,
2402 channel_participants: HashMap<db::ChannelId, Vec<UserId>>,
2403 channel_invites: Vec<db::Channel>,
2404) -> proto::UpdateChannels {
2405 let mut update = proto::UpdateChannels::default();
2406
2407 for channel in channels {
2408 update.channels.push(proto::Channel {
2409 id: channel.id.to_proto(),
2410 name: channel.name,
2411 parent_id: channel.parent_id.map(|id| id.to_proto()),
2412 });
2413 }
2414
2415 for channel in channel_invites {
2416 update.channel_invitations.push(proto::Channel {
2417 id: channel.id.to_proto(),
2418 name: channel.name,
2419 parent_id: None,
2420 });
2421 }
2422
2423 update
2424}
2425
2426fn build_initial_contacts_update(
2427 contacts: Vec<db::Contact>,
2428 pool: &ConnectionPool,
2429) -> proto::UpdateContacts {
2430 let mut update = proto::UpdateContacts::default();
2431
2432 for contact in contacts {
2433 match contact {
2434 db::Contact::Accepted {
2435 user_id,
2436 should_notify,
2437 busy,
2438 } => {
2439 update
2440 .contacts
2441 .push(contact_for_user(user_id, should_notify, busy, &pool));
2442 }
2443 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
2444 db::Contact::Incoming {
2445 user_id,
2446 should_notify,
2447 } => update
2448 .incoming_requests
2449 .push(proto::IncomingContactRequest {
2450 requester_id: user_id.to_proto(),
2451 should_notify,
2452 }),
2453 }
2454 }
2455
2456 update
2457}
2458
2459fn contact_for_user(
2460 user_id: UserId,
2461 should_notify: bool,
2462 busy: bool,
2463 pool: &ConnectionPool,
2464) -> proto::Contact {
2465 proto::Contact {
2466 user_id: user_id.to_proto(),
2467 online: pool.is_user_online(user_id),
2468 busy,
2469 should_notify,
2470 }
2471}
2472
2473fn room_updated(room: &proto::Room, peer: &Peer) {
2474 broadcast(
2475 None,
2476 room.participants
2477 .iter()
2478 .filter_map(|participant| Some(participant.peer_id?.into())),
2479 |peer_id| {
2480 peer.send(
2481 peer_id.into(),
2482 proto::RoomUpdated {
2483 room: Some(room.clone()),
2484 },
2485 )
2486 },
2487 );
2488}
2489
2490fn channel_updated(
2491 channel_id: ChannelId,
2492 room: &proto::Room,
2493 channel_members: &[UserId],
2494 peer: &Peer,
2495 pool: &ConnectionPool,
2496) {
2497 let participants = room
2498 .participants
2499 .iter()
2500 .map(|p| p.user_id)
2501 .collect::<Vec<_>>();
2502
2503 broadcast(
2504 None,
2505 channel_members
2506 .iter()
2507 .filter(|user_id| {
2508 !room
2509 .participants
2510 .iter()
2511 .any(|p| p.user_id == user_id.to_proto())
2512 })
2513 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2514 |peer_id| {
2515 peer.send(
2516 peer_id.into(),
2517 proto::UpdateChannels {
2518 channel_participants: vec![proto::ChannelParticipants {
2519 channel_id: channel_id.to_proto(),
2520 participant_user_ids: participants.clone(),
2521 }],
2522 ..Default::default()
2523 },
2524 )
2525 },
2526 );
2527}
2528
2529async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
2530 let db = session.db().await;
2531 let contacts = db.get_contacts(user_id).await?;
2532 let busy = db.is_user_busy(user_id).await?;
2533
2534 let pool = session.connection_pool().await;
2535 let updated_contact = contact_for_user(user_id, false, busy, &pool);
2536 for contact in contacts {
2537 if let db::Contact::Accepted {
2538 user_id: contact_user_id,
2539 ..
2540 } = contact
2541 {
2542 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
2543 session
2544 .peer
2545 .send(
2546 contact_conn_id,
2547 proto::UpdateContacts {
2548 contacts: vec![updated_contact.clone()],
2549 remove_contacts: Default::default(),
2550 incoming_requests: Default::default(),
2551 remove_incoming_requests: Default::default(),
2552 outgoing_requests: Default::default(),
2553 remove_outgoing_requests: Default::default(),
2554 },
2555 )
2556 .trace_err();
2557 }
2558 }
2559 }
2560 Ok(())
2561}
2562
2563async fn leave_room_for_session(session: &Session) -> Result<()> {
2564 let mut contacts_to_update = HashSet::default();
2565
2566 let room_id;
2567 let canceled_calls_to_user_ids;
2568 let live_kit_room;
2569 let delete_live_kit_room;
2570 let room;
2571 let channel_members;
2572 let channel_id;
2573
2574 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
2575 contacts_to_update.insert(session.user_id);
2576
2577 for project in left_room.left_projects.values() {
2578 project_left(project, session);
2579 }
2580
2581 room_id = RoomId::from_proto(left_room.room.id);
2582 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
2583 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
2584 delete_live_kit_room = left_room.deleted;
2585 room = mem::take(&mut left_room.room);
2586 channel_members = mem::take(&mut left_room.channel_members);
2587 channel_id = left_room.channel_id;
2588
2589 room_updated(&room, &session.peer);
2590 } else {
2591 return Ok(());
2592 }
2593
2594 // TODO - do this while holding the room guard.
2595 if let Some(channel_id) = channel_id {
2596 channel_updated(
2597 channel_id,
2598 &room,
2599 &channel_members,
2600 &session.peer,
2601 &*session.connection_pool().await,
2602 );
2603 }
2604
2605 {
2606 let pool = session.connection_pool().await;
2607 for canceled_user_id in canceled_calls_to_user_ids {
2608 for connection_id in pool.user_connection_ids(canceled_user_id) {
2609 session
2610 .peer
2611 .send(
2612 connection_id,
2613 proto::CallCanceled {
2614 room_id: room_id.to_proto(),
2615 },
2616 )
2617 .trace_err();
2618 }
2619 contacts_to_update.insert(canceled_user_id);
2620 }
2621 }
2622
2623 for contact_user_id in contacts_to_update {
2624 update_user_contacts(contact_user_id, &session).await?;
2625 }
2626
2627 if let Some(live_kit) = session.live_kit_client.as_ref() {
2628 live_kit
2629 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
2630 .await
2631 .trace_err();
2632
2633 if delete_live_kit_room {
2634 live_kit.delete_room(live_kit_room).await.trace_err();
2635 }
2636 }
2637
2638 Ok(())
2639}
2640
2641fn project_left(project: &db::LeftProject, session: &Session) {
2642 for connection_id in &project.connection_ids {
2643 if project.host_user_id == session.user_id {
2644 session
2645 .peer
2646 .send(
2647 *connection_id,
2648 proto::UnshareProject {
2649 project_id: project.id.to_proto(),
2650 },
2651 )
2652 .trace_err();
2653 } else {
2654 session
2655 .peer
2656 .send(
2657 *connection_id,
2658 proto::RemoveProjectCollaborator {
2659 project_id: project.id.to_proto(),
2660 peer_id: Some(session.connection_id.into()),
2661 },
2662 )
2663 .trace_err();
2664 }
2665 }
2666}
2667
2668pub trait ResultExt {
2669 type Ok;
2670
2671 fn trace_err(self) -> Option<Self::Ok>;
2672}
2673
2674impl<T, E> ResultExt for Result<T, E>
2675where
2676 E: std::fmt::Debug,
2677{
2678 type Ok = T;
2679
2680 fn trace_err(self) -> Option<T> {
2681 match self {
2682 Ok(value) => Some(value),
2683 Err(error) => {
2684 tracing::error!("{:?}", error);
2685 None
2686 }
2687 }
2688 }
2689}