1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{
6 self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreateChannelResult,
7 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
8 MoveChannelResult, NotificationId, ProjectId, RemoveChannelMemberResult,
9 RenameChannelResult, RespondToChannelInvite, RoomId, ServerId, SetChannelVisibilityResult,
10 User, UserId,
11 },
12 executor::Executor,
13 AppState, Result,
14};
15use anyhow::anyhow;
16use async_tungstenite::tungstenite::{
17 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
18};
19use axum::{
20 body::Body,
21 extract::{
22 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
23 ConnectInfo, WebSocketUpgrade,
24 },
25 headers::{Header, HeaderName},
26 http::StatusCode,
27 middleware,
28 response::IntoResponse,
29 routing::get,
30 Extension, Router, TypedHeader,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::ConnectionPool;
34use futures::{
35 channel::oneshot,
36 future::{self, BoxFuture},
37 stream::FuturesUnordered,
38 FutureExt, SinkExt, StreamExt, TryStreamExt,
39};
40use lazy_static::lazy_static;
41use prometheus::{register_int_gauge, IntGauge};
42use rpc::{
43 proto::{
44 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
45 RequestMessage, UpdateChannelBufferCollaborators,
46 },
47 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
48};
49use serde::{Serialize, Serializer};
50use std::{
51 any::TypeId,
52 fmt,
53 future::Future,
54 marker::PhantomData,
55 mem,
56 net::SocketAddr,
57 ops::{Deref, DerefMut},
58 rc::Rc,
59 sync::{
60 atomic::{AtomicBool, Ordering::SeqCst},
61 Arc,
62 },
63 time::{Duration, Instant},
64};
65use time::OffsetDateTime;
66use tokio::sync::{watch, Semaphore};
67use tower::ServiceBuilder;
68use tracing::{info_span, instrument, Instrument};
69use util::channel::RELEASE_CHANNEL_NAME;
70
71pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
72pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
73
74const MESSAGE_COUNT_PER_PAGE: usize = 100;
75const MAX_MESSAGE_LEN: usize = 1024;
76const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
77
78lazy_static! {
79 static ref METRIC_CONNECTIONS: IntGauge =
80 register_int_gauge!("connections", "number of connections").unwrap();
81 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
82 "shared_projects",
83 "number of open projects with one or more guests"
84 )
85 .unwrap();
86}
87
88type MessageHandler =
89 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
90
91struct Response<R> {
92 peer: Arc<Peer>,
93 receipt: Receipt<R>,
94 responded: Arc<AtomicBool>,
95}
96
97impl<R: RequestMessage> Response<R> {
98 fn send(self, payload: R::Response) -> Result<()> {
99 self.responded.store(true, SeqCst);
100 self.peer.respond(self.receipt, payload)?;
101 Ok(())
102 }
103}
104
105#[derive(Clone)]
106struct Session {
107 user_id: UserId,
108 connection_id: ConnectionId,
109 db: Arc<tokio::sync::Mutex<DbHandle>>,
110 peer: Arc<Peer>,
111 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
112 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
113 _executor: Executor,
114}
115
116impl Session {
117 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
118 #[cfg(test)]
119 tokio::task::yield_now().await;
120 let guard = self.db.lock().await;
121 #[cfg(test)]
122 tokio::task::yield_now().await;
123 guard
124 }
125
126 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
127 #[cfg(test)]
128 tokio::task::yield_now().await;
129 let guard = self.connection_pool.lock();
130 ConnectionPoolGuard {
131 guard,
132 _not_send: PhantomData,
133 }
134 }
135}
136
137impl fmt::Debug for Session {
138 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
139 f.debug_struct("Session")
140 .field("user_id", &self.user_id)
141 .field("connection_id", &self.connection_id)
142 .finish()
143 }
144}
145
146struct DbHandle(Arc<Database>);
147
148impl Deref for DbHandle {
149 type Target = Database;
150
151 fn deref(&self) -> &Self::Target {
152 self.0.as_ref()
153 }
154}
155
156pub struct Server {
157 id: parking_lot::Mutex<ServerId>,
158 peer: Arc<Peer>,
159 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
160 app_state: Arc<AppState>,
161 executor: Executor,
162 handlers: HashMap<TypeId, MessageHandler>,
163 teardown: watch::Sender<()>,
164}
165
166pub(crate) struct ConnectionPoolGuard<'a> {
167 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
168 _not_send: PhantomData<Rc<()>>,
169}
170
171#[derive(Serialize)]
172pub struct ServerSnapshot<'a> {
173 peer: &'a Peer,
174 #[serde(serialize_with = "serialize_deref")]
175 connection_pool: ConnectionPoolGuard<'a>,
176}
177
178pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
179where
180 S: Serializer,
181 T: Deref<Target = U>,
182 U: Serialize,
183{
184 Serialize::serialize(value.deref(), serializer)
185}
186
187impl Server {
188 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
189 let mut server = Self {
190 id: parking_lot::Mutex::new(id),
191 peer: Peer::new(id.0 as u32),
192 app_state,
193 executor,
194 connection_pool: Default::default(),
195 handlers: Default::default(),
196 teardown: watch::channel(()).0,
197 };
198
199 server
200 .add_request_handler(ping)
201 .add_request_handler(create_room)
202 .add_request_handler(join_room)
203 .add_request_handler(rejoin_room)
204 .add_request_handler(leave_room)
205 .add_request_handler(call)
206 .add_request_handler(cancel_call)
207 .add_message_handler(decline_call)
208 .add_request_handler(update_participant_location)
209 .add_request_handler(share_project)
210 .add_message_handler(unshare_project)
211 .add_request_handler(join_project)
212 .add_message_handler(leave_project)
213 .add_request_handler(update_project)
214 .add_request_handler(update_worktree)
215 .add_message_handler(start_language_server)
216 .add_message_handler(update_language_server)
217 .add_message_handler(update_diagnostic_summary)
218 .add_message_handler(update_worktree_settings)
219 .add_message_handler(refresh_inlay_hints)
220 .add_request_handler(forward_project_request::<proto::GetHover>)
221 .add_request_handler(forward_project_request::<proto::GetDefinition>)
222 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
223 .add_request_handler(forward_project_request::<proto::GetReferences>)
224 .add_request_handler(forward_project_request::<proto::SearchProject>)
225 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
226 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
227 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
228 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
229 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
230 .add_request_handler(forward_project_request::<proto::GetCompletions>)
231 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
232 .add_request_handler(forward_project_request::<proto::ResolveCompletionDocumentation>)
233 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
234 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
235 .add_request_handler(forward_project_request::<proto::PrepareRename>)
236 .add_request_handler(forward_project_request::<proto::PerformRename>)
237 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
238 .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
239 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
240 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
241 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
242 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
243 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
244 .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
245 .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
246 .add_request_handler(forward_project_request::<proto::InlayHints>)
247 .add_message_handler(create_buffer_for_peer)
248 .add_request_handler(update_buffer)
249 .add_message_handler(update_buffer_file)
250 .add_message_handler(buffer_reloaded)
251 .add_message_handler(buffer_saved)
252 .add_request_handler(forward_project_request::<proto::SaveBuffer>)
253 .add_request_handler(get_users)
254 .add_request_handler(fuzzy_search_users)
255 .add_request_handler(request_contact)
256 .add_request_handler(remove_contact)
257 .add_request_handler(respond_to_contact_request)
258 .add_request_handler(create_channel)
259 .add_request_handler(delete_channel)
260 .add_request_handler(invite_channel_member)
261 .add_request_handler(remove_channel_member)
262 .add_request_handler(set_channel_member_role)
263 .add_request_handler(set_channel_visibility)
264 .add_request_handler(rename_channel)
265 .add_request_handler(join_channel_buffer)
266 .add_request_handler(leave_channel_buffer)
267 .add_message_handler(update_channel_buffer)
268 .add_request_handler(rejoin_channel_buffers)
269 .add_request_handler(get_channel_members)
270 .add_request_handler(respond_to_channel_invite)
271 .add_request_handler(join_channel)
272 .add_request_handler(join_channel_chat)
273 .add_message_handler(leave_channel_chat)
274 .add_request_handler(send_channel_message)
275 .add_request_handler(remove_channel_message)
276 .add_request_handler(get_channel_messages)
277 .add_request_handler(get_channel_messages_by_id)
278 .add_request_handler(get_notifications)
279 .add_request_handler(mark_notification_as_read)
280 .add_request_handler(move_channel)
281 .add_request_handler(follow)
282 .add_message_handler(unfollow)
283 .add_message_handler(update_followers)
284 .add_message_handler(update_diff_base)
285 .add_request_handler(get_private_user_info)
286 .add_message_handler(acknowledge_channel_message)
287 .add_message_handler(acknowledge_buffer_version);
288
289 Arc::new(server)
290 }
291
292 pub async fn start(&self) -> Result<()> {
293 let server_id = *self.id.lock();
294 let app_state = self.app_state.clone();
295 let peer = self.peer.clone();
296 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
297 let pool = self.connection_pool.clone();
298 let live_kit_client = self.app_state.live_kit_client.clone();
299
300 let span = info_span!("start server");
301 self.executor.spawn_detached(
302 async move {
303 tracing::info!("waiting for cleanup timeout");
304 timeout.await;
305 tracing::info!("cleanup timeout expired, retrieving stale rooms");
306 if let Some((room_ids, channel_ids)) = app_state
307 .db
308 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
309 .await
310 .trace_err()
311 {
312 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
313 tracing::info!(
314 stale_channel_buffer_count = channel_ids.len(),
315 "retrieved stale channel buffers"
316 );
317
318 for channel_id in channel_ids {
319 if let Some(refreshed_channel_buffer) = app_state
320 .db
321 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
322 .await
323 .trace_err()
324 {
325 for connection_id in refreshed_channel_buffer.connection_ids {
326 peer.send(
327 connection_id,
328 proto::UpdateChannelBufferCollaborators {
329 channel_id: channel_id.to_proto(),
330 collaborators: refreshed_channel_buffer
331 .collaborators
332 .clone(),
333 },
334 )
335 .trace_err();
336 }
337 }
338 }
339
340 for room_id in room_ids {
341 let mut contacts_to_update = HashSet::default();
342 let mut canceled_calls_to_user_ids = Vec::new();
343 let mut live_kit_room = String::new();
344 let mut delete_live_kit_room = false;
345
346 if let Some(mut refreshed_room) = app_state
347 .db
348 .clear_stale_room_participants(room_id, server_id)
349 .await
350 .trace_err()
351 {
352 tracing::info!(
353 room_id = room_id.0,
354 new_participant_count = refreshed_room.room.participants.len(),
355 "refreshed room"
356 );
357 room_updated(&refreshed_room.room, &peer);
358 if let Some(channel_id) = refreshed_room.channel_id {
359 channel_updated(
360 channel_id,
361 &refreshed_room.room,
362 &refreshed_room.channel_members,
363 &peer,
364 &*pool.lock(),
365 );
366 }
367 contacts_to_update
368 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
369 contacts_to_update
370 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
371 canceled_calls_to_user_ids =
372 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
373 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
374 delete_live_kit_room = refreshed_room.room.participants.is_empty();
375 }
376
377 {
378 let pool = pool.lock();
379 for canceled_user_id in canceled_calls_to_user_ids {
380 for connection_id in pool.user_connection_ids(canceled_user_id) {
381 peer.send(
382 connection_id,
383 proto::CallCanceled {
384 room_id: room_id.to_proto(),
385 },
386 )
387 .trace_err();
388 }
389 }
390 }
391
392 for user_id in contacts_to_update {
393 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
394 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
395 if let Some((busy, contacts)) = busy.zip(contacts) {
396 let pool = pool.lock();
397 let updated_contact = contact_for_user(user_id, busy, &pool);
398 for contact in contacts {
399 if let db::Contact::Accepted {
400 user_id: contact_user_id,
401 ..
402 } = contact
403 {
404 for contact_conn_id in
405 pool.user_connection_ids(contact_user_id)
406 {
407 peer.send(
408 contact_conn_id,
409 proto::UpdateContacts {
410 contacts: vec![updated_contact.clone()],
411 remove_contacts: Default::default(),
412 incoming_requests: Default::default(),
413 remove_incoming_requests: Default::default(),
414 outgoing_requests: Default::default(),
415 remove_outgoing_requests: Default::default(),
416 },
417 )
418 .trace_err();
419 }
420 }
421 }
422 }
423 }
424
425 if let Some(live_kit) = live_kit_client.as_ref() {
426 if delete_live_kit_room {
427 live_kit.delete_room(live_kit_room).await.trace_err();
428 }
429 }
430 }
431 }
432
433 app_state
434 .db
435 .delete_stale_servers(&app_state.config.zed_environment, server_id)
436 .await
437 .trace_err();
438 }
439 .instrument(span),
440 );
441 Ok(())
442 }
443
444 pub fn teardown(&self) {
445 self.peer.teardown();
446 self.connection_pool.lock().reset();
447 let _ = self.teardown.send(());
448 }
449
450 #[cfg(test)]
451 pub fn reset(&self, id: ServerId) {
452 self.teardown();
453 *self.id.lock() = id;
454 self.peer.reset(id.0 as u32);
455 }
456
457 #[cfg(test)]
458 pub fn id(&self) -> ServerId {
459 *self.id.lock()
460 }
461
462 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
463 where
464 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
465 Fut: 'static + Send + Future<Output = Result<()>>,
466 M: EnvelopedMessage,
467 {
468 let prev_handler = self.handlers.insert(
469 TypeId::of::<M>(),
470 Box::new(move |envelope, session| {
471 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
472 let span = info_span!(
473 "handle message",
474 payload_type = envelope.payload_type_name()
475 );
476 span.in_scope(|| {
477 tracing::info!(
478 payload_type = envelope.payload_type_name(),
479 "message received"
480 );
481 });
482 let start_time = Instant::now();
483 let future = (handler)(*envelope, session);
484 async move {
485 let result = future.await;
486 let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
487 match result {
488 Err(error) => {
489 tracing::error!(%error, ?duration_ms, "error handling message")
490 }
491 Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
492 }
493 }
494 .instrument(span)
495 .boxed()
496 }),
497 );
498 if prev_handler.is_some() {
499 panic!("registered a handler for the same message twice");
500 }
501 self
502 }
503
504 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
505 where
506 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
507 Fut: 'static + Send + Future<Output = Result<()>>,
508 M: EnvelopedMessage,
509 {
510 self.add_handler(move |envelope, session| handler(envelope.payload, session));
511 self
512 }
513
514 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
515 where
516 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
517 Fut: Send + Future<Output = Result<()>>,
518 M: RequestMessage,
519 {
520 let handler = Arc::new(handler);
521 self.add_handler(move |envelope, session| {
522 let receipt = envelope.receipt();
523 let handler = handler.clone();
524 async move {
525 let peer = session.peer.clone();
526 let responded = Arc::new(AtomicBool::default());
527 let response = Response {
528 peer: peer.clone(),
529 responded: responded.clone(),
530 receipt,
531 };
532 match (handler)(envelope.payload, response, session).await {
533 Ok(()) => {
534 if responded.load(std::sync::atomic::Ordering::SeqCst) {
535 Ok(())
536 } else {
537 Err(anyhow!("handler did not send a response"))?
538 }
539 }
540 Err(error) => {
541 peer.respond_with_error(
542 receipt,
543 proto::Error {
544 message: error.to_string(),
545 },
546 )?;
547 Err(error)
548 }
549 }
550 }
551 })
552 }
553
554 pub fn handle_connection(
555 self: &Arc<Self>,
556 connection: Connection,
557 address: String,
558 user: User,
559 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
560 executor: Executor,
561 ) -> impl Future<Output = Result<()>> {
562 let this = self.clone();
563 let user_id = user.id;
564 let login = user.github_login;
565 let span = info_span!("handle connection", %user_id, %login, %address);
566 let mut teardown = self.teardown.subscribe();
567 async move {
568 let (connection_id, handle_io, mut incoming_rx) = this
569 .peer
570 .add_connection(connection, {
571 let executor = executor.clone();
572 move |duration| executor.sleep(duration)
573 });
574
575 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
576 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
577 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
578
579 if let Some(send_connection_id) = send_connection_id.take() {
580 let _ = send_connection_id.send(connection_id);
581 }
582
583 if !user.connected_once {
584 this.peer.send(connection_id, proto::ShowContacts {})?;
585 this.app_state.db.set_user_connected_once(user_id, true).await?;
586 }
587
588 let (contacts, channels_for_user, channel_invites) = future::try_join3(
589 this.app_state.db.get_contacts(user_id),
590 this.app_state.db.get_channels_for_user(user_id),
591 this.app_state.db.get_channel_invites_for_user(user_id),
592 ).await?;
593
594 {
595 let mut pool = this.connection_pool.lock();
596 pool.add_connection(connection_id, user_id, user.admin);
597 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
598 this.peer.send(connection_id, build_channels_update(
599 channels_for_user,
600 channel_invites
601 ))?;
602 }
603
604 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
605 this.peer.send(connection_id, incoming_call)?;
606 }
607
608 let session = Session {
609 user_id,
610 connection_id,
611 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
612 peer: this.peer.clone(),
613 connection_pool: this.connection_pool.clone(),
614 live_kit_client: this.app_state.live_kit_client.clone(),
615 _executor: executor.clone()
616 };
617 update_user_contacts(user_id, &session).await?;
618
619 let handle_io = handle_io.fuse();
620 futures::pin_mut!(handle_io);
621
622 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
623 // This prevents deadlocks when e.g., client A performs a request to client B and
624 // client B performs a request to client A. If both clients stop processing further
625 // messages until their respective request completes, they won't have a chance to
626 // respond to the other client's request and cause a deadlock.
627 //
628 // This arrangement ensures we will attempt to process earlier messages first, but fall
629 // back to processing messages arrived later in the spirit of making progress.
630 let mut foreground_message_handlers = FuturesUnordered::new();
631 let concurrent_handlers = Arc::new(Semaphore::new(256));
632 loop {
633 let next_message = async {
634 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
635 let message = incoming_rx.next().await;
636 (permit, message)
637 }.fuse();
638 futures::pin_mut!(next_message);
639 futures::select_biased! {
640 _ = teardown.changed().fuse() => return Ok(()),
641 result = handle_io => {
642 if let Err(error) = result {
643 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
644 }
645 break;
646 }
647 _ = foreground_message_handlers.next() => {}
648 next_message = next_message => {
649 let (permit, message) = next_message;
650 if let Some(message) = message {
651 let type_name = message.payload_type_name();
652 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
653 let span_enter = span.enter();
654 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
655 let is_background = message.is_background();
656 let handle_message = (handler)(message, session.clone());
657 drop(span_enter);
658
659 let handle_message = async move {
660 handle_message.await;
661 drop(permit);
662 }.instrument(span);
663 if is_background {
664 executor.spawn_detached(handle_message);
665 } else {
666 foreground_message_handlers.push(handle_message);
667 }
668 } else {
669 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
670 }
671 } else {
672 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
673 break;
674 }
675 }
676 }
677 }
678
679 drop(foreground_message_handlers);
680 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
681 if let Err(error) = connection_lost(session, teardown, executor).await {
682 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
683 }
684
685 Ok(())
686 }.instrument(span)
687 }
688
689 pub async fn invite_code_redeemed(
690 self: &Arc<Self>,
691 inviter_id: UserId,
692 invitee_id: UserId,
693 ) -> Result<()> {
694 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
695 if let Some(code) = &user.invite_code {
696 let pool = self.connection_pool.lock();
697 let invitee_contact = contact_for_user(invitee_id, false, &pool);
698 for connection_id in pool.user_connection_ids(inviter_id) {
699 self.peer.send(
700 connection_id,
701 proto::UpdateContacts {
702 contacts: vec![invitee_contact.clone()],
703 ..Default::default()
704 },
705 )?;
706 self.peer.send(
707 connection_id,
708 proto::UpdateInviteInfo {
709 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
710 count: user.invite_count as u32,
711 },
712 )?;
713 }
714 }
715 }
716 Ok(())
717 }
718
719 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
720 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
721 if let Some(invite_code) = &user.invite_code {
722 let pool = self.connection_pool.lock();
723 for connection_id in pool.user_connection_ids(user_id) {
724 self.peer.send(
725 connection_id,
726 proto::UpdateInviteInfo {
727 url: format!(
728 "{}{}",
729 self.app_state.config.invite_link_prefix, invite_code
730 ),
731 count: user.invite_count as u32,
732 },
733 )?;
734 }
735 }
736 }
737 Ok(())
738 }
739
740 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
741 ServerSnapshot {
742 connection_pool: ConnectionPoolGuard {
743 guard: self.connection_pool.lock(),
744 _not_send: PhantomData,
745 },
746 peer: &self.peer,
747 }
748 }
749}
750
751impl<'a> Deref for ConnectionPoolGuard<'a> {
752 type Target = ConnectionPool;
753
754 fn deref(&self) -> &Self::Target {
755 &*self.guard
756 }
757}
758
759impl<'a> DerefMut for ConnectionPoolGuard<'a> {
760 fn deref_mut(&mut self) -> &mut Self::Target {
761 &mut *self.guard
762 }
763}
764
765impl<'a> Drop for ConnectionPoolGuard<'a> {
766 fn drop(&mut self) {
767 #[cfg(test)]
768 self.check_invariants();
769 }
770}
771
772fn broadcast<F>(
773 sender_id: Option<ConnectionId>,
774 receiver_ids: impl IntoIterator<Item = ConnectionId>,
775 mut f: F,
776) where
777 F: FnMut(ConnectionId) -> anyhow::Result<()>,
778{
779 for receiver_id in receiver_ids {
780 if Some(receiver_id) != sender_id {
781 if let Err(error) = f(receiver_id) {
782 tracing::error!("failed to send to {:?} {}", receiver_id, error);
783 }
784 }
785 }
786}
787
788lazy_static! {
789 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
790}
791
792pub struct ProtocolVersion(u32);
793
794impl Header for ProtocolVersion {
795 fn name() -> &'static HeaderName {
796 &ZED_PROTOCOL_VERSION
797 }
798
799 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
800 where
801 Self: Sized,
802 I: Iterator<Item = &'i axum::http::HeaderValue>,
803 {
804 let version = values
805 .next()
806 .ok_or_else(axum::headers::Error::invalid)?
807 .to_str()
808 .map_err(|_| axum::headers::Error::invalid())?
809 .parse()
810 .map_err(|_| axum::headers::Error::invalid())?;
811 Ok(Self(version))
812 }
813
814 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
815 values.extend([self.0.to_string().parse().unwrap()]);
816 }
817}
818
819pub fn routes(server: Arc<Server>) -> Router<Body> {
820 Router::new()
821 .route("/rpc", get(handle_websocket_request))
822 .layer(
823 ServiceBuilder::new()
824 .layer(Extension(server.app_state.clone()))
825 .layer(middleware::from_fn(auth::validate_header)),
826 )
827 .route("/metrics", get(handle_metrics))
828 .layer(Extension(server))
829}
830
831pub async fn handle_websocket_request(
832 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
833 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
834 Extension(server): Extension<Arc<Server>>,
835 Extension(user): Extension<User>,
836 ws: WebSocketUpgrade,
837) -> axum::response::Response {
838 if protocol_version != rpc::PROTOCOL_VERSION {
839 return (
840 StatusCode::UPGRADE_REQUIRED,
841 "client must be upgraded".to_string(),
842 )
843 .into_response();
844 }
845 let socket_address = socket_address.to_string();
846 ws.on_upgrade(move |socket| {
847 use util::ResultExt;
848 let socket = socket
849 .map_ok(to_tungstenite_message)
850 .err_into()
851 .with(|message| async move { Ok(to_axum_message(message)) });
852 let connection = Connection::new(Box::pin(socket));
853 async move {
854 server
855 .handle_connection(connection, socket_address, user, None, Executor::Production)
856 .await
857 .log_err();
858 }
859 })
860}
861
862pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
863 let connections = server
864 .connection_pool
865 .lock()
866 .connections()
867 .filter(|connection| !connection.admin)
868 .count();
869
870 METRIC_CONNECTIONS.set(connections as _);
871
872 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
873 METRIC_SHARED_PROJECTS.set(shared_projects as _);
874
875 let encoder = prometheus::TextEncoder::new();
876 let metric_families = prometheus::gather();
877 let encoded_metrics = encoder
878 .encode_to_string(&metric_families)
879 .map_err(|err| anyhow!("{}", err))?;
880 Ok(encoded_metrics)
881}
882
883#[instrument(err, skip(executor))]
884async fn connection_lost(
885 session: Session,
886 mut teardown: watch::Receiver<()>,
887 executor: Executor,
888) -> Result<()> {
889 session.peer.disconnect(session.connection_id);
890 session
891 .connection_pool()
892 .await
893 .remove_connection(session.connection_id)?;
894
895 session
896 .db()
897 .await
898 .connection_lost(session.connection_id)
899 .await
900 .trace_err();
901
902 futures::select_biased! {
903 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
904 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
905 leave_room_for_session(&session).await.trace_err();
906 leave_channel_buffers_for_session(&session)
907 .await
908 .trace_err();
909
910 if !session
911 .connection_pool()
912 .await
913 .is_user_online(session.user_id)
914 {
915 let db = session.db().await;
916 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
917 room_updated(&room, &session.peer);
918 }
919 }
920
921 update_user_contacts(session.user_id, &session).await?;
922 }
923 _ = teardown.changed().fuse() => {}
924 }
925
926 Ok(())
927}
928
929async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
930 response.send(proto::Ack {})?;
931 Ok(())
932}
933
934async fn create_room(
935 _request: proto::CreateRoom,
936 response: Response<proto::CreateRoom>,
937 session: Session,
938) -> Result<()> {
939 let live_kit_room = nanoid::nanoid!(30);
940
941 let live_kit_connection_info = {
942 let live_kit_room = live_kit_room.clone();
943 let live_kit = session.live_kit_client.as_ref();
944
945 util::async_maybe!({
946 let live_kit = live_kit?;
947
948 let token = live_kit
949 .room_token(&live_kit_room, &session.user_id.to_string())
950 .trace_err()?;
951
952 Some(proto::LiveKitConnectionInfo {
953 server_url: live_kit.url().into(),
954 token,
955 can_publish: true,
956 })
957 })
958 }
959 .await;
960
961 let room = session
962 .db()
963 .await
964 .create_room(
965 session.user_id,
966 session.connection_id,
967 &live_kit_room,
968 RELEASE_CHANNEL_NAME.as_str(),
969 )
970 .await?;
971
972 response.send(proto::CreateRoomResponse {
973 room: Some(room.clone()),
974 live_kit_connection_info,
975 })?;
976
977 update_user_contacts(session.user_id, &session).await?;
978 Ok(())
979}
980
981async fn join_room(
982 request: proto::JoinRoom,
983 response: Response<proto::JoinRoom>,
984 session: Session,
985) -> Result<()> {
986 let room_id = RoomId::from_proto(request.id);
987
988 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
989
990 if let Some(channel_id) = channel_id {
991 return join_channel_internal(channel_id, Box::new(response), session).await;
992 }
993
994 let joined_room = {
995 let room = session
996 .db()
997 .await
998 .join_room(
999 room_id,
1000 session.user_id,
1001 session.connection_id,
1002 RELEASE_CHANNEL_NAME.as_str(),
1003 )
1004 .await?;
1005 room_updated(&room.room, &session.peer);
1006 room.into_inner()
1007 };
1008
1009 for connection_id in session
1010 .connection_pool()
1011 .await
1012 .user_connection_ids(session.user_id)
1013 {
1014 session
1015 .peer
1016 .send(
1017 connection_id,
1018 proto::CallCanceled {
1019 room_id: room_id.to_proto(),
1020 },
1021 )
1022 .trace_err();
1023 }
1024
1025 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1026 if let Some(token) = live_kit
1027 .room_token(
1028 &joined_room.room.live_kit_room,
1029 &session.user_id.to_string(),
1030 )
1031 .trace_err()
1032 {
1033 Some(proto::LiveKitConnectionInfo {
1034 server_url: live_kit.url().into(),
1035 token,
1036 can_publish: true,
1037 })
1038 } else {
1039 None
1040 }
1041 } else {
1042 None
1043 };
1044
1045 response.send(proto::JoinRoomResponse {
1046 room: Some(joined_room.room),
1047 channel_id: None,
1048 live_kit_connection_info,
1049 })?;
1050
1051 update_user_contacts(session.user_id, &session).await?;
1052 Ok(())
1053}
1054
1055async fn rejoin_room(
1056 request: proto::RejoinRoom,
1057 response: Response<proto::RejoinRoom>,
1058 session: Session,
1059) -> Result<()> {
1060 let room;
1061 let channel_id;
1062 let channel_members;
1063 {
1064 let mut rejoined_room = session
1065 .db()
1066 .await
1067 .rejoin_room(request, session.user_id, session.connection_id)
1068 .await?;
1069
1070 response.send(proto::RejoinRoomResponse {
1071 room: Some(rejoined_room.room.clone()),
1072 reshared_projects: rejoined_room
1073 .reshared_projects
1074 .iter()
1075 .map(|project| proto::ResharedProject {
1076 id: project.id.to_proto(),
1077 collaborators: project
1078 .collaborators
1079 .iter()
1080 .map(|collaborator| collaborator.to_proto())
1081 .collect(),
1082 })
1083 .collect(),
1084 rejoined_projects: rejoined_room
1085 .rejoined_projects
1086 .iter()
1087 .map(|rejoined_project| proto::RejoinedProject {
1088 id: rejoined_project.id.to_proto(),
1089 worktrees: rejoined_project
1090 .worktrees
1091 .iter()
1092 .map(|worktree| proto::WorktreeMetadata {
1093 id: worktree.id,
1094 root_name: worktree.root_name.clone(),
1095 visible: worktree.visible,
1096 abs_path: worktree.abs_path.clone(),
1097 })
1098 .collect(),
1099 collaborators: rejoined_project
1100 .collaborators
1101 .iter()
1102 .map(|collaborator| collaborator.to_proto())
1103 .collect(),
1104 language_servers: rejoined_project.language_servers.clone(),
1105 })
1106 .collect(),
1107 })?;
1108 room_updated(&rejoined_room.room, &session.peer);
1109
1110 for project in &rejoined_room.reshared_projects {
1111 for collaborator in &project.collaborators {
1112 session
1113 .peer
1114 .send(
1115 collaborator.connection_id,
1116 proto::UpdateProjectCollaborator {
1117 project_id: project.id.to_proto(),
1118 old_peer_id: Some(project.old_connection_id.into()),
1119 new_peer_id: Some(session.connection_id.into()),
1120 },
1121 )
1122 .trace_err();
1123 }
1124
1125 broadcast(
1126 Some(session.connection_id),
1127 project
1128 .collaborators
1129 .iter()
1130 .map(|collaborator| collaborator.connection_id),
1131 |connection_id| {
1132 session.peer.forward_send(
1133 session.connection_id,
1134 connection_id,
1135 proto::UpdateProject {
1136 project_id: project.id.to_proto(),
1137 worktrees: project.worktrees.clone(),
1138 },
1139 )
1140 },
1141 );
1142 }
1143
1144 for project in &rejoined_room.rejoined_projects {
1145 for collaborator in &project.collaborators {
1146 session
1147 .peer
1148 .send(
1149 collaborator.connection_id,
1150 proto::UpdateProjectCollaborator {
1151 project_id: project.id.to_proto(),
1152 old_peer_id: Some(project.old_connection_id.into()),
1153 new_peer_id: Some(session.connection_id.into()),
1154 },
1155 )
1156 .trace_err();
1157 }
1158 }
1159
1160 for project in &mut rejoined_room.rejoined_projects {
1161 for worktree in mem::take(&mut project.worktrees) {
1162 #[cfg(any(test, feature = "test-support"))]
1163 const MAX_CHUNK_SIZE: usize = 2;
1164 #[cfg(not(any(test, feature = "test-support")))]
1165 const MAX_CHUNK_SIZE: usize = 256;
1166
1167 // Stream this worktree's entries.
1168 let message = proto::UpdateWorktree {
1169 project_id: project.id.to_proto(),
1170 worktree_id: worktree.id,
1171 abs_path: worktree.abs_path.clone(),
1172 root_name: worktree.root_name,
1173 updated_entries: worktree.updated_entries,
1174 removed_entries: worktree.removed_entries,
1175 scan_id: worktree.scan_id,
1176 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1177 updated_repositories: worktree.updated_repositories,
1178 removed_repositories: worktree.removed_repositories,
1179 };
1180 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1181 session.peer.send(session.connection_id, update.clone())?;
1182 }
1183
1184 // Stream this worktree's diagnostics.
1185 for summary in worktree.diagnostic_summaries {
1186 session.peer.send(
1187 session.connection_id,
1188 proto::UpdateDiagnosticSummary {
1189 project_id: project.id.to_proto(),
1190 worktree_id: worktree.id,
1191 summary: Some(summary),
1192 },
1193 )?;
1194 }
1195
1196 for settings_file in worktree.settings_files {
1197 session.peer.send(
1198 session.connection_id,
1199 proto::UpdateWorktreeSettings {
1200 project_id: project.id.to_proto(),
1201 worktree_id: worktree.id,
1202 path: settings_file.path,
1203 content: Some(settings_file.content),
1204 },
1205 )?;
1206 }
1207 }
1208
1209 for language_server in &project.language_servers {
1210 session.peer.send(
1211 session.connection_id,
1212 proto::UpdateLanguageServer {
1213 project_id: project.id.to_proto(),
1214 language_server_id: language_server.id,
1215 variant: Some(
1216 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1217 proto::LspDiskBasedDiagnosticsUpdated {},
1218 ),
1219 ),
1220 },
1221 )?;
1222 }
1223 }
1224
1225 let rejoined_room = rejoined_room.into_inner();
1226
1227 room = rejoined_room.room;
1228 channel_id = rejoined_room.channel_id;
1229 channel_members = rejoined_room.channel_members;
1230 }
1231
1232 if let Some(channel_id) = channel_id {
1233 channel_updated(
1234 channel_id,
1235 &room,
1236 &channel_members,
1237 &session.peer,
1238 &*session.connection_pool().await,
1239 );
1240 }
1241
1242 update_user_contacts(session.user_id, &session).await?;
1243 Ok(())
1244}
1245
1246async fn leave_room(
1247 _: proto::LeaveRoom,
1248 response: Response<proto::LeaveRoom>,
1249 session: Session,
1250) -> Result<()> {
1251 leave_room_for_session(&session).await?;
1252 response.send(proto::Ack {})?;
1253 Ok(())
1254}
1255
1256async fn call(
1257 request: proto::Call,
1258 response: Response<proto::Call>,
1259 session: Session,
1260) -> Result<()> {
1261 let room_id = RoomId::from_proto(request.room_id);
1262 let calling_user_id = session.user_id;
1263 let calling_connection_id = session.connection_id;
1264 let called_user_id = UserId::from_proto(request.called_user_id);
1265 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1266 if !session
1267 .db()
1268 .await
1269 .has_contact(calling_user_id, called_user_id)
1270 .await?
1271 {
1272 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1273 }
1274
1275 let incoming_call = {
1276 let (room, incoming_call) = &mut *session
1277 .db()
1278 .await
1279 .call(
1280 room_id,
1281 calling_user_id,
1282 calling_connection_id,
1283 called_user_id,
1284 initial_project_id,
1285 )
1286 .await?;
1287 room_updated(&room, &session.peer);
1288 mem::take(incoming_call)
1289 };
1290 update_user_contacts(called_user_id, &session).await?;
1291
1292 let mut calls = session
1293 .connection_pool()
1294 .await
1295 .user_connection_ids(called_user_id)
1296 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1297 .collect::<FuturesUnordered<_>>();
1298
1299 while let Some(call_response) = calls.next().await {
1300 match call_response.as_ref() {
1301 Ok(_) => {
1302 response.send(proto::Ack {})?;
1303 return Ok(());
1304 }
1305 Err(_) => {
1306 call_response.trace_err();
1307 }
1308 }
1309 }
1310
1311 {
1312 let room = session
1313 .db()
1314 .await
1315 .call_failed(room_id, called_user_id)
1316 .await?;
1317 room_updated(&room, &session.peer);
1318 }
1319 update_user_contacts(called_user_id, &session).await?;
1320
1321 Err(anyhow!("failed to ring user"))?
1322}
1323
1324async fn cancel_call(
1325 request: proto::CancelCall,
1326 response: Response<proto::CancelCall>,
1327 session: Session,
1328) -> Result<()> {
1329 let called_user_id = UserId::from_proto(request.called_user_id);
1330 let room_id = RoomId::from_proto(request.room_id);
1331 {
1332 let room = session
1333 .db()
1334 .await
1335 .cancel_call(room_id, session.connection_id, called_user_id)
1336 .await?;
1337 room_updated(&room, &session.peer);
1338 }
1339
1340 for connection_id in session
1341 .connection_pool()
1342 .await
1343 .user_connection_ids(called_user_id)
1344 {
1345 session
1346 .peer
1347 .send(
1348 connection_id,
1349 proto::CallCanceled {
1350 room_id: room_id.to_proto(),
1351 },
1352 )
1353 .trace_err();
1354 }
1355 response.send(proto::Ack {})?;
1356
1357 update_user_contacts(called_user_id, &session).await?;
1358 Ok(())
1359}
1360
1361async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1362 let room_id = RoomId::from_proto(message.room_id);
1363 {
1364 let room = session
1365 .db()
1366 .await
1367 .decline_call(Some(room_id), session.user_id)
1368 .await?
1369 .ok_or_else(|| anyhow!("failed to decline call"))?;
1370 room_updated(&room, &session.peer);
1371 }
1372
1373 for connection_id in session
1374 .connection_pool()
1375 .await
1376 .user_connection_ids(session.user_id)
1377 {
1378 session
1379 .peer
1380 .send(
1381 connection_id,
1382 proto::CallCanceled {
1383 room_id: room_id.to_proto(),
1384 },
1385 )
1386 .trace_err();
1387 }
1388 update_user_contacts(session.user_id, &session).await?;
1389 Ok(())
1390}
1391
1392async fn update_participant_location(
1393 request: proto::UpdateParticipantLocation,
1394 response: Response<proto::UpdateParticipantLocation>,
1395 session: Session,
1396) -> Result<()> {
1397 let room_id = RoomId::from_proto(request.room_id);
1398 let location = request
1399 .location
1400 .ok_or_else(|| anyhow!("invalid location"))?;
1401
1402 let db = session.db().await;
1403 let room = db
1404 .update_room_participant_location(room_id, session.connection_id, location)
1405 .await?;
1406
1407 room_updated(&room, &session.peer);
1408 response.send(proto::Ack {})?;
1409 Ok(())
1410}
1411
1412async fn share_project(
1413 request: proto::ShareProject,
1414 response: Response<proto::ShareProject>,
1415 session: Session,
1416) -> Result<()> {
1417 let (project_id, room) = &*session
1418 .db()
1419 .await
1420 .share_project(
1421 RoomId::from_proto(request.room_id),
1422 session.connection_id,
1423 &request.worktrees,
1424 )
1425 .await?;
1426 response.send(proto::ShareProjectResponse {
1427 project_id: project_id.to_proto(),
1428 })?;
1429 room_updated(&room, &session.peer);
1430
1431 Ok(())
1432}
1433
1434async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1435 let project_id = ProjectId::from_proto(message.project_id);
1436
1437 let (room, guest_connection_ids) = &*session
1438 .db()
1439 .await
1440 .unshare_project(project_id, session.connection_id)
1441 .await?;
1442
1443 broadcast(
1444 Some(session.connection_id),
1445 guest_connection_ids.iter().copied(),
1446 |conn_id| session.peer.send(conn_id, message.clone()),
1447 );
1448 room_updated(&room, &session.peer);
1449
1450 Ok(())
1451}
1452
1453async fn join_project(
1454 request: proto::JoinProject,
1455 response: Response<proto::JoinProject>,
1456 session: Session,
1457) -> Result<()> {
1458 let project_id = ProjectId::from_proto(request.project_id);
1459 let guest_user_id = session.user_id;
1460
1461 tracing::info!(%project_id, "join project");
1462
1463 let (project, replica_id) = &mut *session
1464 .db()
1465 .await
1466 .join_project(project_id, session.connection_id)
1467 .await?;
1468
1469 let collaborators = project
1470 .collaborators
1471 .iter()
1472 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1473 .map(|collaborator| collaborator.to_proto())
1474 .collect::<Vec<_>>();
1475
1476 let worktrees = project
1477 .worktrees
1478 .iter()
1479 .map(|(id, worktree)| proto::WorktreeMetadata {
1480 id: *id,
1481 root_name: worktree.root_name.clone(),
1482 visible: worktree.visible,
1483 abs_path: worktree.abs_path.clone(),
1484 })
1485 .collect::<Vec<_>>();
1486
1487 for collaborator in &collaborators {
1488 session
1489 .peer
1490 .send(
1491 collaborator.peer_id.unwrap().into(),
1492 proto::AddProjectCollaborator {
1493 project_id: project_id.to_proto(),
1494 collaborator: Some(proto::Collaborator {
1495 peer_id: Some(session.connection_id.into()),
1496 replica_id: replica_id.0 as u32,
1497 user_id: guest_user_id.to_proto(),
1498 }),
1499 },
1500 )
1501 .trace_err();
1502 }
1503
1504 // First, we send the metadata associated with each worktree.
1505 response.send(proto::JoinProjectResponse {
1506 worktrees: worktrees.clone(),
1507 replica_id: replica_id.0 as u32,
1508 collaborators: collaborators.clone(),
1509 language_servers: project.language_servers.clone(),
1510 })?;
1511
1512 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1513 #[cfg(any(test, feature = "test-support"))]
1514 const MAX_CHUNK_SIZE: usize = 2;
1515 #[cfg(not(any(test, feature = "test-support")))]
1516 const MAX_CHUNK_SIZE: usize = 256;
1517
1518 // Stream this worktree's entries.
1519 let message = proto::UpdateWorktree {
1520 project_id: project_id.to_proto(),
1521 worktree_id,
1522 abs_path: worktree.abs_path.clone(),
1523 root_name: worktree.root_name,
1524 updated_entries: worktree.entries,
1525 removed_entries: Default::default(),
1526 scan_id: worktree.scan_id,
1527 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1528 updated_repositories: worktree.repository_entries.into_values().collect(),
1529 removed_repositories: Default::default(),
1530 };
1531 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1532 session.peer.send(session.connection_id, update.clone())?;
1533 }
1534
1535 // Stream this worktree's diagnostics.
1536 for summary in worktree.diagnostic_summaries {
1537 session.peer.send(
1538 session.connection_id,
1539 proto::UpdateDiagnosticSummary {
1540 project_id: project_id.to_proto(),
1541 worktree_id: worktree.id,
1542 summary: Some(summary),
1543 },
1544 )?;
1545 }
1546
1547 for settings_file in worktree.settings_files {
1548 session.peer.send(
1549 session.connection_id,
1550 proto::UpdateWorktreeSettings {
1551 project_id: project_id.to_proto(),
1552 worktree_id: worktree.id,
1553 path: settings_file.path,
1554 content: Some(settings_file.content),
1555 },
1556 )?;
1557 }
1558 }
1559
1560 for language_server in &project.language_servers {
1561 session.peer.send(
1562 session.connection_id,
1563 proto::UpdateLanguageServer {
1564 project_id: project_id.to_proto(),
1565 language_server_id: language_server.id,
1566 variant: Some(
1567 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1568 proto::LspDiskBasedDiagnosticsUpdated {},
1569 ),
1570 ),
1571 },
1572 )?;
1573 }
1574
1575 Ok(())
1576}
1577
1578async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1579 let sender_id = session.connection_id;
1580 let project_id = ProjectId::from_proto(request.project_id);
1581
1582 let (room, project) = &*session
1583 .db()
1584 .await
1585 .leave_project(project_id, sender_id)
1586 .await?;
1587 tracing::info!(
1588 %project_id,
1589 host_user_id = %project.host_user_id,
1590 host_connection_id = %project.host_connection_id,
1591 "leave project"
1592 );
1593
1594 project_left(&project, &session);
1595 room_updated(&room, &session.peer);
1596
1597 Ok(())
1598}
1599
1600async fn update_project(
1601 request: proto::UpdateProject,
1602 response: Response<proto::UpdateProject>,
1603 session: Session,
1604) -> Result<()> {
1605 let project_id = ProjectId::from_proto(request.project_id);
1606 let (room, guest_connection_ids) = &*session
1607 .db()
1608 .await
1609 .update_project(project_id, session.connection_id, &request.worktrees)
1610 .await?;
1611 broadcast(
1612 Some(session.connection_id),
1613 guest_connection_ids.iter().copied(),
1614 |connection_id| {
1615 session
1616 .peer
1617 .forward_send(session.connection_id, connection_id, request.clone())
1618 },
1619 );
1620 room_updated(&room, &session.peer);
1621 response.send(proto::Ack {})?;
1622
1623 Ok(())
1624}
1625
1626async fn update_worktree(
1627 request: proto::UpdateWorktree,
1628 response: Response<proto::UpdateWorktree>,
1629 session: Session,
1630) -> Result<()> {
1631 let guest_connection_ids = session
1632 .db()
1633 .await
1634 .update_worktree(&request, session.connection_id)
1635 .await?;
1636
1637 broadcast(
1638 Some(session.connection_id),
1639 guest_connection_ids.iter().copied(),
1640 |connection_id| {
1641 session
1642 .peer
1643 .forward_send(session.connection_id, connection_id, request.clone())
1644 },
1645 );
1646 response.send(proto::Ack {})?;
1647 Ok(())
1648}
1649
1650async fn update_diagnostic_summary(
1651 message: proto::UpdateDiagnosticSummary,
1652 session: Session,
1653) -> Result<()> {
1654 let guest_connection_ids = session
1655 .db()
1656 .await
1657 .update_diagnostic_summary(&message, session.connection_id)
1658 .await?;
1659
1660 broadcast(
1661 Some(session.connection_id),
1662 guest_connection_ids.iter().copied(),
1663 |connection_id| {
1664 session
1665 .peer
1666 .forward_send(session.connection_id, connection_id, message.clone())
1667 },
1668 );
1669
1670 Ok(())
1671}
1672
1673async fn update_worktree_settings(
1674 message: proto::UpdateWorktreeSettings,
1675 session: Session,
1676) -> Result<()> {
1677 let guest_connection_ids = session
1678 .db()
1679 .await
1680 .update_worktree_settings(&message, session.connection_id)
1681 .await?;
1682
1683 broadcast(
1684 Some(session.connection_id),
1685 guest_connection_ids.iter().copied(),
1686 |connection_id| {
1687 session
1688 .peer
1689 .forward_send(session.connection_id, connection_id, message.clone())
1690 },
1691 );
1692
1693 Ok(())
1694}
1695
1696async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1697 broadcast_project_message(request.project_id, request, session).await
1698}
1699
1700async fn start_language_server(
1701 request: proto::StartLanguageServer,
1702 session: Session,
1703) -> Result<()> {
1704 let guest_connection_ids = session
1705 .db()
1706 .await
1707 .start_language_server(&request, session.connection_id)
1708 .await?;
1709
1710 broadcast(
1711 Some(session.connection_id),
1712 guest_connection_ids.iter().copied(),
1713 |connection_id| {
1714 session
1715 .peer
1716 .forward_send(session.connection_id, connection_id, request.clone())
1717 },
1718 );
1719 Ok(())
1720}
1721
1722async fn update_language_server(
1723 request: proto::UpdateLanguageServer,
1724 session: Session,
1725) -> Result<()> {
1726 let project_id = ProjectId::from_proto(request.project_id);
1727 let project_connection_ids = session
1728 .db()
1729 .await
1730 .project_connection_ids(project_id, session.connection_id)
1731 .await?;
1732 broadcast(
1733 Some(session.connection_id),
1734 project_connection_ids.iter().copied(),
1735 |connection_id| {
1736 session
1737 .peer
1738 .forward_send(session.connection_id, connection_id, request.clone())
1739 },
1740 );
1741 Ok(())
1742}
1743
1744async fn forward_project_request<T>(
1745 request: T,
1746 response: Response<T>,
1747 session: Session,
1748) -> Result<()>
1749where
1750 T: EntityMessage + RequestMessage,
1751{
1752 let project_id = ProjectId::from_proto(request.remote_entity_id());
1753 let host_connection_id = {
1754 let collaborators = session
1755 .db()
1756 .await
1757 .project_collaborators(project_id, session.connection_id)
1758 .await?;
1759 collaborators
1760 .iter()
1761 .find(|collaborator| collaborator.is_host)
1762 .ok_or_else(|| anyhow!("host not found"))?
1763 .connection_id
1764 };
1765
1766 let payload = session
1767 .peer
1768 .forward_request(session.connection_id, host_connection_id, request)
1769 .await?;
1770
1771 response.send(payload)?;
1772 Ok(())
1773}
1774
1775async fn create_buffer_for_peer(
1776 request: proto::CreateBufferForPeer,
1777 session: Session,
1778) -> Result<()> {
1779 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1780 session
1781 .peer
1782 .forward_send(session.connection_id, peer_id.into(), request)?;
1783 Ok(())
1784}
1785
1786async fn update_buffer(
1787 request: proto::UpdateBuffer,
1788 response: Response<proto::UpdateBuffer>,
1789 session: Session,
1790) -> Result<()> {
1791 let project_id = ProjectId::from_proto(request.project_id);
1792 let mut guest_connection_ids;
1793 let mut host_connection_id = None;
1794 {
1795 let collaborators = session
1796 .db()
1797 .await
1798 .project_collaborators(project_id, session.connection_id)
1799 .await?;
1800 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1801 for collaborator in collaborators.iter() {
1802 if collaborator.is_host {
1803 host_connection_id = Some(collaborator.connection_id);
1804 } else {
1805 guest_connection_ids.push(collaborator.connection_id);
1806 }
1807 }
1808 }
1809 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1810
1811 broadcast(
1812 Some(session.connection_id),
1813 guest_connection_ids,
1814 |connection_id| {
1815 session
1816 .peer
1817 .forward_send(session.connection_id, connection_id, request.clone())
1818 },
1819 );
1820 if host_connection_id != session.connection_id {
1821 session
1822 .peer
1823 .forward_request(session.connection_id, host_connection_id, request.clone())
1824 .await?;
1825 }
1826
1827 response.send(proto::Ack {})?;
1828 Ok(())
1829}
1830
1831async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1832 let project_id = ProjectId::from_proto(request.project_id);
1833 let project_connection_ids = session
1834 .db()
1835 .await
1836 .project_connection_ids(project_id, session.connection_id)
1837 .await?;
1838
1839 broadcast(
1840 Some(session.connection_id),
1841 project_connection_ids.iter().copied(),
1842 |connection_id| {
1843 session
1844 .peer
1845 .forward_send(session.connection_id, connection_id, request.clone())
1846 },
1847 );
1848 Ok(())
1849}
1850
1851async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1852 let project_id = ProjectId::from_proto(request.project_id);
1853 let project_connection_ids = session
1854 .db()
1855 .await
1856 .project_connection_ids(project_id, session.connection_id)
1857 .await?;
1858 broadcast(
1859 Some(session.connection_id),
1860 project_connection_ids.iter().copied(),
1861 |connection_id| {
1862 session
1863 .peer
1864 .forward_send(session.connection_id, connection_id, request.clone())
1865 },
1866 );
1867 Ok(())
1868}
1869
1870async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1871 broadcast_project_message(request.project_id, request, session).await
1872}
1873
1874async fn broadcast_project_message<T: EnvelopedMessage>(
1875 project_id: u64,
1876 request: T,
1877 session: Session,
1878) -> Result<()> {
1879 let project_id = ProjectId::from_proto(project_id);
1880 let project_connection_ids = session
1881 .db()
1882 .await
1883 .project_connection_ids(project_id, session.connection_id)
1884 .await?;
1885 broadcast(
1886 Some(session.connection_id),
1887 project_connection_ids.iter().copied(),
1888 |connection_id| {
1889 session
1890 .peer
1891 .forward_send(session.connection_id, connection_id, request.clone())
1892 },
1893 );
1894 Ok(())
1895}
1896
1897async fn follow(
1898 request: proto::Follow,
1899 response: Response<proto::Follow>,
1900 session: Session,
1901) -> Result<()> {
1902 let room_id = RoomId::from_proto(request.room_id);
1903 let project_id = request.project_id.map(ProjectId::from_proto);
1904 let leader_id = request
1905 .leader_id
1906 .ok_or_else(|| anyhow!("invalid leader id"))?
1907 .into();
1908 let follower_id = session.connection_id;
1909
1910 session
1911 .db()
1912 .await
1913 .check_room_participants(room_id, leader_id, session.connection_id)
1914 .await?;
1915
1916 let response_payload = session
1917 .peer
1918 .forward_request(session.connection_id, leader_id, request)
1919 .await?;
1920 response.send(response_payload)?;
1921
1922 if let Some(project_id) = project_id {
1923 let room = session
1924 .db()
1925 .await
1926 .follow(room_id, project_id, leader_id, follower_id)
1927 .await?;
1928 room_updated(&room, &session.peer);
1929 }
1930
1931 Ok(())
1932}
1933
1934async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1935 let room_id = RoomId::from_proto(request.room_id);
1936 let project_id = request.project_id.map(ProjectId::from_proto);
1937 let leader_id = request
1938 .leader_id
1939 .ok_or_else(|| anyhow!("invalid leader id"))?
1940 .into();
1941 let follower_id = session.connection_id;
1942
1943 session
1944 .db()
1945 .await
1946 .check_room_participants(room_id, leader_id, session.connection_id)
1947 .await?;
1948
1949 session
1950 .peer
1951 .forward_send(session.connection_id, leader_id, request)?;
1952
1953 if let Some(project_id) = project_id {
1954 let room = session
1955 .db()
1956 .await
1957 .unfollow(room_id, project_id, leader_id, follower_id)
1958 .await?;
1959 room_updated(&room, &session.peer);
1960 }
1961
1962 Ok(())
1963}
1964
1965async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1966 let room_id = RoomId::from_proto(request.room_id);
1967 let database = session.db.lock().await;
1968
1969 let connection_ids = if let Some(project_id) = request.project_id {
1970 let project_id = ProjectId::from_proto(project_id);
1971 database
1972 .project_connection_ids(project_id, session.connection_id)
1973 .await?
1974 } else {
1975 database
1976 .room_connection_ids(room_id, session.connection_id)
1977 .await?
1978 };
1979
1980 // For now, don't send view update messages back to that view's current leader.
1981 let connection_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
1982 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1983 _ => None,
1984 });
1985
1986 for follower_peer_id in request.follower_ids.iter().copied() {
1987 let follower_connection_id = follower_peer_id.into();
1988 if Some(follower_peer_id) != connection_id_to_omit
1989 && connection_ids.contains(&follower_connection_id)
1990 {
1991 session.peer.forward_send(
1992 session.connection_id,
1993 follower_connection_id,
1994 request.clone(),
1995 )?;
1996 }
1997 }
1998 Ok(())
1999}
2000
2001async fn get_users(
2002 request: proto::GetUsers,
2003 response: Response<proto::GetUsers>,
2004 session: Session,
2005) -> Result<()> {
2006 let user_ids = request
2007 .user_ids
2008 .into_iter()
2009 .map(UserId::from_proto)
2010 .collect();
2011 let users = session
2012 .db()
2013 .await
2014 .get_users_by_ids(user_ids)
2015 .await?
2016 .into_iter()
2017 .map(|user| proto::User {
2018 id: user.id.to_proto(),
2019 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2020 github_login: user.github_login,
2021 })
2022 .collect();
2023 response.send(proto::UsersResponse { users })?;
2024 Ok(())
2025}
2026
2027async fn fuzzy_search_users(
2028 request: proto::FuzzySearchUsers,
2029 response: Response<proto::FuzzySearchUsers>,
2030 session: Session,
2031) -> Result<()> {
2032 let query = request.query;
2033 let users = match query.len() {
2034 0 => vec![],
2035 1 | 2 => session
2036 .db()
2037 .await
2038 .get_user_by_github_login(&query)
2039 .await?
2040 .into_iter()
2041 .collect(),
2042 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2043 };
2044 let users = users
2045 .into_iter()
2046 .filter(|user| user.id != session.user_id)
2047 .map(|user| proto::User {
2048 id: user.id.to_proto(),
2049 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2050 github_login: user.github_login,
2051 })
2052 .collect();
2053 response.send(proto::UsersResponse { users })?;
2054 Ok(())
2055}
2056
2057async fn request_contact(
2058 request: proto::RequestContact,
2059 response: Response<proto::RequestContact>,
2060 session: Session,
2061) -> Result<()> {
2062 let requester_id = session.user_id;
2063 let responder_id = UserId::from_proto(request.responder_id);
2064 if requester_id == responder_id {
2065 return Err(anyhow!("cannot add yourself as a contact"))?;
2066 }
2067
2068 let notifications = session
2069 .db()
2070 .await
2071 .send_contact_request(requester_id, responder_id)
2072 .await?;
2073
2074 // Update outgoing contact requests of requester
2075 let mut update = proto::UpdateContacts::default();
2076 update.outgoing_requests.push(responder_id.to_proto());
2077 for connection_id in session
2078 .connection_pool()
2079 .await
2080 .user_connection_ids(requester_id)
2081 {
2082 session.peer.send(connection_id, update.clone())?;
2083 }
2084
2085 // Update incoming contact requests of responder
2086 let mut update = proto::UpdateContacts::default();
2087 update
2088 .incoming_requests
2089 .push(proto::IncomingContactRequest {
2090 requester_id: requester_id.to_proto(),
2091 });
2092 let connection_pool = session.connection_pool().await;
2093 for connection_id in connection_pool.user_connection_ids(responder_id) {
2094 session.peer.send(connection_id, update.clone())?;
2095 }
2096
2097 send_notifications(&*connection_pool, &session.peer, notifications);
2098
2099 response.send(proto::Ack {})?;
2100 Ok(())
2101}
2102
2103async fn respond_to_contact_request(
2104 request: proto::RespondToContactRequest,
2105 response: Response<proto::RespondToContactRequest>,
2106 session: Session,
2107) -> Result<()> {
2108 let responder_id = session.user_id;
2109 let requester_id = UserId::from_proto(request.requester_id);
2110 let db = session.db().await;
2111 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2112 db.dismiss_contact_notification(responder_id, requester_id)
2113 .await?;
2114 } else {
2115 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2116
2117 let notifications = db
2118 .respond_to_contact_request(responder_id, requester_id, accept)
2119 .await?;
2120 let requester_busy = db.is_user_busy(requester_id).await?;
2121 let responder_busy = db.is_user_busy(responder_id).await?;
2122
2123 let pool = session.connection_pool().await;
2124 // Update responder with new contact
2125 let mut update = proto::UpdateContacts::default();
2126 if accept {
2127 update
2128 .contacts
2129 .push(contact_for_user(requester_id, requester_busy, &pool));
2130 }
2131 update
2132 .remove_incoming_requests
2133 .push(requester_id.to_proto());
2134 for connection_id in pool.user_connection_ids(responder_id) {
2135 session.peer.send(connection_id, update.clone())?;
2136 }
2137
2138 // Update requester with new contact
2139 let mut update = proto::UpdateContacts::default();
2140 if accept {
2141 update
2142 .contacts
2143 .push(contact_for_user(responder_id, responder_busy, &pool));
2144 }
2145 update
2146 .remove_outgoing_requests
2147 .push(responder_id.to_proto());
2148
2149 for connection_id in pool.user_connection_ids(requester_id) {
2150 session.peer.send(connection_id, update.clone())?;
2151 }
2152
2153 send_notifications(&*pool, &session.peer, notifications);
2154 }
2155
2156 response.send(proto::Ack {})?;
2157 Ok(())
2158}
2159
2160async fn remove_contact(
2161 request: proto::RemoveContact,
2162 response: Response<proto::RemoveContact>,
2163 session: Session,
2164) -> Result<()> {
2165 let requester_id = session.user_id;
2166 let responder_id = UserId::from_proto(request.user_id);
2167 let db = session.db().await;
2168 let (contact_accepted, deleted_notification_id) =
2169 db.remove_contact(requester_id, responder_id).await?;
2170
2171 let pool = session.connection_pool().await;
2172 // Update outgoing contact requests of requester
2173 let mut update = proto::UpdateContacts::default();
2174 if contact_accepted {
2175 update.remove_contacts.push(responder_id.to_proto());
2176 } else {
2177 update
2178 .remove_outgoing_requests
2179 .push(responder_id.to_proto());
2180 }
2181 for connection_id in pool.user_connection_ids(requester_id) {
2182 session.peer.send(connection_id, update.clone())?;
2183 }
2184
2185 // Update incoming contact requests of responder
2186 let mut update = proto::UpdateContacts::default();
2187 if contact_accepted {
2188 update.remove_contacts.push(requester_id.to_proto());
2189 } else {
2190 update
2191 .remove_incoming_requests
2192 .push(requester_id.to_proto());
2193 }
2194 for connection_id in pool.user_connection_ids(responder_id) {
2195 session.peer.send(connection_id, update.clone())?;
2196 if let Some(notification_id) = deleted_notification_id {
2197 session.peer.send(
2198 connection_id,
2199 proto::DeleteNotification {
2200 notification_id: notification_id.to_proto(),
2201 },
2202 )?;
2203 }
2204 }
2205
2206 response.send(proto::Ack {})?;
2207 Ok(())
2208}
2209
2210async fn create_channel(
2211 request: proto::CreateChannel,
2212 response: Response<proto::CreateChannel>,
2213 session: Session,
2214) -> Result<()> {
2215 let db = session.db().await;
2216
2217 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2218 let CreateChannelResult {
2219 channel,
2220 participants_to_update,
2221 } = db
2222 .create_channel(&request.name, parent_id, session.user_id)
2223 .await?;
2224
2225 response.send(proto::CreateChannelResponse {
2226 channel: Some(channel.to_proto()),
2227 parent_id: request.parent_id,
2228 })?;
2229
2230 let connection_pool = session.connection_pool().await;
2231 for (user_id, channels) in participants_to_update {
2232 let update = build_channels_update(channels, vec![]);
2233 for connection_id in connection_pool.user_connection_ids(user_id) {
2234 if user_id == session.user_id {
2235 continue;
2236 }
2237 session.peer.send(connection_id, update.clone())?;
2238 }
2239 }
2240
2241 Ok(())
2242}
2243
2244async fn delete_channel(
2245 request: proto::DeleteChannel,
2246 response: Response<proto::DeleteChannel>,
2247 session: Session,
2248) -> Result<()> {
2249 let db = session.db().await;
2250
2251 let channel_id = request.channel_id;
2252 let (removed_channels, member_ids) = db
2253 .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2254 .await?;
2255 response.send(proto::Ack {})?;
2256
2257 // Notify members of removed channels
2258 let mut update = proto::UpdateChannels::default();
2259 update
2260 .delete_channels
2261 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2262
2263 let connection_pool = session.connection_pool().await;
2264 for member_id in member_ids {
2265 for connection_id in connection_pool.user_connection_ids(member_id) {
2266 session.peer.send(connection_id, update.clone())?;
2267 }
2268 }
2269
2270 Ok(())
2271}
2272
2273async fn invite_channel_member(
2274 request: proto::InviteChannelMember,
2275 response: Response<proto::InviteChannelMember>,
2276 session: Session,
2277) -> Result<()> {
2278 let db = session.db().await;
2279 let channel_id = ChannelId::from_proto(request.channel_id);
2280 let invitee_id = UserId::from_proto(request.user_id);
2281 let InviteMemberResult {
2282 channel,
2283 notifications,
2284 } = db
2285 .invite_channel_member(
2286 channel_id,
2287 invitee_id,
2288 session.user_id,
2289 request.role().into(),
2290 )
2291 .await?;
2292
2293 let update = proto::UpdateChannels {
2294 channel_invitations: vec![channel.to_proto()],
2295 ..Default::default()
2296 };
2297
2298 let connection_pool = session.connection_pool().await;
2299 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2300 session.peer.send(connection_id, update.clone())?;
2301 }
2302
2303 send_notifications(&*connection_pool, &session.peer, notifications);
2304
2305 response.send(proto::Ack {})?;
2306 Ok(())
2307}
2308
2309async fn remove_channel_member(
2310 request: proto::RemoveChannelMember,
2311 response: Response<proto::RemoveChannelMember>,
2312 session: Session,
2313) -> Result<()> {
2314 let db = session.db().await;
2315 let channel_id = ChannelId::from_proto(request.channel_id);
2316 let member_id = UserId::from_proto(request.user_id);
2317
2318 let RemoveChannelMemberResult {
2319 membership_update,
2320 notification_id,
2321 } = db
2322 .remove_channel_member(channel_id, member_id, session.user_id)
2323 .await?;
2324
2325 let connection_pool = &session.connection_pool().await;
2326 notify_membership_updated(
2327 &connection_pool,
2328 membership_update,
2329 member_id,
2330 &session.peer,
2331 );
2332 for connection_id in connection_pool.user_connection_ids(member_id) {
2333 if let Some(notification_id) = notification_id {
2334 session
2335 .peer
2336 .send(
2337 connection_id,
2338 proto::DeleteNotification {
2339 notification_id: notification_id.to_proto(),
2340 },
2341 )
2342 .trace_err();
2343 }
2344 }
2345
2346 response.send(proto::Ack {})?;
2347 Ok(())
2348}
2349
2350async fn set_channel_visibility(
2351 request: proto::SetChannelVisibility,
2352 response: Response<proto::SetChannelVisibility>,
2353 session: Session,
2354) -> Result<()> {
2355 let db = session.db().await;
2356 let channel_id = ChannelId::from_proto(request.channel_id);
2357 let visibility = request.visibility().into();
2358
2359 let SetChannelVisibilityResult {
2360 participants_to_update,
2361 participants_to_remove,
2362 channels_to_remove,
2363 } = db
2364 .set_channel_visibility(channel_id, visibility, session.user_id)
2365 .await?;
2366
2367 let connection_pool = session.connection_pool().await;
2368 for (user_id, channels) in participants_to_update {
2369 let update = build_channels_update(channels, vec![]);
2370 for connection_id in connection_pool.user_connection_ids(user_id) {
2371 session.peer.send(connection_id, update.clone())?;
2372 }
2373 }
2374 for user_id in participants_to_remove {
2375 let update = proto::UpdateChannels {
2376 delete_channels: channels_to_remove.iter().map(|id| id.to_proto()).collect(),
2377 ..Default::default()
2378 };
2379 for connection_id in connection_pool.user_connection_ids(user_id) {
2380 session.peer.send(connection_id, update.clone())?;
2381 }
2382 }
2383
2384 response.send(proto::Ack {})?;
2385 Ok(())
2386}
2387
2388async fn set_channel_member_role(
2389 request: proto::SetChannelMemberRole,
2390 response: Response<proto::SetChannelMemberRole>,
2391 session: Session,
2392) -> Result<()> {
2393 let db = session.db().await;
2394 let channel_id = ChannelId::from_proto(request.channel_id);
2395 let member_id = UserId::from_proto(request.user_id);
2396 let result = db
2397 .set_channel_member_role(
2398 channel_id,
2399 session.user_id,
2400 member_id,
2401 request.role().into(),
2402 )
2403 .await?;
2404
2405 match result {
2406 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2407 let connection_pool = session.connection_pool().await;
2408 notify_membership_updated(
2409 &connection_pool,
2410 membership_update,
2411 member_id,
2412 &session.peer,
2413 )
2414 }
2415 db::SetMemberRoleResult::InviteUpdated(channel) => {
2416 let update = proto::UpdateChannels {
2417 channel_invitations: vec![channel.to_proto()],
2418 ..Default::default()
2419 };
2420
2421 for connection_id in session
2422 .connection_pool()
2423 .await
2424 .user_connection_ids(member_id)
2425 {
2426 session.peer.send(connection_id, update.clone())?;
2427 }
2428 }
2429 }
2430
2431 response.send(proto::Ack {})?;
2432 Ok(())
2433}
2434
2435async fn rename_channel(
2436 request: proto::RenameChannel,
2437 response: Response<proto::RenameChannel>,
2438 session: Session,
2439) -> Result<()> {
2440 let db = session.db().await;
2441 let channel_id = ChannelId::from_proto(request.channel_id);
2442 let RenameChannelResult {
2443 channel,
2444 participants_to_update,
2445 } = db
2446 .rename_channel(channel_id, session.user_id, &request.name)
2447 .await?;
2448
2449 response.send(proto::RenameChannelResponse {
2450 channel: Some(channel.to_proto()),
2451 })?;
2452
2453 let connection_pool = session.connection_pool().await;
2454 for (user_id, channel) in participants_to_update {
2455 for connection_id in connection_pool.user_connection_ids(user_id) {
2456 let update = proto::UpdateChannels {
2457 channels: vec![channel.to_proto()],
2458 ..Default::default()
2459 };
2460
2461 session.peer.send(connection_id, update.clone())?;
2462 }
2463 }
2464
2465 Ok(())
2466}
2467
2468async fn move_channel(
2469 request: proto::MoveChannel,
2470 response: Response<proto::MoveChannel>,
2471 session: Session,
2472) -> Result<()> {
2473 let channel_id = ChannelId::from_proto(request.channel_id);
2474 let to = request.to.map(ChannelId::from_proto);
2475
2476 let result = session
2477 .db()
2478 .await
2479 .move_channel(channel_id, to, session.user_id)
2480 .await?;
2481
2482 notify_channel_moved(result, session).await?;
2483
2484 response.send(Ack {})?;
2485 Ok(())
2486}
2487
2488async fn notify_channel_moved(result: Option<MoveChannelResult>, session: Session) -> Result<()> {
2489 let Some(MoveChannelResult {
2490 participants_to_remove,
2491 participants_to_update,
2492 moved_channels,
2493 }) = result
2494 else {
2495 return Ok(());
2496 };
2497 let moved_channels: Vec<u64> = moved_channels.iter().map(|id| id.to_proto()).collect();
2498
2499 let connection_pool = session.connection_pool().await;
2500 for (user_id, channels) in participants_to_update {
2501 let mut update = build_channels_update(channels, vec![]);
2502 update.delete_channels = moved_channels.clone();
2503 for connection_id in connection_pool.user_connection_ids(user_id) {
2504 session.peer.send(connection_id, update.clone())?;
2505 }
2506 }
2507
2508 for user_id in participants_to_remove {
2509 let update = proto::UpdateChannels {
2510 delete_channels: moved_channels.clone(),
2511 ..Default::default()
2512 };
2513 for connection_id in connection_pool.user_connection_ids(user_id) {
2514 session.peer.send(connection_id, update.clone())?;
2515 }
2516 }
2517 Ok(())
2518}
2519
2520async fn get_channel_members(
2521 request: proto::GetChannelMembers,
2522 response: Response<proto::GetChannelMembers>,
2523 session: Session,
2524) -> Result<()> {
2525 let db = session.db().await;
2526 let channel_id = ChannelId::from_proto(request.channel_id);
2527 let members = db
2528 .get_channel_participant_details(channel_id, session.user_id)
2529 .await?;
2530 response.send(proto::GetChannelMembersResponse { members })?;
2531 Ok(())
2532}
2533
2534async fn respond_to_channel_invite(
2535 request: proto::RespondToChannelInvite,
2536 response: Response<proto::RespondToChannelInvite>,
2537 session: Session,
2538) -> Result<()> {
2539 let db = session.db().await;
2540 let channel_id = ChannelId::from_proto(request.channel_id);
2541 let RespondToChannelInvite {
2542 membership_update,
2543 notifications,
2544 } = db
2545 .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2546 .await?;
2547
2548 let connection_pool = session.connection_pool().await;
2549 if let Some(membership_update) = membership_update {
2550 notify_membership_updated(
2551 &connection_pool,
2552 membership_update,
2553 session.user_id,
2554 &session.peer,
2555 );
2556 } else {
2557 let update = proto::UpdateChannels {
2558 remove_channel_invitations: vec![channel_id.to_proto()],
2559 ..Default::default()
2560 };
2561
2562 for connection_id in connection_pool.user_connection_ids(session.user_id) {
2563 session.peer.send(connection_id, update.clone())?;
2564 }
2565 };
2566
2567 send_notifications(&*connection_pool, &session.peer, notifications);
2568
2569 response.send(proto::Ack {})?;
2570
2571 Ok(())
2572}
2573
2574async fn join_channel(
2575 request: proto::JoinChannel,
2576 response: Response<proto::JoinChannel>,
2577 session: Session,
2578) -> Result<()> {
2579 let channel_id = ChannelId::from_proto(request.channel_id);
2580 join_channel_internal(channel_id, Box::new(response), session).await
2581}
2582
2583trait JoinChannelInternalResponse {
2584 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2585}
2586impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2587 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2588 Response::<proto::JoinChannel>::send(self, result)
2589 }
2590}
2591impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2592 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2593 Response::<proto::JoinRoom>::send(self, result)
2594 }
2595}
2596
2597async fn join_channel_internal(
2598 channel_id: ChannelId,
2599 response: Box<impl JoinChannelInternalResponse>,
2600 session: Session,
2601) -> Result<()> {
2602 let joined_room = {
2603 leave_room_for_session(&session).await?;
2604 let db = session.db().await;
2605
2606 let (joined_room, membership_updated, role) = db
2607 .join_channel(
2608 channel_id,
2609 session.user_id,
2610 session.connection_id,
2611 RELEASE_CHANNEL_NAME.as_str(),
2612 )
2613 .await?;
2614
2615 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2616 let (can_publish, token) = if role == ChannelRole::Guest {
2617 (
2618 false,
2619 live_kit
2620 .guest_token(
2621 &joined_room.room.live_kit_room,
2622 &session.user_id.to_string(),
2623 )
2624 .trace_err()?,
2625 )
2626 } else {
2627 (
2628 true,
2629 live_kit
2630 .room_token(
2631 &joined_room.room.live_kit_room,
2632 &session.user_id.to_string(),
2633 )
2634 .trace_err()?,
2635 )
2636 };
2637
2638 Some(LiveKitConnectionInfo {
2639 server_url: live_kit.url().into(),
2640 token,
2641 can_publish,
2642 })
2643 });
2644
2645 response.send(proto::JoinRoomResponse {
2646 room: Some(joined_room.room.clone()),
2647 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2648 live_kit_connection_info,
2649 })?;
2650
2651 let connection_pool = session.connection_pool().await;
2652 if let Some(membership_updated) = membership_updated {
2653 notify_membership_updated(
2654 &connection_pool,
2655 membership_updated,
2656 session.user_id,
2657 &session.peer,
2658 );
2659 }
2660
2661 room_updated(&joined_room.room, &session.peer);
2662
2663 joined_room
2664 };
2665
2666 channel_updated(
2667 channel_id,
2668 &joined_room.room,
2669 &joined_room.channel_members,
2670 &session.peer,
2671 &*session.connection_pool().await,
2672 );
2673
2674 update_user_contacts(session.user_id, &session).await?;
2675 Ok(())
2676}
2677
2678async fn join_channel_buffer(
2679 request: proto::JoinChannelBuffer,
2680 response: Response<proto::JoinChannelBuffer>,
2681 session: Session,
2682) -> Result<()> {
2683 let db = session.db().await;
2684 let channel_id = ChannelId::from_proto(request.channel_id);
2685
2686 let open_response = db
2687 .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2688 .await?;
2689
2690 let collaborators = open_response.collaborators.clone();
2691 response.send(open_response)?;
2692
2693 let update = UpdateChannelBufferCollaborators {
2694 channel_id: channel_id.to_proto(),
2695 collaborators: collaborators.clone(),
2696 };
2697 channel_buffer_updated(
2698 session.connection_id,
2699 collaborators
2700 .iter()
2701 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2702 &update,
2703 &session.peer,
2704 );
2705
2706 Ok(())
2707}
2708
2709async fn update_channel_buffer(
2710 request: proto::UpdateChannelBuffer,
2711 session: Session,
2712) -> Result<()> {
2713 let db = session.db().await;
2714 let channel_id = ChannelId::from_proto(request.channel_id);
2715
2716 let (collaborators, non_collaborators, epoch, version) = db
2717 .update_channel_buffer(channel_id, session.user_id, &request.operations)
2718 .await?;
2719
2720 channel_buffer_updated(
2721 session.connection_id,
2722 collaborators,
2723 &proto::UpdateChannelBuffer {
2724 channel_id: channel_id.to_proto(),
2725 operations: request.operations,
2726 },
2727 &session.peer,
2728 );
2729
2730 let pool = &*session.connection_pool().await;
2731
2732 broadcast(
2733 None,
2734 non_collaborators
2735 .iter()
2736 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2737 |peer_id| {
2738 session.peer.send(
2739 peer_id.into(),
2740 proto::UpdateChannels {
2741 unseen_channel_buffer_changes: vec![proto::UnseenChannelBufferChange {
2742 channel_id: channel_id.to_proto(),
2743 epoch: epoch as u64,
2744 version: version.clone(),
2745 }],
2746 ..Default::default()
2747 },
2748 )
2749 },
2750 );
2751
2752 Ok(())
2753}
2754
2755async fn rejoin_channel_buffers(
2756 request: proto::RejoinChannelBuffers,
2757 response: Response<proto::RejoinChannelBuffers>,
2758 session: Session,
2759) -> Result<()> {
2760 let db = session.db().await;
2761 let buffers = db
2762 .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2763 .await?;
2764
2765 for rejoined_buffer in &buffers {
2766 let collaborators_to_notify = rejoined_buffer
2767 .buffer
2768 .collaborators
2769 .iter()
2770 .filter_map(|c| Some(c.peer_id?.into()));
2771 channel_buffer_updated(
2772 session.connection_id,
2773 collaborators_to_notify,
2774 &proto::UpdateChannelBufferCollaborators {
2775 channel_id: rejoined_buffer.buffer.channel_id,
2776 collaborators: rejoined_buffer.buffer.collaborators.clone(),
2777 },
2778 &session.peer,
2779 );
2780 }
2781
2782 response.send(proto::RejoinChannelBuffersResponse {
2783 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2784 })?;
2785
2786 Ok(())
2787}
2788
2789async fn leave_channel_buffer(
2790 request: proto::LeaveChannelBuffer,
2791 response: Response<proto::LeaveChannelBuffer>,
2792 session: Session,
2793) -> Result<()> {
2794 let db = session.db().await;
2795 let channel_id = ChannelId::from_proto(request.channel_id);
2796
2797 let left_buffer = db
2798 .leave_channel_buffer(channel_id, session.connection_id)
2799 .await?;
2800
2801 response.send(Ack {})?;
2802
2803 channel_buffer_updated(
2804 session.connection_id,
2805 left_buffer.connections,
2806 &proto::UpdateChannelBufferCollaborators {
2807 channel_id: channel_id.to_proto(),
2808 collaborators: left_buffer.collaborators,
2809 },
2810 &session.peer,
2811 );
2812
2813 Ok(())
2814}
2815
2816fn channel_buffer_updated<T: EnvelopedMessage>(
2817 sender_id: ConnectionId,
2818 collaborators: impl IntoIterator<Item = ConnectionId>,
2819 message: &T,
2820 peer: &Peer,
2821) {
2822 broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2823 peer.send(peer_id.into(), message.clone())
2824 });
2825}
2826
2827fn send_notifications(
2828 connection_pool: &ConnectionPool,
2829 peer: &Peer,
2830 notifications: db::NotificationBatch,
2831) {
2832 for (user_id, notification) in notifications {
2833 for connection_id in connection_pool.user_connection_ids(user_id) {
2834 if let Err(error) = peer.send(
2835 connection_id,
2836 proto::AddNotification {
2837 notification: Some(notification.clone()),
2838 },
2839 ) {
2840 tracing::error!(
2841 "failed to send notification to {:?} {}",
2842 connection_id,
2843 error
2844 );
2845 }
2846 }
2847 }
2848}
2849
2850async fn send_channel_message(
2851 request: proto::SendChannelMessage,
2852 response: Response<proto::SendChannelMessage>,
2853 session: Session,
2854) -> Result<()> {
2855 // Validate the message body.
2856 let body = request.body.trim().to_string();
2857 if body.len() > MAX_MESSAGE_LEN {
2858 return Err(anyhow!("message is too long"))?;
2859 }
2860 if body.is_empty() {
2861 return Err(anyhow!("message can't be blank"))?;
2862 }
2863
2864 // TODO: adjust mentions if body is trimmed
2865
2866 let timestamp = OffsetDateTime::now_utc();
2867 let nonce = request
2868 .nonce
2869 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
2870
2871 let channel_id = ChannelId::from_proto(request.channel_id);
2872 let CreatedChannelMessage {
2873 message_id,
2874 participant_connection_ids,
2875 channel_members,
2876 notifications,
2877 } = session
2878 .db()
2879 .await
2880 .create_channel_message(
2881 channel_id,
2882 session.user_id,
2883 &body,
2884 &request.mentions,
2885 timestamp,
2886 nonce.clone().into(),
2887 )
2888 .await?;
2889 let message = proto::ChannelMessage {
2890 sender_id: session.user_id.to_proto(),
2891 id: message_id.to_proto(),
2892 body,
2893 mentions: request.mentions,
2894 timestamp: timestamp.unix_timestamp() as u64,
2895 nonce: Some(nonce),
2896 };
2897 broadcast(
2898 Some(session.connection_id),
2899 participant_connection_ids,
2900 |connection| {
2901 session.peer.send(
2902 connection,
2903 proto::ChannelMessageSent {
2904 channel_id: channel_id.to_proto(),
2905 message: Some(message.clone()),
2906 },
2907 )
2908 },
2909 );
2910 response.send(proto::SendChannelMessageResponse {
2911 message: Some(message),
2912 })?;
2913
2914 let pool = &*session.connection_pool().await;
2915 broadcast(
2916 None,
2917 channel_members
2918 .iter()
2919 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2920 |peer_id| {
2921 session.peer.send(
2922 peer_id.into(),
2923 proto::UpdateChannels {
2924 unseen_channel_messages: vec![proto::UnseenChannelMessage {
2925 channel_id: channel_id.to_proto(),
2926 message_id: message_id.to_proto(),
2927 }],
2928 ..Default::default()
2929 },
2930 )
2931 },
2932 );
2933 send_notifications(pool, &session.peer, notifications);
2934
2935 Ok(())
2936}
2937
2938async fn remove_channel_message(
2939 request: proto::RemoveChannelMessage,
2940 response: Response<proto::RemoveChannelMessage>,
2941 session: Session,
2942) -> Result<()> {
2943 let channel_id = ChannelId::from_proto(request.channel_id);
2944 let message_id = MessageId::from_proto(request.message_id);
2945 let connection_ids = session
2946 .db()
2947 .await
2948 .remove_channel_message(channel_id, message_id, session.user_id)
2949 .await?;
2950 broadcast(Some(session.connection_id), connection_ids, |connection| {
2951 session.peer.send(connection, request.clone())
2952 });
2953 response.send(proto::Ack {})?;
2954 Ok(())
2955}
2956
2957async fn acknowledge_channel_message(
2958 request: proto::AckChannelMessage,
2959 session: Session,
2960) -> Result<()> {
2961 let channel_id = ChannelId::from_proto(request.channel_id);
2962 let message_id = MessageId::from_proto(request.message_id);
2963 let notifications = session
2964 .db()
2965 .await
2966 .observe_channel_message(channel_id, session.user_id, message_id)
2967 .await?;
2968 send_notifications(
2969 &*session.connection_pool().await,
2970 &session.peer,
2971 notifications,
2972 );
2973 Ok(())
2974}
2975
2976async fn acknowledge_buffer_version(
2977 request: proto::AckBufferOperation,
2978 session: Session,
2979) -> Result<()> {
2980 let buffer_id = BufferId::from_proto(request.buffer_id);
2981 session
2982 .db()
2983 .await
2984 .observe_buffer_version(
2985 buffer_id,
2986 session.user_id,
2987 request.epoch as i32,
2988 &request.version,
2989 )
2990 .await?;
2991 Ok(())
2992}
2993
2994async fn join_channel_chat(
2995 request: proto::JoinChannelChat,
2996 response: Response<proto::JoinChannelChat>,
2997 session: Session,
2998) -> Result<()> {
2999 let channel_id = ChannelId::from_proto(request.channel_id);
3000
3001 let db = session.db().await;
3002 db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3003 .await?;
3004 let messages = db
3005 .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3006 .await?;
3007 response.send(proto::JoinChannelChatResponse {
3008 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3009 messages,
3010 })?;
3011 Ok(())
3012}
3013
3014async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3015 let channel_id = ChannelId::from_proto(request.channel_id);
3016 session
3017 .db()
3018 .await
3019 .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3020 .await?;
3021 Ok(())
3022}
3023
3024async fn get_channel_messages(
3025 request: proto::GetChannelMessages,
3026 response: Response<proto::GetChannelMessages>,
3027 session: Session,
3028) -> Result<()> {
3029 let channel_id = ChannelId::from_proto(request.channel_id);
3030 let messages = session
3031 .db()
3032 .await
3033 .get_channel_messages(
3034 channel_id,
3035 session.user_id,
3036 MESSAGE_COUNT_PER_PAGE,
3037 Some(MessageId::from_proto(request.before_message_id)),
3038 )
3039 .await?;
3040 response.send(proto::GetChannelMessagesResponse {
3041 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3042 messages,
3043 })?;
3044 Ok(())
3045}
3046
3047async fn get_channel_messages_by_id(
3048 request: proto::GetChannelMessagesById,
3049 response: Response<proto::GetChannelMessagesById>,
3050 session: Session,
3051) -> Result<()> {
3052 let message_ids = request
3053 .message_ids
3054 .iter()
3055 .map(|id| MessageId::from_proto(*id))
3056 .collect::<Vec<_>>();
3057 let messages = session
3058 .db()
3059 .await
3060 .get_channel_messages_by_id(session.user_id, &message_ids)
3061 .await?;
3062 response.send(proto::GetChannelMessagesResponse {
3063 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3064 messages,
3065 })?;
3066 Ok(())
3067}
3068
3069async fn get_notifications(
3070 request: proto::GetNotifications,
3071 response: Response<proto::GetNotifications>,
3072 session: Session,
3073) -> Result<()> {
3074 let notifications = session
3075 .db()
3076 .await
3077 .get_notifications(
3078 session.user_id,
3079 NOTIFICATION_COUNT_PER_PAGE,
3080 request
3081 .before_id
3082 .map(|id| db::NotificationId::from_proto(id)),
3083 )
3084 .await?;
3085 response.send(proto::GetNotificationsResponse {
3086 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3087 notifications,
3088 })?;
3089 Ok(())
3090}
3091
3092async fn mark_notification_as_read(
3093 request: proto::MarkNotificationRead,
3094 response: Response<proto::MarkNotificationRead>,
3095 session: Session,
3096) -> Result<()> {
3097 let database = &session.db().await;
3098 let notifications = database
3099 .mark_notification_as_read_by_id(
3100 session.user_id,
3101 NotificationId::from_proto(request.notification_id),
3102 )
3103 .await?;
3104 send_notifications(
3105 &*session.connection_pool().await,
3106 &session.peer,
3107 notifications,
3108 );
3109 response.send(proto::Ack {})?;
3110 Ok(())
3111}
3112
3113async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
3114 let project_id = ProjectId::from_proto(request.project_id);
3115 let project_connection_ids = session
3116 .db()
3117 .await
3118 .project_connection_ids(project_id, session.connection_id)
3119 .await?;
3120 broadcast(
3121 Some(session.connection_id),
3122 project_connection_ids.iter().copied(),
3123 |connection_id| {
3124 session
3125 .peer
3126 .forward_send(session.connection_id, connection_id, request.clone())
3127 },
3128 );
3129 Ok(())
3130}
3131
3132async fn get_private_user_info(
3133 _request: proto::GetPrivateUserInfo,
3134 response: Response<proto::GetPrivateUserInfo>,
3135 session: Session,
3136) -> Result<()> {
3137 let db = session.db().await;
3138
3139 let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3140 let user = db
3141 .get_user_by_id(session.user_id)
3142 .await?
3143 .ok_or_else(|| anyhow!("user not found"))?;
3144 let flags = db.get_user_flags(session.user_id).await?;
3145
3146 response.send(proto::GetPrivateUserInfoResponse {
3147 metrics_id,
3148 staff: user.admin,
3149 flags,
3150 })?;
3151 Ok(())
3152}
3153
3154fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3155 match message {
3156 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3157 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3158 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3159 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3160 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3161 code: frame.code.into(),
3162 reason: frame.reason,
3163 })),
3164 }
3165}
3166
3167fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3168 match message {
3169 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3170 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3171 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3172 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3173 AxumMessage::Close(frame) => {
3174 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3175 code: frame.code.into(),
3176 reason: frame.reason,
3177 }))
3178 }
3179 }
3180}
3181
3182fn notify_membership_updated(
3183 connection_pool: &ConnectionPool,
3184 result: MembershipUpdated,
3185 user_id: UserId,
3186 peer: &Peer,
3187) {
3188 let mut update = build_channels_update(result.new_channels, vec![]);
3189 update.delete_channels = result
3190 .removed_channels
3191 .into_iter()
3192 .map(|id| id.to_proto())
3193 .collect();
3194 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3195
3196 for connection_id in connection_pool.user_connection_ids(user_id) {
3197 peer.send(connection_id, update.clone()).trace_err();
3198 }
3199}
3200
3201fn build_channels_update(
3202 channels: ChannelsForUser,
3203 channel_invites: Vec<db::Channel>,
3204) -> proto::UpdateChannels {
3205 let mut update = proto::UpdateChannels::default();
3206
3207 for channel in channels.channels {
3208 update.channels.push(channel.to_proto());
3209 }
3210
3211 update.unseen_channel_buffer_changes = channels.unseen_buffer_changes;
3212 update.unseen_channel_messages = channels.channel_messages;
3213
3214 for (channel_id, participants) in channels.channel_participants {
3215 update
3216 .channel_participants
3217 .push(proto::ChannelParticipants {
3218 channel_id: channel_id.to_proto(),
3219 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3220 });
3221 }
3222
3223 for channel in channel_invites {
3224 update.channel_invitations.push(channel.to_proto());
3225 }
3226
3227 update
3228}
3229
3230fn build_initial_contacts_update(
3231 contacts: Vec<db::Contact>,
3232 pool: &ConnectionPool,
3233) -> proto::UpdateContacts {
3234 let mut update = proto::UpdateContacts::default();
3235
3236 for contact in contacts {
3237 match contact {
3238 db::Contact::Accepted { user_id, busy } => {
3239 update.contacts.push(contact_for_user(user_id, busy, &pool));
3240 }
3241 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3242 db::Contact::Incoming { user_id } => {
3243 update
3244 .incoming_requests
3245 .push(proto::IncomingContactRequest {
3246 requester_id: user_id.to_proto(),
3247 })
3248 }
3249 }
3250 }
3251
3252 update
3253}
3254
3255fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3256 proto::Contact {
3257 user_id: user_id.to_proto(),
3258 online: pool.is_user_online(user_id),
3259 busy,
3260 }
3261}
3262
3263fn room_updated(room: &proto::Room, peer: &Peer) {
3264 broadcast(
3265 None,
3266 room.participants
3267 .iter()
3268 .filter_map(|participant| Some(participant.peer_id?.into())),
3269 |peer_id| {
3270 peer.send(
3271 peer_id.into(),
3272 proto::RoomUpdated {
3273 room: Some(room.clone()),
3274 },
3275 )
3276 },
3277 );
3278}
3279
3280fn channel_updated(
3281 channel_id: ChannelId,
3282 room: &proto::Room,
3283 channel_members: &[UserId],
3284 peer: &Peer,
3285 pool: &ConnectionPool,
3286) {
3287 let participants = room
3288 .participants
3289 .iter()
3290 .map(|p| p.user_id)
3291 .collect::<Vec<_>>();
3292
3293 broadcast(
3294 None,
3295 channel_members
3296 .iter()
3297 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3298 |peer_id| {
3299 peer.send(
3300 peer_id.into(),
3301 proto::UpdateChannels {
3302 channel_participants: vec![proto::ChannelParticipants {
3303 channel_id: channel_id.to_proto(),
3304 participant_user_ids: participants.clone(),
3305 }],
3306 ..Default::default()
3307 },
3308 )
3309 },
3310 );
3311}
3312
3313async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3314 let db = session.db().await;
3315
3316 let contacts = db.get_contacts(user_id).await?;
3317 let busy = db.is_user_busy(user_id).await?;
3318
3319 let pool = session.connection_pool().await;
3320 let updated_contact = contact_for_user(user_id, busy, &pool);
3321 for contact in contacts {
3322 if let db::Contact::Accepted {
3323 user_id: contact_user_id,
3324 ..
3325 } = contact
3326 {
3327 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3328 session
3329 .peer
3330 .send(
3331 contact_conn_id,
3332 proto::UpdateContacts {
3333 contacts: vec![updated_contact.clone()],
3334 remove_contacts: Default::default(),
3335 incoming_requests: Default::default(),
3336 remove_incoming_requests: Default::default(),
3337 outgoing_requests: Default::default(),
3338 remove_outgoing_requests: Default::default(),
3339 },
3340 )
3341 .trace_err();
3342 }
3343 }
3344 }
3345 Ok(())
3346}
3347
3348async fn leave_room_for_session(session: &Session) -> Result<()> {
3349 let mut contacts_to_update = HashSet::default();
3350
3351 let room_id;
3352 let canceled_calls_to_user_ids;
3353 let live_kit_room;
3354 let delete_live_kit_room;
3355 let room;
3356 let channel_members;
3357 let channel_id;
3358
3359 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3360 contacts_to_update.insert(session.user_id);
3361
3362 for project in left_room.left_projects.values() {
3363 project_left(project, session);
3364 }
3365
3366 room_id = RoomId::from_proto(left_room.room.id);
3367 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3368 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3369 delete_live_kit_room = left_room.deleted;
3370 room = mem::take(&mut left_room.room);
3371 channel_members = mem::take(&mut left_room.channel_members);
3372 channel_id = left_room.channel_id;
3373
3374 room_updated(&room, &session.peer);
3375 } else {
3376 return Ok(());
3377 }
3378
3379 if let Some(channel_id) = channel_id {
3380 channel_updated(
3381 channel_id,
3382 &room,
3383 &channel_members,
3384 &session.peer,
3385 &*session.connection_pool().await,
3386 );
3387 }
3388
3389 {
3390 let pool = session.connection_pool().await;
3391 for canceled_user_id in canceled_calls_to_user_ids {
3392 for connection_id in pool.user_connection_ids(canceled_user_id) {
3393 session
3394 .peer
3395 .send(
3396 connection_id,
3397 proto::CallCanceled {
3398 room_id: room_id.to_proto(),
3399 },
3400 )
3401 .trace_err();
3402 }
3403 contacts_to_update.insert(canceled_user_id);
3404 }
3405 }
3406
3407 for contact_user_id in contacts_to_update {
3408 update_user_contacts(contact_user_id, &session).await?;
3409 }
3410
3411 if let Some(live_kit) = session.live_kit_client.as_ref() {
3412 live_kit
3413 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3414 .await
3415 .trace_err();
3416
3417 if delete_live_kit_room {
3418 live_kit.delete_room(live_kit_room).await.trace_err();
3419 }
3420 }
3421
3422 Ok(())
3423}
3424
3425async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3426 let left_channel_buffers = session
3427 .db()
3428 .await
3429 .leave_channel_buffers(session.connection_id)
3430 .await?;
3431
3432 for left_buffer in left_channel_buffers {
3433 channel_buffer_updated(
3434 session.connection_id,
3435 left_buffer.connections,
3436 &proto::UpdateChannelBufferCollaborators {
3437 channel_id: left_buffer.channel_id.to_proto(),
3438 collaborators: left_buffer.collaborators,
3439 },
3440 &session.peer,
3441 );
3442 }
3443
3444 Ok(())
3445}
3446
3447fn project_left(project: &db::LeftProject, session: &Session) {
3448 for connection_id in &project.connection_ids {
3449 if project.host_user_id == session.user_id {
3450 session
3451 .peer
3452 .send(
3453 *connection_id,
3454 proto::UnshareProject {
3455 project_id: project.id.to_proto(),
3456 },
3457 )
3458 .trace_err();
3459 } else {
3460 session
3461 .peer
3462 .send(
3463 *connection_id,
3464 proto::RemoveProjectCollaborator {
3465 project_id: project.id.to_proto(),
3466 peer_id: Some(session.connection_id.into()),
3467 },
3468 )
3469 .trace_err();
3470 }
3471 }
3472}
3473
3474pub trait ResultExt {
3475 type Ok;
3476
3477 fn trace_err(self) -> Option<Self::Ok>;
3478}
3479
3480impl<T, E> ResultExt for Result<T, E>
3481where
3482 E: std::fmt::Debug,
3483{
3484 type Ok = T;
3485
3486 fn trace_err(self) -> Option<T> {
3487 match self {
3488 Ok(value) => Some(value),
3489 Err(error) => {
3490 tracing::error!("{:?}", error);
3491 None
3492 }
3493 }
3494 }
3495}