1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{self, ChannelId, Database, ProjectId, RoomId, ServerId, User, UserId},
6 executor::Executor,
7 AppState, Result,
8};
9use anyhow::anyhow;
10use async_tungstenite::tungstenite::{
11 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
12};
13use axum::{
14 body::Body,
15 extract::{
16 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
17 ConnectInfo, WebSocketUpgrade,
18 },
19 headers::{Header, HeaderName},
20 http::StatusCode,
21 middleware,
22 response::IntoResponse,
23 routing::get,
24 Extension, Router, TypedHeader,
25};
26use collections::{HashMap, HashSet};
27pub use connection_pool::ConnectionPool;
28use futures::{
29 channel::oneshot,
30 future::{self, BoxFuture},
31 stream::FuturesUnordered,
32 FutureExt, SinkExt, StreamExt, TryStreamExt,
33};
34use lazy_static::lazy_static;
35use prometheus::{register_int_gauge, IntGauge};
36use rpc::{
37 proto::{
38 self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
39 RequestMessage,
40 },
41 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
42};
43use serde::{Serialize, Serializer};
44use std::{
45 any::TypeId,
46 fmt,
47 future::Future,
48 marker::PhantomData,
49 mem,
50 net::SocketAddr,
51 ops::{Deref, DerefMut},
52 rc::Rc,
53 sync::{
54 atomic::{AtomicBool, Ordering::SeqCst},
55 Arc,
56 },
57 time::{Duration, Instant},
58};
59use tokio::sync::{watch, Semaphore};
60use tower::ServiceBuilder;
61use tracing::{info_span, instrument, Instrument};
62
63pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
64pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
65
66lazy_static! {
67 static ref METRIC_CONNECTIONS: IntGauge =
68 register_int_gauge!("connections", "number of connections").unwrap();
69 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
70 "shared_projects",
71 "number of open projects with one or more guests"
72 )
73 .unwrap();
74}
75
76type MessageHandler =
77 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
78
79struct Response<R> {
80 peer: Arc<Peer>,
81 receipt: Receipt<R>,
82 responded: Arc<AtomicBool>,
83}
84
85impl<R: RequestMessage> Response<R> {
86 fn send(self, payload: R::Response) -> Result<()> {
87 self.responded.store(true, SeqCst);
88 self.peer.respond(self.receipt, payload)?;
89 Ok(())
90 }
91}
92
93#[derive(Clone)]
94struct Session {
95 user_id: UserId,
96 connection_id: ConnectionId,
97 db: Arc<tokio::sync::Mutex<DbHandle>>,
98 peer: Arc<Peer>,
99 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
100 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
101 executor: Executor,
102}
103
104impl Session {
105 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
106 #[cfg(test)]
107 tokio::task::yield_now().await;
108 let guard = self.db.lock().await;
109 #[cfg(test)]
110 tokio::task::yield_now().await;
111 guard
112 }
113
114 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
115 #[cfg(test)]
116 tokio::task::yield_now().await;
117 let guard = self.connection_pool.lock();
118 ConnectionPoolGuard {
119 guard,
120 _not_send: PhantomData,
121 }
122 }
123}
124
125impl fmt::Debug for Session {
126 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127 f.debug_struct("Session")
128 .field("user_id", &self.user_id)
129 .field("connection_id", &self.connection_id)
130 .finish()
131 }
132}
133
134struct DbHandle(Arc<Database>);
135
136impl Deref for DbHandle {
137 type Target = Database;
138
139 fn deref(&self) -> &Self::Target {
140 self.0.as_ref()
141 }
142}
143
144pub struct Server {
145 id: parking_lot::Mutex<ServerId>,
146 peer: Arc<Peer>,
147 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
148 app_state: Arc<AppState>,
149 executor: Executor,
150 handlers: HashMap<TypeId, MessageHandler>,
151 teardown: watch::Sender<()>,
152}
153
154pub(crate) struct ConnectionPoolGuard<'a> {
155 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
156 _not_send: PhantomData<Rc<()>>,
157}
158
159#[derive(Serialize)]
160pub struct ServerSnapshot<'a> {
161 peer: &'a Peer,
162 #[serde(serialize_with = "serialize_deref")]
163 connection_pool: ConnectionPoolGuard<'a>,
164}
165
166pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
167where
168 S: Serializer,
169 T: Deref<Target = U>,
170 U: Serialize,
171{
172 Serialize::serialize(value.deref(), serializer)
173}
174
175impl Server {
176 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
177 let mut server = Self {
178 id: parking_lot::Mutex::new(id),
179 peer: Peer::new(id.0 as u32),
180 app_state,
181 executor,
182 connection_pool: Default::default(),
183 handlers: Default::default(),
184 teardown: watch::channel(()).0,
185 };
186
187 server
188 .add_request_handler(ping)
189 .add_request_handler(create_room)
190 .add_request_handler(join_room)
191 .add_request_handler(rejoin_room)
192 .add_request_handler(leave_room)
193 .add_request_handler(call)
194 .add_request_handler(cancel_call)
195 .add_message_handler(decline_call)
196 .add_request_handler(update_participant_location)
197 .add_request_handler(share_project)
198 .add_message_handler(unshare_project)
199 .add_request_handler(join_project)
200 .add_message_handler(leave_project)
201 .add_request_handler(update_project)
202 .add_request_handler(update_worktree)
203 .add_message_handler(start_language_server)
204 .add_message_handler(update_language_server)
205 .add_message_handler(update_diagnostic_summary)
206 .add_message_handler(update_worktree_settings)
207 .add_message_handler(refresh_inlay_hints)
208 .add_request_handler(forward_project_request::<proto::GetHover>)
209 .add_request_handler(forward_project_request::<proto::GetDefinition>)
210 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
211 .add_request_handler(forward_project_request::<proto::GetReferences>)
212 .add_request_handler(forward_project_request::<proto::SearchProject>)
213 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
214 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
215 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
216 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
217 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
218 .add_request_handler(forward_project_request::<proto::GetCompletions>)
219 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
220 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
221 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
222 .add_request_handler(forward_project_request::<proto::PrepareRename>)
223 .add_request_handler(forward_project_request::<proto::PerformRename>)
224 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
225 .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
226 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
227 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
228 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
229 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
230 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
231 .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
232 .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
233 .add_request_handler(forward_project_request::<proto::InlayHints>)
234 .add_message_handler(create_buffer_for_peer)
235 .add_request_handler(update_buffer)
236 .add_message_handler(update_buffer_file)
237 .add_message_handler(buffer_reloaded)
238 .add_message_handler(buffer_saved)
239 .add_request_handler(forward_project_request::<proto::SaveBuffer>)
240 .add_request_handler(get_users)
241 .add_request_handler(fuzzy_search_users)
242 .add_request_handler(request_contact)
243 .add_request_handler(remove_contact)
244 .add_request_handler(respond_to_contact_request)
245 .add_request_handler(create_channel)
246 .add_request_handler(remove_channel)
247 .add_request_handler(invite_channel_member)
248 .add_request_handler(remove_channel_member)
249 .add_request_handler(set_channel_member_admin)
250 .add_request_handler(rename_channel)
251 .add_request_handler(get_channel_members)
252 .add_request_handler(respond_to_channel_invite)
253 .add_request_handler(join_channel)
254 .add_request_handler(follow)
255 .add_message_handler(unfollow)
256 .add_message_handler(update_followers)
257 .add_message_handler(update_diff_base)
258 .add_request_handler(get_private_user_info);
259
260 Arc::new(server)
261 }
262
263 pub async fn start(&self) -> Result<()> {
264 let server_id = *self.id.lock();
265 let app_state = self.app_state.clone();
266 let peer = self.peer.clone();
267 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
268 let pool = self.connection_pool.clone();
269 let live_kit_client = self.app_state.live_kit_client.clone();
270
271 let span = info_span!("start server");
272 self.executor.spawn_detached(
273 async move {
274 tracing::info!("waiting for cleanup timeout");
275 timeout.await;
276 tracing::info!("cleanup timeout expired, retrieving stale rooms");
277 if let Some(room_ids) = app_state
278 .db
279 .stale_room_ids(&app_state.config.zed_environment, server_id)
280 .await
281 .trace_err()
282 {
283 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
284 for room_id in room_ids {
285 let mut contacts_to_update = HashSet::default();
286 let mut canceled_calls_to_user_ids = Vec::new();
287 let mut live_kit_room = String::new();
288 let mut delete_live_kit_room = false;
289
290 if let Some(mut refreshed_room) = app_state
291 .db
292 .refresh_room(room_id, server_id)
293 .await
294 .trace_err()
295 {
296 tracing::info!(
297 room_id = room_id.0,
298 new_participant_count = refreshed_room.room.participants.len(),
299 "refreshed room"
300 );
301 room_updated(&refreshed_room.room, &peer);
302 if let Some(channel_id) = refreshed_room.channel_id {
303 channel_updated(
304 channel_id,
305 &refreshed_room.room,
306 &refreshed_room.channel_members,
307 &peer,
308 &*pool.lock(),
309 );
310 }
311 contacts_to_update
312 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
313 contacts_to_update
314 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
315 canceled_calls_to_user_ids =
316 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
317 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
318 delete_live_kit_room = refreshed_room.room.participants.is_empty();
319 }
320
321 {
322 let pool = pool.lock();
323 for canceled_user_id in canceled_calls_to_user_ids {
324 for connection_id in pool.user_connection_ids(canceled_user_id) {
325 peer.send(
326 connection_id,
327 proto::CallCanceled {
328 room_id: room_id.to_proto(),
329 },
330 )
331 .trace_err();
332 }
333 }
334 }
335
336 for user_id in contacts_to_update {
337 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
338 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
339 if let Some((busy, contacts)) = busy.zip(contacts) {
340 let pool = pool.lock();
341 let updated_contact = contact_for_user(user_id, false, busy, &pool);
342 for contact in contacts {
343 if let db::Contact::Accepted {
344 user_id: contact_user_id,
345 ..
346 } = contact
347 {
348 for contact_conn_id in
349 pool.user_connection_ids(contact_user_id)
350 {
351 peer.send(
352 contact_conn_id,
353 proto::UpdateContacts {
354 contacts: vec![updated_contact.clone()],
355 remove_contacts: Default::default(),
356 incoming_requests: Default::default(),
357 remove_incoming_requests: Default::default(),
358 outgoing_requests: Default::default(),
359 remove_outgoing_requests: Default::default(),
360 },
361 )
362 .trace_err();
363 }
364 }
365 }
366 }
367 }
368
369 if let Some(live_kit) = live_kit_client.as_ref() {
370 if delete_live_kit_room {
371 live_kit.delete_room(live_kit_room).await.trace_err();
372 }
373 }
374 }
375 }
376
377 app_state
378 .db
379 .delete_stale_servers(&app_state.config.zed_environment, server_id)
380 .await
381 .trace_err();
382 }
383 .instrument(span),
384 );
385 Ok(())
386 }
387
388 pub fn teardown(&self) {
389 self.peer.teardown();
390 self.connection_pool.lock().reset();
391 let _ = self.teardown.send(());
392 }
393
394 #[cfg(test)]
395 pub fn reset(&self, id: ServerId) {
396 self.teardown();
397 *self.id.lock() = id;
398 self.peer.reset(id.0 as u32);
399 }
400
401 #[cfg(test)]
402 pub fn id(&self) -> ServerId {
403 *self.id.lock()
404 }
405
406 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
407 where
408 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
409 Fut: 'static + Send + Future<Output = Result<()>>,
410 M: EnvelopedMessage,
411 {
412 let prev_handler = self.handlers.insert(
413 TypeId::of::<M>(),
414 Box::new(move |envelope, session| {
415 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
416 let span = info_span!(
417 "handle message",
418 payload_type = envelope.payload_type_name()
419 );
420 span.in_scope(|| {
421 tracing::info!(
422 payload_type = envelope.payload_type_name(),
423 "message received"
424 );
425 });
426 let start_time = Instant::now();
427 let future = (handler)(*envelope, session);
428 async move {
429 let result = future.await;
430 let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
431 match result {
432 Err(error) => {
433 tracing::error!(%error, ?duration_ms, "error handling message")
434 }
435 Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
436 }
437 }
438 .instrument(span)
439 .boxed()
440 }),
441 );
442 if prev_handler.is_some() {
443 panic!("registered a handler for the same message twice");
444 }
445 self
446 }
447
448 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
449 where
450 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
451 Fut: 'static + Send + Future<Output = Result<()>>,
452 M: EnvelopedMessage,
453 {
454 self.add_handler(move |envelope, session| handler(envelope.payload, session));
455 self
456 }
457
458 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
459 where
460 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
461 Fut: Send + Future<Output = Result<()>>,
462 M: RequestMessage,
463 {
464 let handler = Arc::new(handler);
465 self.add_handler(move |envelope, session| {
466 let receipt = envelope.receipt();
467 let handler = handler.clone();
468 async move {
469 let peer = session.peer.clone();
470 let responded = Arc::new(AtomicBool::default());
471 let response = Response {
472 peer: peer.clone(),
473 responded: responded.clone(),
474 receipt,
475 };
476 match (handler)(envelope.payload, response, session).await {
477 Ok(()) => {
478 if responded.load(std::sync::atomic::Ordering::SeqCst) {
479 Ok(())
480 } else {
481 Err(anyhow!("handler did not send a response"))?
482 }
483 }
484 Err(error) => {
485 peer.respond_with_error(
486 receipt,
487 proto::Error {
488 message: error.to_string(),
489 },
490 )?;
491 Err(error)
492 }
493 }
494 }
495 })
496 }
497
498 pub fn handle_connection(
499 self: &Arc<Self>,
500 connection: Connection,
501 address: String,
502 user: User,
503 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
504 executor: Executor,
505 ) -> impl Future<Output = Result<()>> {
506 let this = self.clone();
507 let user_id = user.id;
508 let login = user.github_login;
509 let span = info_span!("handle connection", %user_id, %login, %address);
510 let mut teardown = self.teardown.subscribe();
511 async move {
512 let (connection_id, handle_io, mut incoming_rx) = this
513 .peer
514 .add_connection(connection, {
515 let executor = executor.clone();
516 move |duration| executor.sleep(duration)
517 });
518
519 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
520 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
521 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
522
523 if let Some(send_connection_id) = send_connection_id.take() {
524 let _ = send_connection_id.send(connection_id);
525 }
526
527 if !user.connected_once {
528 this.peer.send(connection_id, proto::ShowContacts {})?;
529 this.app_state.db.set_user_connected_once(user_id, true).await?;
530 }
531
532 let (contacts, invite_code, channels_for_user, channel_invites) = future::try_join4(
533 this.app_state.db.get_contacts(user_id),
534 this.app_state.db.get_invite_code_for_user(user_id),
535 this.app_state.db.get_channels_for_user(user_id),
536 this.app_state.db.get_channel_invites_for_user(user_id)
537 ).await?;
538
539 {
540 let mut pool = this.connection_pool.lock();
541 pool.add_connection(connection_id, user_id, user.admin);
542 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
543 this.peer.send(connection_id, build_initial_channels_update(
544 channels_for_user.channels,
545 channels_for_user.channel_participants,
546 channel_invites
547 ))?;
548
549 if let Some((code, count)) = invite_code {
550 this.peer.send(connection_id, proto::UpdateInviteInfo {
551 url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
552 count: count as u32,
553 })?;
554 }
555 }
556
557 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
558 this.peer.send(connection_id, incoming_call)?;
559 }
560
561 let session = Session {
562 user_id,
563 connection_id,
564 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
565 peer: this.peer.clone(),
566 connection_pool: this.connection_pool.clone(),
567 live_kit_client: this.app_state.live_kit_client.clone(),
568 executor: executor.clone(),
569 };
570 update_user_contacts(user_id, &session).await?;
571
572 let handle_io = handle_io.fuse();
573 futures::pin_mut!(handle_io);
574
575 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
576 // This prevents deadlocks when e.g., client A performs a request to client B and
577 // client B performs a request to client A. If both clients stop processing further
578 // messages until their respective request completes, they won't have a chance to
579 // respond to the other client's request and cause a deadlock.
580 //
581 // This arrangement ensures we will attempt to process earlier messages first, but fall
582 // back to processing messages arrived later in the spirit of making progress.
583 let mut foreground_message_handlers = FuturesUnordered::new();
584 let concurrent_handlers = Arc::new(Semaphore::new(256));
585 loop {
586 let next_message = async {
587 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
588 let message = incoming_rx.next().await;
589 (permit, message)
590 }.fuse();
591 futures::pin_mut!(next_message);
592 futures::select_biased! {
593 _ = teardown.changed().fuse() => return Ok(()),
594 result = handle_io => {
595 if let Err(error) = result {
596 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
597 }
598 break;
599 }
600 _ = foreground_message_handlers.next() => {}
601 next_message = next_message => {
602 let (permit, message) = next_message;
603 if let Some(message) = message {
604 let type_name = message.payload_type_name();
605 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
606 let span_enter = span.enter();
607 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
608 let is_background = message.is_background();
609 let handle_message = (handler)(message, session.clone());
610 drop(span_enter);
611
612 let handle_message = async move {
613 handle_message.await;
614 drop(permit);
615 }.instrument(span);
616 if is_background {
617 executor.spawn_detached(handle_message);
618 } else {
619 foreground_message_handlers.push(handle_message);
620 }
621 } else {
622 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
623 }
624 } else {
625 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
626 break;
627 }
628 }
629 }
630 }
631
632 drop(foreground_message_handlers);
633 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
634 if let Err(error) = connection_lost(session, teardown, executor).await {
635 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
636 }
637
638 Ok(())
639 }.instrument(span)
640 }
641
642 pub async fn invite_code_redeemed(
643 self: &Arc<Self>,
644 inviter_id: UserId,
645 invitee_id: UserId,
646 ) -> Result<()> {
647 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
648 if let Some(code) = &user.invite_code {
649 let pool = self.connection_pool.lock();
650 let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
651 for connection_id in pool.user_connection_ids(inviter_id) {
652 self.peer.send(
653 connection_id,
654 proto::UpdateContacts {
655 contacts: vec![invitee_contact.clone()],
656 ..Default::default()
657 },
658 )?;
659 self.peer.send(
660 connection_id,
661 proto::UpdateInviteInfo {
662 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
663 count: user.invite_count as u32,
664 },
665 )?;
666 }
667 }
668 }
669 Ok(())
670 }
671
672 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
673 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
674 if let Some(invite_code) = &user.invite_code {
675 let pool = self.connection_pool.lock();
676 for connection_id in pool.user_connection_ids(user_id) {
677 self.peer.send(
678 connection_id,
679 proto::UpdateInviteInfo {
680 url: format!(
681 "{}{}",
682 self.app_state.config.invite_link_prefix, invite_code
683 ),
684 count: user.invite_count as u32,
685 },
686 )?;
687 }
688 }
689 }
690 Ok(())
691 }
692
693 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
694 ServerSnapshot {
695 connection_pool: ConnectionPoolGuard {
696 guard: self.connection_pool.lock(),
697 _not_send: PhantomData,
698 },
699 peer: &self.peer,
700 }
701 }
702}
703
704impl<'a> Deref for ConnectionPoolGuard<'a> {
705 type Target = ConnectionPool;
706
707 fn deref(&self) -> &Self::Target {
708 &*self.guard
709 }
710}
711
712impl<'a> DerefMut for ConnectionPoolGuard<'a> {
713 fn deref_mut(&mut self) -> &mut Self::Target {
714 &mut *self.guard
715 }
716}
717
718impl<'a> Drop for ConnectionPoolGuard<'a> {
719 fn drop(&mut self) {
720 #[cfg(test)]
721 self.check_invariants();
722 }
723}
724
725fn broadcast<F>(
726 sender_id: Option<ConnectionId>,
727 receiver_ids: impl IntoIterator<Item = ConnectionId>,
728 mut f: F,
729) where
730 F: FnMut(ConnectionId) -> anyhow::Result<()>,
731{
732 for receiver_id in receiver_ids {
733 if Some(receiver_id) != sender_id {
734 if let Err(error) = f(receiver_id) {
735 tracing::error!("failed to send to {:?} {}", receiver_id, error);
736 }
737 }
738 }
739}
740
741lazy_static! {
742 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
743}
744
745pub struct ProtocolVersion(u32);
746
747impl Header for ProtocolVersion {
748 fn name() -> &'static HeaderName {
749 &ZED_PROTOCOL_VERSION
750 }
751
752 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
753 where
754 Self: Sized,
755 I: Iterator<Item = &'i axum::http::HeaderValue>,
756 {
757 let version = values
758 .next()
759 .ok_or_else(axum::headers::Error::invalid)?
760 .to_str()
761 .map_err(|_| axum::headers::Error::invalid())?
762 .parse()
763 .map_err(|_| axum::headers::Error::invalid())?;
764 Ok(Self(version))
765 }
766
767 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
768 values.extend([self.0.to_string().parse().unwrap()]);
769 }
770}
771
772pub fn routes(server: Arc<Server>) -> Router<Body> {
773 Router::new()
774 .route("/rpc", get(handle_websocket_request))
775 .layer(
776 ServiceBuilder::new()
777 .layer(Extension(server.app_state.clone()))
778 .layer(middleware::from_fn(auth::validate_header)),
779 )
780 .route("/metrics", get(handle_metrics))
781 .layer(Extension(server))
782}
783
784pub async fn handle_websocket_request(
785 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
786 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
787 Extension(server): Extension<Arc<Server>>,
788 Extension(user): Extension<User>,
789 ws: WebSocketUpgrade,
790) -> axum::response::Response {
791 if protocol_version != rpc::PROTOCOL_VERSION {
792 return (
793 StatusCode::UPGRADE_REQUIRED,
794 "client must be upgraded".to_string(),
795 )
796 .into_response();
797 }
798 let socket_address = socket_address.to_string();
799 ws.on_upgrade(move |socket| {
800 use util::ResultExt;
801 let socket = socket
802 .map_ok(to_tungstenite_message)
803 .err_into()
804 .with(|message| async move { Ok(to_axum_message(message)) });
805 let connection = Connection::new(Box::pin(socket));
806 async move {
807 server
808 .handle_connection(connection, socket_address, user, None, Executor::Production)
809 .await
810 .log_err();
811 }
812 })
813}
814
815pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
816 let connections = server
817 .connection_pool
818 .lock()
819 .connections()
820 .filter(|connection| !connection.admin)
821 .count();
822
823 METRIC_CONNECTIONS.set(connections as _);
824
825 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
826 METRIC_SHARED_PROJECTS.set(shared_projects as _);
827
828 let encoder = prometheus::TextEncoder::new();
829 let metric_families = prometheus::gather();
830 let encoded_metrics = encoder
831 .encode_to_string(&metric_families)
832 .map_err(|err| anyhow!("{}", err))?;
833 Ok(encoded_metrics)
834}
835
836#[instrument(err, skip(executor))]
837async fn connection_lost(
838 session: Session,
839 mut teardown: watch::Receiver<()>,
840 executor: Executor,
841) -> Result<()> {
842 session.peer.disconnect(session.connection_id);
843 session
844 .connection_pool()
845 .await
846 .remove_connection(session.connection_id)?;
847
848 session
849 .db()
850 .await
851 .connection_lost(session.connection_id)
852 .await
853 .trace_err();
854
855 futures::select_biased! {
856 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
857 leave_room_for_session(&session).await.trace_err();
858
859 if !session
860 .connection_pool()
861 .await
862 .is_user_online(session.user_id)
863 {
864 let db = session.db().await;
865 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
866 room_updated(&room, &session.peer);
867 }
868 }
869 update_user_contacts(session.user_id, &session).await?;
870 }
871 _ = teardown.changed().fuse() => {}
872 }
873
874 Ok(())
875}
876
877async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
878 response.send(proto::Ack {})?;
879 Ok(())
880}
881
882async fn create_room(
883 _request: proto::CreateRoom,
884 response: Response<proto::CreateRoom>,
885 session: Session,
886) -> Result<()> {
887 let live_kit_room = nanoid::nanoid!(30);
888
889 let live_kit_connection_info = {
890 let live_kit_room = live_kit_room.clone();
891 let live_kit = session.live_kit_client.as_ref();
892
893 util::async_iife!({
894 let live_kit = live_kit?;
895
896 live_kit
897 .create_room(live_kit_room.clone())
898 .await
899 .trace_err()?;
900
901 let token = live_kit
902 .room_token(&live_kit_room, &session.user_id.to_string())
903 .trace_err()?;
904
905 Some(proto::LiveKitConnectionInfo {
906 server_url: live_kit.url().into(),
907 token,
908 })
909 })
910 }
911 .await;
912
913 let room = session
914 .db()
915 .await
916 .create_room(session.user_id, session.connection_id, &live_kit_room)
917 .await?;
918
919 response.send(proto::CreateRoomResponse {
920 room: Some(room.clone()),
921 live_kit_connection_info,
922 })?;
923
924 update_user_contacts(session.user_id, &session).await?;
925 Ok(())
926}
927
928async fn join_room(
929 request: proto::JoinRoom,
930 response: Response<proto::JoinRoom>,
931 session: Session,
932) -> Result<()> {
933 let room_id = RoomId::from_proto(request.id);
934 let room = {
935 let room = session
936 .db()
937 .await
938 .join_room(room_id, session.user_id, None, session.connection_id)
939 .await?;
940 room_updated(&room.room, &session.peer);
941 room.room.clone()
942 };
943
944 for connection_id in session
945 .connection_pool()
946 .await
947 .user_connection_ids(session.user_id)
948 {
949 session
950 .peer
951 .send(
952 connection_id,
953 proto::CallCanceled {
954 room_id: room_id.to_proto(),
955 },
956 )
957 .trace_err();
958 }
959
960 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
961 if let Some(token) = live_kit
962 .room_token(&room.live_kit_room, &session.user_id.to_string())
963 .trace_err()
964 {
965 Some(proto::LiveKitConnectionInfo {
966 server_url: live_kit.url().into(),
967 token,
968 })
969 } else {
970 None
971 }
972 } else {
973 None
974 };
975
976 response.send(proto::JoinRoomResponse {
977 room: Some(room),
978 live_kit_connection_info,
979 })?;
980
981 update_user_contacts(session.user_id, &session).await?;
982 Ok(())
983}
984
985async fn rejoin_room(
986 request: proto::RejoinRoom,
987 response: Response<proto::RejoinRoom>,
988 session: Session,
989) -> Result<()> {
990 let room;
991 let channel_id;
992 let channel_members;
993 {
994 let mut rejoined_room = session
995 .db()
996 .await
997 .rejoin_room(request, session.user_id, session.connection_id)
998 .await?;
999
1000 response.send(proto::RejoinRoomResponse {
1001 room: Some(rejoined_room.room.clone()),
1002 reshared_projects: rejoined_room
1003 .reshared_projects
1004 .iter()
1005 .map(|project| proto::ResharedProject {
1006 id: project.id.to_proto(),
1007 collaborators: project
1008 .collaborators
1009 .iter()
1010 .map(|collaborator| collaborator.to_proto())
1011 .collect(),
1012 })
1013 .collect(),
1014 rejoined_projects: rejoined_room
1015 .rejoined_projects
1016 .iter()
1017 .map(|rejoined_project| proto::RejoinedProject {
1018 id: rejoined_project.id.to_proto(),
1019 worktrees: rejoined_project
1020 .worktrees
1021 .iter()
1022 .map(|worktree| proto::WorktreeMetadata {
1023 id: worktree.id,
1024 root_name: worktree.root_name.clone(),
1025 visible: worktree.visible,
1026 abs_path: worktree.abs_path.clone(),
1027 })
1028 .collect(),
1029 collaborators: rejoined_project
1030 .collaborators
1031 .iter()
1032 .map(|collaborator| collaborator.to_proto())
1033 .collect(),
1034 language_servers: rejoined_project.language_servers.clone(),
1035 })
1036 .collect(),
1037 })?;
1038 room_updated(&rejoined_room.room, &session.peer);
1039
1040 for project in &rejoined_room.reshared_projects {
1041 for collaborator in &project.collaborators {
1042 session
1043 .peer
1044 .send(
1045 collaborator.connection_id,
1046 proto::UpdateProjectCollaborator {
1047 project_id: project.id.to_proto(),
1048 old_peer_id: Some(project.old_connection_id.into()),
1049 new_peer_id: Some(session.connection_id.into()),
1050 },
1051 )
1052 .trace_err();
1053 }
1054
1055 broadcast(
1056 Some(session.connection_id),
1057 project
1058 .collaborators
1059 .iter()
1060 .map(|collaborator| collaborator.connection_id),
1061 |connection_id| {
1062 session.peer.forward_send(
1063 session.connection_id,
1064 connection_id,
1065 proto::UpdateProject {
1066 project_id: project.id.to_proto(),
1067 worktrees: project.worktrees.clone(),
1068 },
1069 )
1070 },
1071 );
1072 }
1073
1074 for project in &rejoined_room.rejoined_projects {
1075 for collaborator in &project.collaborators {
1076 session
1077 .peer
1078 .send(
1079 collaborator.connection_id,
1080 proto::UpdateProjectCollaborator {
1081 project_id: project.id.to_proto(),
1082 old_peer_id: Some(project.old_connection_id.into()),
1083 new_peer_id: Some(session.connection_id.into()),
1084 },
1085 )
1086 .trace_err();
1087 }
1088 }
1089
1090 for project in &mut rejoined_room.rejoined_projects {
1091 for worktree in mem::take(&mut project.worktrees) {
1092 #[cfg(any(test, feature = "test-support"))]
1093 const MAX_CHUNK_SIZE: usize = 2;
1094 #[cfg(not(any(test, feature = "test-support")))]
1095 const MAX_CHUNK_SIZE: usize = 256;
1096
1097 // Stream this worktree's entries.
1098 let message = proto::UpdateWorktree {
1099 project_id: project.id.to_proto(),
1100 worktree_id: worktree.id,
1101 abs_path: worktree.abs_path.clone(),
1102 root_name: worktree.root_name,
1103 updated_entries: worktree.updated_entries,
1104 removed_entries: worktree.removed_entries,
1105 scan_id: worktree.scan_id,
1106 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1107 updated_repositories: worktree.updated_repositories,
1108 removed_repositories: worktree.removed_repositories,
1109 };
1110 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1111 session.peer.send(session.connection_id, update.clone())?;
1112 }
1113
1114 // Stream this worktree's diagnostics.
1115 for summary in worktree.diagnostic_summaries {
1116 session.peer.send(
1117 session.connection_id,
1118 proto::UpdateDiagnosticSummary {
1119 project_id: project.id.to_proto(),
1120 worktree_id: worktree.id,
1121 summary: Some(summary),
1122 },
1123 )?;
1124 }
1125
1126 for settings_file in worktree.settings_files {
1127 session.peer.send(
1128 session.connection_id,
1129 proto::UpdateWorktreeSettings {
1130 project_id: project.id.to_proto(),
1131 worktree_id: worktree.id,
1132 path: settings_file.path,
1133 content: Some(settings_file.content),
1134 },
1135 )?;
1136 }
1137 }
1138
1139 for language_server in &project.language_servers {
1140 session.peer.send(
1141 session.connection_id,
1142 proto::UpdateLanguageServer {
1143 project_id: project.id.to_proto(),
1144 language_server_id: language_server.id,
1145 variant: Some(
1146 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1147 proto::LspDiskBasedDiagnosticsUpdated {},
1148 ),
1149 ),
1150 },
1151 )?;
1152 }
1153 }
1154
1155 room = mem::take(&mut rejoined_room.room);
1156 channel_id = rejoined_room.channel_id;
1157 channel_members = mem::take(&mut rejoined_room.channel_members);
1158 }
1159
1160 //TODO: move this into the room guard
1161 if let Some(channel_id) = channel_id {
1162 channel_updated(
1163 channel_id,
1164 &room,
1165 &channel_members,
1166 &session.peer,
1167 &*session.connection_pool().await,
1168 );
1169 }
1170
1171 update_user_contacts(session.user_id, &session).await?;
1172 Ok(())
1173}
1174
1175async fn leave_room(
1176 _: proto::LeaveRoom,
1177 response: Response<proto::LeaveRoom>,
1178 session: Session,
1179) -> Result<()> {
1180 leave_room_for_session(&session).await?;
1181 response.send(proto::Ack {})?;
1182 Ok(())
1183}
1184
1185async fn call(
1186 request: proto::Call,
1187 response: Response<proto::Call>,
1188 session: Session,
1189) -> Result<()> {
1190 let room_id = RoomId::from_proto(request.room_id);
1191 let calling_user_id = session.user_id;
1192 let calling_connection_id = session.connection_id;
1193 let called_user_id = UserId::from_proto(request.called_user_id);
1194 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1195 if !session
1196 .db()
1197 .await
1198 .has_contact(calling_user_id, called_user_id)
1199 .await?
1200 {
1201 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1202 }
1203
1204 let incoming_call = {
1205 let (room, incoming_call) = &mut *session
1206 .db()
1207 .await
1208 .call(
1209 room_id,
1210 calling_user_id,
1211 calling_connection_id,
1212 called_user_id,
1213 initial_project_id,
1214 )
1215 .await?;
1216 room_updated(&room, &session.peer);
1217 mem::take(incoming_call)
1218 };
1219 update_user_contacts(called_user_id, &session).await?;
1220
1221 let mut calls = session
1222 .connection_pool()
1223 .await
1224 .user_connection_ids(called_user_id)
1225 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1226 .collect::<FuturesUnordered<_>>();
1227
1228 while let Some(call_response) = calls.next().await {
1229 match call_response.as_ref() {
1230 Ok(_) => {
1231 response.send(proto::Ack {})?;
1232 return Ok(());
1233 }
1234 Err(_) => {
1235 call_response.trace_err();
1236 }
1237 }
1238 }
1239
1240 {
1241 let room = session
1242 .db()
1243 .await
1244 .call_failed(room_id, called_user_id)
1245 .await?;
1246 room_updated(&room, &session.peer);
1247 }
1248 update_user_contacts(called_user_id, &session).await?;
1249
1250 Err(anyhow!("failed to ring user"))?
1251}
1252
1253async fn cancel_call(
1254 request: proto::CancelCall,
1255 response: Response<proto::CancelCall>,
1256 session: Session,
1257) -> Result<()> {
1258 let called_user_id = UserId::from_proto(request.called_user_id);
1259 let room_id = RoomId::from_proto(request.room_id);
1260 {
1261 let room = session
1262 .db()
1263 .await
1264 .cancel_call(room_id, session.connection_id, called_user_id)
1265 .await?;
1266 room_updated(&room, &session.peer);
1267 }
1268
1269 for connection_id in session
1270 .connection_pool()
1271 .await
1272 .user_connection_ids(called_user_id)
1273 {
1274 session
1275 .peer
1276 .send(
1277 connection_id,
1278 proto::CallCanceled {
1279 room_id: room_id.to_proto(),
1280 },
1281 )
1282 .trace_err();
1283 }
1284 response.send(proto::Ack {})?;
1285
1286 update_user_contacts(called_user_id, &session).await?;
1287 Ok(())
1288}
1289
1290async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1291 let room_id = RoomId::from_proto(message.room_id);
1292 {
1293 let room = session
1294 .db()
1295 .await
1296 .decline_call(Some(room_id), session.user_id)
1297 .await?
1298 .ok_or_else(|| anyhow!("failed to decline call"))?;
1299 room_updated(&room, &session.peer);
1300 }
1301
1302 for connection_id in session
1303 .connection_pool()
1304 .await
1305 .user_connection_ids(session.user_id)
1306 {
1307 session
1308 .peer
1309 .send(
1310 connection_id,
1311 proto::CallCanceled {
1312 room_id: room_id.to_proto(),
1313 },
1314 )
1315 .trace_err();
1316 }
1317 update_user_contacts(session.user_id, &session).await?;
1318 Ok(())
1319}
1320
1321async fn update_participant_location(
1322 request: proto::UpdateParticipantLocation,
1323 response: Response<proto::UpdateParticipantLocation>,
1324 session: Session,
1325) -> Result<()> {
1326 let room_id = RoomId::from_proto(request.room_id);
1327 let location = request
1328 .location
1329 .ok_or_else(|| anyhow!("invalid location"))?;
1330
1331 let db = session.db().await;
1332 let room = db
1333 .update_room_participant_location(room_id, session.connection_id, location)
1334 .await?;
1335
1336 room_updated(&room, &session.peer);
1337 response.send(proto::Ack {})?;
1338 Ok(())
1339}
1340
1341async fn share_project(
1342 request: proto::ShareProject,
1343 response: Response<proto::ShareProject>,
1344 session: Session,
1345) -> Result<()> {
1346 let (project_id, room) = &*session
1347 .db()
1348 .await
1349 .share_project(
1350 RoomId::from_proto(request.room_id),
1351 session.connection_id,
1352 &request.worktrees,
1353 )
1354 .await?;
1355 response.send(proto::ShareProjectResponse {
1356 project_id: project_id.to_proto(),
1357 })?;
1358 room_updated(&room, &session.peer);
1359
1360 Ok(())
1361}
1362
1363async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1364 let project_id = ProjectId::from_proto(message.project_id);
1365
1366 let (room, guest_connection_ids) = &*session
1367 .db()
1368 .await
1369 .unshare_project(project_id, session.connection_id)
1370 .await?;
1371
1372 broadcast(
1373 Some(session.connection_id),
1374 guest_connection_ids.iter().copied(),
1375 |conn_id| session.peer.send(conn_id, message.clone()),
1376 );
1377 room_updated(&room, &session.peer);
1378
1379 Ok(())
1380}
1381
1382async fn join_project(
1383 request: proto::JoinProject,
1384 response: Response<proto::JoinProject>,
1385 session: Session,
1386) -> Result<()> {
1387 let project_id = ProjectId::from_proto(request.project_id);
1388 let guest_user_id = session.user_id;
1389
1390 tracing::info!(%project_id, "join project");
1391
1392 let (project, replica_id) = &mut *session
1393 .db()
1394 .await
1395 .join_project(project_id, session.connection_id)
1396 .await?;
1397
1398 let collaborators = project
1399 .collaborators
1400 .iter()
1401 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1402 .map(|collaborator| collaborator.to_proto())
1403 .collect::<Vec<_>>();
1404
1405 let worktrees = project
1406 .worktrees
1407 .iter()
1408 .map(|(id, worktree)| proto::WorktreeMetadata {
1409 id: *id,
1410 root_name: worktree.root_name.clone(),
1411 visible: worktree.visible,
1412 abs_path: worktree.abs_path.clone(),
1413 })
1414 .collect::<Vec<_>>();
1415
1416 for collaborator in &collaborators {
1417 session
1418 .peer
1419 .send(
1420 collaborator.peer_id.unwrap().into(),
1421 proto::AddProjectCollaborator {
1422 project_id: project_id.to_proto(),
1423 collaborator: Some(proto::Collaborator {
1424 peer_id: Some(session.connection_id.into()),
1425 replica_id: replica_id.0 as u32,
1426 user_id: guest_user_id.to_proto(),
1427 }),
1428 },
1429 )
1430 .trace_err();
1431 }
1432
1433 // First, we send the metadata associated with each worktree.
1434 response.send(proto::JoinProjectResponse {
1435 worktrees: worktrees.clone(),
1436 replica_id: replica_id.0 as u32,
1437 collaborators: collaborators.clone(),
1438 language_servers: project.language_servers.clone(),
1439 })?;
1440
1441 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1442 #[cfg(any(test, feature = "test-support"))]
1443 const MAX_CHUNK_SIZE: usize = 2;
1444 #[cfg(not(any(test, feature = "test-support")))]
1445 const MAX_CHUNK_SIZE: usize = 256;
1446
1447 // Stream this worktree's entries.
1448 let message = proto::UpdateWorktree {
1449 project_id: project_id.to_proto(),
1450 worktree_id,
1451 abs_path: worktree.abs_path.clone(),
1452 root_name: worktree.root_name,
1453 updated_entries: worktree.entries,
1454 removed_entries: Default::default(),
1455 scan_id: worktree.scan_id,
1456 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1457 updated_repositories: worktree.repository_entries.into_values().collect(),
1458 removed_repositories: Default::default(),
1459 };
1460 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1461 session.peer.send(session.connection_id, update.clone())?;
1462 }
1463
1464 // Stream this worktree's diagnostics.
1465 for summary in worktree.diagnostic_summaries {
1466 session.peer.send(
1467 session.connection_id,
1468 proto::UpdateDiagnosticSummary {
1469 project_id: project_id.to_proto(),
1470 worktree_id: worktree.id,
1471 summary: Some(summary),
1472 },
1473 )?;
1474 }
1475
1476 for settings_file in worktree.settings_files {
1477 session.peer.send(
1478 session.connection_id,
1479 proto::UpdateWorktreeSettings {
1480 project_id: project_id.to_proto(),
1481 worktree_id: worktree.id,
1482 path: settings_file.path,
1483 content: Some(settings_file.content),
1484 },
1485 )?;
1486 }
1487 }
1488
1489 for language_server in &project.language_servers {
1490 session.peer.send(
1491 session.connection_id,
1492 proto::UpdateLanguageServer {
1493 project_id: project_id.to_proto(),
1494 language_server_id: language_server.id,
1495 variant: Some(
1496 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1497 proto::LspDiskBasedDiagnosticsUpdated {},
1498 ),
1499 ),
1500 },
1501 )?;
1502 }
1503
1504 Ok(())
1505}
1506
1507async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1508 let sender_id = session.connection_id;
1509 let project_id = ProjectId::from_proto(request.project_id);
1510
1511 let (room, project) = &*session
1512 .db()
1513 .await
1514 .leave_project(project_id, sender_id)
1515 .await?;
1516 tracing::info!(
1517 %project_id,
1518 host_user_id = %project.host_user_id,
1519 host_connection_id = %project.host_connection_id,
1520 "leave project"
1521 );
1522
1523 project_left(&project, &session);
1524 room_updated(&room, &session.peer);
1525
1526 Ok(())
1527}
1528
1529async fn update_project(
1530 request: proto::UpdateProject,
1531 response: Response<proto::UpdateProject>,
1532 session: Session,
1533) -> Result<()> {
1534 let project_id = ProjectId::from_proto(request.project_id);
1535 let (room, guest_connection_ids) = &*session
1536 .db()
1537 .await
1538 .update_project(project_id, session.connection_id, &request.worktrees)
1539 .await?;
1540 broadcast(
1541 Some(session.connection_id),
1542 guest_connection_ids.iter().copied(),
1543 |connection_id| {
1544 session
1545 .peer
1546 .forward_send(session.connection_id, connection_id, request.clone())
1547 },
1548 );
1549 room_updated(&room, &session.peer);
1550 response.send(proto::Ack {})?;
1551
1552 Ok(())
1553}
1554
1555async fn update_worktree(
1556 request: proto::UpdateWorktree,
1557 response: Response<proto::UpdateWorktree>,
1558 session: Session,
1559) -> Result<()> {
1560 let guest_connection_ids = session
1561 .db()
1562 .await
1563 .update_worktree(&request, session.connection_id)
1564 .await?;
1565
1566 broadcast(
1567 Some(session.connection_id),
1568 guest_connection_ids.iter().copied(),
1569 |connection_id| {
1570 session
1571 .peer
1572 .forward_send(session.connection_id, connection_id, request.clone())
1573 },
1574 );
1575 response.send(proto::Ack {})?;
1576 Ok(())
1577}
1578
1579async fn update_diagnostic_summary(
1580 message: proto::UpdateDiagnosticSummary,
1581 session: Session,
1582) -> Result<()> {
1583 let guest_connection_ids = session
1584 .db()
1585 .await
1586 .update_diagnostic_summary(&message, session.connection_id)
1587 .await?;
1588
1589 broadcast(
1590 Some(session.connection_id),
1591 guest_connection_ids.iter().copied(),
1592 |connection_id| {
1593 session
1594 .peer
1595 .forward_send(session.connection_id, connection_id, message.clone())
1596 },
1597 );
1598
1599 Ok(())
1600}
1601
1602async fn update_worktree_settings(
1603 message: proto::UpdateWorktreeSettings,
1604 session: Session,
1605) -> Result<()> {
1606 let guest_connection_ids = session
1607 .db()
1608 .await
1609 .update_worktree_settings(&message, session.connection_id)
1610 .await?;
1611
1612 broadcast(
1613 Some(session.connection_id),
1614 guest_connection_ids.iter().copied(),
1615 |connection_id| {
1616 session
1617 .peer
1618 .forward_send(session.connection_id, connection_id, message.clone())
1619 },
1620 );
1621
1622 Ok(())
1623}
1624
1625async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1626 broadcast_project_message(request.project_id, request, session).await
1627}
1628
1629async fn start_language_server(
1630 request: proto::StartLanguageServer,
1631 session: Session,
1632) -> Result<()> {
1633 let guest_connection_ids = session
1634 .db()
1635 .await
1636 .start_language_server(&request, session.connection_id)
1637 .await?;
1638
1639 broadcast(
1640 Some(session.connection_id),
1641 guest_connection_ids.iter().copied(),
1642 |connection_id| {
1643 session
1644 .peer
1645 .forward_send(session.connection_id, connection_id, request.clone())
1646 },
1647 );
1648 Ok(())
1649}
1650
1651async fn update_language_server(
1652 request: proto::UpdateLanguageServer,
1653 session: Session,
1654) -> Result<()> {
1655 session.executor.record_backtrace();
1656 let project_id = ProjectId::from_proto(request.project_id);
1657 let project_connection_ids = session
1658 .db()
1659 .await
1660 .project_connection_ids(project_id, session.connection_id)
1661 .await?;
1662 broadcast(
1663 Some(session.connection_id),
1664 project_connection_ids.iter().copied(),
1665 |connection_id| {
1666 session
1667 .peer
1668 .forward_send(session.connection_id, connection_id, request.clone())
1669 },
1670 );
1671 Ok(())
1672}
1673
1674async fn forward_project_request<T>(
1675 request: T,
1676 response: Response<T>,
1677 session: Session,
1678) -> Result<()>
1679where
1680 T: EntityMessage + RequestMessage,
1681{
1682 session.executor.record_backtrace();
1683 let project_id = ProjectId::from_proto(request.remote_entity_id());
1684 let host_connection_id = {
1685 let collaborators = session
1686 .db()
1687 .await
1688 .project_collaborators(project_id, session.connection_id)
1689 .await?;
1690 collaborators
1691 .iter()
1692 .find(|collaborator| collaborator.is_host)
1693 .ok_or_else(|| anyhow!("host not found"))?
1694 .connection_id
1695 };
1696
1697 let payload = session
1698 .peer
1699 .forward_request(session.connection_id, host_connection_id, request)
1700 .await?;
1701
1702 response.send(payload)?;
1703 Ok(())
1704}
1705
1706async fn create_buffer_for_peer(
1707 request: proto::CreateBufferForPeer,
1708 session: Session,
1709) -> Result<()> {
1710 session.executor.record_backtrace();
1711 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1712 session
1713 .peer
1714 .forward_send(session.connection_id, peer_id.into(), request)?;
1715 Ok(())
1716}
1717
1718async fn update_buffer(
1719 request: proto::UpdateBuffer,
1720 response: Response<proto::UpdateBuffer>,
1721 session: Session,
1722) -> Result<()> {
1723 session.executor.record_backtrace();
1724 let project_id = ProjectId::from_proto(request.project_id);
1725 let mut guest_connection_ids;
1726 let mut host_connection_id = None;
1727 {
1728 let collaborators = session
1729 .db()
1730 .await
1731 .project_collaborators(project_id, session.connection_id)
1732 .await?;
1733 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1734 for collaborator in collaborators.iter() {
1735 if collaborator.is_host {
1736 host_connection_id = Some(collaborator.connection_id);
1737 } else {
1738 guest_connection_ids.push(collaborator.connection_id);
1739 }
1740 }
1741 }
1742 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1743
1744 session.executor.record_backtrace();
1745 broadcast(
1746 Some(session.connection_id),
1747 guest_connection_ids,
1748 |connection_id| {
1749 session
1750 .peer
1751 .forward_send(session.connection_id, connection_id, request.clone())
1752 },
1753 );
1754 if host_connection_id != session.connection_id {
1755 session
1756 .peer
1757 .forward_request(session.connection_id, host_connection_id, request.clone())
1758 .await?;
1759 }
1760
1761 response.send(proto::Ack {})?;
1762 Ok(())
1763}
1764
1765async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1766 let project_id = ProjectId::from_proto(request.project_id);
1767 let project_connection_ids = session
1768 .db()
1769 .await
1770 .project_connection_ids(project_id, session.connection_id)
1771 .await?;
1772
1773 broadcast(
1774 Some(session.connection_id),
1775 project_connection_ids.iter().copied(),
1776 |connection_id| {
1777 session
1778 .peer
1779 .forward_send(session.connection_id, connection_id, request.clone())
1780 },
1781 );
1782 Ok(())
1783}
1784
1785async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1786 let project_id = ProjectId::from_proto(request.project_id);
1787 let project_connection_ids = session
1788 .db()
1789 .await
1790 .project_connection_ids(project_id, session.connection_id)
1791 .await?;
1792 broadcast(
1793 Some(session.connection_id),
1794 project_connection_ids.iter().copied(),
1795 |connection_id| {
1796 session
1797 .peer
1798 .forward_send(session.connection_id, connection_id, request.clone())
1799 },
1800 );
1801 Ok(())
1802}
1803
1804async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1805 broadcast_project_message(request.project_id, request, session).await
1806}
1807
1808async fn broadcast_project_message<T: EnvelopedMessage>(
1809 project_id: u64,
1810 request: T,
1811 session: Session,
1812) -> Result<()> {
1813 let project_id = ProjectId::from_proto(project_id);
1814 let project_connection_ids = session
1815 .db()
1816 .await
1817 .project_connection_ids(project_id, session.connection_id)
1818 .await?;
1819 broadcast(
1820 Some(session.connection_id),
1821 project_connection_ids.iter().copied(),
1822 |connection_id| {
1823 session
1824 .peer
1825 .forward_send(session.connection_id, connection_id, request.clone())
1826 },
1827 );
1828 Ok(())
1829}
1830
1831async fn follow(
1832 request: proto::Follow,
1833 response: Response<proto::Follow>,
1834 session: Session,
1835) -> Result<()> {
1836 let project_id = ProjectId::from_proto(request.project_id);
1837 let leader_id = request
1838 .leader_id
1839 .ok_or_else(|| anyhow!("invalid leader id"))?
1840 .into();
1841 let follower_id = session.connection_id;
1842
1843 {
1844 let project_connection_ids = session
1845 .db()
1846 .await
1847 .project_connection_ids(project_id, session.connection_id)
1848 .await?;
1849
1850 if !project_connection_ids.contains(&leader_id) {
1851 Err(anyhow!("no such peer"))?;
1852 }
1853 }
1854
1855 let mut response_payload = session
1856 .peer
1857 .forward_request(session.connection_id, leader_id, request)
1858 .await?;
1859 response_payload
1860 .views
1861 .retain(|view| view.leader_id != Some(follower_id.into()));
1862 response.send(response_payload)?;
1863
1864 let room = session
1865 .db()
1866 .await
1867 .follow(project_id, leader_id, follower_id)
1868 .await?;
1869 room_updated(&room, &session.peer);
1870
1871 Ok(())
1872}
1873
1874async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1875 let project_id = ProjectId::from_proto(request.project_id);
1876 let leader_id = request
1877 .leader_id
1878 .ok_or_else(|| anyhow!("invalid leader id"))?
1879 .into();
1880 let follower_id = session.connection_id;
1881
1882 if !session
1883 .db()
1884 .await
1885 .project_connection_ids(project_id, session.connection_id)
1886 .await?
1887 .contains(&leader_id)
1888 {
1889 Err(anyhow!("no such peer"))?;
1890 }
1891
1892 session
1893 .peer
1894 .forward_send(session.connection_id, leader_id, request)?;
1895
1896 let room = session
1897 .db()
1898 .await
1899 .unfollow(project_id, leader_id, follower_id)
1900 .await?;
1901 room_updated(&room, &session.peer);
1902
1903 Ok(())
1904}
1905
1906async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1907 let project_id = ProjectId::from_proto(request.project_id);
1908 let project_connection_ids = session
1909 .db
1910 .lock()
1911 .await
1912 .project_connection_ids(project_id, session.connection_id)
1913 .await?;
1914
1915 let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1916 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1917 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1918 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1919 });
1920 for follower_peer_id in request.follower_ids.iter().copied() {
1921 let follower_connection_id = follower_peer_id.into();
1922 if project_connection_ids.contains(&follower_connection_id)
1923 && Some(follower_peer_id) != leader_id
1924 {
1925 session.peer.forward_send(
1926 session.connection_id,
1927 follower_connection_id,
1928 request.clone(),
1929 )?;
1930 }
1931 }
1932 Ok(())
1933}
1934
1935async fn get_users(
1936 request: proto::GetUsers,
1937 response: Response<proto::GetUsers>,
1938 session: Session,
1939) -> Result<()> {
1940 let user_ids = request
1941 .user_ids
1942 .into_iter()
1943 .map(UserId::from_proto)
1944 .collect();
1945 let users = session
1946 .db()
1947 .await
1948 .get_users_by_ids(user_ids)
1949 .await?
1950 .into_iter()
1951 .map(|user| proto::User {
1952 id: user.id.to_proto(),
1953 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1954 github_login: user.github_login,
1955 })
1956 .collect();
1957 response.send(proto::UsersResponse { users })?;
1958 Ok(())
1959}
1960
1961async fn fuzzy_search_users(
1962 request: proto::FuzzySearchUsers,
1963 response: Response<proto::FuzzySearchUsers>,
1964 session: Session,
1965) -> Result<()> {
1966 let query = request.query;
1967 let users = match query.len() {
1968 0 => vec![],
1969 1 | 2 => session
1970 .db()
1971 .await
1972 .get_user_by_github_login(&query)
1973 .await?
1974 .into_iter()
1975 .collect(),
1976 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1977 };
1978 let users = users
1979 .into_iter()
1980 .filter(|user| user.id != session.user_id)
1981 .map(|user| proto::User {
1982 id: user.id.to_proto(),
1983 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1984 github_login: user.github_login,
1985 })
1986 .collect();
1987 response.send(proto::UsersResponse { users })?;
1988 Ok(())
1989}
1990
1991async fn request_contact(
1992 request: proto::RequestContact,
1993 response: Response<proto::RequestContact>,
1994 session: Session,
1995) -> Result<()> {
1996 let requester_id = session.user_id;
1997 let responder_id = UserId::from_proto(request.responder_id);
1998 if requester_id == responder_id {
1999 return Err(anyhow!("cannot add yourself as a contact"))?;
2000 }
2001
2002 session
2003 .db()
2004 .await
2005 .send_contact_request(requester_id, responder_id)
2006 .await?;
2007
2008 // Update outgoing contact requests of requester
2009 let mut update = proto::UpdateContacts::default();
2010 update.outgoing_requests.push(responder_id.to_proto());
2011 for connection_id in session
2012 .connection_pool()
2013 .await
2014 .user_connection_ids(requester_id)
2015 {
2016 session.peer.send(connection_id, update.clone())?;
2017 }
2018
2019 // Update incoming contact requests of responder
2020 let mut update = proto::UpdateContacts::default();
2021 update
2022 .incoming_requests
2023 .push(proto::IncomingContactRequest {
2024 requester_id: requester_id.to_proto(),
2025 should_notify: true,
2026 });
2027 for connection_id in session
2028 .connection_pool()
2029 .await
2030 .user_connection_ids(responder_id)
2031 {
2032 session.peer.send(connection_id, update.clone())?;
2033 }
2034
2035 response.send(proto::Ack {})?;
2036 Ok(())
2037}
2038
2039async fn respond_to_contact_request(
2040 request: proto::RespondToContactRequest,
2041 response: Response<proto::RespondToContactRequest>,
2042 session: Session,
2043) -> Result<()> {
2044 let responder_id = session.user_id;
2045 let requester_id = UserId::from_proto(request.requester_id);
2046 let db = session.db().await;
2047 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2048 db.dismiss_contact_notification(responder_id, requester_id)
2049 .await?;
2050 } else {
2051 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2052
2053 db.respond_to_contact_request(responder_id, requester_id, accept)
2054 .await?;
2055 let requester_busy = db.is_user_busy(requester_id).await?;
2056 let responder_busy = db.is_user_busy(responder_id).await?;
2057
2058 let pool = session.connection_pool().await;
2059 // Update responder with new contact
2060 let mut update = proto::UpdateContacts::default();
2061 if accept {
2062 update
2063 .contacts
2064 .push(contact_for_user(requester_id, false, requester_busy, &pool));
2065 }
2066 update
2067 .remove_incoming_requests
2068 .push(requester_id.to_proto());
2069 for connection_id in pool.user_connection_ids(responder_id) {
2070 session.peer.send(connection_id, update.clone())?;
2071 }
2072
2073 // Update requester with new contact
2074 let mut update = proto::UpdateContacts::default();
2075 if accept {
2076 update
2077 .contacts
2078 .push(contact_for_user(responder_id, true, responder_busy, &pool));
2079 }
2080 update
2081 .remove_outgoing_requests
2082 .push(responder_id.to_proto());
2083 for connection_id in pool.user_connection_ids(requester_id) {
2084 session.peer.send(connection_id, update.clone())?;
2085 }
2086 }
2087
2088 response.send(proto::Ack {})?;
2089 Ok(())
2090}
2091
2092async fn remove_contact(
2093 request: proto::RemoveContact,
2094 response: Response<proto::RemoveContact>,
2095 session: Session,
2096) -> Result<()> {
2097 let requester_id = session.user_id;
2098 let responder_id = UserId::from_proto(request.user_id);
2099 let db = session.db().await;
2100 let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
2101
2102 let pool = session.connection_pool().await;
2103 // Update outgoing contact requests of requester
2104 let mut update = proto::UpdateContacts::default();
2105 if contact_accepted {
2106 update.remove_contacts.push(responder_id.to_proto());
2107 } else {
2108 update
2109 .remove_outgoing_requests
2110 .push(responder_id.to_proto());
2111 }
2112 for connection_id in pool.user_connection_ids(requester_id) {
2113 session.peer.send(connection_id, update.clone())?;
2114 }
2115
2116 // Update incoming contact requests of responder
2117 let mut update = proto::UpdateContacts::default();
2118 if contact_accepted {
2119 update.remove_contacts.push(requester_id.to_proto());
2120 } else {
2121 update
2122 .remove_incoming_requests
2123 .push(requester_id.to_proto());
2124 }
2125 for connection_id in pool.user_connection_ids(responder_id) {
2126 session.peer.send(connection_id, update.clone())?;
2127 }
2128
2129 response.send(proto::Ack {})?;
2130 Ok(())
2131}
2132
2133async fn create_channel(
2134 request: proto::CreateChannel,
2135 response: Response<proto::CreateChannel>,
2136 session: Session,
2137) -> Result<()> {
2138 let db = session.db().await;
2139 let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
2140
2141 if let Some(live_kit) = session.live_kit_client.as_ref() {
2142 live_kit.create_room(live_kit_room.clone()).await?;
2143 }
2144
2145 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2146 let id = db
2147 .create_channel(&request.name, parent_id, &live_kit_room, session.user_id)
2148 .await?;
2149
2150 response.send(proto::CreateChannelResponse {
2151 channel_id: id.to_proto(),
2152 })?;
2153
2154 let mut update = proto::UpdateChannels::default();
2155 update.channels.push(proto::Channel {
2156 id: id.to_proto(),
2157 name: request.name,
2158 parent_id: request.parent_id,
2159 });
2160
2161 let user_ids_to_notify = if let Some(parent_id) = parent_id {
2162 db.get_channel_members(parent_id).await?
2163 } else {
2164 vec![session.user_id]
2165 };
2166
2167 let connection_pool = session.connection_pool().await;
2168 for user_id in user_ids_to_notify {
2169 for connection_id in connection_pool.user_connection_ids(user_id) {
2170 let mut update = update.clone();
2171 if user_id == session.user_id {
2172 update.channel_permissions.push(proto::ChannelPermission {
2173 channel_id: id.to_proto(),
2174 is_admin: true,
2175 });
2176 }
2177 session.peer.send(connection_id, update)?;
2178 }
2179 }
2180
2181 Ok(())
2182}
2183
2184async fn remove_channel(
2185 request: proto::RemoveChannel,
2186 response: Response<proto::RemoveChannel>,
2187 session: Session,
2188) -> Result<()> {
2189 let db = session.db().await;
2190
2191 let channel_id = request.channel_id;
2192 let (removed_channels, member_ids) = db
2193 .remove_channel(ChannelId::from_proto(channel_id), session.user_id)
2194 .await?;
2195 response.send(proto::Ack {})?;
2196
2197 // Notify members of removed channels
2198 let mut update = proto::UpdateChannels::default();
2199 update
2200 .remove_channels
2201 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2202
2203 let connection_pool = session.connection_pool().await;
2204 for member_id in member_ids {
2205 for connection_id in connection_pool.user_connection_ids(member_id) {
2206 session.peer.send(connection_id, update.clone())?;
2207 }
2208 }
2209
2210 Ok(())
2211}
2212
2213async fn invite_channel_member(
2214 request: proto::InviteChannelMember,
2215 response: Response<proto::InviteChannelMember>,
2216 session: Session,
2217) -> Result<()> {
2218 let db = session.db().await;
2219 let channel_id = ChannelId::from_proto(request.channel_id);
2220 let invitee_id = UserId::from_proto(request.user_id);
2221 db.invite_channel_member(channel_id, invitee_id, session.user_id, request.admin)
2222 .await?;
2223
2224 let (channel, _) = db
2225 .get_channel(channel_id, session.user_id)
2226 .await?
2227 .ok_or_else(|| anyhow!("channel not found"))?;
2228
2229 let mut update = proto::UpdateChannels::default();
2230 update.channel_invitations.push(proto::Channel {
2231 id: channel.id.to_proto(),
2232 name: channel.name,
2233 parent_id: None,
2234 });
2235 for connection_id in session
2236 .connection_pool()
2237 .await
2238 .user_connection_ids(invitee_id)
2239 {
2240 session.peer.send(connection_id, update.clone())?;
2241 }
2242
2243 response.send(proto::Ack {})?;
2244 Ok(())
2245}
2246
2247async fn remove_channel_member(
2248 request: proto::RemoveChannelMember,
2249 response: Response<proto::RemoveChannelMember>,
2250 session: Session,
2251) -> Result<()> {
2252 let db = session.db().await;
2253 let channel_id = ChannelId::from_proto(request.channel_id);
2254 let member_id = UserId::from_proto(request.user_id);
2255
2256 db.remove_channel_member(channel_id, member_id, session.user_id)
2257 .await?;
2258
2259 let mut update = proto::UpdateChannels::default();
2260 update.remove_channels.push(channel_id.to_proto());
2261
2262 for connection_id in session
2263 .connection_pool()
2264 .await
2265 .user_connection_ids(member_id)
2266 {
2267 session.peer.send(connection_id, update.clone())?;
2268 }
2269
2270 response.send(proto::Ack {})?;
2271 Ok(())
2272}
2273
2274async fn set_channel_member_admin(
2275 request: proto::SetChannelMemberAdmin,
2276 response: Response<proto::SetChannelMemberAdmin>,
2277 session: Session,
2278) -> Result<()> {
2279 let db = session.db().await;
2280 let channel_id = ChannelId::from_proto(request.channel_id);
2281 let member_id = UserId::from_proto(request.user_id);
2282 db.set_channel_member_admin(channel_id, session.user_id, member_id, request.admin)
2283 .await?;
2284
2285 let (channel, has_accepted) = db
2286 .get_channel(channel_id, member_id)
2287 .await?
2288 .ok_or_else(|| anyhow!("channel not found"))?;
2289
2290 let mut update = proto::UpdateChannels::default();
2291 if has_accepted {
2292 update.channel_permissions.push(proto::ChannelPermission {
2293 channel_id: channel.id.to_proto(),
2294 is_admin: request.admin,
2295 });
2296 }
2297
2298 for connection_id in session
2299 .connection_pool()
2300 .await
2301 .user_connection_ids(member_id)
2302 {
2303 session.peer.send(connection_id, update.clone())?;
2304 }
2305
2306 response.send(proto::Ack {})?;
2307 Ok(())
2308}
2309
2310async fn rename_channel(
2311 request: proto::RenameChannel,
2312 response: Response<proto::RenameChannel>,
2313 session: Session,
2314) -> Result<()> {
2315 let db = session.db().await;
2316 let channel_id = ChannelId::from_proto(request.channel_id);
2317 let new_name = db
2318 .rename_channel(channel_id, session.user_id, &request.name)
2319 .await?;
2320
2321 response.send(proto::Ack {})?;
2322
2323 let mut update = proto::UpdateChannels::default();
2324 update.channels.push(proto::Channel {
2325 id: request.channel_id,
2326 name: new_name,
2327 parent_id: None,
2328 });
2329
2330 let member_ids = db.get_channel_members(channel_id).await?;
2331
2332 let connection_pool = session.connection_pool().await;
2333 for member_id in member_ids {
2334 for connection_id in connection_pool.user_connection_ids(member_id) {
2335 session.peer.send(connection_id, update.clone())?;
2336 }
2337 }
2338
2339 Ok(())
2340}
2341
2342async fn get_channel_members(
2343 request: proto::GetChannelMembers,
2344 response: Response<proto::GetChannelMembers>,
2345 session: Session,
2346) -> Result<()> {
2347 let db = session.db().await;
2348 let channel_id = ChannelId::from_proto(request.channel_id);
2349 let members = db
2350 .get_channel_member_details(channel_id, session.user_id)
2351 .await?;
2352 response.send(proto::GetChannelMembersResponse { members })?;
2353 Ok(())
2354}
2355
2356async fn respond_to_channel_invite(
2357 request: proto::RespondToChannelInvite,
2358 response: Response<proto::RespondToChannelInvite>,
2359 session: Session,
2360) -> Result<()> {
2361 let db = session.db().await;
2362 let channel_id = ChannelId::from_proto(request.channel_id);
2363 db.respond_to_channel_invite(channel_id, session.user_id, request.accept)
2364 .await?;
2365
2366 let mut update = proto::UpdateChannels::default();
2367 update
2368 .remove_channel_invitations
2369 .push(channel_id.to_proto());
2370 if request.accept {
2371 let result = db.get_channels_for_user(session.user_id).await?;
2372 update
2373 .channels
2374 .extend(result.channels.into_iter().map(|channel| proto::Channel {
2375 id: channel.id.to_proto(),
2376 name: channel.name,
2377 parent_id: channel.parent_id.map(ChannelId::to_proto),
2378 }));
2379 update
2380 .channel_participants
2381 .extend(
2382 result
2383 .channel_participants
2384 .into_iter()
2385 .map(|(channel_id, user_ids)| proto::ChannelParticipants {
2386 channel_id: channel_id.to_proto(),
2387 participant_user_ids: user_ids.into_iter().map(UserId::to_proto).collect(),
2388 }),
2389 );
2390 update
2391 .channel_permissions
2392 .extend(
2393 result
2394 .channels_with_admin_privileges
2395 .into_iter()
2396 .map(|channel_id| proto::ChannelPermission {
2397 channel_id: channel_id.to_proto(),
2398 is_admin: true,
2399 }),
2400 );
2401 }
2402 session.peer.send(session.connection_id, update)?;
2403 response.send(proto::Ack {})?;
2404
2405 Ok(())
2406}
2407
2408async fn join_channel(
2409 request: proto::JoinChannel,
2410 response: Response<proto::JoinChannel>,
2411 session: Session,
2412) -> Result<()> {
2413 let channel_id = ChannelId::from_proto(request.channel_id);
2414
2415 let joined_room = {
2416 let db = session.db().await;
2417
2418 let room_id = db.room_id_for_channel(channel_id).await?;
2419
2420 let joined_room = db
2421 .join_room(
2422 room_id,
2423 session.user_id,
2424 Some(channel_id),
2425 session.connection_id,
2426 )
2427 .await?;
2428
2429 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2430 let token = live_kit
2431 .room_token(
2432 &joined_room.room.live_kit_room,
2433 &session.user_id.to_string(),
2434 )
2435 .trace_err()?;
2436
2437 Some(LiveKitConnectionInfo {
2438 server_url: live_kit.url().into(),
2439 token,
2440 })
2441 });
2442
2443 response.send(proto::JoinRoomResponse {
2444 room: Some(joined_room.room.clone()),
2445 live_kit_connection_info,
2446 })?;
2447
2448 room_updated(&joined_room.room, &session.peer);
2449
2450 joined_room.clone()
2451 };
2452
2453 // TODO - do this while still holding the room guard,
2454 // currently there's a possible race condition if someone joins the channel
2455 // after we've dropped the lock but before we finish sending these updates
2456 channel_updated(
2457 channel_id,
2458 &joined_room.room,
2459 &joined_room.channel_members,
2460 &session.peer,
2461 &*session.connection_pool().await,
2462 );
2463
2464 update_user_contacts(session.user_id, &session).await?;
2465
2466 Ok(())
2467}
2468
2469async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
2470 let project_id = ProjectId::from_proto(request.project_id);
2471 let project_connection_ids = session
2472 .db()
2473 .await
2474 .project_connection_ids(project_id, session.connection_id)
2475 .await?;
2476 broadcast(
2477 Some(session.connection_id),
2478 project_connection_ids.iter().copied(),
2479 |connection_id| {
2480 session
2481 .peer
2482 .forward_send(session.connection_id, connection_id, request.clone())
2483 },
2484 );
2485 Ok(())
2486}
2487
2488async fn get_private_user_info(
2489 _request: proto::GetPrivateUserInfo,
2490 response: Response<proto::GetPrivateUserInfo>,
2491 session: Session,
2492) -> Result<()> {
2493 let metrics_id = session
2494 .db()
2495 .await
2496 .get_user_metrics_id(session.user_id)
2497 .await?;
2498 let user = session
2499 .db()
2500 .await
2501 .get_user_by_id(session.user_id)
2502 .await?
2503 .ok_or_else(|| anyhow!("user not found"))?;
2504 response.send(proto::GetPrivateUserInfoResponse {
2505 metrics_id,
2506 staff: user.admin,
2507 })?;
2508 Ok(())
2509}
2510
2511fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
2512 match message {
2513 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
2514 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
2515 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
2516 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
2517 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
2518 code: frame.code.into(),
2519 reason: frame.reason,
2520 })),
2521 }
2522}
2523
2524fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
2525 match message {
2526 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
2527 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
2528 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
2529 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
2530 AxumMessage::Close(frame) => {
2531 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
2532 code: frame.code.into(),
2533 reason: frame.reason,
2534 }))
2535 }
2536 }
2537}
2538
2539fn build_initial_channels_update(
2540 channels: Vec<db::Channel>,
2541 channel_participants: HashMap<db::ChannelId, Vec<UserId>>,
2542 channel_invites: Vec<db::Channel>,
2543) -> proto::UpdateChannels {
2544 let mut update = proto::UpdateChannels::default();
2545
2546 for channel in channels {
2547 update.channels.push(proto::Channel {
2548 id: channel.id.to_proto(),
2549 name: channel.name,
2550 parent_id: channel.parent_id.map(|id| id.to_proto()),
2551 });
2552 }
2553
2554 for (channel_id, participants) in channel_participants {
2555 update
2556 .channel_participants
2557 .push(proto::ChannelParticipants {
2558 channel_id: channel_id.to_proto(),
2559 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
2560 });
2561 }
2562
2563 for channel in channel_invites {
2564 update.channel_invitations.push(proto::Channel {
2565 id: channel.id.to_proto(),
2566 name: channel.name,
2567 parent_id: None,
2568 });
2569 }
2570
2571 update
2572}
2573
2574fn build_initial_contacts_update(
2575 contacts: Vec<db::Contact>,
2576 pool: &ConnectionPool,
2577) -> proto::UpdateContacts {
2578 let mut update = proto::UpdateContacts::default();
2579
2580 for contact in contacts {
2581 match contact {
2582 db::Contact::Accepted {
2583 user_id,
2584 should_notify,
2585 busy,
2586 } => {
2587 update
2588 .contacts
2589 .push(contact_for_user(user_id, should_notify, busy, &pool));
2590 }
2591 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
2592 db::Contact::Incoming {
2593 user_id,
2594 should_notify,
2595 } => update
2596 .incoming_requests
2597 .push(proto::IncomingContactRequest {
2598 requester_id: user_id.to_proto(),
2599 should_notify,
2600 }),
2601 }
2602 }
2603
2604 update
2605}
2606
2607fn contact_for_user(
2608 user_id: UserId,
2609 should_notify: bool,
2610 busy: bool,
2611 pool: &ConnectionPool,
2612) -> proto::Contact {
2613 proto::Contact {
2614 user_id: user_id.to_proto(),
2615 online: pool.is_user_online(user_id),
2616 busy,
2617 should_notify,
2618 }
2619}
2620
2621fn room_updated(room: &proto::Room, peer: &Peer) {
2622 broadcast(
2623 None,
2624 room.participants
2625 .iter()
2626 .filter_map(|participant| Some(participant.peer_id?.into())),
2627 |peer_id| {
2628 peer.send(
2629 peer_id.into(),
2630 proto::RoomUpdated {
2631 room: Some(room.clone()),
2632 },
2633 )
2634 },
2635 );
2636}
2637
2638fn channel_updated(
2639 channel_id: ChannelId,
2640 room: &proto::Room,
2641 channel_members: &[UserId],
2642 peer: &Peer,
2643 pool: &ConnectionPool,
2644) {
2645 let participants = room
2646 .participants
2647 .iter()
2648 .map(|p| p.user_id)
2649 .collect::<Vec<_>>();
2650
2651 broadcast(
2652 None,
2653 channel_members
2654 .iter()
2655 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2656 |peer_id| {
2657 peer.send(
2658 peer_id.into(),
2659 proto::UpdateChannels {
2660 channel_participants: vec![proto::ChannelParticipants {
2661 channel_id: channel_id.to_proto(),
2662 participant_user_ids: participants.clone(),
2663 }],
2664 ..Default::default()
2665 },
2666 )
2667 },
2668 );
2669}
2670
2671async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
2672 let db = session.db().await;
2673
2674 let contacts = db.get_contacts(user_id).await?;
2675 let busy = db.is_user_busy(user_id).await?;
2676
2677 let pool = session.connection_pool().await;
2678 let updated_contact = contact_for_user(user_id, false, busy, &pool);
2679 for contact in contacts {
2680 if let db::Contact::Accepted {
2681 user_id: contact_user_id,
2682 ..
2683 } = contact
2684 {
2685 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
2686 session
2687 .peer
2688 .send(
2689 contact_conn_id,
2690 proto::UpdateContacts {
2691 contacts: vec![updated_contact.clone()],
2692 remove_contacts: Default::default(),
2693 incoming_requests: Default::default(),
2694 remove_incoming_requests: Default::default(),
2695 outgoing_requests: Default::default(),
2696 remove_outgoing_requests: Default::default(),
2697 },
2698 )
2699 .trace_err();
2700 }
2701 }
2702 }
2703 Ok(())
2704}
2705
2706async fn leave_room_for_session(session: &Session) -> Result<()> {
2707 let mut contacts_to_update = HashSet::default();
2708
2709 let room_id;
2710 let canceled_calls_to_user_ids;
2711 let live_kit_room;
2712 let delete_live_kit_room;
2713 let room;
2714 let channel_members;
2715 let channel_id;
2716
2717 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
2718 contacts_to_update.insert(session.user_id);
2719
2720 for project in left_room.left_projects.values() {
2721 project_left(project, session);
2722 }
2723
2724 room_id = RoomId::from_proto(left_room.room.id);
2725 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
2726 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
2727 delete_live_kit_room = left_room.deleted;
2728 room = mem::take(&mut left_room.room);
2729 channel_members = mem::take(&mut left_room.channel_members);
2730 channel_id = left_room.channel_id;
2731
2732 room_updated(&room, &session.peer);
2733 } else {
2734 return Ok(());
2735 }
2736
2737 // TODO - do this while holding the room guard.
2738 if let Some(channel_id) = channel_id {
2739 channel_updated(
2740 channel_id,
2741 &room,
2742 &channel_members,
2743 &session.peer,
2744 &*session.connection_pool().await,
2745 );
2746 }
2747
2748 {
2749 let pool = session.connection_pool().await;
2750 for canceled_user_id in canceled_calls_to_user_ids {
2751 for connection_id in pool.user_connection_ids(canceled_user_id) {
2752 session
2753 .peer
2754 .send(
2755 connection_id,
2756 proto::CallCanceled {
2757 room_id: room_id.to_proto(),
2758 },
2759 )
2760 .trace_err();
2761 }
2762 contacts_to_update.insert(canceled_user_id);
2763 }
2764 }
2765
2766 for contact_user_id in contacts_to_update {
2767 update_user_contacts(contact_user_id, &session).await?;
2768 }
2769
2770 if let Some(live_kit) = session.live_kit_client.as_ref() {
2771 live_kit
2772 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
2773 .await
2774 .trace_err();
2775
2776 if delete_live_kit_room {
2777 live_kit.delete_room(live_kit_room).await.trace_err();
2778 }
2779 }
2780
2781 Ok(())
2782}
2783
2784fn project_left(project: &db::LeftProject, session: &Session) {
2785 for connection_id in &project.connection_ids {
2786 if project.host_user_id == session.user_id {
2787 session
2788 .peer
2789 .send(
2790 *connection_id,
2791 proto::UnshareProject {
2792 project_id: project.id.to_proto(),
2793 },
2794 )
2795 .trace_err();
2796 } else {
2797 session
2798 .peer
2799 .send(
2800 *connection_id,
2801 proto::RemoveProjectCollaborator {
2802 project_id: project.id.to_proto(),
2803 peer_id: Some(session.connection_id.into()),
2804 },
2805 )
2806 .trace_err();
2807 }
2808 }
2809}
2810
2811pub trait ResultExt {
2812 type Ok;
2813
2814 fn trace_err(self) -> Option<Self::Ok>;
2815}
2816
2817impl<T, E> ResultExt for Result<T, E>
2818where
2819 E: std::fmt::Debug,
2820{
2821 type Ok = T;
2822
2823 fn trace_err(self) -> Option<T> {
2824 match self {
2825 Ok(value) => Some(value),
2826 Err(error) => {
2827 tracing::error!("{:?}", error);
2828 None
2829 }
2830 }
2831 }
2832}