1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{
6 self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreateChannelResult,
7 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
8 MoveChannelResult, NotificationId, ProjectId, RemoveChannelMemberResult,
9 RenameChannelResult, RespondToChannelInvite, RoomId, ServerId, SetChannelVisibilityResult,
10 User, UserId,
11 },
12 executor::Executor,
13 AppState, Result,
14};
15use anyhow::anyhow;
16use async_tungstenite::tungstenite::{
17 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
18};
19use axum::{
20 body::Body,
21 extract::{
22 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
23 ConnectInfo, WebSocketUpgrade,
24 },
25 headers::{Header, HeaderName},
26 http::StatusCode,
27 middleware,
28 response::IntoResponse,
29 routing::get,
30 Extension, Router, TypedHeader,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::ConnectionPool;
34use futures::{
35 channel::oneshot,
36 future::{self, BoxFuture},
37 stream::FuturesUnordered,
38 FutureExt, SinkExt, StreamExt, TryStreamExt,
39};
40use lazy_static::lazy_static;
41use prometheus::{register_int_gauge, IntGauge};
42use rpc::{
43 proto::{
44 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
45 RequestMessage, UpdateChannelBufferCollaborators,
46 },
47 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
48};
49use serde::{Serialize, Serializer};
50use std::{
51 any::TypeId,
52 fmt,
53 future::Future,
54 marker::PhantomData,
55 mem,
56 net::SocketAddr,
57 ops::{Deref, DerefMut},
58 rc::Rc,
59 sync::{
60 atomic::{AtomicBool, Ordering::SeqCst},
61 Arc,
62 },
63 time::{Duration, Instant},
64};
65use time::OffsetDateTime;
66use tokio::sync::{watch, Semaphore};
67use tower::ServiceBuilder;
68use tracing::{info_span, instrument, Instrument};
69use util::channel::RELEASE_CHANNEL_NAME;
70
71pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
72pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
73
74const MESSAGE_COUNT_PER_PAGE: usize = 100;
75const MAX_MESSAGE_LEN: usize = 1024;
76const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
77
78lazy_static! {
79 static ref METRIC_CONNECTIONS: IntGauge =
80 register_int_gauge!("connections", "number of connections").unwrap();
81 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
82 "shared_projects",
83 "number of open projects with one or more guests"
84 )
85 .unwrap();
86}
87
88type MessageHandler =
89 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
90
91struct Response<R> {
92 peer: Arc<Peer>,
93 receipt: Receipt<R>,
94 responded: Arc<AtomicBool>,
95}
96
97impl<R: RequestMessage> Response<R> {
98 fn send(self, payload: R::Response) -> Result<()> {
99 self.responded.store(true, SeqCst);
100 self.peer.respond(self.receipt, payload)?;
101 Ok(())
102 }
103}
104
105#[derive(Clone)]
106struct Session {
107 user_id: UserId,
108 connection_id: ConnectionId,
109 db: Arc<tokio::sync::Mutex<DbHandle>>,
110 peer: Arc<Peer>,
111 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
112 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
113 _executor: Executor,
114}
115
116impl Session {
117 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
118 #[cfg(test)]
119 tokio::task::yield_now().await;
120 let guard = self.db.lock().await;
121 #[cfg(test)]
122 tokio::task::yield_now().await;
123 guard
124 }
125
126 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
127 #[cfg(test)]
128 tokio::task::yield_now().await;
129 let guard = self.connection_pool.lock();
130 ConnectionPoolGuard {
131 guard,
132 _not_send: PhantomData,
133 }
134 }
135}
136
137impl fmt::Debug for Session {
138 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
139 f.debug_struct("Session")
140 .field("user_id", &self.user_id)
141 .field("connection_id", &self.connection_id)
142 .finish()
143 }
144}
145
146struct DbHandle(Arc<Database>);
147
148impl Deref for DbHandle {
149 type Target = Database;
150
151 fn deref(&self) -> &Self::Target {
152 self.0.as_ref()
153 }
154}
155
156pub struct Server {
157 id: parking_lot::Mutex<ServerId>,
158 peer: Arc<Peer>,
159 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
160 app_state: Arc<AppState>,
161 executor: Executor,
162 handlers: HashMap<TypeId, MessageHandler>,
163 teardown: watch::Sender<()>,
164}
165
166pub(crate) struct ConnectionPoolGuard<'a> {
167 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
168 _not_send: PhantomData<Rc<()>>,
169}
170
171#[derive(Serialize)]
172pub struct ServerSnapshot<'a> {
173 peer: &'a Peer,
174 #[serde(serialize_with = "serialize_deref")]
175 connection_pool: ConnectionPoolGuard<'a>,
176}
177
178pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
179where
180 S: Serializer,
181 T: Deref<Target = U>,
182 U: Serialize,
183{
184 Serialize::serialize(value.deref(), serializer)
185}
186
187impl Server {
188 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
189 let mut server = Self {
190 id: parking_lot::Mutex::new(id),
191 peer: Peer::new(id.0 as u32),
192 app_state,
193 executor,
194 connection_pool: Default::default(),
195 handlers: Default::default(),
196 teardown: watch::channel(()).0,
197 };
198
199 server
200 .add_request_handler(ping)
201 .add_request_handler(create_room)
202 .add_request_handler(join_room)
203 .add_request_handler(rejoin_room)
204 .add_request_handler(leave_room)
205 .add_request_handler(call)
206 .add_request_handler(cancel_call)
207 .add_message_handler(decline_call)
208 .add_request_handler(update_participant_location)
209 .add_request_handler(share_project)
210 .add_message_handler(unshare_project)
211 .add_request_handler(join_project)
212 .add_message_handler(leave_project)
213 .add_request_handler(update_project)
214 .add_request_handler(update_worktree)
215 .add_message_handler(start_language_server)
216 .add_message_handler(update_language_server)
217 .add_message_handler(update_diagnostic_summary)
218 .add_message_handler(update_worktree_settings)
219 .add_message_handler(refresh_inlay_hints)
220 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
221 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
222 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
223 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
224 .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
225 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
226 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
227 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
228 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
229 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
230 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
231 .add_request_handler(forward_mutating_project_request::<proto::OpenBufferByPath>)
232 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
233 .add_request_handler(
234 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
235 )
236 .add_request_handler(
237 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
238 )
239 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
240 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
241 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
242 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
243 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
244 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
245 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
246 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
247 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
248 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
249 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
250 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
251 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
252 .add_message_handler(create_buffer_for_peer)
253 .add_request_handler(update_buffer)
254 .add_message_handler(update_buffer_file)
255 .add_message_handler(buffer_reloaded)
256 .add_message_handler(buffer_saved)
257 .add_request_handler(get_users)
258 .add_request_handler(fuzzy_search_users)
259 .add_request_handler(request_contact)
260 .add_request_handler(remove_contact)
261 .add_request_handler(respond_to_contact_request)
262 .add_request_handler(create_channel)
263 .add_request_handler(delete_channel)
264 .add_request_handler(invite_channel_member)
265 .add_request_handler(remove_channel_member)
266 .add_request_handler(set_channel_member_role)
267 .add_request_handler(set_channel_visibility)
268 .add_request_handler(rename_channel)
269 .add_request_handler(join_channel_buffer)
270 .add_request_handler(leave_channel_buffer)
271 .add_message_handler(update_channel_buffer)
272 .add_request_handler(rejoin_channel_buffers)
273 .add_request_handler(get_channel_members)
274 .add_request_handler(respond_to_channel_invite)
275 .add_request_handler(join_channel)
276 .add_request_handler(join_channel_chat)
277 .add_message_handler(leave_channel_chat)
278 .add_request_handler(send_channel_message)
279 .add_request_handler(remove_channel_message)
280 .add_request_handler(get_channel_messages)
281 .add_request_handler(get_channel_messages_by_id)
282 .add_request_handler(get_notifications)
283 .add_request_handler(mark_notification_as_read)
284 .add_request_handler(move_channel)
285 .add_request_handler(follow)
286 .add_message_handler(unfollow)
287 .add_message_handler(update_followers)
288 .add_message_handler(update_diff_base)
289 .add_request_handler(get_private_user_info)
290 .add_message_handler(acknowledge_channel_message)
291 .add_message_handler(acknowledge_buffer_version);
292
293 Arc::new(server)
294 }
295
296 pub async fn start(&self) -> Result<()> {
297 let server_id = *self.id.lock();
298 let app_state = self.app_state.clone();
299 let peer = self.peer.clone();
300 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
301 let pool = self.connection_pool.clone();
302 let live_kit_client = self.app_state.live_kit_client.clone();
303
304 let span = info_span!("start server");
305 self.executor.spawn_detached(
306 async move {
307 tracing::info!("waiting for cleanup timeout");
308 timeout.await;
309 tracing::info!("cleanup timeout expired, retrieving stale rooms");
310 if let Some((room_ids, channel_ids)) = app_state
311 .db
312 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
313 .await
314 .trace_err()
315 {
316 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
317 tracing::info!(
318 stale_channel_buffer_count = channel_ids.len(),
319 "retrieved stale channel buffers"
320 );
321
322 for channel_id in channel_ids {
323 if let Some(refreshed_channel_buffer) = app_state
324 .db
325 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
326 .await
327 .trace_err()
328 {
329 for connection_id in refreshed_channel_buffer.connection_ids {
330 peer.send(
331 connection_id,
332 proto::UpdateChannelBufferCollaborators {
333 channel_id: channel_id.to_proto(),
334 collaborators: refreshed_channel_buffer
335 .collaborators
336 .clone(),
337 },
338 )
339 .trace_err();
340 }
341 }
342 }
343
344 for room_id in room_ids {
345 let mut contacts_to_update = HashSet::default();
346 let mut canceled_calls_to_user_ids = Vec::new();
347 let mut live_kit_room = String::new();
348 let mut delete_live_kit_room = false;
349
350 if let Some(mut refreshed_room) = app_state
351 .db
352 .clear_stale_room_participants(room_id, server_id)
353 .await
354 .trace_err()
355 {
356 tracing::info!(
357 room_id = room_id.0,
358 new_participant_count = refreshed_room.room.participants.len(),
359 "refreshed room"
360 );
361 room_updated(&refreshed_room.room, &peer);
362 if let Some(channel_id) = refreshed_room.channel_id {
363 channel_updated(
364 channel_id,
365 &refreshed_room.room,
366 &refreshed_room.channel_members,
367 &peer,
368 &*pool.lock(),
369 );
370 }
371 contacts_to_update
372 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
373 contacts_to_update
374 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
375 canceled_calls_to_user_ids =
376 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
377 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
378 delete_live_kit_room = refreshed_room.room.participants.is_empty();
379 }
380
381 {
382 let pool = pool.lock();
383 for canceled_user_id in canceled_calls_to_user_ids {
384 for connection_id in pool.user_connection_ids(canceled_user_id) {
385 peer.send(
386 connection_id,
387 proto::CallCanceled {
388 room_id: room_id.to_proto(),
389 },
390 )
391 .trace_err();
392 }
393 }
394 }
395
396 for user_id in contacts_to_update {
397 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
398 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
399 if let Some((busy, contacts)) = busy.zip(contacts) {
400 let pool = pool.lock();
401 let updated_contact = contact_for_user(user_id, busy, &pool);
402 for contact in contacts {
403 if let db::Contact::Accepted {
404 user_id: contact_user_id,
405 ..
406 } = contact
407 {
408 for contact_conn_id in
409 pool.user_connection_ids(contact_user_id)
410 {
411 peer.send(
412 contact_conn_id,
413 proto::UpdateContacts {
414 contacts: vec![updated_contact.clone()],
415 remove_contacts: Default::default(),
416 incoming_requests: Default::default(),
417 remove_incoming_requests: Default::default(),
418 outgoing_requests: Default::default(),
419 remove_outgoing_requests: Default::default(),
420 },
421 )
422 .trace_err();
423 }
424 }
425 }
426 }
427 }
428
429 if let Some(live_kit) = live_kit_client.as_ref() {
430 if delete_live_kit_room {
431 live_kit.delete_room(live_kit_room).await.trace_err();
432 }
433 }
434 }
435 }
436
437 app_state
438 .db
439 .delete_stale_servers(&app_state.config.zed_environment, server_id)
440 .await
441 .trace_err();
442 }
443 .instrument(span),
444 );
445 Ok(())
446 }
447
448 pub fn teardown(&self) {
449 self.peer.teardown();
450 self.connection_pool.lock().reset();
451 let _ = self.teardown.send(());
452 }
453
454 #[cfg(test)]
455 pub fn reset(&self, id: ServerId) {
456 self.teardown();
457 *self.id.lock() = id;
458 self.peer.reset(id.0 as u32);
459 }
460
461 #[cfg(test)]
462 pub fn id(&self) -> ServerId {
463 *self.id.lock()
464 }
465
466 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
467 where
468 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
469 Fut: 'static + Send + Future<Output = Result<()>>,
470 M: EnvelopedMessage,
471 {
472 let prev_handler = self.handlers.insert(
473 TypeId::of::<M>(),
474 Box::new(move |envelope, session| {
475 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
476 let span = info_span!(
477 "handle message",
478 payload_type = envelope.payload_type_name()
479 );
480 span.in_scope(|| {
481 tracing::info!(
482 payload_type = envelope.payload_type_name(),
483 "message received"
484 );
485 });
486 let start_time = Instant::now();
487 let future = (handler)(*envelope, session);
488 async move {
489 let result = future.await;
490 let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
491 match result {
492 Err(error) => {
493 tracing::error!(%error, ?duration_ms, "error handling message")
494 }
495 Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
496 }
497 }
498 .instrument(span)
499 .boxed()
500 }),
501 );
502 if prev_handler.is_some() {
503 panic!("registered a handler for the same message twice");
504 }
505 self
506 }
507
508 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
509 where
510 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
511 Fut: 'static + Send + Future<Output = Result<()>>,
512 M: EnvelopedMessage,
513 {
514 self.add_handler(move |envelope, session| handler(envelope.payload, session));
515 self
516 }
517
518 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
519 where
520 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
521 Fut: Send + Future<Output = Result<()>>,
522 M: RequestMessage,
523 {
524 let handler = Arc::new(handler);
525 self.add_handler(move |envelope, session| {
526 let receipt = envelope.receipt();
527 let handler = handler.clone();
528 async move {
529 let peer = session.peer.clone();
530 let responded = Arc::new(AtomicBool::default());
531 let response = Response {
532 peer: peer.clone(),
533 responded: responded.clone(),
534 receipt,
535 };
536 match (handler)(envelope.payload, response, session).await {
537 Ok(()) => {
538 if responded.load(std::sync::atomic::Ordering::SeqCst) {
539 Ok(())
540 } else {
541 Err(anyhow!("handler did not send a response"))?
542 }
543 }
544 Err(error) => {
545 peer.respond_with_error(
546 receipt,
547 proto::Error {
548 message: error.to_string(),
549 },
550 )?;
551 Err(error)
552 }
553 }
554 }
555 })
556 }
557
558 pub fn handle_connection(
559 self: &Arc<Self>,
560 connection: Connection,
561 address: String,
562 user: User,
563 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
564 executor: Executor,
565 ) -> impl Future<Output = Result<()>> {
566 let this = self.clone();
567 let user_id = user.id;
568 let login = user.github_login;
569 let span = info_span!("handle connection", %user_id, %login, %address);
570 let mut teardown = self.teardown.subscribe();
571 async move {
572 let (connection_id, handle_io, mut incoming_rx) = this
573 .peer
574 .add_connection(connection, {
575 let executor = executor.clone();
576 move |duration| executor.sleep(duration)
577 });
578
579 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
580 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
581 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
582
583 if let Some(send_connection_id) = send_connection_id.take() {
584 let _ = send_connection_id.send(connection_id);
585 }
586
587 if !user.connected_once {
588 this.peer.send(connection_id, proto::ShowContacts {})?;
589 this.app_state.db.set_user_connected_once(user_id, true).await?;
590 }
591
592 let (contacts, channels_for_user, channel_invites) = future::try_join3(
593 this.app_state.db.get_contacts(user_id),
594 this.app_state.db.get_channels_for_user(user_id),
595 this.app_state.db.get_channel_invites_for_user(user_id),
596 ).await?;
597
598 {
599 let mut pool = this.connection_pool.lock();
600 pool.add_connection(connection_id, user_id, user.admin);
601 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
602 this.peer.send(connection_id, build_channels_update(
603 channels_for_user,
604 channel_invites
605 ))?;
606 }
607
608 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
609 this.peer.send(connection_id, incoming_call)?;
610 }
611
612 let session = Session {
613 user_id,
614 connection_id,
615 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
616 peer: this.peer.clone(),
617 connection_pool: this.connection_pool.clone(),
618 live_kit_client: this.app_state.live_kit_client.clone(),
619 _executor: executor.clone()
620 };
621 update_user_contacts(user_id, &session).await?;
622
623 let handle_io = handle_io.fuse();
624 futures::pin_mut!(handle_io);
625
626 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
627 // This prevents deadlocks when e.g., client A performs a request to client B and
628 // client B performs a request to client A. If both clients stop processing further
629 // messages until their respective request completes, they won't have a chance to
630 // respond to the other client's request and cause a deadlock.
631 //
632 // This arrangement ensures we will attempt to process earlier messages first, but fall
633 // back to processing messages arrived later in the spirit of making progress.
634 let mut foreground_message_handlers = FuturesUnordered::new();
635 let concurrent_handlers = Arc::new(Semaphore::new(256));
636 loop {
637 let next_message = async {
638 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
639 let message = incoming_rx.next().await;
640 (permit, message)
641 }.fuse();
642 futures::pin_mut!(next_message);
643 futures::select_biased! {
644 _ = teardown.changed().fuse() => return Ok(()),
645 result = handle_io => {
646 if let Err(error) = result {
647 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
648 }
649 break;
650 }
651 _ = foreground_message_handlers.next() => {}
652 next_message = next_message => {
653 let (permit, message) = next_message;
654 if let Some(message) = message {
655 let type_name = message.payload_type_name();
656 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
657 let span_enter = span.enter();
658 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
659 let is_background = message.is_background();
660 let handle_message = (handler)(message, session.clone());
661 drop(span_enter);
662
663 let handle_message = async move {
664 handle_message.await;
665 drop(permit);
666 }.instrument(span);
667 if is_background {
668 executor.spawn_detached(handle_message);
669 } else {
670 foreground_message_handlers.push(handle_message);
671 }
672 } else {
673 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
674 }
675 } else {
676 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
677 break;
678 }
679 }
680 }
681 }
682
683 drop(foreground_message_handlers);
684 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
685 if let Err(error) = connection_lost(session, teardown, executor).await {
686 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
687 }
688
689 Ok(())
690 }.instrument(span)
691 }
692
693 pub async fn invite_code_redeemed(
694 self: &Arc<Self>,
695 inviter_id: UserId,
696 invitee_id: UserId,
697 ) -> Result<()> {
698 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
699 if let Some(code) = &user.invite_code {
700 let pool = self.connection_pool.lock();
701 let invitee_contact = contact_for_user(invitee_id, false, &pool);
702 for connection_id in pool.user_connection_ids(inviter_id) {
703 self.peer.send(
704 connection_id,
705 proto::UpdateContacts {
706 contacts: vec![invitee_contact.clone()],
707 ..Default::default()
708 },
709 )?;
710 self.peer.send(
711 connection_id,
712 proto::UpdateInviteInfo {
713 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
714 count: user.invite_count as u32,
715 },
716 )?;
717 }
718 }
719 }
720 Ok(())
721 }
722
723 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
724 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
725 if let Some(invite_code) = &user.invite_code {
726 let pool = self.connection_pool.lock();
727 for connection_id in pool.user_connection_ids(user_id) {
728 self.peer.send(
729 connection_id,
730 proto::UpdateInviteInfo {
731 url: format!(
732 "{}{}",
733 self.app_state.config.invite_link_prefix, invite_code
734 ),
735 count: user.invite_count as u32,
736 },
737 )?;
738 }
739 }
740 }
741 Ok(())
742 }
743
744 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
745 ServerSnapshot {
746 connection_pool: ConnectionPoolGuard {
747 guard: self.connection_pool.lock(),
748 _not_send: PhantomData,
749 },
750 peer: &self.peer,
751 }
752 }
753}
754
755impl<'a> Deref for ConnectionPoolGuard<'a> {
756 type Target = ConnectionPool;
757
758 fn deref(&self) -> &Self::Target {
759 &*self.guard
760 }
761}
762
763impl<'a> DerefMut for ConnectionPoolGuard<'a> {
764 fn deref_mut(&mut self) -> &mut Self::Target {
765 &mut *self.guard
766 }
767}
768
769impl<'a> Drop for ConnectionPoolGuard<'a> {
770 fn drop(&mut self) {
771 #[cfg(test)]
772 self.check_invariants();
773 }
774}
775
776fn broadcast<F>(
777 sender_id: Option<ConnectionId>,
778 receiver_ids: impl IntoIterator<Item = ConnectionId>,
779 mut f: F,
780) where
781 F: FnMut(ConnectionId) -> anyhow::Result<()>,
782{
783 for receiver_id in receiver_ids {
784 if Some(receiver_id) != sender_id {
785 if let Err(error) = f(receiver_id) {
786 tracing::error!("failed to send to {:?} {}", receiver_id, error);
787 }
788 }
789 }
790}
791
792lazy_static! {
793 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
794}
795
796pub struct ProtocolVersion(u32);
797
798impl Header for ProtocolVersion {
799 fn name() -> &'static HeaderName {
800 &ZED_PROTOCOL_VERSION
801 }
802
803 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
804 where
805 Self: Sized,
806 I: Iterator<Item = &'i axum::http::HeaderValue>,
807 {
808 let version = values
809 .next()
810 .ok_or_else(axum::headers::Error::invalid)?
811 .to_str()
812 .map_err(|_| axum::headers::Error::invalid())?
813 .parse()
814 .map_err(|_| axum::headers::Error::invalid())?;
815 Ok(Self(version))
816 }
817
818 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
819 values.extend([self.0.to_string().parse().unwrap()]);
820 }
821}
822
823pub fn routes(server: Arc<Server>) -> Router<Body> {
824 Router::new()
825 .route("/rpc", get(handle_websocket_request))
826 .layer(
827 ServiceBuilder::new()
828 .layer(Extension(server.app_state.clone()))
829 .layer(middleware::from_fn(auth::validate_header)),
830 )
831 .route("/metrics", get(handle_metrics))
832 .layer(Extension(server))
833}
834
835pub async fn handle_websocket_request(
836 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
837 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
838 Extension(server): Extension<Arc<Server>>,
839 Extension(user): Extension<User>,
840 ws: WebSocketUpgrade,
841) -> axum::response::Response {
842 if protocol_version != rpc::PROTOCOL_VERSION {
843 return (
844 StatusCode::UPGRADE_REQUIRED,
845 "client must be upgraded".to_string(),
846 )
847 .into_response();
848 }
849 let socket_address = socket_address.to_string();
850 ws.on_upgrade(move |socket| {
851 use util::ResultExt;
852 let socket = socket
853 .map_ok(to_tungstenite_message)
854 .err_into()
855 .with(|message| async move { Ok(to_axum_message(message)) });
856 let connection = Connection::new(Box::pin(socket));
857 async move {
858 server
859 .handle_connection(connection, socket_address, user, None, Executor::Production)
860 .await
861 .log_err();
862 }
863 })
864}
865
866pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
867 let connections = server
868 .connection_pool
869 .lock()
870 .connections()
871 .filter(|connection| !connection.admin)
872 .count();
873
874 METRIC_CONNECTIONS.set(connections as _);
875
876 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
877 METRIC_SHARED_PROJECTS.set(shared_projects as _);
878
879 let encoder = prometheus::TextEncoder::new();
880 let metric_families = prometheus::gather();
881 let encoded_metrics = encoder
882 .encode_to_string(&metric_families)
883 .map_err(|err| anyhow!("{}", err))?;
884 Ok(encoded_metrics)
885}
886
887#[instrument(err, skip(executor))]
888async fn connection_lost(
889 session: Session,
890 mut teardown: watch::Receiver<()>,
891 executor: Executor,
892) -> Result<()> {
893 session.peer.disconnect(session.connection_id);
894 session
895 .connection_pool()
896 .await
897 .remove_connection(session.connection_id)?;
898
899 session
900 .db()
901 .await
902 .connection_lost(session.connection_id)
903 .await
904 .trace_err();
905
906 futures::select_biased! {
907 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
908 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
909 leave_room_for_session(&session).await.trace_err();
910 leave_channel_buffers_for_session(&session)
911 .await
912 .trace_err();
913
914 if !session
915 .connection_pool()
916 .await
917 .is_user_online(session.user_id)
918 {
919 let db = session.db().await;
920 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
921 room_updated(&room, &session.peer);
922 }
923 }
924
925 update_user_contacts(session.user_id, &session).await?;
926 }
927 _ = teardown.changed().fuse() => {}
928 }
929
930 Ok(())
931}
932
933async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
934 response.send(proto::Ack {})?;
935 Ok(())
936}
937
938async fn create_room(
939 _request: proto::CreateRoom,
940 response: Response<proto::CreateRoom>,
941 session: Session,
942) -> Result<()> {
943 let live_kit_room = nanoid::nanoid!(30);
944
945 let live_kit_connection_info = {
946 let live_kit_room = live_kit_room.clone();
947 let live_kit = session.live_kit_client.as_ref();
948
949 util::async_maybe!({
950 let live_kit = live_kit?;
951
952 let token = live_kit
953 .room_token(&live_kit_room, &session.user_id.to_string())
954 .trace_err()?;
955
956 Some(proto::LiveKitConnectionInfo {
957 server_url: live_kit.url().into(),
958 token,
959 can_publish: true,
960 })
961 })
962 }
963 .await;
964
965 let room = session
966 .db()
967 .await
968 .create_room(
969 session.user_id,
970 session.connection_id,
971 &live_kit_room,
972 RELEASE_CHANNEL_NAME.as_str(),
973 )
974 .await?;
975
976 response.send(proto::CreateRoomResponse {
977 room: Some(room.clone()),
978 live_kit_connection_info,
979 })?;
980
981 update_user_contacts(session.user_id, &session).await?;
982 Ok(())
983}
984
985async fn join_room(
986 request: proto::JoinRoom,
987 response: Response<proto::JoinRoom>,
988 session: Session,
989) -> Result<()> {
990 let room_id = RoomId::from_proto(request.id);
991
992 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
993
994 if let Some(channel_id) = channel_id {
995 return join_channel_internal(channel_id, Box::new(response), session).await;
996 }
997
998 let joined_room = {
999 let room = session
1000 .db()
1001 .await
1002 .join_room(
1003 room_id,
1004 session.user_id,
1005 session.connection_id,
1006 RELEASE_CHANNEL_NAME.as_str(),
1007 )
1008 .await?;
1009 room_updated(&room.room, &session.peer);
1010 room.into_inner()
1011 };
1012
1013 for connection_id in session
1014 .connection_pool()
1015 .await
1016 .user_connection_ids(session.user_id)
1017 {
1018 session
1019 .peer
1020 .send(
1021 connection_id,
1022 proto::CallCanceled {
1023 room_id: room_id.to_proto(),
1024 },
1025 )
1026 .trace_err();
1027 }
1028
1029 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1030 if let Some(token) = live_kit
1031 .room_token(
1032 &joined_room.room.live_kit_room,
1033 &session.user_id.to_string(),
1034 )
1035 .trace_err()
1036 {
1037 Some(proto::LiveKitConnectionInfo {
1038 server_url: live_kit.url().into(),
1039 token,
1040 can_publish: true,
1041 })
1042 } else {
1043 None
1044 }
1045 } else {
1046 None
1047 };
1048
1049 response.send(proto::JoinRoomResponse {
1050 room: Some(joined_room.room),
1051 channel_id: None,
1052 live_kit_connection_info,
1053 })?;
1054
1055 update_user_contacts(session.user_id, &session).await?;
1056 Ok(())
1057}
1058
1059async fn rejoin_room(
1060 request: proto::RejoinRoom,
1061 response: Response<proto::RejoinRoom>,
1062 session: Session,
1063) -> Result<()> {
1064 let room;
1065 let channel_id;
1066 let channel_members;
1067 {
1068 let mut rejoined_room = session
1069 .db()
1070 .await
1071 .rejoin_room(request, session.user_id, session.connection_id)
1072 .await?;
1073
1074 response.send(proto::RejoinRoomResponse {
1075 room: Some(rejoined_room.room.clone()),
1076 reshared_projects: rejoined_room
1077 .reshared_projects
1078 .iter()
1079 .map(|project| proto::ResharedProject {
1080 id: project.id.to_proto(),
1081 collaborators: project
1082 .collaborators
1083 .iter()
1084 .map(|collaborator| collaborator.to_proto())
1085 .collect(),
1086 })
1087 .collect(),
1088 rejoined_projects: rejoined_room
1089 .rejoined_projects
1090 .iter()
1091 .map(|rejoined_project| proto::RejoinedProject {
1092 id: rejoined_project.id.to_proto(),
1093 worktrees: rejoined_project
1094 .worktrees
1095 .iter()
1096 .map(|worktree| proto::WorktreeMetadata {
1097 id: worktree.id,
1098 root_name: worktree.root_name.clone(),
1099 visible: worktree.visible,
1100 abs_path: worktree.abs_path.clone(),
1101 })
1102 .collect(),
1103 collaborators: rejoined_project
1104 .collaborators
1105 .iter()
1106 .map(|collaborator| collaborator.to_proto())
1107 .collect(),
1108 language_servers: rejoined_project.language_servers.clone(),
1109 })
1110 .collect(),
1111 })?;
1112 room_updated(&rejoined_room.room, &session.peer);
1113
1114 for project in &rejoined_room.reshared_projects {
1115 for collaborator in &project.collaborators {
1116 session
1117 .peer
1118 .send(
1119 collaborator.connection_id,
1120 proto::UpdateProjectCollaborator {
1121 project_id: project.id.to_proto(),
1122 old_peer_id: Some(project.old_connection_id.into()),
1123 new_peer_id: Some(session.connection_id.into()),
1124 },
1125 )
1126 .trace_err();
1127 }
1128
1129 broadcast(
1130 Some(session.connection_id),
1131 project
1132 .collaborators
1133 .iter()
1134 .map(|collaborator| collaborator.connection_id),
1135 |connection_id| {
1136 session.peer.forward_send(
1137 session.connection_id,
1138 connection_id,
1139 proto::UpdateProject {
1140 project_id: project.id.to_proto(),
1141 worktrees: project.worktrees.clone(),
1142 },
1143 )
1144 },
1145 );
1146 }
1147
1148 for project in &rejoined_room.rejoined_projects {
1149 for collaborator in &project.collaborators {
1150 session
1151 .peer
1152 .send(
1153 collaborator.connection_id,
1154 proto::UpdateProjectCollaborator {
1155 project_id: project.id.to_proto(),
1156 old_peer_id: Some(project.old_connection_id.into()),
1157 new_peer_id: Some(session.connection_id.into()),
1158 },
1159 )
1160 .trace_err();
1161 }
1162 }
1163
1164 for project in &mut rejoined_room.rejoined_projects {
1165 for worktree in mem::take(&mut project.worktrees) {
1166 #[cfg(any(test, feature = "test-support"))]
1167 const MAX_CHUNK_SIZE: usize = 2;
1168 #[cfg(not(any(test, feature = "test-support")))]
1169 const MAX_CHUNK_SIZE: usize = 256;
1170
1171 // Stream this worktree's entries.
1172 let message = proto::UpdateWorktree {
1173 project_id: project.id.to_proto(),
1174 worktree_id: worktree.id,
1175 abs_path: worktree.abs_path.clone(),
1176 root_name: worktree.root_name,
1177 updated_entries: worktree.updated_entries,
1178 removed_entries: worktree.removed_entries,
1179 scan_id: worktree.scan_id,
1180 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1181 updated_repositories: worktree.updated_repositories,
1182 removed_repositories: worktree.removed_repositories,
1183 };
1184 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1185 session.peer.send(session.connection_id, update.clone())?;
1186 }
1187
1188 // Stream this worktree's diagnostics.
1189 for summary in worktree.diagnostic_summaries {
1190 session.peer.send(
1191 session.connection_id,
1192 proto::UpdateDiagnosticSummary {
1193 project_id: project.id.to_proto(),
1194 worktree_id: worktree.id,
1195 summary: Some(summary),
1196 },
1197 )?;
1198 }
1199
1200 for settings_file in worktree.settings_files {
1201 session.peer.send(
1202 session.connection_id,
1203 proto::UpdateWorktreeSettings {
1204 project_id: project.id.to_proto(),
1205 worktree_id: worktree.id,
1206 path: settings_file.path,
1207 content: Some(settings_file.content),
1208 },
1209 )?;
1210 }
1211 }
1212
1213 for language_server in &project.language_servers {
1214 session.peer.send(
1215 session.connection_id,
1216 proto::UpdateLanguageServer {
1217 project_id: project.id.to_proto(),
1218 language_server_id: language_server.id,
1219 variant: Some(
1220 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1221 proto::LspDiskBasedDiagnosticsUpdated {},
1222 ),
1223 ),
1224 },
1225 )?;
1226 }
1227 }
1228
1229 let rejoined_room = rejoined_room.into_inner();
1230
1231 room = rejoined_room.room;
1232 channel_id = rejoined_room.channel_id;
1233 channel_members = rejoined_room.channel_members;
1234 }
1235
1236 if let Some(channel_id) = channel_id {
1237 channel_updated(
1238 channel_id,
1239 &room,
1240 &channel_members,
1241 &session.peer,
1242 &*session.connection_pool().await,
1243 );
1244 }
1245
1246 update_user_contacts(session.user_id, &session).await?;
1247 Ok(())
1248}
1249
1250async fn leave_room(
1251 _: proto::LeaveRoom,
1252 response: Response<proto::LeaveRoom>,
1253 session: Session,
1254) -> Result<()> {
1255 leave_room_for_session(&session).await?;
1256 response.send(proto::Ack {})?;
1257 Ok(())
1258}
1259
1260async fn call(
1261 request: proto::Call,
1262 response: Response<proto::Call>,
1263 session: Session,
1264) -> Result<()> {
1265 let room_id = RoomId::from_proto(request.room_id);
1266 let calling_user_id = session.user_id;
1267 let calling_connection_id = session.connection_id;
1268 let called_user_id = UserId::from_proto(request.called_user_id);
1269 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1270 if !session
1271 .db()
1272 .await
1273 .has_contact(calling_user_id, called_user_id)
1274 .await?
1275 {
1276 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1277 }
1278
1279 let incoming_call = {
1280 let (room, incoming_call) = &mut *session
1281 .db()
1282 .await
1283 .call(
1284 room_id,
1285 calling_user_id,
1286 calling_connection_id,
1287 called_user_id,
1288 initial_project_id,
1289 )
1290 .await?;
1291 room_updated(&room, &session.peer);
1292 mem::take(incoming_call)
1293 };
1294 update_user_contacts(called_user_id, &session).await?;
1295
1296 let mut calls = session
1297 .connection_pool()
1298 .await
1299 .user_connection_ids(called_user_id)
1300 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1301 .collect::<FuturesUnordered<_>>();
1302
1303 while let Some(call_response) = calls.next().await {
1304 match call_response.as_ref() {
1305 Ok(_) => {
1306 response.send(proto::Ack {})?;
1307 return Ok(());
1308 }
1309 Err(_) => {
1310 call_response.trace_err();
1311 }
1312 }
1313 }
1314
1315 {
1316 let room = session
1317 .db()
1318 .await
1319 .call_failed(room_id, called_user_id)
1320 .await?;
1321 room_updated(&room, &session.peer);
1322 }
1323 update_user_contacts(called_user_id, &session).await?;
1324
1325 Err(anyhow!("failed to ring user"))?
1326}
1327
1328async fn cancel_call(
1329 request: proto::CancelCall,
1330 response: Response<proto::CancelCall>,
1331 session: Session,
1332) -> Result<()> {
1333 let called_user_id = UserId::from_proto(request.called_user_id);
1334 let room_id = RoomId::from_proto(request.room_id);
1335 {
1336 let room = session
1337 .db()
1338 .await
1339 .cancel_call(room_id, session.connection_id, called_user_id)
1340 .await?;
1341 room_updated(&room, &session.peer);
1342 }
1343
1344 for connection_id in session
1345 .connection_pool()
1346 .await
1347 .user_connection_ids(called_user_id)
1348 {
1349 session
1350 .peer
1351 .send(
1352 connection_id,
1353 proto::CallCanceled {
1354 room_id: room_id.to_proto(),
1355 },
1356 )
1357 .trace_err();
1358 }
1359 response.send(proto::Ack {})?;
1360
1361 update_user_contacts(called_user_id, &session).await?;
1362 Ok(())
1363}
1364
1365async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1366 let room_id = RoomId::from_proto(message.room_id);
1367 {
1368 let room = session
1369 .db()
1370 .await
1371 .decline_call(Some(room_id), session.user_id)
1372 .await?
1373 .ok_or_else(|| anyhow!("failed to decline call"))?;
1374 room_updated(&room, &session.peer);
1375 }
1376
1377 for connection_id in session
1378 .connection_pool()
1379 .await
1380 .user_connection_ids(session.user_id)
1381 {
1382 session
1383 .peer
1384 .send(
1385 connection_id,
1386 proto::CallCanceled {
1387 room_id: room_id.to_proto(),
1388 },
1389 )
1390 .trace_err();
1391 }
1392 update_user_contacts(session.user_id, &session).await?;
1393 Ok(())
1394}
1395
1396async fn update_participant_location(
1397 request: proto::UpdateParticipantLocation,
1398 response: Response<proto::UpdateParticipantLocation>,
1399 session: Session,
1400) -> Result<()> {
1401 let room_id = RoomId::from_proto(request.room_id);
1402 let location = request
1403 .location
1404 .ok_or_else(|| anyhow!("invalid location"))?;
1405
1406 let db = session.db().await;
1407 let room = db
1408 .update_room_participant_location(room_id, session.connection_id, location)
1409 .await?;
1410
1411 room_updated(&room, &session.peer);
1412 response.send(proto::Ack {})?;
1413 Ok(())
1414}
1415
1416async fn share_project(
1417 request: proto::ShareProject,
1418 response: Response<proto::ShareProject>,
1419 session: Session,
1420) -> Result<()> {
1421 let (project_id, room) = &*session
1422 .db()
1423 .await
1424 .share_project(
1425 RoomId::from_proto(request.room_id),
1426 session.connection_id,
1427 &request.worktrees,
1428 )
1429 .await?;
1430 response.send(proto::ShareProjectResponse {
1431 project_id: project_id.to_proto(),
1432 })?;
1433 room_updated(&room, &session.peer);
1434
1435 Ok(())
1436}
1437
1438async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1439 let project_id = ProjectId::from_proto(message.project_id);
1440
1441 let (room, guest_connection_ids) = &*session
1442 .db()
1443 .await
1444 .unshare_project(project_id, session.connection_id)
1445 .await?;
1446
1447 broadcast(
1448 Some(session.connection_id),
1449 guest_connection_ids.iter().copied(),
1450 |conn_id| session.peer.send(conn_id, message.clone()),
1451 );
1452 room_updated(&room, &session.peer);
1453
1454 Ok(())
1455}
1456
1457async fn join_project(
1458 request: proto::JoinProject,
1459 response: Response<proto::JoinProject>,
1460 session: Session,
1461) -> Result<()> {
1462 let project_id = ProjectId::from_proto(request.project_id);
1463 let guest_user_id = session.user_id;
1464
1465 tracing::info!(%project_id, "join project");
1466
1467 let (project, replica_id) = &mut *session
1468 .db()
1469 .await
1470 .join_project(project_id, session.connection_id)
1471 .await?;
1472
1473 let collaborators = project
1474 .collaborators
1475 .iter()
1476 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1477 .map(|collaborator| collaborator.to_proto())
1478 .collect::<Vec<_>>();
1479
1480 let worktrees = project
1481 .worktrees
1482 .iter()
1483 .map(|(id, worktree)| proto::WorktreeMetadata {
1484 id: *id,
1485 root_name: worktree.root_name.clone(),
1486 visible: worktree.visible,
1487 abs_path: worktree.abs_path.clone(),
1488 })
1489 .collect::<Vec<_>>();
1490
1491 for collaborator in &collaborators {
1492 session
1493 .peer
1494 .send(
1495 collaborator.peer_id.unwrap().into(),
1496 proto::AddProjectCollaborator {
1497 project_id: project_id.to_proto(),
1498 collaborator: Some(proto::Collaborator {
1499 peer_id: Some(session.connection_id.into()),
1500 replica_id: replica_id.0 as u32,
1501 user_id: guest_user_id.to_proto(),
1502 }),
1503 },
1504 )
1505 .trace_err();
1506 }
1507
1508 // First, we send the metadata associated with each worktree.
1509 response.send(proto::JoinProjectResponse {
1510 worktrees: worktrees.clone(),
1511 replica_id: replica_id.0 as u32,
1512 collaborators: collaborators.clone(),
1513 language_servers: project.language_servers.clone(),
1514 })?;
1515
1516 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1517 #[cfg(any(test, feature = "test-support"))]
1518 const MAX_CHUNK_SIZE: usize = 2;
1519 #[cfg(not(any(test, feature = "test-support")))]
1520 const MAX_CHUNK_SIZE: usize = 256;
1521
1522 // Stream this worktree's entries.
1523 let message = proto::UpdateWorktree {
1524 project_id: project_id.to_proto(),
1525 worktree_id,
1526 abs_path: worktree.abs_path.clone(),
1527 root_name: worktree.root_name,
1528 updated_entries: worktree.entries,
1529 removed_entries: Default::default(),
1530 scan_id: worktree.scan_id,
1531 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1532 updated_repositories: worktree.repository_entries.into_values().collect(),
1533 removed_repositories: Default::default(),
1534 };
1535 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1536 session.peer.send(session.connection_id, update.clone())?;
1537 }
1538
1539 // Stream this worktree's diagnostics.
1540 for summary in worktree.diagnostic_summaries {
1541 session.peer.send(
1542 session.connection_id,
1543 proto::UpdateDiagnosticSummary {
1544 project_id: project_id.to_proto(),
1545 worktree_id: worktree.id,
1546 summary: Some(summary),
1547 },
1548 )?;
1549 }
1550
1551 for settings_file in worktree.settings_files {
1552 session.peer.send(
1553 session.connection_id,
1554 proto::UpdateWorktreeSettings {
1555 project_id: project_id.to_proto(),
1556 worktree_id: worktree.id,
1557 path: settings_file.path,
1558 content: Some(settings_file.content),
1559 },
1560 )?;
1561 }
1562 }
1563
1564 for language_server in &project.language_servers {
1565 session.peer.send(
1566 session.connection_id,
1567 proto::UpdateLanguageServer {
1568 project_id: project_id.to_proto(),
1569 language_server_id: language_server.id,
1570 variant: Some(
1571 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1572 proto::LspDiskBasedDiagnosticsUpdated {},
1573 ),
1574 ),
1575 },
1576 )?;
1577 }
1578
1579 Ok(())
1580}
1581
1582async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1583 let sender_id = session.connection_id;
1584 let project_id = ProjectId::from_proto(request.project_id);
1585
1586 let (room, project) = &*session
1587 .db()
1588 .await
1589 .leave_project(project_id, sender_id)
1590 .await?;
1591 tracing::info!(
1592 %project_id,
1593 host_user_id = %project.host_user_id,
1594 host_connection_id = %project.host_connection_id,
1595 "leave project"
1596 );
1597
1598 project_left(&project, &session);
1599 room_updated(&room, &session.peer);
1600
1601 Ok(())
1602}
1603
1604async fn update_project(
1605 request: proto::UpdateProject,
1606 response: Response<proto::UpdateProject>,
1607 session: Session,
1608) -> Result<()> {
1609 let project_id = ProjectId::from_proto(request.project_id);
1610 let (room, guest_connection_ids) = &*session
1611 .db()
1612 .await
1613 .update_project(project_id, session.connection_id, &request.worktrees)
1614 .await?;
1615 broadcast(
1616 Some(session.connection_id),
1617 guest_connection_ids.iter().copied(),
1618 |connection_id| {
1619 session
1620 .peer
1621 .forward_send(session.connection_id, connection_id, request.clone())
1622 },
1623 );
1624 room_updated(&room, &session.peer);
1625 response.send(proto::Ack {})?;
1626
1627 Ok(())
1628}
1629
1630async fn update_worktree(
1631 request: proto::UpdateWorktree,
1632 response: Response<proto::UpdateWorktree>,
1633 session: Session,
1634) -> Result<()> {
1635 let guest_connection_ids = session
1636 .db()
1637 .await
1638 .update_worktree(&request, session.connection_id)
1639 .await?;
1640
1641 broadcast(
1642 Some(session.connection_id),
1643 guest_connection_ids.iter().copied(),
1644 |connection_id| {
1645 session
1646 .peer
1647 .forward_send(session.connection_id, connection_id, request.clone())
1648 },
1649 );
1650 response.send(proto::Ack {})?;
1651 Ok(())
1652}
1653
1654async fn update_diagnostic_summary(
1655 message: proto::UpdateDiagnosticSummary,
1656 session: Session,
1657) -> Result<()> {
1658 let guest_connection_ids = session
1659 .db()
1660 .await
1661 .update_diagnostic_summary(&message, session.connection_id)
1662 .await?;
1663
1664 broadcast(
1665 Some(session.connection_id),
1666 guest_connection_ids.iter().copied(),
1667 |connection_id| {
1668 session
1669 .peer
1670 .forward_send(session.connection_id, connection_id, message.clone())
1671 },
1672 );
1673
1674 Ok(())
1675}
1676
1677async fn update_worktree_settings(
1678 message: proto::UpdateWorktreeSettings,
1679 session: Session,
1680) -> Result<()> {
1681 let guest_connection_ids = session
1682 .db()
1683 .await
1684 .update_worktree_settings(&message, session.connection_id)
1685 .await?;
1686
1687 broadcast(
1688 Some(session.connection_id),
1689 guest_connection_ids.iter().copied(),
1690 |connection_id| {
1691 session
1692 .peer
1693 .forward_send(session.connection_id, connection_id, message.clone())
1694 },
1695 );
1696
1697 Ok(())
1698}
1699
1700async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1701 broadcast_project_message(request.project_id, request, session).await
1702}
1703
1704async fn start_language_server(
1705 request: proto::StartLanguageServer,
1706 session: Session,
1707) -> Result<()> {
1708 let guest_connection_ids = session
1709 .db()
1710 .await
1711 .start_language_server(&request, session.connection_id)
1712 .await?;
1713
1714 broadcast(
1715 Some(session.connection_id),
1716 guest_connection_ids.iter().copied(),
1717 |connection_id| {
1718 session
1719 .peer
1720 .forward_send(session.connection_id, connection_id, request.clone())
1721 },
1722 );
1723 Ok(())
1724}
1725
1726async fn update_language_server(
1727 request: proto::UpdateLanguageServer,
1728 session: Session,
1729) -> Result<()> {
1730 let project_id = ProjectId::from_proto(request.project_id);
1731 let project_connection_ids = session
1732 .db()
1733 .await
1734 .project_connection_ids(project_id, session.connection_id)
1735 .await?;
1736 broadcast(
1737 Some(session.connection_id),
1738 project_connection_ids.iter().copied(),
1739 |connection_id| {
1740 session
1741 .peer
1742 .forward_send(session.connection_id, connection_id, request.clone())
1743 },
1744 );
1745 Ok(())
1746}
1747
1748async fn forward_read_only_project_request<T>(
1749 request: T,
1750 response: Response<T>,
1751 session: Session,
1752) -> Result<()>
1753where
1754 T: EntityMessage + RequestMessage,
1755{
1756 let project_id = ProjectId::from_proto(request.remote_entity_id());
1757 let host_connection_id = {
1758 let collaborators = session
1759 .db()
1760 .await
1761 .project_collaborators(project_id, session.connection_id)
1762 .await?;
1763 collaborators
1764 .iter()
1765 .find(|collaborator| collaborator.is_host)
1766 .ok_or_else(|| anyhow!("host not found"))?
1767 .connection_id
1768 };
1769
1770 let payload = session
1771 .peer
1772 .forward_request(session.connection_id, host_connection_id, request)
1773 .await?;
1774
1775 response.send(payload)?;
1776 Ok(())
1777}
1778
1779async fn forward_mutating_project_request<T>(
1780 request: T,
1781 response: Response<T>,
1782 session: Session,
1783) -> Result<()>
1784where
1785 T: EntityMessage + RequestMessage,
1786{
1787 let project_id = ProjectId::from_proto(request.remote_entity_id());
1788 let host_connection_id = session
1789 .db()
1790 .await
1791 .host_for_mutating_project_request(project_id, session.connection_id)
1792 .await?;
1793
1794 let payload = session
1795 .peer
1796 .forward_request(session.connection_id, host_connection_id, request)
1797 .await?;
1798
1799 response.send(payload)?;
1800 Ok(())
1801}
1802
1803async fn create_buffer_for_peer(
1804 request: proto::CreateBufferForPeer,
1805 session: Session,
1806) -> Result<()> {
1807 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1808 session
1809 .peer
1810 .forward_send(session.connection_id, peer_id.into(), request)?;
1811 Ok(())
1812}
1813
1814async fn update_buffer(
1815 request: proto::UpdateBuffer,
1816 response: Response<proto::UpdateBuffer>,
1817 session: Session,
1818) -> Result<()> {
1819 let project_id = ProjectId::from_proto(request.project_id);
1820 let mut guest_connection_ids;
1821 let mut host_connection_id = None;
1822 {
1823 let collaborators = session
1824 .db()
1825 .await
1826 .project_collaborators(project_id, session.connection_id)
1827 .await?;
1828 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1829 for collaborator in collaborators.iter() {
1830 if collaborator.is_host {
1831 host_connection_id = Some(collaborator.connection_id);
1832 } else {
1833 guest_connection_ids.push(collaborator.connection_id);
1834 }
1835 }
1836 }
1837 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1838
1839 broadcast(
1840 Some(session.connection_id),
1841 guest_connection_ids,
1842 |connection_id| {
1843 session
1844 .peer
1845 .forward_send(session.connection_id, connection_id, request.clone())
1846 },
1847 );
1848 if host_connection_id != session.connection_id {
1849 session
1850 .peer
1851 .forward_request(session.connection_id, host_connection_id, request.clone())
1852 .await?;
1853 }
1854
1855 response.send(proto::Ack {})?;
1856 Ok(())
1857}
1858
1859async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1860 let project_id = ProjectId::from_proto(request.project_id);
1861 let project_connection_ids = session
1862 .db()
1863 .await
1864 .project_connection_ids(project_id, session.connection_id)
1865 .await?;
1866
1867 broadcast(
1868 Some(session.connection_id),
1869 project_connection_ids.iter().copied(),
1870 |connection_id| {
1871 session
1872 .peer
1873 .forward_send(session.connection_id, connection_id, request.clone())
1874 },
1875 );
1876 Ok(())
1877}
1878
1879async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1880 let project_id = ProjectId::from_proto(request.project_id);
1881 let project_connection_ids = session
1882 .db()
1883 .await
1884 .project_connection_ids(project_id, session.connection_id)
1885 .await?;
1886 broadcast(
1887 Some(session.connection_id),
1888 project_connection_ids.iter().copied(),
1889 |connection_id| {
1890 session
1891 .peer
1892 .forward_send(session.connection_id, connection_id, request.clone())
1893 },
1894 );
1895 Ok(())
1896}
1897
1898async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1899 broadcast_project_message(request.project_id, request, session).await
1900}
1901
1902async fn broadcast_project_message<T: EnvelopedMessage>(
1903 project_id: u64,
1904 request: T,
1905 session: Session,
1906) -> Result<()> {
1907 let project_id = ProjectId::from_proto(project_id);
1908 let project_connection_ids = session
1909 .db()
1910 .await
1911 .project_connection_ids(project_id, session.connection_id)
1912 .await?;
1913 broadcast(
1914 Some(session.connection_id),
1915 project_connection_ids.iter().copied(),
1916 |connection_id| {
1917 session
1918 .peer
1919 .forward_send(session.connection_id, connection_id, request.clone())
1920 },
1921 );
1922 Ok(())
1923}
1924
1925async fn follow(
1926 request: proto::Follow,
1927 response: Response<proto::Follow>,
1928 session: Session,
1929) -> Result<()> {
1930 let room_id = RoomId::from_proto(request.room_id);
1931 let project_id = request.project_id.map(ProjectId::from_proto);
1932 let leader_id = request
1933 .leader_id
1934 .ok_or_else(|| anyhow!("invalid leader id"))?
1935 .into();
1936 let follower_id = session.connection_id;
1937
1938 session
1939 .db()
1940 .await
1941 .check_room_participants(room_id, leader_id, session.connection_id)
1942 .await?;
1943
1944 let response_payload = session
1945 .peer
1946 .forward_request(session.connection_id, leader_id, request)
1947 .await?;
1948 response.send(response_payload)?;
1949
1950 if let Some(project_id) = project_id {
1951 let room = session
1952 .db()
1953 .await
1954 .follow(room_id, project_id, leader_id, follower_id)
1955 .await?;
1956 room_updated(&room, &session.peer);
1957 }
1958
1959 Ok(())
1960}
1961
1962async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1963 let room_id = RoomId::from_proto(request.room_id);
1964 let project_id = request.project_id.map(ProjectId::from_proto);
1965 let leader_id = request
1966 .leader_id
1967 .ok_or_else(|| anyhow!("invalid leader id"))?
1968 .into();
1969 let follower_id = session.connection_id;
1970
1971 session
1972 .db()
1973 .await
1974 .check_room_participants(room_id, leader_id, session.connection_id)
1975 .await?;
1976
1977 session
1978 .peer
1979 .forward_send(session.connection_id, leader_id, request)?;
1980
1981 if let Some(project_id) = project_id {
1982 let room = session
1983 .db()
1984 .await
1985 .unfollow(room_id, project_id, leader_id, follower_id)
1986 .await?;
1987 room_updated(&room, &session.peer);
1988 }
1989
1990 Ok(())
1991}
1992
1993async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1994 let room_id = RoomId::from_proto(request.room_id);
1995 let database = session.db.lock().await;
1996
1997 let connection_ids = if let Some(project_id) = request.project_id {
1998 let project_id = ProjectId::from_proto(project_id);
1999 database
2000 .project_connection_ids(project_id, session.connection_id)
2001 .await?
2002 } else {
2003 database
2004 .room_connection_ids(room_id, session.connection_id)
2005 .await?
2006 };
2007
2008 // For now, don't send view update messages back to that view's current leader.
2009 let connection_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2010 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2011 _ => None,
2012 });
2013
2014 for follower_peer_id in request.follower_ids.iter().copied() {
2015 let follower_connection_id = follower_peer_id.into();
2016 if Some(follower_peer_id) != connection_id_to_omit
2017 && connection_ids.contains(&follower_connection_id)
2018 {
2019 session.peer.forward_send(
2020 session.connection_id,
2021 follower_connection_id,
2022 request.clone(),
2023 )?;
2024 }
2025 }
2026 Ok(())
2027}
2028
2029async fn get_users(
2030 request: proto::GetUsers,
2031 response: Response<proto::GetUsers>,
2032 session: Session,
2033) -> Result<()> {
2034 let user_ids = request
2035 .user_ids
2036 .into_iter()
2037 .map(UserId::from_proto)
2038 .collect();
2039 let users = session
2040 .db()
2041 .await
2042 .get_users_by_ids(user_ids)
2043 .await?
2044 .into_iter()
2045 .map(|user| proto::User {
2046 id: user.id.to_proto(),
2047 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2048 github_login: user.github_login,
2049 })
2050 .collect();
2051 response.send(proto::UsersResponse { users })?;
2052 Ok(())
2053}
2054
2055async fn fuzzy_search_users(
2056 request: proto::FuzzySearchUsers,
2057 response: Response<proto::FuzzySearchUsers>,
2058 session: Session,
2059) -> Result<()> {
2060 let query = request.query;
2061 let users = match query.len() {
2062 0 => vec![],
2063 1 | 2 => session
2064 .db()
2065 .await
2066 .get_user_by_github_login(&query)
2067 .await?
2068 .into_iter()
2069 .collect(),
2070 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2071 };
2072 let users = users
2073 .into_iter()
2074 .filter(|user| user.id != session.user_id)
2075 .map(|user| proto::User {
2076 id: user.id.to_proto(),
2077 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2078 github_login: user.github_login,
2079 })
2080 .collect();
2081 response.send(proto::UsersResponse { users })?;
2082 Ok(())
2083}
2084
2085async fn request_contact(
2086 request: proto::RequestContact,
2087 response: Response<proto::RequestContact>,
2088 session: Session,
2089) -> Result<()> {
2090 let requester_id = session.user_id;
2091 let responder_id = UserId::from_proto(request.responder_id);
2092 if requester_id == responder_id {
2093 return Err(anyhow!("cannot add yourself as a contact"))?;
2094 }
2095
2096 let notifications = session
2097 .db()
2098 .await
2099 .send_contact_request(requester_id, responder_id)
2100 .await?;
2101
2102 // Update outgoing contact requests of requester
2103 let mut update = proto::UpdateContacts::default();
2104 update.outgoing_requests.push(responder_id.to_proto());
2105 for connection_id in session
2106 .connection_pool()
2107 .await
2108 .user_connection_ids(requester_id)
2109 {
2110 session.peer.send(connection_id, update.clone())?;
2111 }
2112
2113 // Update incoming contact requests of responder
2114 let mut update = proto::UpdateContacts::default();
2115 update
2116 .incoming_requests
2117 .push(proto::IncomingContactRequest {
2118 requester_id: requester_id.to_proto(),
2119 });
2120 let connection_pool = session.connection_pool().await;
2121 for connection_id in connection_pool.user_connection_ids(responder_id) {
2122 session.peer.send(connection_id, update.clone())?;
2123 }
2124
2125 send_notifications(&*connection_pool, &session.peer, notifications);
2126
2127 response.send(proto::Ack {})?;
2128 Ok(())
2129}
2130
2131async fn respond_to_contact_request(
2132 request: proto::RespondToContactRequest,
2133 response: Response<proto::RespondToContactRequest>,
2134 session: Session,
2135) -> Result<()> {
2136 let responder_id = session.user_id;
2137 let requester_id = UserId::from_proto(request.requester_id);
2138 let db = session.db().await;
2139 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2140 db.dismiss_contact_notification(responder_id, requester_id)
2141 .await?;
2142 } else {
2143 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2144
2145 let notifications = db
2146 .respond_to_contact_request(responder_id, requester_id, accept)
2147 .await?;
2148 let requester_busy = db.is_user_busy(requester_id).await?;
2149 let responder_busy = db.is_user_busy(responder_id).await?;
2150
2151 let pool = session.connection_pool().await;
2152 // Update responder with new contact
2153 let mut update = proto::UpdateContacts::default();
2154 if accept {
2155 update
2156 .contacts
2157 .push(contact_for_user(requester_id, requester_busy, &pool));
2158 }
2159 update
2160 .remove_incoming_requests
2161 .push(requester_id.to_proto());
2162 for connection_id in pool.user_connection_ids(responder_id) {
2163 session.peer.send(connection_id, update.clone())?;
2164 }
2165
2166 // Update requester with new contact
2167 let mut update = proto::UpdateContacts::default();
2168 if accept {
2169 update
2170 .contacts
2171 .push(contact_for_user(responder_id, responder_busy, &pool));
2172 }
2173 update
2174 .remove_outgoing_requests
2175 .push(responder_id.to_proto());
2176
2177 for connection_id in pool.user_connection_ids(requester_id) {
2178 session.peer.send(connection_id, update.clone())?;
2179 }
2180
2181 send_notifications(&*pool, &session.peer, notifications);
2182 }
2183
2184 response.send(proto::Ack {})?;
2185 Ok(())
2186}
2187
2188async fn remove_contact(
2189 request: proto::RemoveContact,
2190 response: Response<proto::RemoveContact>,
2191 session: Session,
2192) -> Result<()> {
2193 let requester_id = session.user_id;
2194 let responder_id = UserId::from_proto(request.user_id);
2195 let db = session.db().await;
2196 let (contact_accepted, deleted_notification_id) =
2197 db.remove_contact(requester_id, responder_id).await?;
2198
2199 let pool = session.connection_pool().await;
2200 // Update outgoing contact requests of requester
2201 let mut update = proto::UpdateContacts::default();
2202 if contact_accepted {
2203 update.remove_contacts.push(responder_id.to_proto());
2204 } else {
2205 update
2206 .remove_outgoing_requests
2207 .push(responder_id.to_proto());
2208 }
2209 for connection_id in pool.user_connection_ids(requester_id) {
2210 session.peer.send(connection_id, update.clone())?;
2211 }
2212
2213 // Update incoming contact requests of responder
2214 let mut update = proto::UpdateContacts::default();
2215 if contact_accepted {
2216 update.remove_contacts.push(requester_id.to_proto());
2217 } else {
2218 update
2219 .remove_incoming_requests
2220 .push(requester_id.to_proto());
2221 }
2222 for connection_id in pool.user_connection_ids(responder_id) {
2223 session.peer.send(connection_id, update.clone())?;
2224 if let Some(notification_id) = deleted_notification_id {
2225 session.peer.send(
2226 connection_id,
2227 proto::DeleteNotification {
2228 notification_id: notification_id.to_proto(),
2229 },
2230 )?;
2231 }
2232 }
2233
2234 response.send(proto::Ack {})?;
2235 Ok(())
2236}
2237
2238async fn create_channel(
2239 request: proto::CreateChannel,
2240 response: Response<proto::CreateChannel>,
2241 session: Session,
2242) -> Result<()> {
2243 let db = session.db().await;
2244
2245 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2246 let CreateChannelResult {
2247 channel,
2248 participants_to_update,
2249 } = db
2250 .create_channel(&request.name, parent_id, session.user_id)
2251 .await?;
2252
2253 response.send(proto::CreateChannelResponse {
2254 channel: Some(channel.to_proto()),
2255 parent_id: request.parent_id,
2256 })?;
2257
2258 let connection_pool = session.connection_pool().await;
2259 for (user_id, channels) in participants_to_update {
2260 let update = build_channels_update(channels, vec![]);
2261 for connection_id in connection_pool.user_connection_ids(user_id) {
2262 if user_id == session.user_id {
2263 continue;
2264 }
2265 session.peer.send(connection_id, update.clone())?;
2266 }
2267 }
2268
2269 Ok(())
2270}
2271
2272async fn delete_channel(
2273 request: proto::DeleteChannel,
2274 response: Response<proto::DeleteChannel>,
2275 session: Session,
2276) -> Result<()> {
2277 let db = session.db().await;
2278
2279 let channel_id = request.channel_id;
2280 let (removed_channels, member_ids) = db
2281 .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2282 .await?;
2283 response.send(proto::Ack {})?;
2284
2285 // Notify members of removed channels
2286 let mut update = proto::UpdateChannels::default();
2287 update
2288 .delete_channels
2289 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2290
2291 let connection_pool = session.connection_pool().await;
2292 for member_id in member_ids {
2293 for connection_id in connection_pool.user_connection_ids(member_id) {
2294 session.peer.send(connection_id, update.clone())?;
2295 }
2296 }
2297
2298 Ok(())
2299}
2300
2301async fn invite_channel_member(
2302 request: proto::InviteChannelMember,
2303 response: Response<proto::InviteChannelMember>,
2304 session: Session,
2305) -> Result<()> {
2306 let db = session.db().await;
2307 let channel_id = ChannelId::from_proto(request.channel_id);
2308 let invitee_id = UserId::from_proto(request.user_id);
2309 let InviteMemberResult {
2310 channel,
2311 notifications,
2312 } = db
2313 .invite_channel_member(
2314 channel_id,
2315 invitee_id,
2316 session.user_id,
2317 request.role().into(),
2318 )
2319 .await?;
2320
2321 let update = proto::UpdateChannels {
2322 channel_invitations: vec![channel.to_proto()],
2323 ..Default::default()
2324 };
2325
2326 let connection_pool = session.connection_pool().await;
2327 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2328 session.peer.send(connection_id, update.clone())?;
2329 }
2330
2331 send_notifications(&*connection_pool, &session.peer, notifications);
2332
2333 response.send(proto::Ack {})?;
2334 Ok(())
2335}
2336
2337async fn remove_channel_member(
2338 request: proto::RemoveChannelMember,
2339 response: Response<proto::RemoveChannelMember>,
2340 session: Session,
2341) -> Result<()> {
2342 let db = session.db().await;
2343 let channel_id = ChannelId::from_proto(request.channel_id);
2344 let member_id = UserId::from_proto(request.user_id);
2345
2346 let RemoveChannelMemberResult {
2347 membership_update,
2348 notification_id,
2349 } = db
2350 .remove_channel_member(channel_id, member_id, session.user_id)
2351 .await?;
2352
2353 let connection_pool = &session.connection_pool().await;
2354 notify_membership_updated(
2355 &connection_pool,
2356 membership_update,
2357 member_id,
2358 &session.peer,
2359 );
2360 for connection_id in connection_pool.user_connection_ids(member_id) {
2361 if let Some(notification_id) = notification_id {
2362 session
2363 .peer
2364 .send(
2365 connection_id,
2366 proto::DeleteNotification {
2367 notification_id: notification_id.to_proto(),
2368 },
2369 )
2370 .trace_err();
2371 }
2372 }
2373
2374 response.send(proto::Ack {})?;
2375 Ok(())
2376}
2377
2378async fn set_channel_visibility(
2379 request: proto::SetChannelVisibility,
2380 response: Response<proto::SetChannelVisibility>,
2381 session: Session,
2382) -> Result<()> {
2383 let db = session.db().await;
2384 let channel_id = ChannelId::from_proto(request.channel_id);
2385 let visibility = request.visibility().into();
2386
2387 let SetChannelVisibilityResult {
2388 participants_to_update,
2389 participants_to_remove,
2390 channels_to_remove,
2391 } = db
2392 .set_channel_visibility(channel_id, visibility, session.user_id)
2393 .await?;
2394
2395 let connection_pool = session.connection_pool().await;
2396 for (user_id, channels) in participants_to_update {
2397 let update = build_channels_update(channels, vec![]);
2398 for connection_id in connection_pool.user_connection_ids(user_id) {
2399 session.peer.send(connection_id, update.clone())?;
2400 }
2401 }
2402 for user_id in participants_to_remove {
2403 let update = proto::UpdateChannels {
2404 delete_channels: channels_to_remove.iter().map(|id| id.to_proto()).collect(),
2405 ..Default::default()
2406 };
2407 for connection_id in connection_pool.user_connection_ids(user_id) {
2408 session.peer.send(connection_id, update.clone())?;
2409 }
2410 }
2411
2412 response.send(proto::Ack {})?;
2413 Ok(())
2414}
2415
2416async fn set_channel_member_role(
2417 request: proto::SetChannelMemberRole,
2418 response: Response<proto::SetChannelMemberRole>,
2419 session: Session,
2420) -> Result<()> {
2421 let db = session.db().await;
2422 let channel_id = ChannelId::from_proto(request.channel_id);
2423 let member_id = UserId::from_proto(request.user_id);
2424 let result = db
2425 .set_channel_member_role(
2426 channel_id,
2427 session.user_id,
2428 member_id,
2429 request.role().into(),
2430 )
2431 .await?;
2432
2433 match result {
2434 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2435 let connection_pool = session.connection_pool().await;
2436 notify_membership_updated(
2437 &connection_pool,
2438 membership_update,
2439 member_id,
2440 &session.peer,
2441 )
2442 }
2443 db::SetMemberRoleResult::InviteUpdated(channel) => {
2444 let update = proto::UpdateChannels {
2445 channel_invitations: vec![channel.to_proto()],
2446 ..Default::default()
2447 };
2448
2449 for connection_id in session
2450 .connection_pool()
2451 .await
2452 .user_connection_ids(member_id)
2453 {
2454 session.peer.send(connection_id, update.clone())?;
2455 }
2456 }
2457 }
2458
2459 response.send(proto::Ack {})?;
2460 Ok(())
2461}
2462
2463async fn rename_channel(
2464 request: proto::RenameChannel,
2465 response: Response<proto::RenameChannel>,
2466 session: Session,
2467) -> Result<()> {
2468 let db = session.db().await;
2469 let channel_id = ChannelId::from_proto(request.channel_id);
2470 let RenameChannelResult {
2471 channel,
2472 participants_to_update,
2473 } = db
2474 .rename_channel(channel_id, session.user_id, &request.name)
2475 .await?;
2476
2477 response.send(proto::RenameChannelResponse {
2478 channel: Some(channel.to_proto()),
2479 })?;
2480
2481 let connection_pool = session.connection_pool().await;
2482 for (user_id, channel) in participants_to_update {
2483 for connection_id in connection_pool.user_connection_ids(user_id) {
2484 let update = proto::UpdateChannels {
2485 channels: vec![channel.to_proto()],
2486 ..Default::default()
2487 };
2488
2489 session.peer.send(connection_id, update.clone())?;
2490 }
2491 }
2492
2493 Ok(())
2494}
2495
2496async fn move_channel(
2497 request: proto::MoveChannel,
2498 response: Response<proto::MoveChannel>,
2499 session: Session,
2500) -> Result<()> {
2501 let channel_id = ChannelId::from_proto(request.channel_id);
2502 let to = request.to.map(ChannelId::from_proto);
2503
2504 let result = session
2505 .db()
2506 .await
2507 .move_channel(channel_id, to, session.user_id)
2508 .await?;
2509
2510 notify_channel_moved(result, session).await?;
2511
2512 response.send(Ack {})?;
2513 Ok(())
2514}
2515
2516async fn notify_channel_moved(result: Option<MoveChannelResult>, session: Session) -> Result<()> {
2517 let Some(MoveChannelResult {
2518 participants_to_remove,
2519 participants_to_update,
2520 moved_channels,
2521 }) = result
2522 else {
2523 return Ok(());
2524 };
2525 let moved_channels: Vec<u64> = moved_channels.iter().map(|id| id.to_proto()).collect();
2526
2527 let connection_pool = session.connection_pool().await;
2528 for (user_id, channels) in participants_to_update {
2529 let mut update = build_channels_update(channels, vec![]);
2530 update.delete_channels = moved_channels.clone();
2531 for connection_id in connection_pool.user_connection_ids(user_id) {
2532 session.peer.send(connection_id, update.clone())?;
2533 }
2534 }
2535
2536 for user_id in participants_to_remove {
2537 let update = proto::UpdateChannels {
2538 delete_channels: moved_channels.clone(),
2539 ..Default::default()
2540 };
2541 for connection_id in connection_pool.user_connection_ids(user_id) {
2542 session.peer.send(connection_id, update.clone())?;
2543 }
2544 }
2545 Ok(())
2546}
2547
2548async fn get_channel_members(
2549 request: proto::GetChannelMembers,
2550 response: Response<proto::GetChannelMembers>,
2551 session: Session,
2552) -> Result<()> {
2553 let db = session.db().await;
2554 let channel_id = ChannelId::from_proto(request.channel_id);
2555 let members = db
2556 .get_channel_participant_details(channel_id, session.user_id)
2557 .await?;
2558 response.send(proto::GetChannelMembersResponse { members })?;
2559 Ok(())
2560}
2561
2562async fn respond_to_channel_invite(
2563 request: proto::RespondToChannelInvite,
2564 response: Response<proto::RespondToChannelInvite>,
2565 session: Session,
2566) -> Result<()> {
2567 let db = session.db().await;
2568 let channel_id = ChannelId::from_proto(request.channel_id);
2569 let RespondToChannelInvite {
2570 membership_update,
2571 notifications,
2572 } = db
2573 .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2574 .await?;
2575
2576 let connection_pool = session.connection_pool().await;
2577 if let Some(membership_update) = membership_update {
2578 notify_membership_updated(
2579 &connection_pool,
2580 membership_update,
2581 session.user_id,
2582 &session.peer,
2583 );
2584 } else {
2585 let update = proto::UpdateChannels {
2586 remove_channel_invitations: vec![channel_id.to_proto()],
2587 ..Default::default()
2588 };
2589
2590 for connection_id in connection_pool.user_connection_ids(session.user_id) {
2591 session.peer.send(connection_id, update.clone())?;
2592 }
2593 };
2594
2595 send_notifications(&*connection_pool, &session.peer, notifications);
2596
2597 response.send(proto::Ack {})?;
2598
2599 Ok(())
2600}
2601
2602async fn join_channel(
2603 request: proto::JoinChannel,
2604 response: Response<proto::JoinChannel>,
2605 session: Session,
2606) -> Result<()> {
2607 let channel_id = ChannelId::from_proto(request.channel_id);
2608 join_channel_internal(channel_id, Box::new(response), session).await
2609}
2610
2611trait JoinChannelInternalResponse {
2612 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2613}
2614impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2615 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2616 Response::<proto::JoinChannel>::send(self, result)
2617 }
2618}
2619impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2620 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2621 Response::<proto::JoinRoom>::send(self, result)
2622 }
2623}
2624
2625async fn join_channel_internal(
2626 channel_id: ChannelId,
2627 response: Box<impl JoinChannelInternalResponse>,
2628 session: Session,
2629) -> Result<()> {
2630 let joined_room = {
2631 leave_room_for_session(&session).await?;
2632 let db = session.db().await;
2633
2634 let (joined_room, membership_updated, role) = db
2635 .join_channel(
2636 channel_id,
2637 session.user_id,
2638 session.connection_id,
2639 RELEASE_CHANNEL_NAME.as_str(),
2640 )
2641 .await?;
2642
2643 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2644 let (can_publish, token) = if role == ChannelRole::Guest {
2645 (
2646 false,
2647 live_kit
2648 .guest_token(
2649 &joined_room.room.live_kit_room,
2650 &session.user_id.to_string(),
2651 )
2652 .trace_err()?,
2653 )
2654 } else {
2655 (
2656 true,
2657 live_kit
2658 .room_token(
2659 &joined_room.room.live_kit_room,
2660 &session.user_id.to_string(),
2661 )
2662 .trace_err()?,
2663 )
2664 };
2665
2666 Some(LiveKitConnectionInfo {
2667 server_url: live_kit.url().into(),
2668 token,
2669 can_publish,
2670 })
2671 });
2672
2673 response.send(proto::JoinRoomResponse {
2674 room: Some(joined_room.room.clone()),
2675 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2676 live_kit_connection_info,
2677 })?;
2678
2679 let connection_pool = session.connection_pool().await;
2680 if let Some(membership_updated) = membership_updated {
2681 notify_membership_updated(
2682 &connection_pool,
2683 membership_updated,
2684 session.user_id,
2685 &session.peer,
2686 );
2687 }
2688
2689 room_updated(&joined_room.room, &session.peer);
2690
2691 joined_room
2692 };
2693
2694 channel_updated(
2695 channel_id,
2696 &joined_room.room,
2697 &joined_room.channel_members,
2698 &session.peer,
2699 &*session.connection_pool().await,
2700 );
2701
2702 update_user_contacts(session.user_id, &session).await?;
2703 Ok(())
2704}
2705
2706async fn join_channel_buffer(
2707 request: proto::JoinChannelBuffer,
2708 response: Response<proto::JoinChannelBuffer>,
2709 session: Session,
2710) -> Result<()> {
2711 let db = session.db().await;
2712 let channel_id = ChannelId::from_proto(request.channel_id);
2713
2714 let open_response = db
2715 .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2716 .await?;
2717
2718 let collaborators = open_response.collaborators.clone();
2719 response.send(open_response)?;
2720
2721 let update = UpdateChannelBufferCollaborators {
2722 channel_id: channel_id.to_proto(),
2723 collaborators: collaborators.clone(),
2724 };
2725 channel_buffer_updated(
2726 session.connection_id,
2727 collaborators
2728 .iter()
2729 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2730 &update,
2731 &session.peer,
2732 );
2733
2734 Ok(())
2735}
2736
2737async fn update_channel_buffer(
2738 request: proto::UpdateChannelBuffer,
2739 session: Session,
2740) -> Result<()> {
2741 let db = session.db().await;
2742 let channel_id = ChannelId::from_proto(request.channel_id);
2743
2744 let (collaborators, non_collaborators, epoch, version) = db
2745 .update_channel_buffer(channel_id, session.user_id, &request.operations)
2746 .await?;
2747
2748 channel_buffer_updated(
2749 session.connection_id,
2750 collaborators,
2751 &proto::UpdateChannelBuffer {
2752 channel_id: channel_id.to_proto(),
2753 operations: request.operations,
2754 },
2755 &session.peer,
2756 );
2757
2758 let pool = &*session.connection_pool().await;
2759
2760 broadcast(
2761 None,
2762 non_collaborators
2763 .iter()
2764 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2765 |peer_id| {
2766 session.peer.send(
2767 peer_id.into(),
2768 proto::UpdateChannels {
2769 unseen_channel_buffer_changes: vec![proto::UnseenChannelBufferChange {
2770 channel_id: channel_id.to_proto(),
2771 epoch: epoch as u64,
2772 version: version.clone(),
2773 }],
2774 ..Default::default()
2775 },
2776 )
2777 },
2778 );
2779
2780 Ok(())
2781}
2782
2783async fn rejoin_channel_buffers(
2784 request: proto::RejoinChannelBuffers,
2785 response: Response<proto::RejoinChannelBuffers>,
2786 session: Session,
2787) -> Result<()> {
2788 let db = session.db().await;
2789 let buffers = db
2790 .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2791 .await?;
2792
2793 for rejoined_buffer in &buffers {
2794 let collaborators_to_notify = rejoined_buffer
2795 .buffer
2796 .collaborators
2797 .iter()
2798 .filter_map(|c| Some(c.peer_id?.into()));
2799 channel_buffer_updated(
2800 session.connection_id,
2801 collaborators_to_notify,
2802 &proto::UpdateChannelBufferCollaborators {
2803 channel_id: rejoined_buffer.buffer.channel_id,
2804 collaborators: rejoined_buffer.buffer.collaborators.clone(),
2805 },
2806 &session.peer,
2807 );
2808 }
2809
2810 response.send(proto::RejoinChannelBuffersResponse {
2811 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2812 })?;
2813
2814 Ok(())
2815}
2816
2817async fn leave_channel_buffer(
2818 request: proto::LeaveChannelBuffer,
2819 response: Response<proto::LeaveChannelBuffer>,
2820 session: Session,
2821) -> Result<()> {
2822 let db = session.db().await;
2823 let channel_id = ChannelId::from_proto(request.channel_id);
2824
2825 let left_buffer = db
2826 .leave_channel_buffer(channel_id, session.connection_id)
2827 .await?;
2828
2829 response.send(Ack {})?;
2830
2831 channel_buffer_updated(
2832 session.connection_id,
2833 left_buffer.connections,
2834 &proto::UpdateChannelBufferCollaborators {
2835 channel_id: channel_id.to_proto(),
2836 collaborators: left_buffer.collaborators,
2837 },
2838 &session.peer,
2839 );
2840
2841 Ok(())
2842}
2843
2844fn channel_buffer_updated<T: EnvelopedMessage>(
2845 sender_id: ConnectionId,
2846 collaborators: impl IntoIterator<Item = ConnectionId>,
2847 message: &T,
2848 peer: &Peer,
2849) {
2850 broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2851 peer.send(peer_id.into(), message.clone())
2852 });
2853}
2854
2855fn send_notifications(
2856 connection_pool: &ConnectionPool,
2857 peer: &Peer,
2858 notifications: db::NotificationBatch,
2859) {
2860 for (user_id, notification) in notifications {
2861 for connection_id in connection_pool.user_connection_ids(user_id) {
2862 if let Err(error) = peer.send(
2863 connection_id,
2864 proto::AddNotification {
2865 notification: Some(notification.clone()),
2866 },
2867 ) {
2868 tracing::error!(
2869 "failed to send notification to {:?} {}",
2870 connection_id,
2871 error
2872 );
2873 }
2874 }
2875 }
2876}
2877
2878async fn send_channel_message(
2879 request: proto::SendChannelMessage,
2880 response: Response<proto::SendChannelMessage>,
2881 session: Session,
2882) -> Result<()> {
2883 // Validate the message body.
2884 let body = request.body.trim().to_string();
2885 if body.len() > MAX_MESSAGE_LEN {
2886 return Err(anyhow!("message is too long"))?;
2887 }
2888 if body.is_empty() {
2889 return Err(anyhow!("message can't be blank"))?;
2890 }
2891
2892 // TODO: adjust mentions if body is trimmed
2893
2894 let timestamp = OffsetDateTime::now_utc();
2895 let nonce = request
2896 .nonce
2897 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
2898
2899 let channel_id = ChannelId::from_proto(request.channel_id);
2900 let CreatedChannelMessage {
2901 message_id,
2902 participant_connection_ids,
2903 channel_members,
2904 notifications,
2905 } = session
2906 .db()
2907 .await
2908 .create_channel_message(
2909 channel_id,
2910 session.user_id,
2911 &body,
2912 &request.mentions,
2913 timestamp,
2914 nonce.clone().into(),
2915 )
2916 .await?;
2917 let message = proto::ChannelMessage {
2918 sender_id: session.user_id.to_proto(),
2919 id: message_id.to_proto(),
2920 body,
2921 mentions: request.mentions,
2922 timestamp: timestamp.unix_timestamp() as u64,
2923 nonce: Some(nonce),
2924 };
2925 broadcast(
2926 Some(session.connection_id),
2927 participant_connection_ids,
2928 |connection| {
2929 session.peer.send(
2930 connection,
2931 proto::ChannelMessageSent {
2932 channel_id: channel_id.to_proto(),
2933 message: Some(message.clone()),
2934 },
2935 )
2936 },
2937 );
2938 response.send(proto::SendChannelMessageResponse {
2939 message: Some(message),
2940 })?;
2941
2942 let pool = &*session.connection_pool().await;
2943 broadcast(
2944 None,
2945 channel_members
2946 .iter()
2947 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2948 |peer_id| {
2949 session.peer.send(
2950 peer_id.into(),
2951 proto::UpdateChannels {
2952 unseen_channel_messages: vec![proto::UnseenChannelMessage {
2953 channel_id: channel_id.to_proto(),
2954 message_id: message_id.to_proto(),
2955 }],
2956 ..Default::default()
2957 },
2958 )
2959 },
2960 );
2961 send_notifications(pool, &session.peer, notifications);
2962
2963 Ok(())
2964}
2965
2966async fn remove_channel_message(
2967 request: proto::RemoveChannelMessage,
2968 response: Response<proto::RemoveChannelMessage>,
2969 session: Session,
2970) -> Result<()> {
2971 let channel_id = ChannelId::from_proto(request.channel_id);
2972 let message_id = MessageId::from_proto(request.message_id);
2973 let connection_ids = session
2974 .db()
2975 .await
2976 .remove_channel_message(channel_id, message_id, session.user_id)
2977 .await?;
2978 broadcast(Some(session.connection_id), connection_ids, |connection| {
2979 session.peer.send(connection, request.clone())
2980 });
2981 response.send(proto::Ack {})?;
2982 Ok(())
2983}
2984
2985async fn acknowledge_channel_message(
2986 request: proto::AckChannelMessage,
2987 session: Session,
2988) -> Result<()> {
2989 let channel_id = ChannelId::from_proto(request.channel_id);
2990 let message_id = MessageId::from_proto(request.message_id);
2991 let notifications = session
2992 .db()
2993 .await
2994 .observe_channel_message(channel_id, session.user_id, message_id)
2995 .await?;
2996 send_notifications(
2997 &*session.connection_pool().await,
2998 &session.peer,
2999 notifications,
3000 );
3001 Ok(())
3002}
3003
3004async fn acknowledge_buffer_version(
3005 request: proto::AckBufferOperation,
3006 session: Session,
3007) -> Result<()> {
3008 let buffer_id = BufferId::from_proto(request.buffer_id);
3009 session
3010 .db()
3011 .await
3012 .observe_buffer_version(
3013 buffer_id,
3014 session.user_id,
3015 request.epoch as i32,
3016 &request.version,
3017 )
3018 .await?;
3019 Ok(())
3020}
3021
3022async fn join_channel_chat(
3023 request: proto::JoinChannelChat,
3024 response: Response<proto::JoinChannelChat>,
3025 session: Session,
3026) -> Result<()> {
3027 let channel_id = ChannelId::from_proto(request.channel_id);
3028
3029 let db = session.db().await;
3030 db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3031 .await?;
3032 let messages = db
3033 .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3034 .await?;
3035 response.send(proto::JoinChannelChatResponse {
3036 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3037 messages,
3038 })?;
3039 Ok(())
3040}
3041
3042async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3043 let channel_id = ChannelId::from_proto(request.channel_id);
3044 session
3045 .db()
3046 .await
3047 .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3048 .await?;
3049 Ok(())
3050}
3051
3052async fn get_channel_messages(
3053 request: proto::GetChannelMessages,
3054 response: Response<proto::GetChannelMessages>,
3055 session: Session,
3056) -> Result<()> {
3057 let channel_id = ChannelId::from_proto(request.channel_id);
3058 let messages = session
3059 .db()
3060 .await
3061 .get_channel_messages(
3062 channel_id,
3063 session.user_id,
3064 MESSAGE_COUNT_PER_PAGE,
3065 Some(MessageId::from_proto(request.before_message_id)),
3066 )
3067 .await?;
3068 response.send(proto::GetChannelMessagesResponse {
3069 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3070 messages,
3071 })?;
3072 Ok(())
3073}
3074
3075async fn get_channel_messages_by_id(
3076 request: proto::GetChannelMessagesById,
3077 response: Response<proto::GetChannelMessagesById>,
3078 session: Session,
3079) -> Result<()> {
3080 let message_ids = request
3081 .message_ids
3082 .iter()
3083 .map(|id| MessageId::from_proto(*id))
3084 .collect::<Vec<_>>();
3085 let messages = session
3086 .db()
3087 .await
3088 .get_channel_messages_by_id(session.user_id, &message_ids)
3089 .await?;
3090 response.send(proto::GetChannelMessagesResponse {
3091 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3092 messages,
3093 })?;
3094 Ok(())
3095}
3096
3097async fn get_notifications(
3098 request: proto::GetNotifications,
3099 response: Response<proto::GetNotifications>,
3100 session: Session,
3101) -> Result<()> {
3102 let notifications = session
3103 .db()
3104 .await
3105 .get_notifications(
3106 session.user_id,
3107 NOTIFICATION_COUNT_PER_PAGE,
3108 request
3109 .before_id
3110 .map(|id| db::NotificationId::from_proto(id)),
3111 )
3112 .await?;
3113 response.send(proto::GetNotificationsResponse {
3114 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3115 notifications,
3116 })?;
3117 Ok(())
3118}
3119
3120async fn mark_notification_as_read(
3121 request: proto::MarkNotificationRead,
3122 response: Response<proto::MarkNotificationRead>,
3123 session: Session,
3124) -> Result<()> {
3125 let database = &session.db().await;
3126 let notifications = database
3127 .mark_notification_as_read_by_id(
3128 session.user_id,
3129 NotificationId::from_proto(request.notification_id),
3130 )
3131 .await?;
3132 send_notifications(
3133 &*session.connection_pool().await,
3134 &session.peer,
3135 notifications,
3136 );
3137 response.send(proto::Ack {})?;
3138 Ok(())
3139}
3140
3141async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
3142 let project_id = ProjectId::from_proto(request.project_id);
3143 let project_connection_ids = session
3144 .db()
3145 .await
3146 .project_connection_ids(project_id, session.connection_id)
3147 .await?;
3148 broadcast(
3149 Some(session.connection_id),
3150 project_connection_ids.iter().copied(),
3151 |connection_id| {
3152 session
3153 .peer
3154 .forward_send(session.connection_id, connection_id, request.clone())
3155 },
3156 );
3157 Ok(())
3158}
3159
3160async fn get_private_user_info(
3161 _request: proto::GetPrivateUserInfo,
3162 response: Response<proto::GetPrivateUserInfo>,
3163 session: Session,
3164) -> Result<()> {
3165 let db = session.db().await;
3166
3167 let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3168 let user = db
3169 .get_user_by_id(session.user_id)
3170 .await?
3171 .ok_or_else(|| anyhow!("user not found"))?;
3172 let flags = db.get_user_flags(session.user_id).await?;
3173
3174 response.send(proto::GetPrivateUserInfoResponse {
3175 metrics_id,
3176 staff: user.admin,
3177 flags,
3178 })?;
3179 Ok(())
3180}
3181
3182fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3183 match message {
3184 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3185 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3186 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3187 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3188 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3189 code: frame.code.into(),
3190 reason: frame.reason,
3191 })),
3192 }
3193}
3194
3195fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3196 match message {
3197 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3198 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3199 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3200 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3201 AxumMessage::Close(frame) => {
3202 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3203 code: frame.code.into(),
3204 reason: frame.reason,
3205 }))
3206 }
3207 }
3208}
3209
3210fn notify_membership_updated(
3211 connection_pool: &ConnectionPool,
3212 result: MembershipUpdated,
3213 user_id: UserId,
3214 peer: &Peer,
3215) {
3216 let mut update = build_channels_update(result.new_channels, vec![]);
3217 update.delete_channels = result
3218 .removed_channels
3219 .into_iter()
3220 .map(|id| id.to_proto())
3221 .collect();
3222 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3223
3224 for connection_id in connection_pool.user_connection_ids(user_id) {
3225 peer.send(connection_id, update.clone()).trace_err();
3226 }
3227}
3228
3229fn build_channels_update(
3230 channels: ChannelsForUser,
3231 channel_invites: Vec<db::Channel>,
3232) -> proto::UpdateChannels {
3233 let mut update = proto::UpdateChannels::default();
3234
3235 for channel in channels.channels {
3236 update.channels.push(channel.to_proto());
3237 }
3238
3239 update.unseen_channel_buffer_changes = channels.unseen_buffer_changes;
3240 update.unseen_channel_messages = channels.channel_messages;
3241
3242 for (channel_id, participants) in channels.channel_participants {
3243 update
3244 .channel_participants
3245 .push(proto::ChannelParticipants {
3246 channel_id: channel_id.to_proto(),
3247 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3248 });
3249 }
3250
3251 for channel in channel_invites {
3252 update.channel_invitations.push(channel.to_proto());
3253 }
3254
3255 update
3256}
3257
3258fn build_initial_contacts_update(
3259 contacts: Vec<db::Contact>,
3260 pool: &ConnectionPool,
3261) -> proto::UpdateContacts {
3262 let mut update = proto::UpdateContacts::default();
3263
3264 for contact in contacts {
3265 match contact {
3266 db::Contact::Accepted { user_id, busy } => {
3267 update.contacts.push(contact_for_user(user_id, busy, &pool));
3268 }
3269 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3270 db::Contact::Incoming { user_id } => {
3271 update
3272 .incoming_requests
3273 .push(proto::IncomingContactRequest {
3274 requester_id: user_id.to_proto(),
3275 })
3276 }
3277 }
3278 }
3279
3280 update
3281}
3282
3283fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3284 proto::Contact {
3285 user_id: user_id.to_proto(),
3286 online: pool.is_user_online(user_id),
3287 busy,
3288 }
3289}
3290
3291fn room_updated(room: &proto::Room, peer: &Peer) {
3292 broadcast(
3293 None,
3294 room.participants
3295 .iter()
3296 .filter_map(|participant| Some(participant.peer_id?.into())),
3297 |peer_id| {
3298 peer.send(
3299 peer_id.into(),
3300 proto::RoomUpdated {
3301 room: Some(room.clone()),
3302 },
3303 )
3304 },
3305 );
3306}
3307
3308fn channel_updated(
3309 channel_id: ChannelId,
3310 room: &proto::Room,
3311 channel_members: &[UserId],
3312 peer: &Peer,
3313 pool: &ConnectionPool,
3314) {
3315 let participants = room
3316 .participants
3317 .iter()
3318 .map(|p| p.user_id)
3319 .collect::<Vec<_>>();
3320
3321 broadcast(
3322 None,
3323 channel_members
3324 .iter()
3325 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3326 |peer_id| {
3327 peer.send(
3328 peer_id.into(),
3329 proto::UpdateChannels {
3330 channel_participants: vec![proto::ChannelParticipants {
3331 channel_id: channel_id.to_proto(),
3332 participant_user_ids: participants.clone(),
3333 }],
3334 ..Default::default()
3335 },
3336 )
3337 },
3338 );
3339}
3340
3341async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3342 let db = session.db().await;
3343
3344 let contacts = db.get_contacts(user_id).await?;
3345 let busy = db.is_user_busy(user_id).await?;
3346
3347 let pool = session.connection_pool().await;
3348 let updated_contact = contact_for_user(user_id, busy, &pool);
3349 for contact in contacts {
3350 if let db::Contact::Accepted {
3351 user_id: contact_user_id,
3352 ..
3353 } = contact
3354 {
3355 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3356 session
3357 .peer
3358 .send(
3359 contact_conn_id,
3360 proto::UpdateContacts {
3361 contacts: vec![updated_contact.clone()],
3362 remove_contacts: Default::default(),
3363 incoming_requests: Default::default(),
3364 remove_incoming_requests: Default::default(),
3365 outgoing_requests: Default::default(),
3366 remove_outgoing_requests: Default::default(),
3367 },
3368 )
3369 .trace_err();
3370 }
3371 }
3372 }
3373 Ok(())
3374}
3375
3376async fn leave_room_for_session(session: &Session) -> Result<()> {
3377 let mut contacts_to_update = HashSet::default();
3378
3379 let room_id;
3380 let canceled_calls_to_user_ids;
3381 let live_kit_room;
3382 let delete_live_kit_room;
3383 let room;
3384 let channel_members;
3385 let channel_id;
3386
3387 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3388 contacts_to_update.insert(session.user_id);
3389
3390 for project in left_room.left_projects.values() {
3391 project_left(project, session);
3392 }
3393
3394 room_id = RoomId::from_proto(left_room.room.id);
3395 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3396 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3397 delete_live_kit_room = left_room.deleted;
3398 room = mem::take(&mut left_room.room);
3399 channel_members = mem::take(&mut left_room.channel_members);
3400 channel_id = left_room.channel_id;
3401
3402 room_updated(&room, &session.peer);
3403 } else {
3404 return Ok(());
3405 }
3406
3407 if let Some(channel_id) = channel_id {
3408 channel_updated(
3409 channel_id,
3410 &room,
3411 &channel_members,
3412 &session.peer,
3413 &*session.connection_pool().await,
3414 );
3415 }
3416
3417 {
3418 let pool = session.connection_pool().await;
3419 for canceled_user_id in canceled_calls_to_user_ids {
3420 for connection_id in pool.user_connection_ids(canceled_user_id) {
3421 session
3422 .peer
3423 .send(
3424 connection_id,
3425 proto::CallCanceled {
3426 room_id: room_id.to_proto(),
3427 },
3428 )
3429 .trace_err();
3430 }
3431 contacts_to_update.insert(canceled_user_id);
3432 }
3433 }
3434
3435 for contact_user_id in contacts_to_update {
3436 update_user_contacts(contact_user_id, &session).await?;
3437 }
3438
3439 if let Some(live_kit) = session.live_kit_client.as_ref() {
3440 live_kit
3441 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3442 .await
3443 .trace_err();
3444
3445 if delete_live_kit_room {
3446 live_kit.delete_room(live_kit_room).await.trace_err();
3447 }
3448 }
3449
3450 Ok(())
3451}
3452
3453async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3454 let left_channel_buffers = session
3455 .db()
3456 .await
3457 .leave_channel_buffers(session.connection_id)
3458 .await?;
3459
3460 for left_buffer in left_channel_buffers {
3461 channel_buffer_updated(
3462 session.connection_id,
3463 left_buffer.connections,
3464 &proto::UpdateChannelBufferCollaborators {
3465 channel_id: left_buffer.channel_id.to_proto(),
3466 collaborators: left_buffer.collaborators,
3467 },
3468 &session.peer,
3469 );
3470 }
3471
3472 Ok(())
3473}
3474
3475fn project_left(project: &db::LeftProject, session: &Session) {
3476 for connection_id in &project.connection_ids {
3477 if project.host_user_id == session.user_id {
3478 session
3479 .peer
3480 .send(
3481 *connection_id,
3482 proto::UnshareProject {
3483 project_id: project.id.to_proto(),
3484 },
3485 )
3486 .trace_err();
3487 } else {
3488 session
3489 .peer
3490 .send(
3491 *connection_id,
3492 proto::RemoveProjectCollaborator {
3493 project_id: project.id.to_proto(),
3494 peer_id: Some(session.connection_id.into()),
3495 },
3496 )
3497 .trace_err();
3498 }
3499 }
3500}
3501
3502pub trait ResultExt {
3503 type Ok;
3504
3505 fn trace_err(self) -> Option<Self::Ok>;
3506}
3507
3508impl<T, E> ResultExt for Result<T, E>
3509where
3510 E: std::fmt::Debug,
3511{
3512 type Ok = T;
3513
3514 fn trace_err(self) -> Option<T> {
3515 match self {
3516 Ok(value) => Some(value),
3517 Err(error) => {
3518 tracing::error!("{:?}", error);
3519 None
3520 }
3521 }
3522 }
3523}