1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{
6 self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreateChannelResult,
7 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
8 MoveChannelResult, NotificationId, ProjectId, RemoveChannelMemberResult,
9 RenameChannelResult, RespondToChannelInvite, RoomId, ServerId, SetChannelVisibilityResult,
10 User, UserId,
11 },
12 executor::Executor,
13 AppState, Result,
14};
15use anyhow::anyhow;
16use async_tungstenite::tungstenite::{
17 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
18};
19use axum::{
20 body::Body,
21 extract::{
22 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
23 ConnectInfo, WebSocketUpgrade,
24 },
25 headers::{Header, HeaderName},
26 http::StatusCode,
27 middleware,
28 response::IntoResponse,
29 routing::get,
30 Extension, Router, TypedHeader,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::ConnectionPool;
34use futures::{
35 channel::oneshot,
36 future::{self, BoxFuture},
37 stream::FuturesUnordered,
38 FutureExt, SinkExt, StreamExt, TryStreamExt,
39};
40use lazy_static::lazy_static;
41use prometheus::{register_int_gauge, IntGauge};
42use rpc::{
43 proto::{
44 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
45 RequestMessage, UpdateChannelBufferCollaborators,
46 },
47 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
48};
49use serde::{Serialize, Serializer};
50use std::{
51 any::TypeId,
52 fmt,
53 future::Future,
54 marker::PhantomData,
55 mem,
56 net::SocketAddr,
57 ops::{Deref, DerefMut},
58 rc::Rc,
59 sync::{
60 atomic::{AtomicBool, Ordering::SeqCst},
61 Arc,
62 },
63 time::{Duration, Instant},
64};
65use time::OffsetDateTime;
66use tokio::sync::{watch, Semaphore};
67use tower::ServiceBuilder;
68use tracing::{info_span, instrument, Instrument};
69
70pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
71pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
72
73const MESSAGE_COUNT_PER_PAGE: usize = 100;
74const MAX_MESSAGE_LEN: usize = 1024;
75const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
76
77lazy_static! {
78 static ref METRIC_CONNECTIONS: IntGauge =
79 register_int_gauge!("connections", "number of connections").unwrap();
80 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
81 "shared_projects",
82 "number of open projects with one or more guests"
83 )
84 .unwrap();
85}
86
87type MessageHandler =
88 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
89
90struct Response<R> {
91 peer: Arc<Peer>,
92 receipt: Receipt<R>,
93 responded: Arc<AtomicBool>,
94}
95
96impl<R: RequestMessage> Response<R> {
97 fn send(self, payload: R::Response) -> Result<()> {
98 self.responded.store(true, SeqCst);
99 self.peer.respond(self.receipt, payload)?;
100 Ok(())
101 }
102}
103
104#[derive(Clone)]
105struct Session {
106 zed_environment: Arc<str>,
107 user_id: UserId,
108 connection_id: ConnectionId,
109 db: Arc<tokio::sync::Mutex<DbHandle>>,
110 peer: Arc<Peer>,
111 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
112 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
113 _executor: Executor,
114}
115
116impl Session {
117 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
118 #[cfg(test)]
119 tokio::task::yield_now().await;
120 let guard = self.db.lock().await;
121 #[cfg(test)]
122 tokio::task::yield_now().await;
123 guard
124 }
125
126 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
127 #[cfg(test)]
128 tokio::task::yield_now().await;
129 let guard = self.connection_pool.lock();
130 ConnectionPoolGuard {
131 guard,
132 _not_send: PhantomData,
133 }
134 }
135}
136
137impl fmt::Debug for Session {
138 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
139 f.debug_struct("Session")
140 .field("user_id", &self.user_id)
141 .field("connection_id", &self.connection_id)
142 .finish()
143 }
144}
145
146struct DbHandle(Arc<Database>);
147
148impl Deref for DbHandle {
149 type Target = Database;
150
151 fn deref(&self) -> &Self::Target {
152 self.0.as_ref()
153 }
154}
155
156pub struct Server {
157 id: parking_lot::Mutex<ServerId>,
158 peer: Arc<Peer>,
159 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
160 app_state: Arc<AppState>,
161 executor: Executor,
162 handlers: HashMap<TypeId, MessageHandler>,
163 teardown: watch::Sender<()>,
164}
165
166pub(crate) struct ConnectionPoolGuard<'a> {
167 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
168 _not_send: PhantomData<Rc<()>>,
169}
170
171#[derive(Serialize)]
172pub struct ServerSnapshot<'a> {
173 peer: &'a Peer,
174 #[serde(serialize_with = "serialize_deref")]
175 connection_pool: ConnectionPoolGuard<'a>,
176}
177
178pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
179where
180 S: Serializer,
181 T: Deref<Target = U>,
182 U: Serialize,
183{
184 Serialize::serialize(value.deref(), serializer)
185}
186
187impl Server {
188 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
189 let mut server = Self {
190 id: parking_lot::Mutex::new(id),
191 peer: Peer::new(id.0 as u32),
192 app_state,
193 executor,
194 connection_pool: Default::default(),
195 handlers: Default::default(),
196 teardown: watch::channel(()).0,
197 };
198
199 server
200 .add_request_handler(ping)
201 .add_request_handler(create_room)
202 .add_request_handler(join_room)
203 .add_request_handler(rejoin_room)
204 .add_request_handler(leave_room)
205 .add_request_handler(call)
206 .add_request_handler(cancel_call)
207 .add_message_handler(decline_call)
208 .add_request_handler(update_participant_location)
209 .add_request_handler(share_project)
210 .add_message_handler(unshare_project)
211 .add_request_handler(join_project)
212 .add_message_handler(leave_project)
213 .add_request_handler(update_project)
214 .add_request_handler(update_worktree)
215 .add_message_handler(start_language_server)
216 .add_message_handler(update_language_server)
217 .add_message_handler(update_diagnostic_summary)
218 .add_message_handler(update_worktree_settings)
219 .add_message_handler(refresh_inlay_hints)
220 .add_request_handler(forward_project_request::<proto::GetHover>)
221 .add_request_handler(forward_project_request::<proto::GetDefinition>)
222 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
223 .add_request_handler(forward_project_request::<proto::GetReferences>)
224 .add_request_handler(forward_project_request::<proto::SearchProject>)
225 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
226 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
227 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
228 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
229 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
230 .add_request_handler(forward_project_request::<proto::GetCompletions>)
231 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
232 .add_request_handler(forward_project_request::<proto::ResolveCompletionDocumentation>)
233 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
234 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
235 .add_request_handler(forward_project_request::<proto::PrepareRename>)
236 .add_request_handler(forward_project_request::<proto::PerformRename>)
237 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
238 .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
239 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
240 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
241 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
242 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
243 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
244 .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
245 .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
246 .add_request_handler(forward_project_request::<proto::InlayHints>)
247 .add_message_handler(create_buffer_for_peer)
248 .add_request_handler(update_buffer)
249 .add_message_handler(update_buffer_file)
250 .add_message_handler(buffer_reloaded)
251 .add_message_handler(buffer_saved)
252 .add_request_handler(forward_project_request::<proto::SaveBuffer>)
253 .add_request_handler(get_users)
254 .add_request_handler(fuzzy_search_users)
255 .add_request_handler(request_contact)
256 .add_request_handler(remove_contact)
257 .add_request_handler(respond_to_contact_request)
258 .add_request_handler(create_channel)
259 .add_request_handler(delete_channel)
260 .add_request_handler(invite_channel_member)
261 .add_request_handler(remove_channel_member)
262 .add_request_handler(set_channel_member_role)
263 .add_request_handler(set_channel_visibility)
264 .add_request_handler(rename_channel)
265 .add_request_handler(join_channel_buffer)
266 .add_request_handler(leave_channel_buffer)
267 .add_message_handler(update_channel_buffer)
268 .add_request_handler(rejoin_channel_buffers)
269 .add_request_handler(get_channel_members)
270 .add_request_handler(respond_to_channel_invite)
271 .add_request_handler(join_channel)
272 .add_request_handler(join_channel_chat)
273 .add_message_handler(leave_channel_chat)
274 .add_request_handler(send_channel_message)
275 .add_request_handler(remove_channel_message)
276 .add_request_handler(get_channel_messages)
277 .add_request_handler(get_channel_messages_by_id)
278 .add_request_handler(get_notifications)
279 .add_request_handler(mark_notification_as_read)
280 .add_request_handler(move_channel)
281 .add_request_handler(follow)
282 .add_message_handler(unfollow)
283 .add_message_handler(update_followers)
284 .add_message_handler(update_diff_base)
285 .add_request_handler(get_private_user_info)
286 .add_message_handler(acknowledge_channel_message)
287 .add_message_handler(acknowledge_buffer_version);
288
289 Arc::new(server)
290 }
291
292 pub async fn start(&self) -> Result<()> {
293 let server_id = *self.id.lock();
294 let app_state = self.app_state.clone();
295 let peer = self.peer.clone();
296 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
297 let pool = self.connection_pool.clone();
298 let live_kit_client = self.app_state.live_kit_client.clone();
299
300 let span = info_span!("start server");
301 self.executor.spawn_detached(
302 async move {
303 tracing::info!("waiting for cleanup timeout");
304 timeout.await;
305 tracing::info!("cleanup timeout expired, retrieving stale rooms");
306 if let Some((room_ids, channel_ids)) = app_state
307 .db
308 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
309 .await
310 .trace_err()
311 {
312 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
313 tracing::info!(
314 stale_channel_buffer_count = channel_ids.len(),
315 "retrieved stale channel buffers"
316 );
317
318 for channel_id in channel_ids {
319 if let Some(refreshed_channel_buffer) = app_state
320 .db
321 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
322 .await
323 .trace_err()
324 {
325 for connection_id in refreshed_channel_buffer.connection_ids {
326 peer.send(
327 connection_id,
328 proto::UpdateChannelBufferCollaborators {
329 channel_id: channel_id.to_proto(),
330 collaborators: refreshed_channel_buffer
331 .collaborators
332 .clone(),
333 },
334 )
335 .trace_err();
336 }
337 }
338 }
339
340 for room_id in room_ids {
341 let mut contacts_to_update = HashSet::default();
342 let mut canceled_calls_to_user_ids = Vec::new();
343 let mut live_kit_room = String::new();
344 let mut delete_live_kit_room = false;
345
346 if let Some(mut refreshed_room) = app_state
347 .db
348 .clear_stale_room_participants(room_id, server_id)
349 .await
350 .trace_err()
351 {
352 tracing::info!(
353 room_id = room_id.0,
354 new_participant_count = refreshed_room.room.participants.len(),
355 "refreshed room"
356 );
357 room_updated(&refreshed_room.room, &peer);
358 if let Some(channel_id) = refreshed_room.channel_id {
359 channel_updated(
360 channel_id,
361 &refreshed_room.room,
362 &refreshed_room.channel_members,
363 &peer,
364 &*pool.lock(),
365 );
366 }
367 contacts_to_update
368 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
369 contacts_to_update
370 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
371 canceled_calls_to_user_ids =
372 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
373 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
374 delete_live_kit_room = refreshed_room.room.participants.is_empty();
375 }
376
377 {
378 let pool = pool.lock();
379 for canceled_user_id in canceled_calls_to_user_ids {
380 for connection_id in pool.user_connection_ids(canceled_user_id) {
381 peer.send(
382 connection_id,
383 proto::CallCanceled {
384 room_id: room_id.to_proto(),
385 },
386 )
387 .trace_err();
388 }
389 }
390 }
391
392 for user_id in contacts_to_update {
393 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
394 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
395 if let Some((busy, contacts)) = busy.zip(contacts) {
396 let pool = pool.lock();
397 let updated_contact = contact_for_user(user_id, busy, &pool);
398 for contact in contacts {
399 if let db::Contact::Accepted {
400 user_id: contact_user_id,
401 ..
402 } = contact
403 {
404 for contact_conn_id in
405 pool.user_connection_ids(contact_user_id)
406 {
407 peer.send(
408 contact_conn_id,
409 proto::UpdateContacts {
410 contacts: vec![updated_contact.clone()],
411 remove_contacts: Default::default(),
412 incoming_requests: Default::default(),
413 remove_incoming_requests: Default::default(),
414 outgoing_requests: Default::default(),
415 remove_outgoing_requests: Default::default(),
416 },
417 )
418 .trace_err();
419 }
420 }
421 }
422 }
423 }
424
425 if let Some(live_kit) = live_kit_client.as_ref() {
426 if delete_live_kit_room {
427 live_kit.delete_room(live_kit_room).await.trace_err();
428 }
429 }
430 }
431 }
432
433 app_state
434 .db
435 .delete_stale_servers(&app_state.config.zed_environment, server_id)
436 .await
437 .trace_err();
438 }
439 .instrument(span),
440 );
441 Ok(())
442 }
443
444 pub fn teardown(&self) {
445 self.peer.teardown();
446 self.connection_pool.lock().reset();
447 let _ = self.teardown.send(());
448 }
449
450 #[cfg(test)]
451 pub fn reset(&self, id: ServerId) {
452 self.teardown();
453 *self.id.lock() = id;
454 self.peer.reset(id.0 as u32);
455 }
456
457 #[cfg(test)]
458 pub fn id(&self) -> ServerId {
459 *self.id.lock()
460 }
461
462 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
463 where
464 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
465 Fut: 'static + Send + Future<Output = Result<()>>,
466 M: EnvelopedMessage,
467 {
468 let prev_handler = self.handlers.insert(
469 TypeId::of::<M>(),
470 Box::new(move |envelope, session| {
471 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
472 let span = info_span!(
473 "handle message",
474 payload_type = envelope.payload_type_name()
475 );
476 span.in_scope(|| {
477 tracing::info!(
478 payload_type = envelope.payload_type_name(),
479 "message received"
480 );
481 });
482 let start_time = Instant::now();
483 let future = (handler)(*envelope, session);
484 async move {
485 let result = future.await;
486 let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
487 match result {
488 Err(error) => {
489 tracing::error!(%error, ?duration_ms, "error handling message")
490 }
491 Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
492 }
493 }
494 .instrument(span)
495 .boxed()
496 }),
497 );
498 if prev_handler.is_some() {
499 panic!("registered a handler for the same message twice");
500 }
501 self
502 }
503
504 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
505 where
506 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
507 Fut: 'static + Send + Future<Output = Result<()>>,
508 M: EnvelopedMessage,
509 {
510 self.add_handler(move |envelope, session| handler(envelope.payload, session));
511 self
512 }
513
514 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
515 where
516 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
517 Fut: Send + Future<Output = Result<()>>,
518 M: RequestMessage,
519 {
520 let handler = Arc::new(handler);
521 self.add_handler(move |envelope, session| {
522 let receipt = envelope.receipt();
523 let handler = handler.clone();
524 async move {
525 let peer = session.peer.clone();
526 let responded = Arc::new(AtomicBool::default());
527 let response = Response {
528 peer: peer.clone(),
529 responded: responded.clone(),
530 receipt,
531 };
532 match (handler)(envelope.payload, response, session).await {
533 Ok(()) => {
534 if responded.load(std::sync::atomic::Ordering::SeqCst) {
535 Ok(())
536 } else {
537 Err(anyhow!("handler did not send a response"))?
538 }
539 }
540 Err(error) => {
541 peer.respond_with_error(
542 receipt,
543 proto::Error {
544 message: error.to_string(),
545 },
546 )?;
547 Err(error)
548 }
549 }
550 }
551 })
552 }
553
554 pub fn handle_connection(
555 self: &Arc<Self>,
556 connection: Connection,
557 address: String,
558 user: User,
559 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
560 executor: Executor,
561 ) -> impl Future<Output = Result<()>> {
562 let this = self.clone();
563 let user_id = user.id;
564 let login = user.github_login;
565 let span = info_span!("handle connection", %user_id, %login, %address);
566 let mut teardown = self.teardown.subscribe();
567 async move {
568 let (connection_id, handle_io, mut incoming_rx) = this
569 .peer
570 .add_connection(connection, {
571 let executor = executor.clone();
572 move |duration| executor.sleep(duration)
573 });
574
575 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
576 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
577 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
578
579 if let Some(send_connection_id) = send_connection_id.take() {
580 let _ = send_connection_id.send(connection_id);
581 }
582
583 if !user.connected_once {
584 this.peer.send(connection_id, proto::ShowContacts {})?;
585 this.app_state.db.set_user_connected_once(user_id, true).await?;
586 }
587
588 let (contacts, channels_for_user, channel_invites) = future::try_join3(
589 this.app_state.db.get_contacts(user_id),
590 this.app_state.db.get_channels_for_user(user_id),
591 this.app_state.db.get_channel_invites_for_user(user_id),
592 ).await?;
593
594 {
595 let mut pool = this.connection_pool.lock();
596 pool.add_connection(connection_id, user_id, user.admin);
597 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
598 this.peer.send(connection_id, build_channels_update(
599 channels_for_user,
600 channel_invites
601 ))?;
602 }
603
604 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
605 this.peer.send(connection_id, incoming_call)?;
606 }
607
608 let session = Session {
609 user_id,
610 connection_id,
611 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
612 zed_environment: this.app_state.config.zed_environment.clone(),
613 peer: this.peer.clone(),
614 connection_pool: this.connection_pool.clone(),
615 live_kit_client: this.app_state.live_kit_client.clone(),
616 _executor: executor.clone()
617 };
618 update_user_contacts(user_id, &session).await?;
619
620 let handle_io = handle_io.fuse();
621 futures::pin_mut!(handle_io);
622
623 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
624 // This prevents deadlocks when e.g., client A performs a request to client B and
625 // client B performs a request to client A. If both clients stop processing further
626 // messages until their respective request completes, they won't have a chance to
627 // respond to the other client's request and cause a deadlock.
628 //
629 // This arrangement ensures we will attempt to process earlier messages first, but fall
630 // back to processing messages arrived later in the spirit of making progress.
631 let mut foreground_message_handlers = FuturesUnordered::new();
632 let concurrent_handlers = Arc::new(Semaphore::new(256));
633 loop {
634 let next_message = async {
635 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
636 let message = incoming_rx.next().await;
637 (permit, message)
638 }.fuse();
639 futures::pin_mut!(next_message);
640 futures::select_biased! {
641 _ = teardown.changed().fuse() => return Ok(()),
642 result = handle_io => {
643 if let Err(error) = result {
644 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
645 }
646 break;
647 }
648 _ = foreground_message_handlers.next() => {}
649 next_message = next_message => {
650 let (permit, message) = next_message;
651 if let Some(message) = message {
652 let type_name = message.payload_type_name();
653 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
654 let span_enter = span.enter();
655 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
656 let is_background = message.is_background();
657 let handle_message = (handler)(message, session.clone());
658 drop(span_enter);
659
660 let handle_message = async move {
661 handle_message.await;
662 drop(permit);
663 }.instrument(span);
664 if is_background {
665 executor.spawn_detached(handle_message);
666 } else {
667 foreground_message_handlers.push(handle_message);
668 }
669 } else {
670 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
671 }
672 } else {
673 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
674 break;
675 }
676 }
677 }
678 }
679
680 drop(foreground_message_handlers);
681 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
682 if let Err(error) = connection_lost(session, teardown, executor).await {
683 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
684 }
685
686 Ok(())
687 }.instrument(span)
688 }
689
690 pub async fn invite_code_redeemed(
691 self: &Arc<Self>,
692 inviter_id: UserId,
693 invitee_id: UserId,
694 ) -> Result<()> {
695 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
696 if let Some(code) = &user.invite_code {
697 let pool = self.connection_pool.lock();
698 let invitee_contact = contact_for_user(invitee_id, false, &pool);
699 for connection_id in pool.user_connection_ids(inviter_id) {
700 self.peer.send(
701 connection_id,
702 proto::UpdateContacts {
703 contacts: vec![invitee_contact.clone()],
704 ..Default::default()
705 },
706 )?;
707 self.peer.send(
708 connection_id,
709 proto::UpdateInviteInfo {
710 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
711 count: user.invite_count as u32,
712 },
713 )?;
714 }
715 }
716 }
717 Ok(())
718 }
719
720 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
721 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
722 if let Some(invite_code) = &user.invite_code {
723 let pool = self.connection_pool.lock();
724 for connection_id in pool.user_connection_ids(user_id) {
725 self.peer.send(
726 connection_id,
727 proto::UpdateInviteInfo {
728 url: format!(
729 "{}{}",
730 self.app_state.config.invite_link_prefix, invite_code
731 ),
732 count: user.invite_count as u32,
733 },
734 )?;
735 }
736 }
737 }
738 Ok(())
739 }
740
741 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
742 ServerSnapshot {
743 connection_pool: ConnectionPoolGuard {
744 guard: self.connection_pool.lock(),
745 _not_send: PhantomData,
746 },
747 peer: &self.peer,
748 }
749 }
750}
751
752impl<'a> Deref for ConnectionPoolGuard<'a> {
753 type Target = ConnectionPool;
754
755 fn deref(&self) -> &Self::Target {
756 &*self.guard
757 }
758}
759
760impl<'a> DerefMut for ConnectionPoolGuard<'a> {
761 fn deref_mut(&mut self) -> &mut Self::Target {
762 &mut *self.guard
763 }
764}
765
766impl<'a> Drop for ConnectionPoolGuard<'a> {
767 fn drop(&mut self) {
768 #[cfg(test)]
769 self.check_invariants();
770 }
771}
772
773fn broadcast<F>(
774 sender_id: Option<ConnectionId>,
775 receiver_ids: impl IntoIterator<Item = ConnectionId>,
776 mut f: F,
777) where
778 F: FnMut(ConnectionId) -> anyhow::Result<()>,
779{
780 for receiver_id in receiver_ids {
781 if Some(receiver_id) != sender_id {
782 if let Err(error) = f(receiver_id) {
783 tracing::error!("failed to send to {:?} {}", receiver_id, error);
784 }
785 }
786 }
787}
788
789lazy_static! {
790 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
791}
792
793pub struct ProtocolVersion(u32);
794
795impl Header for ProtocolVersion {
796 fn name() -> &'static HeaderName {
797 &ZED_PROTOCOL_VERSION
798 }
799
800 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
801 where
802 Self: Sized,
803 I: Iterator<Item = &'i axum::http::HeaderValue>,
804 {
805 let version = values
806 .next()
807 .ok_or_else(axum::headers::Error::invalid)?
808 .to_str()
809 .map_err(|_| axum::headers::Error::invalid())?
810 .parse()
811 .map_err(|_| axum::headers::Error::invalid())?;
812 Ok(Self(version))
813 }
814
815 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
816 values.extend([self.0.to_string().parse().unwrap()]);
817 }
818}
819
820pub fn routes(server: Arc<Server>) -> Router<Body> {
821 Router::new()
822 .route("/rpc", get(handle_websocket_request))
823 .layer(
824 ServiceBuilder::new()
825 .layer(Extension(server.app_state.clone()))
826 .layer(middleware::from_fn(auth::validate_header)),
827 )
828 .route("/metrics", get(handle_metrics))
829 .layer(Extension(server))
830}
831
832pub async fn handle_websocket_request(
833 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
834 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
835 Extension(server): Extension<Arc<Server>>,
836 Extension(user): Extension<User>,
837 ws: WebSocketUpgrade,
838) -> axum::response::Response {
839 if protocol_version != rpc::PROTOCOL_VERSION {
840 return (
841 StatusCode::UPGRADE_REQUIRED,
842 "client must be upgraded".to_string(),
843 )
844 .into_response();
845 }
846 let socket_address = socket_address.to_string();
847 ws.on_upgrade(move |socket| {
848 use util::ResultExt;
849 let socket = socket
850 .map_ok(to_tungstenite_message)
851 .err_into()
852 .with(|message| async move { Ok(to_axum_message(message)) });
853 let connection = Connection::new(Box::pin(socket));
854 async move {
855 server
856 .handle_connection(connection, socket_address, user, None, Executor::Production)
857 .await
858 .log_err();
859 }
860 })
861}
862
863pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
864 let connections = server
865 .connection_pool
866 .lock()
867 .connections()
868 .filter(|connection| !connection.admin)
869 .count();
870
871 METRIC_CONNECTIONS.set(connections as _);
872
873 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
874 METRIC_SHARED_PROJECTS.set(shared_projects as _);
875
876 let encoder = prometheus::TextEncoder::new();
877 let metric_families = prometheus::gather();
878 let encoded_metrics = encoder
879 .encode_to_string(&metric_families)
880 .map_err(|err| anyhow!("{}", err))?;
881 Ok(encoded_metrics)
882}
883
884#[instrument(err, skip(executor))]
885async fn connection_lost(
886 session: Session,
887 mut teardown: watch::Receiver<()>,
888 executor: Executor,
889) -> Result<()> {
890 session.peer.disconnect(session.connection_id);
891 session
892 .connection_pool()
893 .await
894 .remove_connection(session.connection_id)?;
895
896 session
897 .db()
898 .await
899 .connection_lost(session.connection_id)
900 .await
901 .trace_err();
902
903 futures::select_biased! {
904 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
905 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
906 leave_room_for_session(&session).await.trace_err();
907 leave_channel_buffers_for_session(&session)
908 .await
909 .trace_err();
910
911 if !session
912 .connection_pool()
913 .await
914 .is_user_online(session.user_id)
915 {
916 let db = session.db().await;
917 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
918 room_updated(&room, &session.peer);
919 }
920 }
921
922 update_user_contacts(session.user_id, &session).await?;
923 }
924 _ = teardown.changed().fuse() => {}
925 }
926
927 Ok(())
928}
929
930async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
931 response.send(proto::Ack {})?;
932 Ok(())
933}
934
935async fn create_room(
936 _request: proto::CreateRoom,
937 response: Response<proto::CreateRoom>,
938 session: Session,
939) -> Result<()> {
940 let live_kit_room = nanoid::nanoid!(30);
941
942 let live_kit_connection_info = {
943 let live_kit_room = live_kit_room.clone();
944 let live_kit = session.live_kit_client.as_ref();
945
946 util::async_maybe!({
947 let live_kit = live_kit?;
948
949 let token = live_kit
950 .room_token(&live_kit_room, &session.user_id.to_string())
951 .trace_err()?;
952
953 Some(proto::LiveKitConnectionInfo {
954 server_url: live_kit.url().into(),
955 token,
956 can_publish: true,
957 })
958 })
959 }
960 .await;
961
962 let room = session
963 .db()
964 .await
965 .create_room(
966 session.user_id,
967 session.connection_id,
968 &live_kit_room,
969 &session.zed_environment,
970 )
971 .await?;
972
973 response.send(proto::CreateRoomResponse {
974 room: Some(room.clone()),
975 live_kit_connection_info,
976 })?;
977
978 update_user_contacts(session.user_id, &session).await?;
979 Ok(())
980}
981
982async fn join_room(
983 request: proto::JoinRoom,
984 response: Response<proto::JoinRoom>,
985 session: Session,
986) -> Result<()> {
987 let room_id = RoomId::from_proto(request.id);
988
989 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
990
991 if let Some(channel_id) = channel_id {
992 return join_channel_internal(channel_id, Box::new(response), session).await;
993 }
994
995 let joined_room = {
996 let room = session
997 .db()
998 .await
999 .join_room(
1000 room_id,
1001 session.user_id,
1002 session.connection_id,
1003 session.zed_environment.as_ref(),
1004 )
1005 .await?;
1006 room_updated(&room.room, &session.peer);
1007 room.into_inner()
1008 };
1009
1010 for connection_id in session
1011 .connection_pool()
1012 .await
1013 .user_connection_ids(session.user_id)
1014 {
1015 session
1016 .peer
1017 .send(
1018 connection_id,
1019 proto::CallCanceled {
1020 room_id: room_id.to_proto(),
1021 },
1022 )
1023 .trace_err();
1024 }
1025
1026 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1027 if let Some(token) = live_kit
1028 .room_token(
1029 &joined_room.room.live_kit_room,
1030 &session.user_id.to_string(),
1031 )
1032 .trace_err()
1033 {
1034 Some(proto::LiveKitConnectionInfo {
1035 server_url: live_kit.url().into(),
1036 token,
1037 can_publish: true,
1038 })
1039 } else {
1040 None
1041 }
1042 } else {
1043 None
1044 };
1045
1046 response.send(proto::JoinRoomResponse {
1047 room: Some(joined_room.room),
1048 channel_id: None,
1049 live_kit_connection_info,
1050 })?;
1051
1052 update_user_contacts(session.user_id, &session).await?;
1053 Ok(())
1054}
1055
1056async fn rejoin_room(
1057 request: proto::RejoinRoom,
1058 response: Response<proto::RejoinRoom>,
1059 session: Session,
1060) -> Result<()> {
1061 let room;
1062 let channel_id;
1063 let channel_members;
1064 {
1065 let mut rejoined_room = session
1066 .db()
1067 .await
1068 .rejoin_room(request, session.user_id, session.connection_id)
1069 .await?;
1070
1071 response.send(proto::RejoinRoomResponse {
1072 room: Some(rejoined_room.room.clone()),
1073 reshared_projects: rejoined_room
1074 .reshared_projects
1075 .iter()
1076 .map(|project| proto::ResharedProject {
1077 id: project.id.to_proto(),
1078 collaborators: project
1079 .collaborators
1080 .iter()
1081 .map(|collaborator| collaborator.to_proto())
1082 .collect(),
1083 })
1084 .collect(),
1085 rejoined_projects: rejoined_room
1086 .rejoined_projects
1087 .iter()
1088 .map(|rejoined_project| proto::RejoinedProject {
1089 id: rejoined_project.id.to_proto(),
1090 worktrees: rejoined_project
1091 .worktrees
1092 .iter()
1093 .map(|worktree| proto::WorktreeMetadata {
1094 id: worktree.id,
1095 root_name: worktree.root_name.clone(),
1096 visible: worktree.visible,
1097 abs_path: worktree.abs_path.clone(),
1098 })
1099 .collect(),
1100 collaborators: rejoined_project
1101 .collaborators
1102 .iter()
1103 .map(|collaborator| collaborator.to_proto())
1104 .collect(),
1105 language_servers: rejoined_project.language_servers.clone(),
1106 })
1107 .collect(),
1108 })?;
1109 room_updated(&rejoined_room.room, &session.peer);
1110
1111 for project in &rejoined_room.reshared_projects {
1112 for collaborator in &project.collaborators {
1113 session
1114 .peer
1115 .send(
1116 collaborator.connection_id,
1117 proto::UpdateProjectCollaborator {
1118 project_id: project.id.to_proto(),
1119 old_peer_id: Some(project.old_connection_id.into()),
1120 new_peer_id: Some(session.connection_id.into()),
1121 },
1122 )
1123 .trace_err();
1124 }
1125
1126 broadcast(
1127 Some(session.connection_id),
1128 project
1129 .collaborators
1130 .iter()
1131 .map(|collaborator| collaborator.connection_id),
1132 |connection_id| {
1133 session.peer.forward_send(
1134 session.connection_id,
1135 connection_id,
1136 proto::UpdateProject {
1137 project_id: project.id.to_proto(),
1138 worktrees: project.worktrees.clone(),
1139 },
1140 )
1141 },
1142 );
1143 }
1144
1145 for project in &rejoined_room.rejoined_projects {
1146 for collaborator in &project.collaborators {
1147 session
1148 .peer
1149 .send(
1150 collaborator.connection_id,
1151 proto::UpdateProjectCollaborator {
1152 project_id: project.id.to_proto(),
1153 old_peer_id: Some(project.old_connection_id.into()),
1154 new_peer_id: Some(session.connection_id.into()),
1155 },
1156 )
1157 .trace_err();
1158 }
1159 }
1160
1161 for project in &mut rejoined_room.rejoined_projects {
1162 for worktree in mem::take(&mut project.worktrees) {
1163 #[cfg(any(test, feature = "test-support"))]
1164 const MAX_CHUNK_SIZE: usize = 2;
1165 #[cfg(not(any(test, feature = "test-support")))]
1166 const MAX_CHUNK_SIZE: usize = 256;
1167
1168 // Stream this worktree's entries.
1169 let message = proto::UpdateWorktree {
1170 project_id: project.id.to_proto(),
1171 worktree_id: worktree.id,
1172 abs_path: worktree.abs_path.clone(),
1173 root_name: worktree.root_name,
1174 updated_entries: worktree.updated_entries,
1175 removed_entries: worktree.removed_entries,
1176 scan_id: worktree.scan_id,
1177 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1178 updated_repositories: worktree.updated_repositories,
1179 removed_repositories: worktree.removed_repositories,
1180 };
1181 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1182 session.peer.send(session.connection_id, update.clone())?;
1183 }
1184
1185 // Stream this worktree's diagnostics.
1186 for summary in worktree.diagnostic_summaries {
1187 session.peer.send(
1188 session.connection_id,
1189 proto::UpdateDiagnosticSummary {
1190 project_id: project.id.to_proto(),
1191 worktree_id: worktree.id,
1192 summary: Some(summary),
1193 },
1194 )?;
1195 }
1196
1197 for settings_file in worktree.settings_files {
1198 session.peer.send(
1199 session.connection_id,
1200 proto::UpdateWorktreeSettings {
1201 project_id: project.id.to_proto(),
1202 worktree_id: worktree.id,
1203 path: settings_file.path,
1204 content: Some(settings_file.content),
1205 },
1206 )?;
1207 }
1208 }
1209
1210 for language_server in &project.language_servers {
1211 session.peer.send(
1212 session.connection_id,
1213 proto::UpdateLanguageServer {
1214 project_id: project.id.to_proto(),
1215 language_server_id: language_server.id,
1216 variant: Some(
1217 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1218 proto::LspDiskBasedDiagnosticsUpdated {},
1219 ),
1220 ),
1221 },
1222 )?;
1223 }
1224 }
1225
1226 let rejoined_room = rejoined_room.into_inner();
1227
1228 room = rejoined_room.room;
1229 channel_id = rejoined_room.channel_id;
1230 channel_members = rejoined_room.channel_members;
1231 }
1232
1233 if let Some(channel_id) = channel_id {
1234 channel_updated(
1235 channel_id,
1236 &room,
1237 &channel_members,
1238 &session.peer,
1239 &*session.connection_pool().await,
1240 );
1241 }
1242
1243 update_user_contacts(session.user_id, &session).await?;
1244 Ok(())
1245}
1246
1247async fn leave_room(
1248 _: proto::LeaveRoom,
1249 response: Response<proto::LeaveRoom>,
1250 session: Session,
1251) -> Result<()> {
1252 leave_room_for_session(&session).await?;
1253 response.send(proto::Ack {})?;
1254 Ok(())
1255}
1256
1257async fn call(
1258 request: proto::Call,
1259 response: Response<proto::Call>,
1260 session: Session,
1261) -> Result<()> {
1262 let room_id = RoomId::from_proto(request.room_id);
1263 let calling_user_id = session.user_id;
1264 let calling_connection_id = session.connection_id;
1265 let called_user_id = UserId::from_proto(request.called_user_id);
1266 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1267 if !session
1268 .db()
1269 .await
1270 .has_contact(calling_user_id, called_user_id)
1271 .await?
1272 {
1273 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1274 }
1275
1276 let incoming_call = {
1277 let (room, incoming_call) = &mut *session
1278 .db()
1279 .await
1280 .call(
1281 room_id,
1282 calling_user_id,
1283 calling_connection_id,
1284 called_user_id,
1285 initial_project_id,
1286 )
1287 .await?;
1288 room_updated(&room, &session.peer);
1289 mem::take(incoming_call)
1290 };
1291 update_user_contacts(called_user_id, &session).await?;
1292
1293 let mut calls = session
1294 .connection_pool()
1295 .await
1296 .user_connection_ids(called_user_id)
1297 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1298 .collect::<FuturesUnordered<_>>();
1299
1300 while let Some(call_response) = calls.next().await {
1301 match call_response.as_ref() {
1302 Ok(_) => {
1303 response.send(proto::Ack {})?;
1304 return Ok(());
1305 }
1306 Err(_) => {
1307 call_response.trace_err();
1308 }
1309 }
1310 }
1311
1312 {
1313 let room = session
1314 .db()
1315 .await
1316 .call_failed(room_id, called_user_id)
1317 .await?;
1318 room_updated(&room, &session.peer);
1319 }
1320 update_user_contacts(called_user_id, &session).await?;
1321
1322 Err(anyhow!("failed to ring user"))?
1323}
1324
1325async fn cancel_call(
1326 request: proto::CancelCall,
1327 response: Response<proto::CancelCall>,
1328 session: Session,
1329) -> Result<()> {
1330 let called_user_id = UserId::from_proto(request.called_user_id);
1331 let room_id = RoomId::from_proto(request.room_id);
1332 {
1333 let room = session
1334 .db()
1335 .await
1336 .cancel_call(room_id, session.connection_id, called_user_id)
1337 .await?;
1338 room_updated(&room, &session.peer);
1339 }
1340
1341 for connection_id in session
1342 .connection_pool()
1343 .await
1344 .user_connection_ids(called_user_id)
1345 {
1346 session
1347 .peer
1348 .send(
1349 connection_id,
1350 proto::CallCanceled {
1351 room_id: room_id.to_proto(),
1352 },
1353 )
1354 .trace_err();
1355 }
1356 response.send(proto::Ack {})?;
1357
1358 update_user_contacts(called_user_id, &session).await?;
1359 Ok(())
1360}
1361
1362async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1363 let room_id = RoomId::from_proto(message.room_id);
1364 {
1365 let room = session
1366 .db()
1367 .await
1368 .decline_call(Some(room_id), session.user_id)
1369 .await?
1370 .ok_or_else(|| anyhow!("failed to decline call"))?;
1371 room_updated(&room, &session.peer);
1372 }
1373
1374 for connection_id in session
1375 .connection_pool()
1376 .await
1377 .user_connection_ids(session.user_id)
1378 {
1379 session
1380 .peer
1381 .send(
1382 connection_id,
1383 proto::CallCanceled {
1384 room_id: room_id.to_proto(),
1385 },
1386 )
1387 .trace_err();
1388 }
1389 update_user_contacts(session.user_id, &session).await?;
1390 Ok(())
1391}
1392
1393async fn update_participant_location(
1394 request: proto::UpdateParticipantLocation,
1395 response: Response<proto::UpdateParticipantLocation>,
1396 session: Session,
1397) -> Result<()> {
1398 let room_id = RoomId::from_proto(request.room_id);
1399 let location = request
1400 .location
1401 .ok_or_else(|| anyhow!("invalid location"))?;
1402
1403 let db = session.db().await;
1404 let room = db
1405 .update_room_participant_location(room_id, session.connection_id, location)
1406 .await?;
1407
1408 room_updated(&room, &session.peer);
1409 response.send(proto::Ack {})?;
1410 Ok(())
1411}
1412
1413async fn share_project(
1414 request: proto::ShareProject,
1415 response: Response<proto::ShareProject>,
1416 session: Session,
1417) -> Result<()> {
1418 let (project_id, room) = &*session
1419 .db()
1420 .await
1421 .share_project(
1422 RoomId::from_proto(request.room_id),
1423 session.connection_id,
1424 &request.worktrees,
1425 )
1426 .await?;
1427 response.send(proto::ShareProjectResponse {
1428 project_id: project_id.to_proto(),
1429 })?;
1430 room_updated(&room, &session.peer);
1431
1432 Ok(())
1433}
1434
1435async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1436 let project_id = ProjectId::from_proto(message.project_id);
1437
1438 let (room, guest_connection_ids) = &*session
1439 .db()
1440 .await
1441 .unshare_project(project_id, session.connection_id)
1442 .await?;
1443
1444 broadcast(
1445 Some(session.connection_id),
1446 guest_connection_ids.iter().copied(),
1447 |conn_id| session.peer.send(conn_id, message.clone()),
1448 );
1449 room_updated(&room, &session.peer);
1450
1451 Ok(())
1452}
1453
1454async fn join_project(
1455 request: proto::JoinProject,
1456 response: Response<proto::JoinProject>,
1457 session: Session,
1458) -> Result<()> {
1459 let project_id = ProjectId::from_proto(request.project_id);
1460 let guest_user_id = session.user_id;
1461
1462 tracing::info!(%project_id, "join project");
1463
1464 let (project, replica_id) = &mut *session
1465 .db()
1466 .await
1467 .join_project(project_id, session.connection_id)
1468 .await?;
1469
1470 let collaborators = project
1471 .collaborators
1472 .iter()
1473 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1474 .map(|collaborator| collaborator.to_proto())
1475 .collect::<Vec<_>>();
1476
1477 let worktrees = project
1478 .worktrees
1479 .iter()
1480 .map(|(id, worktree)| proto::WorktreeMetadata {
1481 id: *id,
1482 root_name: worktree.root_name.clone(),
1483 visible: worktree.visible,
1484 abs_path: worktree.abs_path.clone(),
1485 })
1486 .collect::<Vec<_>>();
1487
1488 for collaborator in &collaborators {
1489 session
1490 .peer
1491 .send(
1492 collaborator.peer_id.unwrap().into(),
1493 proto::AddProjectCollaborator {
1494 project_id: project_id.to_proto(),
1495 collaborator: Some(proto::Collaborator {
1496 peer_id: Some(session.connection_id.into()),
1497 replica_id: replica_id.0 as u32,
1498 user_id: guest_user_id.to_proto(),
1499 }),
1500 },
1501 )
1502 .trace_err();
1503 }
1504
1505 // First, we send the metadata associated with each worktree.
1506 response.send(proto::JoinProjectResponse {
1507 worktrees: worktrees.clone(),
1508 replica_id: replica_id.0 as u32,
1509 collaborators: collaborators.clone(),
1510 language_servers: project.language_servers.clone(),
1511 })?;
1512
1513 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1514 #[cfg(any(test, feature = "test-support"))]
1515 const MAX_CHUNK_SIZE: usize = 2;
1516 #[cfg(not(any(test, feature = "test-support")))]
1517 const MAX_CHUNK_SIZE: usize = 256;
1518
1519 // Stream this worktree's entries.
1520 let message = proto::UpdateWorktree {
1521 project_id: project_id.to_proto(),
1522 worktree_id,
1523 abs_path: worktree.abs_path.clone(),
1524 root_name: worktree.root_name,
1525 updated_entries: worktree.entries,
1526 removed_entries: Default::default(),
1527 scan_id: worktree.scan_id,
1528 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1529 updated_repositories: worktree.repository_entries.into_values().collect(),
1530 removed_repositories: Default::default(),
1531 };
1532 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1533 session.peer.send(session.connection_id, update.clone())?;
1534 }
1535
1536 // Stream this worktree's diagnostics.
1537 for summary in worktree.diagnostic_summaries {
1538 session.peer.send(
1539 session.connection_id,
1540 proto::UpdateDiagnosticSummary {
1541 project_id: project_id.to_proto(),
1542 worktree_id: worktree.id,
1543 summary: Some(summary),
1544 },
1545 )?;
1546 }
1547
1548 for settings_file in worktree.settings_files {
1549 session.peer.send(
1550 session.connection_id,
1551 proto::UpdateWorktreeSettings {
1552 project_id: project_id.to_proto(),
1553 worktree_id: worktree.id,
1554 path: settings_file.path,
1555 content: Some(settings_file.content),
1556 },
1557 )?;
1558 }
1559 }
1560
1561 for language_server in &project.language_servers {
1562 session.peer.send(
1563 session.connection_id,
1564 proto::UpdateLanguageServer {
1565 project_id: project_id.to_proto(),
1566 language_server_id: language_server.id,
1567 variant: Some(
1568 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1569 proto::LspDiskBasedDiagnosticsUpdated {},
1570 ),
1571 ),
1572 },
1573 )?;
1574 }
1575
1576 Ok(())
1577}
1578
1579async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1580 let sender_id = session.connection_id;
1581 let project_id = ProjectId::from_proto(request.project_id);
1582
1583 let (room, project) = &*session
1584 .db()
1585 .await
1586 .leave_project(project_id, sender_id)
1587 .await?;
1588 tracing::info!(
1589 %project_id,
1590 host_user_id = %project.host_user_id,
1591 host_connection_id = %project.host_connection_id,
1592 "leave project"
1593 );
1594
1595 project_left(&project, &session);
1596 room_updated(&room, &session.peer);
1597
1598 Ok(())
1599}
1600
1601async fn update_project(
1602 request: proto::UpdateProject,
1603 response: Response<proto::UpdateProject>,
1604 session: Session,
1605) -> Result<()> {
1606 let project_id = ProjectId::from_proto(request.project_id);
1607 let (room, guest_connection_ids) = &*session
1608 .db()
1609 .await
1610 .update_project(project_id, session.connection_id, &request.worktrees)
1611 .await?;
1612 broadcast(
1613 Some(session.connection_id),
1614 guest_connection_ids.iter().copied(),
1615 |connection_id| {
1616 session
1617 .peer
1618 .forward_send(session.connection_id, connection_id, request.clone())
1619 },
1620 );
1621 room_updated(&room, &session.peer);
1622 response.send(proto::Ack {})?;
1623
1624 Ok(())
1625}
1626
1627async fn update_worktree(
1628 request: proto::UpdateWorktree,
1629 response: Response<proto::UpdateWorktree>,
1630 session: Session,
1631) -> Result<()> {
1632 let guest_connection_ids = session
1633 .db()
1634 .await
1635 .update_worktree(&request, session.connection_id)
1636 .await?;
1637
1638 broadcast(
1639 Some(session.connection_id),
1640 guest_connection_ids.iter().copied(),
1641 |connection_id| {
1642 session
1643 .peer
1644 .forward_send(session.connection_id, connection_id, request.clone())
1645 },
1646 );
1647 response.send(proto::Ack {})?;
1648 Ok(())
1649}
1650
1651async fn update_diagnostic_summary(
1652 message: proto::UpdateDiagnosticSummary,
1653 session: Session,
1654) -> Result<()> {
1655 let guest_connection_ids = session
1656 .db()
1657 .await
1658 .update_diagnostic_summary(&message, session.connection_id)
1659 .await?;
1660
1661 broadcast(
1662 Some(session.connection_id),
1663 guest_connection_ids.iter().copied(),
1664 |connection_id| {
1665 session
1666 .peer
1667 .forward_send(session.connection_id, connection_id, message.clone())
1668 },
1669 );
1670
1671 Ok(())
1672}
1673
1674async fn update_worktree_settings(
1675 message: proto::UpdateWorktreeSettings,
1676 session: Session,
1677) -> Result<()> {
1678 let guest_connection_ids = session
1679 .db()
1680 .await
1681 .update_worktree_settings(&message, session.connection_id)
1682 .await?;
1683
1684 broadcast(
1685 Some(session.connection_id),
1686 guest_connection_ids.iter().copied(),
1687 |connection_id| {
1688 session
1689 .peer
1690 .forward_send(session.connection_id, connection_id, message.clone())
1691 },
1692 );
1693
1694 Ok(())
1695}
1696
1697async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1698 broadcast_project_message(request.project_id, request, session).await
1699}
1700
1701async fn start_language_server(
1702 request: proto::StartLanguageServer,
1703 session: Session,
1704) -> Result<()> {
1705 let guest_connection_ids = session
1706 .db()
1707 .await
1708 .start_language_server(&request, session.connection_id)
1709 .await?;
1710
1711 broadcast(
1712 Some(session.connection_id),
1713 guest_connection_ids.iter().copied(),
1714 |connection_id| {
1715 session
1716 .peer
1717 .forward_send(session.connection_id, connection_id, request.clone())
1718 },
1719 );
1720 Ok(())
1721}
1722
1723async fn update_language_server(
1724 request: proto::UpdateLanguageServer,
1725 session: Session,
1726) -> Result<()> {
1727 let project_id = ProjectId::from_proto(request.project_id);
1728 let project_connection_ids = session
1729 .db()
1730 .await
1731 .project_connection_ids(project_id, session.connection_id)
1732 .await?;
1733 broadcast(
1734 Some(session.connection_id),
1735 project_connection_ids.iter().copied(),
1736 |connection_id| {
1737 session
1738 .peer
1739 .forward_send(session.connection_id, connection_id, request.clone())
1740 },
1741 );
1742 Ok(())
1743}
1744
1745async fn forward_project_request<T>(
1746 request: T,
1747 response: Response<T>,
1748 session: Session,
1749) -> Result<()>
1750where
1751 T: EntityMessage + RequestMessage,
1752{
1753 let project_id = ProjectId::from_proto(request.remote_entity_id());
1754 let host_connection_id = {
1755 let collaborators = session
1756 .db()
1757 .await
1758 .project_collaborators(project_id, session.connection_id)
1759 .await?;
1760 collaborators
1761 .iter()
1762 .find(|collaborator| collaborator.is_host)
1763 .ok_or_else(|| anyhow!("host not found"))?
1764 .connection_id
1765 };
1766
1767 let payload = session
1768 .peer
1769 .forward_request(session.connection_id, host_connection_id, request)
1770 .await?;
1771
1772 response.send(payload)?;
1773 Ok(())
1774}
1775
1776async fn create_buffer_for_peer(
1777 request: proto::CreateBufferForPeer,
1778 session: Session,
1779) -> Result<()> {
1780 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1781 session
1782 .peer
1783 .forward_send(session.connection_id, peer_id.into(), request)?;
1784 Ok(())
1785}
1786
1787async fn update_buffer(
1788 request: proto::UpdateBuffer,
1789 response: Response<proto::UpdateBuffer>,
1790 session: Session,
1791) -> Result<()> {
1792 let project_id = ProjectId::from_proto(request.project_id);
1793 let mut guest_connection_ids;
1794 let mut host_connection_id = None;
1795 {
1796 let collaborators = session
1797 .db()
1798 .await
1799 .project_collaborators(project_id, session.connection_id)
1800 .await?;
1801 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1802 for collaborator in collaborators.iter() {
1803 if collaborator.is_host {
1804 host_connection_id = Some(collaborator.connection_id);
1805 } else {
1806 guest_connection_ids.push(collaborator.connection_id);
1807 }
1808 }
1809 }
1810 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1811
1812 broadcast(
1813 Some(session.connection_id),
1814 guest_connection_ids,
1815 |connection_id| {
1816 session
1817 .peer
1818 .forward_send(session.connection_id, connection_id, request.clone())
1819 },
1820 );
1821 if host_connection_id != session.connection_id {
1822 session
1823 .peer
1824 .forward_request(session.connection_id, host_connection_id, request.clone())
1825 .await?;
1826 }
1827
1828 response.send(proto::Ack {})?;
1829 Ok(())
1830}
1831
1832async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1833 let project_id = ProjectId::from_proto(request.project_id);
1834 let project_connection_ids = session
1835 .db()
1836 .await
1837 .project_connection_ids(project_id, session.connection_id)
1838 .await?;
1839
1840 broadcast(
1841 Some(session.connection_id),
1842 project_connection_ids.iter().copied(),
1843 |connection_id| {
1844 session
1845 .peer
1846 .forward_send(session.connection_id, connection_id, request.clone())
1847 },
1848 );
1849 Ok(())
1850}
1851
1852async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1853 let project_id = ProjectId::from_proto(request.project_id);
1854 let project_connection_ids = session
1855 .db()
1856 .await
1857 .project_connection_ids(project_id, session.connection_id)
1858 .await?;
1859 broadcast(
1860 Some(session.connection_id),
1861 project_connection_ids.iter().copied(),
1862 |connection_id| {
1863 session
1864 .peer
1865 .forward_send(session.connection_id, connection_id, request.clone())
1866 },
1867 );
1868 Ok(())
1869}
1870
1871async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1872 broadcast_project_message(request.project_id, request, session).await
1873}
1874
1875async fn broadcast_project_message<T: EnvelopedMessage>(
1876 project_id: u64,
1877 request: T,
1878 session: Session,
1879) -> Result<()> {
1880 let project_id = ProjectId::from_proto(project_id);
1881 let project_connection_ids = session
1882 .db()
1883 .await
1884 .project_connection_ids(project_id, session.connection_id)
1885 .await?;
1886 broadcast(
1887 Some(session.connection_id),
1888 project_connection_ids.iter().copied(),
1889 |connection_id| {
1890 session
1891 .peer
1892 .forward_send(session.connection_id, connection_id, request.clone())
1893 },
1894 );
1895 Ok(())
1896}
1897
1898async fn follow(
1899 request: proto::Follow,
1900 response: Response<proto::Follow>,
1901 session: Session,
1902) -> Result<()> {
1903 let room_id = RoomId::from_proto(request.room_id);
1904 let project_id = request.project_id.map(ProjectId::from_proto);
1905 let leader_id = request
1906 .leader_id
1907 .ok_or_else(|| anyhow!("invalid leader id"))?
1908 .into();
1909 let follower_id = session.connection_id;
1910
1911 session
1912 .db()
1913 .await
1914 .check_room_participants(room_id, leader_id, session.connection_id)
1915 .await?;
1916
1917 let response_payload = session
1918 .peer
1919 .forward_request(session.connection_id, leader_id, request)
1920 .await?;
1921 response.send(response_payload)?;
1922
1923 if let Some(project_id) = project_id {
1924 let room = session
1925 .db()
1926 .await
1927 .follow(room_id, project_id, leader_id, follower_id)
1928 .await?;
1929 room_updated(&room, &session.peer);
1930 }
1931
1932 Ok(())
1933}
1934
1935async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1936 let room_id = RoomId::from_proto(request.room_id);
1937 let project_id = request.project_id.map(ProjectId::from_proto);
1938 let leader_id = request
1939 .leader_id
1940 .ok_or_else(|| anyhow!("invalid leader id"))?
1941 .into();
1942 let follower_id = session.connection_id;
1943
1944 session
1945 .db()
1946 .await
1947 .check_room_participants(room_id, leader_id, session.connection_id)
1948 .await?;
1949
1950 session
1951 .peer
1952 .forward_send(session.connection_id, leader_id, request)?;
1953
1954 if let Some(project_id) = project_id {
1955 let room = session
1956 .db()
1957 .await
1958 .unfollow(room_id, project_id, leader_id, follower_id)
1959 .await?;
1960 room_updated(&room, &session.peer);
1961 }
1962
1963 Ok(())
1964}
1965
1966async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1967 let room_id = RoomId::from_proto(request.room_id);
1968 let database = session.db.lock().await;
1969
1970 let connection_ids = if let Some(project_id) = request.project_id {
1971 let project_id = ProjectId::from_proto(project_id);
1972 database
1973 .project_connection_ids(project_id, session.connection_id)
1974 .await?
1975 } else {
1976 database
1977 .room_connection_ids(room_id, session.connection_id)
1978 .await?
1979 };
1980
1981 // For now, don't send view update messages back to that view's current leader.
1982 let connection_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
1983 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1984 _ => None,
1985 });
1986
1987 for follower_peer_id in request.follower_ids.iter().copied() {
1988 let follower_connection_id = follower_peer_id.into();
1989 if Some(follower_peer_id) != connection_id_to_omit
1990 && connection_ids.contains(&follower_connection_id)
1991 {
1992 session.peer.forward_send(
1993 session.connection_id,
1994 follower_connection_id,
1995 request.clone(),
1996 )?;
1997 }
1998 }
1999 Ok(())
2000}
2001
2002async fn get_users(
2003 request: proto::GetUsers,
2004 response: Response<proto::GetUsers>,
2005 session: Session,
2006) -> Result<()> {
2007 let user_ids = request
2008 .user_ids
2009 .into_iter()
2010 .map(UserId::from_proto)
2011 .collect();
2012 let users = session
2013 .db()
2014 .await
2015 .get_users_by_ids(user_ids)
2016 .await?
2017 .into_iter()
2018 .map(|user| proto::User {
2019 id: user.id.to_proto(),
2020 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2021 github_login: user.github_login,
2022 })
2023 .collect();
2024 response.send(proto::UsersResponse { users })?;
2025 Ok(())
2026}
2027
2028async fn fuzzy_search_users(
2029 request: proto::FuzzySearchUsers,
2030 response: Response<proto::FuzzySearchUsers>,
2031 session: Session,
2032) -> Result<()> {
2033 let query = request.query;
2034 let users = match query.len() {
2035 0 => vec![],
2036 1 | 2 => session
2037 .db()
2038 .await
2039 .get_user_by_github_login(&query)
2040 .await?
2041 .into_iter()
2042 .collect(),
2043 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2044 };
2045 let users = users
2046 .into_iter()
2047 .filter(|user| user.id != session.user_id)
2048 .map(|user| proto::User {
2049 id: user.id.to_proto(),
2050 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2051 github_login: user.github_login,
2052 })
2053 .collect();
2054 response.send(proto::UsersResponse { users })?;
2055 Ok(())
2056}
2057
2058async fn request_contact(
2059 request: proto::RequestContact,
2060 response: Response<proto::RequestContact>,
2061 session: Session,
2062) -> Result<()> {
2063 let requester_id = session.user_id;
2064 let responder_id = UserId::from_proto(request.responder_id);
2065 if requester_id == responder_id {
2066 return Err(anyhow!("cannot add yourself as a contact"))?;
2067 }
2068
2069 let notifications = session
2070 .db()
2071 .await
2072 .send_contact_request(requester_id, responder_id)
2073 .await?;
2074
2075 // Update outgoing contact requests of requester
2076 let mut update = proto::UpdateContacts::default();
2077 update.outgoing_requests.push(responder_id.to_proto());
2078 for connection_id in session
2079 .connection_pool()
2080 .await
2081 .user_connection_ids(requester_id)
2082 {
2083 session.peer.send(connection_id, update.clone())?;
2084 }
2085
2086 // Update incoming contact requests of responder
2087 let mut update = proto::UpdateContacts::default();
2088 update
2089 .incoming_requests
2090 .push(proto::IncomingContactRequest {
2091 requester_id: requester_id.to_proto(),
2092 });
2093 let connection_pool = session.connection_pool().await;
2094 for connection_id in connection_pool.user_connection_ids(responder_id) {
2095 session.peer.send(connection_id, update.clone())?;
2096 }
2097
2098 send_notifications(&*connection_pool, &session.peer, notifications);
2099
2100 response.send(proto::Ack {})?;
2101 Ok(())
2102}
2103
2104async fn respond_to_contact_request(
2105 request: proto::RespondToContactRequest,
2106 response: Response<proto::RespondToContactRequest>,
2107 session: Session,
2108) -> Result<()> {
2109 let responder_id = session.user_id;
2110 let requester_id = UserId::from_proto(request.requester_id);
2111 let db = session.db().await;
2112 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2113 db.dismiss_contact_notification(responder_id, requester_id)
2114 .await?;
2115 } else {
2116 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2117
2118 let notifications = db
2119 .respond_to_contact_request(responder_id, requester_id, accept)
2120 .await?;
2121 let requester_busy = db.is_user_busy(requester_id).await?;
2122 let responder_busy = db.is_user_busy(responder_id).await?;
2123
2124 let pool = session.connection_pool().await;
2125 // Update responder with new contact
2126 let mut update = proto::UpdateContacts::default();
2127 if accept {
2128 update
2129 .contacts
2130 .push(contact_for_user(requester_id, requester_busy, &pool));
2131 }
2132 update
2133 .remove_incoming_requests
2134 .push(requester_id.to_proto());
2135 for connection_id in pool.user_connection_ids(responder_id) {
2136 session.peer.send(connection_id, update.clone())?;
2137 }
2138
2139 // Update requester with new contact
2140 let mut update = proto::UpdateContacts::default();
2141 if accept {
2142 update
2143 .contacts
2144 .push(contact_for_user(responder_id, responder_busy, &pool));
2145 }
2146 update
2147 .remove_outgoing_requests
2148 .push(responder_id.to_proto());
2149
2150 for connection_id in pool.user_connection_ids(requester_id) {
2151 session.peer.send(connection_id, update.clone())?;
2152 }
2153
2154 send_notifications(&*pool, &session.peer, notifications);
2155 }
2156
2157 response.send(proto::Ack {})?;
2158 Ok(())
2159}
2160
2161async fn remove_contact(
2162 request: proto::RemoveContact,
2163 response: Response<proto::RemoveContact>,
2164 session: Session,
2165) -> Result<()> {
2166 let requester_id = session.user_id;
2167 let responder_id = UserId::from_proto(request.user_id);
2168 let db = session.db().await;
2169 let (contact_accepted, deleted_notification_id) =
2170 db.remove_contact(requester_id, responder_id).await?;
2171
2172 let pool = session.connection_pool().await;
2173 // Update outgoing contact requests of requester
2174 let mut update = proto::UpdateContacts::default();
2175 if contact_accepted {
2176 update.remove_contacts.push(responder_id.to_proto());
2177 } else {
2178 update
2179 .remove_outgoing_requests
2180 .push(responder_id.to_proto());
2181 }
2182 for connection_id in pool.user_connection_ids(requester_id) {
2183 session.peer.send(connection_id, update.clone())?;
2184 }
2185
2186 // Update incoming contact requests of responder
2187 let mut update = proto::UpdateContacts::default();
2188 if contact_accepted {
2189 update.remove_contacts.push(requester_id.to_proto());
2190 } else {
2191 update
2192 .remove_incoming_requests
2193 .push(requester_id.to_proto());
2194 }
2195 for connection_id in pool.user_connection_ids(responder_id) {
2196 session.peer.send(connection_id, update.clone())?;
2197 if let Some(notification_id) = deleted_notification_id {
2198 session.peer.send(
2199 connection_id,
2200 proto::DeleteNotification {
2201 notification_id: notification_id.to_proto(),
2202 },
2203 )?;
2204 }
2205 }
2206
2207 response.send(proto::Ack {})?;
2208 Ok(())
2209}
2210
2211async fn create_channel(
2212 request: proto::CreateChannel,
2213 response: Response<proto::CreateChannel>,
2214 session: Session,
2215) -> Result<()> {
2216 let db = session.db().await;
2217
2218 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2219 let CreateChannelResult {
2220 channel,
2221 participants_to_update,
2222 } = db
2223 .create_channel(&request.name, parent_id, session.user_id)
2224 .await?;
2225
2226 response.send(proto::CreateChannelResponse {
2227 channel: Some(channel.to_proto()),
2228 parent_id: request.parent_id,
2229 })?;
2230
2231 let connection_pool = session.connection_pool().await;
2232 for (user_id, channels) in participants_to_update {
2233 let update = build_channels_update(channels, vec![]);
2234 for connection_id in connection_pool.user_connection_ids(user_id) {
2235 if user_id == session.user_id {
2236 continue;
2237 }
2238 session.peer.send(connection_id, update.clone())?;
2239 }
2240 }
2241
2242 Ok(())
2243}
2244
2245async fn delete_channel(
2246 request: proto::DeleteChannel,
2247 response: Response<proto::DeleteChannel>,
2248 session: Session,
2249) -> Result<()> {
2250 let db = session.db().await;
2251
2252 let channel_id = request.channel_id;
2253 let (removed_channels, member_ids) = db
2254 .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2255 .await?;
2256 response.send(proto::Ack {})?;
2257
2258 // Notify members of removed channels
2259 let mut update = proto::UpdateChannels::default();
2260 update
2261 .delete_channels
2262 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2263
2264 let connection_pool = session.connection_pool().await;
2265 for member_id in member_ids {
2266 for connection_id in connection_pool.user_connection_ids(member_id) {
2267 session.peer.send(connection_id, update.clone())?;
2268 }
2269 }
2270
2271 Ok(())
2272}
2273
2274async fn invite_channel_member(
2275 request: proto::InviteChannelMember,
2276 response: Response<proto::InviteChannelMember>,
2277 session: Session,
2278) -> Result<()> {
2279 let db = session.db().await;
2280 let channel_id = ChannelId::from_proto(request.channel_id);
2281 let invitee_id = UserId::from_proto(request.user_id);
2282 let InviteMemberResult {
2283 channel,
2284 notifications,
2285 } = db
2286 .invite_channel_member(
2287 channel_id,
2288 invitee_id,
2289 session.user_id,
2290 request.role().into(),
2291 )
2292 .await?;
2293
2294 let update = proto::UpdateChannels {
2295 channel_invitations: vec![channel.to_proto()],
2296 ..Default::default()
2297 };
2298
2299 let connection_pool = session.connection_pool().await;
2300 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2301 session.peer.send(connection_id, update.clone())?;
2302 }
2303
2304 send_notifications(&*connection_pool, &session.peer, notifications);
2305
2306 response.send(proto::Ack {})?;
2307 Ok(())
2308}
2309
2310async fn remove_channel_member(
2311 request: proto::RemoveChannelMember,
2312 response: Response<proto::RemoveChannelMember>,
2313 session: Session,
2314) -> Result<()> {
2315 let db = session.db().await;
2316 let channel_id = ChannelId::from_proto(request.channel_id);
2317 let member_id = UserId::from_proto(request.user_id);
2318
2319 let RemoveChannelMemberResult {
2320 membership_update,
2321 notification_id,
2322 } = db
2323 .remove_channel_member(channel_id, member_id, session.user_id)
2324 .await?;
2325
2326 let connection_pool = &session.connection_pool().await;
2327 notify_membership_updated(
2328 &connection_pool,
2329 membership_update,
2330 member_id,
2331 &session.peer,
2332 );
2333 for connection_id in connection_pool.user_connection_ids(member_id) {
2334 if let Some(notification_id) = notification_id {
2335 session
2336 .peer
2337 .send(
2338 connection_id,
2339 proto::DeleteNotification {
2340 notification_id: notification_id.to_proto(),
2341 },
2342 )
2343 .trace_err();
2344 }
2345 }
2346
2347 response.send(proto::Ack {})?;
2348 Ok(())
2349}
2350
2351async fn set_channel_visibility(
2352 request: proto::SetChannelVisibility,
2353 response: Response<proto::SetChannelVisibility>,
2354 session: Session,
2355) -> Result<()> {
2356 let db = session.db().await;
2357 let channel_id = ChannelId::from_proto(request.channel_id);
2358 let visibility = request.visibility().into();
2359
2360 let SetChannelVisibilityResult {
2361 participants_to_update,
2362 participants_to_remove,
2363 channels_to_remove,
2364 } = db
2365 .set_channel_visibility(channel_id, visibility, session.user_id)
2366 .await?;
2367
2368 let connection_pool = session.connection_pool().await;
2369 for (user_id, channels) in participants_to_update {
2370 let update = build_channels_update(channels, vec![]);
2371 for connection_id in connection_pool.user_connection_ids(user_id) {
2372 session.peer.send(connection_id, update.clone())?;
2373 }
2374 }
2375 for user_id in participants_to_remove {
2376 let update = proto::UpdateChannels {
2377 delete_channels: channels_to_remove.iter().map(|id| id.to_proto()).collect(),
2378 ..Default::default()
2379 };
2380 for connection_id in connection_pool.user_connection_ids(user_id) {
2381 session.peer.send(connection_id, update.clone())?;
2382 }
2383 }
2384
2385 response.send(proto::Ack {})?;
2386 Ok(())
2387}
2388
2389async fn set_channel_member_role(
2390 request: proto::SetChannelMemberRole,
2391 response: Response<proto::SetChannelMemberRole>,
2392 session: Session,
2393) -> Result<()> {
2394 let db = session.db().await;
2395 let channel_id = ChannelId::from_proto(request.channel_id);
2396 let member_id = UserId::from_proto(request.user_id);
2397 let result = db
2398 .set_channel_member_role(
2399 channel_id,
2400 session.user_id,
2401 member_id,
2402 request.role().into(),
2403 )
2404 .await?;
2405
2406 match result {
2407 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2408 let connection_pool = session.connection_pool().await;
2409 notify_membership_updated(
2410 &connection_pool,
2411 membership_update,
2412 member_id,
2413 &session.peer,
2414 )
2415 }
2416 db::SetMemberRoleResult::InviteUpdated(channel) => {
2417 let update = proto::UpdateChannels {
2418 channel_invitations: vec![channel.to_proto()],
2419 ..Default::default()
2420 };
2421
2422 for connection_id in session
2423 .connection_pool()
2424 .await
2425 .user_connection_ids(member_id)
2426 {
2427 session.peer.send(connection_id, update.clone())?;
2428 }
2429 }
2430 }
2431
2432 response.send(proto::Ack {})?;
2433 Ok(())
2434}
2435
2436async fn rename_channel(
2437 request: proto::RenameChannel,
2438 response: Response<proto::RenameChannel>,
2439 session: Session,
2440) -> Result<()> {
2441 let db = session.db().await;
2442 let channel_id = ChannelId::from_proto(request.channel_id);
2443 let RenameChannelResult {
2444 channel,
2445 participants_to_update,
2446 } = db
2447 .rename_channel(channel_id, session.user_id, &request.name)
2448 .await?;
2449
2450 response.send(proto::RenameChannelResponse {
2451 channel: Some(channel.to_proto()),
2452 })?;
2453
2454 let connection_pool = session.connection_pool().await;
2455 for (user_id, channel) in participants_to_update {
2456 for connection_id in connection_pool.user_connection_ids(user_id) {
2457 let update = proto::UpdateChannels {
2458 channels: vec![channel.to_proto()],
2459 ..Default::default()
2460 };
2461
2462 session.peer.send(connection_id, update.clone())?;
2463 }
2464 }
2465
2466 Ok(())
2467}
2468
2469async fn move_channel(
2470 request: proto::MoveChannel,
2471 response: Response<proto::MoveChannel>,
2472 session: Session,
2473) -> Result<()> {
2474 let channel_id = ChannelId::from_proto(request.channel_id);
2475 let to = request.to.map(ChannelId::from_proto);
2476
2477 let result = session
2478 .db()
2479 .await
2480 .move_channel(channel_id, to, session.user_id)
2481 .await?;
2482
2483 notify_channel_moved(result, session).await?;
2484
2485 response.send(Ack {})?;
2486 Ok(())
2487}
2488
2489async fn notify_channel_moved(result: Option<MoveChannelResult>, session: Session) -> Result<()> {
2490 let Some(MoveChannelResult {
2491 participants_to_remove,
2492 participants_to_update,
2493 moved_channels,
2494 }) = result
2495 else {
2496 return Ok(());
2497 };
2498 let moved_channels: Vec<u64> = moved_channels.iter().map(|id| id.to_proto()).collect();
2499
2500 let connection_pool = session.connection_pool().await;
2501 for (user_id, channels) in participants_to_update {
2502 let mut update = build_channels_update(channels, vec![]);
2503 update.delete_channels = moved_channels.clone();
2504 for connection_id in connection_pool.user_connection_ids(user_id) {
2505 session.peer.send(connection_id, update.clone())?;
2506 }
2507 }
2508
2509 for user_id in participants_to_remove {
2510 let update = proto::UpdateChannels {
2511 delete_channels: moved_channels.clone(),
2512 ..Default::default()
2513 };
2514 for connection_id in connection_pool.user_connection_ids(user_id) {
2515 session.peer.send(connection_id, update.clone())?;
2516 }
2517 }
2518 Ok(())
2519}
2520
2521async fn get_channel_members(
2522 request: proto::GetChannelMembers,
2523 response: Response<proto::GetChannelMembers>,
2524 session: Session,
2525) -> Result<()> {
2526 let db = session.db().await;
2527 let channel_id = ChannelId::from_proto(request.channel_id);
2528 let members = db
2529 .get_channel_participant_details(channel_id, session.user_id)
2530 .await?;
2531 response.send(proto::GetChannelMembersResponse { members })?;
2532 Ok(())
2533}
2534
2535async fn respond_to_channel_invite(
2536 request: proto::RespondToChannelInvite,
2537 response: Response<proto::RespondToChannelInvite>,
2538 session: Session,
2539) -> Result<()> {
2540 let db = session.db().await;
2541 let channel_id = ChannelId::from_proto(request.channel_id);
2542 let RespondToChannelInvite {
2543 membership_update,
2544 notifications,
2545 } = db
2546 .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2547 .await?;
2548
2549 let connection_pool = session.connection_pool().await;
2550 if let Some(membership_update) = membership_update {
2551 notify_membership_updated(
2552 &connection_pool,
2553 membership_update,
2554 session.user_id,
2555 &session.peer,
2556 );
2557 } else {
2558 let update = proto::UpdateChannels {
2559 remove_channel_invitations: vec![channel_id.to_proto()],
2560 ..Default::default()
2561 };
2562
2563 for connection_id in connection_pool.user_connection_ids(session.user_id) {
2564 session.peer.send(connection_id, update.clone())?;
2565 }
2566 };
2567
2568 send_notifications(&*connection_pool, &session.peer, notifications);
2569
2570 response.send(proto::Ack {})?;
2571
2572 Ok(())
2573}
2574
2575async fn join_channel(
2576 request: proto::JoinChannel,
2577 response: Response<proto::JoinChannel>,
2578 session: Session,
2579) -> Result<()> {
2580 let channel_id = ChannelId::from_proto(request.channel_id);
2581 join_channel_internal(channel_id, Box::new(response), session).await
2582}
2583
2584trait JoinChannelInternalResponse {
2585 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2586}
2587impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2588 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2589 Response::<proto::JoinChannel>::send(self, result)
2590 }
2591}
2592impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2593 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2594 Response::<proto::JoinRoom>::send(self, result)
2595 }
2596}
2597
2598async fn join_channel_internal(
2599 channel_id: ChannelId,
2600 response: Box<impl JoinChannelInternalResponse>,
2601 session: Session,
2602) -> Result<()> {
2603 let joined_room = {
2604 leave_room_for_session(&session).await?;
2605 let db = session.db().await;
2606
2607 let (joined_room, membership_updated, role) = db
2608 .join_channel(
2609 channel_id,
2610 session.user_id,
2611 session.connection_id,
2612 session.zed_environment.as_ref(),
2613 )
2614 .await?;
2615
2616 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2617 let (can_publish, token) = if role == ChannelRole::Guest {
2618 (
2619 false,
2620 live_kit
2621 .guest_token(
2622 &joined_room.room.live_kit_room,
2623 &session.user_id.to_string(),
2624 )
2625 .trace_err()?,
2626 )
2627 } else {
2628 (
2629 true,
2630 live_kit
2631 .room_token(
2632 &joined_room.room.live_kit_room,
2633 &session.user_id.to_string(),
2634 )
2635 .trace_err()?,
2636 )
2637 };
2638
2639 Some(LiveKitConnectionInfo {
2640 server_url: live_kit.url().into(),
2641 token,
2642 can_publish,
2643 })
2644 });
2645
2646 response.send(proto::JoinRoomResponse {
2647 room: Some(joined_room.room.clone()),
2648 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2649 live_kit_connection_info,
2650 })?;
2651
2652 let connection_pool = session.connection_pool().await;
2653 if let Some(membership_updated) = membership_updated {
2654 notify_membership_updated(
2655 &connection_pool,
2656 membership_updated,
2657 session.user_id,
2658 &session.peer,
2659 );
2660 }
2661
2662 room_updated(&joined_room.room, &session.peer);
2663
2664 joined_room
2665 };
2666
2667 channel_updated(
2668 channel_id,
2669 &joined_room.room,
2670 &joined_room.channel_members,
2671 &session.peer,
2672 &*session.connection_pool().await,
2673 );
2674
2675 update_user_contacts(session.user_id, &session).await?;
2676 Ok(())
2677}
2678
2679async fn join_channel_buffer(
2680 request: proto::JoinChannelBuffer,
2681 response: Response<proto::JoinChannelBuffer>,
2682 session: Session,
2683) -> Result<()> {
2684 let db = session.db().await;
2685 let channel_id = ChannelId::from_proto(request.channel_id);
2686
2687 let open_response = db
2688 .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2689 .await?;
2690
2691 let collaborators = open_response.collaborators.clone();
2692 response.send(open_response)?;
2693
2694 let update = UpdateChannelBufferCollaborators {
2695 channel_id: channel_id.to_proto(),
2696 collaborators: collaborators.clone(),
2697 };
2698 channel_buffer_updated(
2699 session.connection_id,
2700 collaborators
2701 .iter()
2702 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2703 &update,
2704 &session.peer,
2705 );
2706
2707 Ok(())
2708}
2709
2710async fn update_channel_buffer(
2711 request: proto::UpdateChannelBuffer,
2712 session: Session,
2713) -> Result<()> {
2714 let db = session.db().await;
2715 let channel_id = ChannelId::from_proto(request.channel_id);
2716
2717 let (collaborators, non_collaborators, epoch, version) = db
2718 .update_channel_buffer(channel_id, session.user_id, &request.operations)
2719 .await?;
2720
2721 channel_buffer_updated(
2722 session.connection_id,
2723 collaborators,
2724 &proto::UpdateChannelBuffer {
2725 channel_id: channel_id.to_proto(),
2726 operations: request.operations,
2727 },
2728 &session.peer,
2729 );
2730
2731 let pool = &*session.connection_pool().await;
2732
2733 broadcast(
2734 None,
2735 non_collaborators
2736 .iter()
2737 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2738 |peer_id| {
2739 session.peer.send(
2740 peer_id.into(),
2741 proto::UpdateChannels {
2742 unseen_channel_buffer_changes: vec![proto::UnseenChannelBufferChange {
2743 channel_id: channel_id.to_proto(),
2744 epoch: epoch as u64,
2745 version: version.clone(),
2746 }],
2747 ..Default::default()
2748 },
2749 )
2750 },
2751 );
2752
2753 Ok(())
2754}
2755
2756async fn rejoin_channel_buffers(
2757 request: proto::RejoinChannelBuffers,
2758 response: Response<proto::RejoinChannelBuffers>,
2759 session: Session,
2760) -> Result<()> {
2761 let db = session.db().await;
2762 let buffers = db
2763 .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2764 .await?;
2765
2766 for rejoined_buffer in &buffers {
2767 let collaborators_to_notify = rejoined_buffer
2768 .buffer
2769 .collaborators
2770 .iter()
2771 .filter_map(|c| Some(c.peer_id?.into()));
2772 channel_buffer_updated(
2773 session.connection_id,
2774 collaborators_to_notify,
2775 &proto::UpdateChannelBufferCollaborators {
2776 channel_id: rejoined_buffer.buffer.channel_id,
2777 collaborators: rejoined_buffer.buffer.collaborators.clone(),
2778 },
2779 &session.peer,
2780 );
2781 }
2782
2783 response.send(proto::RejoinChannelBuffersResponse {
2784 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2785 })?;
2786
2787 Ok(())
2788}
2789
2790async fn leave_channel_buffer(
2791 request: proto::LeaveChannelBuffer,
2792 response: Response<proto::LeaveChannelBuffer>,
2793 session: Session,
2794) -> Result<()> {
2795 let db = session.db().await;
2796 let channel_id = ChannelId::from_proto(request.channel_id);
2797
2798 let left_buffer = db
2799 .leave_channel_buffer(channel_id, session.connection_id)
2800 .await?;
2801
2802 response.send(Ack {})?;
2803
2804 channel_buffer_updated(
2805 session.connection_id,
2806 left_buffer.connections,
2807 &proto::UpdateChannelBufferCollaborators {
2808 channel_id: channel_id.to_proto(),
2809 collaborators: left_buffer.collaborators,
2810 },
2811 &session.peer,
2812 );
2813
2814 Ok(())
2815}
2816
2817fn channel_buffer_updated<T: EnvelopedMessage>(
2818 sender_id: ConnectionId,
2819 collaborators: impl IntoIterator<Item = ConnectionId>,
2820 message: &T,
2821 peer: &Peer,
2822) {
2823 broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2824 peer.send(peer_id.into(), message.clone())
2825 });
2826}
2827
2828fn send_notifications(
2829 connection_pool: &ConnectionPool,
2830 peer: &Peer,
2831 notifications: db::NotificationBatch,
2832) {
2833 for (user_id, notification) in notifications {
2834 for connection_id in connection_pool.user_connection_ids(user_id) {
2835 if let Err(error) = peer.send(
2836 connection_id,
2837 proto::AddNotification {
2838 notification: Some(notification.clone()),
2839 },
2840 ) {
2841 tracing::error!(
2842 "failed to send notification to {:?} {}",
2843 connection_id,
2844 error
2845 );
2846 }
2847 }
2848 }
2849}
2850
2851async fn send_channel_message(
2852 request: proto::SendChannelMessage,
2853 response: Response<proto::SendChannelMessage>,
2854 session: Session,
2855) -> Result<()> {
2856 // Validate the message body.
2857 let body = request.body.trim().to_string();
2858 if body.len() > MAX_MESSAGE_LEN {
2859 return Err(anyhow!("message is too long"))?;
2860 }
2861 if body.is_empty() {
2862 return Err(anyhow!("message can't be blank"))?;
2863 }
2864
2865 // TODO: adjust mentions if body is trimmed
2866
2867 let timestamp = OffsetDateTime::now_utc();
2868 let nonce = request
2869 .nonce
2870 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
2871
2872 let channel_id = ChannelId::from_proto(request.channel_id);
2873 let CreatedChannelMessage {
2874 message_id,
2875 participant_connection_ids,
2876 channel_members,
2877 notifications,
2878 } = session
2879 .db()
2880 .await
2881 .create_channel_message(
2882 channel_id,
2883 session.user_id,
2884 &body,
2885 &request.mentions,
2886 timestamp,
2887 nonce.clone().into(),
2888 )
2889 .await?;
2890 let message = proto::ChannelMessage {
2891 sender_id: session.user_id.to_proto(),
2892 id: message_id.to_proto(),
2893 body,
2894 mentions: request.mentions,
2895 timestamp: timestamp.unix_timestamp() as u64,
2896 nonce: Some(nonce),
2897 };
2898 broadcast(
2899 Some(session.connection_id),
2900 participant_connection_ids,
2901 |connection| {
2902 session.peer.send(
2903 connection,
2904 proto::ChannelMessageSent {
2905 channel_id: channel_id.to_proto(),
2906 message: Some(message.clone()),
2907 },
2908 )
2909 },
2910 );
2911 response.send(proto::SendChannelMessageResponse {
2912 message: Some(message),
2913 })?;
2914
2915 let pool = &*session.connection_pool().await;
2916 broadcast(
2917 None,
2918 channel_members
2919 .iter()
2920 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2921 |peer_id| {
2922 session.peer.send(
2923 peer_id.into(),
2924 proto::UpdateChannels {
2925 unseen_channel_messages: vec![proto::UnseenChannelMessage {
2926 channel_id: channel_id.to_proto(),
2927 message_id: message_id.to_proto(),
2928 }],
2929 ..Default::default()
2930 },
2931 )
2932 },
2933 );
2934 send_notifications(pool, &session.peer, notifications);
2935
2936 Ok(())
2937}
2938
2939async fn remove_channel_message(
2940 request: proto::RemoveChannelMessage,
2941 response: Response<proto::RemoveChannelMessage>,
2942 session: Session,
2943) -> Result<()> {
2944 let channel_id = ChannelId::from_proto(request.channel_id);
2945 let message_id = MessageId::from_proto(request.message_id);
2946 let connection_ids = session
2947 .db()
2948 .await
2949 .remove_channel_message(channel_id, message_id, session.user_id)
2950 .await?;
2951 broadcast(Some(session.connection_id), connection_ids, |connection| {
2952 session.peer.send(connection, request.clone())
2953 });
2954 response.send(proto::Ack {})?;
2955 Ok(())
2956}
2957
2958async fn acknowledge_channel_message(
2959 request: proto::AckChannelMessage,
2960 session: Session,
2961) -> Result<()> {
2962 let channel_id = ChannelId::from_proto(request.channel_id);
2963 let message_id = MessageId::from_proto(request.message_id);
2964 let notifications = session
2965 .db()
2966 .await
2967 .observe_channel_message(channel_id, session.user_id, message_id)
2968 .await?;
2969 send_notifications(
2970 &*session.connection_pool().await,
2971 &session.peer,
2972 notifications,
2973 );
2974 Ok(())
2975}
2976
2977async fn acknowledge_buffer_version(
2978 request: proto::AckBufferOperation,
2979 session: Session,
2980) -> Result<()> {
2981 let buffer_id = BufferId::from_proto(request.buffer_id);
2982 session
2983 .db()
2984 .await
2985 .observe_buffer_version(
2986 buffer_id,
2987 session.user_id,
2988 request.epoch as i32,
2989 &request.version,
2990 )
2991 .await?;
2992 Ok(())
2993}
2994
2995async fn join_channel_chat(
2996 request: proto::JoinChannelChat,
2997 response: Response<proto::JoinChannelChat>,
2998 session: Session,
2999) -> Result<()> {
3000 let channel_id = ChannelId::from_proto(request.channel_id);
3001
3002 let db = session.db().await;
3003 db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3004 .await?;
3005 let messages = db
3006 .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3007 .await?;
3008 response.send(proto::JoinChannelChatResponse {
3009 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3010 messages,
3011 })?;
3012 Ok(())
3013}
3014
3015async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3016 let channel_id = ChannelId::from_proto(request.channel_id);
3017 session
3018 .db()
3019 .await
3020 .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3021 .await?;
3022 Ok(())
3023}
3024
3025async fn get_channel_messages(
3026 request: proto::GetChannelMessages,
3027 response: Response<proto::GetChannelMessages>,
3028 session: Session,
3029) -> Result<()> {
3030 let channel_id = ChannelId::from_proto(request.channel_id);
3031 let messages = session
3032 .db()
3033 .await
3034 .get_channel_messages(
3035 channel_id,
3036 session.user_id,
3037 MESSAGE_COUNT_PER_PAGE,
3038 Some(MessageId::from_proto(request.before_message_id)),
3039 )
3040 .await?;
3041 response.send(proto::GetChannelMessagesResponse {
3042 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3043 messages,
3044 })?;
3045 Ok(())
3046}
3047
3048async fn get_channel_messages_by_id(
3049 request: proto::GetChannelMessagesById,
3050 response: Response<proto::GetChannelMessagesById>,
3051 session: Session,
3052) -> Result<()> {
3053 let message_ids = request
3054 .message_ids
3055 .iter()
3056 .map(|id| MessageId::from_proto(*id))
3057 .collect::<Vec<_>>();
3058 let messages = session
3059 .db()
3060 .await
3061 .get_channel_messages_by_id(session.user_id, &message_ids)
3062 .await?;
3063 response.send(proto::GetChannelMessagesResponse {
3064 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3065 messages,
3066 })?;
3067 Ok(())
3068}
3069
3070async fn get_notifications(
3071 request: proto::GetNotifications,
3072 response: Response<proto::GetNotifications>,
3073 session: Session,
3074) -> Result<()> {
3075 let notifications = session
3076 .db()
3077 .await
3078 .get_notifications(
3079 session.user_id,
3080 NOTIFICATION_COUNT_PER_PAGE,
3081 request
3082 .before_id
3083 .map(|id| db::NotificationId::from_proto(id)),
3084 )
3085 .await?;
3086 response.send(proto::GetNotificationsResponse {
3087 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3088 notifications,
3089 })?;
3090 Ok(())
3091}
3092
3093async fn mark_notification_as_read(
3094 request: proto::MarkNotificationRead,
3095 response: Response<proto::MarkNotificationRead>,
3096 session: Session,
3097) -> Result<()> {
3098 let database = &session.db().await;
3099 let notifications = database
3100 .mark_notification_as_read_by_id(
3101 session.user_id,
3102 NotificationId::from_proto(request.notification_id),
3103 )
3104 .await?;
3105 send_notifications(
3106 &*session.connection_pool().await,
3107 &session.peer,
3108 notifications,
3109 );
3110 response.send(proto::Ack {})?;
3111 Ok(())
3112}
3113
3114async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
3115 let project_id = ProjectId::from_proto(request.project_id);
3116 let project_connection_ids = session
3117 .db()
3118 .await
3119 .project_connection_ids(project_id, session.connection_id)
3120 .await?;
3121 broadcast(
3122 Some(session.connection_id),
3123 project_connection_ids.iter().copied(),
3124 |connection_id| {
3125 session
3126 .peer
3127 .forward_send(session.connection_id, connection_id, request.clone())
3128 },
3129 );
3130 Ok(())
3131}
3132
3133async fn get_private_user_info(
3134 _request: proto::GetPrivateUserInfo,
3135 response: Response<proto::GetPrivateUserInfo>,
3136 session: Session,
3137) -> Result<()> {
3138 let db = session.db().await;
3139
3140 let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3141 let user = db
3142 .get_user_by_id(session.user_id)
3143 .await?
3144 .ok_or_else(|| anyhow!("user not found"))?;
3145 let flags = db.get_user_flags(session.user_id).await?;
3146
3147 response.send(proto::GetPrivateUserInfoResponse {
3148 metrics_id,
3149 staff: user.admin,
3150 flags,
3151 })?;
3152 Ok(())
3153}
3154
3155fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3156 match message {
3157 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3158 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3159 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3160 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3161 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3162 code: frame.code.into(),
3163 reason: frame.reason,
3164 })),
3165 }
3166}
3167
3168fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3169 match message {
3170 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3171 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3172 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3173 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3174 AxumMessage::Close(frame) => {
3175 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3176 code: frame.code.into(),
3177 reason: frame.reason,
3178 }))
3179 }
3180 }
3181}
3182
3183fn notify_membership_updated(
3184 connection_pool: &ConnectionPool,
3185 result: MembershipUpdated,
3186 user_id: UserId,
3187 peer: &Peer,
3188) {
3189 let mut update = build_channels_update(result.new_channels, vec![]);
3190 update.delete_channels = result
3191 .removed_channels
3192 .into_iter()
3193 .map(|id| id.to_proto())
3194 .collect();
3195 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3196
3197 for connection_id in connection_pool.user_connection_ids(user_id) {
3198 peer.send(connection_id, update.clone()).trace_err();
3199 }
3200}
3201
3202fn build_channels_update(
3203 channels: ChannelsForUser,
3204 channel_invites: Vec<db::Channel>,
3205) -> proto::UpdateChannels {
3206 let mut update = proto::UpdateChannels::default();
3207
3208 for channel in channels.channels {
3209 update.channels.push(channel.to_proto());
3210 }
3211
3212 update.unseen_channel_buffer_changes = channels.unseen_buffer_changes;
3213 update.unseen_channel_messages = channels.channel_messages;
3214
3215 for (channel_id, participants) in channels.channel_participants {
3216 update
3217 .channel_participants
3218 .push(proto::ChannelParticipants {
3219 channel_id: channel_id.to_proto(),
3220 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3221 });
3222 }
3223
3224 for channel in channel_invites {
3225 update.channel_invitations.push(channel.to_proto());
3226 }
3227
3228 update
3229}
3230
3231fn build_initial_contacts_update(
3232 contacts: Vec<db::Contact>,
3233 pool: &ConnectionPool,
3234) -> proto::UpdateContacts {
3235 let mut update = proto::UpdateContacts::default();
3236
3237 for contact in contacts {
3238 match contact {
3239 db::Contact::Accepted { user_id, busy } => {
3240 update.contacts.push(contact_for_user(user_id, busy, &pool));
3241 }
3242 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3243 db::Contact::Incoming { user_id } => {
3244 update
3245 .incoming_requests
3246 .push(proto::IncomingContactRequest {
3247 requester_id: user_id.to_proto(),
3248 })
3249 }
3250 }
3251 }
3252
3253 update
3254}
3255
3256fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3257 proto::Contact {
3258 user_id: user_id.to_proto(),
3259 online: pool.is_user_online(user_id),
3260 busy,
3261 }
3262}
3263
3264fn room_updated(room: &proto::Room, peer: &Peer) {
3265 broadcast(
3266 None,
3267 room.participants
3268 .iter()
3269 .filter_map(|participant| Some(participant.peer_id?.into())),
3270 |peer_id| {
3271 peer.send(
3272 peer_id.into(),
3273 proto::RoomUpdated {
3274 room: Some(room.clone()),
3275 },
3276 )
3277 },
3278 );
3279}
3280
3281fn channel_updated(
3282 channel_id: ChannelId,
3283 room: &proto::Room,
3284 channel_members: &[UserId],
3285 peer: &Peer,
3286 pool: &ConnectionPool,
3287) {
3288 let participants = room
3289 .participants
3290 .iter()
3291 .map(|p| p.user_id)
3292 .collect::<Vec<_>>();
3293
3294 broadcast(
3295 None,
3296 channel_members
3297 .iter()
3298 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3299 |peer_id| {
3300 peer.send(
3301 peer_id.into(),
3302 proto::UpdateChannels {
3303 channel_participants: vec![proto::ChannelParticipants {
3304 channel_id: channel_id.to_proto(),
3305 participant_user_ids: participants.clone(),
3306 }],
3307 ..Default::default()
3308 },
3309 )
3310 },
3311 );
3312}
3313
3314async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3315 let db = session.db().await;
3316
3317 let contacts = db.get_contacts(user_id).await?;
3318 let busy = db.is_user_busy(user_id).await?;
3319
3320 let pool = session.connection_pool().await;
3321 let updated_contact = contact_for_user(user_id, busy, &pool);
3322 for contact in contacts {
3323 if let db::Contact::Accepted {
3324 user_id: contact_user_id,
3325 ..
3326 } = contact
3327 {
3328 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3329 session
3330 .peer
3331 .send(
3332 contact_conn_id,
3333 proto::UpdateContacts {
3334 contacts: vec![updated_contact.clone()],
3335 remove_contacts: Default::default(),
3336 incoming_requests: Default::default(),
3337 remove_incoming_requests: Default::default(),
3338 outgoing_requests: Default::default(),
3339 remove_outgoing_requests: Default::default(),
3340 },
3341 )
3342 .trace_err();
3343 }
3344 }
3345 }
3346 Ok(())
3347}
3348
3349async fn leave_room_for_session(session: &Session) -> Result<()> {
3350 let mut contacts_to_update = HashSet::default();
3351
3352 let room_id;
3353 let canceled_calls_to_user_ids;
3354 let live_kit_room;
3355 let delete_live_kit_room;
3356 let room;
3357 let channel_members;
3358 let channel_id;
3359
3360 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3361 contacts_to_update.insert(session.user_id);
3362
3363 for project in left_room.left_projects.values() {
3364 project_left(project, session);
3365 }
3366
3367 room_id = RoomId::from_proto(left_room.room.id);
3368 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3369 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3370 delete_live_kit_room = left_room.deleted;
3371 room = mem::take(&mut left_room.room);
3372 channel_members = mem::take(&mut left_room.channel_members);
3373 channel_id = left_room.channel_id;
3374
3375 room_updated(&room, &session.peer);
3376 } else {
3377 return Ok(());
3378 }
3379
3380 if let Some(channel_id) = channel_id {
3381 channel_updated(
3382 channel_id,
3383 &room,
3384 &channel_members,
3385 &session.peer,
3386 &*session.connection_pool().await,
3387 );
3388 }
3389
3390 {
3391 let pool = session.connection_pool().await;
3392 for canceled_user_id in canceled_calls_to_user_ids {
3393 for connection_id in pool.user_connection_ids(canceled_user_id) {
3394 session
3395 .peer
3396 .send(
3397 connection_id,
3398 proto::CallCanceled {
3399 room_id: room_id.to_proto(),
3400 },
3401 )
3402 .trace_err();
3403 }
3404 contacts_to_update.insert(canceled_user_id);
3405 }
3406 }
3407
3408 for contact_user_id in contacts_to_update {
3409 update_user_contacts(contact_user_id, &session).await?;
3410 }
3411
3412 if let Some(live_kit) = session.live_kit_client.as_ref() {
3413 live_kit
3414 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3415 .await
3416 .trace_err();
3417
3418 if delete_live_kit_room {
3419 live_kit.delete_room(live_kit_room).await.trace_err();
3420 }
3421 }
3422
3423 Ok(())
3424}
3425
3426async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3427 let left_channel_buffers = session
3428 .db()
3429 .await
3430 .leave_channel_buffers(session.connection_id)
3431 .await?;
3432
3433 for left_buffer in left_channel_buffers {
3434 channel_buffer_updated(
3435 session.connection_id,
3436 left_buffer.connections,
3437 &proto::UpdateChannelBufferCollaborators {
3438 channel_id: left_buffer.channel_id.to_proto(),
3439 collaborators: left_buffer.collaborators,
3440 },
3441 &session.peer,
3442 );
3443 }
3444
3445 Ok(())
3446}
3447
3448fn project_left(project: &db::LeftProject, session: &Session) {
3449 for connection_id in &project.connection_ids {
3450 if project.host_user_id == session.user_id {
3451 session
3452 .peer
3453 .send(
3454 *connection_id,
3455 proto::UnshareProject {
3456 project_id: project.id.to_proto(),
3457 },
3458 )
3459 .trace_err();
3460 } else {
3461 session
3462 .peer
3463 .send(
3464 *connection_id,
3465 proto::RemoveProjectCollaborator {
3466 project_id: project.id.to_proto(),
3467 peer_id: Some(session.connection_id.into()),
3468 },
3469 )
3470 .trace_err();
3471 }
3472 }
3473}
3474
3475pub trait ResultExt {
3476 type Ok;
3477
3478 fn trace_err(self) -> Option<Self::Ok>;
3479}
3480
3481impl<T, E> ResultExt for Result<T, E>
3482where
3483 E: std::fmt::Debug,
3484{
3485 type Ok = T;
3486
3487 fn trace_err(self) -> Option<T> {
3488 match self {
3489 Ok(value) => Some(value),
3490 Err(error) => {
3491 tracing::error!("{:?}", error);
3492 None
3493 }
3494 }
3495 }
3496}