1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{
6 self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreateChannelResult,
7 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
8 MoveChannelResult, NotificationId, ProjectId, RemoveChannelMemberResult,
9 RenameChannelResult, RespondToChannelInvite, RoomId, ServerId, SetChannelVisibilityResult,
10 User, UserId,
11 },
12 executor::Executor,
13 AppState, Result,
14};
15use anyhow::anyhow;
16use async_tungstenite::tungstenite::{
17 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
18};
19use axum::{
20 body::Body,
21 extract::{
22 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
23 ConnectInfo, WebSocketUpgrade,
24 },
25 headers::{Header, HeaderName},
26 http::StatusCode,
27 middleware,
28 response::IntoResponse,
29 routing::get,
30 Extension, Router, TypedHeader,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::ConnectionPool;
34use futures::{
35 channel::oneshot,
36 future::{self, BoxFuture},
37 stream::FuturesUnordered,
38 FutureExt, SinkExt, StreamExt, TryStreamExt,
39};
40use lazy_static::lazy_static;
41use prometheus::{register_int_gauge, IntGauge};
42use rpc::{
43 proto::{
44 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
45 RequestMessage, UpdateChannelBufferCollaborators,
46 },
47 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
48};
49use serde::{Serialize, Serializer};
50use std::{
51 any::TypeId,
52 fmt,
53 future::Future,
54 marker::PhantomData,
55 mem,
56 net::SocketAddr,
57 ops::{Deref, DerefMut},
58 rc::Rc,
59 sync::{
60 atomic::{AtomicBool, Ordering::SeqCst},
61 Arc,
62 },
63 time::{Duration, Instant},
64};
65use time::OffsetDateTime;
66use tokio::sync::{watch, Semaphore};
67use tower::ServiceBuilder;
68use tracing::{info_span, instrument, Instrument};
69use util::channel::RELEASE_CHANNEL_NAME;
70
71pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
72pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
73
74const MESSAGE_COUNT_PER_PAGE: usize = 100;
75const MAX_MESSAGE_LEN: usize = 1024;
76const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
77
78lazy_static! {
79 static ref METRIC_CONNECTIONS: IntGauge =
80 register_int_gauge!("connections", "number of connections").unwrap();
81 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
82 "shared_projects",
83 "number of open projects with one or more guests"
84 )
85 .unwrap();
86}
87
88type MessageHandler =
89 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
90
91struct Response<R> {
92 peer: Arc<Peer>,
93 receipt: Receipt<R>,
94 responded: Arc<AtomicBool>,
95}
96
97impl<R: RequestMessage> Response<R> {
98 fn send(self, payload: R::Response) -> Result<()> {
99 self.responded.store(true, SeqCst);
100 self.peer.respond(self.receipt, payload)?;
101 Ok(())
102 }
103}
104
105#[derive(Clone)]
106struct Session {
107 user_id: UserId,
108 connection_id: ConnectionId,
109 db: Arc<tokio::sync::Mutex<DbHandle>>,
110 peer: Arc<Peer>,
111 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
112 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
113 executor: Executor,
114}
115
116impl Session {
117 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
118 #[cfg(test)]
119 tokio::task::yield_now().await;
120 let guard = self.db.lock().await;
121 #[cfg(test)]
122 tokio::task::yield_now().await;
123 guard
124 }
125
126 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
127 #[cfg(test)]
128 tokio::task::yield_now().await;
129 let guard = self.connection_pool.lock();
130 ConnectionPoolGuard {
131 guard,
132 _not_send: PhantomData,
133 }
134 }
135}
136
137impl fmt::Debug for Session {
138 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
139 f.debug_struct("Session")
140 .field("user_id", &self.user_id)
141 .field("connection_id", &self.connection_id)
142 .finish()
143 }
144}
145
146struct DbHandle(Arc<Database>);
147
148impl Deref for DbHandle {
149 type Target = Database;
150
151 fn deref(&self) -> &Self::Target {
152 self.0.as_ref()
153 }
154}
155
156pub struct Server {
157 id: parking_lot::Mutex<ServerId>,
158 peer: Arc<Peer>,
159 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
160 app_state: Arc<AppState>,
161 executor: Executor,
162 handlers: HashMap<TypeId, MessageHandler>,
163 teardown: watch::Sender<()>,
164}
165
166pub(crate) struct ConnectionPoolGuard<'a> {
167 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
168 _not_send: PhantomData<Rc<()>>,
169}
170
171#[derive(Serialize)]
172pub struct ServerSnapshot<'a> {
173 peer: &'a Peer,
174 #[serde(serialize_with = "serialize_deref")]
175 connection_pool: ConnectionPoolGuard<'a>,
176}
177
178pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
179where
180 S: Serializer,
181 T: Deref<Target = U>,
182 U: Serialize,
183{
184 Serialize::serialize(value.deref(), serializer)
185}
186
187impl Server {
188 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
189 let mut server = Self {
190 id: parking_lot::Mutex::new(id),
191 peer: Peer::new(id.0 as u32),
192 app_state,
193 executor,
194 connection_pool: Default::default(),
195 handlers: Default::default(),
196 teardown: watch::channel(()).0,
197 };
198
199 server
200 .add_request_handler(ping)
201 .add_request_handler(create_room)
202 .add_request_handler(join_room)
203 .add_request_handler(rejoin_room)
204 .add_request_handler(leave_room)
205 .add_request_handler(call)
206 .add_request_handler(cancel_call)
207 .add_message_handler(decline_call)
208 .add_request_handler(update_participant_location)
209 .add_request_handler(share_project)
210 .add_message_handler(unshare_project)
211 .add_request_handler(join_project)
212 .add_message_handler(leave_project)
213 .add_request_handler(update_project)
214 .add_request_handler(update_worktree)
215 .add_message_handler(start_language_server)
216 .add_message_handler(update_language_server)
217 .add_message_handler(update_diagnostic_summary)
218 .add_message_handler(update_worktree_settings)
219 .add_message_handler(refresh_inlay_hints)
220 .add_request_handler(forward_project_request::<proto::GetHover>)
221 .add_request_handler(forward_project_request::<proto::GetDefinition>)
222 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
223 .add_request_handler(forward_project_request::<proto::GetReferences>)
224 .add_request_handler(forward_project_request::<proto::SearchProject>)
225 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
226 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
227 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
228 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
229 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
230 .add_request_handler(forward_project_request::<proto::GetCompletions>)
231 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
232 .add_request_handler(forward_project_request::<proto::ResolveCompletionDocumentation>)
233 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
234 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
235 .add_request_handler(forward_project_request::<proto::PrepareRename>)
236 .add_request_handler(forward_project_request::<proto::PerformRename>)
237 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
238 .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
239 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
240 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
241 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
242 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
243 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
244 .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
245 .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
246 .add_request_handler(forward_project_request::<proto::InlayHints>)
247 .add_message_handler(create_buffer_for_peer)
248 .add_request_handler(update_buffer)
249 .add_message_handler(update_buffer_file)
250 .add_message_handler(buffer_reloaded)
251 .add_message_handler(buffer_saved)
252 .add_request_handler(forward_project_request::<proto::SaveBuffer>)
253 .add_request_handler(get_users)
254 .add_request_handler(fuzzy_search_users)
255 .add_request_handler(request_contact)
256 .add_request_handler(remove_contact)
257 .add_request_handler(respond_to_contact_request)
258 .add_request_handler(create_channel)
259 .add_request_handler(delete_channel)
260 .add_request_handler(invite_channel_member)
261 .add_request_handler(remove_channel_member)
262 .add_request_handler(set_channel_member_role)
263 .add_request_handler(set_channel_visibility)
264 .add_request_handler(rename_channel)
265 .add_request_handler(join_channel_buffer)
266 .add_request_handler(leave_channel_buffer)
267 .add_message_handler(update_channel_buffer)
268 .add_request_handler(rejoin_channel_buffers)
269 .add_request_handler(get_channel_members)
270 .add_request_handler(respond_to_channel_invite)
271 .add_request_handler(join_channel)
272 .add_request_handler(join_channel_chat)
273 .add_message_handler(leave_channel_chat)
274 .add_request_handler(send_channel_message)
275 .add_request_handler(remove_channel_message)
276 .add_request_handler(get_channel_messages)
277 .add_request_handler(get_channel_messages_by_id)
278 .add_request_handler(get_notifications)
279 .add_request_handler(mark_notification_as_read)
280 .add_request_handler(move_channel)
281 .add_request_handler(follow)
282 .add_message_handler(unfollow)
283 .add_message_handler(update_followers)
284 .add_message_handler(update_diff_base)
285 .add_request_handler(get_private_user_info)
286 .add_message_handler(acknowledge_channel_message)
287 .add_message_handler(acknowledge_buffer_version);
288
289 Arc::new(server)
290 }
291
292 pub async fn start(&self) -> Result<()> {
293 let server_id = *self.id.lock();
294 let app_state = self.app_state.clone();
295 let peer = self.peer.clone();
296 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
297 let pool = self.connection_pool.clone();
298 let live_kit_client = self.app_state.live_kit_client.clone();
299
300 let span = info_span!("start server");
301 self.executor.spawn_detached(
302 async move {
303 tracing::info!("waiting for cleanup timeout");
304 timeout.await;
305 tracing::info!("cleanup timeout expired, retrieving stale rooms");
306 if let Some((room_ids, channel_ids)) = app_state
307 .db
308 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
309 .await
310 .trace_err()
311 {
312 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
313 tracing::info!(
314 stale_channel_buffer_count = channel_ids.len(),
315 "retrieved stale channel buffers"
316 );
317
318 for channel_id in channel_ids {
319 if let Some(refreshed_channel_buffer) = app_state
320 .db
321 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
322 .await
323 .trace_err()
324 {
325 for connection_id in refreshed_channel_buffer.connection_ids {
326 peer.send(
327 connection_id,
328 proto::UpdateChannelBufferCollaborators {
329 channel_id: channel_id.to_proto(),
330 collaborators: refreshed_channel_buffer
331 .collaborators
332 .clone(),
333 },
334 )
335 .trace_err();
336 }
337 }
338 }
339
340 for room_id in room_ids {
341 let mut contacts_to_update = HashSet::default();
342 let mut canceled_calls_to_user_ids = Vec::new();
343 let mut live_kit_room = String::new();
344 let mut delete_live_kit_room = false;
345
346 if let Some(mut refreshed_room) = app_state
347 .db
348 .clear_stale_room_participants(room_id, server_id)
349 .await
350 .trace_err()
351 {
352 tracing::info!(
353 room_id = room_id.0,
354 new_participant_count = refreshed_room.room.participants.len(),
355 "refreshed room"
356 );
357 room_updated(&refreshed_room.room, &peer);
358 if let Some(channel_id) = refreshed_room.channel_id {
359 channel_updated(
360 channel_id,
361 &refreshed_room.room,
362 &refreshed_room.channel_members,
363 &peer,
364 &*pool.lock(),
365 );
366 }
367 contacts_to_update
368 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
369 contacts_to_update
370 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
371 canceled_calls_to_user_ids =
372 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
373 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
374 delete_live_kit_room = refreshed_room.room.participants.is_empty();
375 }
376
377 {
378 let pool = pool.lock();
379 for canceled_user_id in canceled_calls_to_user_ids {
380 for connection_id in pool.user_connection_ids(canceled_user_id) {
381 peer.send(
382 connection_id,
383 proto::CallCanceled {
384 room_id: room_id.to_proto(),
385 },
386 )
387 .trace_err();
388 }
389 }
390 }
391
392 for user_id in contacts_to_update {
393 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
394 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
395 if let Some((busy, contacts)) = busy.zip(contacts) {
396 let pool = pool.lock();
397 let updated_contact = contact_for_user(user_id, busy, &pool);
398 for contact in contacts {
399 if let db::Contact::Accepted {
400 user_id: contact_user_id,
401 ..
402 } = contact
403 {
404 for contact_conn_id in
405 pool.user_connection_ids(contact_user_id)
406 {
407 peer.send(
408 contact_conn_id,
409 proto::UpdateContacts {
410 contacts: vec![updated_contact.clone()],
411 remove_contacts: Default::default(),
412 incoming_requests: Default::default(),
413 remove_incoming_requests: Default::default(),
414 outgoing_requests: Default::default(),
415 remove_outgoing_requests: Default::default(),
416 },
417 )
418 .trace_err();
419 }
420 }
421 }
422 }
423 }
424
425 if let Some(live_kit) = live_kit_client.as_ref() {
426 if delete_live_kit_room {
427 live_kit.delete_room(live_kit_room).await.trace_err();
428 }
429 }
430 }
431 }
432
433 app_state
434 .db
435 .delete_stale_servers(&app_state.config.zed_environment, server_id)
436 .await
437 .trace_err();
438 }
439 .instrument(span),
440 );
441 Ok(())
442 }
443
444 pub fn teardown(&self) {
445 self.peer.teardown();
446 self.connection_pool.lock().reset();
447 let _ = self.teardown.send(());
448 }
449
450 #[cfg(test)]
451 pub fn reset(&self, id: ServerId) {
452 self.teardown();
453 *self.id.lock() = id;
454 self.peer.reset(id.0 as u32);
455 }
456
457 #[cfg(test)]
458 pub fn id(&self) -> ServerId {
459 *self.id.lock()
460 }
461
462 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
463 where
464 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
465 Fut: 'static + Send + Future<Output = Result<()>>,
466 M: EnvelopedMessage,
467 {
468 let prev_handler = self.handlers.insert(
469 TypeId::of::<M>(),
470 Box::new(move |envelope, session| {
471 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
472 let span = info_span!(
473 "handle message",
474 payload_type = envelope.payload_type_name()
475 );
476 span.in_scope(|| {
477 tracing::info!(
478 payload_type = envelope.payload_type_name(),
479 "message received"
480 );
481 });
482 let start_time = Instant::now();
483 let future = (handler)(*envelope, session);
484 async move {
485 let result = future.await;
486 let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
487 match result {
488 Err(error) => {
489 tracing::error!(%error, ?duration_ms, "error handling message")
490 }
491 Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
492 }
493 }
494 .instrument(span)
495 .boxed()
496 }),
497 );
498 if prev_handler.is_some() {
499 panic!("registered a handler for the same message twice");
500 }
501 self
502 }
503
504 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
505 where
506 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
507 Fut: 'static + Send + Future<Output = Result<()>>,
508 M: EnvelopedMessage,
509 {
510 self.add_handler(move |envelope, session| handler(envelope.payload, session));
511 self
512 }
513
514 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
515 where
516 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
517 Fut: Send + Future<Output = Result<()>>,
518 M: RequestMessage,
519 {
520 let handler = Arc::new(handler);
521 self.add_handler(move |envelope, session| {
522 let receipt = envelope.receipt();
523 let handler = handler.clone();
524 async move {
525 let peer = session.peer.clone();
526 let responded = Arc::new(AtomicBool::default());
527 let response = Response {
528 peer: peer.clone(),
529 responded: responded.clone(),
530 receipt,
531 };
532 match (handler)(envelope.payload, response, session).await {
533 Ok(()) => {
534 if responded.load(std::sync::atomic::Ordering::SeqCst) {
535 Ok(())
536 } else {
537 Err(anyhow!("handler did not send a response"))?
538 }
539 }
540 Err(error) => {
541 peer.respond_with_error(
542 receipt,
543 proto::Error {
544 message: error.to_string(),
545 },
546 )?;
547 Err(error)
548 }
549 }
550 }
551 })
552 }
553
554 pub fn handle_connection(
555 self: &Arc<Self>,
556 connection: Connection,
557 address: String,
558 user: User,
559 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
560 executor: Executor,
561 ) -> impl Future<Output = Result<()>> {
562 let this = self.clone();
563 let user_id = user.id;
564 let login = user.github_login;
565 let span = info_span!("handle connection", %user_id, %login, %address);
566 let mut teardown = self.teardown.subscribe();
567 async move {
568 let (connection_id, handle_io, mut incoming_rx) = this
569 .peer
570 .add_connection(connection, {
571 let executor = executor.clone();
572 move |duration| executor.sleep(duration)
573 });
574
575 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
576 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
577 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
578
579 if let Some(send_connection_id) = send_connection_id.take() {
580 let _ = send_connection_id.send(connection_id);
581 }
582
583 if !user.connected_once {
584 this.peer.send(connection_id, proto::ShowContacts {})?;
585 this.app_state.db.set_user_connected_once(user_id, true).await?;
586 }
587
588 let (contacts, channels_for_user, channel_invites) = future::try_join3(
589 this.app_state.db.get_contacts(user_id),
590 this.app_state.db.get_channels_for_user(user_id),
591 this.app_state.db.get_channel_invites_for_user(user_id),
592 ).await?;
593
594 {
595 let mut pool = this.connection_pool.lock();
596 pool.add_connection(connection_id, user_id, user.admin);
597 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
598 this.peer.send(connection_id, build_channels_update(
599 channels_for_user,
600 channel_invites
601 ))?;
602 }
603
604 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
605 this.peer.send(connection_id, incoming_call)?;
606 }
607
608 let session = Session {
609 user_id,
610 connection_id,
611 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
612 peer: this.peer.clone(),
613 connection_pool: this.connection_pool.clone(),
614 live_kit_client: this.app_state.live_kit_client.clone(),
615 executor: executor.clone(),
616 };
617 update_user_contacts(user_id, &session).await?;
618
619 let handle_io = handle_io.fuse();
620 futures::pin_mut!(handle_io);
621
622 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
623 // This prevents deadlocks when e.g., client A performs a request to client B and
624 // client B performs a request to client A. If both clients stop processing further
625 // messages until their respective request completes, they won't have a chance to
626 // respond to the other client's request and cause a deadlock.
627 //
628 // This arrangement ensures we will attempt to process earlier messages first, but fall
629 // back to processing messages arrived later in the spirit of making progress.
630 let mut foreground_message_handlers = FuturesUnordered::new();
631 let concurrent_handlers = Arc::new(Semaphore::new(256));
632 loop {
633 let next_message = async {
634 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
635 let message = incoming_rx.next().await;
636 (permit, message)
637 }.fuse();
638 futures::pin_mut!(next_message);
639 futures::select_biased! {
640 _ = teardown.changed().fuse() => return Ok(()),
641 result = handle_io => {
642 if let Err(error) = result {
643 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
644 }
645 break;
646 }
647 _ = foreground_message_handlers.next() => {}
648 next_message = next_message => {
649 let (permit, message) = next_message;
650 if let Some(message) = message {
651 let type_name = message.payload_type_name();
652 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
653 let span_enter = span.enter();
654 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
655 let is_background = message.is_background();
656 let handle_message = (handler)(message, session.clone());
657 drop(span_enter);
658
659 let handle_message = async move {
660 handle_message.await;
661 drop(permit);
662 }.instrument(span);
663 if is_background {
664 executor.spawn_detached(handle_message);
665 } else {
666 foreground_message_handlers.push(handle_message);
667 }
668 } else {
669 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
670 }
671 } else {
672 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
673 break;
674 }
675 }
676 }
677 }
678
679 drop(foreground_message_handlers);
680 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
681 if let Err(error) = connection_lost(session, teardown, executor).await {
682 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
683 }
684
685 Ok(())
686 }.instrument(span)
687 }
688
689 pub async fn invite_code_redeemed(
690 self: &Arc<Self>,
691 inviter_id: UserId,
692 invitee_id: UserId,
693 ) -> Result<()> {
694 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
695 if let Some(code) = &user.invite_code {
696 let pool = self.connection_pool.lock();
697 let invitee_contact = contact_for_user(invitee_id, false, &pool);
698 for connection_id in pool.user_connection_ids(inviter_id) {
699 self.peer.send(
700 connection_id,
701 proto::UpdateContacts {
702 contacts: vec![invitee_contact.clone()],
703 ..Default::default()
704 },
705 )?;
706 self.peer.send(
707 connection_id,
708 proto::UpdateInviteInfo {
709 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
710 count: user.invite_count as u32,
711 },
712 )?;
713 }
714 }
715 }
716 Ok(())
717 }
718
719 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
720 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
721 if let Some(invite_code) = &user.invite_code {
722 let pool = self.connection_pool.lock();
723 for connection_id in pool.user_connection_ids(user_id) {
724 self.peer.send(
725 connection_id,
726 proto::UpdateInviteInfo {
727 url: format!(
728 "{}{}",
729 self.app_state.config.invite_link_prefix, invite_code
730 ),
731 count: user.invite_count as u32,
732 },
733 )?;
734 }
735 }
736 }
737 Ok(())
738 }
739
740 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
741 ServerSnapshot {
742 connection_pool: ConnectionPoolGuard {
743 guard: self.connection_pool.lock(),
744 _not_send: PhantomData,
745 },
746 peer: &self.peer,
747 }
748 }
749}
750
751impl<'a> Deref for ConnectionPoolGuard<'a> {
752 type Target = ConnectionPool;
753
754 fn deref(&self) -> &Self::Target {
755 &*self.guard
756 }
757}
758
759impl<'a> DerefMut for ConnectionPoolGuard<'a> {
760 fn deref_mut(&mut self) -> &mut Self::Target {
761 &mut *self.guard
762 }
763}
764
765impl<'a> Drop for ConnectionPoolGuard<'a> {
766 fn drop(&mut self) {
767 #[cfg(test)]
768 self.check_invariants();
769 }
770}
771
772fn broadcast<F>(
773 sender_id: Option<ConnectionId>,
774 receiver_ids: impl IntoIterator<Item = ConnectionId>,
775 mut f: F,
776) where
777 F: FnMut(ConnectionId) -> anyhow::Result<()>,
778{
779 for receiver_id in receiver_ids {
780 if Some(receiver_id) != sender_id {
781 if let Err(error) = f(receiver_id) {
782 tracing::error!("failed to send to {:?} {}", receiver_id, error);
783 }
784 }
785 }
786}
787
788lazy_static! {
789 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
790}
791
792pub struct ProtocolVersion(u32);
793
794impl Header for ProtocolVersion {
795 fn name() -> &'static HeaderName {
796 &ZED_PROTOCOL_VERSION
797 }
798
799 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
800 where
801 Self: Sized,
802 I: Iterator<Item = &'i axum::http::HeaderValue>,
803 {
804 let version = values
805 .next()
806 .ok_or_else(axum::headers::Error::invalid)?
807 .to_str()
808 .map_err(|_| axum::headers::Error::invalid())?
809 .parse()
810 .map_err(|_| axum::headers::Error::invalid())?;
811 Ok(Self(version))
812 }
813
814 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
815 values.extend([self.0.to_string().parse().unwrap()]);
816 }
817}
818
819pub fn routes(server: Arc<Server>) -> Router<Body> {
820 Router::new()
821 .route("/rpc", get(handle_websocket_request))
822 .layer(
823 ServiceBuilder::new()
824 .layer(Extension(server.app_state.clone()))
825 .layer(middleware::from_fn(auth::validate_header)),
826 )
827 .route("/metrics", get(handle_metrics))
828 .layer(Extension(server))
829}
830
831pub async fn handle_websocket_request(
832 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
833 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
834 Extension(server): Extension<Arc<Server>>,
835 Extension(user): Extension<User>,
836 ws: WebSocketUpgrade,
837) -> axum::response::Response {
838 if protocol_version != rpc::PROTOCOL_VERSION {
839 return (
840 StatusCode::UPGRADE_REQUIRED,
841 "client must be upgraded".to_string(),
842 )
843 .into_response();
844 }
845 let socket_address = socket_address.to_string();
846 ws.on_upgrade(move |socket| {
847 use util::ResultExt;
848 let socket = socket
849 .map_ok(to_tungstenite_message)
850 .err_into()
851 .with(|message| async move { Ok(to_axum_message(message)) });
852 let connection = Connection::new(Box::pin(socket));
853 async move {
854 server
855 .handle_connection(connection, socket_address, user, None, Executor::Production)
856 .await
857 .log_err();
858 }
859 })
860}
861
862pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
863 let connections = server
864 .connection_pool
865 .lock()
866 .connections()
867 .filter(|connection| !connection.admin)
868 .count();
869
870 METRIC_CONNECTIONS.set(connections as _);
871
872 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
873 METRIC_SHARED_PROJECTS.set(shared_projects as _);
874
875 let encoder = prometheus::TextEncoder::new();
876 let metric_families = prometheus::gather();
877 let encoded_metrics = encoder
878 .encode_to_string(&metric_families)
879 .map_err(|err| anyhow!("{}", err))?;
880 Ok(encoded_metrics)
881}
882
883#[instrument(err, skip(executor))]
884async fn connection_lost(
885 session: Session,
886 mut teardown: watch::Receiver<()>,
887 executor: Executor,
888) -> Result<()> {
889 session.peer.disconnect(session.connection_id);
890 session
891 .connection_pool()
892 .await
893 .remove_connection(session.connection_id)?;
894
895 session
896 .db()
897 .await
898 .connection_lost(session.connection_id)
899 .await
900 .trace_err();
901
902 futures::select_biased! {
903 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
904 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
905 leave_room_for_session(&session).await.trace_err();
906 leave_channel_buffers_for_session(&session)
907 .await
908 .trace_err();
909
910 if !session
911 .connection_pool()
912 .await
913 .is_user_online(session.user_id)
914 {
915 let db = session.db().await;
916 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
917 room_updated(&room, &session.peer);
918 }
919 }
920
921 update_user_contacts(session.user_id, &session).await?;
922 }
923 _ = teardown.changed().fuse() => {}
924 }
925
926 Ok(())
927}
928
929async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
930 response.send(proto::Ack {})?;
931 Ok(())
932}
933
934async fn create_room(
935 _request: proto::CreateRoom,
936 response: Response<proto::CreateRoom>,
937 session: Session,
938) -> Result<()> {
939 let live_kit_room = nanoid::nanoid!(30);
940
941 let live_kit_connection_info = {
942 let live_kit_room = live_kit_room.clone();
943 let live_kit = session.live_kit_client.as_ref();
944
945 util::async_maybe!({
946 let live_kit = live_kit?;
947
948 let token = live_kit
949 .room_token(&live_kit_room, &session.user_id.to_string())
950 .trace_err()?;
951
952 Some(proto::LiveKitConnectionInfo {
953 server_url: live_kit.url().into(),
954 token,
955 can_publish: true,
956 })
957 })
958 }
959 .await;
960
961 let room = session
962 .db()
963 .await
964 .create_room(
965 session.user_id,
966 session.connection_id,
967 &live_kit_room,
968 RELEASE_CHANNEL_NAME.as_str(),
969 )
970 .await?;
971
972 response.send(proto::CreateRoomResponse {
973 room: Some(room.clone()),
974 live_kit_connection_info,
975 })?;
976
977 update_user_contacts(session.user_id, &session).await?;
978 Ok(())
979}
980
981async fn join_room(
982 request: proto::JoinRoom,
983 response: Response<proto::JoinRoom>,
984 session: Session,
985) -> Result<()> {
986 let room_id = RoomId::from_proto(request.id);
987
988 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
989
990 if let Some(channel_id) = channel_id {
991 return join_channel_internal(channel_id, Box::new(response), session).await;
992 }
993
994 let joined_room = {
995 let room = session
996 .db()
997 .await
998 .join_room(
999 room_id,
1000 session.user_id,
1001 session.connection_id,
1002 RELEASE_CHANNEL_NAME.as_str(),
1003 )
1004 .await?;
1005 room_updated(&room.room, &session.peer);
1006 room.into_inner()
1007 };
1008
1009 for connection_id in session
1010 .connection_pool()
1011 .await
1012 .user_connection_ids(session.user_id)
1013 {
1014 session
1015 .peer
1016 .send(
1017 connection_id,
1018 proto::CallCanceled {
1019 room_id: room_id.to_proto(),
1020 },
1021 )
1022 .trace_err();
1023 }
1024
1025 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1026 if let Some(token) = live_kit
1027 .room_token(
1028 &joined_room.room.live_kit_room,
1029 &session.user_id.to_string(),
1030 )
1031 .trace_err()
1032 {
1033 Some(proto::LiveKitConnectionInfo {
1034 server_url: live_kit.url().into(),
1035 token,
1036 can_publish: true,
1037 })
1038 } else {
1039 None
1040 }
1041 } else {
1042 None
1043 };
1044
1045 response.send(proto::JoinRoomResponse {
1046 room: Some(joined_room.room),
1047 channel_id: None,
1048 live_kit_connection_info,
1049 })?;
1050
1051 update_user_contacts(session.user_id, &session).await?;
1052 Ok(())
1053}
1054
1055async fn rejoin_room(
1056 request: proto::RejoinRoom,
1057 response: Response<proto::RejoinRoom>,
1058 session: Session,
1059) -> Result<()> {
1060 let room;
1061 let channel_id;
1062 let channel_members;
1063 {
1064 let mut rejoined_room = session
1065 .db()
1066 .await
1067 .rejoin_room(request, session.user_id, session.connection_id)
1068 .await?;
1069
1070 response.send(proto::RejoinRoomResponse {
1071 room: Some(rejoined_room.room.clone()),
1072 reshared_projects: rejoined_room
1073 .reshared_projects
1074 .iter()
1075 .map(|project| proto::ResharedProject {
1076 id: project.id.to_proto(),
1077 collaborators: project
1078 .collaborators
1079 .iter()
1080 .map(|collaborator| collaborator.to_proto())
1081 .collect(),
1082 })
1083 .collect(),
1084 rejoined_projects: rejoined_room
1085 .rejoined_projects
1086 .iter()
1087 .map(|rejoined_project| proto::RejoinedProject {
1088 id: rejoined_project.id.to_proto(),
1089 worktrees: rejoined_project
1090 .worktrees
1091 .iter()
1092 .map(|worktree| proto::WorktreeMetadata {
1093 id: worktree.id,
1094 root_name: worktree.root_name.clone(),
1095 visible: worktree.visible,
1096 abs_path: worktree.abs_path.clone(),
1097 })
1098 .collect(),
1099 collaborators: rejoined_project
1100 .collaborators
1101 .iter()
1102 .map(|collaborator| collaborator.to_proto())
1103 .collect(),
1104 language_servers: rejoined_project.language_servers.clone(),
1105 })
1106 .collect(),
1107 })?;
1108 room_updated(&rejoined_room.room, &session.peer);
1109
1110 for project in &rejoined_room.reshared_projects {
1111 for collaborator in &project.collaborators {
1112 session
1113 .peer
1114 .send(
1115 collaborator.connection_id,
1116 proto::UpdateProjectCollaborator {
1117 project_id: project.id.to_proto(),
1118 old_peer_id: Some(project.old_connection_id.into()),
1119 new_peer_id: Some(session.connection_id.into()),
1120 },
1121 )
1122 .trace_err();
1123 }
1124
1125 broadcast(
1126 Some(session.connection_id),
1127 project
1128 .collaborators
1129 .iter()
1130 .map(|collaborator| collaborator.connection_id),
1131 |connection_id| {
1132 session.peer.forward_send(
1133 session.connection_id,
1134 connection_id,
1135 proto::UpdateProject {
1136 project_id: project.id.to_proto(),
1137 worktrees: project.worktrees.clone(),
1138 },
1139 )
1140 },
1141 );
1142 }
1143
1144 for project in &rejoined_room.rejoined_projects {
1145 for collaborator in &project.collaborators {
1146 session
1147 .peer
1148 .send(
1149 collaborator.connection_id,
1150 proto::UpdateProjectCollaborator {
1151 project_id: project.id.to_proto(),
1152 old_peer_id: Some(project.old_connection_id.into()),
1153 new_peer_id: Some(session.connection_id.into()),
1154 },
1155 )
1156 .trace_err();
1157 }
1158 }
1159
1160 for project in &mut rejoined_room.rejoined_projects {
1161 for worktree in mem::take(&mut project.worktrees) {
1162 #[cfg(any(test, feature = "test-support"))]
1163 const MAX_CHUNK_SIZE: usize = 2;
1164 #[cfg(not(any(test, feature = "test-support")))]
1165 const MAX_CHUNK_SIZE: usize = 256;
1166
1167 // Stream this worktree's entries.
1168 let message = proto::UpdateWorktree {
1169 project_id: project.id.to_proto(),
1170 worktree_id: worktree.id,
1171 abs_path: worktree.abs_path.clone(),
1172 root_name: worktree.root_name,
1173 updated_entries: worktree.updated_entries,
1174 removed_entries: worktree.removed_entries,
1175 scan_id: worktree.scan_id,
1176 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1177 updated_repositories: worktree.updated_repositories,
1178 removed_repositories: worktree.removed_repositories,
1179 };
1180 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1181 session.peer.send(session.connection_id, update.clone())?;
1182 }
1183
1184 // Stream this worktree's diagnostics.
1185 for summary in worktree.diagnostic_summaries {
1186 session.peer.send(
1187 session.connection_id,
1188 proto::UpdateDiagnosticSummary {
1189 project_id: project.id.to_proto(),
1190 worktree_id: worktree.id,
1191 summary: Some(summary),
1192 },
1193 )?;
1194 }
1195
1196 for settings_file in worktree.settings_files {
1197 session.peer.send(
1198 session.connection_id,
1199 proto::UpdateWorktreeSettings {
1200 project_id: project.id.to_proto(),
1201 worktree_id: worktree.id,
1202 path: settings_file.path,
1203 content: Some(settings_file.content),
1204 },
1205 )?;
1206 }
1207 }
1208
1209 for language_server in &project.language_servers {
1210 session.peer.send(
1211 session.connection_id,
1212 proto::UpdateLanguageServer {
1213 project_id: project.id.to_proto(),
1214 language_server_id: language_server.id,
1215 variant: Some(
1216 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1217 proto::LspDiskBasedDiagnosticsUpdated {},
1218 ),
1219 ),
1220 },
1221 )?;
1222 }
1223 }
1224
1225 let rejoined_room = rejoined_room.into_inner();
1226
1227 room = rejoined_room.room;
1228 channel_id = rejoined_room.channel_id;
1229 channel_members = rejoined_room.channel_members;
1230 }
1231
1232 if let Some(channel_id) = channel_id {
1233 channel_updated(
1234 channel_id,
1235 &room,
1236 &channel_members,
1237 &session.peer,
1238 &*session.connection_pool().await,
1239 );
1240 }
1241
1242 update_user_contacts(session.user_id, &session).await?;
1243 Ok(())
1244}
1245
1246async fn leave_room(
1247 _: proto::LeaveRoom,
1248 response: Response<proto::LeaveRoom>,
1249 session: Session,
1250) -> Result<()> {
1251 leave_room_for_session(&session).await?;
1252 response.send(proto::Ack {})?;
1253 Ok(())
1254}
1255
1256async fn call(
1257 request: proto::Call,
1258 response: Response<proto::Call>,
1259 session: Session,
1260) -> Result<()> {
1261 let room_id = RoomId::from_proto(request.room_id);
1262 let calling_user_id = session.user_id;
1263 let calling_connection_id = session.connection_id;
1264 let called_user_id = UserId::from_proto(request.called_user_id);
1265 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1266 if !session
1267 .db()
1268 .await
1269 .has_contact(calling_user_id, called_user_id)
1270 .await?
1271 {
1272 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1273 }
1274
1275 let incoming_call = {
1276 let (room, incoming_call) = &mut *session
1277 .db()
1278 .await
1279 .call(
1280 room_id,
1281 calling_user_id,
1282 calling_connection_id,
1283 called_user_id,
1284 initial_project_id,
1285 )
1286 .await?;
1287 room_updated(&room, &session.peer);
1288 mem::take(incoming_call)
1289 };
1290 update_user_contacts(called_user_id, &session).await?;
1291
1292 let mut calls = session
1293 .connection_pool()
1294 .await
1295 .user_connection_ids(called_user_id)
1296 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1297 .collect::<FuturesUnordered<_>>();
1298
1299 while let Some(call_response) = calls.next().await {
1300 match call_response.as_ref() {
1301 Ok(_) => {
1302 response.send(proto::Ack {})?;
1303 return Ok(());
1304 }
1305 Err(_) => {
1306 call_response.trace_err();
1307 }
1308 }
1309 }
1310
1311 {
1312 let room = session
1313 .db()
1314 .await
1315 .call_failed(room_id, called_user_id)
1316 .await?;
1317 room_updated(&room, &session.peer);
1318 }
1319 update_user_contacts(called_user_id, &session).await?;
1320
1321 Err(anyhow!("failed to ring user"))?
1322}
1323
1324async fn cancel_call(
1325 request: proto::CancelCall,
1326 response: Response<proto::CancelCall>,
1327 session: Session,
1328) -> Result<()> {
1329 let called_user_id = UserId::from_proto(request.called_user_id);
1330 let room_id = RoomId::from_proto(request.room_id);
1331 {
1332 let room = session
1333 .db()
1334 .await
1335 .cancel_call(room_id, session.connection_id, called_user_id)
1336 .await?;
1337 room_updated(&room, &session.peer);
1338 }
1339
1340 for connection_id in session
1341 .connection_pool()
1342 .await
1343 .user_connection_ids(called_user_id)
1344 {
1345 session
1346 .peer
1347 .send(
1348 connection_id,
1349 proto::CallCanceled {
1350 room_id: room_id.to_proto(),
1351 },
1352 )
1353 .trace_err();
1354 }
1355 response.send(proto::Ack {})?;
1356
1357 update_user_contacts(called_user_id, &session).await?;
1358 Ok(())
1359}
1360
1361async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1362 let room_id = RoomId::from_proto(message.room_id);
1363 {
1364 let room = session
1365 .db()
1366 .await
1367 .decline_call(Some(room_id), session.user_id)
1368 .await?
1369 .ok_or_else(|| anyhow!("failed to decline call"))?;
1370 room_updated(&room, &session.peer);
1371 }
1372
1373 for connection_id in session
1374 .connection_pool()
1375 .await
1376 .user_connection_ids(session.user_id)
1377 {
1378 session
1379 .peer
1380 .send(
1381 connection_id,
1382 proto::CallCanceled {
1383 room_id: room_id.to_proto(),
1384 },
1385 )
1386 .trace_err();
1387 }
1388 update_user_contacts(session.user_id, &session).await?;
1389 Ok(())
1390}
1391
1392async fn update_participant_location(
1393 request: proto::UpdateParticipantLocation,
1394 response: Response<proto::UpdateParticipantLocation>,
1395 session: Session,
1396) -> Result<()> {
1397 let room_id = RoomId::from_proto(request.room_id);
1398 let location = request
1399 .location
1400 .ok_or_else(|| anyhow!("invalid location"))?;
1401
1402 let db = session.db().await;
1403 let room = db
1404 .update_room_participant_location(room_id, session.connection_id, location)
1405 .await?;
1406
1407 room_updated(&room, &session.peer);
1408 response.send(proto::Ack {})?;
1409 Ok(())
1410}
1411
1412async fn share_project(
1413 request: proto::ShareProject,
1414 response: Response<proto::ShareProject>,
1415 session: Session,
1416) -> Result<()> {
1417 let (project_id, room) = &*session
1418 .db()
1419 .await
1420 .share_project(
1421 RoomId::from_proto(request.room_id),
1422 session.connection_id,
1423 &request.worktrees,
1424 )
1425 .await?;
1426 response.send(proto::ShareProjectResponse {
1427 project_id: project_id.to_proto(),
1428 })?;
1429 room_updated(&room, &session.peer);
1430
1431 Ok(())
1432}
1433
1434async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1435 let project_id = ProjectId::from_proto(message.project_id);
1436
1437 let (room, guest_connection_ids) = &*session
1438 .db()
1439 .await
1440 .unshare_project(project_id, session.connection_id)
1441 .await?;
1442
1443 broadcast(
1444 Some(session.connection_id),
1445 guest_connection_ids.iter().copied(),
1446 |conn_id| session.peer.send(conn_id, message.clone()),
1447 );
1448 room_updated(&room, &session.peer);
1449
1450 Ok(())
1451}
1452
1453async fn join_project(
1454 request: proto::JoinProject,
1455 response: Response<proto::JoinProject>,
1456 session: Session,
1457) -> Result<()> {
1458 let project_id = ProjectId::from_proto(request.project_id);
1459 let guest_user_id = session.user_id;
1460
1461 tracing::info!(%project_id, "join project");
1462
1463 let (project, replica_id) = &mut *session
1464 .db()
1465 .await
1466 .join_project(project_id, session.connection_id)
1467 .await?;
1468
1469 let collaborators = project
1470 .collaborators
1471 .iter()
1472 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1473 .map(|collaborator| collaborator.to_proto())
1474 .collect::<Vec<_>>();
1475
1476 let worktrees = project
1477 .worktrees
1478 .iter()
1479 .map(|(id, worktree)| proto::WorktreeMetadata {
1480 id: *id,
1481 root_name: worktree.root_name.clone(),
1482 visible: worktree.visible,
1483 abs_path: worktree.abs_path.clone(),
1484 })
1485 .collect::<Vec<_>>();
1486
1487 for collaborator in &collaborators {
1488 session
1489 .peer
1490 .send(
1491 collaborator.peer_id.unwrap().into(),
1492 proto::AddProjectCollaborator {
1493 project_id: project_id.to_proto(),
1494 collaborator: Some(proto::Collaborator {
1495 peer_id: Some(session.connection_id.into()),
1496 replica_id: replica_id.0 as u32,
1497 user_id: guest_user_id.to_proto(),
1498 }),
1499 },
1500 )
1501 .trace_err();
1502 }
1503
1504 // First, we send the metadata associated with each worktree.
1505 response.send(proto::JoinProjectResponse {
1506 worktrees: worktrees.clone(),
1507 replica_id: replica_id.0 as u32,
1508 collaborators: collaborators.clone(),
1509 language_servers: project.language_servers.clone(),
1510 })?;
1511
1512 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1513 #[cfg(any(test, feature = "test-support"))]
1514 const MAX_CHUNK_SIZE: usize = 2;
1515 #[cfg(not(any(test, feature = "test-support")))]
1516 const MAX_CHUNK_SIZE: usize = 256;
1517
1518 // Stream this worktree's entries.
1519 let message = proto::UpdateWorktree {
1520 project_id: project_id.to_proto(),
1521 worktree_id,
1522 abs_path: worktree.abs_path.clone(),
1523 root_name: worktree.root_name,
1524 updated_entries: worktree.entries,
1525 removed_entries: Default::default(),
1526 scan_id: worktree.scan_id,
1527 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1528 updated_repositories: worktree.repository_entries.into_values().collect(),
1529 removed_repositories: Default::default(),
1530 };
1531 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1532 session.peer.send(session.connection_id, update.clone())?;
1533 }
1534
1535 // Stream this worktree's diagnostics.
1536 for summary in worktree.diagnostic_summaries {
1537 session.peer.send(
1538 session.connection_id,
1539 proto::UpdateDiagnosticSummary {
1540 project_id: project_id.to_proto(),
1541 worktree_id: worktree.id,
1542 summary: Some(summary),
1543 },
1544 )?;
1545 }
1546
1547 for settings_file in worktree.settings_files {
1548 session.peer.send(
1549 session.connection_id,
1550 proto::UpdateWorktreeSettings {
1551 project_id: project_id.to_proto(),
1552 worktree_id: worktree.id,
1553 path: settings_file.path,
1554 content: Some(settings_file.content),
1555 },
1556 )?;
1557 }
1558 }
1559
1560 for language_server in &project.language_servers {
1561 session.peer.send(
1562 session.connection_id,
1563 proto::UpdateLanguageServer {
1564 project_id: project_id.to_proto(),
1565 language_server_id: language_server.id,
1566 variant: Some(
1567 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1568 proto::LspDiskBasedDiagnosticsUpdated {},
1569 ),
1570 ),
1571 },
1572 )?;
1573 }
1574
1575 Ok(())
1576}
1577
1578async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1579 let sender_id = session.connection_id;
1580 let project_id = ProjectId::from_proto(request.project_id);
1581
1582 let (room, project) = &*session
1583 .db()
1584 .await
1585 .leave_project(project_id, sender_id)
1586 .await?;
1587 tracing::info!(
1588 %project_id,
1589 host_user_id = %project.host_user_id,
1590 host_connection_id = %project.host_connection_id,
1591 "leave project"
1592 );
1593
1594 project_left(&project, &session);
1595 room_updated(&room, &session.peer);
1596
1597 Ok(())
1598}
1599
1600async fn update_project(
1601 request: proto::UpdateProject,
1602 response: Response<proto::UpdateProject>,
1603 session: Session,
1604) -> Result<()> {
1605 let project_id = ProjectId::from_proto(request.project_id);
1606 let (room, guest_connection_ids) = &*session
1607 .db()
1608 .await
1609 .update_project(project_id, session.connection_id, &request.worktrees)
1610 .await?;
1611 broadcast(
1612 Some(session.connection_id),
1613 guest_connection_ids.iter().copied(),
1614 |connection_id| {
1615 session
1616 .peer
1617 .forward_send(session.connection_id, connection_id, request.clone())
1618 },
1619 );
1620 room_updated(&room, &session.peer);
1621 response.send(proto::Ack {})?;
1622
1623 Ok(())
1624}
1625
1626async fn update_worktree(
1627 request: proto::UpdateWorktree,
1628 response: Response<proto::UpdateWorktree>,
1629 session: Session,
1630) -> Result<()> {
1631 let guest_connection_ids = session
1632 .db()
1633 .await
1634 .update_worktree(&request, session.connection_id)
1635 .await?;
1636
1637 broadcast(
1638 Some(session.connection_id),
1639 guest_connection_ids.iter().copied(),
1640 |connection_id| {
1641 session
1642 .peer
1643 .forward_send(session.connection_id, connection_id, request.clone())
1644 },
1645 );
1646 response.send(proto::Ack {})?;
1647 Ok(())
1648}
1649
1650async fn update_diagnostic_summary(
1651 message: proto::UpdateDiagnosticSummary,
1652 session: Session,
1653) -> Result<()> {
1654 let guest_connection_ids = session
1655 .db()
1656 .await
1657 .update_diagnostic_summary(&message, session.connection_id)
1658 .await?;
1659
1660 broadcast(
1661 Some(session.connection_id),
1662 guest_connection_ids.iter().copied(),
1663 |connection_id| {
1664 session
1665 .peer
1666 .forward_send(session.connection_id, connection_id, message.clone())
1667 },
1668 );
1669
1670 Ok(())
1671}
1672
1673async fn update_worktree_settings(
1674 message: proto::UpdateWorktreeSettings,
1675 session: Session,
1676) -> Result<()> {
1677 let guest_connection_ids = session
1678 .db()
1679 .await
1680 .update_worktree_settings(&message, session.connection_id)
1681 .await?;
1682
1683 broadcast(
1684 Some(session.connection_id),
1685 guest_connection_ids.iter().copied(),
1686 |connection_id| {
1687 session
1688 .peer
1689 .forward_send(session.connection_id, connection_id, message.clone())
1690 },
1691 );
1692
1693 Ok(())
1694}
1695
1696async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1697 broadcast_project_message(request.project_id, request, session).await
1698}
1699
1700async fn start_language_server(
1701 request: proto::StartLanguageServer,
1702 session: Session,
1703) -> Result<()> {
1704 let guest_connection_ids = session
1705 .db()
1706 .await
1707 .start_language_server(&request, session.connection_id)
1708 .await?;
1709
1710 broadcast(
1711 Some(session.connection_id),
1712 guest_connection_ids.iter().copied(),
1713 |connection_id| {
1714 session
1715 .peer
1716 .forward_send(session.connection_id, connection_id, request.clone())
1717 },
1718 );
1719 Ok(())
1720}
1721
1722async fn update_language_server(
1723 request: proto::UpdateLanguageServer,
1724 session: Session,
1725) -> Result<()> {
1726 session.executor.record_backtrace();
1727 let project_id = ProjectId::from_proto(request.project_id);
1728 let project_connection_ids = session
1729 .db()
1730 .await
1731 .project_connection_ids(project_id, session.connection_id)
1732 .await?;
1733 broadcast(
1734 Some(session.connection_id),
1735 project_connection_ids.iter().copied(),
1736 |connection_id| {
1737 session
1738 .peer
1739 .forward_send(session.connection_id, connection_id, request.clone())
1740 },
1741 );
1742 Ok(())
1743}
1744
1745async fn forward_project_request<T>(
1746 request: T,
1747 response: Response<T>,
1748 session: Session,
1749) -> Result<()>
1750where
1751 T: EntityMessage + RequestMessage,
1752{
1753 session.executor.record_backtrace();
1754 let project_id = ProjectId::from_proto(request.remote_entity_id());
1755 let host_connection_id = {
1756 let collaborators = session
1757 .db()
1758 .await
1759 .project_collaborators(project_id, session.connection_id)
1760 .await?;
1761 collaborators
1762 .iter()
1763 .find(|collaborator| collaborator.is_host)
1764 .ok_or_else(|| anyhow!("host not found"))?
1765 .connection_id
1766 };
1767
1768 let payload = session
1769 .peer
1770 .forward_request(session.connection_id, host_connection_id, request)
1771 .await?;
1772
1773 response.send(payload)?;
1774 Ok(())
1775}
1776
1777async fn create_buffer_for_peer(
1778 request: proto::CreateBufferForPeer,
1779 session: Session,
1780) -> Result<()> {
1781 session.executor.record_backtrace();
1782 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1783 session
1784 .peer
1785 .forward_send(session.connection_id, peer_id.into(), request)?;
1786 Ok(())
1787}
1788
1789async fn update_buffer(
1790 request: proto::UpdateBuffer,
1791 response: Response<proto::UpdateBuffer>,
1792 session: Session,
1793) -> Result<()> {
1794 session.executor.record_backtrace();
1795 let project_id = ProjectId::from_proto(request.project_id);
1796 let mut guest_connection_ids;
1797 let mut host_connection_id = None;
1798 {
1799 let collaborators = session
1800 .db()
1801 .await
1802 .project_collaborators(project_id, session.connection_id)
1803 .await?;
1804 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1805 for collaborator in collaborators.iter() {
1806 if collaborator.is_host {
1807 host_connection_id = Some(collaborator.connection_id);
1808 } else {
1809 guest_connection_ids.push(collaborator.connection_id);
1810 }
1811 }
1812 }
1813 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1814
1815 session.executor.record_backtrace();
1816 broadcast(
1817 Some(session.connection_id),
1818 guest_connection_ids,
1819 |connection_id| {
1820 session
1821 .peer
1822 .forward_send(session.connection_id, connection_id, request.clone())
1823 },
1824 );
1825 if host_connection_id != session.connection_id {
1826 session
1827 .peer
1828 .forward_request(session.connection_id, host_connection_id, request.clone())
1829 .await?;
1830 }
1831
1832 response.send(proto::Ack {})?;
1833 Ok(())
1834}
1835
1836async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1837 let project_id = ProjectId::from_proto(request.project_id);
1838 let project_connection_ids = session
1839 .db()
1840 .await
1841 .project_connection_ids(project_id, session.connection_id)
1842 .await?;
1843
1844 broadcast(
1845 Some(session.connection_id),
1846 project_connection_ids.iter().copied(),
1847 |connection_id| {
1848 session
1849 .peer
1850 .forward_send(session.connection_id, connection_id, request.clone())
1851 },
1852 );
1853 Ok(())
1854}
1855
1856async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1857 let project_id = ProjectId::from_proto(request.project_id);
1858 let project_connection_ids = session
1859 .db()
1860 .await
1861 .project_connection_ids(project_id, session.connection_id)
1862 .await?;
1863 broadcast(
1864 Some(session.connection_id),
1865 project_connection_ids.iter().copied(),
1866 |connection_id| {
1867 session
1868 .peer
1869 .forward_send(session.connection_id, connection_id, request.clone())
1870 },
1871 );
1872 Ok(())
1873}
1874
1875async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1876 broadcast_project_message(request.project_id, request, session).await
1877}
1878
1879async fn broadcast_project_message<T: EnvelopedMessage>(
1880 project_id: u64,
1881 request: T,
1882 session: Session,
1883) -> Result<()> {
1884 let project_id = ProjectId::from_proto(project_id);
1885 let project_connection_ids = session
1886 .db()
1887 .await
1888 .project_connection_ids(project_id, session.connection_id)
1889 .await?;
1890 broadcast(
1891 Some(session.connection_id),
1892 project_connection_ids.iter().copied(),
1893 |connection_id| {
1894 session
1895 .peer
1896 .forward_send(session.connection_id, connection_id, request.clone())
1897 },
1898 );
1899 Ok(())
1900}
1901
1902async fn follow(
1903 request: proto::Follow,
1904 response: Response<proto::Follow>,
1905 session: Session,
1906) -> Result<()> {
1907 let room_id = RoomId::from_proto(request.room_id);
1908 let project_id = request.project_id.map(ProjectId::from_proto);
1909 let leader_id = request
1910 .leader_id
1911 .ok_or_else(|| anyhow!("invalid leader id"))?
1912 .into();
1913 let follower_id = session.connection_id;
1914
1915 session
1916 .db()
1917 .await
1918 .check_room_participants(room_id, leader_id, session.connection_id)
1919 .await?;
1920
1921 let response_payload = session
1922 .peer
1923 .forward_request(session.connection_id, leader_id, request)
1924 .await?;
1925 response.send(response_payload)?;
1926
1927 if let Some(project_id) = project_id {
1928 let room = session
1929 .db()
1930 .await
1931 .follow(room_id, project_id, leader_id, follower_id)
1932 .await?;
1933 room_updated(&room, &session.peer);
1934 }
1935
1936 Ok(())
1937}
1938
1939async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1940 let room_id = RoomId::from_proto(request.room_id);
1941 let project_id = request.project_id.map(ProjectId::from_proto);
1942 let leader_id = request
1943 .leader_id
1944 .ok_or_else(|| anyhow!("invalid leader id"))?
1945 .into();
1946 let follower_id = session.connection_id;
1947
1948 session
1949 .db()
1950 .await
1951 .check_room_participants(room_id, leader_id, session.connection_id)
1952 .await?;
1953
1954 session
1955 .peer
1956 .forward_send(session.connection_id, leader_id, request)?;
1957
1958 if let Some(project_id) = project_id {
1959 let room = session
1960 .db()
1961 .await
1962 .unfollow(room_id, project_id, leader_id, follower_id)
1963 .await?;
1964 room_updated(&room, &session.peer);
1965 }
1966
1967 Ok(())
1968}
1969
1970async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1971 let room_id = RoomId::from_proto(request.room_id);
1972 let database = session.db.lock().await;
1973
1974 let connection_ids = if let Some(project_id) = request.project_id {
1975 let project_id = ProjectId::from_proto(project_id);
1976 database
1977 .project_connection_ids(project_id, session.connection_id)
1978 .await?
1979 } else {
1980 database
1981 .room_connection_ids(room_id, session.connection_id)
1982 .await?
1983 };
1984
1985 // For now, don't send view update messages back to that view's current leader.
1986 let connection_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
1987 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1988 _ => None,
1989 });
1990
1991 for follower_peer_id in request.follower_ids.iter().copied() {
1992 let follower_connection_id = follower_peer_id.into();
1993 if Some(follower_peer_id) != connection_id_to_omit
1994 && connection_ids.contains(&follower_connection_id)
1995 {
1996 session.peer.forward_send(
1997 session.connection_id,
1998 follower_connection_id,
1999 request.clone(),
2000 )?;
2001 }
2002 }
2003 Ok(())
2004}
2005
2006async fn get_users(
2007 request: proto::GetUsers,
2008 response: Response<proto::GetUsers>,
2009 session: Session,
2010) -> Result<()> {
2011 let user_ids = request
2012 .user_ids
2013 .into_iter()
2014 .map(UserId::from_proto)
2015 .collect();
2016 let users = session
2017 .db()
2018 .await
2019 .get_users_by_ids(user_ids)
2020 .await?
2021 .into_iter()
2022 .map(|user| proto::User {
2023 id: user.id.to_proto(),
2024 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2025 github_login: user.github_login,
2026 })
2027 .collect();
2028 response.send(proto::UsersResponse { users })?;
2029 Ok(())
2030}
2031
2032async fn fuzzy_search_users(
2033 request: proto::FuzzySearchUsers,
2034 response: Response<proto::FuzzySearchUsers>,
2035 session: Session,
2036) -> Result<()> {
2037 let query = request.query;
2038 let users = match query.len() {
2039 0 => vec![],
2040 1 | 2 => session
2041 .db()
2042 .await
2043 .get_user_by_github_login(&query)
2044 .await?
2045 .into_iter()
2046 .collect(),
2047 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2048 };
2049 let users = users
2050 .into_iter()
2051 .filter(|user| user.id != session.user_id)
2052 .map(|user| proto::User {
2053 id: user.id.to_proto(),
2054 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2055 github_login: user.github_login,
2056 })
2057 .collect();
2058 response.send(proto::UsersResponse { users })?;
2059 Ok(())
2060}
2061
2062async fn request_contact(
2063 request: proto::RequestContact,
2064 response: Response<proto::RequestContact>,
2065 session: Session,
2066) -> Result<()> {
2067 let requester_id = session.user_id;
2068 let responder_id = UserId::from_proto(request.responder_id);
2069 if requester_id == responder_id {
2070 return Err(anyhow!("cannot add yourself as a contact"))?;
2071 }
2072
2073 let notifications = session
2074 .db()
2075 .await
2076 .send_contact_request(requester_id, responder_id)
2077 .await?;
2078
2079 // Update outgoing contact requests of requester
2080 let mut update = proto::UpdateContacts::default();
2081 update.outgoing_requests.push(responder_id.to_proto());
2082 for connection_id in session
2083 .connection_pool()
2084 .await
2085 .user_connection_ids(requester_id)
2086 {
2087 session.peer.send(connection_id, update.clone())?;
2088 }
2089
2090 // Update incoming contact requests of responder
2091 let mut update = proto::UpdateContacts::default();
2092 update
2093 .incoming_requests
2094 .push(proto::IncomingContactRequest {
2095 requester_id: requester_id.to_proto(),
2096 });
2097 let connection_pool = session.connection_pool().await;
2098 for connection_id in connection_pool.user_connection_ids(responder_id) {
2099 session.peer.send(connection_id, update.clone())?;
2100 }
2101
2102 send_notifications(&*connection_pool, &session.peer, notifications);
2103
2104 response.send(proto::Ack {})?;
2105 Ok(())
2106}
2107
2108async fn respond_to_contact_request(
2109 request: proto::RespondToContactRequest,
2110 response: Response<proto::RespondToContactRequest>,
2111 session: Session,
2112) -> Result<()> {
2113 let responder_id = session.user_id;
2114 let requester_id = UserId::from_proto(request.requester_id);
2115 let db = session.db().await;
2116 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2117 db.dismiss_contact_notification(responder_id, requester_id)
2118 .await?;
2119 } else {
2120 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2121
2122 let notifications = db
2123 .respond_to_contact_request(responder_id, requester_id, accept)
2124 .await?;
2125 let requester_busy = db.is_user_busy(requester_id).await?;
2126 let responder_busy = db.is_user_busy(responder_id).await?;
2127
2128 let pool = session.connection_pool().await;
2129 // Update responder with new contact
2130 let mut update = proto::UpdateContacts::default();
2131 if accept {
2132 update
2133 .contacts
2134 .push(contact_for_user(requester_id, requester_busy, &pool));
2135 }
2136 update
2137 .remove_incoming_requests
2138 .push(requester_id.to_proto());
2139 for connection_id in pool.user_connection_ids(responder_id) {
2140 session.peer.send(connection_id, update.clone())?;
2141 }
2142
2143 // Update requester with new contact
2144 let mut update = proto::UpdateContacts::default();
2145 if accept {
2146 update
2147 .contacts
2148 .push(contact_for_user(responder_id, responder_busy, &pool));
2149 }
2150 update
2151 .remove_outgoing_requests
2152 .push(responder_id.to_proto());
2153
2154 for connection_id in pool.user_connection_ids(requester_id) {
2155 session.peer.send(connection_id, update.clone())?;
2156 }
2157
2158 send_notifications(&*pool, &session.peer, notifications);
2159 }
2160
2161 response.send(proto::Ack {})?;
2162 Ok(())
2163}
2164
2165async fn remove_contact(
2166 request: proto::RemoveContact,
2167 response: Response<proto::RemoveContact>,
2168 session: Session,
2169) -> Result<()> {
2170 let requester_id = session.user_id;
2171 let responder_id = UserId::from_proto(request.user_id);
2172 let db = session.db().await;
2173 let (contact_accepted, deleted_notification_id) =
2174 db.remove_contact(requester_id, responder_id).await?;
2175
2176 let pool = session.connection_pool().await;
2177 // Update outgoing contact requests of requester
2178 let mut update = proto::UpdateContacts::default();
2179 if contact_accepted {
2180 update.remove_contacts.push(responder_id.to_proto());
2181 } else {
2182 update
2183 .remove_outgoing_requests
2184 .push(responder_id.to_proto());
2185 }
2186 for connection_id in pool.user_connection_ids(requester_id) {
2187 session.peer.send(connection_id, update.clone())?;
2188 }
2189
2190 // Update incoming contact requests of responder
2191 let mut update = proto::UpdateContacts::default();
2192 if contact_accepted {
2193 update.remove_contacts.push(requester_id.to_proto());
2194 } else {
2195 update
2196 .remove_incoming_requests
2197 .push(requester_id.to_proto());
2198 }
2199 for connection_id in pool.user_connection_ids(responder_id) {
2200 session.peer.send(connection_id, update.clone())?;
2201 if let Some(notification_id) = deleted_notification_id {
2202 session.peer.send(
2203 connection_id,
2204 proto::DeleteNotification {
2205 notification_id: notification_id.to_proto(),
2206 },
2207 )?;
2208 }
2209 }
2210
2211 response.send(proto::Ack {})?;
2212 Ok(())
2213}
2214
2215async fn create_channel(
2216 request: proto::CreateChannel,
2217 response: Response<proto::CreateChannel>,
2218 session: Session,
2219) -> Result<()> {
2220 let db = session.db().await;
2221
2222 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2223 let CreateChannelResult {
2224 channel,
2225 participants_to_update,
2226 } = db
2227 .create_channel(&request.name, parent_id, session.user_id)
2228 .await?;
2229
2230 response.send(proto::CreateChannelResponse {
2231 channel: Some(channel.to_proto()),
2232 parent_id: request.parent_id,
2233 })?;
2234
2235 let connection_pool = session.connection_pool().await;
2236 for (user_id, channels) in participants_to_update {
2237 let update = build_channels_update(channels, vec![]);
2238 for connection_id in connection_pool.user_connection_ids(user_id) {
2239 if user_id == session.user_id {
2240 continue;
2241 }
2242 session.peer.send(connection_id, update.clone())?;
2243 }
2244 }
2245
2246 Ok(())
2247}
2248
2249async fn delete_channel(
2250 request: proto::DeleteChannel,
2251 response: Response<proto::DeleteChannel>,
2252 session: Session,
2253) -> Result<()> {
2254 let db = session.db().await;
2255
2256 let channel_id = request.channel_id;
2257 let (removed_channels, member_ids) = db
2258 .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2259 .await?;
2260 response.send(proto::Ack {})?;
2261
2262 // Notify members of removed channels
2263 let mut update = proto::UpdateChannels::default();
2264 update
2265 .delete_channels
2266 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2267
2268 let connection_pool = session.connection_pool().await;
2269 for member_id in member_ids {
2270 for connection_id in connection_pool.user_connection_ids(member_id) {
2271 session.peer.send(connection_id, update.clone())?;
2272 }
2273 }
2274
2275 Ok(())
2276}
2277
2278async fn invite_channel_member(
2279 request: proto::InviteChannelMember,
2280 response: Response<proto::InviteChannelMember>,
2281 session: Session,
2282) -> Result<()> {
2283 let db = session.db().await;
2284 let channel_id = ChannelId::from_proto(request.channel_id);
2285 let invitee_id = UserId::from_proto(request.user_id);
2286 let InviteMemberResult {
2287 channel,
2288 notifications,
2289 } = db
2290 .invite_channel_member(
2291 channel_id,
2292 invitee_id,
2293 session.user_id,
2294 request.role().into(),
2295 )
2296 .await?;
2297
2298 let update = proto::UpdateChannels {
2299 channel_invitations: vec![channel.to_proto()],
2300 ..Default::default()
2301 };
2302
2303 let connection_pool = session.connection_pool().await;
2304 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2305 session.peer.send(connection_id, update.clone())?;
2306 }
2307
2308 send_notifications(&*connection_pool, &session.peer, notifications);
2309
2310 response.send(proto::Ack {})?;
2311 Ok(())
2312}
2313
2314async fn remove_channel_member(
2315 request: proto::RemoveChannelMember,
2316 response: Response<proto::RemoveChannelMember>,
2317 session: Session,
2318) -> Result<()> {
2319 let db = session.db().await;
2320 let channel_id = ChannelId::from_proto(request.channel_id);
2321 let member_id = UserId::from_proto(request.user_id);
2322
2323 let RemoveChannelMemberResult {
2324 membership_update,
2325 notification_id,
2326 } = db
2327 .remove_channel_member(channel_id, member_id, session.user_id)
2328 .await?;
2329
2330 let connection_pool = &session.connection_pool().await;
2331 notify_membership_updated(
2332 &connection_pool,
2333 membership_update,
2334 member_id,
2335 &session.peer,
2336 );
2337 for connection_id in connection_pool.user_connection_ids(member_id) {
2338 if let Some(notification_id) = notification_id {
2339 session
2340 .peer
2341 .send(
2342 connection_id,
2343 proto::DeleteNotification {
2344 notification_id: notification_id.to_proto(),
2345 },
2346 )
2347 .trace_err();
2348 }
2349 }
2350
2351 response.send(proto::Ack {})?;
2352 Ok(())
2353}
2354
2355async fn set_channel_visibility(
2356 request: proto::SetChannelVisibility,
2357 response: Response<proto::SetChannelVisibility>,
2358 session: Session,
2359) -> Result<()> {
2360 let db = session.db().await;
2361 let channel_id = ChannelId::from_proto(request.channel_id);
2362 let visibility = request.visibility().into();
2363
2364 let SetChannelVisibilityResult {
2365 participants_to_update,
2366 participants_to_remove,
2367 channels_to_remove,
2368 } = db
2369 .set_channel_visibility(channel_id, visibility, session.user_id)
2370 .await?;
2371
2372 let connection_pool = session.connection_pool().await;
2373 for (user_id, channels) in participants_to_update {
2374 let update = build_channels_update(channels, vec![]);
2375 for connection_id in connection_pool.user_connection_ids(user_id) {
2376 session.peer.send(connection_id, update.clone())?;
2377 }
2378 }
2379 for user_id in participants_to_remove {
2380 let update = proto::UpdateChannels {
2381 delete_channels: channels_to_remove.iter().map(|id| id.to_proto()).collect(),
2382 ..Default::default()
2383 };
2384 for connection_id in connection_pool.user_connection_ids(user_id) {
2385 session.peer.send(connection_id, update.clone())?;
2386 }
2387 }
2388
2389 response.send(proto::Ack {})?;
2390 Ok(())
2391}
2392
2393async fn set_channel_member_role(
2394 request: proto::SetChannelMemberRole,
2395 response: Response<proto::SetChannelMemberRole>,
2396 session: Session,
2397) -> Result<()> {
2398 let db = session.db().await;
2399 let channel_id = ChannelId::from_proto(request.channel_id);
2400 let member_id = UserId::from_proto(request.user_id);
2401 let result = db
2402 .set_channel_member_role(
2403 channel_id,
2404 session.user_id,
2405 member_id,
2406 request.role().into(),
2407 )
2408 .await?;
2409
2410 match result {
2411 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2412 let connection_pool = session.connection_pool().await;
2413 notify_membership_updated(
2414 &connection_pool,
2415 membership_update,
2416 member_id,
2417 &session.peer,
2418 )
2419 }
2420 db::SetMemberRoleResult::InviteUpdated(channel) => {
2421 let update = proto::UpdateChannels {
2422 channel_invitations: vec![channel.to_proto()],
2423 ..Default::default()
2424 };
2425
2426 for connection_id in session
2427 .connection_pool()
2428 .await
2429 .user_connection_ids(member_id)
2430 {
2431 session.peer.send(connection_id, update.clone())?;
2432 }
2433 }
2434 }
2435
2436 response.send(proto::Ack {})?;
2437 Ok(())
2438}
2439
2440async fn rename_channel(
2441 request: proto::RenameChannel,
2442 response: Response<proto::RenameChannel>,
2443 session: Session,
2444) -> Result<()> {
2445 let db = session.db().await;
2446 let channel_id = ChannelId::from_proto(request.channel_id);
2447 let RenameChannelResult {
2448 channel,
2449 participants_to_update,
2450 } = db
2451 .rename_channel(channel_id, session.user_id, &request.name)
2452 .await?;
2453
2454 response.send(proto::RenameChannelResponse {
2455 channel: Some(channel.to_proto()),
2456 })?;
2457
2458 let connection_pool = session.connection_pool().await;
2459 for (user_id, channel) in participants_to_update {
2460 for connection_id in connection_pool.user_connection_ids(user_id) {
2461 let update = proto::UpdateChannels {
2462 channels: vec![channel.to_proto()],
2463 ..Default::default()
2464 };
2465
2466 session.peer.send(connection_id, update.clone())?;
2467 }
2468 }
2469
2470 Ok(())
2471}
2472
2473async fn move_channel(
2474 request: proto::MoveChannel,
2475 response: Response<proto::MoveChannel>,
2476 session: Session,
2477) -> Result<()> {
2478 let channel_id = ChannelId::from_proto(request.channel_id);
2479 let to = request.to.map(ChannelId::from_proto);
2480
2481 let result = session
2482 .db()
2483 .await
2484 .move_channel(channel_id, to, session.user_id)
2485 .await?;
2486
2487 notify_channel_moved(result, session).await?;
2488
2489 response.send(Ack {})?;
2490 Ok(())
2491}
2492
2493async fn notify_channel_moved(result: Option<MoveChannelResult>, session: Session) -> Result<()> {
2494 let Some(MoveChannelResult {
2495 participants_to_remove,
2496 participants_to_update,
2497 moved_channels,
2498 }) = result
2499 else {
2500 return Ok(());
2501 };
2502 let moved_channels: Vec<u64> = moved_channels.iter().map(|id| id.to_proto()).collect();
2503
2504 let connection_pool = session.connection_pool().await;
2505 for (user_id, channels) in participants_to_update {
2506 let mut update = build_channels_update(channels, vec![]);
2507 update.delete_channels = moved_channels.clone();
2508 for connection_id in connection_pool.user_connection_ids(user_id) {
2509 session.peer.send(connection_id, update.clone())?;
2510 }
2511 }
2512
2513 for user_id in participants_to_remove {
2514 let update = proto::UpdateChannels {
2515 delete_channels: moved_channels.clone(),
2516 ..Default::default()
2517 };
2518 for connection_id in connection_pool.user_connection_ids(user_id) {
2519 session.peer.send(connection_id, update.clone())?;
2520 }
2521 }
2522 Ok(())
2523}
2524
2525async fn get_channel_members(
2526 request: proto::GetChannelMembers,
2527 response: Response<proto::GetChannelMembers>,
2528 session: Session,
2529) -> Result<()> {
2530 let db = session.db().await;
2531 let channel_id = ChannelId::from_proto(request.channel_id);
2532 let members = db
2533 .get_channel_participant_details(channel_id, session.user_id)
2534 .await?;
2535 response.send(proto::GetChannelMembersResponse { members })?;
2536 Ok(())
2537}
2538
2539async fn respond_to_channel_invite(
2540 request: proto::RespondToChannelInvite,
2541 response: Response<proto::RespondToChannelInvite>,
2542 session: Session,
2543) -> Result<()> {
2544 let db = session.db().await;
2545 let channel_id = ChannelId::from_proto(request.channel_id);
2546 let RespondToChannelInvite {
2547 membership_update,
2548 notifications,
2549 } = db
2550 .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2551 .await?;
2552
2553 let connection_pool = session.connection_pool().await;
2554 if let Some(membership_update) = membership_update {
2555 notify_membership_updated(
2556 &connection_pool,
2557 membership_update,
2558 session.user_id,
2559 &session.peer,
2560 );
2561 } else {
2562 let update = proto::UpdateChannels {
2563 remove_channel_invitations: vec![channel_id.to_proto()],
2564 ..Default::default()
2565 };
2566
2567 for connection_id in connection_pool.user_connection_ids(session.user_id) {
2568 session.peer.send(connection_id, update.clone())?;
2569 }
2570 };
2571
2572 send_notifications(&*connection_pool, &session.peer, notifications);
2573
2574 response.send(proto::Ack {})?;
2575
2576 Ok(())
2577}
2578
2579async fn join_channel(
2580 request: proto::JoinChannel,
2581 response: Response<proto::JoinChannel>,
2582 session: Session,
2583) -> Result<()> {
2584 let channel_id = ChannelId::from_proto(request.channel_id);
2585 join_channel_internal(channel_id, Box::new(response), session).await
2586}
2587
2588trait JoinChannelInternalResponse {
2589 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2590}
2591impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2592 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2593 Response::<proto::JoinChannel>::send(self, result)
2594 }
2595}
2596impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2597 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2598 Response::<proto::JoinRoom>::send(self, result)
2599 }
2600}
2601
2602async fn join_channel_internal(
2603 channel_id: ChannelId,
2604 response: Box<impl JoinChannelInternalResponse>,
2605 session: Session,
2606) -> Result<()> {
2607 let joined_room = {
2608 leave_room_for_session(&session).await?;
2609 let db = session.db().await;
2610
2611 let (joined_room, membership_updated, role) = db
2612 .join_channel(
2613 channel_id,
2614 session.user_id,
2615 session.connection_id,
2616 RELEASE_CHANNEL_NAME.as_str(),
2617 )
2618 .await?;
2619
2620 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2621 let (can_publish, token) = if role == ChannelRole::Guest {
2622 (
2623 false,
2624 live_kit
2625 .guest_token(
2626 &joined_room.room.live_kit_room,
2627 &session.user_id.to_string(),
2628 )
2629 .trace_err()?,
2630 )
2631 } else {
2632 (
2633 true,
2634 live_kit
2635 .room_token(
2636 &joined_room.room.live_kit_room,
2637 &session.user_id.to_string(),
2638 )
2639 .trace_err()?,
2640 )
2641 };
2642
2643 Some(LiveKitConnectionInfo {
2644 server_url: live_kit.url().into(),
2645 token,
2646 can_publish,
2647 })
2648 });
2649
2650 response.send(proto::JoinRoomResponse {
2651 room: Some(joined_room.room.clone()),
2652 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2653 live_kit_connection_info,
2654 })?;
2655
2656 let connection_pool = session.connection_pool().await;
2657 if let Some(membership_updated) = membership_updated {
2658 notify_membership_updated(
2659 &connection_pool,
2660 membership_updated,
2661 session.user_id,
2662 &session.peer,
2663 );
2664 }
2665
2666 room_updated(&joined_room.room, &session.peer);
2667
2668 joined_room
2669 };
2670
2671 channel_updated(
2672 channel_id,
2673 &joined_room.room,
2674 &joined_room.channel_members,
2675 &session.peer,
2676 &*session.connection_pool().await,
2677 );
2678
2679 update_user_contacts(session.user_id, &session).await?;
2680 Ok(())
2681}
2682
2683async fn join_channel_buffer(
2684 request: proto::JoinChannelBuffer,
2685 response: Response<proto::JoinChannelBuffer>,
2686 session: Session,
2687) -> Result<()> {
2688 let db = session.db().await;
2689 let channel_id = ChannelId::from_proto(request.channel_id);
2690
2691 let open_response = db
2692 .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2693 .await?;
2694
2695 let collaborators = open_response.collaborators.clone();
2696 response.send(open_response)?;
2697
2698 let update = UpdateChannelBufferCollaborators {
2699 channel_id: channel_id.to_proto(),
2700 collaborators: collaborators.clone(),
2701 };
2702 channel_buffer_updated(
2703 session.connection_id,
2704 collaborators
2705 .iter()
2706 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2707 &update,
2708 &session.peer,
2709 );
2710
2711 Ok(())
2712}
2713
2714async fn update_channel_buffer(
2715 request: proto::UpdateChannelBuffer,
2716 session: Session,
2717) -> Result<()> {
2718 let db = session.db().await;
2719 let channel_id = ChannelId::from_proto(request.channel_id);
2720
2721 let (collaborators, non_collaborators, epoch, version) = db
2722 .update_channel_buffer(channel_id, session.user_id, &request.operations)
2723 .await?;
2724
2725 channel_buffer_updated(
2726 session.connection_id,
2727 collaborators,
2728 &proto::UpdateChannelBuffer {
2729 channel_id: channel_id.to_proto(),
2730 operations: request.operations,
2731 },
2732 &session.peer,
2733 );
2734
2735 let pool = &*session.connection_pool().await;
2736
2737 broadcast(
2738 None,
2739 non_collaborators
2740 .iter()
2741 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2742 |peer_id| {
2743 session.peer.send(
2744 peer_id.into(),
2745 proto::UpdateChannels {
2746 unseen_channel_buffer_changes: vec![proto::UnseenChannelBufferChange {
2747 channel_id: channel_id.to_proto(),
2748 epoch: epoch as u64,
2749 version: version.clone(),
2750 }],
2751 ..Default::default()
2752 },
2753 )
2754 },
2755 );
2756
2757 Ok(())
2758}
2759
2760async fn rejoin_channel_buffers(
2761 request: proto::RejoinChannelBuffers,
2762 response: Response<proto::RejoinChannelBuffers>,
2763 session: Session,
2764) -> Result<()> {
2765 let db = session.db().await;
2766 let buffers = db
2767 .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2768 .await?;
2769
2770 for rejoined_buffer in &buffers {
2771 let collaborators_to_notify = rejoined_buffer
2772 .buffer
2773 .collaborators
2774 .iter()
2775 .filter_map(|c| Some(c.peer_id?.into()));
2776 channel_buffer_updated(
2777 session.connection_id,
2778 collaborators_to_notify,
2779 &proto::UpdateChannelBufferCollaborators {
2780 channel_id: rejoined_buffer.buffer.channel_id,
2781 collaborators: rejoined_buffer.buffer.collaborators.clone(),
2782 },
2783 &session.peer,
2784 );
2785 }
2786
2787 response.send(proto::RejoinChannelBuffersResponse {
2788 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2789 })?;
2790
2791 Ok(())
2792}
2793
2794async fn leave_channel_buffer(
2795 request: proto::LeaveChannelBuffer,
2796 response: Response<proto::LeaveChannelBuffer>,
2797 session: Session,
2798) -> Result<()> {
2799 let db = session.db().await;
2800 let channel_id = ChannelId::from_proto(request.channel_id);
2801
2802 let left_buffer = db
2803 .leave_channel_buffer(channel_id, session.connection_id)
2804 .await?;
2805
2806 response.send(Ack {})?;
2807
2808 channel_buffer_updated(
2809 session.connection_id,
2810 left_buffer.connections,
2811 &proto::UpdateChannelBufferCollaborators {
2812 channel_id: channel_id.to_proto(),
2813 collaborators: left_buffer.collaborators,
2814 },
2815 &session.peer,
2816 );
2817
2818 Ok(())
2819}
2820
2821fn channel_buffer_updated<T: EnvelopedMessage>(
2822 sender_id: ConnectionId,
2823 collaborators: impl IntoIterator<Item = ConnectionId>,
2824 message: &T,
2825 peer: &Peer,
2826) {
2827 broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2828 peer.send(peer_id.into(), message.clone())
2829 });
2830}
2831
2832fn send_notifications(
2833 connection_pool: &ConnectionPool,
2834 peer: &Peer,
2835 notifications: db::NotificationBatch,
2836) {
2837 for (user_id, notification) in notifications {
2838 for connection_id in connection_pool.user_connection_ids(user_id) {
2839 if let Err(error) = peer.send(
2840 connection_id,
2841 proto::AddNotification {
2842 notification: Some(notification.clone()),
2843 },
2844 ) {
2845 tracing::error!(
2846 "failed to send notification to {:?} {}",
2847 connection_id,
2848 error
2849 );
2850 }
2851 }
2852 }
2853}
2854
2855async fn send_channel_message(
2856 request: proto::SendChannelMessage,
2857 response: Response<proto::SendChannelMessage>,
2858 session: Session,
2859) -> Result<()> {
2860 // Validate the message body.
2861 let body = request.body.trim().to_string();
2862 if body.len() > MAX_MESSAGE_LEN {
2863 return Err(anyhow!("message is too long"))?;
2864 }
2865 if body.is_empty() {
2866 return Err(anyhow!("message can't be blank"))?;
2867 }
2868
2869 // TODO: adjust mentions if body is trimmed
2870
2871 let timestamp = OffsetDateTime::now_utc();
2872 let nonce = request
2873 .nonce
2874 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
2875
2876 let channel_id = ChannelId::from_proto(request.channel_id);
2877 let CreatedChannelMessage {
2878 message_id,
2879 participant_connection_ids,
2880 channel_members,
2881 notifications,
2882 } = session
2883 .db()
2884 .await
2885 .create_channel_message(
2886 channel_id,
2887 session.user_id,
2888 &body,
2889 &request.mentions,
2890 timestamp,
2891 nonce.clone().into(),
2892 )
2893 .await?;
2894 let message = proto::ChannelMessage {
2895 sender_id: session.user_id.to_proto(),
2896 id: message_id.to_proto(),
2897 body,
2898 mentions: request.mentions,
2899 timestamp: timestamp.unix_timestamp() as u64,
2900 nonce: Some(nonce),
2901 };
2902 broadcast(
2903 Some(session.connection_id),
2904 participant_connection_ids,
2905 |connection| {
2906 session.peer.send(
2907 connection,
2908 proto::ChannelMessageSent {
2909 channel_id: channel_id.to_proto(),
2910 message: Some(message.clone()),
2911 },
2912 )
2913 },
2914 );
2915 response.send(proto::SendChannelMessageResponse {
2916 message: Some(message),
2917 })?;
2918
2919 let pool = &*session.connection_pool().await;
2920 broadcast(
2921 None,
2922 channel_members
2923 .iter()
2924 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2925 |peer_id| {
2926 session.peer.send(
2927 peer_id.into(),
2928 proto::UpdateChannels {
2929 unseen_channel_messages: vec![proto::UnseenChannelMessage {
2930 channel_id: channel_id.to_proto(),
2931 message_id: message_id.to_proto(),
2932 }],
2933 ..Default::default()
2934 },
2935 )
2936 },
2937 );
2938 send_notifications(pool, &session.peer, notifications);
2939
2940 Ok(())
2941}
2942
2943async fn remove_channel_message(
2944 request: proto::RemoveChannelMessage,
2945 response: Response<proto::RemoveChannelMessage>,
2946 session: Session,
2947) -> Result<()> {
2948 let channel_id = ChannelId::from_proto(request.channel_id);
2949 let message_id = MessageId::from_proto(request.message_id);
2950 let connection_ids = session
2951 .db()
2952 .await
2953 .remove_channel_message(channel_id, message_id, session.user_id)
2954 .await?;
2955 broadcast(Some(session.connection_id), connection_ids, |connection| {
2956 session.peer.send(connection, request.clone())
2957 });
2958 response.send(proto::Ack {})?;
2959 Ok(())
2960}
2961
2962async fn acknowledge_channel_message(
2963 request: proto::AckChannelMessage,
2964 session: Session,
2965) -> Result<()> {
2966 let channel_id = ChannelId::from_proto(request.channel_id);
2967 let message_id = MessageId::from_proto(request.message_id);
2968 let notifications = session
2969 .db()
2970 .await
2971 .observe_channel_message(channel_id, session.user_id, message_id)
2972 .await?;
2973 send_notifications(
2974 &*session.connection_pool().await,
2975 &session.peer,
2976 notifications,
2977 );
2978 Ok(())
2979}
2980
2981async fn acknowledge_buffer_version(
2982 request: proto::AckBufferOperation,
2983 session: Session,
2984) -> Result<()> {
2985 let buffer_id = BufferId::from_proto(request.buffer_id);
2986 session
2987 .db()
2988 .await
2989 .observe_buffer_version(
2990 buffer_id,
2991 session.user_id,
2992 request.epoch as i32,
2993 &request.version,
2994 )
2995 .await?;
2996 Ok(())
2997}
2998
2999async fn join_channel_chat(
3000 request: proto::JoinChannelChat,
3001 response: Response<proto::JoinChannelChat>,
3002 session: Session,
3003) -> Result<()> {
3004 let channel_id = ChannelId::from_proto(request.channel_id);
3005
3006 let db = session.db().await;
3007 db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3008 .await?;
3009 let messages = db
3010 .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3011 .await?;
3012 response.send(proto::JoinChannelChatResponse {
3013 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3014 messages,
3015 })?;
3016 Ok(())
3017}
3018
3019async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3020 let channel_id = ChannelId::from_proto(request.channel_id);
3021 session
3022 .db()
3023 .await
3024 .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3025 .await?;
3026 Ok(())
3027}
3028
3029async fn get_channel_messages(
3030 request: proto::GetChannelMessages,
3031 response: Response<proto::GetChannelMessages>,
3032 session: Session,
3033) -> Result<()> {
3034 let channel_id = ChannelId::from_proto(request.channel_id);
3035 let messages = session
3036 .db()
3037 .await
3038 .get_channel_messages(
3039 channel_id,
3040 session.user_id,
3041 MESSAGE_COUNT_PER_PAGE,
3042 Some(MessageId::from_proto(request.before_message_id)),
3043 )
3044 .await?;
3045 response.send(proto::GetChannelMessagesResponse {
3046 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3047 messages,
3048 })?;
3049 Ok(())
3050}
3051
3052async fn get_channel_messages_by_id(
3053 request: proto::GetChannelMessagesById,
3054 response: Response<proto::GetChannelMessagesById>,
3055 session: Session,
3056) -> Result<()> {
3057 let message_ids = request
3058 .message_ids
3059 .iter()
3060 .map(|id| MessageId::from_proto(*id))
3061 .collect::<Vec<_>>();
3062 let messages = session
3063 .db()
3064 .await
3065 .get_channel_messages_by_id(session.user_id, &message_ids)
3066 .await?;
3067 response.send(proto::GetChannelMessagesResponse {
3068 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3069 messages,
3070 })?;
3071 Ok(())
3072}
3073
3074async fn get_notifications(
3075 request: proto::GetNotifications,
3076 response: Response<proto::GetNotifications>,
3077 session: Session,
3078) -> Result<()> {
3079 let notifications = session
3080 .db()
3081 .await
3082 .get_notifications(
3083 session.user_id,
3084 NOTIFICATION_COUNT_PER_PAGE,
3085 request
3086 .before_id
3087 .map(|id| db::NotificationId::from_proto(id)),
3088 )
3089 .await?;
3090 response.send(proto::GetNotificationsResponse {
3091 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3092 notifications,
3093 })?;
3094 Ok(())
3095}
3096
3097async fn mark_notification_as_read(
3098 request: proto::MarkNotificationRead,
3099 response: Response<proto::MarkNotificationRead>,
3100 session: Session,
3101) -> Result<()> {
3102 let database = &session.db().await;
3103 let notifications = database
3104 .mark_notification_as_read_by_id(
3105 session.user_id,
3106 NotificationId::from_proto(request.notification_id),
3107 )
3108 .await?;
3109 send_notifications(
3110 &*session.connection_pool().await,
3111 &session.peer,
3112 notifications,
3113 );
3114 response.send(proto::Ack {})?;
3115 Ok(())
3116}
3117
3118async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
3119 let project_id = ProjectId::from_proto(request.project_id);
3120 let project_connection_ids = session
3121 .db()
3122 .await
3123 .project_connection_ids(project_id, session.connection_id)
3124 .await?;
3125 broadcast(
3126 Some(session.connection_id),
3127 project_connection_ids.iter().copied(),
3128 |connection_id| {
3129 session
3130 .peer
3131 .forward_send(session.connection_id, connection_id, request.clone())
3132 },
3133 );
3134 Ok(())
3135}
3136
3137async fn get_private_user_info(
3138 _request: proto::GetPrivateUserInfo,
3139 response: Response<proto::GetPrivateUserInfo>,
3140 session: Session,
3141) -> Result<()> {
3142 let db = session.db().await;
3143
3144 let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3145 let user = db
3146 .get_user_by_id(session.user_id)
3147 .await?
3148 .ok_or_else(|| anyhow!("user not found"))?;
3149 let flags = db.get_user_flags(session.user_id).await?;
3150
3151 response.send(proto::GetPrivateUserInfoResponse {
3152 metrics_id,
3153 staff: user.admin,
3154 flags,
3155 })?;
3156 Ok(())
3157}
3158
3159fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3160 match message {
3161 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3162 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3163 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3164 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3165 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3166 code: frame.code.into(),
3167 reason: frame.reason,
3168 })),
3169 }
3170}
3171
3172fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3173 match message {
3174 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3175 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3176 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3177 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3178 AxumMessage::Close(frame) => {
3179 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3180 code: frame.code.into(),
3181 reason: frame.reason,
3182 }))
3183 }
3184 }
3185}
3186
3187fn notify_membership_updated(
3188 connection_pool: &ConnectionPool,
3189 result: MembershipUpdated,
3190 user_id: UserId,
3191 peer: &Peer,
3192) {
3193 let mut update = build_channels_update(result.new_channels, vec![]);
3194 update.delete_channels = result
3195 .removed_channels
3196 .into_iter()
3197 .map(|id| id.to_proto())
3198 .collect();
3199 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3200
3201 for connection_id in connection_pool.user_connection_ids(user_id) {
3202 peer.send(connection_id, update.clone()).trace_err();
3203 }
3204}
3205
3206fn build_channels_update(
3207 channels: ChannelsForUser,
3208 channel_invites: Vec<db::Channel>,
3209) -> proto::UpdateChannels {
3210 let mut update = proto::UpdateChannels::default();
3211
3212 for channel in channels.channels {
3213 update.channels.push(channel.to_proto());
3214 }
3215
3216 update.unseen_channel_buffer_changes = channels.unseen_buffer_changes;
3217 update.unseen_channel_messages = channels.channel_messages;
3218
3219 for (channel_id, participants) in channels.channel_participants {
3220 update
3221 .channel_participants
3222 .push(proto::ChannelParticipants {
3223 channel_id: channel_id.to_proto(),
3224 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3225 });
3226 }
3227
3228 for channel in channel_invites {
3229 update.channel_invitations.push(channel.to_proto());
3230 }
3231
3232 update
3233}
3234
3235fn build_initial_contacts_update(
3236 contacts: Vec<db::Contact>,
3237 pool: &ConnectionPool,
3238) -> proto::UpdateContacts {
3239 let mut update = proto::UpdateContacts::default();
3240
3241 for contact in contacts {
3242 match contact {
3243 db::Contact::Accepted { user_id, busy } => {
3244 update.contacts.push(contact_for_user(user_id, busy, &pool));
3245 }
3246 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3247 db::Contact::Incoming { user_id } => {
3248 update
3249 .incoming_requests
3250 .push(proto::IncomingContactRequest {
3251 requester_id: user_id.to_proto(),
3252 })
3253 }
3254 }
3255 }
3256
3257 update
3258}
3259
3260fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3261 proto::Contact {
3262 user_id: user_id.to_proto(),
3263 online: pool.is_user_online(user_id),
3264 busy,
3265 }
3266}
3267
3268fn room_updated(room: &proto::Room, peer: &Peer) {
3269 broadcast(
3270 None,
3271 room.participants
3272 .iter()
3273 .filter_map(|participant| Some(participant.peer_id?.into())),
3274 |peer_id| {
3275 peer.send(
3276 peer_id.into(),
3277 proto::RoomUpdated {
3278 room: Some(room.clone()),
3279 },
3280 )
3281 },
3282 );
3283}
3284
3285fn channel_updated(
3286 channel_id: ChannelId,
3287 room: &proto::Room,
3288 channel_members: &[UserId],
3289 peer: &Peer,
3290 pool: &ConnectionPool,
3291) {
3292 let participants = room
3293 .participants
3294 .iter()
3295 .map(|p| p.user_id)
3296 .collect::<Vec<_>>();
3297
3298 broadcast(
3299 None,
3300 channel_members
3301 .iter()
3302 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3303 |peer_id| {
3304 peer.send(
3305 peer_id.into(),
3306 proto::UpdateChannels {
3307 channel_participants: vec![proto::ChannelParticipants {
3308 channel_id: channel_id.to_proto(),
3309 participant_user_ids: participants.clone(),
3310 }],
3311 ..Default::default()
3312 },
3313 )
3314 },
3315 );
3316}
3317
3318async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3319 let db = session.db().await;
3320
3321 let contacts = db.get_contacts(user_id).await?;
3322 let busy = db.is_user_busy(user_id).await?;
3323
3324 let pool = session.connection_pool().await;
3325 let updated_contact = contact_for_user(user_id, busy, &pool);
3326 for contact in contacts {
3327 if let db::Contact::Accepted {
3328 user_id: contact_user_id,
3329 ..
3330 } = contact
3331 {
3332 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3333 session
3334 .peer
3335 .send(
3336 contact_conn_id,
3337 proto::UpdateContacts {
3338 contacts: vec![updated_contact.clone()],
3339 remove_contacts: Default::default(),
3340 incoming_requests: Default::default(),
3341 remove_incoming_requests: Default::default(),
3342 outgoing_requests: Default::default(),
3343 remove_outgoing_requests: Default::default(),
3344 },
3345 )
3346 .trace_err();
3347 }
3348 }
3349 }
3350 Ok(())
3351}
3352
3353async fn leave_room_for_session(session: &Session) -> Result<()> {
3354 let mut contacts_to_update = HashSet::default();
3355
3356 let room_id;
3357 let canceled_calls_to_user_ids;
3358 let live_kit_room;
3359 let delete_live_kit_room;
3360 let room;
3361 let channel_members;
3362 let channel_id;
3363
3364 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3365 contacts_to_update.insert(session.user_id);
3366
3367 for project in left_room.left_projects.values() {
3368 project_left(project, session);
3369 }
3370
3371 room_id = RoomId::from_proto(left_room.room.id);
3372 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3373 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3374 delete_live_kit_room = left_room.deleted;
3375 room = mem::take(&mut left_room.room);
3376 channel_members = mem::take(&mut left_room.channel_members);
3377 channel_id = left_room.channel_id;
3378
3379 room_updated(&room, &session.peer);
3380 } else {
3381 return Ok(());
3382 }
3383
3384 if let Some(channel_id) = channel_id {
3385 channel_updated(
3386 channel_id,
3387 &room,
3388 &channel_members,
3389 &session.peer,
3390 &*session.connection_pool().await,
3391 );
3392 }
3393
3394 {
3395 let pool = session.connection_pool().await;
3396 for canceled_user_id in canceled_calls_to_user_ids {
3397 for connection_id in pool.user_connection_ids(canceled_user_id) {
3398 session
3399 .peer
3400 .send(
3401 connection_id,
3402 proto::CallCanceled {
3403 room_id: room_id.to_proto(),
3404 },
3405 )
3406 .trace_err();
3407 }
3408 contacts_to_update.insert(canceled_user_id);
3409 }
3410 }
3411
3412 for contact_user_id in contacts_to_update {
3413 update_user_contacts(contact_user_id, &session).await?;
3414 }
3415
3416 if let Some(live_kit) = session.live_kit_client.as_ref() {
3417 live_kit
3418 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3419 .await
3420 .trace_err();
3421
3422 if delete_live_kit_room {
3423 live_kit.delete_room(live_kit_room).await.trace_err();
3424 }
3425 }
3426
3427 Ok(())
3428}
3429
3430async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3431 let left_channel_buffers = session
3432 .db()
3433 .await
3434 .leave_channel_buffers(session.connection_id)
3435 .await?;
3436
3437 for left_buffer in left_channel_buffers {
3438 channel_buffer_updated(
3439 session.connection_id,
3440 left_buffer.connections,
3441 &proto::UpdateChannelBufferCollaborators {
3442 channel_id: left_buffer.channel_id.to_proto(),
3443 collaborators: left_buffer.collaborators,
3444 },
3445 &session.peer,
3446 );
3447 }
3448
3449 Ok(())
3450}
3451
3452fn project_left(project: &db::LeftProject, session: &Session) {
3453 for connection_id in &project.connection_ids {
3454 if project.host_user_id == session.user_id {
3455 session
3456 .peer
3457 .send(
3458 *connection_id,
3459 proto::UnshareProject {
3460 project_id: project.id.to_proto(),
3461 },
3462 )
3463 .trace_err();
3464 } else {
3465 session
3466 .peer
3467 .send(
3468 *connection_id,
3469 proto::RemoveProjectCollaborator {
3470 project_id: project.id.to_proto(),
3471 peer_id: Some(session.connection_id.into()),
3472 },
3473 )
3474 .trace_err();
3475 }
3476 }
3477}
3478
3479pub trait ResultExt {
3480 type Ok;
3481
3482 fn trace_err(self) -> Option<Self::Ok>;
3483}
3484
3485impl<T, E> ResultExt for Result<T, E>
3486where
3487 E: std::fmt::Debug,
3488{
3489 type Ok = T;
3490
3491 fn trace_err(self) -> Option<T> {
3492 match self {
3493 Ok(value) => Some(value),
3494 Err(error) => {
3495 tracing::error!("{:?}", error);
3496 None
3497 }
3498 }
3499 }
3500}