1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{
6 self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreateChannelResult,
7 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
8 MoveChannelResult, NotificationId, ProjectId, RemoveChannelMemberResult,
9 RenameChannelResult, RespondToChannelInvite, RoomId, ServerId, SetChannelVisibilityResult,
10 User, UserId,
11 },
12 executor::Executor,
13 AppState, Result,
14};
15use anyhow::anyhow;
16use async_tungstenite::tungstenite::{
17 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
18};
19use axum::{
20 body::Body,
21 extract::{
22 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
23 ConnectInfo, WebSocketUpgrade,
24 },
25 headers::{Header, HeaderName},
26 http::StatusCode,
27 middleware,
28 response::IntoResponse,
29 routing::get,
30 Extension, Router, TypedHeader,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::ConnectionPool;
34use futures::{
35 channel::oneshot,
36 future::{self, BoxFuture},
37 stream::FuturesUnordered,
38 FutureExt, SinkExt, StreamExt, TryStreamExt,
39};
40use lazy_static::lazy_static;
41use prometheus::{register_int_gauge, IntGauge};
42use rpc::{
43 proto::{
44 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
45 RequestMessage, UpdateChannelBufferCollaborators,
46 },
47 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
48};
49use serde::{Serialize, Serializer};
50use std::{
51 any::TypeId,
52 fmt,
53 future::Future,
54 marker::PhantomData,
55 mem,
56 net::SocketAddr,
57 ops::{Deref, DerefMut},
58 rc::Rc,
59 sync::{
60 atomic::{AtomicBool, Ordering::SeqCst},
61 Arc,
62 },
63 time::{Duration, Instant},
64};
65use time::OffsetDateTime;
66use tokio::sync::{watch, Semaphore};
67use tower::ServiceBuilder;
68use tracing::{info_span, instrument, Instrument};
69use util::channel::RELEASE_CHANNEL_NAME;
70
71pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
72pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
73
74const MESSAGE_COUNT_PER_PAGE: usize = 100;
75const MAX_MESSAGE_LEN: usize = 1024;
76const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
77
78lazy_static! {
79 static ref METRIC_CONNECTIONS: IntGauge =
80 register_int_gauge!("connections", "number of connections").unwrap();
81 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
82 "shared_projects",
83 "number of open projects with one or more guests"
84 )
85 .unwrap();
86}
87
88type MessageHandler =
89 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
90
91struct Response<R> {
92 peer: Arc<Peer>,
93 receipt: Receipt<R>,
94 responded: Arc<AtomicBool>,
95}
96
97impl<R: RequestMessage> Response<R> {
98 fn send(self, payload: R::Response) -> Result<()> {
99 self.responded.store(true, SeqCst);
100 self.peer.respond(self.receipt, payload)?;
101 Ok(())
102 }
103}
104
105#[derive(Clone)]
106struct Session {
107 user_id: UserId,
108 connection_id: ConnectionId,
109 db: Arc<tokio::sync::Mutex<DbHandle>>,
110 peer: Arc<Peer>,
111 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
112 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
113 executor: Executor,
114}
115
116impl Session {
117 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
118 #[cfg(test)]
119 tokio::task::yield_now().await;
120 let guard = self.db.lock().await;
121 #[cfg(test)]
122 tokio::task::yield_now().await;
123 guard
124 }
125
126 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
127 #[cfg(test)]
128 tokio::task::yield_now().await;
129 let guard = self.connection_pool.lock();
130 ConnectionPoolGuard {
131 guard,
132 _not_send: PhantomData,
133 }
134 }
135}
136
137impl fmt::Debug for Session {
138 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
139 f.debug_struct("Session")
140 .field("user_id", &self.user_id)
141 .field("connection_id", &self.connection_id)
142 .finish()
143 }
144}
145
146struct DbHandle(Arc<Database>);
147
148impl Deref for DbHandle {
149 type Target = Database;
150
151 fn deref(&self) -> &Self::Target {
152 self.0.as_ref()
153 }
154}
155
156pub struct Server {
157 id: parking_lot::Mutex<ServerId>,
158 peer: Arc<Peer>,
159 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
160 app_state: Arc<AppState>,
161 executor: Executor,
162 handlers: HashMap<TypeId, MessageHandler>,
163 teardown: watch::Sender<()>,
164}
165
166pub(crate) struct ConnectionPoolGuard<'a> {
167 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
168 _not_send: PhantomData<Rc<()>>,
169}
170
171#[derive(Serialize)]
172pub struct ServerSnapshot<'a> {
173 peer: &'a Peer,
174 #[serde(serialize_with = "serialize_deref")]
175 connection_pool: ConnectionPoolGuard<'a>,
176}
177
178pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
179where
180 S: Serializer,
181 T: Deref<Target = U>,
182 U: Serialize,
183{
184 Serialize::serialize(value.deref(), serializer)
185}
186
187impl Server {
188 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
189 let mut server = Self {
190 id: parking_lot::Mutex::new(id),
191 peer: Peer::new(id.0 as u32),
192 app_state,
193 executor,
194 connection_pool: Default::default(),
195 handlers: Default::default(),
196 teardown: watch::channel(()).0,
197 };
198
199 server
200 .add_request_handler(ping)
201 .add_request_handler(create_room)
202 .add_request_handler(join_room)
203 .add_request_handler(rejoin_room)
204 .add_request_handler(leave_room)
205 .add_request_handler(call)
206 .add_request_handler(cancel_call)
207 .add_message_handler(decline_call)
208 .add_request_handler(update_participant_location)
209 .add_request_handler(share_project)
210 .add_message_handler(unshare_project)
211 .add_request_handler(join_project)
212 .add_message_handler(leave_project)
213 .add_request_handler(update_project)
214 .add_request_handler(update_worktree)
215 .add_message_handler(start_language_server)
216 .add_message_handler(update_language_server)
217 .add_message_handler(update_diagnostic_summary)
218 .add_message_handler(update_worktree_settings)
219 .add_message_handler(refresh_inlay_hints)
220 .add_request_handler(forward_project_request::<proto::GetHover>)
221 .add_request_handler(forward_project_request::<proto::GetDefinition>)
222 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
223 .add_request_handler(forward_project_request::<proto::GetReferences>)
224 .add_request_handler(forward_project_request::<proto::SearchProject>)
225 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
226 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
227 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
228 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
229 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
230 .add_request_handler(forward_project_request::<proto::GetCompletions>)
231 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
232 .add_request_handler(forward_project_request::<proto::ResolveCompletionDocumentation>)
233 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
234 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
235 .add_request_handler(forward_project_request::<proto::PrepareRename>)
236 .add_request_handler(forward_project_request::<proto::PerformRename>)
237 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
238 .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
239 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
240 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
241 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
242 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
243 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
244 .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
245 .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
246 .add_request_handler(forward_project_request::<proto::InlayHints>)
247 .add_message_handler(create_buffer_for_peer)
248 .add_request_handler(update_buffer)
249 .add_message_handler(update_buffer_file)
250 .add_message_handler(buffer_reloaded)
251 .add_message_handler(buffer_saved)
252 .add_request_handler(forward_project_request::<proto::SaveBuffer>)
253 .add_request_handler(get_users)
254 .add_request_handler(fuzzy_search_users)
255 .add_request_handler(request_contact)
256 .add_request_handler(remove_contact)
257 .add_request_handler(respond_to_contact_request)
258 .add_request_handler(create_channel)
259 .add_request_handler(delete_channel)
260 .add_request_handler(invite_channel_member)
261 .add_request_handler(remove_channel_member)
262 .add_request_handler(set_channel_member_role)
263 .add_request_handler(set_channel_visibility)
264 .add_request_handler(rename_channel)
265 .add_request_handler(join_channel_buffer)
266 .add_request_handler(leave_channel_buffer)
267 .add_message_handler(update_channel_buffer)
268 .add_request_handler(rejoin_channel_buffers)
269 .add_request_handler(get_channel_members)
270 .add_request_handler(respond_to_channel_invite)
271 .add_request_handler(join_channel)
272 .add_request_handler(join_channel_chat)
273 .add_message_handler(leave_channel_chat)
274 .add_request_handler(send_channel_message)
275 .add_request_handler(remove_channel_message)
276 .add_request_handler(get_channel_messages)
277 .add_request_handler(get_channel_messages_by_id)
278 .add_request_handler(get_notifications)
279 .add_request_handler(mark_notification_as_read)
280 .add_request_handler(link_channel)
281 .add_request_handler(unlink_channel)
282 .add_request_handler(move_channel)
283 .add_request_handler(follow)
284 .add_message_handler(unfollow)
285 .add_message_handler(update_followers)
286 .add_message_handler(update_diff_base)
287 .add_request_handler(get_private_user_info)
288 .add_message_handler(acknowledge_channel_message)
289 .add_message_handler(acknowledge_buffer_version);
290
291 Arc::new(server)
292 }
293
294 pub async fn start(&self) -> Result<()> {
295 let server_id = *self.id.lock();
296 let app_state = self.app_state.clone();
297 let peer = self.peer.clone();
298 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
299 let pool = self.connection_pool.clone();
300 let live_kit_client = self.app_state.live_kit_client.clone();
301
302 let span = info_span!("start server");
303 self.executor.spawn_detached(
304 async move {
305 tracing::info!("waiting for cleanup timeout");
306 timeout.await;
307 tracing::info!("cleanup timeout expired, retrieving stale rooms");
308 if let Some((room_ids, channel_ids)) = app_state
309 .db
310 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
311 .await
312 .trace_err()
313 {
314 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
315 tracing::info!(
316 stale_channel_buffer_count = channel_ids.len(),
317 "retrieved stale channel buffers"
318 );
319
320 for channel_id in channel_ids {
321 if let Some(refreshed_channel_buffer) = app_state
322 .db
323 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
324 .await
325 .trace_err()
326 {
327 for connection_id in refreshed_channel_buffer.connection_ids {
328 peer.send(
329 connection_id,
330 proto::UpdateChannelBufferCollaborators {
331 channel_id: channel_id.to_proto(),
332 collaborators: refreshed_channel_buffer
333 .collaborators
334 .clone(),
335 },
336 )
337 .trace_err();
338 }
339 }
340 }
341
342 for room_id in room_ids {
343 let mut contacts_to_update = HashSet::default();
344 let mut canceled_calls_to_user_ids = Vec::new();
345 let mut live_kit_room = String::new();
346 let mut delete_live_kit_room = false;
347
348 if let Some(mut refreshed_room) = app_state
349 .db
350 .clear_stale_room_participants(room_id, server_id)
351 .await
352 .trace_err()
353 {
354 tracing::info!(
355 room_id = room_id.0,
356 new_participant_count = refreshed_room.room.participants.len(),
357 "refreshed room"
358 );
359 room_updated(&refreshed_room.room, &peer);
360 if let Some(channel_id) = refreshed_room.channel_id {
361 channel_updated(
362 channel_id,
363 &refreshed_room.room,
364 &refreshed_room.channel_members,
365 &peer,
366 &*pool.lock(),
367 );
368 }
369 contacts_to_update
370 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
371 contacts_to_update
372 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
373 canceled_calls_to_user_ids =
374 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
375 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
376 delete_live_kit_room = refreshed_room.room.participants.is_empty();
377 }
378
379 {
380 let pool = pool.lock();
381 for canceled_user_id in canceled_calls_to_user_ids {
382 for connection_id in pool.user_connection_ids(canceled_user_id) {
383 peer.send(
384 connection_id,
385 proto::CallCanceled {
386 room_id: room_id.to_proto(),
387 },
388 )
389 .trace_err();
390 }
391 }
392 }
393
394 for user_id in contacts_to_update {
395 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
396 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
397 if let Some((busy, contacts)) = busy.zip(contacts) {
398 let pool = pool.lock();
399 let updated_contact = contact_for_user(user_id, busy, &pool);
400 for contact in contacts {
401 if let db::Contact::Accepted {
402 user_id: contact_user_id,
403 ..
404 } = contact
405 {
406 for contact_conn_id in
407 pool.user_connection_ids(contact_user_id)
408 {
409 peer.send(
410 contact_conn_id,
411 proto::UpdateContacts {
412 contacts: vec![updated_contact.clone()],
413 remove_contacts: Default::default(),
414 incoming_requests: Default::default(),
415 remove_incoming_requests: Default::default(),
416 outgoing_requests: Default::default(),
417 remove_outgoing_requests: Default::default(),
418 },
419 )
420 .trace_err();
421 }
422 }
423 }
424 }
425 }
426
427 if let Some(live_kit) = live_kit_client.as_ref() {
428 if delete_live_kit_room {
429 live_kit.delete_room(live_kit_room).await.trace_err();
430 }
431 }
432 }
433 }
434
435 app_state
436 .db
437 .delete_stale_servers(&app_state.config.zed_environment, server_id)
438 .await
439 .trace_err();
440 }
441 .instrument(span),
442 );
443 Ok(())
444 }
445
446 pub fn teardown(&self) {
447 self.peer.teardown();
448 self.connection_pool.lock().reset();
449 let _ = self.teardown.send(());
450 }
451
452 #[cfg(test)]
453 pub fn reset(&self, id: ServerId) {
454 self.teardown();
455 *self.id.lock() = id;
456 self.peer.reset(id.0 as u32);
457 }
458
459 #[cfg(test)]
460 pub fn id(&self) -> ServerId {
461 *self.id.lock()
462 }
463
464 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
465 where
466 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
467 Fut: 'static + Send + Future<Output = Result<()>>,
468 M: EnvelopedMessage,
469 {
470 let prev_handler = self.handlers.insert(
471 TypeId::of::<M>(),
472 Box::new(move |envelope, session| {
473 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
474 let span = info_span!(
475 "handle message",
476 payload_type = envelope.payload_type_name()
477 );
478 span.in_scope(|| {
479 tracing::info!(
480 payload_type = envelope.payload_type_name(),
481 "message received"
482 );
483 });
484 let start_time = Instant::now();
485 let future = (handler)(*envelope, session);
486 async move {
487 let result = future.await;
488 let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
489 match result {
490 Err(error) => {
491 tracing::error!(%error, ?duration_ms, "error handling message")
492 }
493 Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
494 }
495 }
496 .instrument(span)
497 .boxed()
498 }),
499 );
500 if prev_handler.is_some() {
501 panic!("registered a handler for the same message twice");
502 }
503 self
504 }
505
506 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
507 where
508 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
509 Fut: 'static + Send + Future<Output = Result<()>>,
510 M: EnvelopedMessage,
511 {
512 self.add_handler(move |envelope, session| handler(envelope.payload, session));
513 self
514 }
515
516 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
517 where
518 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
519 Fut: Send + Future<Output = Result<()>>,
520 M: RequestMessage,
521 {
522 let handler = Arc::new(handler);
523 self.add_handler(move |envelope, session| {
524 let receipt = envelope.receipt();
525 let handler = handler.clone();
526 async move {
527 let peer = session.peer.clone();
528 let responded = Arc::new(AtomicBool::default());
529 let response = Response {
530 peer: peer.clone(),
531 responded: responded.clone(),
532 receipt,
533 };
534 match (handler)(envelope.payload, response, session).await {
535 Ok(()) => {
536 if responded.load(std::sync::atomic::Ordering::SeqCst) {
537 Ok(())
538 } else {
539 Err(anyhow!("handler did not send a response"))?
540 }
541 }
542 Err(error) => {
543 peer.respond_with_error(
544 receipt,
545 proto::Error {
546 message: error.to_string(),
547 },
548 )?;
549 Err(error)
550 }
551 }
552 }
553 })
554 }
555
556 pub fn handle_connection(
557 self: &Arc<Self>,
558 connection: Connection,
559 address: String,
560 user: User,
561 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
562 executor: Executor,
563 ) -> impl Future<Output = Result<()>> {
564 let this = self.clone();
565 let user_id = user.id;
566 let login = user.github_login;
567 let span = info_span!("handle connection", %user_id, %login, %address);
568 let mut teardown = self.teardown.subscribe();
569 async move {
570 let (connection_id, handle_io, mut incoming_rx) = this
571 .peer
572 .add_connection(connection, {
573 let executor = executor.clone();
574 move |duration| executor.sleep(duration)
575 });
576
577 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
578 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
579 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
580
581 if let Some(send_connection_id) = send_connection_id.take() {
582 let _ = send_connection_id.send(connection_id);
583 }
584
585 if !user.connected_once {
586 this.peer.send(connection_id, proto::ShowContacts {})?;
587 this.app_state.db.set_user_connected_once(user_id, true).await?;
588 }
589
590 let (contacts, channels_for_user, channel_invites) = future::try_join3(
591 this.app_state.db.get_contacts(user_id),
592 this.app_state.db.get_channels_for_user(user_id),
593 this.app_state.db.get_channel_invites_for_user(user_id),
594 ).await?;
595
596 {
597 let mut pool = this.connection_pool.lock();
598 pool.add_connection(connection_id, user_id, user.admin);
599 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
600 this.peer.send(connection_id, build_channels_update(
601 channels_for_user,
602 channel_invites
603 ))?;
604 }
605
606 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
607 this.peer.send(connection_id, incoming_call)?;
608 }
609
610 let session = Session {
611 user_id,
612 connection_id,
613 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
614 peer: this.peer.clone(),
615 connection_pool: this.connection_pool.clone(),
616 live_kit_client: this.app_state.live_kit_client.clone(),
617 executor: executor.clone(),
618 };
619 update_user_contacts(user_id, &session).await?;
620
621 let handle_io = handle_io.fuse();
622 futures::pin_mut!(handle_io);
623
624 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
625 // This prevents deadlocks when e.g., client A performs a request to client B and
626 // client B performs a request to client A. If both clients stop processing further
627 // messages until their respective request completes, they won't have a chance to
628 // respond to the other client's request and cause a deadlock.
629 //
630 // This arrangement ensures we will attempt to process earlier messages first, but fall
631 // back to processing messages arrived later in the spirit of making progress.
632 let mut foreground_message_handlers = FuturesUnordered::new();
633 let concurrent_handlers = Arc::new(Semaphore::new(256));
634 loop {
635 let next_message = async {
636 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
637 let message = incoming_rx.next().await;
638 (permit, message)
639 }.fuse();
640 futures::pin_mut!(next_message);
641 futures::select_biased! {
642 _ = teardown.changed().fuse() => return Ok(()),
643 result = handle_io => {
644 if let Err(error) = result {
645 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
646 }
647 break;
648 }
649 _ = foreground_message_handlers.next() => {}
650 next_message = next_message => {
651 let (permit, message) = next_message;
652 if let Some(message) = message {
653 let type_name = message.payload_type_name();
654 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
655 let span_enter = span.enter();
656 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
657 let is_background = message.is_background();
658 let handle_message = (handler)(message, session.clone());
659 drop(span_enter);
660
661 let handle_message = async move {
662 handle_message.await;
663 drop(permit);
664 }.instrument(span);
665 if is_background {
666 executor.spawn_detached(handle_message);
667 } else {
668 foreground_message_handlers.push(handle_message);
669 }
670 } else {
671 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
672 }
673 } else {
674 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
675 break;
676 }
677 }
678 }
679 }
680
681 drop(foreground_message_handlers);
682 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
683 if let Err(error) = connection_lost(session, teardown, executor).await {
684 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
685 }
686
687 Ok(())
688 }.instrument(span)
689 }
690
691 pub async fn invite_code_redeemed(
692 self: &Arc<Self>,
693 inviter_id: UserId,
694 invitee_id: UserId,
695 ) -> Result<()> {
696 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
697 if let Some(code) = &user.invite_code {
698 let pool = self.connection_pool.lock();
699 let invitee_contact = contact_for_user(invitee_id, false, &pool);
700 for connection_id in pool.user_connection_ids(inviter_id) {
701 self.peer.send(
702 connection_id,
703 proto::UpdateContacts {
704 contacts: vec![invitee_contact.clone()],
705 ..Default::default()
706 },
707 )?;
708 self.peer.send(
709 connection_id,
710 proto::UpdateInviteInfo {
711 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
712 count: user.invite_count as u32,
713 },
714 )?;
715 }
716 }
717 }
718 Ok(())
719 }
720
721 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
722 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
723 if let Some(invite_code) = &user.invite_code {
724 let pool = self.connection_pool.lock();
725 for connection_id in pool.user_connection_ids(user_id) {
726 self.peer.send(
727 connection_id,
728 proto::UpdateInviteInfo {
729 url: format!(
730 "{}{}",
731 self.app_state.config.invite_link_prefix, invite_code
732 ),
733 count: user.invite_count as u32,
734 },
735 )?;
736 }
737 }
738 }
739 Ok(())
740 }
741
742 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
743 ServerSnapshot {
744 connection_pool: ConnectionPoolGuard {
745 guard: self.connection_pool.lock(),
746 _not_send: PhantomData,
747 },
748 peer: &self.peer,
749 }
750 }
751}
752
753impl<'a> Deref for ConnectionPoolGuard<'a> {
754 type Target = ConnectionPool;
755
756 fn deref(&self) -> &Self::Target {
757 &*self.guard
758 }
759}
760
761impl<'a> DerefMut for ConnectionPoolGuard<'a> {
762 fn deref_mut(&mut self) -> &mut Self::Target {
763 &mut *self.guard
764 }
765}
766
767impl<'a> Drop for ConnectionPoolGuard<'a> {
768 fn drop(&mut self) {
769 #[cfg(test)]
770 self.check_invariants();
771 }
772}
773
774fn broadcast<F>(
775 sender_id: Option<ConnectionId>,
776 receiver_ids: impl IntoIterator<Item = ConnectionId>,
777 mut f: F,
778) where
779 F: FnMut(ConnectionId) -> anyhow::Result<()>,
780{
781 for receiver_id in receiver_ids {
782 if Some(receiver_id) != sender_id {
783 if let Err(error) = f(receiver_id) {
784 tracing::error!("failed to send to {:?} {}", receiver_id, error);
785 }
786 }
787 }
788}
789
790lazy_static! {
791 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
792}
793
794pub struct ProtocolVersion(u32);
795
796impl Header for ProtocolVersion {
797 fn name() -> &'static HeaderName {
798 &ZED_PROTOCOL_VERSION
799 }
800
801 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
802 where
803 Self: Sized,
804 I: Iterator<Item = &'i axum::http::HeaderValue>,
805 {
806 let version = values
807 .next()
808 .ok_or_else(axum::headers::Error::invalid)?
809 .to_str()
810 .map_err(|_| axum::headers::Error::invalid())?
811 .parse()
812 .map_err(|_| axum::headers::Error::invalid())?;
813 Ok(Self(version))
814 }
815
816 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
817 values.extend([self.0.to_string().parse().unwrap()]);
818 }
819}
820
821pub fn routes(server: Arc<Server>) -> Router<Body> {
822 Router::new()
823 .route("/rpc", get(handle_websocket_request))
824 .layer(
825 ServiceBuilder::new()
826 .layer(Extension(server.app_state.clone()))
827 .layer(middleware::from_fn(auth::validate_header)),
828 )
829 .route("/metrics", get(handle_metrics))
830 .layer(Extension(server))
831}
832
833pub async fn handle_websocket_request(
834 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
835 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
836 Extension(server): Extension<Arc<Server>>,
837 Extension(user): Extension<User>,
838 ws: WebSocketUpgrade,
839) -> axum::response::Response {
840 if protocol_version != rpc::PROTOCOL_VERSION {
841 return (
842 StatusCode::UPGRADE_REQUIRED,
843 "client must be upgraded".to_string(),
844 )
845 .into_response();
846 }
847 let socket_address = socket_address.to_string();
848 ws.on_upgrade(move |socket| {
849 use util::ResultExt;
850 let socket = socket
851 .map_ok(to_tungstenite_message)
852 .err_into()
853 .with(|message| async move { Ok(to_axum_message(message)) });
854 let connection = Connection::new(Box::pin(socket));
855 async move {
856 server
857 .handle_connection(connection, socket_address, user, None, Executor::Production)
858 .await
859 .log_err();
860 }
861 })
862}
863
864pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
865 let connections = server
866 .connection_pool
867 .lock()
868 .connections()
869 .filter(|connection| !connection.admin)
870 .count();
871
872 METRIC_CONNECTIONS.set(connections as _);
873
874 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
875 METRIC_SHARED_PROJECTS.set(shared_projects as _);
876
877 let encoder = prometheus::TextEncoder::new();
878 let metric_families = prometheus::gather();
879 let encoded_metrics = encoder
880 .encode_to_string(&metric_families)
881 .map_err(|err| anyhow!("{}", err))?;
882 Ok(encoded_metrics)
883}
884
885#[instrument(err, skip(executor))]
886async fn connection_lost(
887 session: Session,
888 mut teardown: watch::Receiver<()>,
889 executor: Executor,
890) -> Result<()> {
891 session.peer.disconnect(session.connection_id);
892 session
893 .connection_pool()
894 .await
895 .remove_connection(session.connection_id)?;
896
897 session
898 .db()
899 .await
900 .connection_lost(session.connection_id)
901 .await
902 .trace_err();
903
904 futures::select_biased! {
905 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
906 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
907 leave_room_for_session(&session).await.trace_err();
908 leave_channel_buffers_for_session(&session)
909 .await
910 .trace_err();
911
912 if !session
913 .connection_pool()
914 .await
915 .is_user_online(session.user_id)
916 {
917 let db = session.db().await;
918 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
919 room_updated(&room, &session.peer);
920 }
921 }
922
923 update_user_contacts(session.user_id, &session).await?;
924 }
925 _ = teardown.changed().fuse() => {}
926 }
927
928 Ok(())
929}
930
931async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
932 response.send(proto::Ack {})?;
933 Ok(())
934}
935
936async fn create_room(
937 _request: proto::CreateRoom,
938 response: Response<proto::CreateRoom>,
939 session: Session,
940) -> Result<()> {
941 let live_kit_room = nanoid::nanoid!(30);
942
943 let live_kit_connection_info = {
944 let live_kit_room = live_kit_room.clone();
945 let live_kit = session.live_kit_client.as_ref();
946
947 util::async_iife!({
948 let live_kit = live_kit?;
949
950 let token = live_kit
951 .room_token(&live_kit_room, &session.user_id.to_string())
952 .trace_err()?;
953
954 Some(proto::LiveKitConnectionInfo {
955 server_url: live_kit.url().into(),
956 token,
957 can_publish: true,
958 })
959 })
960 }
961 .await;
962
963 let room = session
964 .db()
965 .await
966 .create_room(
967 session.user_id,
968 session.connection_id,
969 &live_kit_room,
970 RELEASE_CHANNEL_NAME.as_str(),
971 )
972 .await?;
973
974 response.send(proto::CreateRoomResponse {
975 room: Some(room.clone()),
976 live_kit_connection_info,
977 })?;
978
979 update_user_contacts(session.user_id, &session).await?;
980 Ok(())
981}
982
983async fn join_room(
984 request: proto::JoinRoom,
985 response: Response<proto::JoinRoom>,
986 session: Session,
987) -> Result<()> {
988 let room_id = RoomId::from_proto(request.id);
989
990 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
991
992 if let Some(channel_id) = channel_id {
993 return join_channel_internal(channel_id, Box::new(response), session).await;
994 }
995
996 let joined_room = {
997 let room = session
998 .db()
999 .await
1000 .join_room(
1001 room_id,
1002 session.user_id,
1003 session.connection_id,
1004 RELEASE_CHANNEL_NAME.as_str(),
1005 )
1006 .await?;
1007 room_updated(&room.room, &session.peer);
1008 room.into_inner()
1009 };
1010
1011 for connection_id in session
1012 .connection_pool()
1013 .await
1014 .user_connection_ids(session.user_id)
1015 {
1016 session
1017 .peer
1018 .send(
1019 connection_id,
1020 proto::CallCanceled {
1021 room_id: room_id.to_proto(),
1022 },
1023 )
1024 .trace_err();
1025 }
1026
1027 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1028 if let Some(token) = live_kit
1029 .room_token(
1030 &joined_room.room.live_kit_room,
1031 &session.user_id.to_string(),
1032 )
1033 .trace_err()
1034 {
1035 Some(proto::LiveKitConnectionInfo {
1036 server_url: live_kit.url().into(),
1037 token,
1038 can_publish: true,
1039 })
1040 } else {
1041 None
1042 }
1043 } else {
1044 None
1045 };
1046
1047 response.send(proto::JoinRoomResponse {
1048 room: Some(joined_room.room),
1049 channel_id: None,
1050 live_kit_connection_info,
1051 })?;
1052
1053 update_user_contacts(session.user_id, &session).await?;
1054 Ok(())
1055}
1056
1057async fn rejoin_room(
1058 request: proto::RejoinRoom,
1059 response: Response<proto::RejoinRoom>,
1060 session: Session,
1061) -> Result<()> {
1062 let room;
1063 let channel_id;
1064 let channel_members;
1065 {
1066 let mut rejoined_room = session
1067 .db()
1068 .await
1069 .rejoin_room(request, session.user_id, session.connection_id)
1070 .await?;
1071
1072 response.send(proto::RejoinRoomResponse {
1073 room: Some(rejoined_room.room.clone()),
1074 reshared_projects: rejoined_room
1075 .reshared_projects
1076 .iter()
1077 .map(|project| proto::ResharedProject {
1078 id: project.id.to_proto(),
1079 collaborators: project
1080 .collaborators
1081 .iter()
1082 .map(|collaborator| collaborator.to_proto())
1083 .collect(),
1084 })
1085 .collect(),
1086 rejoined_projects: rejoined_room
1087 .rejoined_projects
1088 .iter()
1089 .map(|rejoined_project| proto::RejoinedProject {
1090 id: rejoined_project.id.to_proto(),
1091 worktrees: rejoined_project
1092 .worktrees
1093 .iter()
1094 .map(|worktree| proto::WorktreeMetadata {
1095 id: worktree.id,
1096 root_name: worktree.root_name.clone(),
1097 visible: worktree.visible,
1098 abs_path: worktree.abs_path.clone(),
1099 })
1100 .collect(),
1101 collaborators: rejoined_project
1102 .collaborators
1103 .iter()
1104 .map(|collaborator| collaborator.to_proto())
1105 .collect(),
1106 language_servers: rejoined_project.language_servers.clone(),
1107 })
1108 .collect(),
1109 })?;
1110 room_updated(&rejoined_room.room, &session.peer);
1111
1112 for project in &rejoined_room.reshared_projects {
1113 for collaborator in &project.collaborators {
1114 session
1115 .peer
1116 .send(
1117 collaborator.connection_id,
1118 proto::UpdateProjectCollaborator {
1119 project_id: project.id.to_proto(),
1120 old_peer_id: Some(project.old_connection_id.into()),
1121 new_peer_id: Some(session.connection_id.into()),
1122 },
1123 )
1124 .trace_err();
1125 }
1126
1127 broadcast(
1128 Some(session.connection_id),
1129 project
1130 .collaborators
1131 .iter()
1132 .map(|collaborator| collaborator.connection_id),
1133 |connection_id| {
1134 session.peer.forward_send(
1135 session.connection_id,
1136 connection_id,
1137 proto::UpdateProject {
1138 project_id: project.id.to_proto(),
1139 worktrees: project.worktrees.clone(),
1140 },
1141 )
1142 },
1143 );
1144 }
1145
1146 for project in &rejoined_room.rejoined_projects {
1147 for collaborator in &project.collaborators {
1148 session
1149 .peer
1150 .send(
1151 collaborator.connection_id,
1152 proto::UpdateProjectCollaborator {
1153 project_id: project.id.to_proto(),
1154 old_peer_id: Some(project.old_connection_id.into()),
1155 new_peer_id: Some(session.connection_id.into()),
1156 },
1157 )
1158 .trace_err();
1159 }
1160 }
1161
1162 for project in &mut rejoined_room.rejoined_projects {
1163 for worktree in mem::take(&mut project.worktrees) {
1164 #[cfg(any(test, feature = "test-support"))]
1165 const MAX_CHUNK_SIZE: usize = 2;
1166 #[cfg(not(any(test, feature = "test-support")))]
1167 const MAX_CHUNK_SIZE: usize = 256;
1168
1169 // Stream this worktree's entries.
1170 let message = proto::UpdateWorktree {
1171 project_id: project.id.to_proto(),
1172 worktree_id: worktree.id,
1173 abs_path: worktree.abs_path.clone(),
1174 root_name: worktree.root_name,
1175 updated_entries: worktree.updated_entries,
1176 removed_entries: worktree.removed_entries,
1177 scan_id: worktree.scan_id,
1178 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1179 updated_repositories: worktree.updated_repositories,
1180 removed_repositories: worktree.removed_repositories,
1181 };
1182 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1183 session.peer.send(session.connection_id, update.clone())?;
1184 }
1185
1186 // Stream this worktree's diagnostics.
1187 for summary in worktree.diagnostic_summaries {
1188 session.peer.send(
1189 session.connection_id,
1190 proto::UpdateDiagnosticSummary {
1191 project_id: project.id.to_proto(),
1192 worktree_id: worktree.id,
1193 summary: Some(summary),
1194 },
1195 )?;
1196 }
1197
1198 for settings_file in worktree.settings_files {
1199 session.peer.send(
1200 session.connection_id,
1201 proto::UpdateWorktreeSettings {
1202 project_id: project.id.to_proto(),
1203 worktree_id: worktree.id,
1204 path: settings_file.path,
1205 content: Some(settings_file.content),
1206 },
1207 )?;
1208 }
1209 }
1210
1211 for language_server in &project.language_servers {
1212 session.peer.send(
1213 session.connection_id,
1214 proto::UpdateLanguageServer {
1215 project_id: project.id.to_proto(),
1216 language_server_id: language_server.id,
1217 variant: Some(
1218 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1219 proto::LspDiskBasedDiagnosticsUpdated {},
1220 ),
1221 ),
1222 },
1223 )?;
1224 }
1225 }
1226
1227 let rejoined_room = rejoined_room.into_inner();
1228
1229 room = rejoined_room.room;
1230 channel_id = rejoined_room.channel_id;
1231 channel_members = rejoined_room.channel_members;
1232 }
1233
1234 if let Some(channel_id) = channel_id {
1235 channel_updated(
1236 channel_id,
1237 &room,
1238 &channel_members,
1239 &session.peer,
1240 &*session.connection_pool().await,
1241 );
1242 }
1243
1244 update_user_contacts(session.user_id, &session).await?;
1245 Ok(())
1246}
1247
1248async fn leave_room(
1249 _: proto::LeaveRoom,
1250 response: Response<proto::LeaveRoom>,
1251 session: Session,
1252) -> Result<()> {
1253 leave_room_for_session(&session).await?;
1254 response.send(proto::Ack {})?;
1255 Ok(())
1256}
1257
1258async fn call(
1259 request: proto::Call,
1260 response: Response<proto::Call>,
1261 session: Session,
1262) -> Result<()> {
1263 let room_id = RoomId::from_proto(request.room_id);
1264 let calling_user_id = session.user_id;
1265 let calling_connection_id = session.connection_id;
1266 let called_user_id = UserId::from_proto(request.called_user_id);
1267 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1268 if !session
1269 .db()
1270 .await
1271 .has_contact(calling_user_id, called_user_id)
1272 .await?
1273 {
1274 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1275 }
1276
1277 let incoming_call = {
1278 let (room, incoming_call) = &mut *session
1279 .db()
1280 .await
1281 .call(
1282 room_id,
1283 calling_user_id,
1284 calling_connection_id,
1285 called_user_id,
1286 initial_project_id,
1287 )
1288 .await?;
1289 room_updated(&room, &session.peer);
1290 mem::take(incoming_call)
1291 };
1292 update_user_contacts(called_user_id, &session).await?;
1293
1294 let mut calls = session
1295 .connection_pool()
1296 .await
1297 .user_connection_ids(called_user_id)
1298 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1299 .collect::<FuturesUnordered<_>>();
1300
1301 while let Some(call_response) = calls.next().await {
1302 match call_response.as_ref() {
1303 Ok(_) => {
1304 response.send(proto::Ack {})?;
1305 return Ok(());
1306 }
1307 Err(_) => {
1308 call_response.trace_err();
1309 }
1310 }
1311 }
1312
1313 {
1314 let room = session
1315 .db()
1316 .await
1317 .call_failed(room_id, called_user_id)
1318 .await?;
1319 room_updated(&room, &session.peer);
1320 }
1321 update_user_contacts(called_user_id, &session).await?;
1322
1323 Err(anyhow!("failed to ring user"))?
1324}
1325
1326async fn cancel_call(
1327 request: proto::CancelCall,
1328 response: Response<proto::CancelCall>,
1329 session: Session,
1330) -> Result<()> {
1331 let called_user_id = UserId::from_proto(request.called_user_id);
1332 let room_id = RoomId::from_proto(request.room_id);
1333 {
1334 let room = session
1335 .db()
1336 .await
1337 .cancel_call(room_id, session.connection_id, called_user_id)
1338 .await?;
1339 room_updated(&room, &session.peer);
1340 }
1341
1342 for connection_id in session
1343 .connection_pool()
1344 .await
1345 .user_connection_ids(called_user_id)
1346 {
1347 session
1348 .peer
1349 .send(
1350 connection_id,
1351 proto::CallCanceled {
1352 room_id: room_id.to_proto(),
1353 },
1354 )
1355 .trace_err();
1356 }
1357 response.send(proto::Ack {})?;
1358
1359 update_user_contacts(called_user_id, &session).await?;
1360 Ok(())
1361}
1362
1363async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1364 let room_id = RoomId::from_proto(message.room_id);
1365 {
1366 let room = session
1367 .db()
1368 .await
1369 .decline_call(Some(room_id), session.user_id)
1370 .await?
1371 .ok_or_else(|| anyhow!("failed to decline call"))?;
1372 room_updated(&room, &session.peer);
1373 }
1374
1375 for connection_id in session
1376 .connection_pool()
1377 .await
1378 .user_connection_ids(session.user_id)
1379 {
1380 session
1381 .peer
1382 .send(
1383 connection_id,
1384 proto::CallCanceled {
1385 room_id: room_id.to_proto(),
1386 },
1387 )
1388 .trace_err();
1389 }
1390 update_user_contacts(session.user_id, &session).await?;
1391 Ok(())
1392}
1393
1394async fn update_participant_location(
1395 request: proto::UpdateParticipantLocation,
1396 response: Response<proto::UpdateParticipantLocation>,
1397 session: Session,
1398) -> Result<()> {
1399 let room_id = RoomId::from_proto(request.room_id);
1400 let location = request
1401 .location
1402 .ok_or_else(|| anyhow!("invalid location"))?;
1403
1404 let db = session.db().await;
1405 let room = db
1406 .update_room_participant_location(room_id, session.connection_id, location)
1407 .await?;
1408
1409 room_updated(&room, &session.peer);
1410 response.send(proto::Ack {})?;
1411 Ok(())
1412}
1413
1414async fn share_project(
1415 request: proto::ShareProject,
1416 response: Response<proto::ShareProject>,
1417 session: Session,
1418) -> Result<()> {
1419 let (project_id, room) = &*session
1420 .db()
1421 .await
1422 .share_project(
1423 RoomId::from_proto(request.room_id),
1424 session.connection_id,
1425 &request.worktrees,
1426 )
1427 .await?;
1428 response.send(proto::ShareProjectResponse {
1429 project_id: project_id.to_proto(),
1430 })?;
1431 room_updated(&room, &session.peer);
1432
1433 Ok(())
1434}
1435
1436async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1437 let project_id = ProjectId::from_proto(message.project_id);
1438
1439 let (room, guest_connection_ids) = &*session
1440 .db()
1441 .await
1442 .unshare_project(project_id, session.connection_id)
1443 .await?;
1444
1445 broadcast(
1446 Some(session.connection_id),
1447 guest_connection_ids.iter().copied(),
1448 |conn_id| session.peer.send(conn_id, message.clone()),
1449 );
1450 room_updated(&room, &session.peer);
1451
1452 Ok(())
1453}
1454
1455async fn join_project(
1456 request: proto::JoinProject,
1457 response: Response<proto::JoinProject>,
1458 session: Session,
1459) -> Result<()> {
1460 let project_id = ProjectId::from_proto(request.project_id);
1461 let guest_user_id = session.user_id;
1462
1463 tracing::info!(%project_id, "join project");
1464
1465 let (project, replica_id) = &mut *session
1466 .db()
1467 .await
1468 .join_project(project_id, session.connection_id)
1469 .await?;
1470
1471 let collaborators = project
1472 .collaborators
1473 .iter()
1474 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1475 .map(|collaborator| collaborator.to_proto())
1476 .collect::<Vec<_>>();
1477
1478 let worktrees = project
1479 .worktrees
1480 .iter()
1481 .map(|(id, worktree)| proto::WorktreeMetadata {
1482 id: *id,
1483 root_name: worktree.root_name.clone(),
1484 visible: worktree.visible,
1485 abs_path: worktree.abs_path.clone(),
1486 })
1487 .collect::<Vec<_>>();
1488
1489 for collaborator in &collaborators {
1490 session
1491 .peer
1492 .send(
1493 collaborator.peer_id.unwrap().into(),
1494 proto::AddProjectCollaborator {
1495 project_id: project_id.to_proto(),
1496 collaborator: Some(proto::Collaborator {
1497 peer_id: Some(session.connection_id.into()),
1498 replica_id: replica_id.0 as u32,
1499 user_id: guest_user_id.to_proto(),
1500 }),
1501 },
1502 )
1503 .trace_err();
1504 }
1505
1506 // First, we send the metadata associated with each worktree.
1507 response.send(proto::JoinProjectResponse {
1508 worktrees: worktrees.clone(),
1509 replica_id: replica_id.0 as u32,
1510 collaborators: collaborators.clone(),
1511 language_servers: project.language_servers.clone(),
1512 })?;
1513
1514 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1515 #[cfg(any(test, feature = "test-support"))]
1516 const MAX_CHUNK_SIZE: usize = 2;
1517 #[cfg(not(any(test, feature = "test-support")))]
1518 const MAX_CHUNK_SIZE: usize = 256;
1519
1520 // Stream this worktree's entries.
1521 let message = proto::UpdateWorktree {
1522 project_id: project_id.to_proto(),
1523 worktree_id,
1524 abs_path: worktree.abs_path.clone(),
1525 root_name: worktree.root_name,
1526 updated_entries: worktree.entries,
1527 removed_entries: Default::default(),
1528 scan_id: worktree.scan_id,
1529 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1530 updated_repositories: worktree.repository_entries.into_values().collect(),
1531 removed_repositories: Default::default(),
1532 };
1533 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1534 session.peer.send(session.connection_id, update.clone())?;
1535 }
1536
1537 // Stream this worktree's diagnostics.
1538 for summary in worktree.diagnostic_summaries {
1539 session.peer.send(
1540 session.connection_id,
1541 proto::UpdateDiagnosticSummary {
1542 project_id: project_id.to_proto(),
1543 worktree_id: worktree.id,
1544 summary: Some(summary),
1545 },
1546 )?;
1547 }
1548
1549 for settings_file in worktree.settings_files {
1550 session.peer.send(
1551 session.connection_id,
1552 proto::UpdateWorktreeSettings {
1553 project_id: project_id.to_proto(),
1554 worktree_id: worktree.id,
1555 path: settings_file.path,
1556 content: Some(settings_file.content),
1557 },
1558 )?;
1559 }
1560 }
1561
1562 for language_server in &project.language_servers {
1563 session.peer.send(
1564 session.connection_id,
1565 proto::UpdateLanguageServer {
1566 project_id: project_id.to_proto(),
1567 language_server_id: language_server.id,
1568 variant: Some(
1569 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1570 proto::LspDiskBasedDiagnosticsUpdated {},
1571 ),
1572 ),
1573 },
1574 )?;
1575 }
1576
1577 Ok(())
1578}
1579
1580async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1581 let sender_id = session.connection_id;
1582 let project_id = ProjectId::from_proto(request.project_id);
1583
1584 let (room, project) = &*session
1585 .db()
1586 .await
1587 .leave_project(project_id, sender_id)
1588 .await?;
1589 tracing::info!(
1590 %project_id,
1591 host_user_id = %project.host_user_id,
1592 host_connection_id = %project.host_connection_id,
1593 "leave project"
1594 );
1595
1596 project_left(&project, &session);
1597 room_updated(&room, &session.peer);
1598
1599 Ok(())
1600}
1601
1602async fn update_project(
1603 request: proto::UpdateProject,
1604 response: Response<proto::UpdateProject>,
1605 session: Session,
1606) -> Result<()> {
1607 let project_id = ProjectId::from_proto(request.project_id);
1608 let (room, guest_connection_ids) = &*session
1609 .db()
1610 .await
1611 .update_project(project_id, session.connection_id, &request.worktrees)
1612 .await?;
1613 broadcast(
1614 Some(session.connection_id),
1615 guest_connection_ids.iter().copied(),
1616 |connection_id| {
1617 session
1618 .peer
1619 .forward_send(session.connection_id, connection_id, request.clone())
1620 },
1621 );
1622 room_updated(&room, &session.peer);
1623 response.send(proto::Ack {})?;
1624
1625 Ok(())
1626}
1627
1628async fn update_worktree(
1629 request: proto::UpdateWorktree,
1630 response: Response<proto::UpdateWorktree>,
1631 session: Session,
1632) -> Result<()> {
1633 let guest_connection_ids = session
1634 .db()
1635 .await
1636 .update_worktree(&request, session.connection_id)
1637 .await?;
1638
1639 broadcast(
1640 Some(session.connection_id),
1641 guest_connection_ids.iter().copied(),
1642 |connection_id| {
1643 session
1644 .peer
1645 .forward_send(session.connection_id, connection_id, request.clone())
1646 },
1647 );
1648 response.send(proto::Ack {})?;
1649 Ok(())
1650}
1651
1652async fn update_diagnostic_summary(
1653 message: proto::UpdateDiagnosticSummary,
1654 session: Session,
1655) -> Result<()> {
1656 let guest_connection_ids = session
1657 .db()
1658 .await
1659 .update_diagnostic_summary(&message, session.connection_id)
1660 .await?;
1661
1662 broadcast(
1663 Some(session.connection_id),
1664 guest_connection_ids.iter().copied(),
1665 |connection_id| {
1666 session
1667 .peer
1668 .forward_send(session.connection_id, connection_id, message.clone())
1669 },
1670 );
1671
1672 Ok(())
1673}
1674
1675async fn update_worktree_settings(
1676 message: proto::UpdateWorktreeSettings,
1677 session: Session,
1678) -> Result<()> {
1679 let guest_connection_ids = session
1680 .db()
1681 .await
1682 .update_worktree_settings(&message, session.connection_id)
1683 .await?;
1684
1685 broadcast(
1686 Some(session.connection_id),
1687 guest_connection_ids.iter().copied(),
1688 |connection_id| {
1689 session
1690 .peer
1691 .forward_send(session.connection_id, connection_id, message.clone())
1692 },
1693 );
1694
1695 Ok(())
1696}
1697
1698async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1699 broadcast_project_message(request.project_id, request, session).await
1700}
1701
1702async fn start_language_server(
1703 request: proto::StartLanguageServer,
1704 session: Session,
1705) -> Result<()> {
1706 let guest_connection_ids = session
1707 .db()
1708 .await
1709 .start_language_server(&request, session.connection_id)
1710 .await?;
1711
1712 broadcast(
1713 Some(session.connection_id),
1714 guest_connection_ids.iter().copied(),
1715 |connection_id| {
1716 session
1717 .peer
1718 .forward_send(session.connection_id, connection_id, request.clone())
1719 },
1720 );
1721 Ok(())
1722}
1723
1724async fn update_language_server(
1725 request: proto::UpdateLanguageServer,
1726 session: Session,
1727) -> Result<()> {
1728 session.executor.record_backtrace();
1729 let project_id = ProjectId::from_proto(request.project_id);
1730 let project_connection_ids = session
1731 .db()
1732 .await
1733 .project_connection_ids(project_id, session.connection_id)
1734 .await?;
1735 broadcast(
1736 Some(session.connection_id),
1737 project_connection_ids.iter().copied(),
1738 |connection_id| {
1739 session
1740 .peer
1741 .forward_send(session.connection_id, connection_id, request.clone())
1742 },
1743 );
1744 Ok(())
1745}
1746
1747async fn forward_project_request<T>(
1748 request: T,
1749 response: Response<T>,
1750 session: Session,
1751) -> Result<()>
1752where
1753 T: EntityMessage + RequestMessage,
1754{
1755 session.executor.record_backtrace();
1756 let project_id = ProjectId::from_proto(request.remote_entity_id());
1757 let host_connection_id = {
1758 let collaborators = session
1759 .db()
1760 .await
1761 .project_collaborators(project_id, session.connection_id)
1762 .await?;
1763 collaborators
1764 .iter()
1765 .find(|collaborator| collaborator.is_host)
1766 .ok_or_else(|| anyhow!("host not found"))?
1767 .connection_id
1768 };
1769
1770 let payload = session
1771 .peer
1772 .forward_request(session.connection_id, host_connection_id, request)
1773 .await?;
1774
1775 response.send(payload)?;
1776 Ok(())
1777}
1778
1779async fn create_buffer_for_peer(
1780 request: proto::CreateBufferForPeer,
1781 session: Session,
1782) -> Result<()> {
1783 session.executor.record_backtrace();
1784 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1785 session
1786 .peer
1787 .forward_send(session.connection_id, peer_id.into(), request)?;
1788 Ok(())
1789}
1790
1791async fn update_buffer(
1792 request: proto::UpdateBuffer,
1793 response: Response<proto::UpdateBuffer>,
1794 session: Session,
1795) -> Result<()> {
1796 session.executor.record_backtrace();
1797 let project_id = ProjectId::from_proto(request.project_id);
1798 let mut guest_connection_ids;
1799 let mut host_connection_id = None;
1800 {
1801 let collaborators = session
1802 .db()
1803 .await
1804 .project_collaborators(project_id, session.connection_id)
1805 .await?;
1806 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1807 for collaborator in collaborators.iter() {
1808 if collaborator.is_host {
1809 host_connection_id = Some(collaborator.connection_id);
1810 } else {
1811 guest_connection_ids.push(collaborator.connection_id);
1812 }
1813 }
1814 }
1815 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1816
1817 session.executor.record_backtrace();
1818 broadcast(
1819 Some(session.connection_id),
1820 guest_connection_ids,
1821 |connection_id| {
1822 session
1823 .peer
1824 .forward_send(session.connection_id, connection_id, request.clone())
1825 },
1826 );
1827 if host_connection_id != session.connection_id {
1828 session
1829 .peer
1830 .forward_request(session.connection_id, host_connection_id, request.clone())
1831 .await?;
1832 }
1833
1834 response.send(proto::Ack {})?;
1835 Ok(())
1836}
1837
1838async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1839 let project_id = ProjectId::from_proto(request.project_id);
1840 let project_connection_ids = session
1841 .db()
1842 .await
1843 .project_connection_ids(project_id, session.connection_id)
1844 .await?;
1845
1846 broadcast(
1847 Some(session.connection_id),
1848 project_connection_ids.iter().copied(),
1849 |connection_id| {
1850 session
1851 .peer
1852 .forward_send(session.connection_id, connection_id, request.clone())
1853 },
1854 );
1855 Ok(())
1856}
1857
1858async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1859 let project_id = ProjectId::from_proto(request.project_id);
1860 let project_connection_ids = session
1861 .db()
1862 .await
1863 .project_connection_ids(project_id, session.connection_id)
1864 .await?;
1865 broadcast(
1866 Some(session.connection_id),
1867 project_connection_ids.iter().copied(),
1868 |connection_id| {
1869 session
1870 .peer
1871 .forward_send(session.connection_id, connection_id, request.clone())
1872 },
1873 );
1874 Ok(())
1875}
1876
1877async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1878 broadcast_project_message(request.project_id, request, session).await
1879}
1880
1881async fn broadcast_project_message<T: EnvelopedMessage>(
1882 project_id: u64,
1883 request: T,
1884 session: Session,
1885) -> Result<()> {
1886 let project_id = ProjectId::from_proto(project_id);
1887 let project_connection_ids = session
1888 .db()
1889 .await
1890 .project_connection_ids(project_id, session.connection_id)
1891 .await?;
1892 broadcast(
1893 Some(session.connection_id),
1894 project_connection_ids.iter().copied(),
1895 |connection_id| {
1896 session
1897 .peer
1898 .forward_send(session.connection_id, connection_id, request.clone())
1899 },
1900 );
1901 Ok(())
1902}
1903
1904async fn follow(
1905 request: proto::Follow,
1906 response: Response<proto::Follow>,
1907 session: Session,
1908) -> Result<()> {
1909 let room_id = RoomId::from_proto(request.room_id);
1910 let project_id = request.project_id.map(ProjectId::from_proto);
1911 let leader_id = request
1912 .leader_id
1913 .ok_or_else(|| anyhow!("invalid leader id"))?
1914 .into();
1915 let follower_id = session.connection_id;
1916
1917 session
1918 .db()
1919 .await
1920 .check_room_participants(room_id, leader_id, session.connection_id)
1921 .await?;
1922
1923 let response_payload = session
1924 .peer
1925 .forward_request(session.connection_id, leader_id, request)
1926 .await?;
1927 response.send(response_payload)?;
1928
1929 if let Some(project_id) = project_id {
1930 let room = session
1931 .db()
1932 .await
1933 .follow(room_id, project_id, leader_id, follower_id)
1934 .await?;
1935 room_updated(&room, &session.peer);
1936 }
1937
1938 Ok(())
1939}
1940
1941async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1942 let room_id = RoomId::from_proto(request.room_id);
1943 let project_id = request.project_id.map(ProjectId::from_proto);
1944 let leader_id = request
1945 .leader_id
1946 .ok_or_else(|| anyhow!("invalid leader id"))?
1947 .into();
1948 let follower_id = session.connection_id;
1949
1950 session
1951 .db()
1952 .await
1953 .check_room_participants(room_id, leader_id, session.connection_id)
1954 .await?;
1955
1956 session
1957 .peer
1958 .forward_send(session.connection_id, leader_id, request)?;
1959
1960 if let Some(project_id) = project_id {
1961 let room = session
1962 .db()
1963 .await
1964 .unfollow(room_id, project_id, leader_id, follower_id)
1965 .await?;
1966 room_updated(&room, &session.peer);
1967 }
1968
1969 Ok(())
1970}
1971
1972async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1973 let room_id = RoomId::from_proto(request.room_id);
1974 let database = session.db.lock().await;
1975
1976 let connection_ids = if let Some(project_id) = request.project_id {
1977 let project_id = ProjectId::from_proto(project_id);
1978 database
1979 .project_connection_ids(project_id, session.connection_id)
1980 .await?
1981 } else {
1982 database
1983 .room_connection_ids(room_id, session.connection_id)
1984 .await?
1985 };
1986
1987 // For now, don't send view update messages back to that view's current leader.
1988 let connection_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
1989 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1990 _ => None,
1991 });
1992
1993 for follower_peer_id in request.follower_ids.iter().copied() {
1994 let follower_connection_id = follower_peer_id.into();
1995 if Some(follower_peer_id) != connection_id_to_omit
1996 && connection_ids.contains(&follower_connection_id)
1997 {
1998 session.peer.forward_send(
1999 session.connection_id,
2000 follower_connection_id,
2001 request.clone(),
2002 )?;
2003 }
2004 }
2005 Ok(())
2006}
2007
2008async fn get_users(
2009 request: proto::GetUsers,
2010 response: Response<proto::GetUsers>,
2011 session: Session,
2012) -> Result<()> {
2013 let user_ids = request
2014 .user_ids
2015 .into_iter()
2016 .map(UserId::from_proto)
2017 .collect();
2018 let users = session
2019 .db()
2020 .await
2021 .get_users_by_ids(user_ids)
2022 .await?
2023 .into_iter()
2024 .map(|user| proto::User {
2025 id: user.id.to_proto(),
2026 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2027 github_login: user.github_login,
2028 })
2029 .collect();
2030 response.send(proto::UsersResponse { users })?;
2031 Ok(())
2032}
2033
2034async fn fuzzy_search_users(
2035 request: proto::FuzzySearchUsers,
2036 response: Response<proto::FuzzySearchUsers>,
2037 session: Session,
2038) -> Result<()> {
2039 let query = request.query;
2040 let users = match query.len() {
2041 0 => vec![],
2042 1 | 2 => session
2043 .db()
2044 .await
2045 .get_user_by_github_login(&query)
2046 .await?
2047 .into_iter()
2048 .collect(),
2049 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2050 };
2051 let users = users
2052 .into_iter()
2053 .filter(|user| user.id != session.user_id)
2054 .map(|user| proto::User {
2055 id: user.id.to_proto(),
2056 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2057 github_login: user.github_login,
2058 })
2059 .collect();
2060 response.send(proto::UsersResponse { users })?;
2061 Ok(())
2062}
2063
2064async fn request_contact(
2065 request: proto::RequestContact,
2066 response: Response<proto::RequestContact>,
2067 session: Session,
2068) -> Result<()> {
2069 let requester_id = session.user_id;
2070 let responder_id = UserId::from_proto(request.responder_id);
2071 if requester_id == responder_id {
2072 return Err(anyhow!("cannot add yourself as a contact"))?;
2073 }
2074
2075 let notifications = session
2076 .db()
2077 .await
2078 .send_contact_request(requester_id, responder_id)
2079 .await?;
2080
2081 // Update outgoing contact requests of requester
2082 let mut update = proto::UpdateContacts::default();
2083 update.outgoing_requests.push(responder_id.to_proto());
2084 for connection_id in session
2085 .connection_pool()
2086 .await
2087 .user_connection_ids(requester_id)
2088 {
2089 session.peer.send(connection_id, update.clone())?;
2090 }
2091
2092 // Update incoming contact requests of responder
2093 let mut update = proto::UpdateContacts::default();
2094 update
2095 .incoming_requests
2096 .push(proto::IncomingContactRequest {
2097 requester_id: requester_id.to_proto(),
2098 });
2099 let connection_pool = session.connection_pool().await;
2100 for connection_id in connection_pool.user_connection_ids(responder_id) {
2101 session.peer.send(connection_id, update.clone())?;
2102 }
2103
2104 send_notifications(&*connection_pool, &session.peer, notifications);
2105
2106 response.send(proto::Ack {})?;
2107 Ok(())
2108}
2109
2110async fn respond_to_contact_request(
2111 request: proto::RespondToContactRequest,
2112 response: Response<proto::RespondToContactRequest>,
2113 session: Session,
2114) -> Result<()> {
2115 let responder_id = session.user_id;
2116 let requester_id = UserId::from_proto(request.requester_id);
2117 let db = session.db().await;
2118 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2119 db.dismiss_contact_notification(responder_id, requester_id)
2120 .await?;
2121 } else {
2122 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2123
2124 let notifications = db
2125 .respond_to_contact_request(responder_id, requester_id, accept)
2126 .await?;
2127 let requester_busy = db.is_user_busy(requester_id).await?;
2128 let responder_busy = db.is_user_busy(responder_id).await?;
2129
2130 let pool = session.connection_pool().await;
2131 // Update responder with new contact
2132 let mut update = proto::UpdateContacts::default();
2133 if accept {
2134 update
2135 .contacts
2136 .push(contact_for_user(requester_id, requester_busy, &pool));
2137 }
2138 update
2139 .remove_incoming_requests
2140 .push(requester_id.to_proto());
2141 for connection_id in pool.user_connection_ids(responder_id) {
2142 session.peer.send(connection_id, update.clone())?;
2143 }
2144
2145 // Update requester with new contact
2146 let mut update = proto::UpdateContacts::default();
2147 if accept {
2148 update
2149 .contacts
2150 .push(contact_for_user(responder_id, responder_busy, &pool));
2151 }
2152 update
2153 .remove_outgoing_requests
2154 .push(responder_id.to_proto());
2155
2156 for connection_id in pool.user_connection_ids(requester_id) {
2157 session.peer.send(connection_id, update.clone())?;
2158 }
2159
2160 send_notifications(&*pool, &session.peer, notifications);
2161 }
2162
2163 response.send(proto::Ack {})?;
2164 Ok(())
2165}
2166
2167async fn remove_contact(
2168 request: proto::RemoveContact,
2169 response: Response<proto::RemoveContact>,
2170 session: Session,
2171) -> Result<()> {
2172 let requester_id = session.user_id;
2173 let responder_id = UserId::from_proto(request.user_id);
2174 let db = session.db().await;
2175 let (contact_accepted, deleted_notification_id) =
2176 db.remove_contact(requester_id, responder_id).await?;
2177
2178 let pool = session.connection_pool().await;
2179 // Update outgoing contact requests of requester
2180 let mut update = proto::UpdateContacts::default();
2181 if contact_accepted {
2182 update.remove_contacts.push(responder_id.to_proto());
2183 } else {
2184 update
2185 .remove_outgoing_requests
2186 .push(responder_id.to_proto());
2187 }
2188 for connection_id in pool.user_connection_ids(requester_id) {
2189 session.peer.send(connection_id, update.clone())?;
2190 }
2191
2192 // Update incoming contact requests of responder
2193 let mut update = proto::UpdateContacts::default();
2194 if contact_accepted {
2195 update.remove_contacts.push(requester_id.to_proto());
2196 } else {
2197 update
2198 .remove_incoming_requests
2199 .push(requester_id.to_proto());
2200 }
2201 for connection_id in pool.user_connection_ids(responder_id) {
2202 session.peer.send(connection_id, update.clone())?;
2203 if let Some(notification_id) = deleted_notification_id {
2204 session.peer.send(
2205 connection_id,
2206 proto::DeleteNotification {
2207 notification_id: notification_id.to_proto(),
2208 },
2209 )?;
2210 }
2211 }
2212
2213 response.send(proto::Ack {})?;
2214 Ok(())
2215}
2216
2217async fn create_channel(
2218 request: proto::CreateChannel,
2219 response: Response<proto::CreateChannel>,
2220 session: Session,
2221) -> Result<()> {
2222 let db = session.db().await;
2223
2224 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2225 let CreateChannelResult {
2226 channel,
2227 participants_to_update,
2228 } = db
2229 .create_channel(&request.name, parent_id, session.user_id)
2230 .await?;
2231
2232 response.send(proto::CreateChannelResponse {
2233 channel: Some(channel.to_proto()),
2234 parent_id: request.parent_id,
2235 })?;
2236
2237 let connection_pool = session.connection_pool().await;
2238 for (user_id, channels) in participants_to_update {
2239 let update = build_channels_update(channels, vec![]);
2240 for connection_id in connection_pool.user_connection_ids(user_id) {
2241 if user_id == session.user_id {
2242 continue;
2243 }
2244 session.peer.send(connection_id, update.clone())?;
2245 }
2246 }
2247
2248 Ok(())
2249}
2250
2251async fn delete_channel(
2252 request: proto::DeleteChannel,
2253 response: Response<proto::DeleteChannel>,
2254 session: Session,
2255) -> Result<()> {
2256 let db = session.db().await;
2257
2258 let channel_id = request.channel_id;
2259 let (removed_channels, member_ids) = db
2260 .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2261 .await?;
2262 response.send(proto::Ack {})?;
2263
2264 // Notify members of removed channels
2265 let mut update = proto::UpdateChannels::default();
2266 update
2267 .delete_channels
2268 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2269
2270 let connection_pool = session.connection_pool().await;
2271 for member_id in member_ids {
2272 for connection_id in connection_pool.user_connection_ids(member_id) {
2273 session.peer.send(connection_id, update.clone())?;
2274 }
2275 }
2276
2277 Ok(())
2278}
2279
2280async fn invite_channel_member(
2281 request: proto::InviteChannelMember,
2282 response: Response<proto::InviteChannelMember>,
2283 session: Session,
2284) -> Result<()> {
2285 let db = session.db().await;
2286 let channel_id = ChannelId::from_proto(request.channel_id);
2287 let invitee_id = UserId::from_proto(request.user_id);
2288 let InviteMemberResult {
2289 channel,
2290 notifications,
2291 } = db
2292 .invite_channel_member(
2293 channel_id,
2294 invitee_id,
2295 session.user_id,
2296 request.role().into(),
2297 )
2298 .await?;
2299
2300 let update = proto::UpdateChannels {
2301 channel_invitations: vec![channel.to_proto()],
2302 ..Default::default()
2303 };
2304
2305 let connection_pool = session.connection_pool().await;
2306 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2307 session.peer.send(connection_id, update.clone())?;
2308 }
2309
2310 send_notifications(&*connection_pool, &session.peer, notifications);
2311
2312 response.send(proto::Ack {})?;
2313 Ok(())
2314}
2315
2316async fn remove_channel_member(
2317 request: proto::RemoveChannelMember,
2318 response: Response<proto::RemoveChannelMember>,
2319 session: Session,
2320) -> Result<()> {
2321 let db = session.db().await;
2322 let channel_id = ChannelId::from_proto(request.channel_id);
2323 let member_id = UserId::from_proto(request.user_id);
2324
2325 let RemoveChannelMemberResult {
2326 membership_update,
2327 notification_id,
2328 } = db
2329 .remove_channel_member(channel_id, member_id, session.user_id)
2330 .await?;
2331
2332 let connection_pool = &session.connection_pool().await;
2333 notify_membership_updated(
2334 &connection_pool,
2335 membership_update,
2336 member_id,
2337 &session.peer,
2338 );
2339 for connection_id in connection_pool.user_connection_ids(member_id) {
2340 if let Some(notification_id) = notification_id {
2341 session
2342 .peer
2343 .send(
2344 connection_id,
2345 proto::DeleteNotification {
2346 notification_id: notification_id.to_proto(),
2347 },
2348 )
2349 .trace_err();
2350 }
2351 }
2352
2353 response.send(proto::Ack {})?;
2354 Ok(())
2355}
2356
2357async fn set_channel_visibility(
2358 request: proto::SetChannelVisibility,
2359 response: Response<proto::SetChannelVisibility>,
2360 session: Session,
2361) -> Result<()> {
2362 let db = session.db().await;
2363 let channel_id = ChannelId::from_proto(request.channel_id);
2364 let visibility = request.visibility().into();
2365
2366 let SetChannelVisibilityResult {
2367 participants_to_update,
2368 participants_to_remove,
2369 } = db
2370 .set_channel_visibility(channel_id, visibility, session.user_id)
2371 .await?;
2372
2373 let connection_pool = session.connection_pool().await;
2374 for (user_id, channels) in participants_to_update {
2375 let update = build_channels_update(channels, vec![]);
2376 for connection_id in connection_pool.user_connection_ids(user_id) {
2377 session.peer.send(connection_id, update.clone())?;
2378 }
2379 }
2380 for user_id in participants_to_remove {
2381 let update = proto::UpdateChannels {
2382 // for public participants we only need to remove the current channel
2383 // (not descendants)
2384 // because they can still see any public descendants
2385 delete_channels: vec![channel_id.to_proto()],
2386 ..Default::default()
2387 };
2388 for connection_id in connection_pool.user_connection_ids(user_id) {
2389 session.peer.send(connection_id, update.clone())?;
2390 }
2391 }
2392
2393 response.send(proto::Ack {})?;
2394 Ok(())
2395}
2396
2397async fn set_channel_member_role(
2398 request: proto::SetChannelMemberRole,
2399 response: Response<proto::SetChannelMemberRole>,
2400 session: Session,
2401) -> Result<()> {
2402 let db = session.db().await;
2403 let channel_id = ChannelId::from_proto(request.channel_id);
2404 let member_id = UserId::from_proto(request.user_id);
2405 let result = db
2406 .set_channel_member_role(
2407 channel_id,
2408 session.user_id,
2409 member_id,
2410 request.role().into(),
2411 )
2412 .await?;
2413
2414 match result {
2415 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2416 let connection_pool = session.connection_pool().await;
2417 notify_membership_updated(
2418 &connection_pool,
2419 membership_update,
2420 member_id,
2421 &session.peer,
2422 )
2423 }
2424 db::SetMemberRoleResult::InviteUpdated(channel) => {
2425 let update = proto::UpdateChannels {
2426 channel_invitations: vec![channel.to_proto()],
2427 ..Default::default()
2428 };
2429
2430 for connection_id in session
2431 .connection_pool()
2432 .await
2433 .user_connection_ids(member_id)
2434 {
2435 session.peer.send(connection_id, update.clone())?;
2436 }
2437 }
2438 }
2439
2440 response.send(proto::Ack {})?;
2441 Ok(())
2442}
2443
2444async fn rename_channel(
2445 request: proto::RenameChannel,
2446 response: Response<proto::RenameChannel>,
2447 session: Session,
2448) -> Result<()> {
2449 let db = session.db().await;
2450 let channel_id = ChannelId::from_proto(request.channel_id);
2451 let RenameChannelResult {
2452 channel,
2453 participants_to_update,
2454 } = db
2455 .rename_channel(channel_id, session.user_id, &request.name)
2456 .await?;
2457
2458 response.send(proto::RenameChannelResponse {
2459 channel: Some(channel.to_proto()),
2460 })?;
2461
2462 let connection_pool = session.connection_pool().await;
2463 for (user_id, channel) in participants_to_update {
2464 for connection_id in connection_pool.user_connection_ids(user_id) {
2465 let update = proto::UpdateChannels {
2466 channels: vec![channel.to_proto()],
2467 ..Default::default()
2468 };
2469
2470 session.peer.send(connection_id, update.clone())?;
2471 }
2472 }
2473
2474 Ok(())
2475}
2476
2477// TODO: Implement in terms of symlinks
2478// Current behavior of this is more like 'Move root channel'
2479async fn link_channel(
2480 request: proto::LinkChannel,
2481 response: Response<proto::LinkChannel>,
2482 session: Session,
2483) -> Result<()> {
2484 let db = session.db().await;
2485 let channel_id = ChannelId::from_proto(request.channel_id);
2486 let to = ChannelId::from_proto(request.to);
2487
2488 let result = db
2489 .move_channel(channel_id, None, to, session.user_id)
2490 .await?;
2491 drop(db);
2492
2493 notify_channel_moved(result, session).await?;
2494
2495 response.send(Ack {})?;
2496
2497 Ok(())
2498}
2499
2500// TODO: Implement in terms of symlinks
2501async fn unlink_channel(
2502 _request: proto::UnlinkChannel,
2503 _response: Response<proto::UnlinkChannel>,
2504 _session: Session,
2505) -> Result<()> {
2506 Err(anyhow!("unimplemented").into())
2507}
2508
2509async fn move_channel(
2510 request: proto::MoveChannel,
2511 response: Response<proto::MoveChannel>,
2512 session: Session,
2513) -> Result<()> {
2514 let db = session.db().await;
2515 let channel_id = ChannelId::from_proto(request.channel_id);
2516 let from_parent = ChannelId::from_proto(request.from);
2517 let to = ChannelId::from_proto(request.to);
2518
2519 let result = db
2520 .move_channel(channel_id, Some(from_parent), to, session.user_id)
2521 .await?;
2522 drop(db);
2523
2524 notify_channel_moved(result, session).await?;
2525
2526 response.send(Ack {})?;
2527 Ok(())
2528}
2529
2530async fn notify_channel_moved(result: Option<MoveChannelResult>, session: Session) -> Result<()> {
2531 let Some(MoveChannelResult {
2532 participants_to_remove,
2533 participants_to_update,
2534 moved_channels,
2535 }) = result
2536 else {
2537 return Ok(());
2538 };
2539 let moved_channels: Vec<u64> = moved_channels.iter().map(|id| id.to_proto()).collect();
2540
2541 let connection_pool = session.connection_pool().await;
2542 for (user_id, channels) in participants_to_update {
2543 let mut update = build_channels_update(channels, vec![]);
2544 update.delete_channels = moved_channels.clone();
2545 for connection_id in connection_pool.user_connection_ids(user_id) {
2546 session.peer.send(connection_id, update.clone())?;
2547 }
2548 }
2549
2550 for user_id in participants_to_remove {
2551 let update = proto::UpdateChannels {
2552 delete_channels: moved_channels.clone(),
2553 ..Default::default()
2554 };
2555 for connection_id in connection_pool.user_connection_ids(user_id) {
2556 session.peer.send(connection_id, update.clone())?;
2557 }
2558 }
2559 Ok(())
2560}
2561
2562async fn get_channel_members(
2563 request: proto::GetChannelMembers,
2564 response: Response<proto::GetChannelMembers>,
2565 session: Session,
2566) -> Result<()> {
2567 let db = session.db().await;
2568 let channel_id = ChannelId::from_proto(request.channel_id);
2569 let members = db
2570 .get_channel_participant_details(channel_id, session.user_id)
2571 .await?;
2572 response.send(proto::GetChannelMembersResponse { members })?;
2573 Ok(())
2574}
2575
2576async fn respond_to_channel_invite(
2577 request: proto::RespondToChannelInvite,
2578 response: Response<proto::RespondToChannelInvite>,
2579 session: Session,
2580) -> Result<()> {
2581 let db = session.db().await;
2582 let channel_id = ChannelId::from_proto(request.channel_id);
2583 let RespondToChannelInvite {
2584 membership_update,
2585 notifications,
2586 } = db
2587 .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2588 .await?;
2589
2590 let connection_pool = session.connection_pool().await;
2591 if let Some(membership_update) = membership_update {
2592 notify_membership_updated(
2593 &connection_pool,
2594 membership_update,
2595 session.user_id,
2596 &session.peer,
2597 );
2598 } else {
2599 let update = proto::UpdateChannels {
2600 remove_channel_invitations: vec![channel_id.to_proto()],
2601 ..Default::default()
2602 };
2603
2604 for connection_id in connection_pool.user_connection_ids(session.user_id) {
2605 session.peer.send(connection_id, update.clone())?;
2606 }
2607 };
2608
2609 send_notifications(&*connection_pool, &session.peer, notifications);
2610
2611 response.send(proto::Ack {})?;
2612
2613 Ok(())
2614}
2615
2616async fn join_channel(
2617 request: proto::JoinChannel,
2618 response: Response<proto::JoinChannel>,
2619 session: Session,
2620) -> Result<()> {
2621 let channel_id = ChannelId::from_proto(request.channel_id);
2622 join_channel_internal(channel_id, Box::new(response), session).await
2623}
2624
2625trait JoinChannelInternalResponse {
2626 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2627}
2628impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2629 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2630 Response::<proto::JoinChannel>::send(self, result)
2631 }
2632}
2633impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2634 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2635 Response::<proto::JoinRoom>::send(self, result)
2636 }
2637}
2638
2639async fn join_channel_internal(
2640 channel_id: ChannelId,
2641 response: Box<impl JoinChannelInternalResponse>,
2642 session: Session,
2643) -> Result<()> {
2644 let joined_room = {
2645 leave_room_for_session(&session).await?;
2646 let db = session.db().await;
2647
2648 let (joined_room, accept_invite_result, role) = db
2649 .join_channel(
2650 channel_id,
2651 session.user_id,
2652 session.connection_id,
2653 RELEASE_CHANNEL_NAME.as_str(),
2654 )
2655 .await?;
2656
2657 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2658 let (can_publish, token) = if role == ChannelRole::Guest {
2659 (
2660 false,
2661 live_kit
2662 .guest_token(
2663 &joined_room.room.live_kit_room,
2664 &session.user_id.to_string(),
2665 )
2666 .trace_err()?,
2667 )
2668 } else {
2669 (
2670 true,
2671 live_kit
2672 .room_token(
2673 &joined_room.room.live_kit_room,
2674 &session.user_id.to_string(),
2675 )
2676 .trace_err()?,
2677 )
2678 };
2679
2680 Some(LiveKitConnectionInfo {
2681 server_url: live_kit.url().into(),
2682 token,
2683 can_publish,
2684 })
2685 });
2686
2687 response.send(proto::JoinRoomResponse {
2688 room: Some(joined_room.room.clone()),
2689 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2690 live_kit_connection_info,
2691 })?;
2692
2693 let connection_pool = session.connection_pool().await;
2694 if let Some(accept_invite_result) = accept_invite_result {
2695 notify_membership_updated(
2696 &connection_pool,
2697 accept_invite_result,
2698 session.user_id,
2699 &session.peer,
2700 );
2701 }
2702
2703 room_updated(&joined_room.room, &session.peer);
2704
2705 joined_room
2706 };
2707
2708 channel_updated(
2709 channel_id,
2710 &joined_room.room,
2711 &joined_room.channel_members,
2712 &session.peer,
2713 &*session.connection_pool().await,
2714 );
2715
2716 update_user_contacts(session.user_id, &session).await?;
2717 Ok(())
2718}
2719
2720async fn join_channel_buffer(
2721 request: proto::JoinChannelBuffer,
2722 response: Response<proto::JoinChannelBuffer>,
2723 session: Session,
2724) -> Result<()> {
2725 let db = session.db().await;
2726 let channel_id = ChannelId::from_proto(request.channel_id);
2727
2728 let open_response = db
2729 .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2730 .await?;
2731
2732 let collaborators = open_response.collaborators.clone();
2733 response.send(open_response)?;
2734
2735 let update = UpdateChannelBufferCollaborators {
2736 channel_id: channel_id.to_proto(),
2737 collaborators: collaborators.clone(),
2738 };
2739 channel_buffer_updated(
2740 session.connection_id,
2741 collaborators
2742 .iter()
2743 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2744 &update,
2745 &session.peer,
2746 );
2747
2748 Ok(())
2749}
2750
2751async fn update_channel_buffer(
2752 request: proto::UpdateChannelBuffer,
2753 session: Session,
2754) -> Result<()> {
2755 let db = session.db().await;
2756 let channel_id = ChannelId::from_proto(request.channel_id);
2757
2758 let (collaborators, non_collaborators, epoch, version) = db
2759 .update_channel_buffer(channel_id, session.user_id, &request.operations)
2760 .await?;
2761
2762 channel_buffer_updated(
2763 session.connection_id,
2764 collaborators,
2765 &proto::UpdateChannelBuffer {
2766 channel_id: channel_id.to_proto(),
2767 operations: request.operations,
2768 },
2769 &session.peer,
2770 );
2771
2772 let pool = &*session.connection_pool().await;
2773
2774 broadcast(
2775 None,
2776 non_collaborators
2777 .iter()
2778 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2779 |peer_id| {
2780 session.peer.send(
2781 peer_id.into(),
2782 proto::UpdateChannels {
2783 unseen_channel_buffer_changes: vec![proto::UnseenChannelBufferChange {
2784 channel_id: channel_id.to_proto(),
2785 epoch: epoch as u64,
2786 version: version.clone(),
2787 }],
2788 ..Default::default()
2789 },
2790 )
2791 },
2792 );
2793
2794 Ok(())
2795}
2796
2797async fn rejoin_channel_buffers(
2798 request: proto::RejoinChannelBuffers,
2799 response: Response<proto::RejoinChannelBuffers>,
2800 session: Session,
2801) -> Result<()> {
2802 let db = session.db().await;
2803 let buffers = db
2804 .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2805 .await?;
2806
2807 for rejoined_buffer in &buffers {
2808 let collaborators_to_notify = rejoined_buffer
2809 .buffer
2810 .collaborators
2811 .iter()
2812 .filter_map(|c| Some(c.peer_id?.into()));
2813 channel_buffer_updated(
2814 session.connection_id,
2815 collaborators_to_notify,
2816 &proto::UpdateChannelBufferCollaborators {
2817 channel_id: rejoined_buffer.buffer.channel_id,
2818 collaborators: rejoined_buffer.buffer.collaborators.clone(),
2819 },
2820 &session.peer,
2821 );
2822 }
2823
2824 response.send(proto::RejoinChannelBuffersResponse {
2825 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2826 })?;
2827
2828 Ok(())
2829}
2830
2831async fn leave_channel_buffer(
2832 request: proto::LeaveChannelBuffer,
2833 response: Response<proto::LeaveChannelBuffer>,
2834 session: Session,
2835) -> Result<()> {
2836 let db = session.db().await;
2837 let channel_id = ChannelId::from_proto(request.channel_id);
2838
2839 let left_buffer = db
2840 .leave_channel_buffer(channel_id, session.connection_id)
2841 .await?;
2842
2843 response.send(Ack {})?;
2844
2845 channel_buffer_updated(
2846 session.connection_id,
2847 left_buffer.connections,
2848 &proto::UpdateChannelBufferCollaborators {
2849 channel_id: channel_id.to_proto(),
2850 collaborators: left_buffer.collaborators,
2851 },
2852 &session.peer,
2853 );
2854
2855 Ok(())
2856}
2857
2858fn channel_buffer_updated<T: EnvelopedMessage>(
2859 sender_id: ConnectionId,
2860 collaborators: impl IntoIterator<Item = ConnectionId>,
2861 message: &T,
2862 peer: &Peer,
2863) {
2864 broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2865 peer.send(peer_id.into(), message.clone())
2866 });
2867}
2868
2869fn send_notifications(
2870 connection_pool: &ConnectionPool,
2871 peer: &Peer,
2872 notifications: db::NotificationBatch,
2873) {
2874 for (user_id, notification) in notifications {
2875 for connection_id in connection_pool.user_connection_ids(user_id) {
2876 if let Err(error) = peer.send(
2877 connection_id,
2878 proto::AddNotification {
2879 notification: Some(notification.clone()),
2880 },
2881 ) {
2882 tracing::error!(
2883 "failed to send notification to {:?} {}",
2884 connection_id,
2885 error
2886 );
2887 }
2888 }
2889 }
2890}
2891
2892async fn send_channel_message(
2893 request: proto::SendChannelMessage,
2894 response: Response<proto::SendChannelMessage>,
2895 session: Session,
2896) -> Result<()> {
2897 // Validate the message body.
2898 let body = request.body.trim().to_string();
2899 if body.len() > MAX_MESSAGE_LEN {
2900 return Err(anyhow!("message is too long"))?;
2901 }
2902 if body.is_empty() {
2903 return Err(anyhow!("message can't be blank"))?;
2904 }
2905
2906 // TODO: adjust mentions if body is trimmed
2907
2908 let timestamp = OffsetDateTime::now_utc();
2909 let nonce = request
2910 .nonce
2911 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
2912
2913 let channel_id = ChannelId::from_proto(request.channel_id);
2914 let CreatedChannelMessage {
2915 message_id,
2916 participant_connection_ids,
2917 channel_members,
2918 notifications,
2919 } = session
2920 .db()
2921 .await
2922 .create_channel_message(
2923 channel_id,
2924 session.user_id,
2925 &body,
2926 &request.mentions,
2927 timestamp,
2928 nonce.clone().into(),
2929 )
2930 .await?;
2931 let message = proto::ChannelMessage {
2932 sender_id: session.user_id.to_proto(),
2933 id: message_id.to_proto(),
2934 body,
2935 mentions: request.mentions,
2936 timestamp: timestamp.unix_timestamp() as u64,
2937 nonce: Some(nonce),
2938 };
2939 broadcast(
2940 Some(session.connection_id),
2941 participant_connection_ids,
2942 |connection| {
2943 session.peer.send(
2944 connection,
2945 proto::ChannelMessageSent {
2946 channel_id: channel_id.to_proto(),
2947 message: Some(message.clone()),
2948 },
2949 )
2950 },
2951 );
2952 response.send(proto::SendChannelMessageResponse {
2953 message: Some(message),
2954 })?;
2955
2956 let pool = &*session.connection_pool().await;
2957 broadcast(
2958 None,
2959 channel_members
2960 .iter()
2961 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2962 |peer_id| {
2963 session.peer.send(
2964 peer_id.into(),
2965 proto::UpdateChannels {
2966 unseen_channel_messages: vec![proto::UnseenChannelMessage {
2967 channel_id: channel_id.to_proto(),
2968 message_id: message_id.to_proto(),
2969 }],
2970 ..Default::default()
2971 },
2972 )
2973 },
2974 );
2975 send_notifications(pool, &session.peer, notifications);
2976
2977 Ok(())
2978}
2979
2980async fn remove_channel_message(
2981 request: proto::RemoveChannelMessage,
2982 response: Response<proto::RemoveChannelMessage>,
2983 session: Session,
2984) -> Result<()> {
2985 let channel_id = ChannelId::from_proto(request.channel_id);
2986 let message_id = MessageId::from_proto(request.message_id);
2987 let connection_ids = session
2988 .db()
2989 .await
2990 .remove_channel_message(channel_id, message_id, session.user_id)
2991 .await?;
2992 broadcast(Some(session.connection_id), connection_ids, |connection| {
2993 session.peer.send(connection, request.clone())
2994 });
2995 response.send(proto::Ack {})?;
2996 Ok(())
2997}
2998
2999async fn acknowledge_channel_message(
3000 request: proto::AckChannelMessage,
3001 session: Session,
3002) -> Result<()> {
3003 let channel_id = ChannelId::from_proto(request.channel_id);
3004 let message_id = MessageId::from_proto(request.message_id);
3005 let notifications = session
3006 .db()
3007 .await
3008 .observe_channel_message(channel_id, session.user_id, message_id)
3009 .await?;
3010 send_notifications(
3011 &*session.connection_pool().await,
3012 &session.peer,
3013 notifications,
3014 );
3015 Ok(())
3016}
3017
3018async fn acknowledge_buffer_version(
3019 request: proto::AckBufferOperation,
3020 session: Session,
3021) -> Result<()> {
3022 let buffer_id = BufferId::from_proto(request.buffer_id);
3023 session
3024 .db()
3025 .await
3026 .observe_buffer_version(
3027 buffer_id,
3028 session.user_id,
3029 request.epoch as i32,
3030 &request.version,
3031 )
3032 .await?;
3033 Ok(())
3034}
3035
3036async fn join_channel_chat(
3037 request: proto::JoinChannelChat,
3038 response: Response<proto::JoinChannelChat>,
3039 session: Session,
3040) -> Result<()> {
3041 let channel_id = ChannelId::from_proto(request.channel_id);
3042
3043 let db = session.db().await;
3044 db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3045 .await?;
3046 let messages = db
3047 .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3048 .await?;
3049 response.send(proto::JoinChannelChatResponse {
3050 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3051 messages,
3052 })?;
3053 Ok(())
3054}
3055
3056async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3057 let channel_id = ChannelId::from_proto(request.channel_id);
3058 session
3059 .db()
3060 .await
3061 .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3062 .await?;
3063 Ok(())
3064}
3065
3066async fn get_channel_messages(
3067 request: proto::GetChannelMessages,
3068 response: Response<proto::GetChannelMessages>,
3069 session: Session,
3070) -> Result<()> {
3071 let channel_id = ChannelId::from_proto(request.channel_id);
3072 let messages = session
3073 .db()
3074 .await
3075 .get_channel_messages(
3076 channel_id,
3077 session.user_id,
3078 MESSAGE_COUNT_PER_PAGE,
3079 Some(MessageId::from_proto(request.before_message_id)),
3080 )
3081 .await?;
3082 response.send(proto::GetChannelMessagesResponse {
3083 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3084 messages,
3085 })?;
3086 Ok(())
3087}
3088
3089async fn get_channel_messages_by_id(
3090 request: proto::GetChannelMessagesById,
3091 response: Response<proto::GetChannelMessagesById>,
3092 session: Session,
3093) -> Result<()> {
3094 let message_ids = request
3095 .message_ids
3096 .iter()
3097 .map(|id| MessageId::from_proto(*id))
3098 .collect::<Vec<_>>();
3099 let messages = session
3100 .db()
3101 .await
3102 .get_channel_messages_by_id(session.user_id, &message_ids)
3103 .await?;
3104 response.send(proto::GetChannelMessagesResponse {
3105 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3106 messages,
3107 })?;
3108 Ok(())
3109}
3110
3111async fn get_notifications(
3112 request: proto::GetNotifications,
3113 response: Response<proto::GetNotifications>,
3114 session: Session,
3115) -> Result<()> {
3116 let notifications = session
3117 .db()
3118 .await
3119 .get_notifications(
3120 session.user_id,
3121 NOTIFICATION_COUNT_PER_PAGE,
3122 request
3123 .before_id
3124 .map(|id| db::NotificationId::from_proto(id)),
3125 )
3126 .await?;
3127 response.send(proto::GetNotificationsResponse {
3128 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3129 notifications,
3130 })?;
3131 Ok(())
3132}
3133
3134async fn mark_notification_as_read(
3135 request: proto::MarkNotificationRead,
3136 response: Response<proto::MarkNotificationRead>,
3137 session: Session,
3138) -> Result<()> {
3139 let database = &session.db().await;
3140 let notifications = database
3141 .mark_notification_as_read_by_id(
3142 session.user_id,
3143 NotificationId::from_proto(request.notification_id),
3144 )
3145 .await?;
3146 send_notifications(
3147 &*session.connection_pool().await,
3148 &session.peer,
3149 notifications,
3150 );
3151 response.send(proto::Ack {})?;
3152 Ok(())
3153}
3154
3155async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
3156 let project_id = ProjectId::from_proto(request.project_id);
3157 let project_connection_ids = session
3158 .db()
3159 .await
3160 .project_connection_ids(project_id, session.connection_id)
3161 .await?;
3162 broadcast(
3163 Some(session.connection_id),
3164 project_connection_ids.iter().copied(),
3165 |connection_id| {
3166 session
3167 .peer
3168 .forward_send(session.connection_id, connection_id, request.clone())
3169 },
3170 );
3171 Ok(())
3172}
3173
3174async fn get_private_user_info(
3175 _request: proto::GetPrivateUserInfo,
3176 response: Response<proto::GetPrivateUserInfo>,
3177 session: Session,
3178) -> Result<()> {
3179 let db = session.db().await;
3180
3181 let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3182 let user = db
3183 .get_user_by_id(session.user_id)
3184 .await?
3185 .ok_or_else(|| anyhow!("user not found"))?;
3186 let flags = db.get_user_flags(session.user_id).await?;
3187
3188 response.send(proto::GetPrivateUserInfoResponse {
3189 metrics_id,
3190 staff: user.admin,
3191 flags,
3192 })?;
3193 Ok(())
3194}
3195
3196fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3197 match message {
3198 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3199 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3200 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3201 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3202 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3203 code: frame.code.into(),
3204 reason: frame.reason,
3205 })),
3206 }
3207}
3208
3209fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3210 match message {
3211 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3212 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3213 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3214 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3215 AxumMessage::Close(frame) => {
3216 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3217 code: frame.code.into(),
3218 reason: frame.reason,
3219 }))
3220 }
3221 }
3222}
3223
3224fn notify_membership_updated(
3225 connection_pool: &ConnectionPool,
3226 result: MembershipUpdated,
3227 user_id: UserId,
3228 peer: &Peer,
3229) {
3230 let mut update = build_channels_update(result.new_channels, vec![]);
3231 update.delete_channels = result
3232 .removed_channels
3233 .into_iter()
3234 .map(|id| id.to_proto())
3235 .collect();
3236 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3237
3238 for connection_id in connection_pool.user_connection_ids(user_id) {
3239 peer.send(connection_id, update.clone()).trace_err();
3240 }
3241}
3242
3243fn build_channels_update(
3244 channels: ChannelsForUser,
3245 channel_invites: Vec<db::Channel>,
3246) -> proto::UpdateChannels {
3247 let mut update = proto::UpdateChannels::default();
3248
3249 for channel in channels.channels.channels {
3250 update.channels.push(proto::Channel {
3251 id: channel.id.to_proto(),
3252 name: channel.name,
3253 visibility: channel.visibility.into(),
3254 role: channel.role.into(),
3255 });
3256 }
3257
3258 update.unseen_channel_buffer_changes = channels.unseen_buffer_changes;
3259 update.unseen_channel_messages = channels.channel_messages;
3260 update.insert_edge = channels.channels.edges;
3261
3262 for (channel_id, participants) in channels.channel_participants {
3263 update
3264 .channel_participants
3265 .push(proto::ChannelParticipants {
3266 channel_id: channel_id.to_proto(),
3267 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3268 });
3269 }
3270
3271 for channel in channel_invites {
3272 update.channel_invitations.push(proto::Channel {
3273 id: channel.id.to_proto(),
3274 name: channel.name,
3275 visibility: channel.visibility.into(),
3276 role: channel.role.into(),
3277 });
3278 }
3279
3280 update
3281}
3282
3283fn build_initial_contacts_update(
3284 contacts: Vec<db::Contact>,
3285 pool: &ConnectionPool,
3286) -> proto::UpdateContacts {
3287 let mut update = proto::UpdateContacts::default();
3288
3289 for contact in contacts {
3290 match contact {
3291 db::Contact::Accepted { user_id, busy } => {
3292 update.contacts.push(contact_for_user(user_id, busy, &pool));
3293 }
3294 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3295 db::Contact::Incoming { user_id } => {
3296 update
3297 .incoming_requests
3298 .push(proto::IncomingContactRequest {
3299 requester_id: user_id.to_proto(),
3300 })
3301 }
3302 }
3303 }
3304
3305 update
3306}
3307
3308fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3309 proto::Contact {
3310 user_id: user_id.to_proto(),
3311 online: pool.is_user_online(user_id),
3312 busy,
3313 }
3314}
3315
3316fn room_updated(room: &proto::Room, peer: &Peer) {
3317 broadcast(
3318 None,
3319 room.participants
3320 .iter()
3321 .filter_map(|participant| Some(participant.peer_id?.into())),
3322 |peer_id| {
3323 peer.send(
3324 peer_id.into(),
3325 proto::RoomUpdated {
3326 room: Some(room.clone()),
3327 },
3328 )
3329 },
3330 );
3331}
3332
3333fn channel_updated(
3334 channel_id: ChannelId,
3335 room: &proto::Room,
3336 channel_members: &[UserId],
3337 peer: &Peer,
3338 pool: &ConnectionPool,
3339) {
3340 let participants = room
3341 .participants
3342 .iter()
3343 .map(|p| p.user_id)
3344 .collect::<Vec<_>>();
3345
3346 broadcast(
3347 None,
3348 channel_members
3349 .iter()
3350 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3351 |peer_id| {
3352 peer.send(
3353 peer_id.into(),
3354 proto::UpdateChannels {
3355 channel_participants: vec![proto::ChannelParticipants {
3356 channel_id: channel_id.to_proto(),
3357 participant_user_ids: participants.clone(),
3358 }],
3359 ..Default::default()
3360 },
3361 )
3362 },
3363 );
3364}
3365
3366async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3367 let db = session.db().await;
3368
3369 let contacts = db.get_contacts(user_id).await?;
3370 let busy = db.is_user_busy(user_id).await?;
3371
3372 let pool = session.connection_pool().await;
3373 let updated_contact = contact_for_user(user_id, busy, &pool);
3374 for contact in contacts {
3375 if let db::Contact::Accepted {
3376 user_id: contact_user_id,
3377 ..
3378 } = contact
3379 {
3380 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3381 session
3382 .peer
3383 .send(
3384 contact_conn_id,
3385 proto::UpdateContacts {
3386 contacts: vec![updated_contact.clone()],
3387 remove_contacts: Default::default(),
3388 incoming_requests: Default::default(),
3389 remove_incoming_requests: Default::default(),
3390 outgoing_requests: Default::default(),
3391 remove_outgoing_requests: Default::default(),
3392 },
3393 )
3394 .trace_err();
3395 }
3396 }
3397 }
3398 Ok(())
3399}
3400
3401async fn leave_room_for_session(session: &Session) -> Result<()> {
3402 let mut contacts_to_update = HashSet::default();
3403
3404 let room_id;
3405 let canceled_calls_to_user_ids;
3406 let live_kit_room;
3407 let delete_live_kit_room;
3408 let room;
3409 let channel_members;
3410 let channel_id;
3411
3412 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3413 contacts_to_update.insert(session.user_id);
3414
3415 for project in left_room.left_projects.values() {
3416 project_left(project, session);
3417 }
3418
3419 room_id = RoomId::from_proto(left_room.room.id);
3420 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3421 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3422 delete_live_kit_room = left_room.deleted;
3423 room = mem::take(&mut left_room.room);
3424 channel_members = mem::take(&mut left_room.channel_members);
3425 channel_id = left_room.channel_id;
3426
3427 room_updated(&room, &session.peer);
3428 } else {
3429 return Ok(());
3430 }
3431
3432 if let Some(channel_id) = channel_id {
3433 channel_updated(
3434 channel_id,
3435 &room,
3436 &channel_members,
3437 &session.peer,
3438 &*session.connection_pool().await,
3439 );
3440 }
3441
3442 {
3443 let pool = session.connection_pool().await;
3444 for canceled_user_id in canceled_calls_to_user_ids {
3445 for connection_id in pool.user_connection_ids(canceled_user_id) {
3446 session
3447 .peer
3448 .send(
3449 connection_id,
3450 proto::CallCanceled {
3451 room_id: room_id.to_proto(),
3452 },
3453 )
3454 .trace_err();
3455 }
3456 contacts_to_update.insert(canceled_user_id);
3457 }
3458 }
3459
3460 for contact_user_id in contacts_to_update {
3461 update_user_contacts(contact_user_id, &session).await?;
3462 }
3463
3464 if let Some(live_kit) = session.live_kit_client.as_ref() {
3465 live_kit
3466 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3467 .await
3468 .trace_err();
3469
3470 if delete_live_kit_room {
3471 live_kit.delete_room(live_kit_room).await.trace_err();
3472 }
3473 }
3474
3475 Ok(())
3476}
3477
3478async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3479 let left_channel_buffers = session
3480 .db()
3481 .await
3482 .leave_channel_buffers(session.connection_id)
3483 .await?;
3484
3485 for left_buffer in left_channel_buffers {
3486 channel_buffer_updated(
3487 session.connection_id,
3488 left_buffer.connections,
3489 &proto::UpdateChannelBufferCollaborators {
3490 channel_id: left_buffer.channel_id.to_proto(),
3491 collaborators: left_buffer.collaborators,
3492 },
3493 &session.peer,
3494 );
3495 }
3496
3497 Ok(())
3498}
3499
3500fn project_left(project: &db::LeftProject, session: &Session) {
3501 for connection_id in &project.connection_ids {
3502 if project.host_user_id == session.user_id {
3503 session
3504 .peer
3505 .send(
3506 *connection_id,
3507 proto::UnshareProject {
3508 project_id: project.id.to_proto(),
3509 },
3510 )
3511 .trace_err();
3512 } else {
3513 session
3514 .peer
3515 .send(
3516 *connection_id,
3517 proto::RemoveProjectCollaborator {
3518 project_id: project.id.to_proto(),
3519 peer_id: Some(session.connection_id.into()),
3520 },
3521 )
3522 .trace_err();
3523 }
3524 }
3525}
3526
3527pub trait ResultExt {
3528 type Ok;
3529
3530 fn trace_err(self) -> Option<Self::Ok>;
3531}
3532
3533impl<T, E> ResultExt for Result<T, E>
3534where
3535 E: std::fmt::Debug,
3536{
3537 type Ok = T;
3538
3539 fn trace_err(self) -> Option<T> {
3540 match self {
3541 Ok(value) => Some(value),
3542 Err(error) => {
3543 tracing::error!("{:?}", error);
3544 None
3545 }
3546 }
3547 }
3548}