1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{
6 self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreateChannelResult,
7 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
8 MoveChannelResult, NotificationId, ProjectId, RemoveChannelMemberResult,
9 RenameChannelResult, RespondToChannelInvite, RoomId, ServerId, SetChannelVisibilityResult,
10 User, UserId,
11 },
12 executor::Executor,
13 AppState, Result,
14};
15use anyhow::anyhow;
16use async_tungstenite::tungstenite::{
17 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
18};
19use axum::{
20 body::Body,
21 extract::{
22 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
23 ConnectInfo, WebSocketUpgrade,
24 },
25 headers::{Header, HeaderName},
26 http::StatusCode,
27 middleware,
28 response::IntoResponse,
29 routing::get,
30 Extension, Router, TypedHeader,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::ConnectionPool;
34use futures::{
35 channel::oneshot,
36 future::{self, BoxFuture},
37 stream::FuturesUnordered,
38 FutureExt, SinkExt, StreamExt, TryStreamExt,
39};
40use lazy_static::lazy_static;
41use prometheus::{register_int_gauge, IntGauge};
42use rpc::{
43 proto::{
44 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
45 RequestMessage, UpdateChannelBufferCollaborators,
46 },
47 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
48};
49use serde::{Serialize, Serializer};
50use std::{
51 any::TypeId,
52 fmt,
53 future::Future,
54 marker::PhantomData,
55 mem,
56 net::SocketAddr,
57 ops::{Deref, DerefMut},
58 rc::Rc,
59 sync::{
60 atomic::{AtomicBool, Ordering::SeqCst},
61 Arc,
62 },
63 time::{Duration, Instant},
64};
65use time::OffsetDateTime;
66use tokio::sync::{watch, Semaphore};
67use tower::ServiceBuilder;
68use tracing::{info_span, instrument, Instrument};
69use util::channel::RELEASE_CHANNEL_NAME;
70
71pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
72pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
73
74const MESSAGE_COUNT_PER_PAGE: usize = 100;
75const MAX_MESSAGE_LEN: usize = 1024;
76const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
77
78lazy_static! {
79 static ref METRIC_CONNECTIONS: IntGauge =
80 register_int_gauge!("connections", "number of connections").unwrap();
81 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
82 "shared_projects",
83 "number of open projects with one or more guests"
84 )
85 .unwrap();
86}
87
88type MessageHandler =
89 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
90
91struct Response<R> {
92 peer: Arc<Peer>,
93 receipt: Receipt<R>,
94 responded: Arc<AtomicBool>,
95}
96
97impl<R: RequestMessage> Response<R> {
98 fn send(self, payload: R::Response) -> Result<()> {
99 self.responded.store(true, SeqCst);
100 self.peer.respond(self.receipt, payload)?;
101 Ok(())
102 }
103}
104
105#[derive(Clone)]
106struct Session {
107 user_id: UserId,
108 connection_id: ConnectionId,
109 db: Arc<tokio::sync::Mutex<DbHandle>>,
110 peer: Arc<Peer>,
111 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
112 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
113 executor: Executor,
114}
115
116impl Session {
117 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
118 #[cfg(test)]
119 tokio::task::yield_now().await;
120 let guard = self.db.lock().await;
121 #[cfg(test)]
122 tokio::task::yield_now().await;
123 guard
124 }
125
126 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
127 #[cfg(test)]
128 tokio::task::yield_now().await;
129 let guard = self.connection_pool.lock();
130 ConnectionPoolGuard {
131 guard,
132 _not_send: PhantomData,
133 }
134 }
135}
136
137impl fmt::Debug for Session {
138 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
139 f.debug_struct("Session")
140 .field("user_id", &self.user_id)
141 .field("connection_id", &self.connection_id)
142 .finish()
143 }
144}
145
146struct DbHandle(Arc<Database>);
147
148impl Deref for DbHandle {
149 type Target = Database;
150
151 fn deref(&self) -> &Self::Target {
152 self.0.as_ref()
153 }
154}
155
156pub struct Server {
157 id: parking_lot::Mutex<ServerId>,
158 peer: Arc<Peer>,
159 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
160 app_state: Arc<AppState>,
161 executor: Executor,
162 handlers: HashMap<TypeId, MessageHandler>,
163 teardown: watch::Sender<()>,
164}
165
166pub(crate) struct ConnectionPoolGuard<'a> {
167 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
168 _not_send: PhantomData<Rc<()>>,
169}
170
171#[derive(Serialize)]
172pub struct ServerSnapshot<'a> {
173 peer: &'a Peer,
174 #[serde(serialize_with = "serialize_deref")]
175 connection_pool: ConnectionPoolGuard<'a>,
176}
177
178pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
179where
180 S: Serializer,
181 T: Deref<Target = U>,
182 U: Serialize,
183{
184 Serialize::serialize(value.deref(), serializer)
185}
186
187impl Server {
188 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
189 let mut server = Self {
190 id: parking_lot::Mutex::new(id),
191 peer: Peer::new(id.0 as u32),
192 app_state,
193 executor,
194 connection_pool: Default::default(),
195 handlers: Default::default(),
196 teardown: watch::channel(()).0,
197 };
198
199 server
200 .add_request_handler(ping)
201 .add_request_handler(create_room)
202 .add_request_handler(join_room)
203 .add_request_handler(rejoin_room)
204 .add_request_handler(leave_room)
205 .add_request_handler(call)
206 .add_request_handler(cancel_call)
207 .add_message_handler(decline_call)
208 .add_request_handler(update_participant_location)
209 .add_request_handler(share_project)
210 .add_message_handler(unshare_project)
211 .add_request_handler(join_project)
212 .add_message_handler(leave_project)
213 .add_request_handler(update_project)
214 .add_request_handler(update_worktree)
215 .add_message_handler(start_language_server)
216 .add_message_handler(update_language_server)
217 .add_message_handler(update_diagnostic_summary)
218 .add_message_handler(update_worktree_settings)
219 .add_message_handler(refresh_inlay_hints)
220 .add_request_handler(forward_project_request::<proto::GetHover>)
221 .add_request_handler(forward_project_request::<proto::GetDefinition>)
222 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
223 .add_request_handler(forward_project_request::<proto::GetReferences>)
224 .add_request_handler(forward_project_request::<proto::SearchProject>)
225 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
226 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
227 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
228 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
229 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
230 .add_request_handler(forward_project_request::<proto::GetCompletions>)
231 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
232 .add_request_handler(forward_project_request::<proto::ResolveCompletionDocumentation>)
233 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
234 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
235 .add_request_handler(forward_project_request::<proto::PrepareRename>)
236 .add_request_handler(forward_project_request::<proto::PerformRename>)
237 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
238 .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
239 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
240 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
241 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
242 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
243 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
244 .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
245 .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
246 .add_request_handler(forward_project_request::<proto::InlayHints>)
247 .add_message_handler(create_buffer_for_peer)
248 .add_request_handler(update_buffer)
249 .add_message_handler(update_buffer_file)
250 .add_message_handler(buffer_reloaded)
251 .add_message_handler(buffer_saved)
252 .add_request_handler(forward_project_request::<proto::SaveBuffer>)
253 .add_request_handler(get_users)
254 .add_request_handler(fuzzy_search_users)
255 .add_request_handler(request_contact)
256 .add_request_handler(remove_contact)
257 .add_request_handler(respond_to_contact_request)
258 .add_request_handler(create_channel)
259 .add_request_handler(delete_channel)
260 .add_request_handler(invite_channel_member)
261 .add_request_handler(remove_channel_member)
262 .add_request_handler(set_channel_member_role)
263 .add_request_handler(set_channel_visibility)
264 .add_request_handler(rename_channel)
265 .add_request_handler(join_channel_buffer)
266 .add_request_handler(leave_channel_buffer)
267 .add_message_handler(update_channel_buffer)
268 .add_request_handler(rejoin_channel_buffers)
269 .add_request_handler(get_channel_members)
270 .add_request_handler(respond_to_channel_invite)
271 .add_request_handler(join_channel)
272 .add_request_handler(join_channel_chat)
273 .add_message_handler(leave_channel_chat)
274 .add_request_handler(send_channel_message)
275 .add_request_handler(remove_channel_message)
276 .add_request_handler(get_channel_messages)
277 .add_request_handler(get_channel_messages_by_id)
278 .add_request_handler(get_notifications)
279 .add_request_handler(mark_notification_as_read)
280 .add_request_handler(link_channel)
281 .add_request_handler(unlink_channel)
282 .add_request_handler(move_channel)
283 .add_request_handler(follow)
284 .add_message_handler(unfollow)
285 .add_message_handler(update_followers)
286 .add_message_handler(update_diff_base)
287 .add_request_handler(get_private_user_info)
288 .add_message_handler(acknowledge_channel_message)
289 .add_message_handler(acknowledge_buffer_version);
290
291 Arc::new(server)
292 }
293
294 pub async fn start(&self) -> Result<()> {
295 let server_id = *self.id.lock();
296 let app_state = self.app_state.clone();
297 let peer = self.peer.clone();
298 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
299 let pool = self.connection_pool.clone();
300 let live_kit_client = self.app_state.live_kit_client.clone();
301
302 let span = info_span!("start server");
303 self.executor.spawn_detached(
304 async move {
305 tracing::info!("waiting for cleanup timeout");
306 timeout.await;
307 tracing::info!("cleanup timeout expired, retrieving stale rooms");
308 if let Some((room_ids, channel_ids)) = app_state
309 .db
310 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
311 .await
312 .trace_err()
313 {
314 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
315 tracing::info!(
316 stale_channel_buffer_count = channel_ids.len(),
317 "retrieved stale channel buffers"
318 );
319
320 for channel_id in channel_ids {
321 if let Some(refreshed_channel_buffer) = app_state
322 .db
323 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
324 .await
325 .trace_err()
326 {
327 for connection_id in refreshed_channel_buffer.connection_ids {
328 peer.send(
329 connection_id,
330 proto::UpdateChannelBufferCollaborators {
331 channel_id: channel_id.to_proto(),
332 collaborators: refreshed_channel_buffer
333 .collaborators
334 .clone(),
335 },
336 )
337 .trace_err();
338 }
339 }
340 }
341
342 for room_id in room_ids {
343 let mut contacts_to_update = HashSet::default();
344 let mut canceled_calls_to_user_ids = Vec::new();
345 let mut live_kit_room = String::new();
346 let mut delete_live_kit_room = false;
347
348 if let Some(mut refreshed_room) = app_state
349 .db
350 .clear_stale_room_participants(room_id, server_id)
351 .await
352 .trace_err()
353 {
354 tracing::info!(
355 room_id = room_id.0,
356 new_participant_count = refreshed_room.room.participants.len(),
357 "refreshed room"
358 );
359 room_updated(&refreshed_room.room, &peer);
360 if let Some(channel_id) = refreshed_room.channel_id {
361 channel_updated(
362 channel_id,
363 &refreshed_room.room,
364 &refreshed_room.channel_members,
365 &peer,
366 &*pool.lock(),
367 );
368 }
369 contacts_to_update
370 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
371 contacts_to_update
372 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
373 canceled_calls_to_user_ids =
374 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
375 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
376 delete_live_kit_room = refreshed_room.room.participants.is_empty();
377 }
378
379 {
380 let pool = pool.lock();
381 for canceled_user_id in canceled_calls_to_user_ids {
382 for connection_id in pool.user_connection_ids(canceled_user_id) {
383 peer.send(
384 connection_id,
385 proto::CallCanceled {
386 room_id: room_id.to_proto(),
387 },
388 )
389 .trace_err();
390 }
391 }
392 }
393
394 for user_id in contacts_to_update {
395 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
396 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
397 if let Some((busy, contacts)) = busy.zip(contacts) {
398 let pool = pool.lock();
399 let updated_contact = contact_for_user(user_id, busy, &pool);
400 for contact in contacts {
401 if let db::Contact::Accepted {
402 user_id: contact_user_id,
403 ..
404 } = contact
405 {
406 for contact_conn_id in
407 pool.user_connection_ids(contact_user_id)
408 {
409 peer.send(
410 contact_conn_id,
411 proto::UpdateContacts {
412 contacts: vec![updated_contact.clone()],
413 remove_contacts: Default::default(),
414 incoming_requests: Default::default(),
415 remove_incoming_requests: Default::default(),
416 outgoing_requests: Default::default(),
417 remove_outgoing_requests: Default::default(),
418 },
419 )
420 .trace_err();
421 }
422 }
423 }
424 }
425 }
426
427 if let Some(live_kit) = live_kit_client.as_ref() {
428 if delete_live_kit_room {
429 live_kit.delete_room(live_kit_room).await.trace_err();
430 }
431 }
432 }
433 }
434
435 app_state
436 .db
437 .delete_stale_servers(&app_state.config.zed_environment, server_id)
438 .await
439 .trace_err();
440 }
441 .instrument(span),
442 );
443 Ok(())
444 }
445
446 pub fn teardown(&self) {
447 self.peer.teardown();
448 self.connection_pool.lock().reset();
449 let _ = self.teardown.send(());
450 }
451
452 #[cfg(test)]
453 pub fn reset(&self, id: ServerId) {
454 self.teardown();
455 *self.id.lock() = id;
456 self.peer.reset(id.0 as u32);
457 }
458
459 #[cfg(test)]
460 pub fn id(&self) -> ServerId {
461 *self.id.lock()
462 }
463
464 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
465 where
466 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
467 Fut: 'static + Send + Future<Output = Result<()>>,
468 M: EnvelopedMessage,
469 {
470 let prev_handler = self.handlers.insert(
471 TypeId::of::<M>(),
472 Box::new(move |envelope, session| {
473 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
474 let span = info_span!(
475 "handle message",
476 payload_type = envelope.payload_type_name()
477 );
478 span.in_scope(|| {
479 tracing::info!(
480 payload_type = envelope.payload_type_name(),
481 "message received"
482 );
483 });
484 let start_time = Instant::now();
485 let future = (handler)(*envelope, session);
486 async move {
487 let result = future.await;
488 let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
489 match result {
490 Err(error) => {
491 tracing::error!(%error, ?duration_ms, "error handling message")
492 }
493 Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
494 }
495 }
496 .instrument(span)
497 .boxed()
498 }),
499 );
500 if prev_handler.is_some() {
501 panic!("registered a handler for the same message twice");
502 }
503 self
504 }
505
506 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
507 where
508 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
509 Fut: 'static + Send + Future<Output = Result<()>>,
510 M: EnvelopedMessage,
511 {
512 self.add_handler(move |envelope, session| handler(envelope.payload, session));
513 self
514 }
515
516 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
517 where
518 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
519 Fut: Send + Future<Output = Result<()>>,
520 M: RequestMessage,
521 {
522 let handler = Arc::new(handler);
523 self.add_handler(move |envelope, session| {
524 let receipt = envelope.receipt();
525 let handler = handler.clone();
526 async move {
527 let peer = session.peer.clone();
528 let responded = Arc::new(AtomicBool::default());
529 let response = Response {
530 peer: peer.clone(),
531 responded: responded.clone(),
532 receipt,
533 };
534 match (handler)(envelope.payload, response, session).await {
535 Ok(()) => {
536 if responded.load(std::sync::atomic::Ordering::SeqCst) {
537 Ok(())
538 } else {
539 Err(anyhow!("handler did not send a response"))?
540 }
541 }
542 Err(error) => {
543 peer.respond_with_error(
544 receipt,
545 proto::Error {
546 message: error.to_string(),
547 },
548 )?;
549 Err(error)
550 }
551 }
552 }
553 })
554 }
555
556 pub fn handle_connection(
557 self: &Arc<Self>,
558 connection: Connection,
559 address: String,
560 user: User,
561 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
562 executor: Executor,
563 ) -> impl Future<Output = Result<()>> {
564 let this = self.clone();
565 let user_id = user.id;
566 let login = user.github_login;
567 let span = info_span!("handle connection", %user_id, %login, %address);
568 let mut teardown = self.teardown.subscribe();
569 async move {
570 let (connection_id, handle_io, mut incoming_rx) = this
571 .peer
572 .add_connection(connection, {
573 let executor = executor.clone();
574 move |duration| executor.sleep(duration)
575 });
576
577 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
578 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
579 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
580
581 if let Some(send_connection_id) = send_connection_id.take() {
582 let _ = send_connection_id.send(connection_id);
583 }
584
585 if !user.connected_once {
586 this.peer.send(connection_id, proto::ShowContacts {})?;
587 this.app_state.db.set_user_connected_once(user_id, true).await?;
588 }
589
590 let (contacts, channels_for_user, channel_invites) = future::try_join3(
591 this.app_state.db.get_contacts(user_id),
592 this.app_state.db.get_channels_for_user(user_id),
593 this.app_state.db.get_channel_invites_for_user(user_id),
594 ).await?;
595
596 {
597 let mut pool = this.connection_pool.lock();
598 pool.add_connection(connection_id, user_id, user.admin);
599 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
600 this.peer.send(connection_id, build_channels_update(
601 channels_for_user,
602 channel_invites
603 ))?;
604 }
605
606 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
607 this.peer.send(connection_id, incoming_call)?;
608 }
609
610 let session = Session {
611 user_id,
612 connection_id,
613 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
614 peer: this.peer.clone(),
615 connection_pool: this.connection_pool.clone(),
616 live_kit_client: this.app_state.live_kit_client.clone(),
617 executor: executor.clone(),
618 };
619 update_user_contacts(user_id, &session).await?;
620
621 let handle_io = handle_io.fuse();
622 futures::pin_mut!(handle_io);
623
624 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
625 // This prevents deadlocks when e.g., client A performs a request to client B and
626 // client B performs a request to client A. If both clients stop processing further
627 // messages until their respective request completes, they won't have a chance to
628 // respond to the other client's request and cause a deadlock.
629 //
630 // This arrangement ensures we will attempt to process earlier messages first, but fall
631 // back to processing messages arrived later in the spirit of making progress.
632 let mut foreground_message_handlers = FuturesUnordered::new();
633 let concurrent_handlers = Arc::new(Semaphore::new(256));
634 loop {
635 let next_message = async {
636 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
637 let message = incoming_rx.next().await;
638 (permit, message)
639 }.fuse();
640 futures::pin_mut!(next_message);
641 futures::select_biased! {
642 _ = teardown.changed().fuse() => return Ok(()),
643 result = handle_io => {
644 if let Err(error) = result {
645 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
646 }
647 break;
648 }
649 _ = foreground_message_handlers.next() => {}
650 next_message = next_message => {
651 let (permit, message) = next_message;
652 if let Some(message) = message {
653 let type_name = message.payload_type_name();
654 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
655 let span_enter = span.enter();
656 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
657 let is_background = message.is_background();
658 let handle_message = (handler)(message, session.clone());
659 drop(span_enter);
660
661 let handle_message = async move {
662 handle_message.await;
663 drop(permit);
664 }.instrument(span);
665 if is_background {
666 executor.spawn_detached(handle_message);
667 } else {
668 foreground_message_handlers.push(handle_message);
669 }
670 } else {
671 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
672 }
673 } else {
674 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
675 break;
676 }
677 }
678 }
679 }
680
681 drop(foreground_message_handlers);
682 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
683 if let Err(error) = connection_lost(session, teardown, executor).await {
684 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
685 }
686
687 Ok(())
688 }.instrument(span)
689 }
690
691 pub async fn invite_code_redeemed(
692 self: &Arc<Self>,
693 inviter_id: UserId,
694 invitee_id: UserId,
695 ) -> Result<()> {
696 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
697 if let Some(code) = &user.invite_code {
698 let pool = self.connection_pool.lock();
699 let invitee_contact = contact_for_user(invitee_id, false, &pool);
700 for connection_id in pool.user_connection_ids(inviter_id) {
701 self.peer.send(
702 connection_id,
703 proto::UpdateContacts {
704 contacts: vec![invitee_contact.clone()],
705 ..Default::default()
706 },
707 )?;
708 self.peer.send(
709 connection_id,
710 proto::UpdateInviteInfo {
711 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
712 count: user.invite_count as u32,
713 },
714 )?;
715 }
716 }
717 }
718 Ok(())
719 }
720
721 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
722 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
723 if let Some(invite_code) = &user.invite_code {
724 let pool = self.connection_pool.lock();
725 for connection_id in pool.user_connection_ids(user_id) {
726 self.peer.send(
727 connection_id,
728 proto::UpdateInviteInfo {
729 url: format!(
730 "{}{}",
731 self.app_state.config.invite_link_prefix, invite_code
732 ),
733 count: user.invite_count as u32,
734 },
735 )?;
736 }
737 }
738 }
739 Ok(())
740 }
741
742 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
743 ServerSnapshot {
744 connection_pool: ConnectionPoolGuard {
745 guard: self.connection_pool.lock(),
746 _not_send: PhantomData,
747 },
748 peer: &self.peer,
749 }
750 }
751}
752
753impl<'a> Deref for ConnectionPoolGuard<'a> {
754 type Target = ConnectionPool;
755
756 fn deref(&self) -> &Self::Target {
757 &*self.guard
758 }
759}
760
761impl<'a> DerefMut for ConnectionPoolGuard<'a> {
762 fn deref_mut(&mut self) -> &mut Self::Target {
763 &mut *self.guard
764 }
765}
766
767impl<'a> Drop for ConnectionPoolGuard<'a> {
768 fn drop(&mut self) {
769 #[cfg(test)]
770 self.check_invariants();
771 }
772}
773
774fn broadcast<F>(
775 sender_id: Option<ConnectionId>,
776 receiver_ids: impl IntoIterator<Item = ConnectionId>,
777 mut f: F,
778) where
779 F: FnMut(ConnectionId) -> anyhow::Result<()>,
780{
781 for receiver_id in receiver_ids {
782 if Some(receiver_id) != sender_id {
783 if let Err(error) = f(receiver_id) {
784 tracing::error!("failed to send to {:?} {}", receiver_id, error);
785 }
786 }
787 }
788}
789
790lazy_static! {
791 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
792}
793
794pub struct ProtocolVersion(u32);
795
796impl Header for ProtocolVersion {
797 fn name() -> &'static HeaderName {
798 &ZED_PROTOCOL_VERSION
799 }
800
801 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
802 where
803 Self: Sized,
804 I: Iterator<Item = &'i axum::http::HeaderValue>,
805 {
806 let version = values
807 .next()
808 .ok_or_else(axum::headers::Error::invalid)?
809 .to_str()
810 .map_err(|_| axum::headers::Error::invalid())?
811 .parse()
812 .map_err(|_| axum::headers::Error::invalid())?;
813 Ok(Self(version))
814 }
815
816 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
817 values.extend([self.0.to_string().parse().unwrap()]);
818 }
819}
820
821pub fn routes(server: Arc<Server>) -> Router<Body> {
822 Router::new()
823 .route("/rpc", get(handle_websocket_request))
824 .layer(
825 ServiceBuilder::new()
826 .layer(Extension(server.app_state.clone()))
827 .layer(middleware::from_fn(auth::validate_header)),
828 )
829 .route("/metrics", get(handle_metrics))
830 .layer(Extension(server))
831}
832
833pub async fn handle_websocket_request(
834 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
835 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
836 Extension(server): Extension<Arc<Server>>,
837 Extension(user): Extension<User>,
838 ws: WebSocketUpgrade,
839) -> axum::response::Response {
840 if protocol_version != rpc::PROTOCOL_VERSION {
841 return (
842 StatusCode::UPGRADE_REQUIRED,
843 "client must be upgraded".to_string(),
844 )
845 .into_response();
846 }
847 let socket_address = socket_address.to_string();
848 ws.on_upgrade(move |socket| {
849 use util::ResultExt;
850 let socket = socket
851 .map_ok(to_tungstenite_message)
852 .err_into()
853 .with(|message| async move { Ok(to_axum_message(message)) });
854 let connection = Connection::new(Box::pin(socket));
855 async move {
856 server
857 .handle_connection(connection, socket_address, user, None, Executor::Production)
858 .await
859 .log_err();
860 }
861 })
862}
863
864pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
865 let connections = server
866 .connection_pool
867 .lock()
868 .connections()
869 .filter(|connection| !connection.admin)
870 .count();
871
872 METRIC_CONNECTIONS.set(connections as _);
873
874 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
875 METRIC_SHARED_PROJECTS.set(shared_projects as _);
876
877 let encoder = prometheus::TextEncoder::new();
878 let metric_families = prometheus::gather();
879 let encoded_metrics = encoder
880 .encode_to_string(&metric_families)
881 .map_err(|err| anyhow!("{}", err))?;
882 Ok(encoded_metrics)
883}
884
885#[instrument(err, skip(executor))]
886async fn connection_lost(
887 session: Session,
888 mut teardown: watch::Receiver<()>,
889 executor: Executor,
890) -> Result<()> {
891 session.peer.disconnect(session.connection_id);
892 session
893 .connection_pool()
894 .await
895 .remove_connection(session.connection_id)?;
896
897 session
898 .db()
899 .await
900 .connection_lost(session.connection_id)
901 .await
902 .trace_err();
903
904 futures::select_biased! {
905 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
906 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
907 leave_room_for_session(&session).await.trace_err();
908 leave_channel_buffers_for_session(&session)
909 .await
910 .trace_err();
911
912 if !session
913 .connection_pool()
914 .await
915 .is_user_online(session.user_id)
916 {
917 let db = session.db().await;
918 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
919 room_updated(&room, &session.peer);
920 }
921 }
922
923 update_user_contacts(session.user_id, &session).await?;
924 }
925 _ = teardown.changed().fuse() => {}
926 }
927
928 Ok(())
929}
930
931async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
932 response.send(proto::Ack {})?;
933 Ok(())
934}
935
936async fn create_room(
937 _request: proto::CreateRoom,
938 response: Response<proto::CreateRoom>,
939 session: Session,
940) -> Result<()> {
941 let live_kit_room = nanoid::nanoid!(30);
942
943 let live_kit_connection_info = {
944 let live_kit_room = live_kit_room.clone();
945 let live_kit = session.live_kit_client.as_ref();
946
947 util::async_maybe!({
948 let live_kit = live_kit?;
949
950 let token = live_kit
951 .room_token(&live_kit_room, &session.user_id.to_string())
952 .trace_err()?;
953
954 Some(proto::LiveKitConnectionInfo {
955 server_url: live_kit.url().into(),
956 token,
957 can_publish: true,
958 })
959 })
960 }
961 .await;
962
963 let room = session
964 .db()
965 .await
966 .create_room(
967 session.user_id,
968 session.connection_id,
969 &live_kit_room,
970 RELEASE_CHANNEL_NAME.as_str(),
971 )
972 .await?;
973
974 response.send(proto::CreateRoomResponse {
975 room: Some(room.clone()),
976 live_kit_connection_info,
977 })?;
978
979 update_user_contacts(session.user_id, &session).await?;
980 Ok(())
981}
982
983async fn join_room(
984 request: proto::JoinRoom,
985 response: Response<proto::JoinRoom>,
986 session: Session,
987) -> Result<()> {
988 let room_id = RoomId::from_proto(request.id);
989
990 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
991
992 if let Some(channel_id) = channel_id {
993 return join_channel_internal(channel_id, Box::new(response), session).await;
994 }
995
996 let joined_room = {
997 let room = session
998 .db()
999 .await
1000 .join_room(
1001 room_id,
1002 session.user_id,
1003 session.connection_id,
1004 RELEASE_CHANNEL_NAME.as_str(),
1005 )
1006 .await?;
1007 room_updated(&room.room, &session.peer);
1008 room.into_inner()
1009 };
1010
1011 for connection_id in session
1012 .connection_pool()
1013 .await
1014 .user_connection_ids(session.user_id)
1015 {
1016 session
1017 .peer
1018 .send(
1019 connection_id,
1020 proto::CallCanceled {
1021 room_id: room_id.to_proto(),
1022 },
1023 )
1024 .trace_err();
1025 }
1026
1027 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1028 if let Some(token) = live_kit
1029 .room_token(
1030 &joined_room.room.live_kit_room,
1031 &session.user_id.to_string(),
1032 )
1033 .trace_err()
1034 {
1035 Some(proto::LiveKitConnectionInfo {
1036 server_url: live_kit.url().into(),
1037 token,
1038 can_publish: true,
1039 })
1040 } else {
1041 None
1042 }
1043 } else {
1044 None
1045 };
1046
1047 response.send(proto::JoinRoomResponse {
1048 room: Some(joined_room.room),
1049 channel_id: None,
1050 live_kit_connection_info,
1051 })?;
1052
1053 update_user_contacts(session.user_id, &session).await?;
1054 Ok(())
1055}
1056
1057async fn rejoin_room(
1058 request: proto::RejoinRoom,
1059 response: Response<proto::RejoinRoom>,
1060 session: Session,
1061) -> Result<()> {
1062 let room;
1063 let channel_id;
1064 let channel_members;
1065 {
1066 let mut rejoined_room = session
1067 .db()
1068 .await
1069 .rejoin_room(request, session.user_id, session.connection_id)
1070 .await?;
1071
1072 response.send(proto::RejoinRoomResponse {
1073 room: Some(rejoined_room.room.clone()),
1074 reshared_projects: rejoined_room
1075 .reshared_projects
1076 .iter()
1077 .map(|project| proto::ResharedProject {
1078 id: project.id.to_proto(),
1079 collaborators: project
1080 .collaborators
1081 .iter()
1082 .map(|collaborator| collaborator.to_proto())
1083 .collect(),
1084 })
1085 .collect(),
1086 rejoined_projects: rejoined_room
1087 .rejoined_projects
1088 .iter()
1089 .map(|rejoined_project| proto::RejoinedProject {
1090 id: rejoined_project.id.to_proto(),
1091 worktrees: rejoined_project
1092 .worktrees
1093 .iter()
1094 .map(|worktree| proto::WorktreeMetadata {
1095 id: worktree.id,
1096 root_name: worktree.root_name.clone(),
1097 visible: worktree.visible,
1098 abs_path: worktree.abs_path.clone(),
1099 })
1100 .collect(),
1101 collaborators: rejoined_project
1102 .collaborators
1103 .iter()
1104 .map(|collaborator| collaborator.to_proto())
1105 .collect(),
1106 language_servers: rejoined_project.language_servers.clone(),
1107 })
1108 .collect(),
1109 })?;
1110 room_updated(&rejoined_room.room, &session.peer);
1111
1112 for project in &rejoined_room.reshared_projects {
1113 for collaborator in &project.collaborators {
1114 session
1115 .peer
1116 .send(
1117 collaborator.connection_id,
1118 proto::UpdateProjectCollaborator {
1119 project_id: project.id.to_proto(),
1120 old_peer_id: Some(project.old_connection_id.into()),
1121 new_peer_id: Some(session.connection_id.into()),
1122 },
1123 )
1124 .trace_err();
1125 }
1126
1127 broadcast(
1128 Some(session.connection_id),
1129 project
1130 .collaborators
1131 .iter()
1132 .map(|collaborator| collaborator.connection_id),
1133 |connection_id| {
1134 session.peer.forward_send(
1135 session.connection_id,
1136 connection_id,
1137 proto::UpdateProject {
1138 project_id: project.id.to_proto(),
1139 worktrees: project.worktrees.clone(),
1140 },
1141 )
1142 },
1143 );
1144 }
1145
1146 for project in &rejoined_room.rejoined_projects {
1147 for collaborator in &project.collaborators {
1148 session
1149 .peer
1150 .send(
1151 collaborator.connection_id,
1152 proto::UpdateProjectCollaborator {
1153 project_id: project.id.to_proto(),
1154 old_peer_id: Some(project.old_connection_id.into()),
1155 new_peer_id: Some(session.connection_id.into()),
1156 },
1157 )
1158 .trace_err();
1159 }
1160 }
1161
1162 for project in &mut rejoined_room.rejoined_projects {
1163 for worktree in mem::take(&mut project.worktrees) {
1164 #[cfg(any(test, feature = "test-support"))]
1165 const MAX_CHUNK_SIZE: usize = 2;
1166 #[cfg(not(any(test, feature = "test-support")))]
1167 const MAX_CHUNK_SIZE: usize = 256;
1168
1169 // Stream this worktree's entries.
1170 let message = proto::UpdateWorktree {
1171 project_id: project.id.to_proto(),
1172 worktree_id: worktree.id,
1173 abs_path: worktree.abs_path.clone(),
1174 root_name: worktree.root_name,
1175 updated_entries: worktree.updated_entries,
1176 removed_entries: worktree.removed_entries,
1177 scan_id: worktree.scan_id,
1178 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1179 updated_repositories: worktree.updated_repositories,
1180 removed_repositories: worktree.removed_repositories,
1181 };
1182 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1183 session.peer.send(session.connection_id, update.clone())?;
1184 }
1185
1186 // Stream this worktree's diagnostics.
1187 for summary in worktree.diagnostic_summaries {
1188 session.peer.send(
1189 session.connection_id,
1190 proto::UpdateDiagnosticSummary {
1191 project_id: project.id.to_proto(),
1192 worktree_id: worktree.id,
1193 summary: Some(summary),
1194 },
1195 )?;
1196 }
1197
1198 for settings_file in worktree.settings_files {
1199 session.peer.send(
1200 session.connection_id,
1201 proto::UpdateWorktreeSettings {
1202 project_id: project.id.to_proto(),
1203 worktree_id: worktree.id,
1204 path: settings_file.path,
1205 content: Some(settings_file.content),
1206 },
1207 )?;
1208 }
1209 }
1210
1211 for language_server in &project.language_servers {
1212 session.peer.send(
1213 session.connection_id,
1214 proto::UpdateLanguageServer {
1215 project_id: project.id.to_proto(),
1216 language_server_id: language_server.id,
1217 variant: Some(
1218 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1219 proto::LspDiskBasedDiagnosticsUpdated {},
1220 ),
1221 ),
1222 },
1223 )?;
1224 }
1225 }
1226
1227 let rejoined_room = rejoined_room.into_inner();
1228
1229 room = rejoined_room.room;
1230 channel_id = rejoined_room.channel_id;
1231 channel_members = rejoined_room.channel_members;
1232 }
1233
1234 if let Some(channel_id) = channel_id {
1235 channel_updated(
1236 channel_id,
1237 &room,
1238 &channel_members,
1239 &session.peer,
1240 &*session.connection_pool().await,
1241 );
1242 }
1243
1244 update_user_contacts(session.user_id, &session).await?;
1245 Ok(())
1246}
1247
1248async fn leave_room(
1249 _: proto::LeaveRoom,
1250 response: Response<proto::LeaveRoom>,
1251 session: Session,
1252) -> Result<()> {
1253 leave_room_for_session(&session).await?;
1254 response.send(proto::Ack {})?;
1255 Ok(())
1256}
1257
1258async fn call(
1259 request: proto::Call,
1260 response: Response<proto::Call>,
1261 session: Session,
1262) -> Result<()> {
1263 let room_id = RoomId::from_proto(request.room_id);
1264 let calling_user_id = session.user_id;
1265 let calling_connection_id = session.connection_id;
1266 let called_user_id = UserId::from_proto(request.called_user_id);
1267 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1268 if !session
1269 .db()
1270 .await
1271 .has_contact(calling_user_id, called_user_id)
1272 .await?
1273 {
1274 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1275 }
1276
1277 let incoming_call = {
1278 let (room, incoming_call) = &mut *session
1279 .db()
1280 .await
1281 .call(
1282 room_id,
1283 calling_user_id,
1284 calling_connection_id,
1285 called_user_id,
1286 initial_project_id,
1287 )
1288 .await?;
1289 room_updated(&room, &session.peer);
1290 mem::take(incoming_call)
1291 };
1292 update_user_contacts(called_user_id, &session).await?;
1293
1294 let mut calls = session
1295 .connection_pool()
1296 .await
1297 .user_connection_ids(called_user_id)
1298 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1299 .collect::<FuturesUnordered<_>>();
1300
1301 while let Some(call_response) = calls.next().await {
1302 match call_response.as_ref() {
1303 Ok(_) => {
1304 response.send(proto::Ack {})?;
1305 return Ok(());
1306 }
1307 Err(_) => {
1308 call_response.trace_err();
1309 }
1310 }
1311 }
1312
1313 {
1314 let room = session
1315 .db()
1316 .await
1317 .call_failed(room_id, called_user_id)
1318 .await?;
1319 room_updated(&room, &session.peer);
1320 }
1321 update_user_contacts(called_user_id, &session).await?;
1322
1323 Err(anyhow!("failed to ring user"))?
1324}
1325
1326async fn cancel_call(
1327 request: proto::CancelCall,
1328 response: Response<proto::CancelCall>,
1329 session: Session,
1330) -> Result<()> {
1331 let called_user_id = UserId::from_proto(request.called_user_id);
1332 let room_id = RoomId::from_proto(request.room_id);
1333 {
1334 let room = session
1335 .db()
1336 .await
1337 .cancel_call(room_id, session.connection_id, called_user_id)
1338 .await?;
1339 room_updated(&room, &session.peer);
1340 }
1341
1342 for connection_id in session
1343 .connection_pool()
1344 .await
1345 .user_connection_ids(called_user_id)
1346 {
1347 session
1348 .peer
1349 .send(
1350 connection_id,
1351 proto::CallCanceled {
1352 room_id: room_id.to_proto(),
1353 },
1354 )
1355 .trace_err();
1356 }
1357 response.send(proto::Ack {})?;
1358
1359 update_user_contacts(called_user_id, &session).await?;
1360 Ok(())
1361}
1362
1363async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1364 let room_id = RoomId::from_proto(message.room_id);
1365 {
1366 let room = session
1367 .db()
1368 .await
1369 .decline_call(Some(room_id), session.user_id)
1370 .await?
1371 .ok_or_else(|| anyhow!("failed to decline call"))?;
1372 room_updated(&room, &session.peer);
1373 }
1374
1375 for connection_id in session
1376 .connection_pool()
1377 .await
1378 .user_connection_ids(session.user_id)
1379 {
1380 session
1381 .peer
1382 .send(
1383 connection_id,
1384 proto::CallCanceled {
1385 room_id: room_id.to_proto(),
1386 },
1387 )
1388 .trace_err();
1389 }
1390 update_user_contacts(session.user_id, &session).await?;
1391 Ok(())
1392}
1393
1394async fn update_participant_location(
1395 request: proto::UpdateParticipantLocation,
1396 response: Response<proto::UpdateParticipantLocation>,
1397 session: Session,
1398) -> Result<()> {
1399 let room_id = RoomId::from_proto(request.room_id);
1400 let location = request
1401 .location
1402 .ok_or_else(|| anyhow!("invalid location"))?;
1403
1404 let db = session.db().await;
1405 let room = db
1406 .update_room_participant_location(room_id, session.connection_id, location)
1407 .await?;
1408
1409 room_updated(&room, &session.peer);
1410 response.send(proto::Ack {})?;
1411 Ok(())
1412}
1413
1414async fn share_project(
1415 request: proto::ShareProject,
1416 response: Response<proto::ShareProject>,
1417 session: Session,
1418) -> Result<()> {
1419 let (project_id, room) = &*session
1420 .db()
1421 .await
1422 .share_project(
1423 RoomId::from_proto(request.room_id),
1424 session.connection_id,
1425 &request.worktrees,
1426 )
1427 .await?;
1428 response.send(proto::ShareProjectResponse {
1429 project_id: project_id.to_proto(),
1430 })?;
1431 room_updated(&room, &session.peer);
1432
1433 Ok(())
1434}
1435
1436async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1437 let project_id = ProjectId::from_proto(message.project_id);
1438
1439 let (room, guest_connection_ids) = &*session
1440 .db()
1441 .await
1442 .unshare_project(project_id, session.connection_id)
1443 .await?;
1444
1445 broadcast(
1446 Some(session.connection_id),
1447 guest_connection_ids.iter().copied(),
1448 |conn_id| session.peer.send(conn_id, message.clone()),
1449 );
1450 room_updated(&room, &session.peer);
1451
1452 Ok(())
1453}
1454
1455async fn join_project(
1456 request: proto::JoinProject,
1457 response: Response<proto::JoinProject>,
1458 session: Session,
1459) -> Result<()> {
1460 let project_id = ProjectId::from_proto(request.project_id);
1461 let guest_user_id = session.user_id;
1462
1463 tracing::info!(%project_id, "join project");
1464
1465 let (project, replica_id) = &mut *session
1466 .db()
1467 .await
1468 .join_project(project_id, session.connection_id)
1469 .await?;
1470
1471 let collaborators = project
1472 .collaborators
1473 .iter()
1474 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1475 .map(|collaborator| collaborator.to_proto())
1476 .collect::<Vec<_>>();
1477
1478 let worktrees = project
1479 .worktrees
1480 .iter()
1481 .map(|(id, worktree)| proto::WorktreeMetadata {
1482 id: *id,
1483 root_name: worktree.root_name.clone(),
1484 visible: worktree.visible,
1485 abs_path: worktree.abs_path.clone(),
1486 })
1487 .collect::<Vec<_>>();
1488
1489 for collaborator in &collaborators {
1490 session
1491 .peer
1492 .send(
1493 collaborator.peer_id.unwrap().into(),
1494 proto::AddProjectCollaborator {
1495 project_id: project_id.to_proto(),
1496 collaborator: Some(proto::Collaborator {
1497 peer_id: Some(session.connection_id.into()),
1498 replica_id: replica_id.0 as u32,
1499 user_id: guest_user_id.to_proto(),
1500 }),
1501 },
1502 )
1503 .trace_err();
1504 }
1505
1506 // First, we send the metadata associated with each worktree.
1507 response.send(proto::JoinProjectResponse {
1508 worktrees: worktrees.clone(),
1509 replica_id: replica_id.0 as u32,
1510 collaborators: collaborators.clone(),
1511 language_servers: project.language_servers.clone(),
1512 })?;
1513
1514 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1515 #[cfg(any(test, feature = "test-support"))]
1516 const MAX_CHUNK_SIZE: usize = 2;
1517 #[cfg(not(any(test, feature = "test-support")))]
1518 const MAX_CHUNK_SIZE: usize = 256;
1519
1520 // Stream this worktree's entries.
1521 let message = proto::UpdateWorktree {
1522 project_id: project_id.to_proto(),
1523 worktree_id,
1524 abs_path: worktree.abs_path.clone(),
1525 root_name: worktree.root_name,
1526 updated_entries: worktree.entries,
1527 removed_entries: Default::default(),
1528 scan_id: worktree.scan_id,
1529 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1530 updated_repositories: worktree.repository_entries.into_values().collect(),
1531 removed_repositories: Default::default(),
1532 };
1533 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1534 session.peer.send(session.connection_id, update.clone())?;
1535 }
1536
1537 // Stream this worktree's diagnostics.
1538 for summary in worktree.diagnostic_summaries {
1539 session.peer.send(
1540 session.connection_id,
1541 proto::UpdateDiagnosticSummary {
1542 project_id: project_id.to_proto(),
1543 worktree_id: worktree.id,
1544 summary: Some(summary),
1545 },
1546 )?;
1547 }
1548
1549 for settings_file in worktree.settings_files {
1550 session.peer.send(
1551 session.connection_id,
1552 proto::UpdateWorktreeSettings {
1553 project_id: project_id.to_proto(),
1554 worktree_id: worktree.id,
1555 path: settings_file.path,
1556 content: Some(settings_file.content),
1557 },
1558 )?;
1559 }
1560 }
1561
1562 for language_server in &project.language_servers {
1563 session.peer.send(
1564 session.connection_id,
1565 proto::UpdateLanguageServer {
1566 project_id: project_id.to_proto(),
1567 language_server_id: language_server.id,
1568 variant: Some(
1569 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1570 proto::LspDiskBasedDiagnosticsUpdated {},
1571 ),
1572 ),
1573 },
1574 )?;
1575 }
1576
1577 Ok(())
1578}
1579
1580async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1581 let sender_id = session.connection_id;
1582 let project_id = ProjectId::from_proto(request.project_id);
1583
1584 let (room, project) = &*session
1585 .db()
1586 .await
1587 .leave_project(project_id, sender_id)
1588 .await?;
1589 tracing::info!(
1590 %project_id,
1591 host_user_id = %project.host_user_id,
1592 host_connection_id = %project.host_connection_id,
1593 "leave project"
1594 );
1595
1596 project_left(&project, &session);
1597 room_updated(&room, &session.peer);
1598
1599 Ok(())
1600}
1601
1602async fn update_project(
1603 request: proto::UpdateProject,
1604 response: Response<proto::UpdateProject>,
1605 session: Session,
1606) -> Result<()> {
1607 let project_id = ProjectId::from_proto(request.project_id);
1608 let (room, guest_connection_ids) = &*session
1609 .db()
1610 .await
1611 .update_project(project_id, session.connection_id, &request.worktrees)
1612 .await?;
1613 broadcast(
1614 Some(session.connection_id),
1615 guest_connection_ids.iter().copied(),
1616 |connection_id| {
1617 session
1618 .peer
1619 .forward_send(session.connection_id, connection_id, request.clone())
1620 },
1621 );
1622 room_updated(&room, &session.peer);
1623 response.send(proto::Ack {})?;
1624
1625 Ok(())
1626}
1627
1628async fn update_worktree(
1629 request: proto::UpdateWorktree,
1630 response: Response<proto::UpdateWorktree>,
1631 session: Session,
1632) -> Result<()> {
1633 let guest_connection_ids = session
1634 .db()
1635 .await
1636 .update_worktree(&request, session.connection_id)
1637 .await?;
1638
1639 broadcast(
1640 Some(session.connection_id),
1641 guest_connection_ids.iter().copied(),
1642 |connection_id| {
1643 session
1644 .peer
1645 .forward_send(session.connection_id, connection_id, request.clone())
1646 },
1647 );
1648 response.send(proto::Ack {})?;
1649 Ok(())
1650}
1651
1652async fn update_diagnostic_summary(
1653 message: proto::UpdateDiagnosticSummary,
1654 session: Session,
1655) -> Result<()> {
1656 let guest_connection_ids = session
1657 .db()
1658 .await
1659 .update_diagnostic_summary(&message, session.connection_id)
1660 .await?;
1661
1662 broadcast(
1663 Some(session.connection_id),
1664 guest_connection_ids.iter().copied(),
1665 |connection_id| {
1666 session
1667 .peer
1668 .forward_send(session.connection_id, connection_id, message.clone())
1669 },
1670 );
1671
1672 Ok(())
1673}
1674
1675async fn update_worktree_settings(
1676 message: proto::UpdateWorktreeSettings,
1677 session: Session,
1678) -> Result<()> {
1679 let guest_connection_ids = session
1680 .db()
1681 .await
1682 .update_worktree_settings(&message, session.connection_id)
1683 .await?;
1684
1685 broadcast(
1686 Some(session.connection_id),
1687 guest_connection_ids.iter().copied(),
1688 |connection_id| {
1689 session
1690 .peer
1691 .forward_send(session.connection_id, connection_id, message.clone())
1692 },
1693 );
1694
1695 Ok(())
1696}
1697
1698async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1699 broadcast_project_message(request.project_id, request, session).await
1700}
1701
1702async fn start_language_server(
1703 request: proto::StartLanguageServer,
1704 session: Session,
1705) -> Result<()> {
1706 let guest_connection_ids = session
1707 .db()
1708 .await
1709 .start_language_server(&request, session.connection_id)
1710 .await?;
1711
1712 broadcast(
1713 Some(session.connection_id),
1714 guest_connection_ids.iter().copied(),
1715 |connection_id| {
1716 session
1717 .peer
1718 .forward_send(session.connection_id, connection_id, request.clone())
1719 },
1720 );
1721 Ok(())
1722}
1723
1724async fn update_language_server(
1725 request: proto::UpdateLanguageServer,
1726 session: Session,
1727) -> Result<()> {
1728 session.executor.record_backtrace();
1729 let project_id = ProjectId::from_proto(request.project_id);
1730 let project_connection_ids = session
1731 .db()
1732 .await
1733 .project_connection_ids(project_id, session.connection_id)
1734 .await?;
1735 broadcast(
1736 Some(session.connection_id),
1737 project_connection_ids.iter().copied(),
1738 |connection_id| {
1739 session
1740 .peer
1741 .forward_send(session.connection_id, connection_id, request.clone())
1742 },
1743 );
1744 Ok(())
1745}
1746
1747async fn forward_project_request<T>(
1748 request: T,
1749 response: Response<T>,
1750 session: Session,
1751) -> Result<()>
1752where
1753 T: EntityMessage + RequestMessage,
1754{
1755 session.executor.record_backtrace();
1756 let project_id = ProjectId::from_proto(request.remote_entity_id());
1757 let host_connection_id = {
1758 let collaborators = session
1759 .db()
1760 .await
1761 .project_collaborators(project_id, session.connection_id)
1762 .await?;
1763 collaborators
1764 .iter()
1765 .find(|collaborator| collaborator.is_host)
1766 .ok_or_else(|| anyhow!("host not found"))?
1767 .connection_id
1768 };
1769
1770 let payload = session
1771 .peer
1772 .forward_request(session.connection_id, host_connection_id, request)
1773 .await?;
1774
1775 response.send(payload)?;
1776 Ok(())
1777}
1778
1779async fn create_buffer_for_peer(
1780 request: proto::CreateBufferForPeer,
1781 session: Session,
1782) -> Result<()> {
1783 session.executor.record_backtrace();
1784 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1785 session
1786 .peer
1787 .forward_send(session.connection_id, peer_id.into(), request)?;
1788 Ok(())
1789}
1790
1791async fn update_buffer(
1792 request: proto::UpdateBuffer,
1793 response: Response<proto::UpdateBuffer>,
1794 session: Session,
1795) -> Result<()> {
1796 session.executor.record_backtrace();
1797 let project_id = ProjectId::from_proto(request.project_id);
1798 let mut guest_connection_ids;
1799 let mut host_connection_id = None;
1800 {
1801 let collaborators = session
1802 .db()
1803 .await
1804 .project_collaborators(project_id, session.connection_id)
1805 .await?;
1806 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1807 for collaborator in collaborators.iter() {
1808 if collaborator.is_host {
1809 host_connection_id = Some(collaborator.connection_id);
1810 } else {
1811 guest_connection_ids.push(collaborator.connection_id);
1812 }
1813 }
1814 }
1815 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1816
1817 session.executor.record_backtrace();
1818 broadcast(
1819 Some(session.connection_id),
1820 guest_connection_ids,
1821 |connection_id| {
1822 session
1823 .peer
1824 .forward_send(session.connection_id, connection_id, request.clone())
1825 },
1826 );
1827 if host_connection_id != session.connection_id {
1828 session
1829 .peer
1830 .forward_request(session.connection_id, host_connection_id, request.clone())
1831 .await?;
1832 }
1833
1834 response.send(proto::Ack {})?;
1835 Ok(())
1836}
1837
1838async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1839 let project_id = ProjectId::from_proto(request.project_id);
1840 let project_connection_ids = session
1841 .db()
1842 .await
1843 .project_connection_ids(project_id, session.connection_id)
1844 .await?;
1845
1846 broadcast(
1847 Some(session.connection_id),
1848 project_connection_ids.iter().copied(),
1849 |connection_id| {
1850 session
1851 .peer
1852 .forward_send(session.connection_id, connection_id, request.clone())
1853 },
1854 );
1855 Ok(())
1856}
1857
1858async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1859 let project_id = ProjectId::from_proto(request.project_id);
1860 let project_connection_ids = session
1861 .db()
1862 .await
1863 .project_connection_ids(project_id, session.connection_id)
1864 .await?;
1865 broadcast(
1866 Some(session.connection_id),
1867 project_connection_ids.iter().copied(),
1868 |connection_id| {
1869 session
1870 .peer
1871 .forward_send(session.connection_id, connection_id, request.clone())
1872 },
1873 );
1874 Ok(())
1875}
1876
1877async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1878 broadcast_project_message(request.project_id, request, session).await
1879}
1880
1881async fn broadcast_project_message<T: EnvelopedMessage>(
1882 project_id: u64,
1883 request: T,
1884 session: Session,
1885) -> Result<()> {
1886 let project_id = ProjectId::from_proto(project_id);
1887 let project_connection_ids = session
1888 .db()
1889 .await
1890 .project_connection_ids(project_id, session.connection_id)
1891 .await?;
1892 broadcast(
1893 Some(session.connection_id),
1894 project_connection_ids.iter().copied(),
1895 |connection_id| {
1896 session
1897 .peer
1898 .forward_send(session.connection_id, connection_id, request.clone())
1899 },
1900 );
1901 Ok(())
1902}
1903
1904async fn follow(
1905 request: proto::Follow,
1906 response: Response<proto::Follow>,
1907 session: Session,
1908) -> Result<()> {
1909 let room_id = RoomId::from_proto(request.room_id);
1910 let project_id = request.project_id.map(ProjectId::from_proto);
1911 let leader_id = request
1912 .leader_id
1913 .ok_or_else(|| anyhow!("invalid leader id"))?
1914 .into();
1915 let follower_id = session.connection_id;
1916
1917 session
1918 .db()
1919 .await
1920 .check_room_participants(room_id, leader_id, session.connection_id)
1921 .await?;
1922
1923 let response_payload = session
1924 .peer
1925 .forward_request(session.connection_id, leader_id, request)
1926 .await?;
1927 response.send(response_payload)?;
1928
1929 if let Some(project_id) = project_id {
1930 let room = session
1931 .db()
1932 .await
1933 .follow(room_id, project_id, leader_id, follower_id)
1934 .await?;
1935 room_updated(&room, &session.peer);
1936 }
1937
1938 Ok(())
1939}
1940
1941async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1942 let room_id = RoomId::from_proto(request.room_id);
1943 let project_id = request.project_id.map(ProjectId::from_proto);
1944 let leader_id = request
1945 .leader_id
1946 .ok_or_else(|| anyhow!("invalid leader id"))?
1947 .into();
1948 let follower_id = session.connection_id;
1949
1950 session
1951 .db()
1952 .await
1953 .check_room_participants(room_id, leader_id, session.connection_id)
1954 .await?;
1955
1956 session
1957 .peer
1958 .forward_send(session.connection_id, leader_id, request)?;
1959
1960 if let Some(project_id) = project_id {
1961 let room = session
1962 .db()
1963 .await
1964 .unfollow(room_id, project_id, leader_id, follower_id)
1965 .await?;
1966 room_updated(&room, &session.peer);
1967 }
1968
1969 Ok(())
1970}
1971
1972async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1973 let room_id = RoomId::from_proto(request.room_id);
1974 let database = session.db.lock().await;
1975
1976 let connection_ids = if let Some(project_id) = request.project_id {
1977 let project_id = ProjectId::from_proto(project_id);
1978 database
1979 .project_connection_ids(project_id, session.connection_id)
1980 .await?
1981 } else {
1982 database
1983 .room_connection_ids(room_id, session.connection_id)
1984 .await?
1985 };
1986
1987 // For now, don't send view update messages back to that view's current leader.
1988 let connection_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
1989 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1990 _ => None,
1991 });
1992
1993 for follower_peer_id in request.follower_ids.iter().copied() {
1994 let follower_connection_id = follower_peer_id.into();
1995 if Some(follower_peer_id) != connection_id_to_omit
1996 && connection_ids.contains(&follower_connection_id)
1997 {
1998 session.peer.forward_send(
1999 session.connection_id,
2000 follower_connection_id,
2001 request.clone(),
2002 )?;
2003 }
2004 }
2005 Ok(())
2006}
2007
2008async fn get_users(
2009 request: proto::GetUsers,
2010 response: Response<proto::GetUsers>,
2011 session: Session,
2012) -> Result<()> {
2013 let user_ids = request
2014 .user_ids
2015 .into_iter()
2016 .map(UserId::from_proto)
2017 .collect();
2018 let users = session
2019 .db()
2020 .await
2021 .get_users_by_ids(user_ids)
2022 .await?
2023 .into_iter()
2024 .map(|user| proto::User {
2025 id: user.id.to_proto(),
2026 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2027 github_login: user.github_login,
2028 })
2029 .collect();
2030 response.send(proto::UsersResponse { users })?;
2031 Ok(())
2032}
2033
2034async fn fuzzy_search_users(
2035 request: proto::FuzzySearchUsers,
2036 response: Response<proto::FuzzySearchUsers>,
2037 session: Session,
2038) -> Result<()> {
2039 let query = request.query;
2040 let users = match query.len() {
2041 0 => vec![],
2042 1 | 2 => session
2043 .db()
2044 .await
2045 .get_user_by_github_login(&query)
2046 .await?
2047 .into_iter()
2048 .collect(),
2049 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2050 };
2051 let users = users
2052 .into_iter()
2053 .filter(|user| user.id != session.user_id)
2054 .map(|user| proto::User {
2055 id: user.id.to_proto(),
2056 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2057 github_login: user.github_login,
2058 })
2059 .collect();
2060 response.send(proto::UsersResponse { users })?;
2061 Ok(())
2062}
2063
2064async fn request_contact(
2065 request: proto::RequestContact,
2066 response: Response<proto::RequestContact>,
2067 session: Session,
2068) -> Result<()> {
2069 let requester_id = session.user_id;
2070 let responder_id = UserId::from_proto(request.responder_id);
2071 if requester_id == responder_id {
2072 return Err(anyhow!("cannot add yourself as a contact"))?;
2073 }
2074
2075 let notifications = session
2076 .db()
2077 .await
2078 .send_contact_request(requester_id, responder_id)
2079 .await?;
2080
2081 // Update outgoing contact requests of requester
2082 let mut update = proto::UpdateContacts::default();
2083 update.outgoing_requests.push(responder_id.to_proto());
2084 for connection_id in session
2085 .connection_pool()
2086 .await
2087 .user_connection_ids(requester_id)
2088 {
2089 session.peer.send(connection_id, update.clone())?;
2090 }
2091
2092 // Update incoming contact requests of responder
2093 let mut update = proto::UpdateContacts::default();
2094 update
2095 .incoming_requests
2096 .push(proto::IncomingContactRequest {
2097 requester_id: requester_id.to_proto(),
2098 });
2099 let connection_pool = session.connection_pool().await;
2100 for connection_id in connection_pool.user_connection_ids(responder_id) {
2101 session.peer.send(connection_id, update.clone())?;
2102 }
2103
2104 send_notifications(&*connection_pool, &session.peer, notifications);
2105
2106 response.send(proto::Ack {})?;
2107 Ok(())
2108}
2109
2110async fn respond_to_contact_request(
2111 request: proto::RespondToContactRequest,
2112 response: Response<proto::RespondToContactRequest>,
2113 session: Session,
2114) -> Result<()> {
2115 let responder_id = session.user_id;
2116 let requester_id = UserId::from_proto(request.requester_id);
2117 let db = session.db().await;
2118 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2119 db.dismiss_contact_notification(responder_id, requester_id)
2120 .await?;
2121 } else {
2122 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2123
2124 let notifications = db
2125 .respond_to_contact_request(responder_id, requester_id, accept)
2126 .await?;
2127 let requester_busy = db.is_user_busy(requester_id).await?;
2128 let responder_busy = db.is_user_busy(responder_id).await?;
2129
2130 let pool = session.connection_pool().await;
2131 // Update responder with new contact
2132 let mut update = proto::UpdateContacts::default();
2133 if accept {
2134 update
2135 .contacts
2136 .push(contact_for_user(requester_id, requester_busy, &pool));
2137 }
2138 update
2139 .remove_incoming_requests
2140 .push(requester_id.to_proto());
2141 for connection_id in pool.user_connection_ids(responder_id) {
2142 session.peer.send(connection_id, update.clone())?;
2143 }
2144
2145 // Update requester with new contact
2146 let mut update = proto::UpdateContacts::default();
2147 if accept {
2148 update
2149 .contacts
2150 .push(contact_for_user(responder_id, responder_busy, &pool));
2151 }
2152 update
2153 .remove_outgoing_requests
2154 .push(responder_id.to_proto());
2155
2156 for connection_id in pool.user_connection_ids(requester_id) {
2157 session.peer.send(connection_id, update.clone())?;
2158 }
2159
2160 send_notifications(&*pool, &session.peer, notifications);
2161 }
2162
2163 response.send(proto::Ack {})?;
2164 Ok(())
2165}
2166
2167async fn remove_contact(
2168 request: proto::RemoveContact,
2169 response: Response<proto::RemoveContact>,
2170 session: Session,
2171) -> Result<()> {
2172 let requester_id = session.user_id;
2173 let responder_id = UserId::from_proto(request.user_id);
2174 let db = session.db().await;
2175 let (contact_accepted, deleted_notification_id) =
2176 db.remove_contact(requester_id, responder_id).await?;
2177
2178 let pool = session.connection_pool().await;
2179 // Update outgoing contact requests of requester
2180 let mut update = proto::UpdateContacts::default();
2181 if contact_accepted {
2182 update.remove_contacts.push(responder_id.to_proto());
2183 } else {
2184 update
2185 .remove_outgoing_requests
2186 .push(responder_id.to_proto());
2187 }
2188 for connection_id in pool.user_connection_ids(requester_id) {
2189 session.peer.send(connection_id, update.clone())?;
2190 }
2191
2192 // Update incoming contact requests of responder
2193 let mut update = proto::UpdateContacts::default();
2194 if contact_accepted {
2195 update.remove_contacts.push(requester_id.to_proto());
2196 } else {
2197 update
2198 .remove_incoming_requests
2199 .push(requester_id.to_proto());
2200 }
2201 for connection_id in pool.user_connection_ids(responder_id) {
2202 session.peer.send(connection_id, update.clone())?;
2203 if let Some(notification_id) = deleted_notification_id {
2204 session.peer.send(
2205 connection_id,
2206 proto::DeleteNotification {
2207 notification_id: notification_id.to_proto(),
2208 },
2209 )?;
2210 }
2211 }
2212
2213 response.send(proto::Ack {})?;
2214 Ok(())
2215}
2216
2217async fn create_channel(
2218 request: proto::CreateChannel,
2219 response: Response<proto::CreateChannel>,
2220 session: Session,
2221) -> Result<()> {
2222 let db = session.db().await;
2223
2224 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2225 let CreateChannelResult {
2226 channel,
2227 participants_to_update,
2228 } = db
2229 .create_channel(&request.name, parent_id, session.user_id)
2230 .await?;
2231
2232 response.send(proto::CreateChannelResponse {
2233 channel: Some(channel.to_proto()),
2234 parent_id: request.parent_id,
2235 })?;
2236
2237 let connection_pool = session.connection_pool().await;
2238 for (user_id, channels) in participants_to_update {
2239 let update = build_channels_update(channels, vec![]);
2240 for connection_id in connection_pool.user_connection_ids(user_id) {
2241 if user_id == session.user_id {
2242 continue;
2243 }
2244 session.peer.send(connection_id, update.clone())?;
2245 }
2246 }
2247
2248 Ok(())
2249}
2250
2251async fn delete_channel(
2252 request: proto::DeleteChannel,
2253 response: Response<proto::DeleteChannel>,
2254 session: Session,
2255) -> Result<()> {
2256 let db = session.db().await;
2257
2258 let channel_id = request.channel_id;
2259 let (removed_channels, member_ids) = db
2260 .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2261 .await?;
2262 response.send(proto::Ack {})?;
2263
2264 // Notify members of removed channels
2265 let mut update = proto::UpdateChannels::default();
2266 update
2267 .delete_channels
2268 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2269
2270 let connection_pool = session.connection_pool().await;
2271 for member_id in member_ids {
2272 for connection_id in connection_pool.user_connection_ids(member_id) {
2273 session.peer.send(connection_id, update.clone())?;
2274 }
2275 }
2276
2277 Ok(())
2278}
2279
2280async fn invite_channel_member(
2281 request: proto::InviteChannelMember,
2282 response: Response<proto::InviteChannelMember>,
2283 session: Session,
2284) -> Result<()> {
2285 let db = session.db().await;
2286 let channel_id = ChannelId::from_proto(request.channel_id);
2287 let invitee_id = UserId::from_proto(request.user_id);
2288 let InviteMemberResult {
2289 channel,
2290 notifications,
2291 } = db
2292 .invite_channel_member(
2293 channel_id,
2294 invitee_id,
2295 session.user_id,
2296 request.role().into(),
2297 )
2298 .await?;
2299
2300 let update = proto::UpdateChannels {
2301 channel_invitations: vec![channel.to_proto()],
2302 ..Default::default()
2303 };
2304
2305 let connection_pool = session.connection_pool().await;
2306 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2307 session.peer.send(connection_id, update.clone())?;
2308 }
2309
2310 send_notifications(&*connection_pool, &session.peer, notifications);
2311
2312 response.send(proto::Ack {})?;
2313 Ok(())
2314}
2315
2316async fn remove_channel_member(
2317 request: proto::RemoveChannelMember,
2318 response: Response<proto::RemoveChannelMember>,
2319 session: Session,
2320) -> Result<()> {
2321 let db = session.db().await;
2322 let channel_id = ChannelId::from_proto(request.channel_id);
2323 let member_id = UserId::from_proto(request.user_id);
2324
2325 let RemoveChannelMemberResult {
2326 membership_update,
2327 notification_id,
2328 } = db
2329 .remove_channel_member(channel_id, member_id, session.user_id)
2330 .await?;
2331
2332 let connection_pool = &session.connection_pool().await;
2333 notify_membership_updated(
2334 &connection_pool,
2335 membership_update,
2336 member_id,
2337 &session.peer,
2338 );
2339 for connection_id in connection_pool.user_connection_ids(member_id) {
2340 if let Some(notification_id) = notification_id {
2341 session
2342 .peer
2343 .send(
2344 connection_id,
2345 proto::DeleteNotification {
2346 notification_id: notification_id.to_proto(),
2347 },
2348 )
2349 .trace_err();
2350 }
2351 }
2352
2353 response.send(proto::Ack {})?;
2354 Ok(())
2355}
2356
2357async fn set_channel_visibility(
2358 request: proto::SetChannelVisibility,
2359 response: Response<proto::SetChannelVisibility>,
2360 session: Session,
2361) -> Result<()> {
2362 let db = session.db().await;
2363 let channel_id = ChannelId::from_proto(request.channel_id);
2364 let visibility = request.visibility().into();
2365
2366 let SetChannelVisibilityResult {
2367 participants_to_update,
2368 participants_to_remove,
2369 channels_to_remove,
2370 } = db
2371 .set_channel_visibility(channel_id, visibility, session.user_id)
2372 .await?;
2373
2374 let connection_pool = session.connection_pool().await;
2375 for (user_id, channels) in participants_to_update {
2376 let update = build_channels_update(channels, vec![]);
2377 for connection_id in connection_pool.user_connection_ids(user_id) {
2378 session.peer.send(connection_id, update.clone())?;
2379 }
2380 }
2381 for user_id in participants_to_remove {
2382 let update = proto::UpdateChannels {
2383 delete_channels: channels_to_remove.iter().map(|id| id.to_proto()).collect(),
2384 ..Default::default()
2385 };
2386 for connection_id in connection_pool.user_connection_ids(user_id) {
2387 session.peer.send(connection_id, update.clone())?;
2388 }
2389 }
2390
2391 response.send(proto::Ack {})?;
2392 Ok(())
2393}
2394
2395async fn set_channel_member_role(
2396 request: proto::SetChannelMemberRole,
2397 response: Response<proto::SetChannelMemberRole>,
2398 session: Session,
2399) -> Result<()> {
2400 let db = session.db().await;
2401 let channel_id = ChannelId::from_proto(request.channel_id);
2402 let member_id = UserId::from_proto(request.user_id);
2403 let result = db
2404 .set_channel_member_role(
2405 channel_id,
2406 session.user_id,
2407 member_id,
2408 request.role().into(),
2409 )
2410 .await?;
2411
2412 match result {
2413 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2414 let connection_pool = session.connection_pool().await;
2415 notify_membership_updated(
2416 &connection_pool,
2417 membership_update,
2418 member_id,
2419 &session.peer,
2420 )
2421 }
2422 db::SetMemberRoleResult::InviteUpdated(channel) => {
2423 let update = proto::UpdateChannels {
2424 channel_invitations: vec![channel.to_proto()],
2425 ..Default::default()
2426 };
2427
2428 for connection_id in session
2429 .connection_pool()
2430 .await
2431 .user_connection_ids(member_id)
2432 {
2433 session.peer.send(connection_id, update.clone())?;
2434 }
2435 }
2436 }
2437
2438 response.send(proto::Ack {})?;
2439 Ok(())
2440}
2441
2442async fn rename_channel(
2443 request: proto::RenameChannel,
2444 response: Response<proto::RenameChannel>,
2445 session: Session,
2446) -> Result<()> {
2447 let db = session.db().await;
2448 let channel_id = ChannelId::from_proto(request.channel_id);
2449 let RenameChannelResult {
2450 channel,
2451 participants_to_update,
2452 } = db
2453 .rename_channel(channel_id, session.user_id, &request.name)
2454 .await?;
2455
2456 response.send(proto::RenameChannelResponse {
2457 channel: Some(channel.to_proto()),
2458 })?;
2459
2460 let connection_pool = session.connection_pool().await;
2461 for (user_id, channel) in participants_to_update {
2462 for connection_id in connection_pool.user_connection_ids(user_id) {
2463 let update = proto::UpdateChannels {
2464 channels: vec![channel.to_proto()],
2465 ..Default::default()
2466 };
2467
2468 session.peer.send(connection_id, update.clone())?;
2469 }
2470 }
2471
2472 Ok(())
2473}
2474
2475// TODO: Implement in terms of symlinks
2476// Current behavior of this is more like 'Move root channel'
2477async fn link_channel(
2478 request: proto::LinkChannel,
2479 response: Response<proto::LinkChannel>,
2480 session: Session,
2481) -> Result<()> {
2482 let db = session.db().await;
2483 let channel_id = ChannelId::from_proto(request.channel_id);
2484 let to = ChannelId::from_proto(request.to);
2485
2486 let result = db
2487 .move_channel(channel_id, None, to, session.user_id)
2488 .await?;
2489 drop(db);
2490
2491 notify_channel_moved(result, session).await?;
2492
2493 response.send(Ack {})?;
2494
2495 Ok(())
2496}
2497
2498// TODO: Implement in terms of symlinks
2499async fn unlink_channel(
2500 _request: proto::UnlinkChannel,
2501 _response: Response<proto::UnlinkChannel>,
2502 _session: Session,
2503) -> Result<()> {
2504 Err(anyhow!("unimplemented").into())
2505}
2506
2507async fn move_channel(
2508 request: proto::MoveChannel,
2509 response: Response<proto::MoveChannel>,
2510 session: Session,
2511) -> Result<()> {
2512 let db = session.db().await;
2513 let channel_id = ChannelId::from_proto(request.channel_id);
2514 let from_parent = ChannelId::from_proto(request.from);
2515 let to = ChannelId::from_proto(request.to);
2516
2517 let result = db
2518 .move_channel(channel_id, Some(from_parent), to, session.user_id)
2519 .await?;
2520 drop(db);
2521
2522 notify_channel_moved(result, session).await?;
2523
2524 response.send(Ack {})?;
2525 Ok(())
2526}
2527
2528async fn notify_channel_moved(result: Option<MoveChannelResult>, session: Session) -> Result<()> {
2529 let Some(MoveChannelResult {
2530 participants_to_remove,
2531 participants_to_update,
2532 moved_channels,
2533 }) = result
2534 else {
2535 return Ok(());
2536 };
2537 let moved_channels: Vec<u64> = moved_channels.iter().map(|id| id.to_proto()).collect();
2538
2539 let connection_pool = session.connection_pool().await;
2540 for (user_id, channels) in participants_to_update {
2541 let mut update = build_channels_update(channels, vec![]);
2542 update.delete_channels = moved_channels.clone();
2543 for connection_id in connection_pool.user_connection_ids(user_id) {
2544 session.peer.send(connection_id, update.clone())?;
2545 }
2546 }
2547
2548 for user_id in participants_to_remove {
2549 let update = proto::UpdateChannels {
2550 delete_channels: moved_channels.clone(),
2551 ..Default::default()
2552 };
2553 for connection_id in connection_pool.user_connection_ids(user_id) {
2554 session.peer.send(connection_id, update.clone())?;
2555 }
2556 }
2557 Ok(())
2558}
2559
2560async fn get_channel_members(
2561 request: proto::GetChannelMembers,
2562 response: Response<proto::GetChannelMembers>,
2563 session: Session,
2564) -> Result<()> {
2565 let db = session.db().await;
2566 let channel_id = ChannelId::from_proto(request.channel_id);
2567 let members = db
2568 .get_channel_participant_details(channel_id, session.user_id)
2569 .await?;
2570 response.send(proto::GetChannelMembersResponse { members })?;
2571 Ok(())
2572}
2573
2574async fn respond_to_channel_invite(
2575 request: proto::RespondToChannelInvite,
2576 response: Response<proto::RespondToChannelInvite>,
2577 session: Session,
2578) -> Result<()> {
2579 let db = session.db().await;
2580 let channel_id = ChannelId::from_proto(request.channel_id);
2581 let RespondToChannelInvite {
2582 membership_update,
2583 notifications,
2584 } = db
2585 .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2586 .await?;
2587
2588 let connection_pool = session.connection_pool().await;
2589 if let Some(membership_update) = membership_update {
2590 notify_membership_updated(
2591 &connection_pool,
2592 membership_update,
2593 session.user_id,
2594 &session.peer,
2595 );
2596 } else {
2597 let update = proto::UpdateChannels {
2598 remove_channel_invitations: vec![channel_id.to_proto()],
2599 ..Default::default()
2600 };
2601
2602 for connection_id in connection_pool.user_connection_ids(session.user_id) {
2603 session.peer.send(connection_id, update.clone())?;
2604 }
2605 };
2606
2607 send_notifications(&*connection_pool, &session.peer, notifications);
2608
2609 response.send(proto::Ack {})?;
2610
2611 Ok(())
2612}
2613
2614async fn join_channel(
2615 request: proto::JoinChannel,
2616 response: Response<proto::JoinChannel>,
2617 session: Session,
2618) -> Result<()> {
2619 let channel_id = ChannelId::from_proto(request.channel_id);
2620 join_channel_internal(channel_id, Box::new(response), session).await
2621}
2622
2623trait JoinChannelInternalResponse {
2624 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2625}
2626impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2627 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2628 Response::<proto::JoinChannel>::send(self, result)
2629 }
2630}
2631impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2632 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2633 Response::<proto::JoinRoom>::send(self, result)
2634 }
2635}
2636
2637async fn join_channel_internal(
2638 channel_id: ChannelId,
2639 response: Box<impl JoinChannelInternalResponse>,
2640 session: Session,
2641) -> Result<()> {
2642 let joined_room = {
2643 leave_room_for_session(&session).await?;
2644 let db = session.db().await;
2645
2646 let (joined_room, membership_updated, role) = db
2647 .join_channel(
2648 channel_id,
2649 session.user_id,
2650 session.connection_id,
2651 RELEASE_CHANNEL_NAME.as_str(),
2652 )
2653 .await?;
2654
2655 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2656 let (can_publish, token) = if role == ChannelRole::Guest {
2657 (
2658 false,
2659 live_kit
2660 .guest_token(
2661 &joined_room.room.live_kit_room,
2662 &session.user_id.to_string(),
2663 )
2664 .trace_err()?,
2665 )
2666 } else {
2667 (
2668 true,
2669 live_kit
2670 .room_token(
2671 &joined_room.room.live_kit_room,
2672 &session.user_id.to_string(),
2673 )
2674 .trace_err()?,
2675 )
2676 };
2677
2678 Some(LiveKitConnectionInfo {
2679 server_url: live_kit.url().into(),
2680 token,
2681 can_publish,
2682 })
2683 });
2684
2685 response.send(proto::JoinRoomResponse {
2686 room: Some(joined_room.room.clone()),
2687 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2688 live_kit_connection_info,
2689 })?;
2690
2691 let connection_pool = session.connection_pool().await;
2692 if let Some(membership_updated) = membership_updated {
2693 notify_membership_updated(
2694 &connection_pool,
2695 membership_updated,
2696 session.user_id,
2697 &session.peer,
2698 );
2699 }
2700
2701 room_updated(&joined_room.room, &session.peer);
2702
2703 joined_room
2704 };
2705
2706 channel_updated(
2707 channel_id,
2708 &joined_room.room,
2709 &joined_room.channel_members,
2710 &session.peer,
2711 &*session.connection_pool().await,
2712 );
2713
2714 update_user_contacts(session.user_id, &session).await?;
2715 Ok(())
2716}
2717
2718async fn join_channel_buffer(
2719 request: proto::JoinChannelBuffer,
2720 response: Response<proto::JoinChannelBuffer>,
2721 session: Session,
2722) -> Result<()> {
2723 let db = session.db().await;
2724 let channel_id = ChannelId::from_proto(request.channel_id);
2725
2726 let open_response = db
2727 .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2728 .await?;
2729
2730 let collaborators = open_response.collaborators.clone();
2731 response.send(open_response)?;
2732
2733 let update = UpdateChannelBufferCollaborators {
2734 channel_id: channel_id.to_proto(),
2735 collaborators: collaborators.clone(),
2736 };
2737 channel_buffer_updated(
2738 session.connection_id,
2739 collaborators
2740 .iter()
2741 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2742 &update,
2743 &session.peer,
2744 );
2745
2746 Ok(())
2747}
2748
2749async fn update_channel_buffer(
2750 request: proto::UpdateChannelBuffer,
2751 session: Session,
2752) -> Result<()> {
2753 let db = session.db().await;
2754 let channel_id = ChannelId::from_proto(request.channel_id);
2755
2756 let (collaborators, non_collaborators, epoch, version) = db
2757 .update_channel_buffer(channel_id, session.user_id, &request.operations)
2758 .await?;
2759
2760 channel_buffer_updated(
2761 session.connection_id,
2762 collaborators,
2763 &proto::UpdateChannelBuffer {
2764 channel_id: channel_id.to_proto(),
2765 operations: request.operations,
2766 },
2767 &session.peer,
2768 );
2769
2770 let pool = &*session.connection_pool().await;
2771
2772 broadcast(
2773 None,
2774 non_collaborators
2775 .iter()
2776 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2777 |peer_id| {
2778 session.peer.send(
2779 peer_id.into(),
2780 proto::UpdateChannels {
2781 unseen_channel_buffer_changes: vec![proto::UnseenChannelBufferChange {
2782 channel_id: channel_id.to_proto(),
2783 epoch: epoch as u64,
2784 version: version.clone(),
2785 }],
2786 ..Default::default()
2787 },
2788 )
2789 },
2790 );
2791
2792 Ok(())
2793}
2794
2795async fn rejoin_channel_buffers(
2796 request: proto::RejoinChannelBuffers,
2797 response: Response<proto::RejoinChannelBuffers>,
2798 session: Session,
2799) -> Result<()> {
2800 let db = session.db().await;
2801 let buffers = db
2802 .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2803 .await?;
2804
2805 for rejoined_buffer in &buffers {
2806 let collaborators_to_notify = rejoined_buffer
2807 .buffer
2808 .collaborators
2809 .iter()
2810 .filter_map(|c| Some(c.peer_id?.into()));
2811 channel_buffer_updated(
2812 session.connection_id,
2813 collaborators_to_notify,
2814 &proto::UpdateChannelBufferCollaborators {
2815 channel_id: rejoined_buffer.buffer.channel_id,
2816 collaborators: rejoined_buffer.buffer.collaborators.clone(),
2817 },
2818 &session.peer,
2819 );
2820 }
2821
2822 response.send(proto::RejoinChannelBuffersResponse {
2823 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2824 })?;
2825
2826 Ok(())
2827}
2828
2829async fn leave_channel_buffer(
2830 request: proto::LeaveChannelBuffer,
2831 response: Response<proto::LeaveChannelBuffer>,
2832 session: Session,
2833) -> Result<()> {
2834 let db = session.db().await;
2835 let channel_id = ChannelId::from_proto(request.channel_id);
2836
2837 let left_buffer = db
2838 .leave_channel_buffer(channel_id, session.connection_id)
2839 .await?;
2840
2841 response.send(Ack {})?;
2842
2843 channel_buffer_updated(
2844 session.connection_id,
2845 left_buffer.connections,
2846 &proto::UpdateChannelBufferCollaborators {
2847 channel_id: channel_id.to_proto(),
2848 collaborators: left_buffer.collaborators,
2849 },
2850 &session.peer,
2851 );
2852
2853 Ok(())
2854}
2855
2856fn channel_buffer_updated<T: EnvelopedMessage>(
2857 sender_id: ConnectionId,
2858 collaborators: impl IntoIterator<Item = ConnectionId>,
2859 message: &T,
2860 peer: &Peer,
2861) {
2862 broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2863 peer.send(peer_id.into(), message.clone())
2864 });
2865}
2866
2867fn send_notifications(
2868 connection_pool: &ConnectionPool,
2869 peer: &Peer,
2870 notifications: db::NotificationBatch,
2871) {
2872 for (user_id, notification) in notifications {
2873 for connection_id in connection_pool.user_connection_ids(user_id) {
2874 if let Err(error) = peer.send(
2875 connection_id,
2876 proto::AddNotification {
2877 notification: Some(notification.clone()),
2878 },
2879 ) {
2880 tracing::error!(
2881 "failed to send notification to {:?} {}",
2882 connection_id,
2883 error
2884 );
2885 }
2886 }
2887 }
2888}
2889
2890async fn send_channel_message(
2891 request: proto::SendChannelMessage,
2892 response: Response<proto::SendChannelMessage>,
2893 session: Session,
2894) -> Result<()> {
2895 // Validate the message body.
2896 let body = request.body.trim().to_string();
2897 if body.len() > MAX_MESSAGE_LEN {
2898 return Err(anyhow!("message is too long"))?;
2899 }
2900 if body.is_empty() {
2901 return Err(anyhow!("message can't be blank"))?;
2902 }
2903
2904 // TODO: adjust mentions if body is trimmed
2905
2906 let timestamp = OffsetDateTime::now_utc();
2907 let nonce = request
2908 .nonce
2909 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
2910
2911 let channel_id = ChannelId::from_proto(request.channel_id);
2912 let CreatedChannelMessage {
2913 message_id,
2914 participant_connection_ids,
2915 channel_members,
2916 notifications,
2917 } = session
2918 .db()
2919 .await
2920 .create_channel_message(
2921 channel_id,
2922 session.user_id,
2923 &body,
2924 &request.mentions,
2925 timestamp,
2926 nonce.clone().into(),
2927 )
2928 .await?;
2929 let message = proto::ChannelMessage {
2930 sender_id: session.user_id.to_proto(),
2931 id: message_id.to_proto(),
2932 body,
2933 mentions: request.mentions,
2934 timestamp: timestamp.unix_timestamp() as u64,
2935 nonce: Some(nonce),
2936 };
2937 broadcast(
2938 Some(session.connection_id),
2939 participant_connection_ids,
2940 |connection| {
2941 session.peer.send(
2942 connection,
2943 proto::ChannelMessageSent {
2944 channel_id: channel_id.to_proto(),
2945 message: Some(message.clone()),
2946 },
2947 )
2948 },
2949 );
2950 response.send(proto::SendChannelMessageResponse {
2951 message: Some(message),
2952 })?;
2953
2954 let pool = &*session.connection_pool().await;
2955 broadcast(
2956 None,
2957 channel_members
2958 .iter()
2959 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2960 |peer_id| {
2961 session.peer.send(
2962 peer_id.into(),
2963 proto::UpdateChannels {
2964 unseen_channel_messages: vec![proto::UnseenChannelMessage {
2965 channel_id: channel_id.to_proto(),
2966 message_id: message_id.to_proto(),
2967 }],
2968 ..Default::default()
2969 },
2970 )
2971 },
2972 );
2973 send_notifications(pool, &session.peer, notifications);
2974
2975 Ok(())
2976}
2977
2978async fn remove_channel_message(
2979 request: proto::RemoveChannelMessage,
2980 response: Response<proto::RemoveChannelMessage>,
2981 session: Session,
2982) -> Result<()> {
2983 let channel_id = ChannelId::from_proto(request.channel_id);
2984 let message_id = MessageId::from_proto(request.message_id);
2985 let connection_ids = session
2986 .db()
2987 .await
2988 .remove_channel_message(channel_id, message_id, session.user_id)
2989 .await?;
2990 broadcast(Some(session.connection_id), connection_ids, |connection| {
2991 session.peer.send(connection, request.clone())
2992 });
2993 response.send(proto::Ack {})?;
2994 Ok(())
2995}
2996
2997async fn acknowledge_channel_message(
2998 request: proto::AckChannelMessage,
2999 session: Session,
3000) -> Result<()> {
3001 let channel_id = ChannelId::from_proto(request.channel_id);
3002 let message_id = MessageId::from_proto(request.message_id);
3003 let notifications = session
3004 .db()
3005 .await
3006 .observe_channel_message(channel_id, session.user_id, message_id)
3007 .await?;
3008 send_notifications(
3009 &*session.connection_pool().await,
3010 &session.peer,
3011 notifications,
3012 );
3013 Ok(())
3014}
3015
3016async fn acknowledge_buffer_version(
3017 request: proto::AckBufferOperation,
3018 session: Session,
3019) -> Result<()> {
3020 let buffer_id = BufferId::from_proto(request.buffer_id);
3021 session
3022 .db()
3023 .await
3024 .observe_buffer_version(
3025 buffer_id,
3026 session.user_id,
3027 request.epoch as i32,
3028 &request.version,
3029 )
3030 .await?;
3031 Ok(())
3032}
3033
3034async fn join_channel_chat(
3035 request: proto::JoinChannelChat,
3036 response: Response<proto::JoinChannelChat>,
3037 session: Session,
3038) -> Result<()> {
3039 let channel_id = ChannelId::from_proto(request.channel_id);
3040
3041 let db = session.db().await;
3042 db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3043 .await?;
3044 let messages = db
3045 .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3046 .await?;
3047 response.send(proto::JoinChannelChatResponse {
3048 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3049 messages,
3050 })?;
3051 Ok(())
3052}
3053
3054async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3055 let channel_id = ChannelId::from_proto(request.channel_id);
3056 session
3057 .db()
3058 .await
3059 .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3060 .await?;
3061 Ok(())
3062}
3063
3064async fn get_channel_messages(
3065 request: proto::GetChannelMessages,
3066 response: Response<proto::GetChannelMessages>,
3067 session: Session,
3068) -> Result<()> {
3069 let channel_id = ChannelId::from_proto(request.channel_id);
3070 let messages = session
3071 .db()
3072 .await
3073 .get_channel_messages(
3074 channel_id,
3075 session.user_id,
3076 MESSAGE_COUNT_PER_PAGE,
3077 Some(MessageId::from_proto(request.before_message_id)),
3078 )
3079 .await?;
3080 response.send(proto::GetChannelMessagesResponse {
3081 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3082 messages,
3083 })?;
3084 Ok(())
3085}
3086
3087async fn get_channel_messages_by_id(
3088 request: proto::GetChannelMessagesById,
3089 response: Response<proto::GetChannelMessagesById>,
3090 session: Session,
3091) -> Result<()> {
3092 let message_ids = request
3093 .message_ids
3094 .iter()
3095 .map(|id| MessageId::from_proto(*id))
3096 .collect::<Vec<_>>();
3097 let messages = session
3098 .db()
3099 .await
3100 .get_channel_messages_by_id(session.user_id, &message_ids)
3101 .await?;
3102 response.send(proto::GetChannelMessagesResponse {
3103 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3104 messages,
3105 })?;
3106 Ok(())
3107}
3108
3109async fn get_notifications(
3110 request: proto::GetNotifications,
3111 response: Response<proto::GetNotifications>,
3112 session: Session,
3113) -> Result<()> {
3114 let notifications = session
3115 .db()
3116 .await
3117 .get_notifications(
3118 session.user_id,
3119 NOTIFICATION_COUNT_PER_PAGE,
3120 request
3121 .before_id
3122 .map(|id| db::NotificationId::from_proto(id)),
3123 )
3124 .await?;
3125 response.send(proto::GetNotificationsResponse {
3126 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3127 notifications,
3128 })?;
3129 Ok(())
3130}
3131
3132async fn mark_notification_as_read(
3133 request: proto::MarkNotificationRead,
3134 response: Response<proto::MarkNotificationRead>,
3135 session: Session,
3136) -> Result<()> {
3137 let database = &session.db().await;
3138 let notifications = database
3139 .mark_notification_as_read_by_id(
3140 session.user_id,
3141 NotificationId::from_proto(request.notification_id),
3142 )
3143 .await?;
3144 send_notifications(
3145 &*session.connection_pool().await,
3146 &session.peer,
3147 notifications,
3148 );
3149 response.send(proto::Ack {})?;
3150 Ok(())
3151}
3152
3153async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
3154 let project_id = ProjectId::from_proto(request.project_id);
3155 let project_connection_ids = session
3156 .db()
3157 .await
3158 .project_connection_ids(project_id, session.connection_id)
3159 .await?;
3160 broadcast(
3161 Some(session.connection_id),
3162 project_connection_ids.iter().copied(),
3163 |connection_id| {
3164 session
3165 .peer
3166 .forward_send(session.connection_id, connection_id, request.clone())
3167 },
3168 );
3169 Ok(())
3170}
3171
3172async fn get_private_user_info(
3173 _request: proto::GetPrivateUserInfo,
3174 response: Response<proto::GetPrivateUserInfo>,
3175 session: Session,
3176) -> Result<()> {
3177 let db = session.db().await;
3178
3179 let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3180 let user = db
3181 .get_user_by_id(session.user_id)
3182 .await?
3183 .ok_or_else(|| anyhow!("user not found"))?;
3184 let flags = db.get_user_flags(session.user_id).await?;
3185
3186 response.send(proto::GetPrivateUserInfoResponse {
3187 metrics_id,
3188 staff: user.admin,
3189 flags,
3190 })?;
3191 Ok(())
3192}
3193
3194fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3195 match message {
3196 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3197 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3198 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3199 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3200 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3201 code: frame.code.into(),
3202 reason: frame.reason,
3203 })),
3204 }
3205}
3206
3207fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3208 match message {
3209 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3210 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3211 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3212 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3213 AxumMessage::Close(frame) => {
3214 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3215 code: frame.code.into(),
3216 reason: frame.reason,
3217 }))
3218 }
3219 }
3220}
3221
3222fn notify_membership_updated(
3223 connection_pool: &ConnectionPool,
3224 result: MembershipUpdated,
3225 user_id: UserId,
3226 peer: &Peer,
3227) {
3228 let mut update = build_channels_update(result.new_channels, vec![]);
3229 update.delete_channels = result
3230 .removed_channels
3231 .into_iter()
3232 .map(|id| id.to_proto())
3233 .collect();
3234 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3235
3236 for connection_id in connection_pool.user_connection_ids(user_id) {
3237 peer.send(connection_id, update.clone()).trace_err();
3238 }
3239}
3240
3241fn build_channels_update(
3242 channels: ChannelsForUser,
3243 channel_invites: Vec<db::Channel>,
3244) -> proto::UpdateChannels {
3245 let mut update = proto::UpdateChannels::default();
3246
3247 for channel in channels.channels.channels {
3248 update.channels.push(proto::Channel {
3249 id: channel.id.to_proto(),
3250 name: channel.name,
3251 visibility: channel.visibility.into(),
3252 role: channel.role.into(),
3253 });
3254 }
3255
3256 update.unseen_channel_buffer_changes = channels.unseen_buffer_changes;
3257 update.unseen_channel_messages = channels.channel_messages;
3258 update.insert_edge = channels.channels.edges;
3259
3260 for (channel_id, participants) in channels.channel_participants {
3261 update
3262 .channel_participants
3263 .push(proto::ChannelParticipants {
3264 channel_id: channel_id.to_proto(),
3265 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3266 });
3267 }
3268
3269 for channel in channel_invites {
3270 update.channel_invitations.push(proto::Channel {
3271 id: channel.id.to_proto(),
3272 name: channel.name,
3273 visibility: channel.visibility.into(),
3274 role: channel.role.into(),
3275 });
3276 }
3277
3278 update
3279}
3280
3281fn build_initial_contacts_update(
3282 contacts: Vec<db::Contact>,
3283 pool: &ConnectionPool,
3284) -> proto::UpdateContacts {
3285 let mut update = proto::UpdateContacts::default();
3286
3287 for contact in contacts {
3288 match contact {
3289 db::Contact::Accepted { user_id, busy } => {
3290 update.contacts.push(contact_for_user(user_id, busy, &pool));
3291 }
3292 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3293 db::Contact::Incoming { user_id } => {
3294 update
3295 .incoming_requests
3296 .push(proto::IncomingContactRequest {
3297 requester_id: user_id.to_proto(),
3298 })
3299 }
3300 }
3301 }
3302
3303 update
3304}
3305
3306fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3307 proto::Contact {
3308 user_id: user_id.to_proto(),
3309 online: pool.is_user_online(user_id),
3310 busy,
3311 }
3312}
3313
3314fn room_updated(room: &proto::Room, peer: &Peer) {
3315 broadcast(
3316 None,
3317 room.participants
3318 .iter()
3319 .filter_map(|participant| Some(participant.peer_id?.into())),
3320 |peer_id| {
3321 peer.send(
3322 peer_id.into(),
3323 proto::RoomUpdated {
3324 room: Some(room.clone()),
3325 },
3326 )
3327 },
3328 );
3329}
3330
3331fn channel_updated(
3332 channel_id: ChannelId,
3333 room: &proto::Room,
3334 channel_members: &[UserId],
3335 peer: &Peer,
3336 pool: &ConnectionPool,
3337) {
3338 let participants = room
3339 .participants
3340 .iter()
3341 .map(|p| p.user_id)
3342 .collect::<Vec<_>>();
3343
3344 broadcast(
3345 None,
3346 channel_members
3347 .iter()
3348 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3349 |peer_id| {
3350 peer.send(
3351 peer_id.into(),
3352 proto::UpdateChannels {
3353 channel_participants: vec![proto::ChannelParticipants {
3354 channel_id: channel_id.to_proto(),
3355 participant_user_ids: participants.clone(),
3356 }],
3357 ..Default::default()
3358 },
3359 )
3360 },
3361 );
3362}
3363
3364async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3365 let db = session.db().await;
3366
3367 let contacts = db.get_contacts(user_id).await?;
3368 let busy = db.is_user_busy(user_id).await?;
3369
3370 let pool = session.connection_pool().await;
3371 let updated_contact = contact_for_user(user_id, busy, &pool);
3372 for contact in contacts {
3373 if let db::Contact::Accepted {
3374 user_id: contact_user_id,
3375 ..
3376 } = contact
3377 {
3378 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3379 session
3380 .peer
3381 .send(
3382 contact_conn_id,
3383 proto::UpdateContacts {
3384 contacts: vec![updated_contact.clone()],
3385 remove_contacts: Default::default(),
3386 incoming_requests: Default::default(),
3387 remove_incoming_requests: Default::default(),
3388 outgoing_requests: Default::default(),
3389 remove_outgoing_requests: Default::default(),
3390 },
3391 )
3392 .trace_err();
3393 }
3394 }
3395 }
3396 Ok(())
3397}
3398
3399async fn leave_room_for_session(session: &Session) -> Result<()> {
3400 let mut contacts_to_update = HashSet::default();
3401
3402 let room_id;
3403 let canceled_calls_to_user_ids;
3404 let live_kit_room;
3405 let delete_live_kit_room;
3406 let room;
3407 let channel_members;
3408 let channel_id;
3409
3410 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3411 contacts_to_update.insert(session.user_id);
3412
3413 for project in left_room.left_projects.values() {
3414 project_left(project, session);
3415 }
3416
3417 room_id = RoomId::from_proto(left_room.room.id);
3418 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3419 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3420 delete_live_kit_room = left_room.deleted;
3421 room = mem::take(&mut left_room.room);
3422 channel_members = mem::take(&mut left_room.channel_members);
3423 channel_id = left_room.channel_id;
3424
3425 room_updated(&room, &session.peer);
3426 } else {
3427 return Ok(());
3428 }
3429
3430 if let Some(channel_id) = channel_id {
3431 channel_updated(
3432 channel_id,
3433 &room,
3434 &channel_members,
3435 &session.peer,
3436 &*session.connection_pool().await,
3437 );
3438 }
3439
3440 {
3441 let pool = session.connection_pool().await;
3442 for canceled_user_id in canceled_calls_to_user_ids {
3443 for connection_id in pool.user_connection_ids(canceled_user_id) {
3444 session
3445 .peer
3446 .send(
3447 connection_id,
3448 proto::CallCanceled {
3449 room_id: room_id.to_proto(),
3450 },
3451 )
3452 .trace_err();
3453 }
3454 contacts_to_update.insert(canceled_user_id);
3455 }
3456 }
3457
3458 for contact_user_id in contacts_to_update {
3459 update_user_contacts(contact_user_id, &session).await?;
3460 }
3461
3462 if let Some(live_kit) = session.live_kit_client.as_ref() {
3463 live_kit
3464 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3465 .await
3466 .trace_err();
3467
3468 if delete_live_kit_room {
3469 live_kit.delete_room(live_kit_room).await.trace_err();
3470 }
3471 }
3472
3473 Ok(())
3474}
3475
3476async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3477 let left_channel_buffers = session
3478 .db()
3479 .await
3480 .leave_channel_buffers(session.connection_id)
3481 .await?;
3482
3483 for left_buffer in left_channel_buffers {
3484 channel_buffer_updated(
3485 session.connection_id,
3486 left_buffer.connections,
3487 &proto::UpdateChannelBufferCollaborators {
3488 channel_id: left_buffer.channel_id.to_proto(),
3489 collaborators: left_buffer.collaborators,
3490 },
3491 &session.peer,
3492 );
3493 }
3494
3495 Ok(())
3496}
3497
3498fn project_left(project: &db::LeftProject, session: &Session) {
3499 for connection_id in &project.connection_ids {
3500 if project.host_user_id == session.user_id {
3501 session
3502 .peer
3503 .send(
3504 *connection_id,
3505 proto::UnshareProject {
3506 project_id: project.id.to_proto(),
3507 },
3508 )
3509 .trace_err();
3510 } else {
3511 session
3512 .peer
3513 .send(
3514 *connection_id,
3515 proto::RemoveProjectCollaborator {
3516 project_id: project.id.to_proto(),
3517 peer_id: Some(session.connection_id.into()),
3518 },
3519 )
3520 .trace_err();
3521 }
3522 }
3523}
3524
3525pub trait ResultExt {
3526 type Ok;
3527
3528 fn trace_err(self) -> Option<Self::Ok>;
3529}
3530
3531impl<T, E> ResultExt for Result<T, E>
3532where
3533 E: std::fmt::Debug,
3534{
3535 type Ok = T;
3536
3537 fn trace_err(self) -> Option<T> {
3538 match self {
3539 Ok(value) => Some(value),
3540 Err(error) => {
3541 tracing::error!("{:?}", error);
3542 None
3543 }
3544 }
3545 }
3546}