1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{self, ChannelId, Database, ProjectId, RoomId, ServerId, User, UserId},
6 executor::Executor,
7 AppState, Result,
8};
9use anyhow::anyhow;
10use async_tungstenite::tungstenite::{
11 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
12};
13use axum::{
14 body::Body,
15 extract::{
16 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
17 ConnectInfo, WebSocketUpgrade,
18 },
19 headers::{Header, HeaderName},
20 http::StatusCode,
21 middleware,
22 response::IntoResponse,
23 routing::get,
24 Extension, Router, TypedHeader,
25};
26use collections::{HashMap, HashSet};
27pub use connection_pool::ConnectionPool;
28use futures::{
29 channel::oneshot,
30 future::{self, BoxFuture},
31 stream::FuturesUnordered,
32 FutureExt, SinkExt, StreamExt, TryStreamExt,
33};
34use lazy_static::lazy_static;
35use prometheus::{register_int_gauge, IntGauge};
36use rpc::{
37 proto::{
38 self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
39 RequestMessage,
40 },
41 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
42};
43use serde::{Serialize, Serializer};
44use std::{
45 any::TypeId,
46 fmt,
47 future::Future,
48 marker::PhantomData,
49 mem,
50 net::SocketAddr,
51 ops::{Deref, DerefMut},
52 rc::Rc,
53 sync::{
54 atomic::{AtomicBool, Ordering::SeqCst},
55 Arc,
56 },
57 time::{Duration, Instant},
58};
59use tokio::sync::{watch, Semaphore};
60use tower::ServiceBuilder;
61use tracing::{info_span, instrument, Instrument};
62
63pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
64pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
65
66lazy_static! {
67 static ref METRIC_CONNECTIONS: IntGauge =
68 register_int_gauge!("connections", "number of connections").unwrap();
69 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
70 "shared_projects",
71 "number of open projects with one or more guests"
72 )
73 .unwrap();
74}
75
76type MessageHandler =
77 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
78
79struct Response<R> {
80 peer: Arc<Peer>,
81 receipt: Receipt<R>,
82 responded: Arc<AtomicBool>,
83}
84
85impl<R: RequestMessage> Response<R> {
86 fn send(self, payload: R::Response) -> Result<()> {
87 self.responded.store(true, SeqCst);
88 self.peer.respond(self.receipt, payload)?;
89 Ok(())
90 }
91}
92
93#[derive(Clone)]
94struct Session {
95 user_id: UserId,
96 connection_id: ConnectionId,
97 db: Arc<tokio::sync::Mutex<DbHandle>>,
98 peer: Arc<Peer>,
99 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
100 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
101 executor: Executor,
102}
103
104impl Session {
105 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
106 #[cfg(test)]
107 tokio::task::yield_now().await;
108 let guard = self.db.lock().await;
109 #[cfg(test)]
110 tokio::task::yield_now().await;
111 guard
112 }
113
114 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
115 #[cfg(test)]
116 tokio::task::yield_now().await;
117 let guard = self.connection_pool.lock();
118 ConnectionPoolGuard {
119 guard,
120 _not_send: PhantomData,
121 }
122 }
123}
124
125impl fmt::Debug for Session {
126 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127 f.debug_struct("Session")
128 .field("user_id", &self.user_id)
129 .field("connection_id", &self.connection_id)
130 .finish()
131 }
132}
133
134struct DbHandle(Arc<Database>);
135
136impl Deref for DbHandle {
137 type Target = Database;
138
139 fn deref(&self) -> &Self::Target {
140 self.0.as_ref()
141 }
142}
143
144pub struct Server {
145 id: parking_lot::Mutex<ServerId>,
146 peer: Arc<Peer>,
147 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
148 app_state: Arc<AppState>,
149 executor: Executor,
150 handlers: HashMap<TypeId, MessageHandler>,
151 teardown: watch::Sender<()>,
152}
153
154pub(crate) struct ConnectionPoolGuard<'a> {
155 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
156 _not_send: PhantomData<Rc<()>>,
157}
158
159#[derive(Serialize)]
160pub struct ServerSnapshot<'a> {
161 peer: &'a Peer,
162 #[serde(serialize_with = "serialize_deref")]
163 connection_pool: ConnectionPoolGuard<'a>,
164}
165
166pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
167where
168 S: Serializer,
169 T: Deref<Target = U>,
170 U: Serialize,
171{
172 Serialize::serialize(value.deref(), serializer)
173}
174
175impl Server {
176 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
177 let mut server = Self {
178 id: parking_lot::Mutex::new(id),
179 peer: Peer::new(id.0 as u32),
180 app_state,
181 executor,
182 connection_pool: Default::default(),
183 handlers: Default::default(),
184 teardown: watch::channel(()).0,
185 };
186
187 server
188 .add_request_handler(ping)
189 .add_request_handler(create_room)
190 .add_request_handler(join_room)
191 .add_request_handler(rejoin_room)
192 .add_request_handler(leave_room)
193 .add_request_handler(call)
194 .add_request_handler(cancel_call)
195 .add_message_handler(decline_call)
196 .add_request_handler(update_participant_location)
197 .add_request_handler(share_project)
198 .add_message_handler(unshare_project)
199 .add_request_handler(join_project)
200 .add_message_handler(leave_project)
201 .add_request_handler(update_project)
202 .add_request_handler(update_worktree)
203 .add_message_handler(start_language_server)
204 .add_message_handler(update_language_server)
205 .add_message_handler(update_diagnostic_summary)
206 .add_message_handler(update_worktree_settings)
207 .add_message_handler(refresh_inlay_hints)
208 .add_request_handler(forward_project_request::<proto::GetHover>)
209 .add_request_handler(forward_project_request::<proto::GetDefinition>)
210 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
211 .add_request_handler(forward_project_request::<proto::GetReferences>)
212 .add_request_handler(forward_project_request::<proto::SearchProject>)
213 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
214 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
215 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
216 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
217 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
218 .add_request_handler(forward_project_request::<proto::GetCompletions>)
219 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
220 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
221 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
222 .add_request_handler(forward_project_request::<proto::PrepareRename>)
223 .add_request_handler(forward_project_request::<proto::PerformRename>)
224 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
225 .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
226 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
227 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
228 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
229 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
230 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
231 .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
232 .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
233 .add_request_handler(forward_project_request::<proto::InlayHints>)
234 .add_message_handler(create_buffer_for_peer)
235 .add_request_handler(update_buffer)
236 .add_message_handler(update_buffer_file)
237 .add_message_handler(buffer_reloaded)
238 .add_message_handler(buffer_saved)
239 .add_request_handler(forward_project_request::<proto::SaveBuffer>)
240 .add_request_handler(get_users)
241 .add_request_handler(fuzzy_search_users)
242 .add_request_handler(request_contact)
243 .add_request_handler(remove_contact)
244 .add_request_handler(respond_to_contact_request)
245 .add_request_handler(create_channel)
246 .add_request_handler(remove_channel)
247 .add_request_handler(invite_channel_member)
248 .add_request_handler(remove_channel_member)
249 .add_request_handler(get_channel_members)
250 .add_request_handler(respond_to_channel_invite)
251 .add_request_handler(join_channel)
252 .add_request_handler(follow)
253 .add_message_handler(unfollow)
254 .add_message_handler(update_followers)
255 .add_message_handler(update_diff_base)
256 .add_request_handler(get_private_user_info);
257
258 Arc::new(server)
259 }
260
261 pub async fn start(&self) -> Result<()> {
262 let server_id = *self.id.lock();
263 let app_state = self.app_state.clone();
264 let peer = self.peer.clone();
265 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
266 let pool = self.connection_pool.clone();
267 let live_kit_client = self.app_state.live_kit_client.clone();
268
269 let span = info_span!("start server");
270 self.executor.spawn_detached(
271 async move {
272 tracing::info!("waiting for cleanup timeout");
273 timeout.await;
274 tracing::info!("cleanup timeout expired, retrieving stale rooms");
275 if let Some(room_ids) = app_state
276 .db
277 .stale_room_ids(&app_state.config.zed_environment, server_id)
278 .await
279 .trace_err()
280 {
281 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
282 for room_id in room_ids {
283 let mut contacts_to_update = HashSet::default();
284 let mut canceled_calls_to_user_ids = Vec::new();
285 let mut live_kit_room = String::new();
286 let mut delete_live_kit_room = false;
287
288 if let Some(mut refreshed_room) = app_state
289 .db
290 .refresh_room(room_id, server_id)
291 .await
292 .trace_err()
293 {
294 tracing::info!(
295 room_id = room_id.0,
296 new_participant_count = refreshed_room.room.participants.len(),
297 "refreshed room"
298 );
299 room_updated(&refreshed_room.room, &peer);
300 if let Some(channel_id) = refreshed_room.channel_id {
301 channel_updated(
302 channel_id,
303 &refreshed_room.room,
304 &refreshed_room.channel_members,
305 &peer,
306 &*pool.lock(),
307 );
308 }
309 contacts_to_update
310 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
311 contacts_to_update
312 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
313 canceled_calls_to_user_ids =
314 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
315 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
316 delete_live_kit_room = refreshed_room.room.participants.is_empty();
317 }
318
319 {
320 let pool = pool.lock();
321 for canceled_user_id in canceled_calls_to_user_ids {
322 for connection_id in pool.user_connection_ids(canceled_user_id) {
323 peer.send(
324 connection_id,
325 proto::CallCanceled {
326 room_id: room_id.to_proto(),
327 },
328 )
329 .trace_err();
330 }
331 }
332 }
333
334 for user_id in contacts_to_update {
335 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
336 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
337 if let Some((busy, contacts)) = busy.zip(contacts) {
338 let pool = pool.lock();
339 let updated_contact = contact_for_user(user_id, false, busy, &pool);
340 for contact in contacts {
341 if let db::Contact::Accepted {
342 user_id: contact_user_id,
343 ..
344 } = contact
345 {
346 for contact_conn_id in
347 pool.user_connection_ids(contact_user_id)
348 {
349 peer.send(
350 contact_conn_id,
351 proto::UpdateContacts {
352 contacts: vec![updated_contact.clone()],
353 remove_contacts: Default::default(),
354 incoming_requests: Default::default(),
355 remove_incoming_requests: Default::default(),
356 outgoing_requests: Default::default(),
357 remove_outgoing_requests: Default::default(),
358 },
359 )
360 .trace_err();
361 }
362 }
363 }
364 }
365 }
366
367 if let Some(live_kit) = live_kit_client.as_ref() {
368 if delete_live_kit_room {
369 live_kit.delete_room(live_kit_room).await.trace_err();
370 }
371 }
372 }
373 }
374
375 app_state
376 .db
377 .delete_stale_servers(&app_state.config.zed_environment, server_id)
378 .await
379 .trace_err();
380 }
381 .instrument(span),
382 );
383 Ok(())
384 }
385
386 pub fn teardown(&self) {
387 self.peer.teardown();
388 self.connection_pool.lock().reset();
389 let _ = self.teardown.send(());
390 }
391
392 #[cfg(test)]
393 pub fn reset(&self, id: ServerId) {
394 self.teardown();
395 *self.id.lock() = id;
396 self.peer.reset(id.0 as u32);
397 }
398
399 #[cfg(test)]
400 pub fn id(&self) -> ServerId {
401 *self.id.lock()
402 }
403
404 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
405 where
406 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
407 Fut: 'static + Send + Future<Output = Result<()>>,
408 M: EnvelopedMessage,
409 {
410 let prev_handler = self.handlers.insert(
411 TypeId::of::<M>(),
412 Box::new(move |envelope, session| {
413 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
414 let span = info_span!(
415 "handle message",
416 payload_type = envelope.payload_type_name()
417 );
418 span.in_scope(|| {
419 tracing::info!(
420 payload_type = envelope.payload_type_name(),
421 "message received"
422 );
423 });
424 let start_time = Instant::now();
425 let future = (handler)(*envelope, session);
426 async move {
427 let result = future.await;
428 let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
429 match result {
430 Err(error) => {
431 tracing::error!(%error, ?duration_ms, "error handling message")
432 }
433 Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
434 }
435 }
436 .instrument(span)
437 .boxed()
438 }),
439 );
440 if prev_handler.is_some() {
441 panic!("registered a handler for the same message twice");
442 }
443 self
444 }
445
446 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
447 where
448 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
449 Fut: 'static + Send + Future<Output = Result<()>>,
450 M: EnvelopedMessage,
451 {
452 self.add_handler(move |envelope, session| handler(envelope.payload, session));
453 self
454 }
455
456 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
457 where
458 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
459 Fut: Send + Future<Output = Result<()>>,
460 M: RequestMessage,
461 {
462 let handler = Arc::new(handler);
463 self.add_handler(move |envelope, session| {
464 let receipt = envelope.receipt();
465 let handler = handler.clone();
466 async move {
467 let peer = session.peer.clone();
468 let responded = Arc::new(AtomicBool::default());
469 let response = Response {
470 peer: peer.clone(),
471 responded: responded.clone(),
472 receipt,
473 };
474 match (handler)(envelope.payload, response, session).await {
475 Ok(()) => {
476 if responded.load(std::sync::atomic::Ordering::SeqCst) {
477 Ok(())
478 } else {
479 Err(anyhow!("handler did not send a response"))?
480 }
481 }
482 Err(error) => {
483 peer.respond_with_error(
484 receipt,
485 proto::Error {
486 message: error.to_string(),
487 },
488 )?;
489 Err(error)
490 }
491 }
492 }
493 })
494 }
495
496 pub fn handle_connection(
497 self: &Arc<Self>,
498 connection: Connection,
499 address: String,
500 user: User,
501 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
502 executor: Executor,
503 ) -> impl Future<Output = Result<()>> {
504 let this = self.clone();
505 let user_id = user.id;
506 let login = user.github_login;
507 let span = info_span!("handle connection", %user_id, %login, %address);
508 let mut teardown = self.teardown.subscribe();
509 async move {
510 let (connection_id, handle_io, mut incoming_rx) = this
511 .peer
512 .add_connection(connection, {
513 let executor = executor.clone();
514 move |duration| executor.sleep(duration)
515 });
516
517 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
518 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
519 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
520
521 if let Some(send_connection_id) = send_connection_id.take() {
522 let _ = send_connection_id.send(connection_id);
523 }
524
525 if !user.connected_once {
526 this.peer.send(connection_id, proto::ShowContacts {})?;
527 this.app_state.db.set_user_connected_once(user_id, true).await?;
528 }
529
530 let (contacts, invite_code, (channels, channel_participants), channel_invites) = future::try_join4(
531 this.app_state.db.get_contacts(user_id),
532 this.app_state.db.get_invite_code_for_user(user_id),
533 this.app_state.db.get_channels_for_user(user_id),
534 this.app_state.db.get_channel_invites_for_user(user_id)
535 ).await?;
536
537 {
538 let mut pool = this.connection_pool.lock();
539 pool.add_connection(connection_id, user_id, user.admin);
540 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
541 this.peer.send(connection_id, build_initial_channels_update(channels, channel_participants, channel_invites))?;
542
543 if let Some((code, count)) = invite_code {
544 this.peer.send(connection_id, proto::UpdateInviteInfo {
545 url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
546 count: count as u32,
547 })?;
548 }
549 }
550
551 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
552 this.peer.send(connection_id, incoming_call)?;
553 }
554
555 let session = Session {
556 user_id,
557 connection_id,
558 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
559 peer: this.peer.clone(),
560 connection_pool: this.connection_pool.clone(),
561 live_kit_client: this.app_state.live_kit_client.clone(),
562 executor: executor.clone(),
563 };
564 update_user_contacts(user_id, &session).await?;
565
566 let handle_io = handle_io.fuse();
567 futures::pin_mut!(handle_io);
568
569 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
570 // This prevents deadlocks when e.g., client A performs a request to client B and
571 // client B performs a request to client A. If both clients stop processing further
572 // messages until their respective request completes, they won't have a chance to
573 // respond to the other client's request and cause a deadlock.
574 //
575 // This arrangement ensures we will attempt to process earlier messages first, but fall
576 // back to processing messages arrived later in the spirit of making progress.
577 let mut foreground_message_handlers = FuturesUnordered::new();
578 let concurrent_handlers = Arc::new(Semaphore::new(256));
579 loop {
580 let next_message = async {
581 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
582 let message = incoming_rx.next().await;
583 (permit, message)
584 }.fuse();
585 futures::pin_mut!(next_message);
586 futures::select_biased! {
587 _ = teardown.changed().fuse() => return Ok(()),
588 result = handle_io => {
589 if let Err(error) = result {
590 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
591 }
592 break;
593 }
594 _ = foreground_message_handlers.next() => {}
595 next_message = next_message => {
596 let (permit, message) = next_message;
597 if let Some(message) = message {
598 let type_name = message.payload_type_name();
599 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
600 let span_enter = span.enter();
601 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
602 let is_background = message.is_background();
603 let handle_message = (handler)(message, session.clone());
604 drop(span_enter);
605
606 let handle_message = async move {
607 handle_message.await;
608 drop(permit);
609 }.instrument(span);
610 if is_background {
611 executor.spawn_detached(handle_message);
612 } else {
613 foreground_message_handlers.push(handle_message);
614 }
615 } else {
616 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
617 }
618 } else {
619 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
620 break;
621 }
622 }
623 }
624 }
625
626 drop(foreground_message_handlers);
627 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
628 if let Err(error) = connection_lost(session, teardown, executor).await {
629 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
630 }
631
632 Ok(())
633 }.instrument(span)
634 }
635
636 pub async fn invite_code_redeemed(
637 self: &Arc<Self>,
638 inviter_id: UserId,
639 invitee_id: UserId,
640 ) -> Result<()> {
641 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
642 if let Some(code) = &user.invite_code {
643 let pool = self.connection_pool.lock();
644 let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
645 for connection_id in pool.user_connection_ids(inviter_id) {
646 self.peer.send(
647 connection_id,
648 proto::UpdateContacts {
649 contacts: vec![invitee_contact.clone()],
650 ..Default::default()
651 },
652 )?;
653 self.peer.send(
654 connection_id,
655 proto::UpdateInviteInfo {
656 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
657 count: user.invite_count as u32,
658 },
659 )?;
660 }
661 }
662 }
663 Ok(())
664 }
665
666 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
667 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
668 if let Some(invite_code) = &user.invite_code {
669 let pool = self.connection_pool.lock();
670 for connection_id in pool.user_connection_ids(user_id) {
671 self.peer.send(
672 connection_id,
673 proto::UpdateInviteInfo {
674 url: format!(
675 "{}{}",
676 self.app_state.config.invite_link_prefix, invite_code
677 ),
678 count: user.invite_count as u32,
679 },
680 )?;
681 }
682 }
683 }
684 Ok(())
685 }
686
687 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
688 ServerSnapshot {
689 connection_pool: ConnectionPoolGuard {
690 guard: self.connection_pool.lock(),
691 _not_send: PhantomData,
692 },
693 peer: &self.peer,
694 }
695 }
696}
697
698impl<'a> Deref for ConnectionPoolGuard<'a> {
699 type Target = ConnectionPool;
700
701 fn deref(&self) -> &Self::Target {
702 &*self.guard
703 }
704}
705
706impl<'a> DerefMut for ConnectionPoolGuard<'a> {
707 fn deref_mut(&mut self) -> &mut Self::Target {
708 &mut *self.guard
709 }
710}
711
712impl<'a> Drop for ConnectionPoolGuard<'a> {
713 fn drop(&mut self) {
714 #[cfg(test)]
715 self.check_invariants();
716 }
717}
718
719fn broadcast<F>(
720 sender_id: Option<ConnectionId>,
721 receiver_ids: impl IntoIterator<Item = ConnectionId>,
722 mut f: F,
723) where
724 F: FnMut(ConnectionId) -> anyhow::Result<()>,
725{
726 for receiver_id in receiver_ids {
727 if Some(receiver_id) != sender_id {
728 if let Err(error) = f(receiver_id) {
729 tracing::error!("failed to send to {:?} {}", receiver_id, error);
730 }
731 }
732 }
733}
734
735lazy_static! {
736 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
737}
738
739pub struct ProtocolVersion(u32);
740
741impl Header for ProtocolVersion {
742 fn name() -> &'static HeaderName {
743 &ZED_PROTOCOL_VERSION
744 }
745
746 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
747 where
748 Self: Sized,
749 I: Iterator<Item = &'i axum::http::HeaderValue>,
750 {
751 let version = values
752 .next()
753 .ok_or_else(axum::headers::Error::invalid)?
754 .to_str()
755 .map_err(|_| axum::headers::Error::invalid())?
756 .parse()
757 .map_err(|_| axum::headers::Error::invalid())?;
758 Ok(Self(version))
759 }
760
761 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
762 values.extend([self.0.to_string().parse().unwrap()]);
763 }
764}
765
766pub fn routes(server: Arc<Server>) -> Router<Body> {
767 Router::new()
768 .route("/rpc", get(handle_websocket_request))
769 .layer(
770 ServiceBuilder::new()
771 .layer(Extension(server.app_state.clone()))
772 .layer(middleware::from_fn(auth::validate_header)),
773 )
774 .route("/metrics", get(handle_metrics))
775 .layer(Extension(server))
776}
777
778pub async fn handle_websocket_request(
779 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
780 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
781 Extension(server): Extension<Arc<Server>>,
782 Extension(user): Extension<User>,
783 ws: WebSocketUpgrade,
784) -> axum::response::Response {
785 if protocol_version != rpc::PROTOCOL_VERSION {
786 return (
787 StatusCode::UPGRADE_REQUIRED,
788 "client must be upgraded".to_string(),
789 )
790 .into_response();
791 }
792 let socket_address = socket_address.to_string();
793 ws.on_upgrade(move |socket| {
794 use util::ResultExt;
795 let socket = socket
796 .map_ok(to_tungstenite_message)
797 .err_into()
798 .with(|message| async move { Ok(to_axum_message(message)) });
799 let connection = Connection::new(Box::pin(socket));
800 async move {
801 server
802 .handle_connection(connection, socket_address, user, None, Executor::Production)
803 .await
804 .log_err();
805 }
806 })
807}
808
809pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
810 let connections = server
811 .connection_pool
812 .lock()
813 .connections()
814 .filter(|connection| !connection.admin)
815 .count();
816
817 METRIC_CONNECTIONS.set(connections as _);
818
819 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
820 METRIC_SHARED_PROJECTS.set(shared_projects as _);
821
822 let encoder = prometheus::TextEncoder::new();
823 let metric_families = prometheus::gather();
824 let encoded_metrics = encoder
825 .encode_to_string(&metric_families)
826 .map_err(|err| anyhow!("{}", err))?;
827 Ok(encoded_metrics)
828}
829
830#[instrument(err, skip(executor))]
831async fn connection_lost(
832 session: Session,
833 mut teardown: watch::Receiver<()>,
834 executor: Executor,
835) -> Result<()> {
836 session.peer.disconnect(session.connection_id);
837 session
838 .connection_pool()
839 .await
840 .remove_connection(session.connection_id)?;
841
842 session
843 .db()
844 .await
845 .connection_lost(session.connection_id)
846 .await
847 .trace_err();
848
849 futures::select_biased! {
850 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
851 leave_room_for_session(&session).await.trace_err();
852
853 if !session
854 .connection_pool()
855 .await
856 .is_user_online(session.user_id)
857 {
858 let db = session.db().await;
859 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
860 room_updated(&room, &session.peer);
861 }
862 }
863 update_user_contacts(session.user_id, &session).await?;
864 }
865 _ = teardown.changed().fuse() => {}
866 }
867
868 Ok(())
869}
870
871async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
872 response.send(proto::Ack {})?;
873 Ok(())
874}
875
876async fn create_room(
877 _request: proto::CreateRoom,
878 response: Response<proto::CreateRoom>,
879 session: Session,
880) -> Result<()> {
881 let live_kit_room = nanoid::nanoid!(30);
882
883 let live_kit_connection_info = {
884 let live_kit_room = live_kit_room.clone();
885 let live_kit = session.live_kit_client.as_ref();
886
887 util::async_iife!({
888 let live_kit = live_kit?;
889
890 live_kit
891 .create_room(live_kit_room.clone())
892 .await
893 .trace_err()?;
894
895 let token = live_kit
896 .room_token(&live_kit_room, &session.user_id.to_string())
897 .trace_err()?;
898
899 Some(proto::LiveKitConnectionInfo {
900 server_url: live_kit.url().into(),
901 token,
902 })
903 })
904 }
905 .await;
906
907 let room = session
908 .db()
909 .await
910 .create_room(session.user_id, session.connection_id, &live_kit_room)
911 .await?;
912
913 response.send(proto::CreateRoomResponse {
914 room: Some(room.clone()),
915 live_kit_connection_info,
916 })?;
917
918 update_user_contacts(session.user_id, &session).await?;
919 Ok(())
920}
921
922async fn join_room(
923 request: proto::JoinRoom,
924 response: Response<proto::JoinRoom>,
925 session: Session,
926) -> Result<()> {
927 let room_id = RoomId::from_proto(request.id);
928 let room = {
929 let room = session
930 .db()
931 .await
932 .join_room(room_id, session.user_id, None, session.connection_id)
933 .await?;
934 room_updated(&room.room, &session.peer);
935 room.room.clone()
936 };
937
938 for connection_id in session
939 .connection_pool()
940 .await
941 .user_connection_ids(session.user_id)
942 {
943 session
944 .peer
945 .send(
946 connection_id,
947 proto::CallCanceled {
948 room_id: room_id.to_proto(),
949 },
950 )
951 .trace_err();
952 }
953
954 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
955 if let Some(token) = live_kit
956 .room_token(&room.live_kit_room, &session.user_id.to_string())
957 .trace_err()
958 {
959 Some(proto::LiveKitConnectionInfo {
960 server_url: live_kit.url().into(),
961 token,
962 })
963 } else {
964 None
965 }
966 } else {
967 None
968 };
969
970 response.send(proto::JoinRoomResponse {
971 room: Some(room),
972 live_kit_connection_info,
973 })?;
974
975 update_user_contacts(session.user_id, &session).await?;
976 Ok(())
977}
978
979async fn rejoin_room(
980 request: proto::RejoinRoom,
981 response: Response<proto::RejoinRoom>,
982 session: Session,
983) -> Result<()> {
984 let room;
985 let channel_id;
986 let channel_members;
987 {
988 let mut rejoined_room = session
989 .db()
990 .await
991 .rejoin_room(request, session.user_id, session.connection_id)
992 .await?;
993
994 response.send(proto::RejoinRoomResponse {
995 room: Some(rejoined_room.room.clone()),
996 reshared_projects: rejoined_room
997 .reshared_projects
998 .iter()
999 .map(|project| proto::ResharedProject {
1000 id: project.id.to_proto(),
1001 collaborators: project
1002 .collaborators
1003 .iter()
1004 .map(|collaborator| collaborator.to_proto())
1005 .collect(),
1006 })
1007 .collect(),
1008 rejoined_projects: rejoined_room
1009 .rejoined_projects
1010 .iter()
1011 .map(|rejoined_project| proto::RejoinedProject {
1012 id: rejoined_project.id.to_proto(),
1013 worktrees: rejoined_project
1014 .worktrees
1015 .iter()
1016 .map(|worktree| proto::WorktreeMetadata {
1017 id: worktree.id,
1018 root_name: worktree.root_name.clone(),
1019 visible: worktree.visible,
1020 abs_path: worktree.abs_path.clone(),
1021 })
1022 .collect(),
1023 collaborators: rejoined_project
1024 .collaborators
1025 .iter()
1026 .map(|collaborator| collaborator.to_proto())
1027 .collect(),
1028 language_servers: rejoined_project.language_servers.clone(),
1029 })
1030 .collect(),
1031 })?;
1032 room_updated(&rejoined_room.room, &session.peer);
1033
1034 for project in &rejoined_room.reshared_projects {
1035 for collaborator in &project.collaborators {
1036 session
1037 .peer
1038 .send(
1039 collaborator.connection_id,
1040 proto::UpdateProjectCollaborator {
1041 project_id: project.id.to_proto(),
1042 old_peer_id: Some(project.old_connection_id.into()),
1043 new_peer_id: Some(session.connection_id.into()),
1044 },
1045 )
1046 .trace_err();
1047 }
1048
1049 broadcast(
1050 Some(session.connection_id),
1051 project
1052 .collaborators
1053 .iter()
1054 .map(|collaborator| collaborator.connection_id),
1055 |connection_id| {
1056 session.peer.forward_send(
1057 session.connection_id,
1058 connection_id,
1059 proto::UpdateProject {
1060 project_id: project.id.to_proto(),
1061 worktrees: project.worktrees.clone(),
1062 },
1063 )
1064 },
1065 );
1066 }
1067
1068 for project in &rejoined_room.rejoined_projects {
1069 for collaborator in &project.collaborators {
1070 session
1071 .peer
1072 .send(
1073 collaborator.connection_id,
1074 proto::UpdateProjectCollaborator {
1075 project_id: project.id.to_proto(),
1076 old_peer_id: Some(project.old_connection_id.into()),
1077 new_peer_id: Some(session.connection_id.into()),
1078 },
1079 )
1080 .trace_err();
1081 }
1082 }
1083
1084 for project in &mut rejoined_room.rejoined_projects {
1085 for worktree in mem::take(&mut project.worktrees) {
1086 #[cfg(any(test, feature = "test-support"))]
1087 const MAX_CHUNK_SIZE: usize = 2;
1088 #[cfg(not(any(test, feature = "test-support")))]
1089 const MAX_CHUNK_SIZE: usize = 256;
1090
1091 // Stream this worktree's entries.
1092 let message = proto::UpdateWorktree {
1093 project_id: project.id.to_proto(),
1094 worktree_id: worktree.id,
1095 abs_path: worktree.abs_path.clone(),
1096 root_name: worktree.root_name,
1097 updated_entries: worktree.updated_entries,
1098 removed_entries: worktree.removed_entries,
1099 scan_id: worktree.scan_id,
1100 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1101 updated_repositories: worktree.updated_repositories,
1102 removed_repositories: worktree.removed_repositories,
1103 };
1104 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1105 session.peer.send(session.connection_id, update.clone())?;
1106 }
1107
1108 // Stream this worktree's diagnostics.
1109 for summary in worktree.diagnostic_summaries {
1110 session.peer.send(
1111 session.connection_id,
1112 proto::UpdateDiagnosticSummary {
1113 project_id: project.id.to_proto(),
1114 worktree_id: worktree.id,
1115 summary: Some(summary),
1116 },
1117 )?;
1118 }
1119
1120 for settings_file in worktree.settings_files {
1121 session.peer.send(
1122 session.connection_id,
1123 proto::UpdateWorktreeSettings {
1124 project_id: project.id.to_proto(),
1125 worktree_id: worktree.id,
1126 path: settings_file.path,
1127 content: Some(settings_file.content),
1128 },
1129 )?;
1130 }
1131 }
1132
1133 for language_server in &project.language_servers {
1134 session.peer.send(
1135 session.connection_id,
1136 proto::UpdateLanguageServer {
1137 project_id: project.id.to_proto(),
1138 language_server_id: language_server.id,
1139 variant: Some(
1140 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1141 proto::LspDiskBasedDiagnosticsUpdated {},
1142 ),
1143 ),
1144 },
1145 )?;
1146 }
1147 }
1148
1149 room = mem::take(&mut rejoined_room.room);
1150 channel_id = rejoined_room.channel_id;
1151 channel_members = mem::take(&mut rejoined_room.channel_members);
1152 }
1153
1154 //TODO: move this into the room guard
1155 if let Some(channel_id) = channel_id {
1156 channel_updated(
1157 channel_id,
1158 &room,
1159 &channel_members,
1160 &session.peer,
1161 &*session.connection_pool().await,
1162 );
1163 }
1164
1165 update_user_contacts(session.user_id, &session).await?;
1166 Ok(())
1167}
1168
1169async fn leave_room(
1170 _: proto::LeaveRoom,
1171 response: Response<proto::LeaveRoom>,
1172 session: Session,
1173) -> Result<()> {
1174 leave_room_for_session(&session).await?;
1175 response.send(proto::Ack {})?;
1176 Ok(())
1177}
1178
1179async fn call(
1180 request: proto::Call,
1181 response: Response<proto::Call>,
1182 session: Session,
1183) -> Result<()> {
1184 let room_id = RoomId::from_proto(request.room_id);
1185 let calling_user_id = session.user_id;
1186 let calling_connection_id = session.connection_id;
1187 let called_user_id = UserId::from_proto(request.called_user_id);
1188 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1189 if !session
1190 .db()
1191 .await
1192 .has_contact(calling_user_id, called_user_id)
1193 .await?
1194 {
1195 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1196 }
1197
1198 let incoming_call = {
1199 let (room, incoming_call) = &mut *session
1200 .db()
1201 .await
1202 .call(
1203 room_id,
1204 calling_user_id,
1205 calling_connection_id,
1206 called_user_id,
1207 initial_project_id,
1208 )
1209 .await?;
1210 room_updated(&room, &session.peer);
1211 mem::take(incoming_call)
1212 };
1213 update_user_contacts(called_user_id, &session).await?;
1214
1215 let mut calls = session
1216 .connection_pool()
1217 .await
1218 .user_connection_ids(called_user_id)
1219 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1220 .collect::<FuturesUnordered<_>>();
1221
1222 while let Some(call_response) = calls.next().await {
1223 match call_response.as_ref() {
1224 Ok(_) => {
1225 response.send(proto::Ack {})?;
1226 return Ok(());
1227 }
1228 Err(_) => {
1229 call_response.trace_err();
1230 }
1231 }
1232 }
1233
1234 {
1235 let room = session
1236 .db()
1237 .await
1238 .call_failed(room_id, called_user_id)
1239 .await?;
1240 room_updated(&room, &session.peer);
1241 }
1242 update_user_contacts(called_user_id, &session).await?;
1243
1244 Err(anyhow!("failed to ring user"))?
1245}
1246
1247async fn cancel_call(
1248 request: proto::CancelCall,
1249 response: Response<proto::CancelCall>,
1250 session: Session,
1251) -> Result<()> {
1252 let called_user_id = UserId::from_proto(request.called_user_id);
1253 let room_id = RoomId::from_proto(request.room_id);
1254 {
1255 let room = session
1256 .db()
1257 .await
1258 .cancel_call(room_id, session.connection_id, called_user_id)
1259 .await?;
1260 room_updated(&room, &session.peer);
1261 }
1262
1263 for connection_id in session
1264 .connection_pool()
1265 .await
1266 .user_connection_ids(called_user_id)
1267 {
1268 session
1269 .peer
1270 .send(
1271 connection_id,
1272 proto::CallCanceled {
1273 room_id: room_id.to_proto(),
1274 },
1275 )
1276 .trace_err();
1277 }
1278 response.send(proto::Ack {})?;
1279
1280 update_user_contacts(called_user_id, &session).await?;
1281 Ok(())
1282}
1283
1284async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1285 let room_id = RoomId::from_proto(message.room_id);
1286 {
1287 let room = session
1288 .db()
1289 .await
1290 .decline_call(Some(room_id), session.user_id)
1291 .await?
1292 .ok_or_else(|| anyhow!("failed to decline call"))?;
1293 room_updated(&room, &session.peer);
1294 }
1295
1296 for connection_id in session
1297 .connection_pool()
1298 .await
1299 .user_connection_ids(session.user_id)
1300 {
1301 session
1302 .peer
1303 .send(
1304 connection_id,
1305 proto::CallCanceled {
1306 room_id: room_id.to_proto(),
1307 },
1308 )
1309 .trace_err();
1310 }
1311 update_user_contacts(session.user_id, &session).await?;
1312 Ok(())
1313}
1314
1315async fn update_participant_location(
1316 request: proto::UpdateParticipantLocation,
1317 response: Response<proto::UpdateParticipantLocation>,
1318 session: Session,
1319) -> Result<()> {
1320 let room_id = RoomId::from_proto(request.room_id);
1321 let location = request
1322 .location
1323 .ok_or_else(|| anyhow!("invalid location"))?;
1324
1325 let db = session.db().await;
1326 let room = db
1327 .update_room_participant_location(room_id, session.connection_id, location)
1328 .await?;
1329
1330 room_updated(&room, &session.peer);
1331 response.send(proto::Ack {})?;
1332 Ok(())
1333}
1334
1335async fn share_project(
1336 request: proto::ShareProject,
1337 response: Response<proto::ShareProject>,
1338 session: Session,
1339) -> Result<()> {
1340 let (project_id, room) = &*session
1341 .db()
1342 .await
1343 .share_project(
1344 RoomId::from_proto(request.room_id),
1345 session.connection_id,
1346 &request.worktrees,
1347 )
1348 .await?;
1349 response.send(proto::ShareProjectResponse {
1350 project_id: project_id.to_proto(),
1351 })?;
1352 room_updated(&room, &session.peer);
1353
1354 Ok(())
1355}
1356
1357async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1358 let project_id = ProjectId::from_proto(message.project_id);
1359
1360 let (room, guest_connection_ids) = &*session
1361 .db()
1362 .await
1363 .unshare_project(project_id, session.connection_id)
1364 .await?;
1365
1366 broadcast(
1367 Some(session.connection_id),
1368 guest_connection_ids.iter().copied(),
1369 |conn_id| session.peer.send(conn_id, message.clone()),
1370 );
1371 room_updated(&room, &session.peer);
1372
1373 Ok(())
1374}
1375
1376async fn join_project(
1377 request: proto::JoinProject,
1378 response: Response<proto::JoinProject>,
1379 session: Session,
1380) -> Result<()> {
1381 let project_id = ProjectId::from_proto(request.project_id);
1382 let guest_user_id = session.user_id;
1383
1384 tracing::info!(%project_id, "join project");
1385
1386 let (project, replica_id) = &mut *session
1387 .db()
1388 .await
1389 .join_project(project_id, session.connection_id)
1390 .await?;
1391
1392 let collaborators = project
1393 .collaborators
1394 .iter()
1395 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1396 .map(|collaborator| collaborator.to_proto())
1397 .collect::<Vec<_>>();
1398
1399 let worktrees = project
1400 .worktrees
1401 .iter()
1402 .map(|(id, worktree)| proto::WorktreeMetadata {
1403 id: *id,
1404 root_name: worktree.root_name.clone(),
1405 visible: worktree.visible,
1406 abs_path: worktree.abs_path.clone(),
1407 })
1408 .collect::<Vec<_>>();
1409
1410 for collaborator in &collaborators {
1411 session
1412 .peer
1413 .send(
1414 collaborator.peer_id.unwrap().into(),
1415 proto::AddProjectCollaborator {
1416 project_id: project_id.to_proto(),
1417 collaborator: Some(proto::Collaborator {
1418 peer_id: Some(session.connection_id.into()),
1419 replica_id: replica_id.0 as u32,
1420 user_id: guest_user_id.to_proto(),
1421 }),
1422 },
1423 )
1424 .trace_err();
1425 }
1426
1427 // First, we send the metadata associated with each worktree.
1428 response.send(proto::JoinProjectResponse {
1429 worktrees: worktrees.clone(),
1430 replica_id: replica_id.0 as u32,
1431 collaborators: collaborators.clone(),
1432 language_servers: project.language_servers.clone(),
1433 })?;
1434
1435 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1436 #[cfg(any(test, feature = "test-support"))]
1437 const MAX_CHUNK_SIZE: usize = 2;
1438 #[cfg(not(any(test, feature = "test-support")))]
1439 const MAX_CHUNK_SIZE: usize = 256;
1440
1441 // Stream this worktree's entries.
1442 let message = proto::UpdateWorktree {
1443 project_id: project_id.to_proto(),
1444 worktree_id,
1445 abs_path: worktree.abs_path.clone(),
1446 root_name: worktree.root_name,
1447 updated_entries: worktree.entries,
1448 removed_entries: Default::default(),
1449 scan_id: worktree.scan_id,
1450 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1451 updated_repositories: worktree.repository_entries.into_values().collect(),
1452 removed_repositories: Default::default(),
1453 };
1454 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1455 session.peer.send(session.connection_id, update.clone())?;
1456 }
1457
1458 // Stream this worktree's diagnostics.
1459 for summary in worktree.diagnostic_summaries {
1460 session.peer.send(
1461 session.connection_id,
1462 proto::UpdateDiagnosticSummary {
1463 project_id: project_id.to_proto(),
1464 worktree_id: worktree.id,
1465 summary: Some(summary),
1466 },
1467 )?;
1468 }
1469
1470 for settings_file in worktree.settings_files {
1471 session.peer.send(
1472 session.connection_id,
1473 proto::UpdateWorktreeSettings {
1474 project_id: project_id.to_proto(),
1475 worktree_id: worktree.id,
1476 path: settings_file.path,
1477 content: Some(settings_file.content),
1478 },
1479 )?;
1480 }
1481 }
1482
1483 for language_server in &project.language_servers {
1484 session.peer.send(
1485 session.connection_id,
1486 proto::UpdateLanguageServer {
1487 project_id: project_id.to_proto(),
1488 language_server_id: language_server.id,
1489 variant: Some(
1490 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1491 proto::LspDiskBasedDiagnosticsUpdated {},
1492 ),
1493 ),
1494 },
1495 )?;
1496 }
1497
1498 Ok(())
1499}
1500
1501async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1502 let sender_id = session.connection_id;
1503 let project_id = ProjectId::from_proto(request.project_id);
1504
1505 let (room, project) = &*session
1506 .db()
1507 .await
1508 .leave_project(project_id, sender_id)
1509 .await?;
1510 tracing::info!(
1511 %project_id,
1512 host_user_id = %project.host_user_id,
1513 host_connection_id = %project.host_connection_id,
1514 "leave project"
1515 );
1516
1517 project_left(&project, &session);
1518 room_updated(&room, &session.peer);
1519
1520 Ok(())
1521}
1522
1523async fn update_project(
1524 request: proto::UpdateProject,
1525 response: Response<proto::UpdateProject>,
1526 session: Session,
1527) -> Result<()> {
1528 let project_id = ProjectId::from_proto(request.project_id);
1529 let (room, guest_connection_ids) = &*session
1530 .db()
1531 .await
1532 .update_project(project_id, session.connection_id, &request.worktrees)
1533 .await?;
1534 broadcast(
1535 Some(session.connection_id),
1536 guest_connection_ids.iter().copied(),
1537 |connection_id| {
1538 session
1539 .peer
1540 .forward_send(session.connection_id, connection_id, request.clone())
1541 },
1542 );
1543 room_updated(&room, &session.peer);
1544 response.send(proto::Ack {})?;
1545
1546 Ok(())
1547}
1548
1549async fn update_worktree(
1550 request: proto::UpdateWorktree,
1551 response: Response<proto::UpdateWorktree>,
1552 session: Session,
1553) -> Result<()> {
1554 let guest_connection_ids = session
1555 .db()
1556 .await
1557 .update_worktree(&request, session.connection_id)
1558 .await?;
1559
1560 broadcast(
1561 Some(session.connection_id),
1562 guest_connection_ids.iter().copied(),
1563 |connection_id| {
1564 session
1565 .peer
1566 .forward_send(session.connection_id, connection_id, request.clone())
1567 },
1568 );
1569 response.send(proto::Ack {})?;
1570 Ok(())
1571}
1572
1573async fn update_diagnostic_summary(
1574 message: proto::UpdateDiagnosticSummary,
1575 session: Session,
1576) -> Result<()> {
1577 let guest_connection_ids = session
1578 .db()
1579 .await
1580 .update_diagnostic_summary(&message, session.connection_id)
1581 .await?;
1582
1583 broadcast(
1584 Some(session.connection_id),
1585 guest_connection_ids.iter().copied(),
1586 |connection_id| {
1587 session
1588 .peer
1589 .forward_send(session.connection_id, connection_id, message.clone())
1590 },
1591 );
1592
1593 Ok(())
1594}
1595
1596async fn update_worktree_settings(
1597 message: proto::UpdateWorktreeSettings,
1598 session: Session,
1599) -> Result<()> {
1600 let guest_connection_ids = session
1601 .db()
1602 .await
1603 .update_worktree_settings(&message, session.connection_id)
1604 .await?;
1605
1606 broadcast(
1607 Some(session.connection_id),
1608 guest_connection_ids.iter().copied(),
1609 |connection_id| {
1610 session
1611 .peer
1612 .forward_send(session.connection_id, connection_id, message.clone())
1613 },
1614 );
1615
1616 Ok(())
1617}
1618
1619async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1620 broadcast_project_message(request.project_id, request, session).await
1621}
1622
1623async fn start_language_server(
1624 request: proto::StartLanguageServer,
1625 session: Session,
1626) -> Result<()> {
1627 let guest_connection_ids = session
1628 .db()
1629 .await
1630 .start_language_server(&request, session.connection_id)
1631 .await?;
1632
1633 broadcast(
1634 Some(session.connection_id),
1635 guest_connection_ids.iter().copied(),
1636 |connection_id| {
1637 session
1638 .peer
1639 .forward_send(session.connection_id, connection_id, request.clone())
1640 },
1641 );
1642 Ok(())
1643}
1644
1645async fn update_language_server(
1646 request: proto::UpdateLanguageServer,
1647 session: Session,
1648) -> Result<()> {
1649 session.executor.record_backtrace();
1650 let project_id = ProjectId::from_proto(request.project_id);
1651 let project_connection_ids = session
1652 .db()
1653 .await
1654 .project_connection_ids(project_id, session.connection_id)
1655 .await?;
1656 broadcast(
1657 Some(session.connection_id),
1658 project_connection_ids.iter().copied(),
1659 |connection_id| {
1660 session
1661 .peer
1662 .forward_send(session.connection_id, connection_id, request.clone())
1663 },
1664 );
1665 Ok(())
1666}
1667
1668async fn forward_project_request<T>(
1669 request: T,
1670 response: Response<T>,
1671 session: Session,
1672) -> Result<()>
1673where
1674 T: EntityMessage + RequestMessage,
1675{
1676 session.executor.record_backtrace();
1677 let project_id = ProjectId::from_proto(request.remote_entity_id());
1678 let host_connection_id = {
1679 let collaborators = session
1680 .db()
1681 .await
1682 .project_collaborators(project_id, session.connection_id)
1683 .await?;
1684 collaborators
1685 .iter()
1686 .find(|collaborator| collaborator.is_host)
1687 .ok_or_else(|| anyhow!("host not found"))?
1688 .connection_id
1689 };
1690
1691 let payload = session
1692 .peer
1693 .forward_request(session.connection_id, host_connection_id, request)
1694 .await?;
1695
1696 response.send(payload)?;
1697 Ok(())
1698}
1699
1700async fn create_buffer_for_peer(
1701 request: proto::CreateBufferForPeer,
1702 session: Session,
1703) -> Result<()> {
1704 session.executor.record_backtrace();
1705 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1706 session
1707 .peer
1708 .forward_send(session.connection_id, peer_id.into(), request)?;
1709 Ok(())
1710}
1711
1712async fn update_buffer(
1713 request: proto::UpdateBuffer,
1714 response: Response<proto::UpdateBuffer>,
1715 session: Session,
1716) -> Result<()> {
1717 session.executor.record_backtrace();
1718 let project_id = ProjectId::from_proto(request.project_id);
1719 let mut guest_connection_ids;
1720 let mut host_connection_id = None;
1721 {
1722 let collaborators = session
1723 .db()
1724 .await
1725 .project_collaborators(project_id, session.connection_id)
1726 .await?;
1727 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1728 for collaborator in collaborators.iter() {
1729 if collaborator.is_host {
1730 host_connection_id = Some(collaborator.connection_id);
1731 } else {
1732 guest_connection_ids.push(collaborator.connection_id);
1733 }
1734 }
1735 }
1736 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1737
1738 session.executor.record_backtrace();
1739 broadcast(
1740 Some(session.connection_id),
1741 guest_connection_ids,
1742 |connection_id| {
1743 session
1744 .peer
1745 .forward_send(session.connection_id, connection_id, request.clone())
1746 },
1747 );
1748 if host_connection_id != session.connection_id {
1749 session
1750 .peer
1751 .forward_request(session.connection_id, host_connection_id, request.clone())
1752 .await?;
1753 }
1754
1755 response.send(proto::Ack {})?;
1756 Ok(())
1757}
1758
1759async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1760 let project_id = ProjectId::from_proto(request.project_id);
1761 let project_connection_ids = session
1762 .db()
1763 .await
1764 .project_connection_ids(project_id, session.connection_id)
1765 .await?;
1766
1767 broadcast(
1768 Some(session.connection_id),
1769 project_connection_ids.iter().copied(),
1770 |connection_id| {
1771 session
1772 .peer
1773 .forward_send(session.connection_id, connection_id, request.clone())
1774 },
1775 );
1776 Ok(())
1777}
1778
1779async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1780 let project_id = ProjectId::from_proto(request.project_id);
1781 let project_connection_ids = session
1782 .db()
1783 .await
1784 .project_connection_ids(project_id, session.connection_id)
1785 .await?;
1786 broadcast(
1787 Some(session.connection_id),
1788 project_connection_ids.iter().copied(),
1789 |connection_id| {
1790 session
1791 .peer
1792 .forward_send(session.connection_id, connection_id, request.clone())
1793 },
1794 );
1795 Ok(())
1796}
1797
1798async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1799 broadcast_project_message(request.project_id, request, session).await
1800}
1801
1802async fn broadcast_project_message<T: EnvelopedMessage>(
1803 project_id: u64,
1804 request: T,
1805 session: Session,
1806) -> Result<()> {
1807 let project_id = ProjectId::from_proto(project_id);
1808 let project_connection_ids = session
1809 .db()
1810 .await
1811 .project_connection_ids(project_id, session.connection_id)
1812 .await?;
1813 broadcast(
1814 Some(session.connection_id),
1815 project_connection_ids.iter().copied(),
1816 |connection_id| {
1817 session
1818 .peer
1819 .forward_send(session.connection_id, connection_id, request.clone())
1820 },
1821 );
1822 Ok(())
1823}
1824
1825async fn follow(
1826 request: proto::Follow,
1827 response: Response<proto::Follow>,
1828 session: Session,
1829) -> Result<()> {
1830 let project_id = ProjectId::from_proto(request.project_id);
1831 let leader_id = request
1832 .leader_id
1833 .ok_or_else(|| anyhow!("invalid leader id"))?
1834 .into();
1835 let follower_id = session.connection_id;
1836
1837 {
1838 let project_connection_ids = session
1839 .db()
1840 .await
1841 .project_connection_ids(project_id, session.connection_id)
1842 .await?;
1843
1844 if !project_connection_ids.contains(&leader_id) {
1845 Err(anyhow!("no such peer"))?;
1846 }
1847 }
1848
1849 let mut response_payload = session
1850 .peer
1851 .forward_request(session.connection_id, leader_id, request)
1852 .await?;
1853 response_payload
1854 .views
1855 .retain(|view| view.leader_id != Some(follower_id.into()));
1856 response.send(response_payload)?;
1857
1858 let room = session
1859 .db()
1860 .await
1861 .follow(project_id, leader_id, follower_id)
1862 .await?;
1863 room_updated(&room, &session.peer);
1864
1865 Ok(())
1866}
1867
1868async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1869 let project_id = ProjectId::from_proto(request.project_id);
1870 let leader_id = request
1871 .leader_id
1872 .ok_or_else(|| anyhow!("invalid leader id"))?
1873 .into();
1874 let follower_id = session.connection_id;
1875
1876 if !session
1877 .db()
1878 .await
1879 .project_connection_ids(project_id, session.connection_id)
1880 .await?
1881 .contains(&leader_id)
1882 {
1883 Err(anyhow!("no such peer"))?;
1884 }
1885
1886 session
1887 .peer
1888 .forward_send(session.connection_id, leader_id, request)?;
1889
1890 let room = session
1891 .db()
1892 .await
1893 .unfollow(project_id, leader_id, follower_id)
1894 .await?;
1895 room_updated(&room, &session.peer);
1896
1897 Ok(())
1898}
1899
1900async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1901 let project_id = ProjectId::from_proto(request.project_id);
1902 let project_connection_ids = session
1903 .db
1904 .lock()
1905 .await
1906 .project_connection_ids(project_id, session.connection_id)
1907 .await?;
1908
1909 let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1910 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1911 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1912 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1913 });
1914 for follower_peer_id in request.follower_ids.iter().copied() {
1915 let follower_connection_id = follower_peer_id.into();
1916 if project_connection_ids.contains(&follower_connection_id)
1917 && Some(follower_peer_id) != leader_id
1918 {
1919 session.peer.forward_send(
1920 session.connection_id,
1921 follower_connection_id,
1922 request.clone(),
1923 )?;
1924 }
1925 }
1926 Ok(())
1927}
1928
1929async fn get_users(
1930 request: proto::GetUsers,
1931 response: Response<proto::GetUsers>,
1932 session: Session,
1933) -> Result<()> {
1934 let user_ids = request
1935 .user_ids
1936 .into_iter()
1937 .map(UserId::from_proto)
1938 .collect();
1939 let users = session
1940 .db()
1941 .await
1942 .get_users_by_ids(user_ids)
1943 .await?
1944 .into_iter()
1945 .map(|user| proto::User {
1946 id: user.id.to_proto(),
1947 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1948 github_login: user.github_login,
1949 })
1950 .collect();
1951 response.send(proto::UsersResponse { users })?;
1952 Ok(())
1953}
1954
1955async fn fuzzy_search_users(
1956 request: proto::FuzzySearchUsers,
1957 response: Response<proto::FuzzySearchUsers>,
1958 session: Session,
1959) -> Result<()> {
1960 let query = request.query;
1961 let users = match query.len() {
1962 0 => vec![],
1963 1 | 2 => session
1964 .db()
1965 .await
1966 .get_user_by_github_login(&query)
1967 .await?
1968 .into_iter()
1969 .collect(),
1970 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1971 };
1972 let users = users
1973 .into_iter()
1974 .filter(|user| user.id != session.user_id)
1975 .map(|user| proto::User {
1976 id: user.id.to_proto(),
1977 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1978 github_login: user.github_login,
1979 })
1980 .collect();
1981 response.send(proto::UsersResponse { users })?;
1982 Ok(())
1983}
1984
1985async fn request_contact(
1986 request: proto::RequestContact,
1987 response: Response<proto::RequestContact>,
1988 session: Session,
1989) -> Result<()> {
1990 let requester_id = session.user_id;
1991 let responder_id = UserId::from_proto(request.responder_id);
1992 if requester_id == responder_id {
1993 return Err(anyhow!("cannot add yourself as a contact"))?;
1994 }
1995
1996 session
1997 .db()
1998 .await
1999 .send_contact_request(requester_id, responder_id)
2000 .await?;
2001
2002 // Update outgoing contact requests of requester
2003 let mut update = proto::UpdateContacts::default();
2004 update.outgoing_requests.push(responder_id.to_proto());
2005 for connection_id in session
2006 .connection_pool()
2007 .await
2008 .user_connection_ids(requester_id)
2009 {
2010 session.peer.send(connection_id, update.clone())?;
2011 }
2012
2013 // Update incoming contact requests of responder
2014 let mut update = proto::UpdateContacts::default();
2015 update
2016 .incoming_requests
2017 .push(proto::IncomingContactRequest {
2018 requester_id: requester_id.to_proto(),
2019 should_notify: true,
2020 });
2021 for connection_id in session
2022 .connection_pool()
2023 .await
2024 .user_connection_ids(responder_id)
2025 {
2026 session.peer.send(connection_id, update.clone())?;
2027 }
2028
2029 response.send(proto::Ack {})?;
2030 Ok(())
2031}
2032
2033async fn respond_to_contact_request(
2034 request: proto::RespondToContactRequest,
2035 response: Response<proto::RespondToContactRequest>,
2036 session: Session,
2037) -> Result<()> {
2038 let responder_id = session.user_id;
2039 let requester_id = UserId::from_proto(request.requester_id);
2040 let db = session.db().await;
2041 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2042 db.dismiss_contact_notification(responder_id, requester_id)
2043 .await?;
2044 } else {
2045 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2046
2047 db.respond_to_contact_request(responder_id, requester_id, accept)
2048 .await?;
2049 let requester_busy = db.is_user_busy(requester_id).await?;
2050 let responder_busy = db.is_user_busy(responder_id).await?;
2051
2052 let pool = session.connection_pool().await;
2053 // Update responder with new contact
2054 let mut update = proto::UpdateContacts::default();
2055 if accept {
2056 update
2057 .contacts
2058 .push(contact_for_user(requester_id, false, requester_busy, &pool));
2059 }
2060 update
2061 .remove_incoming_requests
2062 .push(requester_id.to_proto());
2063 for connection_id in pool.user_connection_ids(responder_id) {
2064 session.peer.send(connection_id, update.clone())?;
2065 }
2066
2067 // Update requester with new contact
2068 let mut update = proto::UpdateContacts::default();
2069 if accept {
2070 update
2071 .contacts
2072 .push(contact_for_user(responder_id, true, responder_busy, &pool));
2073 }
2074 update
2075 .remove_outgoing_requests
2076 .push(responder_id.to_proto());
2077 for connection_id in pool.user_connection_ids(requester_id) {
2078 session.peer.send(connection_id, update.clone())?;
2079 }
2080 }
2081
2082 response.send(proto::Ack {})?;
2083 Ok(())
2084}
2085
2086async fn remove_contact(
2087 request: proto::RemoveContact,
2088 response: Response<proto::RemoveContact>,
2089 session: Session,
2090) -> Result<()> {
2091 let requester_id = session.user_id;
2092 let responder_id = UserId::from_proto(request.user_id);
2093 let db = session.db().await;
2094 let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
2095
2096 let pool = session.connection_pool().await;
2097 // Update outgoing contact requests of requester
2098 let mut update = proto::UpdateContacts::default();
2099 if contact_accepted {
2100 update.remove_contacts.push(responder_id.to_proto());
2101 } else {
2102 update
2103 .remove_outgoing_requests
2104 .push(responder_id.to_proto());
2105 }
2106 for connection_id in pool.user_connection_ids(requester_id) {
2107 session.peer.send(connection_id, update.clone())?;
2108 }
2109
2110 // Update incoming contact requests of responder
2111 let mut update = proto::UpdateContacts::default();
2112 if contact_accepted {
2113 update.remove_contacts.push(requester_id.to_proto());
2114 } else {
2115 update
2116 .remove_incoming_requests
2117 .push(requester_id.to_proto());
2118 }
2119 for connection_id in pool.user_connection_ids(responder_id) {
2120 session.peer.send(connection_id, update.clone())?;
2121 }
2122
2123 response.send(proto::Ack {})?;
2124 Ok(())
2125}
2126
2127async fn create_channel(
2128 request: proto::CreateChannel,
2129 response: Response<proto::CreateChannel>,
2130 session: Session,
2131) -> Result<()> {
2132 let db = session.db().await;
2133 let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
2134
2135 if let Some(live_kit) = session.live_kit_client.as_ref() {
2136 live_kit.create_room(live_kit_room.clone()).await?;
2137 }
2138
2139 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2140 let id = db
2141 .create_channel(&request.name, parent_id, &live_kit_room, session.user_id)
2142 .await?;
2143
2144 response.send(proto::CreateChannelResponse {
2145 channel_id: id.to_proto(),
2146 })?;
2147
2148 let mut update = proto::UpdateChannels::default();
2149 update.channels.push(proto::Channel {
2150 id: id.to_proto(),
2151 name: request.name,
2152 parent_id: request.parent_id,
2153 });
2154
2155 if let Some(parent_id) = parent_id {
2156 let member_ids = db.get_channel_members(parent_id).await?;
2157 let connection_pool = session.connection_pool().await;
2158 for member_id in member_ids {
2159 for connection_id in connection_pool.user_connection_ids(member_id) {
2160 session.peer.send(connection_id, update.clone())?;
2161 }
2162 }
2163 } else {
2164 session.peer.send(session.connection_id, update)?;
2165 }
2166
2167 Ok(())
2168}
2169
2170async fn remove_channel(
2171 request: proto::RemoveChannel,
2172 response: Response<proto::RemoveChannel>,
2173 session: Session,
2174) -> Result<()> {
2175 let db = session.db().await;
2176
2177 let channel_id = request.channel_id;
2178 let (removed_channels, member_ids) = db
2179 .remove_channel(ChannelId::from_proto(channel_id), session.user_id)
2180 .await?;
2181 response.send(proto::Ack {})?;
2182
2183 // Notify members of removed channels
2184 let mut update = proto::UpdateChannels::default();
2185 update
2186 .remove_channels
2187 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2188
2189 let connection_pool = session.connection_pool().await;
2190 for member_id in member_ids {
2191 for connection_id in connection_pool.user_connection_ids(member_id) {
2192 session.peer.send(connection_id, update.clone())?;
2193 }
2194 }
2195
2196 Ok(())
2197}
2198
2199async fn invite_channel_member(
2200 request: proto::InviteChannelMember,
2201 response: Response<proto::InviteChannelMember>,
2202 session: Session,
2203) -> Result<()> {
2204 let db = session.db().await;
2205 let channel_id = ChannelId::from_proto(request.channel_id);
2206 let channel = db
2207 .get_channel(channel_id)
2208 .await?
2209 .ok_or_else(|| anyhow!("channel not found"))?;
2210 let invitee_id = UserId::from_proto(request.user_id);
2211 db.invite_channel_member(channel_id, invitee_id, session.user_id, false)
2212 .await?;
2213
2214 let mut update = proto::UpdateChannels::default();
2215 update.channel_invitations.push(proto::Channel {
2216 id: channel.id.to_proto(),
2217 name: channel.name,
2218 parent_id: None,
2219 });
2220 for connection_id in session
2221 .connection_pool()
2222 .await
2223 .user_connection_ids(invitee_id)
2224 {
2225 session.peer.send(connection_id, update.clone())?;
2226 }
2227
2228 response.send(proto::Ack {})?;
2229 Ok(())
2230}
2231
2232async fn remove_channel_member(
2233 request: proto::RemoveChannelMember,
2234 response: Response<proto::RemoveChannelMember>,
2235 session: Session,
2236) -> Result<()> {
2237 let db = session.db().await;
2238 let channel_id = ChannelId::from_proto(request.channel_id);
2239 let member_id = UserId::from_proto(request.user_id);
2240 db.remove_channel_member(channel_id, member_id, session.user_id)
2241 .await?;
2242 response.send(proto::Ack {})?;
2243 Ok(())
2244}
2245
2246async fn get_channel_members(
2247 request: proto::GetChannelMembers,
2248 response: Response<proto::GetChannelMembers>,
2249 session: Session,
2250) -> Result<()> {
2251 let db = session.db().await;
2252 let channel_id = ChannelId::from_proto(request.channel_id);
2253 let members = db
2254 .get_channel_member_details(channel_id, session.user_id)
2255 .await?;
2256 response.send(proto::GetChannelMembersResponse { members })?;
2257 Ok(())
2258}
2259
2260async fn respond_to_channel_invite(
2261 request: proto::RespondToChannelInvite,
2262 response: Response<proto::RespondToChannelInvite>,
2263 session: Session,
2264) -> Result<()> {
2265 let db = session.db().await;
2266 let channel_id = ChannelId::from_proto(request.channel_id);
2267 let channel = db
2268 .get_channel(channel_id)
2269 .await?
2270 .ok_or_else(|| anyhow!("no such channel"))?;
2271 db.respond_to_channel_invite(channel_id, session.user_id, request.accept)
2272 .await?;
2273
2274 let mut update = proto::UpdateChannels::default();
2275 update
2276 .remove_channel_invitations
2277 .push(channel_id.to_proto());
2278 if request.accept {
2279 update.channels.push(proto::Channel {
2280 id: channel.id.to_proto(),
2281 name: channel.name,
2282 parent_id: None,
2283 });
2284 }
2285 session.peer.send(session.connection_id, update)?;
2286 response.send(proto::Ack {})?;
2287
2288 Ok(())
2289}
2290
2291async fn join_channel(
2292 request: proto::JoinChannel,
2293 response: Response<proto::JoinChannel>,
2294 session: Session,
2295) -> Result<()> {
2296 let channel_id = ChannelId::from_proto(request.channel_id);
2297
2298 let joined_room = {
2299 let db = session.db().await;
2300
2301 let room_id = db.room_id_for_channel(channel_id).await?;
2302
2303 let joined_room = db
2304 .join_room(
2305 room_id,
2306 session.user_id,
2307 Some(channel_id),
2308 session.connection_id,
2309 )
2310 .await?;
2311
2312 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2313 let token = live_kit
2314 .room_token(
2315 &joined_room.room.live_kit_room,
2316 &session.user_id.to_string(),
2317 )
2318 .trace_err()?;
2319
2320 Some(LiveKitConnectionInfo {
2321 server_url: live_kit.url().into(),
2322 token,
2323 })
2324 });
2325
2326 response.send(proto::JoinRoomResponse {
2327 room: Some(joined_room.room.clone()),
2328 live_kit_connection_info,
2329 })?;
2330
2331 room_updated(&joined_room.room, &session.peer);
2332
2333 joined_room.clone()
2334 };
2335
2336 // TODO - do this while still holding the room guard,
2337 // currently there's a possible race condition if someone joins the channel
2338 // after we've dropped the lock but before we finish sending these updates
2339 channel_updated(
2340 channel_id,
2341 &joined_room.room,
2342 &joined_room.channel_members,
2343 &session.peer,
2344 &*session.connection_pool().await,
2345 );
2346
2347 update_user_contacts(session.user_id, &session).await?;
2348
2349 Ok(())
2350}
2351
2352async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
2353 let project_id = ProjectId::from_proto(request.project_id);
2354 let project_connection_ids = session
2355 .db()
2356 .await
2357 .project_connection_ids(project_id, session.connection_id)
2358 .await?;
2359 broadcast(
2360 Some(session.connection_id),
2361 project_connection_ids.iter().copied(),
2362 |connection_id| {
2363 session
2364 .peer
2365 .forward_send(session.connection_id, connection_id, request.clone())
2366 },
2367 );
2368 Ok(())
2369}
2370
2371async fn get_private_user_info(
2372 _request: proto::GetPrivateUserInfo,
2373 response: Response<proto::GetPrivateUserInfo>,
2374 session: Session,
2375) -> Result<()> {
2376 let metrics_id = session
2377 .db()
2378 .await
2379 .get_user_metrics_id(session.user_id)
2380 .await?;
2381 let user = session
2382 .db()
2383 .await
2384 .get_user_by_id(session.user_id)
2385 .await?
2386 .ok_or_else(|| anyhow!("user not found"))?;
2387 response.send(proto::GetPrivateUserInfoResponse {
2388 metrics_id,
2389 staff: user.admin,
2390 })?;
2391 Ok(())
2392}
2393
2394fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
2395 match message {
2396 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
2397 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
2398 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
2399 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
2400 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
2401 code: frame.code.into(),
2402 reason: frame.reason,
2403 })),
2404 }
2405}
2406
2407fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
2408 match message {
2409 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
2410 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
2411 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
2412 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
2413 AxumMessage::Close(frame) => {
2414 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
2415 code: frame.code.into(),
2416 reason: frame.reason,
2417 }))
2418 }
2419 }
2420}
2421
2422fn build_initial_channels_update(
2423 channels: Vec<db::Channel>,
2424 channel_participants: HashMap<db::ChannelId, Vec<UserId>>,
2425 channel_invites: Vec<db::Channel>,
2426) -> proto::UpdateChannels {
2427 let mut update = proto::UpdateChannels::default();
2428
2429 for channel in channels {
2430 update.channels.push(proto::Channel {
2431 id: channel.id.to_proto(),
2432 name: channel.name,
2433 parent_id: channel.parent_id.map(|id| id.to_proto()),
2434 });
2435 }
2436
2437 for (channel_id, participants) in channel_participants {
2438 update
2439 .channel_participants
2440 .push(proto::ChannelParticipants {
2441 channel_id: channel_id.to_proto(),
2442 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
2443 });
2444 }
2445
2446 for channel in channel_invites {
2447 update.channel_invitations.push(proto::Channel {
2448 id: channel.id.to_proto(),
2449 name: channel.name,
2450 parent_id: None,
2451 });
2452 }
2453
2454 update
2455}
2456
2457fn build_initial_contacts_update(
2458 contacts: Vec<db::Contact>,
2459 pool: &ConnectionPool,
2460) -> proto::UpdateContacts {
2461 let mut update = proto::UpdateContacts::default();
2462
2463 for contact in contacts {
2464 match contact {
2465 db::Contact::Accepted {
2466 user_id,
2467 should_notify,
2468 busy,
2469 } => {
2470 update
2471 .contacts
2472 .push(contact_for_user(user_id, should_notify, busy, &pool));
2473 }
2474 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
2475 db::Contact::Incoming {
2476 user_id,
2477 should_notify,
2478 } => update
2479 .incoming_requests
2480 .push(proto::IncomingContactRequest {
2481 requester_id: user_id.to_proto(),
2482 should_notify,
2483 }),
2484 }
2485 }
2486
2487 update
2488}
2489
2490fn contact_for_user(
2491 user_id: UserId,
2492 should_notify: bool,
2493 busy: bool,
2494 pool: &ConnectionPool,
2495) -> proto::Contact {
2496 proto::Contact {
2497 user_id: user_id.to_proto(),
2498 online: pool.is_user_online(user_id),
2499 busy,
2500 should_notify,
2501 }
2502}
2503
2504fn room_updated(room: &proto::Room, peer: &Peer) {
2505 broadcast(
2506 None,
2507 room.participants
2508 .iter()
2509 .filter_map(|participant| Some(participant.peer_id?.into())),
2510 |peer_id| {
2511 peer.send(
2512 peer_id.into(),
2513 proto::RoomUpdated {
2514 room: Some(room.clone()),
2515 },
2516 )
2517 },
2518 );
2519}
2520
2521fn channel_updated(
2522 channel_id: ChannelId,
2523 room: &proto::Room,
2524 channel_members: &[UserId],
2525 peer: &Peer,
2526 pool: &ConnectionPool,
2527) {
2528 let participants = room
2529 .participants
2530 .iter()
2531 .map(|p| p.user_id)
2532 .collect::<Vec<_>>();
2533
2534 broadcast(
2535 None,
2536 channel_members
2537 .iter()
2538 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2539 |peer_id| {
2540 peer.send(
2541 peer_id.into(),
2542 proto::UpdateChannels {
2543 channel_participants: vec![proto::ChannelParticipants {
2544 channel_id: channel_id.to_proto(),
2545 participant_user_ids: participants.clone(),
2546 }],
2547 ..Default::default()
2548 },
2549 )
2550 },
2551 );
2552}
2553
2554async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
2555 let db = session.db().await;
2556
2557 let contacts = db.get_contacts(user_id).await?;
2558 let busy = db.is_user_busy(user_id).await?;
2559
2560 let pool = session.connection_pool().await;
2561 let updated_contact = contact_for_user(user_id, false, busy, &pool);
2562 for contact in contacts {
2563 if let db::Contact::Accepted {
2564 user_id: contact_user_id,
2565 ..
2566 } = contact
2567 {
2568 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
2569 session
2570 .peer
2571 .send(
2572 contact_conn_id,
2573 proto::UpdateContacts {
2574 contacts: vec![updated_contact.clone()],
2575 remove_contacts: Default::default(),
2576 incoming_requests: Default::default(),
2577 remove_incoming_requests: Default::default(),
2578 outgoing_requests: Default::default(),
2579 remove_outgoing_requests: Default::default(),
2580 },
2581 )
2582 .trace_err();
2583 }
2584 }
2585 }
2586 Ok(())
2587}
2588
2589async fn leave_room_for_session(session: &Session) -> Result<()> {
2590 let mut contacts_to_update = HashSet::default();
2591
2592 let room_id;
2593 let canceled_calls_to_user_ids;
2594 let live_kit_room;
2595 let delete_live_kit_room;
2596 let room;
2597 let channel_members;
2598 let channel_id;
2599
2600 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
2601 contacts_to_update.insert(session.user_id);
2602
2603 for project in left_room.left_projects.values() {
2604 project_left(project, session);
2605 }
2606
2607 room_id = RoomId::from_proto(left_room.room.id);
2608 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
2609 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
2610 delete_live_kit_room = left_room.deleted;
2611 room = mem::take(&mut left_room.room);
2612 channel_members = mem::take(&mut left_room.channel_members);
2613 channel_id = left_room.channel_id;
2614
2615 room_updated(&room, &session.peer);
2616 } else {
2617 return Ok(());
2618 }
2619
2620 // TODO - do this while holding the room guard.
2621 if let Some(channel_id) = channel_id {
2622 channel_updated(
2623 channel_id,
2624 &room,
2625 &channel_members,
2626 &session.peer,
2627 &*session.connection_pool().await,
2628 );
2629 }
2630
2631 {
2632 let pool = session.connection_pool().await;
2633 for canceled_user_id in canceled_calls_to_user_ids {
2634 for connection_id in pool.user_connection_ids(canceled_user_id) {
2635 session
2636 .peer
2637 .send(
2638 connection_id,
2639 proto::CallCanceled {
2640 room_id: room_id.to_proto(),
2641 },
2642 )
2643 .trace_err();
2644 }
2645 contacts_to_update.insert(canceled_user_id);
2646 }
2647 }
2648
2649 for contact_user_id in contacts_to_update {
2650 update_user_contacts(contact_user_id, &session).await?;
2651 }
2652
2653 if let Some(live_kit) = session.live_kit_client.as_ref() {
2654 live_kit
2655 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
2656 .await
2657 .trace_err();
2658
2659 if delete_live_kit_room {
2660 live_kit.delete_room(live_kit_room).await.trace_err();
2661 }
2662 }
2663
2664 Ok(())
2665}
2666
2667fn project_left(project: &db::LeftProject, session: &Session) {
2668 for connection_id in &project.connection_ids {
2669 if project.host_user_id == session.user_id {
2670 session
2671 .peer
2672 .send(
2673 *connection_id,
2674 proto::UnshareProject {
2675 project_id: project.id.to_proto(),
2676 },
2677 )
2678 .trace_err();
2679 } else {
2680 session
2681 .peer
2682 .send(
2683 *connection_id,
2684 proto::RemoveProjectCollaborator {
2685 project_id: project.id.to_proto(),
2686 peer_id: Some(session.connection_id.into()),
2687 },
2688 )
2689 .trace_err();
2690 }
2691 }
2692}
2693
2694pub trait ResultExt {
2695 type Ok;
2696
2697 fn trace_err(self) -> Option<Self::Ok>;
2698}
2699
2700impl<T, E> ResultExt for Result<T, E>
2701where
2702 E: std::fmt::Debug,
2703{
2704 type Ok = T;
2705
2706 fn trace_err(self) -> Option<T> {
2707 match self {
2708 Ok(value) => Some(value),
2709 Err(error) => {
2710 tracing::error!("{:?}", error);
2711 None
2712 }
2713 }
2714 }
2715}