1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{self, ChannelId, ChannelsForUser, Database, ProjectId, RoomId, ServerId, User, UserId},
6 executor::Executor,
7 AppState, Result,
8};
9use anyhow::anyhow;
10use async_tungstenite::tungstenite::{
11 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
12};
13use axum::{
14 body::Body,
15 extract::{
16 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
17 ConnectInfo, WebSocketUpgrade,
18 },
19 headers::{Header, HeaderName},
20 http::StatusCode,
21 middleware,
22 response::IntoResponse,
23 routing::get,
24 Extension, Router, TypedHeader,
25};
26use collections::{HashMap, HashSet};
27pub use connection_pool::ConnectionPool;
28use futures::{
29 channel::oneshot,
30 future::{self, BoxFuture},
31 stream::FuturesUnordered,
32 FutureExt, SinkExt, StreamExt, TryStreamExt,
33};
34use lazy_static::lazy_static;
35use prometheus::{register_int_gauge, IntGauge};
36use rpc::{
37 proto::{
38 self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
39 RequestMessage,
40 },
41 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
42};
43use serde::{Serialize, Serializer};
44use std::{
45 any::TypeId,
46 fmt,
47 future::Future,
48 marker::PhantomData,
49 mem,
50 net::SocketAddr,
51 ops::{Deref, DerefMut},
52 rc::Rc,
53 sync::{
54 atomic::{AtomicBool, Ordering::SeqCst},
55 Arc,
56 },
57 time::{Duration, Instant},
58};
59use tokio::sync::{watch, Semaphore};
60use tower::ServiceBuilder;
61use tracing::{info_span, instrument, Instrument};
62
63pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
64pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
65
66lazy_static! {
67 static ref METRIC_CONNECTIONS: IntGauge =
68 register_int_gauge!("connections", "number of connections").unwrap();
69 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
70 "shared_projects",
71 "number of open projects with one or more guests"
72 )
73 .unwrap();
74}
75
76type MessageHandler =
77 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
78
79struct Response<R> {
80 peer: Arc<Peer>,
81 receipt: Receipt<R>,
82 responded: Arc<AtomicBool>,
83}
84
85impl<R: RequestMessage> Response<R> {
86 fn send(self, payload: R::Response) -> Result<()> {
87 self.responded.store(true, SeqCst);
88 self.peer.respond(self.receipt, payload)?;
89 Ok(())
90 }
91}
92
93#[derive(Clone)]
94struct Session {
95 user_id: UserId,
96 connection_id: ConnectionId,
97 db: Arc<tokio::sync::Mutex<DbHandle>>,
98 peer: Arc<Peer>,
99 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
100 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
101 executor: Executor,
102}
103
104impl Session {
105 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
106 #[cfg(test)]
107 tokio::task::yield_now().await;
108 let guard = self.db.lock().await;
109 #[cfg(test)]
110 tokio::task::yield_now().await;
111 guard
112 }
113
114 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
115 #[cfg(test)]
116 tokio::task::yield_now().await;
117 let guard = self.connection_pool.lock();
118 ConnectionPoolGuard {
119 guard,
120 _not_send: PhantomData,
121 }
122 }
123}
124
125impl fmt::Debug for Session {
126 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127 f.debug_struct("Session")
128 .field("user_id", &self.user_id)
129 .field("connection_id", &self.connection_id)
130 .finish()
131 }
132}
133
134struct DbHandle(Arc<Database>);
135
136impl Deref for DbHandle {
137 type Target = Database;
138
139 fn deref(&self) -> &Self::Target {
140 self.0.as_ref()
141 }
142}
143
144pub struct Server {
145 id: parking_lot::Mutex<ServerId>,
146 peer: Arc<Peer>,
147 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
148 app_state: Arc<AppState>,
149 executor: Executor,
150 handlers: HashMap<TypeId, MessageHandler>,
151 teardown: watch::Sender<()>,
152}
153
154pub(crate) struct ConnectionPoolGuard<'a> {
155 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
156 _not_send: PhantomData<Rc<()>>,
157}
158
159#[derive(Serialize)]
160pub struct ServerSnapshot<'a> {
161 peer: &'a Peer,
162 #[serde(serialize_with = "serialize_deref")]
163 connection_pool: ConnectionPoolGuard<'a>,
164}
165
166pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
167where
168 S: Serializer,
169 T: Deref<Target = U>,
170 U: Serialize,
171{
172 Serialize::serialize(value.deref(), serializer)
173}
174
175impl Server {
176 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
177 let mut server = Self {
178 id: parking_lot::Mutex::new(id),
179 peer: Peer::new(id.0 as u32),
180 app_state,
181 executor,
182 connection_pool: Default::default(),
183 handlers: Default::default(),
184 teardown: watch::channel(()).0,
185 };
186
187 server
188 .add_request_handler(ping)
189 .add_request_handler(create_room)
190 .add_request_handler(join_room)
191 .add_request_handler(rejoin_room)
192 .add_request_handler(leave_room)
193 .add_request_handler(call)
194 .add_request_handler(cancel_call)
195 .add_message_handler(decline_call)
196 .add_request_handler(update_participant_location)
197 .add_request_handler(share_project)
198 .add_message_handler(unshare_project)
199 .add_request_handler(join_project)
200 .add_message_handler(leave_project)
201 .add_request_handler(update_project)
202 .add_request_handler(update_worktree)
203 .add_message_handler(start_language_server)
204 .add_message_handler(update_language_server)
205 .add_message_handler(update_diagnostic_summary)
206 .add_message_handler(update_worktree_settings)
207 .add_message_handler(refresh_inlay_hints)
208 .add_request_handler(forward_project_request::<proto::GetHover>)
209 .add_request_handler(forward_project_request::<proto::GetDefinition>)
210 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
211 .add_request_handler(forward_project_request::<proto::GetReferences>)
212 .add_request_handler(forward_project_request::<proto::SearchProject>)
213 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
214 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
215 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
216 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
217 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
218 .add_request_handler(forward_project_request::<proto::GetCompletions>)
219 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
220 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
221 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
222 .add_request_handler(forward_project_request::<proto::PrepareRename>)
223 .add_request_handler(forward_project_request::<proto::PerformRename>)
224 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
225 .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
226 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
227 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
228 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
229 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
230 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
231 .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
232 .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
233 .add_request_handler(forward_project_request::<proto::InlayHints>)
234 .add_message_handler(create_buffer_for_peer)
235 .add_request_handler(update_buffer)
236 .add_message_handler(update_buffer_file)
237 .add_message_handler(buffer_reloaded)
238 .add_message_handler(buffer_saved)
239 .add_request_handler(forward_project_request::<proto::SaveBuffer>)
240 .add_request_handler(get_users)
241 .add_request_handler(fuzzy_search_users)
242 .add_request_handler(request_contact)
243 .add_request_handler(remove_contact)
244 .add_request_handler(respond_to_contact_request)
245 .add_request_handler(create_channel)
246 .add_request_handler(remove_channel)
247 .add_request_handler(invite_channel_member)
248 .add_request_handler(remove_channel_member)
249 .add_request_handler(set_channel_member_admin)
250 .add_request_handler(rename_channel)
251 .add_request_handler(get_channel_members)
252 .add_request_handler(respond_to_channel_invite)
253 .add_request_handler(join_channel)
254 .add_request_handler(follow)
255 .add_message_handler(unfollow)
256 .add_message_handler(update_followers)
257 .add_message_handler(update_diff_base)
258 .add_request_handler(get_private_user_info);
259
260 Arc::new(server)
261 }
262
263 pub async fn start(&self) -> Result<()> {
264 let server_id = *self.id.lock();
265 let app_state = self.app_state.clone();
266 let peer = self.peer.clone();
267 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
268 let pool = self.connection_pool.clone();
269 let live_kit_client = self.app_state.live_kit_client.clone();
270
271 let span = info_span!("start server");
272 self.executor.spawn_detached(
273 async move {
274 tracing::info!("waiting for cleanup timeout");
275 timeout.await;
276 tracing::info!("cleanup timeout expired, retrieving stale rooms");
277 if let Some(room_ids) = app_state
278 .db
279 .stale_room_ids(&app_state.config.zed_environment, server_id)
280 .await
281 .trace_err()
282 {
283 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
284 for room_id in room_ids {
285 let mut contacts_to_update = HashSet::default();
286 let mut canceled_calls_to_user_ids = Vec::new();
287 let mut live_kit_room = String::new();
288 let mut delete_live_kit_room = false;
289
290 if let Some(mut refreshed_room) = app_state
291 .db
292 .refresh_room(room_id, server_id)
293 .await
294 .trace_err()
295 {
296 tracing::info!(
297 room_id = room_id.0,
298 new_participant_count = refreshed_room.room.participants.len(),
299 "refreshed room"
300 );
301 room_updated(&refreshed_room.room, &peer);
302 if let Some(channel_id) = refreshed_room.channel_id {
303 channel_updated(
304 channel_id,
305 &refreshed_room.room,
306 &refreshed_room.channel_members,
307 &peer,
308 &*pool.lock(),
309 );
310 }
311 contacts_to_update
312 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
313 contacts_to_update
314 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
315 canceled_calls_to_user_ids =
316 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
317 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
318 delete_live_kit_room = refreshed_room.room.participants.is_empty();
319 }
320
321 {
322 let pool = pool.lock();
323 for canceled_user_id in canceled_calls_to_user_ids {
324 for connection_id in pool.user_connection_ids(canceled_user_id) {
325 peer.send(
326 connection_id,
327 proto::CallCanceled {
328 room_id: room_id.to_proto(),
329 },
330 )
331 .trace_err();
332 }
333 }
334 }
335
336 for user_id in contacts_to_update {
337 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
338 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
339 if let Some((busy, contacts)) = busy.zip(contacts) {
340 let pool = pool.lock();
341 let updated_contact = contact_for_user(user_id, false, busy, &pool);
342 for contact in contacts {
343 if let db::Contact::Accepted {
344 user_id: contact_user_id,
345 ..
346 } = contact
347 {
348 for contact_conn_id in
349 pool.user_connection_ids(contact_user_id)
350 {
351 peer.send(
352 contact_conn_id,
353 proto::UpdateContacts {
354 contacts: vec![updated_contact.clone()],
355 remove_contacts: Default::default(),
356 incoming_requests: Default::default(),
357 remove_incoming_requests: Default::default(),
358 outgoing_requests: Default::default(),
359 remove_outgoing_requests: Default::default(),
360 },
361 )
362 .trace_err();
363 }
364 }
365 }
366 }
367 }
368
369 if let Some(live_kit) = live_kit_client.as_ref() {
370 if delete_live_kit_room {
371 live_kit.delete_room(live_kit_room).await.trace_err();
372 }
373 }
374 }
375 }
376
377 app_state
378 .db
379 .delete_stale_servers(&app_state.config.zed_environment, server_id)
380 .await
381 .trace_err();
382 }
383 .instrument(span),
384 );
385 Ok(())
386 }
387
388 pub fn teardown(&self) {
389 self.peer.teardown();
390 self.connection_pool.lock().reset();
391 let _ = self.teardown.send(());
392 }
393
394 #[cfg(test)]
395 pub fn reset(&self, id: ServerId) {
396 self.teardown();
397 *self.id.lock() = id;
398 self.peer.reset(id.0 as u32);
399 }
400
401 #[cfg(test)]
402 pub fn id(&self) -> ServerId {
403 *self.id.lock()
404 }
405
406 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
407 where
408 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
409 Fut: 'static + Send + Future<Output = Result<()>>,
410 M: EnvelopedMessage,
411 {
412 let prev_handler = self.handlers.insert(
413 TypeId::of::<M>(),
414 Box::new(move |envelope, session| {
415 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
416 let span = info_span!(
417 "handle message",
418 payload_type = envelope.payload_type_name()
419 );
420 span.in_scope(|| {
421 tracing::info!(
422 payload_type = envelope.payload_type_name(),
423 "message received"
424 );
425 });
426 let start_time = Instant::now();
427 let future = (handler)(*envelope, session);
428 async move {
429 let result = future.await;
430 let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
431 match result {
432 Err(error) => {
433 tracing::error!(%error, ?duration_ms, "error handling message")
434 }
435 Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
436 }
437 }
438 .instrument(span)
439 .boxed()
440 }),
441 );
442 if prev_handler.is_some() {
443 panic!("registered a handler for the same message twice");
444 }
445 self
446 }
447
448 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
449 where
450 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
451 Fut: 'static + Send + Future<Output = Result<()>>,
452 M: EnvelopedMessage,
453 {
454 self.add_handler(move |envelope, session| handler(envelope.payload, session));
455 self
456 }
457
458 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
459 where
460 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
461 Fut: Send + Future<Output = Result<()>>,
462 M: RequestMessage,
463 {
464 let handler = Arc::new(handler);
465 self.add_handler(move |envelope, session| {
466 let receipt = envelope.receipt();
467 let handler = handler.clone();
468 async move {
469 let peer = session.peer.clone();
470 let responded = Arc::new(AtomicBool::default());
471 let response = Response {
472 peer: peer.clone(),
473 responded: responded.clone(),
474 receipt,
475 };
476 match (handler)(envelope.payload, response, session).await {
477 Ok(()) => {
478 if responded.load(std::sync::atomic::Ordering::SeqCst) {
479 Ok(())
480 } else {
481 Err(anyhow!("handler did not send a response"))?
482 }
483 }
484 Err(error) => {
485 peer.respond_with_error(
486 receipt,
487 proto::Error {
488 message: error.to_string(),
489 },
490 )?;
491 Err(error)
492 }
493 }
494 }
495 })
496 }
497
498 pub fn handle_connection(
499 self: &Arc<Self>,
500 connection: Connection,
501 address: String,
502 user: User,
503 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
504 executor: Executor,
505 ) -> impl Future<Output = Result<()>> {
506 let this = self.clone();
507 let user_id = user.id;
508 let login = user.github_login;
509 let span = info_span!("handle connection", %user_id, %login, %address);
510 let mut teardown = self.teardown.subscribe();
511 async move {
512 let (connection_id, handle_io, mut incoming_rx) = this
513 .peer
514 .add_connection(connection, {
515 let executor = executor.clone();
516 move |duration| executor.sleep(duration)
517 });
518
519 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
520 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
521 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
522
523 if let Some(send_connection_id) = send_connection_id.take() {
524 let _ = send_connection_id.send(connection_id);
525 }
526
527 if !user.connected_once {
528 this.peer.send(connection_id, proto::ShowContacts {})?;
529 this.app_state.db.set_user_connected_once(user_id, true).await?;
530 }
531
532 let (contacts, invite_code, channels_for_user, channel_invites) = future::try_join4(
533 this.app_state.db.get_contacts(user_id),
534 this.app_state.db.get_invite_code_for_user(user_id),
535 this.app_state.db.get_channels_for_user(user_id),
536 this.app_state.db.get_channel_invites_for_user(user_id)
537 ).await?;
538
539 {
540 let mut pool = this.connection_pool.lock();
541 pool.add_connection(connection_id, user_id, user.admin);
542 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
543 this.peer.send(connection_id, build_initial_channels_update(
544 channels_for_user,
545 channel_invites
546 ))?;
547
548 if let Some((code, count)) = invite_code {
549 this.peer.send(connection_id, proto::UpdateInviteInfo {
550 url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
551 count: count as u32,
552 })?;
553 }
554 }
555
556 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
557 this.peer.send(connection_id, incoming_call)?;
558 }
559
560 let session = Session {
561 user_id,
562 connection_id,
563 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
564 peer: this.peer.clone(),
565 connection_pool: this.connection_pool.clone(),
566 live_kit_client: this.app_state.live_kit_client.clone(),
567 executor: executor.clone(),
568 };
569 update_user_contacts(user_id, &session).await?;
570
571 let handle_io = handle_io.fuse();
572 futures::pin_mut!(handle_io);
573
574 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
575 // This prevents deadlocks when e.g., client A performs a request to client B and
576 // client B performs a request to client A. If both clients stop processing further
577 // messages until their respective request completes, they won't have a chance to
578 // respond to the other client's request and cause a deadlock.
579 //
580 // This arrangement ensures we will attempt to process earlier messages first, but fall
581 // back to processing messages arrived later in the spirit of making progress.
582 let mut foreground_message_handlers = FuturesUnordered::new();
583 let concurrent_handlers = Arc::new(Semaphore::new(256));
584 loop {
585 let next_message = async {
586 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
587 let message = incoming_rx.next().await;
588 (permit, message)
589 }.fuse();
590 futures::pin_mut!(next_message);
591 futures::select_biased! {
592 _ = teardown.changed().fuse() => return Ok(()),
593 result = handle_io => {
594 if let Err(error) = result {
595 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
596 }
597 break;
598 }
599 _ = foreground_message_handlers.next() => {}
600 next_message = next_message => {
601 let (permit, message) = next_message;
602 if let Some(message) = message {
603 let type_name = message.payload_type_name();
604 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
605 let span_enter = span.enter();
606 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
607 let is_background = message.is_background();
608 let handle_message = (handler)(message, session.clone());
609 drop(span_enter);
610
611 let handle_message = async move {
612 handle_message.await;
613 drop(permit);
614 }.instrument(span);
615 if is_background {
616 executor.spawn_detached(handle_message);
617 } else {
618 foreground_message_handlers.push(handle_message);
619 }
620 } else {
621 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
622 }
623 } else {
624 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
625 break;
626 }
627 }
628 }
629 }
630
631 drop(foreground_message_handlers);
632 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
633 if let Err(error) = connection_lost(session, teardown, executor).await {
634 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
635 }
636
637 Ok(())
638 }.instrument(span)
639 }
640
641 pub async fn invite_code_redeemed(
642 self: &Arc<Self>,
643 inviter_id: UserId,
644 invitee_id: UserId,
645 ) -> Result<()> {
646 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
647 if let Some(code) = &user.invite_code {
648 let pool = self.connection_pool.lock();
649 let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
650 for connection_id in pool.user_connection_ids(inviter_id) {
651 self.peer.send(
652 connection_id,
653 proto::UpdateContacts {
654 contacts: vec![invitee_contact.clone()],
655 ..Default::default()
656 },
657 )?;
658 self.peer.send(
659 connection_id,
660 proto::UpdateInviteInfo {
661 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
662 count: user.invite_count as u32,
663 },
664 )?;
665 }
666 }
667 }
668 Ok(())
669 }
670
671 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
672 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
673 if let Some(invite_code) = &user.invite_code {
674 let pool = self.connection_pool.lock();
675 for connection_id in pool.user_connection_ids(user_id) {
676 self.peer.send(
677 connection_id,
678 proto::UpdateInviteInfo {
679 url: format!(
680 "{}{}",
681 self.app_state.config.invite_link_prefix, invite_code
682 ),
683 count: user.invite_count as u32,
684 },
685 )?;
686 }
687 }
688 }
689 Ok(())
690 }
691
692 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
693 ServerSnapshot {
694 connection_pool: ConnectionPoolGuard {
695 guard: self.connection_pool.lock(),
696 _not_send: PhantomData,
697 },
698 peer: &self.peer,
699 }
700 }
701}
702
703impl<'a> Deref for ConnectionPoolGuard<'a> {
704 type Target = ConnectionPool;
705
706 fn deref(&self) -> &Self::Target {
707 &*self.guard
708 }
709}
710
711impl<'a> DerefMut for ConnectionPoolGuard<'a> {
712 fn deref_mut(&mut self) -> &mut Self::Target {
713 &mut *self.guard
714 }
715}
716
717impl<'a> Drop for ConnectionPoolGuard<'a> {
718 fn drop(&mut self) {
719 #[cfg(test)]
720 self.check_invariants();
721 }
722}
723
724fn broadcast<F>(
725 sender_id: Option<ConnectionId>,
726 receiver_ids: impl IntoIterator<Item = ConnectionId>,
727 mut f: F,
728) where
729 F: FnMut(ConnectionId) -> anyhow::Result<()>,
730{
731 for receiver_id in receiver_ids {
732 if Some(receiver_id) != sender_id {
733 if let Err(error) = f(receiver_id) {
734 tracing::error!("failed to send to {:?} {}", receiver_id, error);
735 }
736 }
737 }
738}
739
740lazy_static! {
741 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
742}
743
744pub struct ProtocolVersion(u32);
745
746impl Header for ProtocolVersion {
747 fn name() -> &'static HeaderName {
748 &ZED_PROTOCOL_VERSION
749 }
750
751 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
752 where
753 Self: Sized,
754 I: Iterator<Item = &'i axum::http::HeaderValue>,
755 {
756 let version = values
757 .next()
758 .ok_or_else(axum::headers::Error::invalid)?
759 .to_str()
760 .map_err(|_| axum::headers::Error::invalid())?
761 .parse()
762 .map_err(|_| axum::headers::Error::invalid())?;
763 Ok(Self(version))
764 }
765
766 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
767 values.extend([self.0.to_string().parse().unwrap()]);
768 }
769}
770
771pub fn routes(server: Arc<Server>) -> Router<Body> {
772 Router::new()
773 .route("/rpc", get(handle_websocket_request))
774 .layer(
775 ServiceBuilder::new()
776 .layer(Extension(server.app_state.clone()))
777 .layer(middleware::from_fn(auth::validate_header)),
778 )
779 .route("/metrics", get(handle_metrics))
780 .layer(Extension(server))
781}
782
783pub async fn handle_websocket_request(
784 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
785 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
786 Extension(server): Extension<Arc<Server>>,
787 Extension(user): Extension<User>,
788 ws: WebSocketUpgrade,
789) -> axum::response::Response {
790 if protocol_version != rpc::PROTOCOL_VERSION {
791 return (
792 StatusCode::UPGRADE_REQUIRED,
793 "client must be upgraded".to_string(),
794 )
795 .into_response();
796 }
797 let socket_address = socket_address.to_string();
798 ws.on_upgrade(move |socket| {
799 use util::ResultExt;
800 let socket = socket
801 .map_ok(to_tungstenite_message)
802 .err_into()
803 .with(|message| async move { Ok(to_axum_message(message)) });
804 let connection = Connection::new(Box::pin(socket));
805 async move {
806 server
807 .handle_connection(connection, socket_address, user, None, Executor::Production)
808 .await
809 .log_err();
810 }
811 })
812}
813
814pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
815 let connections = server
816 .connection_pool
817 .lock()
818 .connections()
819 .filter(|connection| !connection.admin)
820 .count();
821
822 METRIC_CONNECTIONS.set(connections as _);
823
824 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
825 METRIC_SHARED_PROJECTS.set(shared_projects as _);
826
827 let encoder = prometheus::TextEncoder::new();
828 let metric_families = prometheus::gather();
829 let encoded_metrics = encoder
830 .encode_to_string(&metric_families)
831 .map_err(|err| anyhow!("{}", err))?;
832 Ok(encoded_metrics)
833}
834
835#[instrument(err, skip(executor))]
836async fn connection_lost(
837 session: Session,
838 mut teardown: watch::Receiver<()>,
839 executor: Executor,
840) -> Result<()> {
841 session.peer.disconnect(session.connection_id);
842 session
843 .connection_pool()
844 .await
845 .remove_connection(session.connection_id)?;
846
847 session
848 .db()
849 .await
850 .connection_lost(session.connection_id)
851 .await
852 .trace_err();
853
854 futures::select_biased! {
855 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
856 leave_room_for_session(&session).await.trace_err();
857
858 if !session
859 .connection_pool()
860 .await
861 .is_user_online(session.user_id)
862 {
863 let db = session.db().await;
864 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
865 room_updated(&room, &session.peer);
866 }
867 }
868 update_user_contacts(session.user_id, &session).await?;
869 }
870 _ = teardown.changed().fuse() => {}
871 }
872
873 Ok(())
874}
875
876async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
877 response.send(proto::Ack {})?;
878 Ok(())
879}
880
881async fn create_room(
882 _request: proto::CreateRoom,
883 response: Response<proto::CreateRoom>,
884 session: Session,
885) -> Result<()> {
886 let live_kit_room = nanoid::nanoid!(30);
887
888 let live_kit_connection_info = {
889 let live_kit_room = live_kit_room.clone();
890 let live_kit = session.live_kit_client.as_ref();
891
892 util::async_iife!({
893 let live_kit = live_kit?;
894
895 live_kit
896 .create_room(live_kit_room.clone())
897 .await
898 .trace_err()?;
899
900 let token = live_kit
901 .room_token(&live_kit_room, &session.user_id.to_string())
902 .trace_err()?;
903
904 Some(proto::LiveKitConnectionInfo {
905 server_url: live_kit.url().into(),
906 token,
907 })
908 })
909 }
910 .await;
911
912 let room = session
913 .db()
914 .await
915 .create_room(session.user_id, session.connection_id, &live_kit_room)
916 .await?;
917
918 response.send(proto::CreateRoomResponse {
919 room: Some(room.clone()),
920 live_kit_connection_info,
921 })?;
922
923 update_user_contacts(session.user_id, &session).await?;
924 Ok(())
925}
926
927async fn join_room(
928 request: proto::JoinRoom,
929 response: Response<proto::JoinRoom>,
930 session: Session,
931) -> Result<()> {
932 let room_id = RoomId::from_proto(request.id);
933 let room = {
934 let room = session
935 .db()
936 .await
937 .join_room(room_id, session.user_id, None, session.connection_id)
938 .await?;
939 room_updated(&room.room, &session.peer);
940 room.room.clone()
941 };
942
943 for connection_id in session
944 .connection_pool()
945 .await
946 .user_connection_ids(session.user_id)
947 {
948 session
949 .peer
950 .send(
951 connection_id,
952 proto::CallCanceled {
953 room_id: room_id.to_proto(),
954 },
955 )
956 .trace_err();
957 }
958
959 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
960 if let Some(token) = live_kit
961 .room_token(&room.live_kit_room, &session.user_id.to_string())
962 .trace_err()
963 {
964 Some(proto::LiveKitConnectionInfo {
965 server_url: live_kit.url().into(),
966 token,
967 })
968 } else {
969 None
970 }
971 } else {
972 None
973 };
974
975 response.send(proto::JoinRoomResponse {
976 room: Some(room),
977 live_kit_connection_info,
978 })?;
979
980 update_user_contacts(session.user_id, &session).await?;
981 Ok(())
982}
983
984async fn rejoin_room(
985 request: proto::RejoinRoom,
986 response: Response<proto::RejoinRoom>,
987 session: Session,
988) -> Result<()> {
989 let room;
990 let channel_id;
991 let channel_members;
992 {
993 let mut rejoined_room = session
994 .db()
995 .await
996 .rejoin_room(request, session.user_id, session.connection_id)
997 .await?;
998
999 response.send(proto::RejoinRoomResponse {
1000 room: Some(rejoined_room.room.clone()),
1001 reshared_projects: rejoined_room
1002 .reshared_projects
1003 .iter()
1004 .map(|project| proto::ResharedProject {
1005 id: project.id.to_proto(),
1006 collaborators: project
1007 .collaborators
1008 .iter()
1009 .map(|collaborator| collaborator.to_proto())
1010 .collect(),
1011 })
1012 .collect(),
1013 rejoined_projects: rejoined_room
1014 .rejoined_projects
1015 .iter()
1016 .map(|rejoined_project| proto::RejoinedProject {
1017 id: rejoined_project.id.to_proto(),
1018 worktrees: rejoined_project
1019 .worktrees
1020 .iter()
1021 .map(|worktree| proto::WorktreeMetadata {
1022 id: worktree.id,
1023 root_name: worktree.root_name.clone(),
1024 visible: worktree.visible,
1025 abs_path: worktree.abs_path.clone(),
1026 })
1027 .collect(),
1028 collaborators: rejoined_project
1029 .collaborators
1030 .iter()
1031 .map(|collaborator| collaborator.to_proto())
1032 .collect(),
1033 language_servers: rejoined_project.language_servers.clone(),
1034 })
1035 .collect(),
1036 })?;
1037 room_updated(&rejoined_room.room, &session.peer);
1038
1039 for project in &rejoined_room.reshared_projects {
1040 for collaborator in &project.collaborators {
1041 session
1042 .peer
1043 .send(
1044 collaborator.connection_id,
1045 proto::UpdateProjectCollaborator {
1046 project_id: project.id.to_proto(),
1047 old_peer_id: Some(project.old_connection_id.into()),
1048 new_peer_id: Some(session.connection_id.into()),
1049 },
1050 )
1051 .trace_err();
1052 }
1053
1054 broadcast(
1055 Some(session.connection_id),
1056 project
1057 .collaborators
1058 .iter()
1059 .map(|collaborator| collaborator.connection_id),
1060 |connection_id| {
1061 session.peer.forward_send(
1062 session.connection_id,
1063 connection_id,
1064 proto::UpdateProject {
1065 project_id: project.id.to_proto(),
1066 worktrees: project.worktrees.clone(),
1067 },
1068 )
1069 },
1070 );
1071 }
1072
1073 for project in &rejoined_room.rejoined_projects {
1074 for collaborator in &project.collaborators {
1075 session
1076 .peer
1077 .send(
1078 collaborator.connection_id,
1079 proto::UpdateProjectCollaborator {
1080 project_id: project.id.to_proto(),
1081 old_peer_id: Some(project.old_connection_id.into()),
1082 new_peer_id: Some(session.connection_id.into()),
1083 },
1084 )
1085 .trace_err();
1086 }
1087 }
1088
1089 for project in &mut rejoined_room.rejoined_projects {
1090 for worktree in mem::take(&mut project.worktrees) {
1091 #[cfg(any(test, feature = "test-support"))]
1092 const MAX_CHUNK_SIZE: usize = 2;
1093 #[cfg(not(any(test, feature = "test-support")))]
1094 const MAX_CHUNK_SIZE: usize = 256;
1095
1096 // Stream this worktree's entries.
1097 let message = proto::UpdateWorktree {
1098 project_id: project.id.to_proto(),
1099 worktree_id: worktree.id,
1100 abs_path: worktree.abs_path.clone(),
1101 root_name: worktree.root_name,
1102 updated_entries: worktree.updated_entries,
1103 removed_entries: worktree.removed_entries,
1104 scan_id: worktree.scan_id,
1105 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1106 updated_repositories: worktree.updated_repositories,
1107 removed_repositories: worktree.removed_repositories,
1108 };
1109 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1110 session.peer.send(session.connection_id, update.clone())?;
1111 }
1112
1113 // Stream this worktree's diagnostics.
1114 for summary in worktree.diagnostic_summaries {
1115 session.peer.send(
1116 session.connection_id,
1117 proto::UpdateDiagnosticSummary {
1118 project_id: project.id.to_proto(),
1119 worktree_id: worktree.id,
1120 summary: Some(summary),
1121 },
1122 )?;
1123 }
1124
1125 for settings_file in worktree.settings_files {
1126 session.peer.send(
1127 session.connection_id,
1128 proto::UpdateWorktreeSettings {
1129 project_id: project.id.to_proto(),
1130 worktree_id: worktree.id,
1131 path: settings_file.path,
1132 content: Some(settings_file.content),
1133 },
1134 )?;
1135 }
1136 }
1137
1138 for language_server in &project.language_servers {
1139 session.peer.send(
1140 session.connection_id,
1141 proto::UpdateLanguageServer {
1142 project_id: project.id.to_proto(),
1143 language_server_id: language_server.id,
1144 variant: Some(
1145 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1146 proto::LspDiskBasedDiagnosticsUpdated {},
1147 ),
1148 ),
1149 },
1150 )?;
1151 }
1152 }
1153
1154 room = mem::take(&mut rejoined_room.room);
1155 channel_id = rejoined_room.channel_id;
1156 channel_members = mem::take(&mut rejoined_room.channel_members);
1157 }
1158
1159 if let Some(channel_id) = channel_id {
1160 channel_updated(
1161 channel_id,
1162 &room,
1163 &channel_members,
1164 &session.peer,
1165 &*session.connection_pool().await,
1166 );
1167 }
1168
1169 update_user_contacts(session.user_id, &session).await?;
1170 Ok(())
1171}
1172
1173async fn leave_room(
1174 _: proto::LeaveRoom,
1175 response: Response<proto::LeaveRoom>,
1176 session: Session,
1177) -> Result<()> {
1178 leave_room_for_session(&session).await?;
1179 response.send(proto::Ack {})?;
1180 Ok(())
1181}
1182
1183async fn call(
1184 request: proto::Call,
1185 response: Response<proto::Call>,
1186 session: Session,
1187) -> Result<()> {
1188 let room_id = RoomId::from_proto(request.room_id);
1189 let calling_user_id = session.user_id;
1190 let calling_connection_id = session.connection_id;
1191 let called_user_id = UserId::from_proto(request.called_user_id);
1192 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1193 if !session
1194 .db()
1195 .await
1196 .has_contact(calling_user_id, called_user_id)
1197 .await?
1198 {
1199 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1200 }
1201
1202 let incoming_call = {
1203 let (room, incoming_call) = &mut *session
1204 .db()
1205 .await
1206 .call(
1207 room_id,
1208 calling_user_id,
1209 calling_connection_id,
1210 called_user_id,
1211 initial_project_id,
1212 )
1213 .await?;
1214 room_updated(&room, &session.peer);
1215 mem::take(incoming_call)
1216 };
1217 update_user_contacts(called_user_id, &session).await?;
1218
1219 let mut calls = session
1220 .connection_pool()
1221 .await
1222 .user_connection_ids(called_user_id)
1223 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1224 .collect::<FuturesUnordered<_>>();
1225
1226 while let Some(call_response) = calls.next().await {
1227 match call_response.as_ref() {
1228 Ok(_) => {
1229 response.send(proto::Ack {})?;
1230 return Ok(());
1231 }
1232 Err(_) => {
1233 call_response.trace_err();
1234 }
1235 }
1236 }
1237
1238 {
1239 let room = session
1240 .db()
1241 .await
1242 .call_failed(room_id, called_user_id)
1243 .await?;
1244 room_updated(&room, &session.peer);
1245 }
1246 update_user_contacts(called_user_id, &session).await?;
1247
1248 Err(anyhow!("failed to ring user"))?
1249}
1250
1251async fn cancel_call(
1252 request: proto::CancelCall,
1253 response: Response<proto::CancelCall>,
1254 session: Session,
1255) -> Result<()> {
1256 let called_user_id = UserId::from_proto(request.called_user_id);
1257 let room_id = RoomId::from_proto(request.room_id);
1258 {
1259 let room = session
1260 .db()
1261 .await
1262 .cancel_call(room_id, session.connection_id, called_user_id)
1263 .await?;
1264 room_updated(&room, &session.peer);
1265 }
1266
1267 for connection_id in session
1268 .connection_pool()
1269 .await
1270 .user_connection_ids(called_user_id)
1271 {
1272 session
1273 .peer
1274 .send(
1275 connection_id,
1276 proto::CallCanceled {
1277 room_id: room_id.to_proto(),
1278 },
1279 )
1280 .trace_err();
1281 }
1282 response.send(proto::Ack {})?;
1283
1284 update_user_contacts(called_user_id, &session).await?;
1285 Ok(())
1286}
1287
1288async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1289 let room_id = RoomId::from_proto(message.room_id);
1290 {
1291 let room = session
1292 .db()
1293 .await
1294 .decline_call(Some(room_id), session.user_id)
1295 .await?
1296 .ok_or_else(|| anyhow!("failed to decline call"))?;
1297 room_updated(&room, &session.peer);
1298 }
1299
1300 for connection_id in session
1301 .connection_pool()
1302 .await
1303 .user_connection_ids(session.user_id)
1304 {
1305 session
1306 .peer
1307 .send(
1308 connection_id,
1309 proto::CallCanceled {
1310 room_id: room_id.to_proto(),
1311 },
1312 )
1313 .trace_err();
1314 }
1315 update_user_contacts(session.user_id, &session).await?;
1316 Ok(())
1317}
1318
1319async fn update_participant_location(
1320 request: proto::UpdateParticipantLocation,
1321 response: Response<proto::UpdateParticipantLocation>,
1322 session: Session,
1323) -> Result<()> {
1324 let room_id = RoomId::from_proto(request.room_id);
1325 let location = request
1326 .location
1327 .ok_or_else(|| anyhow!("invalid location"))?;
1328
1329 let db = session.db().await;
1330 let room = db
1331 .update_room_participant_location(room_id, session.connection_id, location)
1332 .await?;
1333
1334 room_updated(&room, &session.peer);
1335 response.send(proto::Ack {})?;
1336 Ok(())
1337}
1338
1339async fn share_project(
1340 request: proto::ShareProject,
1341 response: Response<proto::ShareProject>,
1342 session: Session,
1343) -> Result<()> {
1344 let (project_id, room) = &*session
1345 .db()
1346 .await
1347 .share_project(
1348 RoomId::from_proto(request.room_id),
1349 session.connection_id,
1350 &request.worktrees,
1351 )
1352 .await?;
1353 response.send(proto::ShareProjectResponse {
1354 project_id: project_id.to_proto(),
1355 })?;
1356 room_updated(&room, &session.peer);
1357
1358 Ok(())
1359}
1360
1361async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1362 let project_id = ProjectId::from_proto(message.project_id);
1363
1364 let (room, guest_connection_ids) = &*session
1365 .db()
1366 .await
1367 .unshare_project(project_id, session.connection_id)
1368 .await?;
1369
1370 broadcast(
1371 Some(session.connection_id),
1372 guest_connection_ids.iter().copied(),
1373 |conn_id| session.peer.send(conn_id, message.clone()),
1374 );
1375 room_updated(&room, &session.peer);
1376
1377 Ok(())
1378}
1379
1380async fn join_project(
1381 request: proto::JoinProject,
1382 response: Response<proto::JoinProject>,
1383 session: Session,
1384) -> Result<()> {
1385 let project_id = ProjectId::from_proto(request.project_id);
1386 let guest_user_id = session.user_id;
1387
1388 tracing::info!(%project_id, "join project");
1389
1390 let (project, replica_id) = &mut *session
1391 .db()
1392 .await
1393 .join_project(project_id, session.connection_id)
1394 .await?;
1395
1396 let collaborators = project
1397 .collaborators
1398 .iter()
1399 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1400 .map(|collaborator| collaborator.to_proto())
1401 .collect::<Vec<_>>();
1402
1403 let worktrees = project
1404 .worktrees
1405 .iter()
1406 .map(|(id, worktree)| proto::WorktreeMetadata {
1407 id: *id,
1408 root_name: worktree.root_name.clone(),
1409 visible: worktree.visible,
1410 abs_path: worktree.abs_path.clone(),
1411 })
1412 .collect::<Vec<_>>();
1413
1414 for collaborator in &collaborators {
1415 session
1416 .peer
1417 .send(
1418 collaborator.peer_id.unwrap().into(),
1419 proto::AddProjectCollaborator {
1420 project_id: project_id.to_proto(),
1421 collaborator: Some(proto::Collaborator {
1422 peer_id: Some(session.connection_id.into()),
1423 replica_id: replica_id.0 as u32,
1424 user_id: guest_user_id.to_proto(),
1425 }),
1426 },
1427 )
1428 .trace_err();
1429 }
1430
1431 // First, we send the metadata associated with each worktree.
1432 response.send(proto::JoinProjectResponse {
1433 worktrees: worktrees.clone(),
1434 replica_id: replica_id.0 as u32,
1435 collaborators: collaborators.clone(),
1436 language_servers: project.language_servers.clone(),
1437 })?;
1438
1439 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1440 #[cfg(any(test, feature = "test-support"))]
1441 const MAX_CHUNK_SIZE: usize = 2;
1442 #[cfg(not(any(test, feature = "test-support")))]
1443 const MAX_CHUNK_SIZE: usize = 256;
1444
1445 // Stream this worktree's entries.
1446 let message = proto::UpdateWorktree {
1447 project_id: project_id.to_proto(),
1448 worktree_id,
1449 abs_path: worktree.abs_path.clone(),
1450 root_name: worktree.root_name,
1451 updated_entries: worktree.entries,
1452 removed_entries: Default::default(),
1453 scan_id: worktree.scan_id,
1454 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1455 updated_repositories: worktree.repository_entries.into_values().collect(),
1456 removed_repositories: Default::default(),
1457 };
1458 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1459 session.peer.send(session.connection_id, update.clone())?;
1460 }
1461
1462 // Stream this worktree's diagnostics.
1463 for summary in worktree.diagnostic_summaries {
1464 session.peer.send(
1465 session.connection_id,
1466 proto::UpdateDiagnosticSummary {
1467 project_id: project_id.to_proto(),
1468 worktree_id: worktree.id,
1469 summary: Some(summary),
1470 },
1471 )?;
1472 }
1473
1474 for settings_file in worktree.settings_files {
1475 session.peer.send(
1476 session.connection_id,
1477 proto::UpdateWorktreeSettings {
1478 project_id: project_id.to_proto(),
1479 worktree_id: worktree.id,
1480 path: settings_file.path,
1481 content: Some(settings_file.content),
1482 },
1483 )?;
1484 }
1485 }
1486
1487 for language_server in &project.language_servers {
1488 session.peer.send(
1489 session.connection_id,
1490 proto::UpdateLanguageServer {
1491 project_id: project_id.to_proto(),
1492 language_server_id: language_server.id,
1493 variant: Some(
1494 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1495 proto::LspDiskBasedDiagnosticsUpdated {},
1496 ),
1497 ),
1498 },
1499 )?;
1500 }
1501
1502 Ok(())
1503}
1504
1505async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1506 let sender_id = session.connection_id;
1507 let project_id = ProjectId::from_proto(request.project_id);
1508
1509 let (room, project) = &*session
1510 .db()
1511 .await
1512 .leave_project(project_id, sender_id)
1513 .await?;
1514 tracing::info!(
1515 %project_id,
1516 host_user_id = %project.host_user_id,
1517 host_connection_id = %project.host_connection_id,
1518 "leave project"
1519 );
1520
1521 project_left(&project, &session);
1522 room_updated(&room, &session.peer);
1523
1524 Ok(())
1525}
1526
1527async fn update_project(
1528 request: proto::UpdateProject,
1529 response: Response<proto::UpdateProject>,
1530 session: Session,
1531) -> Result<()> {
1532 let project_id = ProjectId::from_proto(request.project_id);
1533 let (room, guest_connection_ids) = &*session
1534 .db()
1535 .await
1536 .update_project(project_id, session.connection_id, &request.worktrees)
1537 .await?;
1538 broadcast(
1539 Some(session.connection_id),
1540 guest_connection_ids.iter().copied(),
1541 |connection_id| {
1542 session
1543 .peer
1544 .forward_send(session.connection_id, connection_id, request.clone())
1545 },
1546 );
1547 room_updated(&room, &session.peer);
1548 response.send(proto::Ack {})?;
1549
1550 Ok(())
1551}
1552
1553async fn update_worktree(
1554 request: proto::UpdateWorktree,
1555 response: Response<proto::UpdateWorktree>,
1556 session: Session,
1557) -> Result<()> {
1558 let guest_connection_ids = session
1559 .db()
1560 .await
1561 .update_worktree(&request, session.connection_id)
1562 .await?;
1563
1564 broadcast(
1565 Some(session.connection_id),
1566 guest_connection_ids.iter().copied(),
1567 |connection_id| {
1568 session
1569 .peer
1570 .forward_send(session.connection_id, connection_id, request.clone())
1571 },
1572 );
1573 response.send(proto::Ack {})?;
1574 Ok(())
1575}
1576
1577async fn update_diagnostic_summary(
1578 message: proto::UpdateDiagnosticSummary,
1579 session: Session,
1580) -> Result<()> {
1581 let guest_connection_ids = session
1582 .db()
1583 .await
1584 .update_diagnostic_summary(&message, session.connection_id)
1585 .await?;
1586
1587 broadcast(
1588 Some(session.connection_id),
1589 guest_connection_ids.iter().copied(),
1590 |connection_id| {
1591 session
1592 .peer
1593 .forward_send(session.connection_id, connection_id, message.clone())
1594 },
1595 );
1596
1597 Ok(())
1598}
1599
1600async fn update_worktree_settings(
1601 message: proto::UpdateWorktreeSettings,
1602 session: Session,
1603) -> Result<()> {
1604 let guest_connection_ids = session
1605 .db()
1606 .await
1607 .update_worktree_settings(&message, session.connection_id)
1608 .await?;
1609
1610 broadcast(
1611 Some(session.connection_id),
1612 guest_connection_ids.iter().copied(),
1613 |connection_id| {
1614 session
1615 .peer
1616 .forward_send(session.connection_id, connection_id, message.clone())
1617 },
1618 );
1619
1620 Ok(())
1621}
1622
1623async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1624 broadcast_project_message(request.project_id, request, session).await
1625}
1626
1627async fn start_language_server(
1628 request: proto::StartLanguageServer,
1629 session: Session,
1630) -> Result<()> {
1631 let guest_connection_ids = session
1632 .db()
1633 .await
1634 .start_language_server(&request, session.connection_id)
1635 .await?;
1636
1637 broadcast(
1638 Some(session.connection_id),
1639 guest_connection_ids.iter().copied(),
1640 |connection_id| {
1641 session
1642 .peer
1643 .forward_send(session.connection_id, connection_id, request.clone())
1644 },
1645 );
1646 Ok(())
1647}
1648
1649async fn update_language_server(
1650 request: proto::UpdateLanguageServer,
1651 session: Session,
1652) -> Result<()> {
1653 session.executor.record_backtrace();
1654 let project_id = ProjectId::from_proto(request.project_id);
1655 let project_connection_ids = session
1656 .db()
1657 .await
1658 .project_connection_ids(project_id, session.connection_id)
1659 .await?;
1660 broadcast(
1661 Some(session.connection_id),
1662 project_connection_ids.iter().copied(),
1663 |connection_id| {
1664 session
1665 .peer
1666 .forward_send(session.connection_id, connection_id, request.clone())
1667 },
1668 );
1669 Ok(())
1670}
1671
1672async fn forward_project_request<T>(
1673 request: T,
1674 response: Response<T>,
1675 session: Session,
1676) -> Result<()>
1677where
1678 T: EntityMessage + RequestMessage,
1679{
1680 session.executor.record_backtrace();
1681 let project_id = ProjectId::from_proto(request.remote_entity_id());
1682 let host_connection_id = {
1683 let collaborators = session
1684 .db()
1685 .await
1686 .project_collaborators(project_id, session.connection_id)
1687 .await?;
1688 collaborators
1689 .iter()
1690 .find(|collaborator| collaborator.is_host)
1691 .ok_or_else(|| anyhow!("host not found"))?
1692 .connection_id
1693 };
1694
1695 let payload = session
1696 .peer
1697 .forward_request(session.connection_id, host_connection_id, request)
1698 .await?;
1699
1700 response.send(payload)?;
1701 Ok(())
1702}
1703
1704async fn create_buffer_for_peer(
1705 request: proto::CreateBufferForPeer,
1706 session: Session,
1707) -> Result<()> {
1708 session.executor.record_backtrace();
1709 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1710 session
1711 .peer
1712 .forward_send(session.connection_id, peer_id.into(), request)?;
1713 Ok(())
1714}
1715
1716async fn update_buffer(
1717 request: proto::UpdateBuffer,
1718 response: Response<proto::UpdateBuffer>,
1719 session: Session,
1720) -> Result<()> {
1721 session.executor.record_backtrace();
1722 let project_id = ProjectId::from_proto(request.project_id);
1723 let mut guest_connection_ids;
1724 let mut host_connection_id = None;
1725 {
1726 let collaborators = session
1727 .db()
1728 .await
1729 .project_collaborators(project_id, session.connection_id)
1730 .await?;
1731 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1732 for collaborator in collaborators.iter() {
1733 if collaborator.is_host {
1734 host_connection_id = Some(collaborator.connection_id);
1735 } else {
1736 guest_connection_ids.push(collaborator.connection_id);
1737 }
1738 }
1739 }
1740 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1741
1742 session.executor.record_backtrace();
1743 broadcast(
1744 Some(session.connection_id),
1745 guest_connection_ids,
1746 |connection_id| {
1747 session
1748 .peer
1749 .forward_send(session.connection_id, connection_id, request.clone())
1750 },
1751 );
1752 if host_connection_id != session.connection_id {
1753 session
1754 .peer
1755 .forward_request(session.connection_id, host_connection_id, request.clone())
1756 .await?;
1757 }
1758
1759 response.send(proto::Ack {})?;
1760 Ok(())
1761}
1762
1763async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1764 let project_id = ProjectId::from_proto(request.project_id);
1765 let project_connection_ids = session
1766 .db()
1767 .await
1768 .project_connection_ids(project_id, session.connection_id)
1769 .await?;
1770
1771 broadcast(
1772 Some(session.connection_id),
1773 project_connection_ids.iter().copied(),
1774 |connection_id| {
1775 session
1776 .peer
1777 .forward_send(session.connection_id, connection_id, request.clone())
1778 },
1779 );
1780 Ok(())
1781}
1782
1783async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1784 let project_id = ProjectId::from_proto(request.project_id);
1785 let project_connection_ids = session
1786 .db()
1787 .await
1788 .project_connection_ids(project_id, session.connection_id)
1789 .await?;
1790 broadcast(
1791 Some(session.connection_id),
1792 project_connection_ids.iter().copied(),
1793 |connection_id| {
1794 session
1795 .peer
1796 .forward_send(session.connection_id, connection_id, request.clone())
1797 },
1798 );
1799 Ok(())
1800}
1801
1802async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1803 broadcast_project_message(request.project_id, request, session).await
1804}
1805
1806async fn broadcast_project_message<T: EnvelopedMessage>(
1807 project_id: u64,
1808 request: T,
1809 session: Session,
1810) -> Result<()> {
1811 let project_id = ProjectId::from_proto(project_id);
1812 let project_connection_ids = session
1813 .db()
1814 .await
1815 .project_connection_ids(project_id, session.connection_id)
1816 .await?;
1817 broadcast(
1818 Some(session.connection_id),
1819 project_connection_ids.iter().copied(),
1820 |connection_id| {
1821 session
1822 .peer
1823 .forward_send(session.connection_id, connection_id, request.clone())
1824 },
1825 );
1826 Ok(())
1827}
1828
1829async fn follow(
1830 request: proto::Follow,
1831 response: Response<proto::Follow>,
1832 session: Session,
1833) -> Result<()> {
1834 let project_id = ProjectId::from_proto(request.project_id);
1835 let leader_id = request
1836 .leader_id
1837 .ok_or_else(|| anyhow!("invalid leader id"))?
1838 .into();
1839 let follower_id = session.connection_id;
1840
1841 {
1842 let project_connection_ids = session
1843 .db()
1844 .await
1845 .project_connection_ids(project_id, session.connection_id)
1846 .await?;
1847
1848 if !project_connection_ids.contains(&leader_id) {
1849 Err(anyhow!("no such peer"))?;
1850 }
1851 }
1852
1853 let mut response_payload = session
1854 .peer
1855 .forward_request(session.connection_id, leader_id, request)
1856 .await?;
1857 response_payload
1858 .views
1859 .retain(|view| view.leader_id != Some(follower_id.into()));
1860 response.send(response_payload)?;
1861
1862 let room = session
1863 .db()
1864 .await
1865 .follow(project_id, leader_id, follower_id)
1866 .await?;
1867 room_updated(&room, &session.peer);
1868
1869 Ok(())
1870}
1871
1872async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1873 let project_id = ProjectId::from_proto(request.project_id);
1874 let leader_id = request
1875 .leader_id
1876 .ok_or_else(|| anyhow!("invalid leader id"))?
1877 .into();
1878 let follower_id = session.connection_id;
1879
1880 if !session
1881 .db()
1882 .await
1883 .project_connection_ids(project_id, session.connection_id)
1884 .await?
1885 .contains(&leader_id)
1886 {
1887 Err(anyhow!("no such peer"))?;
1888 }
1889
1890 session
1891 .peer
1892 .forward_send(session.connection_id, leader_id, request)?;
1893
1894 let room = session
1895 .db()
1896 .await
1897 .unfollow(project_id, leader_id, follower_id)
1898 .await?;
1899 room_updated(&room, &session.peer);
1900
1901 Ok(())
1902}
1903
1904async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1905 let project_id = ProjectId::from_proto(request.project_id);
1906 let project_connection_ids = session
1907 .db
1908 .lock()
1909 .await
1910 .project_connection_ids(project_id, session.connection_id)
1911 .await?;
1912
1913 let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1914 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1915 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1916 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1917 });
1918 for follower_peer_id in request.follower_ids.iter().copied() {
1919 let follower_connection_id = follower_peer_id.into();
1920 if project_connection_ids.contains(&follower_connection_id)
1921 && Some(follower_peer_id) != leader_id
1922 {
1923 session.peer.forward_send(
1924 session.connection_id,
1925 follower_connection_id,
1926 request.clone(),
1927 )?;
1928 }
1929 }
1930 Ok(())
1931}
1932
1933async fn get_users(
1934 request: proto::GetUsers,
1935 response: Response<proto::GetUsers>,
1936 session: Session,
1937) -> Result<()> {
1938 let user_ids = request
1939 .user_ids
1940 .into_iter()
1941 .map(UserId::from_proto)
1942 .collect();
1943 let users = session
1944 .db()
1945 .await
1946 .get_users_by_ids(user_ids)
1947 .await?
1948 .into_iter()
1949 .map(|user| proto::User {
1950 id: user.id.to_proto(),
1951 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1952 github_login: user.github_login,
1953 })
1954 .collect();
1955 response.send(proto::UsersResponse { users })?;
1956 Ok(())
1957}
1958
1959async fn fuzzy_search_users(
1960 request: proto::FuzzySearchUsers,
1961 response: Response<proto::FuzzySearchUsers>,
1962 session: Session,
1963) -> Result<()> {
1964 let query = request.query;
1965 let users = match query.len() {
1966 0 => vec![],
1967 1 | 2 => session
1968 .db()
1969 .await
1970 .get_user_by_github_login(&query)
1971 .await?
1972 .into_iter()
1973 .collect(),
1974 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1975 };
1976 let users = users
1977 .into_iter()
1978 .filter(|user| user.id != session.user_id)
1979 .map(|user| proto::User {
1980 id: user.id.to_proto(),
1981 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1982 github_login: user.github_login,
1983 })
1984 .collect();
1985 response.send(proto::UsersResponse { users })?;
1986 Ok(())
1987}
1988
1989async fn request_contact(
1990 request: proto::RequestContact,
1991 response: Response<proto::RequestContact>,
1992 session: Session,
1993) -> Result<()> {
1994 let requester_id = session.user_id;
1995 let responder_id = UserId::from_proto(request.responder_id);
1996 if requester_id == responder_id {
1997 return Err(anyhow!("cannot add yourself as a contact"))?;
1998 }
1999
2000 session
2001 .db()
2002 .await
2003 .send_contact_request(requester_id, responder_id)
2004 .await?;
2005
2006 // Update outgoing contact requests of requester
2007 let mut update = proto::UpdateContacts::default();
2008 update.outgoing_requests.push(responder_id.to_proto());
2009 for connection_id in session
2010 .connection_pool()
2011 .await
2012 .user_connection_ids(requester_id)
2013 {
2014 session.peer.send(connection_id, update.clone())?;
2015 }
2016
2017 // Update incoming contact requests of responder
2018 let mut update = proto::UpdateContacts::default();
2019 update
2020 .incoming_requests
2021 .push(proto::IncomingContactRequest {
2022 requester_id: requester_id.to_proto(),
2023 should_notify: true,
2024 });
2025 for connection_id in session
2026 .connection_pool()
2027 .await
2028 .user_connection_ids(responder_id)
2029 {
2030 session.peer.send(connection_id, update.clone())?;
2031 }
2032
2033 response.send(proto::Ack {})?;
2034 Ok(())
2035}
2036
2037async fn respond_to_contact_request(
2038 request: proto::RespondToContactRequest,
2039 response: Response<proto::RespondToContactRequest>,
2040 session: Session,
2041) -> Result<()> {
2042 let responder_id = session.user_id;
2043 let requester_id = UserId::from_proto(request.requester_id);
2044 let db = session.db().await;
2045 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2046 db.dismiss_contact_notification(responder_id, requester_id)
2047 .await?;
2048 } else {
2049 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2050
2051 db.respond_to_contact_request(responder_id, requester_id, accept)
2052 .await?;
2053 let requester_busy = db.is_user_busy(requester_id).await?;
2054 let responder_busy = db.is_user_busy(responder_id).await?;
2055
2056 let pool = session.connection_pool().await;
2057 // Update responder with new contact
2058 let mut update = proto::UpdateContacts::default();
2059 if accept {
2060 update
2061 .contacts
2062 .push(contact_for_user(requester_id, false, requester_busy, &pool));
2063 }
2064 update
2065 .remove_incoming_requests
2066 .push(requester_id.to_proto());
2067 for connection_id in pool.user_connection_ids(responder_id) {
2068 session.peer.send(connection_id, update.clone())?;
2069 }
2070
2071 // Update requester with new contact
2072 let mut update = proto::UpdateContacts::default();
2073 if accept {
2074 update
2075 .contacts
2076 .push(contact_for_user(responder_id, true, responder_busy, &pool));
2077 }
2078 update
2079 .remove_outgoing_requests
2080 .push(responder_id.to_proto());
2081 for connection_id in pool.user_connection_ids(requester_id) {
2082 session.peer.send(connection_id, update.clone())?;
2083 }
2084 }
2085
2086 response.send(proto::Ack {})?;
2087 Ok(())
2088}
2089
2090async fn remove_contact(
2091 request: proto::RemoveContact,
2092 response: Response<proto::RemoveContact>,
2093 session: Session,
2094) -> Result<()> {
2095 let requester_id = session.user_id;
2096 let responder_id = UserId::from_proto(request.user_id);
2097 let db = session.db().await;
2098 let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
2099
2100 let pool = session.connection_pool().await;
2101 // Update outgoing contact requests of requester
2102 let mut update = proto::UpdateContacts::default();
2103 if contact_accepted {
2104 update.remove_contacts.push(responder_id.to_proto());
2105 } else {
2106 update
2107 .remove_outgoing_requests
2108 .push(responder_id.to_proto());
2109 }
2110 for connection_id in pool.user_connection_ids(requester_id) {
2111 session.peer.send(connection_id, update.clone())?;
2112 }
2113
2114 // Update incoming contact requests of responder
2115 let mut update = proto::UpdateContacts::default();
2116 if contact_accepted {
2117 update.remove_contacts.push(requester_id.to_proto());
2118 } else {
2119 update
2120 .remove_incoming_requests
2121 .push(requester_id.to_proto());
2122 }
2123 for connection_id in pool.user_connection_ids(responder_id) {
2124 session.peer.send(connection_id, update.clone())?;
2125 }
2126
2127 response.send(proto::Ack {})?;
2128 Ok(())
2129}
2130
2131async fn create_channel(
2132 request: proto::CreateChannel,
2133 response: Response<proto::CreateChannel>,
2134 session: Session,
2135) -> Result<()> {
2136 let db = session.db().await;
2137 let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
2138
2139 if let Some(live_kit) = session.live_kit_client.as_ref() {
2140 live_kit.create_room(live_kit_room.clone()).await?;
2141 }
2142
2143 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2144 let id = db
2145 .create_channel(&request.name, parent_id, &live_kit_room, session.user_id)
2146 .await?;
2147
2148 let channel = proto::Channel {
2149 id: id.to_proto(),
2150 name: request.name,
2151 parent_id: request.parent_id,
2152 };
2153
2154 response.send(proto::ChannelResponse {
2155 channel: Some(channel.clone()),
2156 })?;
2157
2158 let mut update = proto::UpdateChannels::default();
2159 update.channels.push(channel);
2160
2161 let user_ids_to_notify = if let Some(parent_id) = parent_id {
2162 db.get_channel_members(parent_id).await?
2163 } else {
2164 vec![session.user_id]
2165 };
2166
2167 let connection_pool = session.connection_pool().await;
2168 for user_id in user_ids_to_notify {
2169 for connection_id in connection_pool.user_connection_ids(user_id) {
2170 let mut update = update.clone();
2171 if user_id == session.user_id {
2172 update.channel_permissions.push(proto::ChannelPermission {
2173 channel_id: id.to_proto(),
2174 is_admin: true,
2175 });
2176 }
2177 session.peer.send(connection_id, update)?;
2178 }
2179 }
2180
2181 Ok(())
2182}
2183
2184async fn remove_channel(
2185 request: proto::RemoveChannel,
2186 response: Response<proto::RemoveChannel>,
2187 session: Session,
2188) -> Result<()> {
2189 let db = session.db().await;
2190
2191 let channel_id = request.channel_id;
2192 let (removed_channels, member_ids) = db
2193 .remove_channel(ChannelId::from_proto(channel_id), session.user_id)
2194 .await?;
2195 response.send(proto::Ack {})?;
2196
2197 // Notify members of removed channels
2198 let mut update = proto::UpdateChannels::default();
2199 update
2200 .remove_channels
2201 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2202
2203 let connection_pool = session.connection_pool().await;
2204 for member_id in member_ids {
2205 for connection_id in connection_pool.user_connection_ids(member_id) {
2206 session.peer.send(connection_id, update.clone())?;
2207 }
2208 }
2209
2210 Ok(())
2211}
2212
2213async fn invite_channel_member(
2214 request: proto::InviteChannelMember,
2215 response: Response<proto::InviteChannelMember>,
2216 session: Session,
2217) -> Result<()> {
2218 let db = session.db().await;
2219 let channel_id = ChannelId::from_proto(request.channel_id);
2220 let invitee_id = UserId::from_proto(request.user_id);
2221 db.invite_channel_member(channel_id, invitee_id, session.user_id, request.admin)
2222 .await?;
2223
2224 let (channel, _) = db
2225 .get_channel(channel_id, session.user_id)
2226 .await?
2227 .ok_or_else(|| anyhow!("channel not found"))?;
2228
2229 let mut update = proto::UpdateChannels::default();
2230 update.channel_invitations.push(proto::Channel {
2231 id: channel.id.to_proto(),
2232 name: channel.name,
2233 parent_id: None,
2234 });
2235 for connection_id in session
2236 .connection_pool()
2237 .await
2238 .user_connection_ids(invitee_id)
2239 {
2240 session.peer.send(connection_id, update.clone())?;
2241 }
2242
2243 response.send(proto::Ack {})?;
2244 Ok(())
2245}
2246
2247async fn remove_channel_member(
2248 request: proto::RemoveChannelMember,
2249 response: Response<proto::RemoveChannelMember>,
2250 session: Session,
2251) -> Result<()> {
2252 let db = session.db().await;
2253 let channel_id = ChannelId::from_proto(request.channel_id);
2254 let member_id = UserId::from_proto(request.user_id);
2255
2256 db.remove_channel_member(channel_id, member_id, session.user_id)
2257 .await?;
2258
2259 let mut update = proto::UpdateChannels::default();
2260 update.remove_channels.push(channel_id.to_proto());
2261
2262 for connection_id in session
2263 .connection_pool()
2264 .await
2265 .user_connection_ids(member_id)
2266 {
2267 session.peer.send(connection_id, update.clone())?;
2268 }
2269
2270 response.send(proto::Ack {})?;
2271 Ok(())
2272}
2273
2274async fn set_channel_member_admin(
2275 request: proto::SetChannelMemberAdmin,
2276 response: Response<proto::SetChannelMemberAdmin>,
2277 session: Session,
2278) -> Result<()> {
2279 let db = session.db().await;
2280 let channel_id = ChannelId::from_proto(request.channel_id);
2281 let member_id = UserId::from_proto(request.user_id);
2282 db.set_channel_member_admin(channel_id, session.user_id, member_id, request.admin)
2283 .await?;
2284
2285 let (channel, has_accepted) = db
2286 .get_channel(channel_id, member_id)
2287 .await?
2288 .ok_or_else(|| anyhow!("channel not found"))?;
2289
2290 let mut update = proto::UpdateChannels::default();
2291 if has_accepted {
2292 update.channel_permissions.push(proto::ChannelPermission {
2293 channel_id: channel.id.to_proto(),
2294 is_admin: request.admin,
2295 });
2296 }
2297
2298 for connection_id in session
2299 .connection_pool()
2300 .await
2301 .user_connection_ids(member_id)
2302 {
2303 session.peer.send(connection_id, update.clone())?;
2304 }
2305
2306 response.send(proto::Ack {})?;
2307 Ok(())
2308}
2309
2310async fn rename_channel(
2311 request: proto::RenameChannel,
2312 response: Response<proto::RenameChannel>,
2313 session: Session,
2314) -> Result<()> {
2315 let db = session.db().await;
2316 let channel_id = ChannelId::from_proto(request.channel_id);
2317 let new_name = db
2318 .rename_channel(channel_id, session.user_id, &request.name)
2319 .await?;
2320
2321 let channel = proto::Channel {
2322 id: request.channel_id,
2323 name: new_name,
2324 parent_id: None,
2325 };
2326 response.send(proto::ChannelResponse {
2327 channel: Some(channel.clone()),
2328 })?;
2329 let mut update = proto::UpdateChannels::default();
2330 update.channels.push(channel);
2331
2332 let member_ids = db.get_channel_members(channel_id).await?;
2333
2334 let connection_pool = session.connection_pool().await;
2335 for member_id in member_ids {
2336 for connection_id in connection_pool.user_connection_ids(member_id) {
2337 session.peer.send(connection_id, update.clone())?;
2338 }
2339 }
2340
2341 Ok(())
2342}
2343
2344async fn get_channel_members(
2345 request: proto::GetChannelMembers,
2346 response: Response<proto::GetChannelMembers>,
2347 session: Session,
2348) -> Result<()> {
2349 let db = session.db().await;
2350 let channel_id = ChannelId::from_proto(request.channel_id);
2351 let members = db
2352 .get_channel_member_details(channel_id, session.user_id)
2353 .await?;
2354 response.send(proto::GetChannelMembersResponse { members })?;
2355 Ok(())
2356}
2357
2358async fn respond_to_channel_invite(
2359 request: proto::RespondToChannelInvite,
2360 response: Response<proto::RespondToChannelInvite>,
2361 session: Session,
2362) -> Result<()> {
2363 let db = session.db().await;
2364 let channel_id = ChannelId::from_proto(request.channel_id);
2365 db.respond_to_channel_invite(channel_id, session.user_id, request.accept)
2366 .await?;
2367
2368 let mut update = proto::UpdateChannels::default();
2369 update
2370 .remove_channel_invitations
2371 .push(channel_id.to_proto());
2372 if request.accept {
2373 let result = db.get_channels_for_user(session.user_id).await?;
2374 update
2375 .channels
2376 .extend(result.channels.into_iter().map(|channel| proto::Channel {
2377 id: channel.id.to_proto(),
2378 name: channel.name,
2379 parent_id: channel.parent_id.map(ChannelId::to_proto),
2380 }));
2381 update
2382 .channel_participants
2383 .extend(
2384 result
2385 .channel_participants
2386 .into_iter()
2387 .map(|(channel_id, user_ids)| proto::ChannelParticipants {
2388 channel_id: channel_id.to_proto(),
2389 participant_user_ids: user_ids.into_iter().map(UserId::to_proto).collect(),
2390 }),
2391 );
2392 update
2393 .channel_permissions
2394 .extend(
2395 result
2396 .channels_with_admin_privileges
2397 .into_iter()
2398 .map(|channel_id| proto::ChannelPermission {
2399 channel_id: channel_id.to_proto(),
2400 is_admin: true,
2401 }),
2402 );
2403 }
2404 session.peer.send(session.connection_id, update)?;
2405 response.send(proto::Ack {})?;
2406
2407 Ok(())
2408}
2409
2410async fn join_channel(
2411 request: proto::JoinChannel,
2412 response: Response<proto::JoinChannel>,
2413 session: Session,
2414) -> Result<()> {
2415 let channel_id = ChannelId::from_proto(request.channel_id);
2416
2417 let joined_room = {
2418 let db = session.db().await;
2419
2420 let room_id = db.room_id_for_channel(channel_id).await?;
2421
2422 let joined_room = db
2423 .join_room(
2424 room_id,
2425 session.user_id,
2426 Some(channel_id),
2427 session.connection_id,
2428 )
2429 .await?;
2430
2431 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2432 let token = live_kit
2433 .room_token(
2434 &joined_room.room.live_kit_room,
2435 &session.user_id.to_string(),
2436 )
2437 .trace_err()?;
2438
2439 Some(LiveKitConnectionInfo {
2440 server_url: live_kit.url().into(),
2441 token,
2442 })
2443 });
2444
2445 response.send(proto::JoinRoomResponse {
2446 room: Some(joined_room.room.clone()),
2447 live_kit_connection_info,
2448 })?;
2449
2450 room_updated(&joined_room.room, &session.peer);
2451
2452 joined_room.clone()
2453 };
2454
2455 channel_updated(
2456 channel_id,
2457 &joined_room.room,
2458 &joined_room.channel_members,
2459 &session.peer,
2460 &*session.connection_pool().await,
2461 );
2462
2463 update_user_contacts(session.user_id, &session).await?;
2464
2465 Ok(())
2466}
2467
2468async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
2469 let project_id = ProjectId::from_proto(request.project_id);
2470 let project_connection_ids = session
2471 .db()
2472 .await
2473 .project_connection_ids(project_id, session.connection_id)
2474 .await?;
2475 broadcast(
2476 Some(session.connection_id),
2477 project_connection_ids.iter().copied(),
2478 |connection_id| {
2479 session
2480 .peer
2481 .forward_send(session.connection_id, connection_id, request.clone())
2482 },
2483 );
2484 Ok(())
2485}
2486
2487async fn get_private_user_info(
2488 _request: proto::GetPrivateUserInfo,
2489 response: Response<proto::GetPrivateUserInfo>,
2490 session: Session,
2491) -> Result<()> {
2492 let metrics_id = session
2493 .db()
2494 .await
2495 .get_user_metrics_id(session.user_id)
2496 .await?;
2497 let user = session
2498 .db()
2499 .await
2500 .get_user_by_id(session.user_id)
2501 .await?
2502 .ok_or_else(|| anyhow!("user not found"))?;
2503 response.send(proto::GetPrivateUserInfoResponse {
2504 metrics_id,
2505 staff: user.admin,
2506 })?;
2507 Ok(())
2508}
2509
2510fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
2511 match message {
2512 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
2513 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
2514 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
2515 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
2516 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
2517 code: frame.code.into(),
2518 reason: frame.reason,
2519 })),
2520 }
2521}
2522
2523fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
2524 match message {
2525 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
2526 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
2527 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
2528 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
2529 AxumMessage::Close(frame) => {
2530 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
2531 code: frame.code.into(),
2532 reason: frame.reason,
2533 }))
2534 }
2535 }
2536}
2537
2538fn build_initial_channels_update(
2539 channels: ChannelsForUser,
2540 channel_invites: Vec<db::Channel>,
2541) -> proto::UpdateChannels {
2542 let mut update = proto::UpdateChannels::default();
2543
2544 for channel in channels.channels {
2545 update.channels.push(proto::Channel {
2546 id: channel.id.to_proto(),
2547 name: channel.name,
2548 parent_id: channel.parent_id.map(|id| id.to_proto()),
2549 });
2550 }
2551
2552 for (channel_id, participants) in channels.channel_participants {
2553 update
2554 .channel_participants
2555 .push(proto::ChannelParticipants {
2556 channel_id: channel_id.to_proto(),
2557 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
2558 });
2559 }
2560
2561 update
2562 .channel_permissions
2563 .extend(
2564 channels
2565 .channels_with_admin_privileges
2566 .into_iter()
2567 .map(|id| proto::ChannelPermission {
2568 channel_id: id.to_proto(),
2569 is_admin: true,
2570 }),
2571 );
2572
2573 for channel in channel_invites {
2574 update.channel_invitations.push(proto::Channel {
2575 id: channel.id.to_proto(),
2576 name: channel.name,
2577 parent_id: None,
2578 });
2579 }
2580
2581 update
2582}
2583
2584fn build_initial_contacts_update(
2585 contacts: Vec<db::Contact>,
2586 pool: &ConnectionPool,
2587) -> proto::UpdateContacts {
2588 let mut update = proto::UpdateContacts::default();
2589
2590 for contact in contacts {
2591 match contact {
2592 db::Contact::Accepted {
2593 user_id,
2594 should_notify,
2595 busy,
2596 } => {
2597 update
2598 .contacts
2599 .push(contact_for_user(user_id, should_notify, busy, &pool));
2600 }
2601 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
2602 db::Contact::Incoming {
2603 user_id,
2604 should_notify,
2605 } => update
2606 .incoming_requests
2607 .push(proto::IncomingContactRequest {
2608 requester_id: user_id.to_proto(),
2609 should_notify,
2610 }),
2611 }
2612 }
2613
2614 update
2615}
2616
2617fn contact_for_user(
2618 user_id: UserId,
2619 should_notify: bool,
2620 busy: bool,
2621 pool: &ConnectionPool,
2622) -> proto::Contact {
2623 proto::Contact {
2624 user_id: user_id.to_proto(),
2625 online: pool.is_user_online(user_id),
2626 busy,
2627 should_notify,
2628 }
2629}
2630
2631fn room_updated(room: &proto::Room, peer: &Peer) {
2632 broadcast(
2633 None,
2634 room.participants
2635 .iter()
2636 .filter_map(|participant| Some(participant.peer_id?.into())),
2637 |peer_id| {
2638 peer.send(
2639 peer_id.into(),
2640 proto::RoomUpdated {
2641 room: Some(room.clone()),
2642 },
2643 )
2644 },
2645 );
2646}
2647
2648fn channel_updated(
2649 channel_id: ChannelId,
2650 room: &proto::Room,
2651 channel_members: &[UserId],
2652 peer: &Peer,
2653 pool: &ConnectionPool,
2654) {
2655 let participants = room
2656 .participants
2657 .iter()
2658 .map(|p| p.user_id)
2659 .collect::<Vec<_>>();
2660
2661 broadcast(
2662 None,
2663 channel_members
2664 .iter()
2665 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2666 |peer_id| {
2667 peer.send(
2668 peer_id.into(),
2669 proto::UpdateChannels {
2670 channel_participants: vec![proto::ChannelParticipants {
2671 channel_id: channel_id.to_proto(),
2672 participant_user_ids: participants.clone(),
2673 }],
2674 ..Default::default()
2675 },
2676 )
2677 },
2678 );
2679}
2680
2681async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
2682 let db = session.db().await;
2683
2684 let contacts = db.get_contacts(user_id).await?;
2685 let busy = db.is_user_busy(user_id).await?;
2686
2687 let pool = session.connection_pool().await;
2688 let updated_contact = contact_for_user(user_id, false, busy, &pool);
2689 for contact in contacts {
2690 if let db::Contact::Accepted {
2691 user_id: contact_user_id,
2692 ..
2693 } = contact
2694 {
2695 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
2696 session
2697 .peer
2698 .send(
2699 contact_conn_id,
2700 proto::UpdateContacts {
2701 contacts: vec![updated_contact.clone()],
2702 remove_contacts: Default::default(),
2703 incoming_requests: Default::default(),
2704 remove_incoming_requests: Default::default(),
2705 outgoing_requests: Default::default(),
2706 remove_outgoing_requests: Default::default(),
2707 },
2708 )
2709 .trace_err();
2710 }
2711 }
2712 }
2713 Ok(())
2714}
2715
2716async fn leave_room_for_session(session: &Session) -> Result<()> {
2717 let mut contacts_to_update = HashSet::default();
2718
2719 let room_id;
2720 let canceled_calls_to_user_ids;
2721 let live_kit_room;
2722 let delete_live_kit_room;
2723 let room;
2724 let channel_members;
2725 let channel_id;
2726
2727 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
2728 contacts_to_update.insert(session.user_id);
2729
2730 for project in left_room.left_projects.values() {
2731 project_left(project, session);
2732 }
2733
2734 room_id = RoomId::from_proto(left_room.room.id);
2735 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
2736 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
2737 delete_live_kit_room = left_room.deleted;
2738 room = mem::take(&mut left_room.room);
2739 channel_members = mem::take(&mut left_room.channel_members);
2740 channel_id = left_room.channel_id;
2741
2742 room_updated(&room, &session.peer);
2743 } else {
2744 return Ok(());
2745 }
2746
2747 if let Some(channel_id) = channel_id {
2748 channel_updated(
2749 channel_id,
2750 &room,
2751 &channel_members,
2752 &session.peer,
2753 &*session.connection_pool().await,
2754 );
2755 }
2756
2757 {
2758 let pool = session.connection_pool().await;
2759 for canceled_user_id in canceled_calls_to_user_ids {
2760 for connection_id in pool.user_connection_ids(canceled_user_id) {
2761 session
2762 .peer
2763 .send(
2764 connection_id,
2765 proto::CallCanceled {
2766 room_id: room_id.to_proto(),
2767 },
2768 )
2769 .trace_err();
2770 }
2771 contacts_to_update.insert(canceled_user_id);
2772 }
2773 }
2774
2775 for contact_user_id in contacts_to_update {
2776 update_user_contacts(contact_user_id, &session).await?;
2777 }
2778
2779 if let Some(live_kit) = session.live_kit_client.as_ref() {
2780 live_kit
2781 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
2782 .await
2783 .trace_err();
2784
2785 if delete_live_kit_room {
2786 live_kit.delete_room(live_kit_room).await.trace_err();
2787 }
2788 }
2789
2790 Ok(())
2791}
2792
2793fn project_left(project: &db::LeftProject, session: &Session) {
2794 for connection_id in &project.connection_ids {
2795 if project.host_user_id == session.user_id {
2796 session
2797 .peer
2798 .send(
2799 *connection_id,
2800 proto::UnshareProject {
2801 project_id: project.id.to_proto(),
2802 },
2803 )
2804 .trace_err();
2805 } else {
2806 session
2807 .peer
2808 .send(
2809 *connection_id,
2810 proto::RemoveProjectCollaborator {
2811 project_id: project.id.to_proto(),
2812 peer_id: Some(session.connection_id.into()),
2813 },
2814 )
2815 .trace_err();
2816 }
2817 }
2818}
2819
2820pub trait ResultExt {
2821 type Ok;
2822
2823 fn trace_err(self) -> Option<Self::Ok>;
2824}
2825
2826impl<T, E> ResultExt for Result<T, E>
2827where
2828 E: std::fmt::Debug,
2829{
2830 type Ok = T;
2831
2832 fn trace_err(self) -> Option<T> {
2833 match self {
2834 Ok(value) => Some(value),
2835 Err(error) => {
2836 tracing::error!("{:?}", error);
2837 None
2838 }
2839 }
2840 }
2841}