1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{
6 self, ChannelId, ChannelsForUser, Database, MessageId, ProjectId, RoomId, ServerId, User,
7 UserId,
8 },
9 executor::Executor,
10 AppState, Result,
11};
12use anyhow::anyhow;
13use async_tungstenite::tungstenite::{
14 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
15};
16use axum::{
17 body::Body,
18 extract::{
19 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
20 ConnectInfo, WebSocketUpgrade,
21 },
22 headers::{Header, HeaderName},
23 http::StatusCode,
24 middleware,
25 response::IntoResponse,
26 routing::get,
27 Extension, Router, TypedHeader,
28};
29use collections::{HashMap, HashSet};
30pub use connection_pool::ConnectionPool;
31use futures::{
32 channel::oneshot,
33 future::{self, BoxFuture},
34 stream::FuturesUnordered,
35 FutureExt, SinkExt, StreamExt, TryStreamExt,
36};
37use lazy_static::lazy_static;
38use prometheus::{register_int_gauge, IntGauge};
39use rpc::{
40 proto::{
41 self, Ack, AnyTypedEnvelope, ChannelEdge, EntityMessage, EnvelopedMessage,
42 LiveKitConnectionInfo, RequestMessage, UpdateChannelBufferCollaborators,
43 },
44 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
45};
46use serde::{Serialize, Serializer};
47use std::{
48 any::TypeId,
49 fmt,
50 future::Future,
51 marker::PhantomData,
52 mem,
53 net::SocketAddr,
54 ops::{Deref, DerefMut},
55 rc::Rc,
56 sync::{
57 atomic::{AtomicBool, Ordering::SeqCst},
58 Arc,
59 },
60 time::{Duration, Instant},
61};
62use time::OffsetDateTime;
63use tokio::sync::{watch, Semaphore};
64use tower::ServiceBuilder;
65use tracing::{info_span, instrument, Instrument};
66
67pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
68pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
69
70const MESSAGE_COUNT_PER_PAGE: usize = 100;
71const MAX_MESSAGE_LEN: usize = 1024;
72
73lazy_static! {
74 static ref METRIC_CONNECTIONS: IntGauge =
75 register_int_gauge!("connections", "number of connections").unwrap();
76 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
77 "shared_projects",
78 "number of open projects with one or more guests"
79 )
80 .unwrap();
81}
82
83type MessageHandler =
84 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
85
86struct Response<R> {
87 peer: Arc<Peer>,
88 receipt: Receipt<R>,
89 responded: Arc<AtomicBool>,
90}
91
92impl<R: RequestMessage> Response<R> {
93 fn send(self, payload: R::Response) -> Result<()> {
94 self.responded.store(true, SeqCst);
95 self.peer.respond(self.receipt, payload)?;
96 Ok(())
97 }
98}
99
100#[derive(Clone)]
101struct Session {
102 user_id: UserId,
103 connection_id: ConnectionId,
104 db: Arc<tokio::sync::Mutex<DbHandle>>,
105 peer: Arc<Peer>,
106 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
107 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
108 executor: Executor,
109}
110
111impl Session {
112 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
113 #[cfg(test)]
114 tokio::task::yield_now().await;
115 let guard = self.db.lock().await;
116 #[cfg(test)]
117 tokio::task::yield_now().await;
118 guard
119 }
120
121 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
122 #[cfg(test)]
123 tokio::task::yield_now().await;
124 let guard = self.connection_pool.lock();
125 ConnectionPoolGuard {
126 guard,
127 _not_send: PhantomData,
128 }
129 }
130}
131
132impl fmt::Debug for Session {
133 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
134 f.debug_struct("Session")
135 .field("user_id", &self.user_id)
136 .field("connection_id", &self.connection_id)
137 .finish()
138 }
139}
140
141struct DbHandle(Arc<Database>);
142
143impl Deref for DbHandle {
144 type Target = Database;
145
146 fn deref(&self) -> &Self::Target {
147 self.0.as_ref()
148 }
149}
150
151pub struct Server {
152 id: parking_lot::Mutex<ServerId>,
153 peer: Arc<Peer>,
154 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
155 app_state: Arc<AppState>,
156 executor: Executor,
157 handlers: HashMap<TypeId, MessageHandler>,
158 teardown: watch::Sender<()>,
159}
160
161pub(crate) struct ConnectionPoolGuard<'a> {
162 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
163 _not_send: PhantomData<Rc<()>>,
164}
165
166#[derive(Serialize)]
167pub struct ServerSnapshot<'a> {
168 peer: &'a Peer,
169 #[serde(serialize_with = "serialize_deref")]
170 connection_pool: ConnectionPoolGuard<'a>,
171}
172
173pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
174where
175 S: Serializer,
176 T: Deref<Target = U>,
177 U: Serialize,
178{
179 Serialize::serialize(value.deref(), serializer)
180}
181
182impl Server {
183 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
184 let mut server = Self {
185 id: parking_lot::Mutex::new(id),
186 peer: Peer::new(id.0 as u32),
187 app_state,
188 executor,
189 connection_pool: Default::default(),
190 handlers: Default::default(),
191 teardown: watch::channel(()).0,
192 };
193
194 server
195 .add_request_handler(ping)
196 .add_request_handler(create_room)
197 .add_request_handler(join_room)
198 .add_request_handler(rejoin_room)
199 .add_request_handler(leave_room)
200 .add_request_handler(call)
201 .add_request_handler(cancel_call)
202 .add_message_handler(decline_call)
203 .add_request_handler(update_participant_location)
204 .add_request_handler(share_project)
205 .add_message_handler(unshare_project)
206 .add_request_handler(join_project)
207 .add_message_handler(leave_project)
208 .add_request_handler(update_project)
209 .add_request_handler(update_worktree)
210 .add_message_handler(start_language_server)
211 .add_message_handler(update_language_server)
212 .add_message_handler(update_diagnostic_summary)
213 .add_message_handler(update_worktree_settings)
214 .add_message_handler(refresh_inlay_hints)
215 .add_request_handler(forward_project_request::<proto::GetHover>)
216 .add_request_handler(forward_project_request::<proto::GetDefinition>)
217 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
218 .add_request_handler(forward_project_request::<proto::GetReferences>)
219 .add_request_handler(forward_project_request::<proto::SearchProject>)
220 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
221 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
222 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
223 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
224 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
225 .add_request_handler(forward_project_request::<proto::GetCompletions>)
226 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
227 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
228 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
229 .add_request_handler(forward_project_request::<proto::PrepareRename>)
230 .add_request_handler(forward_project_request::<proto::PerformRename>)
231 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
232 .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
233 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
234 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
235 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
236 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
237 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
238 .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
239 .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
240 .add_request_handler(forward_project_request::<proto::InlayHints>)
241 .add_message_handler(create_buffer_for_peer)
242 .add_request_handler(update_buffer)
243 .add_message_handler(update_buffer_file)
244 .add_message_handler(buffer_reloaded)
245 .add_message_handler(buffer_saved)
246 .add_request_handler(forward_project_request::<proto::SaveBuffer>)
247 .add_request_handler(get_users)
248 .add_request_handler(fuzzy_search_users)
249 .add_request_handler(request_contact)
250 .add_request_handler(remove_contact)
251 .add_request_handler(respond_to_contact_request)
252 .add_request_handler(create_channel)
253 .add_request_handler(delete_channel)
254 .add_request_handler(invite_channel_member)
255 .add_request_handler(remove_channel_member)
256 .add_request_handler(set_channel_member_admin)
257 .add_request_handler(rename_channel)
258 .add_request_handler(join_channel_buffer)
259 .add_request_handler(leave_channel_buffer)
260 .add_message_handler(update_channel_buffer)
261 .add_request_handler(rejoin_channel_buffers)
262 .add_request_handler(get_channel_members)
263 .add_request_handler(respond_to_channel_invite)
264 .add_request_handler(join_channel)
265 .add_request_handler(join_channel_chat)
266 .add_message_handler(leave_channel_chat)
267 .add_request_handler(send_channel_message)
268 .add_request_handler(remove_channel_message)
269 .add_request_handler(get_channel_messages)
270 .add_request_handler(link_channel)
271 .add_request_handler(unlink_channel)
272 .add_request_handler(move_channel)
273 .add_request_handler(follow)
274 .add_message_handler(unfollow)
275 .add_message_handler(update_followers)
276 .add_message_handler(update_diff_base)
277 .add_request_handler(get_private_user_info)
278 .add_message_handler(acknowledge_channel_message);
279
280 Arc::new(server)
281 }
282
283 pub async fn start(&self) -> Result<()> {
284 let server_id = *self.id.lock();
285 let app_state = self.app_state.clone();
286 let peer = self.peer.clone();
287 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
288 let pool = self.connection_pool.clone();
289 let live_kit_client = self.app_state.live_kit_client.clone();
290
291 let span = info_span!("start server");
292 self.executor.spawn_detached(
293 async move {
294 tracing::info!("waiting for cleanup timeout");
295 timeout.await;
296 tracing::info!("cleanup timeout expired, retrieving stale rooms");
297 if let Some((room_ids, channel_ids)) = app_state
298 .db
299 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
300 .await
301 .trace_err()
302 {
303 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
304 tracing::info!(
305 stale_channel_buffer_count = channel_ids.len(),
306 "retrieved stale channel buffers"
307 );
308
309 for channel_id in channel_ids {
310 if let Some(refreshed_channel_buffer) = app_state
311 .db
312 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
313 .await
314 .trace_err()
315 {
316 for connection_id in refreshed_channel_buffer.connection_ids {
317 peer.send(
318 connection_id,
319 proto::UpdateChannelBufferCollaborators {
320 channel_id: channel_id.to_proto(),
321 collaborators: refreshed_channel_buffer
322 .collaborators
323 .clone(),
324 },
325 )
326 .trace_err();
327 }
328 }
329 }
330
331 for room_id in room_ids {
332 let mut contacts_to_update = HashSet::default();
333 let mut canceled_calls_to_user_ids = Vec::new();
334 let mut live_kit_room = String::new();
335 let mut delete_live_kit_room = false;
336
337 if let Some(mut refreshed_room) = app_state
338 .db
339 .clear_stale_room_participants(room_id, server_id)
340 .await
341 .trace_err()
342 {
343 tracing::info!(
344 room_id = room_id.0,
345 new_participant_count = refreshed_room.room.participants.len(),
346 "refreshed room"
347 );
348 room_updated(&refreshed_room.room, &peer);
349 if let Some(channel_id) = refreshed_room.channel_id {
350 channel_updated(
351 channel_id,
352 &refreshed_room.room,
353 &refreshed_room.channel_members,
354 &peer,
355 &*pool.lock(),
356 );
357 }
358 contacts_to_update
359 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
360 contacts_to_update
361 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
362 canceled_calls_to_user_ids =
363 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
364 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
365 delete_live_kit_room = refreshed_room.room.participants.is_empty();
366 }
367
368 {
369 let pool = pool.lock();
370 for canceled_user_id in canceled_calls_to_user_ids {
371 for connection_id in pool.user_connection_ids(canceled_user_id) {
372 peer.send(
373 connection_id,
374 proto::CallCanceled {
375 room_id: room_id.to_proto(),
376 },
377 )
378 .trace_err();
379 }
380 }
381 }
382
383 for user_id in contacts_to_update {
384 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
385 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
386 if let Some((busy, contacts)) = busy.zip(contacts) {
387 let pool = pool.lock();
388 let updated_contact = contact_for_user(user_id, false, busy, &pool);
389 for contact in contacts {
390 if let db::Contact::Accepted {
391 user_id: contact_user_id,
392 ..
393 } = contact
394 {
395 for contact_conn_id in
396 pool.user_connection_ids(contact_user_id)
397 {
398 peer.send(
399 contact_conn_id,
400 proto::UpdateContacts {
401 contacts: vec![updated_contact.clone()],
402 remove_contacts: Default::default(),
403 incoming_requests: Default::default(),
404 remove_incoming_requests: Default::default(),
405 outgoing_requests: Default::default(),
406 remove_outgoing_requests: Default::default(),
407 },
408 )
409 .trace_err();
410 }
411 }
412 }
413 }
414 }
415
416 if let Some(live_kit) = live_kit_client.as_ref() {
417 if delete_live_kit_room {
418 live_kit.delete_room(live_kit_room).await.trace_err();
419 }
420 }
421 }
422 }
423
424 app_state
425 .db
426 .delete_stale_servers(&app_state.config.zed_environment, server_id)
427 .await
428 .trace_err();
429 }
430 .instrument(span),
431 );
432 Ok(())
433 }
434
435 pub fn teardown(&self) {
436 self.peer.teardown();
437 self.connection_pool.lock().reset();
438 let _ = self.teardown.send(());
439 }
440
441 #[cfg(test)]
442 pub fn reset(&self, id: ServerId) {
443 self.teardown();
444 *self.id.lock() = id;
445 self.peer.reset(id.0 as u32);
446 }
447
448 #[cfg(test)]
449 pub fn id(&self) -> ServerId {
450 *self.id.lock()
451 }
452
453 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
454 where
455 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
456 Fut: 'static + Send + Future<Output = Result<()>>,
457 M: EnvelopedMessage,
458 {
459 let prev_handler = self.handlers.insert(
460 TypeId::of::<M>(),
461 Box::new(move |envelope, session| {
462 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
463 let span = info_span!(
464 "handle message",
465 payload_type = envelope.payload_type_name()
466 );
467 span.in_scope(|| {
468 tracing::info!(
469 payload_type = envelope.payload_type_name(),
470 "message received"
471 );
472 });
473 let start_time = Instant::now();
474 let future = (handler)(*envelope, session);
475 async move {
476 let result = future.await;
477 let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
478 match result {
479 Err(error) => {
480 tracing::error!(%error, ?duration_ms, "error handling message")
481 }
482 Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
483 }
484 }
485 .instrument(span)
486 .boxed()
487 }),
488 );
489 if prev_handler.is_some() {
490 panic!("registered a handler for the same message twice");
491 }
492 self
493 }
494
495 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
496 where
497 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
498 Fut: 'static + Send + Future<Output = Result<()>>,
499 M: EnvelopedMessage,
500 {
501 self.add_handler(move |envelope, session| handler(envelope.payload, session));
502 self
503 }
504
505 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
506 where
507 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
508 Fut: Send + Future<Output = Result<()>>,
509 M: RequestMessage,
510 {
511 let handler = Arc::new(handler);
512 self.add_handler(move |envelope, session| {
513 let receipt = envelope.receipt();
514 let handler = handler.clone();
515 async move {
516 let peer = session.peer.clone();
517 let responded = Arc::new(AtomicBool::default());
518 let response = Response {
519 peer: peer.clone(),
520 responded: responded.clone(),
521 receipt,
522 };
523 match (handler)(envelope.payload, response, session).await {
524 Ok(()) => {
525 if responded.load(std::sync::atomic::Ordering::SeqCst) {
526 Ok(())
527 } else {
528 Err(anyhow!("handler did not send a response"))?
529 }
530 }
531 Err(error) => {
532 peer.respond_with_error(
533 receipt,
534 proto::Error {
535 message: error.to_string(),
536 },
537 )?;
538 Err(error)
539 }
540 }
541 }
542 })
543 }
544
545 pub fn handle_connection(
546 self: &Arc<Self>,
547 connection: Connection,
548 address: String,
549 user: User,
550 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
551 executor: Executor,
552 ) -> impl Future<Output = Result<()>> {
553 let this = self.clone();
554 let user_id = user.id;
555 let login = user.github_login;
556 let span = info_span!("handle connection", %user_id, %login, %address);
557 let mut teardown = self.teardown.subscribe();
558 async move {
559 let (connection_id, handle_io, mut incoming_rx) = this
560 .peer
561 .add_connection(connection, {
562 let executor = executor.clone();
563 move |duration| executor.sleep(duration)
564 });
565
566 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
567 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
568 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
569
570 if let Some(send_connection_id) = send_connection_id.take() {
571 let _ = send_connection_id.send(connection_id);
572 }
573
574 if !user.connected_once {
575 this.peer.send(connection_id, proto::ShowContacts {})?;
576 this.app_state.db.set_user_connected_once(user_id, true).await?;
577 }
578
579 let (contacts, channels_for_user, channel_invites) = future::try_join3(
580 this.app_state.db.get_contacts(user_id),
581 this.app_state.db.get_channels_for_user(user_id),
582 this.app_state.db.get_channel_invites_for_user(user_id)
583 ).await?;
584
585 {
586 let mut pool = this.connection_pool.lock();
587 pool.add_connection(connection_id, user_id, user.admin);
588 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
589 this.peer.send(connection_id, build_initial_channels_update(
590 channels_for_user,
591 channel_invites
592 ))?;
593 }
594
595 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
596 this.peer.send(connection_id, incoming_call)?;
597 }
598
599 let session = Session {
600 user_id,
601 connection_id,
602 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
603 peer: this.peer.clone(),
604 connection_pool: this.connection_pool.clone(),
605 live_kit_client: this.app_state.live_kit_client.clone(),
606 executor: executor.clone(),
607 };
608 update_user_contacts(user_id, &session).await?;
609
610 let handle_io = handle_io.fuse();
611 futures::pin_mut!(handle_io);
612
613 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
614 // This prevents deadlocks when e.g., client A performs a request to client B and
615 // client B performs a request to client A. If both clients stop processing further
616 // messages until their respective request completes, they won't have a chance to
617 // respond to the other client's request and cause a deadlock.
618 //
619 // This arrangement ensures we will attempt to process earlier messages first, but fall
620 // back to processing messages arrived later in the spirit of making progress.
621 let mut foreground_message_handlers = FuturesUnordered::new();
622 let concurrent_handlers = Arc::new(Semaphore::new(256));
623 loop {
624 let next_message = async {
625 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
626 let message = incoming_rx.next().await;
627 (permit, message)
628 }.fuse();
629 futures::pin_mut!(next_message);
630 futures::select_biased! {
631 _ = teardown.changed().fuse() => return Ok(()),
632 result = handle_io => {
633 if let Err(error) = result {
634 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
635 }
636 break;
637 }
638 _ = foreground_message_handlers.next() => {}
639 next_message = next_message => {
640 let (permit, message) = next_message;
641 if let Some(message) = message {
642 let type_name = message.payload_type_name();
643 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
644 let span_enter = span.enter();
645 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
646 let is_background = message.is_background();
647 let handle_message = (handler)(message, session.clone());
648 drop(span_enter);
649
650 let handle_message = async move {
651 handle_message.await;
652 drop(permit);
653 }.instrument(span);
654 if is_background {
655 executor.spawn_detached(handle_message);
656 } else {
657 foreground_message_handlers.push(handle_message);
658 }
659 } else {
660 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
661 }
662 } else {
663 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
664 break;
665 }
666 }
667 }
668 }
669
670 drop(foreground_message_handlers);
671 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
672 if let Err(error) = connection_lost(session, teardown, executor).await {
673 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
674 }
675
676 Ok(())
677 }.instrument(span)
678 }
679
680 pub async fn invite_code_redeemed(
681 self: &Arc<Self>,
682 inviter_id: UserId,
683 invitee_id: UserId,
684 ) -> Result<()> {
685 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
686 if let Some(code) = &user.invite_code {
687 let pool = self.connection_pool.lock();
688 let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
689 for connection_id in pool.user_connection_ids(inviter_id) {
690 self.peer.send(
691 connection_id,
692 proto::UpdateContacts {
693 contacts: vec![invitee_contact.clone()],
694 ..Default::default()
695 },
696 )?;
697 self.peer.send(
698 connection_id,
699 proto::UpdateInviteInfo {
700 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
701 count: user.invite_count as u32,
702 },
703 )?;
704 }
705 }
706 }
707 Ok(())
708 }
709
710 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
711 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
712 if let Some(invite_code) = &user.invite_code {
713 let pool = self.connection_pool.lock();
714 for connection_id in pool.user_connection_ids(user_id) {
715 self.peer.send(
716 connection_id,
717 proto::UpdateInviteInfo {
718 url: format!(
719 "{}{}",
720 self.app_state.config.invite_link_prefix, invite_code
721 ),
722 count: user.invite_count as u32,
723 },
724 )?;
725 }
726 }
727 }
728 Ok(())
729 }
730
731 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
732 ServerSnapshot {
733 connection_pool: ConnectionPoolGuard {
734 guard: self.connection_pool.lock(),
735 _not_send: PhantomData,
736 },
737 peer: &self.peer,
738 }
739 }
740}
741
742impl<'a> Deref for ConnectionPoolGuard<'a> {
743 type Target = ConnectionPool;
744
745 fn deref(&self) -> &Self::Target {
746 &*self.guard
747 }
748}
749
750impl<'a> DerefMut for ConnectionPoolGuard<'a> {
751 fn deref_mut(&mut self) -> &mut Self::Target {
752 &mut *self.guard
753 }
754}
755
756impl<'a> Drop for ConnectionPoolGuard<'a> {
757 fn drop(&mut self) {
758 #[cfg(test)]
759 self.check_invariants();
760 }
761}
762
763fn broadcast<F>(
764 sender_id: Option<ConnectionId>,
765 receiver_ids: impl IntoIterator<Item = ConnectionId>,
766 mut f: F,
767) where
768 F: FnMut(ConnectionId) -> anyhow::Result<()>,
769{
770 for receiver_id in receiver_ids {
771 if Some(receiver_id) != sender_id {
772 if let Err(error) = f(receiver_id) {
773 tracing::error!("failed to send to {:?} {}", receiver_id, error);
774 }
775 }
776 }
777}
778
779lazy_static! {
780 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
781}
782
783pub struct ProtocolVersion(u32);
784
785impl Header for ProtocolVersion {
786 fn name() -> &'static HeaderName {
787 &ZED_PROTOCOL_VERSION
788 }
789
790 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
791 where
792 Self: Sized,
793 I: Iterator<Item = &'i axum::http::HeaderValue>,
794 {
795 let version = values
796 .next()
797 .ok_or_else(axum::headers::Error::invalid)?
798 .to_str()
799 .map_err(|_| axum::headers::Error::invalid())?
800 .parse()
801 .map_err(|_| axum::headers::Error::invalid())?;
802 Ok(Self(version))
803 }
804
805 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
806 values.extend([self.0.to_string().parse().unwrap()]);
807 }
808}
809
810pub fn routes(server: Arc<Server>) -> Router<Body> {
811 Router::new()
812 .route("/rpc", get(handle_websocket_request))
813 .layer(
814 ServiceBuilder::new()
815 .layer(Extension(server.app_state.clone()))
816 .layer(middleware::from_fn(auth::validate_header)),
817 )
818 .route("/metrics", get(handle_metrics))
819 .layer(Extension(server))
820}
821
822pub async fn handle_websocket_request(
823 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
824 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
825 Extension(server): Extension<Arc<Server>>,
826 Extension(user): Extension<User>,
827 ws: WebSocketUpgrade,
828) -> axum::response::Response {
829 if protocol_version != rpc::PROTOCOL_VERSION {
830 return (
831 StatusCode::UPGRADE_REQUIRED,
832 "client must be upgraded".to_string(),
833 )
834 .into_response();
835 }
836 let socket_address = socket_address.to_string();
837 ws.on_upgrade(move |socket| {
838 use util::ResultExt;
839 let socket = socket
840 .map_ok(to_tungstenite_message)
841 .err_into()
842 .with(|message| async move { Ok(to_axum_message(message)) });
843 let connection = Connection::new(Box::pin(socket));
844 async move {
845 server
846 .handle_connection(connection, socket_address, user, None, Executor::Production)
847 .await
848 .log_err();
849 }
850 })
851}
852
853pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
854 let connections = server
855 .connection_pool
856 .lock()
857 .connections()
858 .filter(|connection| !connection.admin)
859 .count();
860
861 METRIC_CONNECTIONS.set(connections as _);
862
863 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
864 METRIC_SHARED_PROJECTS.set(shared_projects as _);
865
866 let encoder = prometheus::TextEncoder::new();
867 let metric_families = prometheus::gather();
868 let encoded_metrics = encoder
869 .encode_to_string(&metric_families)
870 .map_err(|err| anyhow!("{}", err))?;
871 Ok(encoded_metrics)
872}
873
874#[instrument(err, skip(executor))]
875async fn connection_lost(
876 session: Session,
877 mut teardown: watch::Receiver<()>,
878 executor: Executor,
879) -> Result<()> {
880 session.peer.disconnect(session.connection_id);
881 session
882 .connection_pool()
883 .await
884 .remove_connection(session.connection_id)?;
885
886 session
887 .db()
888 .await
889 .connection_lost(session.connection_id)
890 .await
891 .trace_err();
892
893 futures::select_biased! {
894 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
895 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
896 leave_room_for_session(&session).await.trace_err();
897 leave_channel_buffers_for_session(&session)
898 .await
899 .trace_err();
900
901 if !session
902 .connection_pool()
903 .await
904 .is_user_online(session.user_id)
905 {
906 let db = session.db().await;
907 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
908 room_updated(&room, &session.peer);
909 }
910 }
911
912 update_user_contacts(session.user_id, &session).await?;
913 }
914 _ = teardown.changed().fuse() => {}
915 }
916
917 Ok(())
918}
919
920async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
921 response.send(proto::Ack {})?;
922 Ok(())
923}
924
925async fn create_room(
926 _request: proto::CreateRoom,
927 response: Response<proto::CreateRoom>,
928 session: Session,
929) -> Result<()> {
930 let live_kit_room = nanoid::nanoid!(30);
931
932 let live_kit_connection_info = {
933 let live_kit_room = live_kit_room.clone();
934 let live_kit = session.live_kit_client.as_ref();
935
936 util::async_iife!({
937 let live_kit = live_kit?;
938
939 live_kit
940 .create_room(live_kit_room.clone())
941 .await
942 .trace_err()?;
943
944 let token = live_kit
945 .room_token(&live_kit_room, &session.user_id.to_string())
946 .trace_err()?;
947
948 Some(proto::LiveKitConnectionInfo {
949 server_url: live_kit.url().into(),
950 token,
951 })
952 })
953 }
954 .await;
955
956 let room = session
957 .db()
958 .await
959 .create_room(session.user_id, session.connection_id, &live_kit_room)
960 .await?;
961
962 response.send(proto::CreateRoomResponse {
963 room: Some(room.clone()),
964 live_kit_connection_info,
965 })?;
966
967 update_user_contacts(session.user_id, &session).await?;
968 Ok(())
969}
970
971async fn join_room(
972 request: proto::JoinRoom,
973 response: Response<proto::JoinRoom>,
974 session: Session,
975) -> Result<()> {
976 let room_id = RoomId::from_proto(request.id);
977 let joined_room = {
978 let room = session
979 .db()
980 .await
981 .join_room(room_id, session.user_id, session.connection_id)
982 .await?;
983 room_updated(&room.room, &session.peer);
984 room.into_inner()
985 };
986
987 if let Some(channel_id) = joined_room.channel_id {
988 channel_updated(
989 channel_id,
990 &joined_room.room,
991 &joined_room.channel_members,
992 &session.peer,
993 &*session.connection_pool().await,
994 )
995 }
996
997 for connection_id in session
998 .connection_pool()
999 .await
1000 .user_connection_ids(session.user_id)
1001 {
1002 session
1003 .peer
1004 .send(
1005 connection_id,
1006 proto::CallCanceled {
1007 room_id: room_id.to_proto(),
1008 },
1009 )
1010 .trace_err();
1011 }
1012
1013 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1014 if let Some(token) = live_kit
1015 .room_token(
1016 &joined_room.room.live_kit_room,
1017 &session.user_id.to_string(),
1018 )
1019 .trace_err()
1020 {
1021 Some(proto::LiveKitConnectionInfo {
1022 server_url: live_kit.url().into(),
1023 token,
1024 })
1025 } else {
1026 None
1027 }
1028 } else {
1029 None
1030 };
1031
1032 response.send(proto::JoinRoomResponse {
1033 room: Some(joined_room.room),
1034 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
1035 live_kit_connection_info,
1036 })?;
1037
1038 update_user_contacts(session.user_id, &session).await?;
1039 Ok(())
1040}
1041
1042async fn rejoin_room(
1043 request: proto::RejoinRoom,
1044 response: Response<proto::RejoinRoom>,
1045 session: Session,
1046) -> Result<()> {
1047 let room;
1048 let channel_id;
1049 let channel_members;
1050 {
1051 let mut rejoined_room = session
1052 .db()
1053 .await
1054 .rejoin_room(request, session.user_id, session.connection_id)
1055 .await?;
1056
1057 response.send(proto::RejoinRoomResponse {
1058 room: Some(rejoined_room.room.clone()),
1059 reshared_projects: rejoined_room
1060 .reshared_projects
1061 .iter()
1062 .map(|project| proto::ResharedProject {
1063 id: project.id.to_proto(),
1064 collaborators: project
1065 .collaborators
1066 .iter()
1067 .map(|collaborator| collaborator.to_proto())
1068 .collect(),
1069 })
1070 .collect(),
1071 rejoined_projects: rejoined_room
1072 .rejoined_projects
1073 .iter()
1074 .map(|rejoined_project| proto::RejoinedProject {
1075 id: rejoined_project.id.to_proto(),
1076 worktrees: rejoined_project
1077 .worktrees
1078 .iter()
1079 .map(|worktree| proto::WorktreeMetadata {
1080 id: worktree.id,
1081 root_name: worktree.root_name.clone(),
1082 visible: worktree.visible,
1083 abs_path: worktree.abs_path.clone(),
1084 })
1085 .collect(),
1086 collaborators: rejoined_project
1087 .collaborators
1088 .iter()
1089 .map(|collaborator| collaborator.to_proto())
1090 .collect(),
1091 language_servers: rejoined_project.language_servers.clone(),
1092 })
1093 .collect(),
1094 })?;
1095 room_updated(&rejoined_room.room, &session.peer);
1096
1097 for project in &rejoined_room.reshared_projects {
1098 for collaborator in &project.collaborators {
1099 session
1100 .peer
1101 .send(
1102 collaborator.connection_id,
1103 proto::UpdateProjectCollaborator {
1104 project_id: project.id.to_proto(),
1105 old_peer_id: Some(project.old_connection_id.into()),
1106 new_peer_id: Some(session.connection_id.into()),
1107 },
1108 )
1109 .trace_err();
1110 }
1111
1112 broadcast(
1113 Some(session.connection_id),
1114 project
1115 .collaborators
1116 .iter()
1117 .map(|collaborator| collaborator.connection_id),
1118 |connection_id| {
1119 session.peer.forward_send(
1120 session.connection_id,
1121 connection_id,
1122 proto::UpdateProject {
1123 project_id: project.id.to_proto(),
1124 worktrees: project.worktrees.clone(),
1125 },
1126 )
1127 },
1128 );
1129 }
1130
1131 for project in &rejoined_room.rejoined_projects {
1132 for collaborator in &project.collaborators {
1133 session
1134 .peer
1135 .send(
1136 collaborator.connection_id,
1137 proto::UpdateProjectCollaborator {
1138 project_id: project.id.to_proto(),
1139 old_peer_id: Some(project.old_connection_id.into()),
1140 new_peer_id: Some(session.connection_id.into()),
1141 },
1142 )
1143 .trace_err();
1144 }
1145 }
1146
1147 for project in &mut rejoined_room.rejoined_projects {
1148 for worktree in mem::take(&mut project.worktrees) {
1149 #[cfg(any(test, feature = "test-support"))]
1150 const MAX_CHUNK_SIZE: usize = 2;
1151 #[cfg(not(any(test, feature = "test-support")))]
1152 const MAX_CHUNK_SIZE: usize = 256;
1153
1154 // Stream this worktree's entries.
1155 let message = proto::UpdateWorktree {
1156 project_id: project.id.to_proto(),
1157 worktree_id: worktree.id,
1158 abs_path: worktree.abs_path.clone(),
1159 root_name: worktree.root_name,
1160 updated_entries: worktree.updated_entries,
1161 removed_entries: worktree.removed_entries,
1162 scan_id: worktree.scan_id,
1163 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1164 updated_repositories: worktree.updated_repositories,
1165 removed_repositories: worktree.removed_repositories,
1166 };
1167 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1168 session.peer.send(session.connection_id, update.clone())?;
1169 }
1170
1171 // Stream this worktree's diagnostics.
1172 for summary in worktree.diagnostic_summaries {
1173 session.peer.send(
1174 session.connection_id,
1175 proto::UpdateDiagnosticSummary {
1176 project_id: project.id.to_proto(),
1177 worktree_id: worktree.id,
1178 summary: Some(summary),
1179 },
1180 )?;
1181 }
1182
1183 for settings_file in worktree.settings_files {
1184 session.peer.send(
1185 session.connection_id,
1186 proto::UpdateWorktreeSettings {
1187 project_id: project.id.to_proto(),
1188 worktree_id: worktree.id,
1189 path: settings_file.path,
1190 content: Some(settings_file.content),
1191 },
1192 )?;
1193 }
1194 }
1195
1196 for language_server in &project.language_servers {
1197 session.peer.send(
1198 session.connection_id,
1199 proto::UpdateLanguageServer {
1200 project_id: project.id.to_proto(),
1201 language_server_id: language_server.id,
1202 variant: Some(
1203 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1204 proto::LspDiskBasedDiagnosticsUpdated {},
1205 ),
1206 ),
1207 },
1208 )?;
1209 }
1210 }
1211
1212 let rejoined_room = rejoined_room.into_inner();
1213
1214 room = rejoined_room.room;
1215 channel_id = rejoined_room.channel_id;
1216 channel_members = rejoined_room.channel_members;
1217 }
1218
1219 if let Some(channel_id) = channel_id {
1220 channel_updated(
1221 channel_id,
1222 &room,
1223 &channel_members,
1224 &session.peer,
1225 &*session.connection_pool().await,
1226 );
1227 }
1228
1229 update_user_contacts(session.user_id, &session).await?;
1230 Ok(())
1231}
1232
1233async fn leave_room(
1234 _: proto::LeaveRoom,
1235 response: Response<proto::LeaveRoom>,
1236 session: Session,
1237) -> Result<()> {
1238 leave_room_for_session(&session).await?;
1239 response.send(proto::Ack {})?;
1240 Ok(())
1241}
1242
1243async fn call(
1244 request: proto::Call,
1245 response: Response<proto::Call>,
1246 session: Session,
1247) -> Result<()> {
1248 let room_id = RoomId::from_proto(request.room_id);
1249 let calling_user_id = session.user_id;
1250 let calling_connection_id = session.connection_id;
1251 let called_user_id = UserId::from_proto(request.called_user_id);
1252 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1253 if !session
1254 .db()
1255 .await
1256 .has_contact(calling_user_id, called_user_id)
1257 .await?
1258 {
1259 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1260 }
1261
1262 let incoming_call = {
1263 let (room, incoming_call) = &mut *session
1264 .db()
1265 .await
1266 .call(
1267 room_id,
1268 calling_user_id,
1269 calling_connection_id,
1270 called_user_id,
1271 initial_project_id,
1272 )
1273 .await?;
1274 room_updated(&room, &session.peer);
1275 mem::take(incoming_call)
1276 };
1277 update_user_contacts(called_user_id, &session).await?;
1278
1279 let mut calls = session
1280 .connection_pool()
1281 .await
1282 .user_connection_ids(called_user_id)
1283 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1284 .collect::<FuturesUnordered<_>>();
1285
1286 while let Some(call_response) = calls.next().await {
1287 match call_response.as_ref() {
1288 Ok(_) => {
1289 response.send(proto::Ack {})?;
1290 return Ok(());
1291 }
1292 Err(_) => {
1293 call_response.trace_err();
1294 }
1295 }
1296 }
1297
1298 {
1299 let room = session
1300 .db()
1301 .await
1302 .call_failed(room_id, called_user_id)
1303 .await?;
1304 room_updated(&room, &session.peer);
1305 }
1306 update_user_contacts(called_user_id, &session).await?;
1307
1308 Err(anyhow!("failed to ring user"))?
1309}
1310
1311async fn cancel_call(
1312 request: proto::CancelCall,
1313 response: Response<proto::CancelCall>,
1314 session: Session,
1315) -> Result<()> {
1316 let called_user_id = UserId::from_proto(request.called_user_id);
1317 let room_id = RoomId::from_proto(request.room_id);
1318 {
1319 let room = session
1320 .db()
1321 .await
1322 .cancel_call(room_id, session.connection_id, called_user_id)
1323 .await?;
1324 room_updated(&room, &session.peer);
1325 }
1326
1327 for connection_id in session
1328 .connection_pool()
1329 .await
1330 .user_connection_ids(called_user_id)
1331 {
1332 session
1333 .peer
1334 .send(
1335 connection_id,
1336 proto::CallCanceled {
1337 room_id: room_id.to_proto(),
1338 },
1339 )
1340 .trace_err();
1341 }
1342 response.send(proto::Ack {})?;
1343
1344 update_user_contacts(called_user_id, &session).await?;
1345 Ok(())
1346}
1347
1348async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1349 let room_id = RoomId::from_proto(message.room_id);
1350 {
1351 let room = session
1352 .db()
1353 .await
1354 .decline_call(Some(room_id), session.user_id)
1355 .await?
1356 .ok_or_else(|| anyhow!("failed to decline call"))?;
1357 room_updated(&room, &session.peer);
1358 }
1359
1360 for connection_id in session
1361 .connection_pool()
1362 .await
1363 .user_connection_ids(session.user_id)
1364 {
1365 session
1366 .peer
1367 .send(
1368 connection_id,
1369 proto::CallCanceled {
1370 room_id: room_id.to_proto(),
1371 },
1372 )
1373 .trace_err();
1374 }
1375 update_user_contacts(session.user_id, &session).await?;
1376 Ok(())
1377}
1378
1379async fn update_participant_location(
1380 request: proto::UpdateParticipantLocation,
1381 response: Response<proto::UpdateParticipantLocation>,
1382 session: Session,
1383) -> Result<()> {
1384 let room_id = RoomId::from_proto(request.room_id);
1385 let location = request
1386 .location
1387 .ok_or_else(|| anyhow!("invalid location"))?;
1388
1389 let db = session.db().await;
1390 let room = db
1391 .update_room_participant_location(room_id, session.connection_id, location)
1392 .await?;
1393
1394 room_updated(&room, &session.peer);
1395 response.send(proto::Ack {})?;
1396 Ok(())
1397}
1398
1399async fn share_project(
1400 request: proto::ShareProject,
1401 response: Response<proto::ShareProject>,
1402 session: Session,
1403) -> Result<()> {
1404 let (project_id, room) = &*session
1405 .db()
1406 .await
1407 .share_project(
1408 RoomId::from_proto(request.room_id),
1409 session.connection_id,
1410 &request.worktrees,
1411 )
1412 .await?;
1413 response.send(proto::ShareProjectResponse {
1414 project_id: project_id.to_proto(),
1415 })?;
1416 room_updated(&room, &session.peer);
1417
1418 Ok(())
1419}
1420
1421async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1422 let project_id = ProjectId::from_proto(message.project_id);
1423
1424 let (room, guest_connection_ids) = &*session
1425 .db()
1426 .await
1427 .unshare_project(project_id, session.connection_id)
1428 .await?;
1429
1430 broadcast(
1431 Some(session.connection_id),
1432 guest_connection_ids.iter().copied(),
1433 |conn_id| session.peer.send(conn_id, message.clone()),
1434 );
1435 room_updated(&room, &session.peer);
1436
1437 Ok(())
1438}
1439
1440async fn join_project(
1441 request: proto::JoinProject,
1442 response: Response<proto::JoinProject>,
1443 session: Session,
1444) -> Result<()> {
1445 let project_id = ProjectId::from_proto(request.project_id);
1446 let guest_user_id = session.user_id;
1447
1448 tracing::info!(%project_id, "join project");
1449
1450 let (project, replica_id) = &mut *session
1451 .db()
1452 .await
1453 .join_project(project_id, session.connection_id)
1454 .await?;
1455
1456 let collaborators = project
1457 .collaborators
1458 .iter()
1459 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1460 .map(|collaborator| collaborator.to_proto())
1461 .collect::<Vec<_>>();
1462
1463 let worktrees = project
1464 .worktrees
1465 .iter()
1466 .map(|(id, worktree)| proto::WorktreeMetadata {
1467 id: *id,
1468 root_name: worktree.root_name.clone(),
1469 visible: worktree.visible,
1470 abs_path: worktree.abs_path.clone(),
1471 })
1472 .collect::<Vec<_>>();
1473
1474 for collaborator in &collaborators {
1475 session
1476 .peer
1477 .send(
1478 collaborator.peer_id.unwrap().into(),
1479 proto::AddProjectCollaborator {
1480 project_id: project_id.to_proto(),
1481 collaborator: Some(proto::Collaborator {
1482 peer_id: Some(session.connection_id.into()),
1483 replica_id: replica_id.0 as u32,
1484 user_id: guest_user_id.to_proto(),
1485 }),
1486 },
1487 )
1488 .trace_err();
1489 }
1490
1491 // First, we send the metadata associated with each worktree.
1492 response.send(proto::JoinProjectResponse {
1493 worktrees: worktrees.clone(),
1494 replica_id: replica_id.0 as u32,
1495 collaborators: collaborators.clone(),
1496 language_servers: project.language_servers.clone(),
1497 })?;
1498
1499 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1500 #[cfg(any(test, feature = "test-support"))]
1501 const MAX_CHUNK_SIZE: usize = 2;
1502 #[cfg(not(any(test, feature = "test-support")))]
1503 const MAX_CHUNK_SIZE: usize = 256;
1504
1505 // Stream this worktree's entries.
1506 let message = proto::UpdateWorktree {
1507 project_id: project_id.to_proto(),
1508 worktree_id,
1509 abs_path: worktree.abs_path.clone(),
1510 root_name: worktree.root_name,
1511 updated_entries: worktree.entries,
1512 removed_entries: Default::default(),
1513 scan_id: worktree.scan_id,
1514 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1515 updated_repositories: worktree.repository_entries.into_values().collect(),
1516 removed_repositories: Default::default(),
1517 };
1518 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1519 session.peer.send(session.connection_id, update.clone())?;
1520 }
1521
1522 // Stream this worktree's diagnostics.
1523 for summary in worktree.diagnostic_summaries {
1524 session.peer.send(
1525 session.connection_id,
1526 proto::UpdateDiagnosticSummary {
1527 project_id: project_id.to_proto(),
1528 worktree_id: worktree.id,
1529 summary: Some(summary),
1530 },
1531 )?;
1532 }
1533
1534 for settings_file in worktree.settings_files {
1535 session.peer.send(
1536 session.connection_id,
1537 proto::UpdateWorktreeSettings {
1538 project_id: project_id.to_proto(),
1539 worktree_id: worktree.id,
1540 path: settings_file.path,
1541 content: Some(settings_file.content),
1542 },
1543 )?;
1544 }
1545 }
1546
1547 for language_server in &project.language_servers {
1548 session.peer.send(
1549 session.connection_id,
1550 proto::UpdateLanguageServer {
1551 project_id: project_id.to_proto(),
1552 language_server_id: language_server.id,
1553 variant: Some(
1554 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1555 proto::LspDiskBasedDiagnosticsUpdated {},
1556 ),
1557 ),
1558 },
1559 )?;
1560 }
1561
1562 Ok(())
1563}
1564
1565async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1566 let sender_id = session.connection_id;
1567 let project_id = ProjectId::from_proto(request.project_id);
1568
1569 let (room, project) = &*session
1570 .db()
1571 .await
1572 .leave_project(project_id, sender_id)
1573 .await?;
1574 tracing::info!(
1575 %project_id,
1576 host_user_id = %project.host_user_id,
1577 host_connection_id = %project.host_connection_id,
1578 "leave project"
1579 );
1580
1581 project_left(&project, &session);
1582 room_updated(&room, &session.peer);
1583
1584 Ok(())
1585}
1586
1587async fn update_project(
1588 request: proto::UpdateProject,
1589 response: Response<proto::UpdateProject>,
1590 session: Session,
1591) -> Result<()> {
1592 let project_id = ProjectId::from_proto(request.project_id);
1593 let (room, guest_connection_ids) = &*session
1594 .db()
1595 .await
1596 .update_project(project_id, session.connection_id, &request.worktrees)
1597 .await?;
1598 broadcast(
1599 Some(session.connection_id),
1600 guest_connection_ids.iter().copied(),
1601 |connection_id| {
1602 session
1603 .peer
1604 .forward_send(session.connection_id, connection_id, request.clone())
1605 },
1606 );
1607 room_updated(&room, &session.peer);
1608 response.send(proto::Ack {})?;
1609
1610 Ok(())
1611}
1612
1613async fn update_worktree(
1614 request: proto::UpdateWorktree,
1615 response: Response<proto::UpdateWorktree>,
1616 session: Session,
1617) -> Result<()> {
1618 let guest_connection_ids = session
1619 .db()
1620 .await
1621 .update_worktree(&request, session.connection_id)
1622 .await?;
1623
1624 broadcast(
1625 Some(session.connection_id),
1626 guest_connection_ids.iter().copied(),
1627 |connection_id| {
1628 session
1629 .peer
1630 .forward_send(session.connection_id, connection_id, request.clone())
1631 },
1632 );
1633 response.send(proto::Ack {})?;
1634 Ok(())
1635}
1636
1637async fn update_diagnostic_summary(
1638 message: proto::UpdateDiagnosticSummary,
1639 session: Session,
1640) -> Result<()> {
1641 let guest_connection_ids = session
1642 .db()
1643 .await
1644 .update_diagnostic_summary(&message, session.connection_id)
1645 .await?;
1646
1647 broadcast(
1648 Some(session.connection_id),
1649 guest_connection_ids.iter().copied(),
1650 |connection_id| {
1651 session
1652 .peer
1653 .forward_send(session.connection_id, connection_id, message.clone())
1654 },
1655 );
1656
1657 Ok(())
1658}
1659
1660async fn update_worktree_settings(
1661 message: proto::UpdateWorktreeSettings,
1662 session: Session,
1663) -> Result<()> {
1664 let guest_connection_ids = session
1665 .db()
1666 .await
1667 .update_worktree_settings(&message, session.connection_id)
1668 .await?;
1669
1670 broadcast(
1671 Some(session.connection_id),
1672 guest_connection_ids.iter().copied(),
1673 |connection_id| {
1674 session
1675 .peer
1676 .forward_send(session.connection_id, connection_id, message.clone())
1677 },
1678 );
1679
1680 Ok(())
1681}
1682
1683async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1684 broadcast_project_message(request.project_id, request, session).await
1685}
1686
1687async fn start_language_server(
1688 request: proto::StartLanguageServer,
1689 session: Session,
1690) -> Result<()> {
1691 let guest_connection_ids = session
1692 .db()
1693 .await
1694 .start_language_server(&request, session.connection_id)
1695 .await?;
1696
1697 broadcast(
1698 Some(session.connection_id),
1699 guest_connection_ids.iter().copied(),
1700 |connection_id| {
1701 session
1702 .peer
1703 .forward_send(session.connection_id, connection_id, request.clone())
1704 },
1705 );
1706 Ok(())
1707}
1708
1709async fn update_language_server(
1710 request: proto::UpdateLanguageServer,
1711 session: Session,
1712) -> Result<()> {
1713 session.executor.record_backtrace();
1714 let project_id = ProjectId::from_proto(request.project_id);
1715 let project_connection_ids = session
1716 .db()
1717 .await
1718 .project_connection_ids(project_id, session.connection_id)
1719 .await?;
1720 broadcast(
1721 Some(session.connection_id),
1722 project_connection_ids.iter().copied(),
1723 |connection_id| {
1724 session
1725 .peer
1726 .forward_send(session.connection_id, connection_id, request.clone())
1727 },
1728 );
1729 Ok(())
1730}
1731
1732async fn forward_project_request<T>(
1733 request: T,
1734 response: Response<T>,
1735 session: Session,
1736) -> Result<()>
1737where
1738 T: EntityMessage + RequestMessage,
1739{
1740 session.executor.record_backtrace();
1741 let project_id = ProjectId::from_proto(request.remote_entity_id());
1742 let host_connection_id = {
1743 let collaborators = session
1744 .db()
1745 .await
1746 .project_collaborators(project_id, session.connection_id)
1747 .await?;
1748 collaborators
1749 .iter()
1750 .find(|collaborator| collaborator.is_host)
1751 .ok_or_else(|| anyhow!("host not found"))?
1752 .connection_id
1753 };
1754
1755 let payload = session
1756 .peer
1757 .forward_request(session.connection_id, host_connection_id, request)
1758 .await?;
1759
1760 response.send(payload)?;
1761 Ok(())
1762}
1763
1764async fn create_buffer_for_peer(
1765 request: proto::CreateBufferForPeer,
1766 session: Session,
1767) -> Result<()> {
1768 session.executor.record_backtrace();
1769 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1770 session
1771 .peer
1772 .forward_send(session.connection_id, peer_id.into(), request)?;
1773 Ok(())
1774}
1775
1776async fn update_buffer(
1777 request: proto::UpdateBuffer,
1778 response: Response<proto::UpdateBuffer>,
1779 session: Session,
1780) -> Result<()> {
1781 session.executor.record_backtrace();
1782 let project_id = ProjectId::from_proto(request.project_id);
1783 let mut guest_connection_ids;
1784 let mut host_connection_id = None;
1785 {
1786 let collaborators = session
1787 .db()
1788 .await
1789 .project_collaborators(project_id, session.connection_id)
1790 .await?;
1791 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1792 for collaborator in collaborators.iter() {
1793 if collaborator.is_host {
1794 host_connection_id = Some(collaborator.connection_id);
1795 } else {
1796 guest_connection_ids.push(collaborator.connection_id);
1797 }
1798 }
1799 }
1800 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1801
1802 session.executor.record_backtrace();
1803 broadcast(
1804 Some(session.connection_id),
1805 guest_connection_ids,
1806 |connection_id| {
1807 session
1808 .peer
1809 .forward_send(session.connection_id, connection_id, request.clone())
1810 },
1811 );
1812 if host_connection_id != session.connection_id {
1813 session
1814 .peer
1815 .forward_request(session.connection_id, host_connection_id, request.clone())
1816 .await?;
1817 }
1818
1819 response.send(proto::Ack {})?;
1820 Ok(())
1821}
1822
1823async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1824 let project_id = ProjectId::from_proto(request.project_id);
1825 let project_connection_ids = session
1826 .db()
1827 .await
1828 .project_connection_ids(project_id, session.connection_id)
1829 .await?;
1830
1831 broadcast(
1832 Some(session.connection_id),
1833 project_connection_ids.iter().copied(),
1834 |connection_id| {
1835 session
1836 .peer
1837 .forward_send(session.connection_id, connection_id, request.clone())
1838 },
1839 );
1840 Ok(())
1841}
1842
1843async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1844 let project_id = ProjectId::from_proto(request.project_id);
1845 let project_connection_ids = session
1846 .db()
1847 .await
1848 .project_connection_ids(project_id, session.connection_id)
1849 .await?;
1850 broadcast(
1851 Some(session.connection_id),
1852 project_connection_ids.iter().copied(),
1853 |connection_id| {
1854 session
1855 .peer
1856 .forward_send(session.connection_id, connection_id, request.clone())
1857 },
1858 );
1859 Ok(())
1860}
1861
1862async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1863 broadcast_project_message(request.project_id, request, session).await
1864}
1865
1866async fn broadcast_project_message<T: EnvelopedMessage>(
1867 project_id: u64,
1868 request: T,
1869 session: Session,
1870) -> Result<()> {
1871 let project_id = ProjectId::from_proto(project_id);
1872 let project_connection_ids = session
1873 .db()
1874 .await
1875 .project_connection_ids(project_id, session.connection_id)
1876 .await?;
1877 broadcast(
1878 Some(session.connection_id),
1879 project_connection_ids.iter().copied(),
1880 |connection_id| {
1881 session
1882 .peer
1883 .forward_send(session.connection_id, connection_id, request.clone())
1884 },
1885 );
1886 Ok(())
1887}
1888
1889async fn follow(
1890 request: proto::Follow,
1891 response: Response<proto::Follow>,
1892 session: Session,
1893) -> Result<()> {
1894 let room_id = RoomId::from_proto(request.room_id);
1895 let project_id = request.project_id.map(ProjectId::from_proto);
1896 let leader_id = request
1897 .leader_id
1898 .ok_or_else(|| anyhow!("invalid leader id"))?
1899 .into();
1900 let follower_id = session.connection_id;
1901
1902 session
1903 .db()
1904 .await
1905 .check_room_participants(room_id, leader_id, session.connection_id)
1906 .await?;
1907
1908 let mut response_payload = session
1909 .peer
1910 .forward_request(session.connection_id, leader_id, request)
1911 .await?;
1912 response_payload
1913 .views
1914 .retain(|view| view.leader_id != Some(follower_id.into()));
1915 response.send(response_payload)?;
1916
1917 if let Some(project_id) = project_id {
1918 let room = session
1919 .db()
1920 .await
1921 .follow(room_id, project_id, leader_id, follower_id)
1922 .await?;
1923 room_updated(&room, &session.peer);
1924 }
1925
1926 Ok(())
1927}
1928
1929async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1930 let room_id = RoomId::from_proto(request.room_id);
1931 let project_id = request.project_id.map(ProjectId::from_proto);
1932 let leader_id = request
1933 .leader_id
1934 .ok_or_else(|| anyhow!("invalid leader id"))?
1935 .into();
1936 let follower_id = session.connection_id;
1937
1938 session
1939 .db()
1940 .await
1941 .check_room_participants(room_id, leader_id, session.connection_id)
1942 .await?;
1943
1944 session
1945 .peer
1946 .forward_send(session.connection_id, leader_id, request)?;
1947
1948 if let Some(project_id) = project_id {
1949 let room = session
1950 .db()
1951 .await
1952 .unfollow(room_id, project_id, leader_id, follower_id)
1953 .await?;
1954 room_updated(&room, &session.peer);
1955 }
1956
1957 Ok(())
1958}
1959
1960async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1961 let room_id = RoomId::from_proto(request.room_id);
1962 let database = session.db.lock().await;
1963
1964 let connection_ids = if let Some(project_id) = request.project_id {
1965 let project_id = ProjectId::from_proto(project_id);
1966 database
1967 .project_connection_ids(project_id, session.connection_id)
1968 .await?
1969 } else {
1970 database
1971 .room_connection_ids(room_id, session.connection_id)
1972 .await?
1973 };
1974
1975 let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1976 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1977 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1978 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1979 });
1980 for follower_peer_id in request.follower_ids.iter().copied() {
1981 let follower_connection_id = follower_peer_id.into();
1982 if Some(follower_peer_id) != leader_id && connection_ids.contains(&follower_connection_id) {
1983 session.peer.forward_send(
1984 session.connection_id,
1985 follower_connection_id,
1986 request.clone(),
1987 )?;
1988 }
1989 }
1990 Ok(())
1991}
1992
1993async fn get_users(
1994 request: proto::GetUsers,
1995 response: Response<proto::GetUsers>,
1996 session: Session,
1997) -> Result<()> {
1998 let user_ids = request
1999 .user_ids
2000 .into_iter()
2001 .map(UserId::from_proto)
2002 .collect();
2003 let users = session
2004 .db()
2005 .await
2006 .get_users_by_ids(user_ids)
2007 .await?
2008 .into_iter()
2009 .map(|user| proto::User {
2010 id: user.id.to_proto(),
2011 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2012 github_login: user.github_login,
2013 })
2014 .collect();
2015 response.send(proto::UsersResponse { users })?;
2016 Ok(())
2017}
2018
2019async fn fuzzy_search_users(
2020 request: proto::FuzzySearchUsers,
2021 response: Response<proto::FuzzySearchUsers>,
2022 session: Session,
2023) -> Result<()> {
2024 let query = request.query;
2025 let users = match query.len() {
2026 0 => vec![],
2027 1 | 2 => session
2028 .db()
2029 .await
2030 .get_user_by_github_login(&query)
2031 .await?
2032 .into_iter()
2033 .collect(),
2034 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2035 };
2036 let users = users
2037 .into_iter()
2038 .filter(|user| user.id != session.user_id)
2039 .map(|user| proto::User {
2040 id: user.id.to_proto(),
2041 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2042 github_login: user.github_login,
2043 })
2044 .collect();
2045 response.send(proto::UsersResponse { users })?;
2046 Ok(())
2047}
2048
2049async fn request_contact(
2050 request: proto::RequestContact,
2051 response: Response<proto::RequestContact>,
2052 session: Session,
2053) -> Result<()> {
2054 let requester_id = session.user_id;
2055 let responder_id = UserId::from_proto(request.responder_id);
2056 if requester_id == responder_id {
2057 return Err(anyhow!("cannot add yourself as a contact"))?;
2058 }
2059
2060 session
2061 .db()
2062 .await
2063 .send_contact_request(requester_id, responder_id)
2064 .await?;
2065
2066 // Update outgoing contact requests of requester
2067 let mut update = proto::UpdateContacts::default();
2068 update.outgoing_requests.push(responder_id.to_proto());
2069 for connection_id in session
2070 .connection_pool()
2071 .await
2072 .user_connection_ids(requester_id)
2073 {
2074 session.peer.send(connection_id, update.clone())?;
2075 }
2076
2077 // Update incoming contact requests of responder
2078 let mut update = proto::UpdateContacts::default();
2079 update
2080 .incoming_requests
2081 .push(proto::IncomingContactRequest {
2082 requester_id: requester_id.to_proto(),
2083 should_notify: true,
2084 });
2085 for connection_id in session
2086 .connection_pool()
2087 .await
2088 .user_connection_ids(responder_id)
2089 {
2090 session.peer.send(connection_id, update.clone())?;
2091 }
2092
2093 response.send(proto::Ack {})?;
2094 Ok(())
2095}
2096
2097async fn respond_to_contact_request(
2098 request: proto::RespondToContactRequest,
2099 response: Response<proto::RespondToContactRequest>,
2100 session: Session,
2101) -> Result<()> {
2102 let responder_id = session.user_id;
2103 let requester_id = UserId::from_proto(request.requester_id);
2104 let db = session.db().await;
2105 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2106 db.dismiss_contact_notification(responder_id, requester_id)
2107 .await?;
2108 } else {
2109 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2110
2111 db.respond_to_contact_request(responder_id, requester_id, accept)
2112 .await?;
2113 let requester_busy = db.is_user_busy(requester_id).await?;
2114 let responder_busy = db.is_user_busy(responder_id).await?;
2115
2116 let pool = session.connection_pool().await;
2117 // Update responder with new contact
2118 let mut update = proto::UpdateContacts::default();
2119 if accept {
2120 update
2121 .contacts
2122 .push(contact_for_user(requester_id, false, requester_busy, &pool));
2123 }
2124 update
2125 .remove_incoming_requests
2126 .push(requester_id.to_proto());
2127 for connection_id in pool.user_connection_ids(responder_id) {
2128 session.peer.send(connection_id, update.clone())?;
2129 }
2130
2131 // Update requester with new contact
2132 let mut update = proto::UpdateContacts::default();
2133 if accept {
2134 update
2135 .contacts
2136 .push(contact_for_user(responder_id, true, responder_busy, &pool));
2137 }
2138 update
2139 .remove_outgoing_requests
2140 .push(responder_id.to_proto());
2141 for connection_id in pool.user_connection_ids(requester_id) {
2142 session.peer.send(connection_id, update.clone())?;
2143 }
2144 }
2145
2146 response.send(proto::Ack {})?;
2147 Ok(())
2148}
2149
2150async fn remove_contact(
2151 request: proto::RemoveContact,
2152 response: Response<proto::RemoveContact>,
2153 session: Session,
2154) -> Result<()> {
2155 let requester_id = session.user_id;
2156 let responder_id = UserId::from_proto(request.user_id);
2157 let db = session.db().await;
2158 let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
2159
2160 let pool = session.connection_pool().await;
2161 // Update outgoing contact requests of requester
2162 let mut update = proto::UpdateContacts::default();
2163 if contact_accepted {
2164 update.remove_contacts.push(responder_id.to_proto());
2165 } else {
2166 update
2167 .remove_outgoing_requests
2168 .push(responder_id.to_proto());
2169 }
2170 for connection_id in pool.user_connection_ids(requester_id) {
2171 session.peer.send(connection_id, update.clone())?;
2172 }
2173
2174 // Update incoming contact requests of responder
2175 let mut update = proto::UpdateContacts::default();
2176 if contact_accepted {
2177 update.remove_contacts.push(requester_id.to_proto());
2178 } else {
2179 update
2180 .remove_incoming_requests
2181 .push(requester_id.to_proto());
2182 }
2183 for connection_id in pool.user_connection_ids(responder_id) {
2184 session.peer.send(connection_id, update.clone())?;
2185 }
2186
2187 response.send(proto::Ack {})?;
2188 Ok(())
2189}
2190
2191async fn create_channel(
2192 request: proto::CreateChannel,
2193 response: Response<proto::CreateChannel>,
2194 session: Session,
2195) -> Result<()> {
2196 let db = session.db().await;
2197 let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
2198
2199 if let Some(live_kit) = session.live_kit_client.as_ref() {
2200 live_kit.create_room(live_kit_room.clone()).await?;
2201 }
2202
2203 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2204 let id = db
2205 .create_channel(&request.name, parent_id, &live_kit_room, session.user_id)
2206 .await?;
2207
2208 let channel = proto::Channel {
2209 id: id.to_proto(),
2210 name: request.name,
2211 };
2212
2213 response.send(proto::CreateChannelResponse {
2214 channel: Some(channel.clone()),
2215 parent_id: request.parent_id,
2216 })?;
2217
2218 let Some(parent_id) = parent_id else {
2219 return Ok(());
2220 };
2221
2222 let update = proto::UpdateChannels {
2223 channels: vec![channel],
2224 insert_edge: vec![ChannelEdge {
2225 parent_id: parent_id.to_proto(),
2226 channel_id: id.to_proto(),
2227 }],
2228 ..Default::default()
2229 };
2230
2231 let user_ids_to_notify = db.get_channel_members(parent_id).await?;
2232
2233 let connection_pool = session.connection_pool().await;
2234 for user_id in user_ids_to_notify {
2235 for connection_id in connection_pool.user_connection_ids(user_id) {
2236 if user_id == session.user_id {
2237 continue;
2238 }
2239 session.peer.send(connection_id, update.clone())?;
2240 }
2241 }
2242
2243 Ok(())
2244}
2245
2246async fn delete_channel(
2247 request: proto::DeleteChannel,
2248 response: Response<proto::DeleteChannel>,
2249 session: Session,
2250) -> Result<()> {
2251 let db = session.db().await;
2252
2253 let channel_id = request.channel_id;
2254 let (removed_channels, member_ids) = db
2255 .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2256 .await?;
2257 response.send(proto::Ack {})?;
2258
2259 // Notify members of removed channels
2260 let mut update = proto::UpdateChannels::default();
2261 update
2262 .delete_channels
2263 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2264
2265 let connection_pool = session.connection_pool().await;
2266 for member_id in member_ids {
2267 for connection_id in connection_pool.user_connection_ids(member_id) {
2268 session.peer.send(connection_id, update.clone())?;
2269 }
2270 }
2271
2272 Ok(())
2273}
2274
2275async fn invite_channel_member(
2276 request: proto::InviteChannelMember,
2277 response: Response<proto::InviteChannelMember>,
2278 session: Session,
2279) -> Result<()> {
2280 let db = session.db().await;
2281 let channel_id = ChannelId::from_proto(request.channel_id);
2282 let invitee_id = UserId::from_proto(request.user_id);
2283 db.invite_channel_member(channel_id, invitee_id, session.user_id, request.admin)
2284 .await?;
2285
2286 let (channel, _) = db
2287 .get_channel(channel_id, session.user_id)
2288 .await?
2289 .ok_or_else(|| anyhow!("channel not found"))?;
2290
2291 let mut update = proto::UpdateChannels::default();
2292 update.channel_invitations.push(proto::Channel {
2293 id: channel.id.to_proto(),
2294 name: channel.name,
2295 });
2296 for connection_id in session
2297 .connection_pool()
2298 .await
2299 .user_connection_ids(invitee_id)
2300 {
2301 session.peer.send(connection_id, update.clone())?;
2302 }
2303
2304 response.send(proto::Ack {})?;
2305 Ok(())
2306}
2307
2308async fn remove_channel_member(
2309 request: proto::RemoveChannelMember,
2310 response: Response<proto::RemoveChannelMember>,
2311 session: Session,
2312) -> Result<()> {
2313 let db = session.db().await;
2314 let channel_id = ChannelId::from_proto(request.channel_id);
2315 let member_id = UserId::from_proto(request.user_id);
2316
2317 db.remove_channel_member(channel_id, member_id, session.user_id)
2318 .await?;
2319
2320 let mut update = proto::UpdateChannels::default();
2321 update.delete_channels.push(channel_id.to_proto());
2322
2323 for connection_id in session
2324 .connection_pool()
2325 .await
2326 .user_connection_ids(member_id)
2327 {
2328 session.peer.send(connection_id, update.clone())?;
2329 }
2330
2331 response.send(proto::Ack {})?;
2332 Ok(())
2333}
2334
2335async fn set_channel_member_admin(
2336 request: proto::SetChannelMemberAdmin,
2337 response: Response<proto::SetChannelMemberAdmin>,
2338 session: Session,
2339) -> Result<()> {
2340 let db = session.db().await;
2341 let channel_id = ChannelId::from_proto(request.channel_id);
2342 let member_id = UserId::from_proto(request.user_id);
2343 db.set_channel_member_admin(channel_id, session.user_id, member_id, request.admin)
2344 .await?;
2345
2346 let (channel, has_accepted) = db
2347 .get_channel(channel_id, member_id)
2348 .await?
2349 .ok_or_else(|| anyhow!("channel not found"))?;
2350
2351 let mut update = proto::UpdateChannels::default();
2352 if has_accepted {
2353 update.channel_permissions.push(proto::ChannelPermission {
2354 channel_id: channel.id.to_proto(),
2355 is_admin: request.admin,
2356 });
2357 }
2358
2359 for connection_id in session
2360 .connection_pool()
2361 .await
2362 .user_connection_ids(member_id)
2363 {
2364 session.peer.send(connection_id, update.clone())?;
2365 }
2366
2367 response.send(proto::Ack {})?;
2368 Ok(())
2369}
2370
2371async fn rename_channel(
2372 request: proto::RenameChannel,
2373 response: Response<proto::RenameChannel>,
2374 session: Session,
2375) -> Result<()> {
2376 let db = session.db().await;
2377 let channel_id = ChannelId::from_proto(request.channel_id);
2378 let new_name = db
2379 .rename_channel(channel_id, session.user_id, &request.name)
2380 .await?;
2381
2382 let channel = proto::Channel {
2383 id: request.channel_id,
2384 name: new_name,
2385 };
2386 response.send(proto::RenameChannelResponse {
2387 channel: Some(channel.clone()),
2388 })?;
2389 let mut update = proto::UpdateChannels::default();
2390 update.channels.push(channel);
2391
2392 let member_ids = db.get_channel_members(channel_id).await?;
2393
2394 let connection_pool = session.connection_pool().await;
2395 for member_id in member_ids {
2396 for connection_id in connection_pool.user_connection_ids(member_id) {
2397 session.peer.send(connection_id, update.clone())?;
2398 }
2399 }
2400
2401 Ok(())
2402}
2403
2404async fn link_channel(
2405 request: proto::LinkChannel,
2406 response: Response<proto::LinkChannel>,
2407 session: Session,
2408) -> Result<()> {
2409 let db = session.db().await;
2410 let channel_id = ChannelId::from_proto(request.channel_id);
2411 let to = ChannelId::from_proto(request.to);
2412 let channels_to_send = db.link_channel(session.user_id, channel_id, to).await?;
2413
2414 let members = db.get_channel_members(to).await?;
2415 let connection_pool = session.connection_pool().await;
2416 let update = proto::UpdateChannels {
2417 channels: channels_to_send
2418 .channels
2419 .into_iter()
2420 .map(|channel| proto::Channel {
2421 id: channel.id.to_proto(),
2422 name: channel.name,
2423 })
2424 .collect(),
2425 insert_edge: channels_to_send.edges,
2426 ..Default::default()
2427 };
2428 for member_id in members {
2429 for connection_id in connection_pool.user_connection_ids(member_id) {
2430 session.peer.send(connection_id, update.clone())?;
2431 }
2432 }
2433
2434 response.send(Ack {})?;
2435
2436 Ok(())
2437}
2438
2439async fn unlink_channel(
2440 request: proto::UnlinkChannel,
2441 response: Response<proto::UnlinkChannel>,
2442 session: Session,
2443) -> Result<()> {
2444 let db = session.db().await;
2445 let channel_id = ChannelId::from_proto(request.channel_id);
2446 let from = ChannelId::from_proto(request.from);
2447
2448 db.unlink_channel(session.user_id, channel_id, from).await?;
2449
2450 let members = db.get_channel_members(from).await?;
2451
2452 let update = proto::UpdateChannels {
2453 delete_edge: vec![proto::ChannelEdge {
2454 channel_id: channel_id.to_proto(),
2455 parent_id: from.to_proto(),
2456 }],
2457 ..Default::default()
2458 };
2459 let connection_pool = session.connection_pool().await;
2460 for member_id in members {
2461 for connection_id in connection_pool.user_connection_ids(member_id) {
2462 session.peer.send(connection_id, update.clone())?;
2463 }
2464 }
2465
2466 response.send(Ack {})?;
2467
2468 Ok(())
2469}
2470
2471async fn move_channel(
2472 request: proto::MoveChannel,
2473 response: Response<proto::MoveChannel>,
2474 session: Session,
2475) -> Result<()> {
2476 let db = session.db().await;
2477 let channel_id = ChannelId::from_proto(request.channel_id);
2478 let from_parent = ChannelId::from_proto(request.from);
2479 let to = ChannelId::from_proto(request.to);
2480
2481 let channels_to_send = db
2482 .move_channel(session.user_id, channel_id, from_parent, to)
2483 .await?;
2484
2485 if channels_to_send.is_empty() {
2486 response.send(Ack {})?;
2487 return Ok(());
2488 }
2489
2490 let members_from = db.get_channel_members(from_parent).await?;
2491 let members_to = db.get_channel_members(to).await?;
2492
2493 let update = proto::UpdateChannels {
2494 delete_edge: vec![proto::ChannelEdge {
2495 channel_id: channel_id.to_proto(),
2496 parent_id: from_parent.to_proto(),
2497 }],
2498 ..Default::default()
2499 };
2500 let connection_pool = session.connection_pool().await;
2501 for member_id in members_from {
2502 for connection_id in connection_pool.user_connection_ids(member_id) {
2503 session.peer.send(connection_id, update.clone())?;
2504 }
2505 }
2506
2507 let update = proto::UpdateChannels {
2508 channels: channels_to_send
2509 .channels
2510 .into_iter()
2511 .map(|channel| proto::Channel {
2512 id: channel.id.to_proto(),
2513 name: channel.name,
2514 })
2515 .collect(),
2516 insert_edge: channels_to_send.edges,
2517 ..Default::default()
2518 };
2519 for member_id in members_to {
2520 for connection_id in connection_pool.user_connection_ids(member_id) {
2521 session.peer.send(connection_id, update.clone())?;
2522 }
2523 }
2524
2525 response.send(Ack {})?;
2526
2527 Ok(())
2528}
2529
2530async fn get_channel_members(
2531 request: proto::GetChannelMembers,
2532 response: Response<proto::GetChannelMembers>,
2533 session: Session,
2534) -> Result<()> {
2535 let db = session.db().await;
2536 let channel_id = ChannelId::from_proto(request.channel_id);
2537 let members = db
2538 .get_channel_member_details(channel_id, session.user_id)
2539 .await?;
2540 response.send(proto::GetChannelMembersResponse { members })?;
2541 Ok(())
2542}
2543
2544async fn respond_to_channel_invite(
2545 request: proto::RespondToChannelInvite,
2546 response: Response<proto::RespondToChannelInvite>,
2547 session: Session,
2548) -> Result<()> {
2549 let db = session.db().await;
2550 let channel_id = ChannelId::from_proto(request.channel_id);
2551 db.respond_to_channel_invite(channel_id, session.user_id, request.accept)
2552 .await?;
2553
2554 let mut update = proto::UpdateChannels::default();
2555 update
2556 .remove_channel_invitations
2557 .push(channel_id.to_proto());
2558 if request.accept {
2559 let result = db.get_channel_for_user(channel_id, session.user_id).await?;
2560 update
2561 .channels
2562 .extend(
2563 result
2564 .channels
2565 .channels
2566 .into_iter()
2567 .map(|channel| proto::Channel {
2568 id: channel.id.to_proto(),
2569 name: channel.name,
2570 }),
2571 );
2572 update.unseen_channel_messages = result.channel_messages;
2573 update.unseen_channel_buffer_changes = result.unseen_buffer_changes;
2574 update.insert_edge = result.channels.edges;
2575 update
2576 .channel_participants
2577 .extend(
2578 result
2579 .channel_participants
2580 .into_iter()
2581 .map(|(channel_id, user_ids)| proto::ChannelParticipants {
2582 channel_id: channel_id.to_proto(),
2583 participant_user_ids: user_ids.into_iter().map(UserId::to_proto).collect(),
2584 }),
2585 );
2586 update
2587 .channel_permissions
2588 .extend(
2589 result
2590 .channels_with_admin_privileges
2591 .into_iter()
2592 .map(|channel_id| proto::ChannelPermission {
2593 channel_id: channel_id.to_proto(),
2594 is_admin: true,
2595 }),
2596 );
2597 }
2598 session.peer.send(session.connection_id, update)?;
2599 response.send(proto::Ack {})?;
2600
2601 Ok(())
2602}
2603
2604async fn join_channel(
2605 request: proto::JoinChannel,
2606 response: Response<proto::JoinChannel>,
2607 session: Session,
2608) -> Result<()> {
2609 let channel_id = ChannelId::from_proto(request.channel_id);
2610
2611 let joined_room = {
2612 leave_room_for_session(&session).await?;
2613 let db = session.db().await;
2614
2615 let room_id = db.room_id_for_channel(channel_id).await?;
2616
2617 let joined_room = db
2618 .join_room(room_id, session.user_id, session.connection_id)
2619 .await?;
2620
2621 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2622 let token = live_kit
2623 .room_token(
2624 &joined_room.room.live_kit_room,
2625 &session.user_id.to_string(),
2626 )
2627 .trace_err()?;
2628
2629 Some(LiveKitConnectionInfo {
2630 server_url: live_kit.url().into(),
2631 token,
2632 })
2633 });
2634
2635 response.send(proto::JoinRoomResponse {
2636 room: Some(joined_room.room.clone()),
2637 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2638 live_kit_connection_info,
2639 })?;
2640
2641 room_updated(&joined_room.room, &session.peer);
2642
2643 joined_room.into_inner()
2644 };
2645
2646 channel_updated(
2647 channel_id,
2648 &joined_room.room,
2649 &joined_room.channel_members,
2650 &session.peer,
2651 &*session.connection_pool().await,
2652 );
2653
2654 update_user_contacts(session.user_id, &session).await?;
2655
2656 Ok(())
2657}
2658
2659async fn join_channel_buffer(
2660 request: proto::JoinChannelBuffer,
2661 response: Response<proto::JoinChannelBuffer>,
2662 session: Session,
2663) -> Result<()> {
2664 let db = session.db().await;
2665 let channel_id = ChannelId::from_proto(request.channel_id);
2666
2667 let open_response = db
2668 .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2669 .await?;
2670
2671 let collaborators = open_response.collaborators.clone();
2672 response.send(open_response)?;
2673
2674 let update = UpdateChannelBufferCollaborators {
2675 channel_id: channel_id.to_proto(),
2676 collaborators: collaborators.clone(),
2677 };
2678 channel_buffer_updated(
2679 session.connection_id,
2680 collaborators
2681 .iter()
2682 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2683 &update,
2684 &session.peer,
2685 );
2686
2687 Ok(())
2688}
2689
2690async fn update_channel_buffer(
2691 request: proto::UpdateChannelBuffer,
2692 session: Session,
2693) -> Result<()> {
2694 let db = session.db().await;
2695 let channel_id = ChannelId::from_proto(request.channel_id);
2696
2697 let (collaborators, non_collaborators, epoch, version) = db
2698 .update_channel_buffer(channel_id, session.user_id, &request.operations)
2699 .await?;
2700
2701 channel_buffer_updated(
2702 session.connection_id,
2703 collaborators,
2704 &proto::UpdateChannelBuffer {
2705 channel_id: channel_id.to_proto(),
2706 operations: request.operations,
2707 },
2708 &session.peer,
2709 );
2710
2711 let pool = &*session.connection_pool().await;
2712
2713 broadcast(
2714 None,
2715 non_collaborators
2716 .iter()
2717 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2718 |peer_id| {
2719 session.peer.send(
2720 peer_id.into(),
2721 proto::UpdateChannels {
2722 unseen_channel_buffer_changes: vec![proto::UnseenChannelBufferChange {
2723 channel_id: channel_id.to_proto(),
2724 epoch: epoch as u64,
2725 version: version.clone(),
2726 }],
2727 ..Default::default()
2728 },
2729 )
2730 },
2731 );
2732
2733 Ok(())
2734}
2735
2736async fn rejoin_channel_buffers(
2737 request: proto::RejoinChannelBuffers,
2738 response: Response<proto::RejoinChannelBuffers>,
2739 session: Session,
2740) -> Result<()> {
2741 let db = session.db().await;
2742 let buffers = db
2743 .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2744 .await?;
2745
2746 for rejoined_buffer in &buffers {
2747 let collaborators_to_notify = rejoined_buffer
2748 .buffer
2749 .collaborators
2750 .iter()
2751 .filter_map(|c| Some(c.peer_id?.into()));
2752 channel_buffer_updated(
2753 session.connection_id,
2754 collaborators_to_notify,
2755 &proto::UpdateChannelBufferCollaborators {
2756 channel_id: rejoined_buffer.buffer.channel_id,
2757 collaborators: rejoined_buffer.buffer.collaborators.clone(),
2758 },
2759 &session.peer,
2760 );
2761 }
2762
2763 response.send(proto::RejoinChannelBuffersResponse {
2764 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2765 })?;
2766
2767 Ok(())
2768}
2769
2770async fn leave_channel_buffer(
2771 request: proto::LeaveChannelBuffer,
2772 response: Response<proto::LeaveChannelBuffer>,
2773 session: Session,
2774) -> Result<()> {
2775 let db = session.db().await;
2776 let channel_id = ChannelId::from_proto(request.channel_id);
2777
2778 let left_buffer = db
2779 .leave_channel_buffer(channel_id, session.connection_id)
2780 .await?;
2781
2782 response.send(Ack {})?;
2783
2784 channel_buffer_updated(
2785 session.connection_id,
2786 left_buffer.connections,
2787 &proto::UpdateChannelBufferCollaborators {
2788 channel_id: channel_id.to_proto(),
2789 collaborators: left_buffer.collaborators,
2790 },
2791 &session.peer,
2792 );
2793
2794 Ok(())
2795}
2796
2797fn channel_buffer_updated<T: EnvelopedMessage>(
2798 sender_id: ConnectionId,
2799 collaborators: impl IntoIterator<Item = ConnectionId>,
2800 message: &T,
2801 peer: &Peer,
2802) {
2803 broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2804 peer.send(peer_id.into(), message.clone())
2805 });
2806}
2807
2808async fn send_channel_message(
2809 request: proto::SendChannelMessage,
2810 response: Response<proto::SendChannelMessage>,
2811 session: Session,
2812) -> Result<()> {
2813 // Validate the message body.
2814 let body = request.body.trim().to_string();
2815 if body.len() > MAX_MESSAGE_LEN {
2816 return Err(anyhow!("message is too long"))?;
2817 }
2818 if body.is_empty() {
2819 return Err(anyhow!("message can't be blank"))?;
2820 }
2821
2822 let timestamp = OffsetDateTime::now_utc();
2823 let nonce = request
2824 .nonce
2825 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
2826
2827 let channel_id = ChannelId::from_proto(request.channel_id);
2828 let (message_id, connection_ids, non_participants) = session
2829 .db()
2830 .await
2831 .create_channel_message(
2832 channel_id,
2833 session.user_id,
2834 &body,
2835 timestamp,
2836 nonce.clone().into(),
2837 )
2838 .await?;
2839 let message = proto::ChannelMessage {
2840 sender_id: session.user_id.to_proto(),
2841 id: message_id.to_proto(),
2842 body,
2843 timestamp: timestamp.unix_timestamp() as u64,
2844 nonce: Some(nonce),
2845 };
2846 broadcast(Some(session.connection_id), connection_ids, |connection| {
2847 session.peer.send(
2848 connection,
2849 proto::ChannelMessageSent {
2850 channel_id: channel_id.to_proto(),
2851 message: Some(message.clone()),
2852 },
2853 )
2854 });
2855 response.send(proto::SendChannelMessageResponse {
2856 message: Some(message),
2857 })?;
2858
2859 let pool = &*session.connection_pool().await;
2860 broadcast(
2861 None,
2862 non_participants
2863 .iter()
2864 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2865 |peer_id| {
2866 session.peer.send(
2867 peer_id.into(),
2868 proto::UpdateChannels {
2869 unseen_channel_messages: vec![proto::UnseenChannelMessage {
2870 channel_id: channel_id.to_proto(),
2871 message_id: message_id.to_proto(),
2872 }],
2873 ..Default::default()
2874 },
2875 )
2876 },
2877 );
2878
2879 Ok(())
2880}
2881
2882async fn remove_channel_message(
2883 request: proto::RemoveChannelMessage,
2884 response: Response<proto::RemoveChannelMessage>,
2885 session: Session,
2886) -> Result<()> {
2887 let channel_id = ChannelId::from_proto(request.channel_id);
2888 let message_id = MessageId::from_proto(request.message_id);
2889 let connection_ids = session
2890 .db()
2891 .await
2892 .remove_channel_message(channel_id, message_id, session.user_id)
2893 .await?;
2894 broadcast(Some(session.connection_id), connection_ids, |connection| {
2895 session.peer.send(connection, request.clone())
2896 });
2897 response.send(proto::Ack {})?;
2898 Ok(())
2899}
2900
2901async fn acknowledge_channel_message(
2902 request: proto::AckChannelMessage,
2903 session: Session,
2904) -> Result<()> {
2905 let channel_id = ChannelId::from_proto(request.channel_id);
2906 let message_id = MessageId::from_proto(request.message_id);
2907 session
2908 .db()
2909 .await
2910 .observe_channel_message(channel_id, session.user_id, message_id)
2911 .await?;
2912 Ok(())
2913}
2914
2915async fn join_channel_chat(
2916 request: proto::JoinChannelChat,
2917 response: Response<proto::JoinChannelChat>,
2918 session: Session,
2919) -> Result<()> {
2920 let channel_id = ChannelId::from_proto(request.channel_id);
2921
2922 let db = session.db().await;
2923 db.join_channel_chat(channel_id, session.connection_id, session.user_id)
2924 .await?;
2925 let messages = db
2926 .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
2927 .await?;
2928 response.send(proto::JoinChannelChatResponse {
2929 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
2930 messages,
2931 })?;
2932 Ok(())
2933}
2934
2935async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
2936 let channel_id = ChannelId::from_proto(request.channel_id);
2937 session
2938 .db()
2939 .await
2940 .leave_channel_chat(channel_id, session.connection_id, session.user_id)
2941 .await?;
2942 Ok(())
2943}
2944
2945async fn get_channel_messages(
2946 request: proto::GetChannelMessages,
2947 response: Response<proto::GetChannelMessages>,
2948 session: Session,
2949) -> Result<()> {
2950 let channel_id = ChannelId::from_proto(request.channel_id);
2951 let messages = session
2952 .db()
2953 .await
2954 .get_channel_messages(
2955 channel_id,
2956 session.user_id,
2957 MESSAGE_COUNT_PER_PAGE,
2958 Some(MessageId::from_proto(request.before_message_id)),
2959 )
2960 .await?;
2961 response.send(proto::GetChannelMessagesResponse {
2962 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
2963 messages,
2964 })?;
2965 Ok(())
2966}
2967
2968async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
2969 let project_id = ProjectId::from_proto(request.project_id);
2970 let project_connection_ids = session
2971 .db()
2972 .await
2973 .project_connection_ids(project_id, session.connection_id)
2974 .await?;
2975 broadcast(
2976 Some(session.connection_id),
2977 project_connection_ids.iter().copied(),
2978 |connection_id| {
2979 session
2980 .peer
2981 .forward_send(session.connection_id, connection_id, request.clone())
2982 },
2983 );
2984 Ok(())
2985}
2986
2987async fn get_private_user_info(
2988 _request: proto::GetPrivateUserInfo,
2989 response: Response<proto::GetPrivateUserInfo>,
2990 session: Session,
2991) -> Result<()> {
2992 let db = session.db().await;
2993
2994 let metrics_id = db.get_user_metrics_id(session.user_id).await?;
2995 let user = db
2996 .get_user_by_id(session.user_id)
2997 .await?
2998 .ok_or_else(|| anyhow!("user not found"))?;
2999 let flags = db.get_user_flags(session.user_id).await?;
3000
3001 response.send(proto::GetPrivateUserInfoResponse {
3002 metrics_id,
3003 staff: user.admin,
3004 flags,
3005 })?;
3006 Ok(())
3007}
3008
3009fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3010 match message {
3011 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3012 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3013 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3014 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3015 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3016 code: frame.code.into(),
3017 reason: frame.reason,
3018 })),
3019 }
3020}
3021
3022fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3023 match message {
3024 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3025 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3026 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3027 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3028 AxumMessage::Close(frame) => {
3029 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3030 code: frame.code.into(),
3031 reason: frame.reason,
3032 }))
3033 }
3034 }
3035}
3036
3037fn build_initial_channels_update(
3038 channels: ChannelsForUser,
3039 channel_invites: Vec<db::Channel>,
3040) -> proto::UpdateChannels {
3041 let mut update = proto::UpdateChannels::default();
3042
3043 for channel in channels.channels.channels {
3044 update.channels.push(proto::Channel {
3045 id: channel.id.to_proto(),
3046 name: channel.name,
3047 });
3048 }
3049
3050 update.unseen_channel_buffer_changes = channels.unseen_buffer_changes;
3051 update.unseen_channel_messages = channels.channel_messages;
3052 update.insert_edge = channels.channels.edges;
3053
3054 for (channel_id, participants) in channels.channel_participants {
3055 update
3056 .channel_participants
3057 .push(proto::ChannelParticipants {
3058 channel_id: channel_id.to_proto(),
3059 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3060 });
3061 }
3062
3063 update
3064 .channel_permissions
3065 .extend(
3066 channels
3067 .channels_with_admin_privileges
3068 .into_iter()
3069 .map(|id| proto::ChannelPermission {
3070 channel_id: id.to_proto(),
3071 is_admin: true,
3072 }),
3073 );
3074
3075 for channel in channel_invites {
3076 update.channel_invitations.push(proto::Channel {
3077 id: channel.id.to_proto(),
3078 name: channel.name,
3079 });
3080 }
3081
3082 update
3083}
3084
3085fn build_initial_contacts_update(
3086 contacts: Vec<db::Contact>,
3087 pool: &ConnectionPool,
3088) -> proto::UpdateContacts {
3089 let mut update = proto::UpdateContacts::default();
3090
3091 for contact in contacts {
3092 match contact {
3093 db::Contact::Accepted {
3094 user_id,
3095 should_notify,
3096 busy,
3097 } => {
3098 update
3099 .contacts
3100 .push(contact_for_user(user_id, should_notify, busy, &pool));
3101 }
3102 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3103 db::Contact::Incoming {
3104 user_id,
3105 should_notify,
3106 } => update
3107 .incoming_requests
3108 .push(proto::IncomingContactRequest {
3109 requester_id: user_id.to_proto(),
3110 should_notify,
3111 }),
3112 }
3113 }
3114
3115 update
3116}
3117
3118fn contact_for_user(
3119 user_id: UserId,
3120 should_notify: bool,
3121 busy: bool,
3122 pool: &ConnectionPool,
3123) -> proto::Contact {
3124 proto::Contact {
3125 user_id: user_id.to_proto(),
3126 online: pool.is_user_online(user_id),
3127 busy,
3128 should_notify,
3129 }
3130}
3131
3132fn room_updated(room: &proto::Room, peer: &Peer) {
3133 broadcast(
3134 None,
3135 room.participants
3136 .iter()
3137 .filter_map(|participant| Some(participant.peer_id?.into())),
3138 |peer_id| {
3139 peer.send(
3140 peer_id.into(),
3141 proto::RoomUpdated {
3142 room: Some(room.clone()),
3143 },
3144 )
3145 },
3146 );
3147}
3148
3149fn channel_updated(
3150 channel_id: ChannelId,
3151 room: &proto::Room,
3152 channel_members: &[UserId],
3153 peer: &Peer,
3154 pool: &ConnectionPool,
3155) {
3156 let participants = room
3157 .participants
3158 .iter()
3159 .map(|p| p.user_id)
3160 .collect::<Vec<_>>();
3161
3162 broadcast(
3163 None,
3164 channel_members
3165 .iter()
3166 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3167 |peer_id| {
3168 peer.send(
3169 peer_id.into(),
3170 proto::UpdateChannels {
3171 channel_participants: vec![proto::ChannelParticipants {
3172 channel_id: channel_id.to_proto(),
3173 participant_user_ids: participants.clone(),
3174 }],
3175 ..Default::default()
3176 },
3177 )
3178 },
3179 );
3180}
3181
3182async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3183 let db = session.db().await;
3184
3185 let contacts = db.get_contacts(user_id).await?;
3186 let busy = db.is_user_busy(user_id).await?;
3187
3188 let pool = session.connection_pool().await;
3189 let updated_contact = contact_for_user(user_id, false, busy, &pool);
3190 for contact in contacts {
3191 if let db::Contact::Accepted {
3192 user_id: contact_user_id,
3193 ..
3194 } = contact
3195 {
3196 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3197 session
3198 .peer
3199 .send(
3200 contact_conn_id,
3201 proto::UpdateContacts {
3202 contacts: vec![updated_contact.clone()],
3203 remove_contacts: Default::default(),
3204 incoming_requests: Default::default(),
3205 remove_incoming_requests: Default::default(),
3206 outgoing_requests: Default::default(),
3207 remove_outgoing_requests: Default::default(),
3208 },
3209 )
3210 .trace_err();
3211 }
3212 }
3213 }
3214 Ok(())
3215}
3216
3217async fn leave_room_for_session(session: &Session) -> Result<()> {
3218 let mut contacts_to_update = HashSet::default();
3219
3220 let room_id;
3221 let canceled_calls_to_user_ids;
3222 let live_kit_room;
3223 let delete_live_kit_room;
3224 let room;
3225 let channel_members;
3226 let channel_id;
3227
3228 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3229 contacts_to_update.insert(session.user_id);
3230
3231 for project in left_room.left_projects.values() {
3232 project_left(project, session);
3233 }
3234
3235 room_id = RoomId::from_proto(left_room.room.id);
3236 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3237 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3238 delete_live_kit_room = left_room.deleted;
3239 room = mem::take(&mut left_room.room);
3240 channel_members = mem::take(&mut left_room.channel_members);
3241 channel_id = left_room.channel_id;
3242
3243 room_updated(&room, &session.peer);
3244 } else {
3245 return Ok(());
3246 }
3247
3248 if let Some(channel_id) = channel_id {
3249 channel_updated(
3250 channel_id,
3251 &room,
3252 &channel_members,
3253 &session.peer,
3254 &*session.connection_pool().await,
3255 );
3256 }
3257
3258 {
3259 let pool = session.connection_pool().await;
3260 for canceled_user_id in canceled_calls_to_user_ids {
3261 for connection_id in pool.user_connection_ids(canceled_user_id) {
3262 session
3263 .peer
3264 .send(
3265 connection_id,
3266 proto::CallCanceled {
3267 room_id: room_id.to_proto(),
3268 },
3269 )
3270 .trace_err();
3271 }
3272 contacts_to_update.insert(canceled_user_id);
3273 }
3274 }
3275
3276 for contact_user_id in contacts_to_update {
3277 update_user_contacts(contact_user_id, &session).await?;
3278 }
3279
3280 if let Some(live_kit) = session.live_kit_client.as_ref() {
3281 live_kit
3282 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3283 .await
3284 .trace_err();
3285
3286 if delete_live_kit_room {
3287 live_kit.delete_room(live_kit_room).await.trace_err();
3288 }
3289 }
3290
3291 Ok(())
3292}
3293
3294async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3295 let left_channel_buffers = session
3296 .db()
3297 .await
3298 .leave_channel_buffers(session.connection_id)
3299 .await?;
3300
3301 for left_buffer in left_channel_buffers {
3302 channel_buffer_updated(
3303 session.connection_id,
3304 left_buffer.connections,
3305 &proto::UpdateChannelBufferCollaborators {
3306 channel_id: left_buffer.channel_id.to_proto(),
3307 collaborators: left_buffer.collaborators,
3308 },
3309 &session.peer,
3310 );
3311 }
3312
3313 Ok(())
3314}
3315
3316fn project_left(project: &db::LeftProject, session: &Session) {
3317 for connection_id in &project.connection_ids {
3318 if project.host_user_id == session.user_id {
3319 session
3320 .peer
3321 .send(
3322 *connection_id,
3323 proto::UnshareProject {
3324 project_id: project.id.to_proto(),
3325 },
3326 )
3327 .trace_err();
3328 } else {
3329 session
3330 .peer
3331 .send(
3332 *connection_id,
3333 proto::RemoveProjectCollaborator {
3334 project_id: project.id.to_proto(),
3335 peer_id: Some(session.connection_id.into()),
3336 },
3337 )
3338 .trace_err();
3339 }
3340 }
3341}
3342
3343pub trait ResultExt {
3344 type Ok;
3345
3346 fn trace_err(self) -> Option<Self::Ok>;
3347}
3348
3349impl<T, E> ResultExt for Result<T, E>
3350where
3351 E: std::fmt::Debug,
3352{
3353 type Ok = T;
3354
3355 fn trace_err(self) -> Option<T> {
3356 match self {
3357 Ok(value) => Some(value),
3358 Err(error) => {
3359 tracing::error!("{:?}", error);
3360 None
3361 }
3362 }
3363 }
3364}