1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{
6 self, BufferId, ChannelId, ChannelsForUser, Database, MessageId, ProjectId, RoomId,
7 ServerId, User, UserId,
8 },
9 executor::Executor,
10 AppState, Result,
11};
12use anyhow::anyhow;
13use async_tungstenite::tungstenite::{
14 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
15};
16use axum::{
17 body::Body,
18 extract::{
19 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
20 ConnectInfo, WebSocketUpgrade,
21 },
22 headers::{Header, HeaderName},
23 http::StatusCode,
24 middleware,
25 response::IntoResponse,
26 routing::get,
27 Extension, Router, TypedHeader,
28};
29use collections::{HashMap, HashSet};
30pub use connection_pool::ConnectionPool;
31use futures::{
32 channel::oneshot,
33 future::{self, BoxFuture},
34 stream::FuturesUnordered,
35 FutureExt, SinkExt, StreamExt, TryStreamExt,
36};
37use lazy_static::lazy_static;
38use prometheus::{register_int_gauge, IntGauge};
39use rpc::{
40 proto::{
41 self, Ack, AnyTypedEnvelope, ChannelEdge, EntityMessage, EnvelopedMessage,
42 LiveKitConnectionInfo, RequestMessage, UpdateChannelBufferCollaborators,
43 },
44 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
45};
46use serde::{Serialize, Serializer};
47use std::{
48 any::TypeId,
49 fmt,
50 future::Future,
51 marker::PhantomData,
52 mem,
53 net::SocketAddr,
54 ops::{Deref, DerefMut},
55 rc::Rc,
56 sync::{
57 atomic::{AtomicBool, Ordering::SeqCst},
58 Arc,
59 },
60 time::{Duration, Instant},
61};
62use time::OffsetDateTime;
63use tokio::sync::{watch, Semaphore};
64use tower::ServiceBuilder;
65use tracing::{info_span, instrument, Instrument};
66
67pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
68pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
69
70const MESSAGE_COUNT_PER_PAGE: usize = 100;
71const MAX_MESSAGE_LEN: usize = 1024;
72
73lazy_static! {
74 static ref METRIC_CONNECTIONS: IntGauge =
75 register_int_gauge!("connections", "number of connections").unwrap();
76 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
77 "shared_projects",
78 "number of open projects with one or more guests"
79 )
80 .unwrap();
81}
82
83type MessageHandler =
84 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
85
86struct Response<R> {
87 peer: Arc<Peer>,
88 receipt: Receipt<R>,
89 responded: Arc<AtomicBool>,
90}
91
92impl<R: RequestMessage> Response<R> {
93 fn send(self, payload: R::Response) -> Result<()> {
94 self.responded.store(true, SeqCst);
95 self.peer.respond(self.receipt, payload)?;
96 Ok(())
97 }
98}
99
100#[derive(Clone)]
101struct Session {
102 user_id: UserId,
103 connection_id: ConnectionId,
104 db: Arc<tokio::sync::Mutex<DbHandle>>,
105 peer: Arc<Peer>,
106 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
107 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
108 executor: Executor,
109}
110
111impl Session {
112 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
113 #[cfg(test)]
114 tokio::task::yield_now().await;
115 let guard = self.db.lock().await;
116 #[cfg(test)]
117 tokio::task::yield_now().await;
118 guard
119 }
120
121 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
122 #[cfg(test)]
123 tokio::task::yield_now().await;
124 let guard = self.connection_pool.lock();
125 ConnectionPoolGuard {
126 guard,
127 _not_send: PhantomData,
128 }
129 }
130}
131
132impl fmt::Debug for Session {
133 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
134 f.debug_struct("Session")
135 .field("user_id", &self.user_id)
136 .field("connection_id", &self.connection_id)
137 .finish()
138 }
139}
140
141struct DbHandle(Arc<Database>);
142
143impl Deref for DbHandle {
144 type Target = Database;
145
146 fn deref(&self) -> &Self::Target {
147 self.0.as_ref()
148 }
149}
150
151pub struct Server {
152 id: parking_lot::Mutex<ServerId>,
153 peer: Arc<Peer>,
154 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
155 app_state: Arc<AppState>,
156 executor: Executor,
157 handlers: HashMap<TypeId, MessageHandler>,
158 teardown: watch::Sender<()>,
159}
160
161pub(crate) struct ConnectionPoolGuard<'a> {
162 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
163 _not_send: PhantomData<Rc<()>>,
164}
165
166#[derive(Serialize)]
167pub struct ServerSnapshot<'a> {
168 peer: &'a Peer,
169 #[serde(serialize_with = "serialize_deref")]
170 connection_pool: ConnectionPoolGuard<'a>,
171}
172
173pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
174where
175 S: Serializer,
176 T: Deref<Target = U>,
177 U: Serialize,
178{
179 Serialize::serialize(value.deref(), serializer)
180}
181
182impl Server {
183 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
184 let mut server = Self {
185 id: parking_lot::Mutex::new(id),
186 peer: Peer::new(id.0 as u32),
187 app_state,
188 executor,
189 connection_pool: Default::default(),
190 handlers: Default::default(),
191 teardown: watch::channel(()).0,
192 };
193
194 server
195 .add_request_handler(ping)
196 .add_request_handler(create_room)
197 .add_request_handler(join_room)
198 .add_request_handler(rejoin_room)
199 .add_request_handler(leave_room)
200 .add_request_handler(call)
201 .add_request_handler(cancel_call)
202 .add_message_handler(decline_call)
203 .add_request_handler(update_participant_location)
204 .add_request_handler(share_project)
205 .add_message_handler(unshare_project)
206 .add_request_handler(join_project)
207 .add_message_handler(leave_project)
208 .add_request_handler(update_project)
209 .add_request_handler(update_worktree)
210 .add_message_handler(start_language_server)
211 .add_message_handler(update_language_server)
212 .add_message_handler(update_diagnostic_summary)
213 .add_message_handler(update_worktree_settings)
214 .add_message_handler(refresh_inlay_hints)
215 .add_request_handler(forward_project_request::<proto::GetHover>)
216 .add_request_handler(forward_project_request::<proto::GetDefinition>)
217 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
218 .add_request_handler(forward_project_request::<proto::GetReferences>)
219 .add_request_handler(forward_project_request::<proto::SearchProject>)
220 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
221 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
222 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
223 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
224 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
225 .add_request_handler(forward_project_request::<proto::GetCompletions>)
226 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
227 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
228 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
229 .add_request_handler(forward_project_request::<proto::PrepareRename>)
230 .add_request_handler(forward_project_request::<proto::PerformRename>)
231 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
232 .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
233 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
234 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
235 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
236 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
237 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
238 .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
239 .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
240 .add_request_handler(forward_project_request::<proto::InlayHints>)
241 .add_message_handler(create_buffer_for_peer)
242 .add_request_handler(update_buffer)
243 .add_message_handler(update_buffer_file)
244 .add_message_handler(buffer_reloaded)
245 .add_message_handler(buffer_saved)
246 .add_request_handler(forward_project_request::<proto::SaveBuffer>)
247 .add_request_handler(get_users)
248 .add_request_handler(fuzzy_search_users)
249 .add_request_handler(request_contact)
250 .add_request_handler(remove_contact)
251 .add_request_handler(respond_to_contact_request)
252 .add_request_handler(create_channel)
253 .add_request_handler(delete_channel)
254 .add_request_handler(invite_channel_member)
255 .add_request_handler(remove_channel_member)
256 .add_request_handler(set_channel_member_admin)
257 .add_request_handler(rename_channel)
258 .add_request_handler(join_channel_buffer)
259 .add_request_handler(leave_channel_buffer)
260 .add_message_handler(update_channel_buffer)
261 .add_request_handler(rejoin_channel_buffers)
262 .add_request_handler(get_channel_members)
263 .add_request_handler(respond_to_channel_invite)
264 .add_request_handler(join_channel)
265 .add_request_handler(join_channel_chat)
266 .add_message_handler(leave_channel_chat)
267 .add_request_handler(send_channel_message)
268 .add_request_handler(remove_channel_message)
269 .add_request_handler(get_channel_messages)
270 .add_request_handler(link_channel)
271 .add_request_handler(unlink_channel)
272 .add_request_handler(move_channel)
273 .add_request_handler(follow)
274 .add_message_handler(unfollow)
275 .add_message_handler(update_followers)
276 .add_message_handler(update_diff_base)
277 .add_request_handler(get_private_user_info)
278 .add_message_handler(acknowledge_channel_message)
279 .add_message_handler(acknowledge_buffer_version);
280
281 Arc::new(server)
282 }
283
284 pub async fn start(&self) -> Result<()> {
285 let server_id = *self.id.lock();
286 let app_state = self.app_state.clone();
287 let peer = self.peer.clone();
288 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
289 let pool = self.connection_pool.clone();
290 let live_kit_client = self.app_state.live_kit_client.clone();
291
292 let span = info_span!("start server");
293 self.executor.spawn_detached(
294 async move {
295 tracing::info!("waiting for cleanup timeout");
296 timeout.await;
297 tracing::info!("cleanup timeout expired, retrieving stale rooms");
298 if let Some((room_ids, channel_ids)) = app_state
299 .db
300 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
301 .await
302 .trace_err()
303 {
304 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
305 tracing::info!(
306 stale_channel_buffer_count = channel_ids.len(),
307 "retrieved stale channel buffers"
308 );
309
310 for channel_id in channel_ids {
311 if let Some(refreshed_channel_buffer) = app_state
312 .db
313 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
314 .await
315 .trace_err()
316 {
317 for connection_id in refreshed_channel_buffer.connection_ids {
318 peer.send(
319 connection_id,
320 proto::UpdateChannelBufferCollaborators {
321 channel_id: channel_id.to_proto(),
322 collaborators: refreshed_channel_buffer
323 .collaborators
324 .clone(),
325 },
326 )
327 .trace_err();
328 }
329 }
330 }
331
332 for room_id in room_ids {
333 let mut contacts_to_update = HashSet::default();
334 let mut canceled_calls_to_user_ids = Vec::new();
335 let mut live_kit_room = String::new();
336 let mut delete_live_kit_room = false;
337
338 if let Some(mut refreshed_room) = app_state
339 .db
340 .clear_stale_room_participants(room_id, server_id)
341 .await
342 .trace_err()
343 {
344 tracing::info!(
345 room_id = room_id.0,
346 new_participant_count = refreshed_room.room.participants.len(),
347 "refreshed room"
348 );
349 room_updated(&refreshed_room.room, &peer);
350 if let Some(channel_id) = refreshed_room.channel_id {
351 channel_updated(
352 channel_id,
353 &refreshed_room.room,
354 &refreshed_room.channel_members,
355 &peer,
356 &*pool.lock(),
357 );
358 }
359 contacts_to_update
360 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
361 contacts_to_update
362 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
363 canceled_calls_to_user_ids =
364 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
365 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
366 delete_live_kit_room = refreshed_room.room.participants.is_empty();
367 }
368
369 {
370 let pool = pool.lock();
371 for canceled_user_id in canceled_calls_to_user_ids {
372 for connection_id in pool.user_connection_ids(canceled_user_id) {
373 peer.send(
374 connection_id,
375 proto::CallCanceled {
376 room_id: room_id.to_proto(),
377 },
378 )
379 .trace_err();
380 }
381 }
382 }
383
384 for user_id in contacts_to_update {
385 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
386 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
387 if let Some((busy, contacts)) = busy.zip(contacts) {
388 let pool = pool.lock();
389 let updated_contact = contact_for_user(user_id, false, busy, &pool);
390 for contact in contacts {
391 if let db::Contact::Accepted {
392 user_id: contact_user_id,
393 ..
394 } = contact
395 {
396 for contact_conn_id in
397 pool.user_connection_ids(contact_user_id)
398 {
399 peer.send(
400 contact_conn_id,
401 proto::UpdateContacts {
402 contacts: vec![updated_contact.clone()],
403 remove_contacts: Default::default(),
404 incoming_requests: Default::default(),
405 remove_incoming_requests: Default::default(),
406 outgoing_requests: Default::default(),
407 remove_outgoing_requests: Default::default(),
408 },
409 )
410 .trace_err();
411 }
412 }
413 }
414 }
415 }
416
417 if let Some(live_kit) = live_kit_client.as_ref() {
418 if delete_live_kit_room {
419 live_kit.delete_room(live_kit_room).await.trace_err();
420 }
421 }
422 }
423 }
424
425 app_state
426 .db
427 .delete_stale_servers(&app_state.config.zed_environment, server_id)
428 .await
429 .trace_err();
430 }
431 .instrument(span),
432 );
433 Ok(())
434 }
435
436 pub fn teardown(&self) {
437 self.peer.teardown();
438 self.connection_pool.lock().reset();
439 let _ = self.teardown.send(());
440 }
441
442 #[cfg(test)]
443 pub fn reset(&self, id: ServerId) {
444 self.teardown();
445 *self.id.lock() = id;
446 self.peer.reset(id.0 as u32);
447 }
448
449 #[cfg(test)]
450 pub fn id(&self) -> ServerId {
451 *self.id.lock()
452 }
453
454 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
455 where
456 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
457 Fut: 'static + Send + Future<Output = Result<()>>,
458 M: EnvelopedMessage,
459 {
460 let prev_handler = self.handlers.insert(
461 TypeId::of::<M>(),
462 Box::new(move |envelope, session| {
463 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
464 let span = info_span!(
465 "handle message",
466 payload_type = envelope.payload_type_name()
467 );
468 span.in_scope(|| {
469 tracing::info!(
470 payload_type = envelope.payload_type_name(),
471 "message received"
472 );
473 });
474 let start_time = Instant::now();
475 let future = (handler)(*envelope, session);
476 async move {
477 let result = future.await;
478 let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
479 match result {
480 Err(error) => {
481 tracing::error!(%error, ?duration_ms, "error handling message")
482 }
483 Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
484 }
485 }
486 .instrument(span)
487 .boxed()
488 }),
489 );
490 if prev_handler.is_some() {
491 panic!("registered a handler for the same message twice");
492 }
493 self
494 }
495
496 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
497 where
498 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
499 Fut: 'static + Send + Future<Output = Result<()>>,
500 M: EnvelopedMessage,
501 {
502 self.add_handler(move |envelope, session| handler(envelope.payload, session));
503 self
504 }
505
506 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
507 where
508 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
509 Fut: Send + Future<Output = Result<()>>,
510 M: RequestMessage,
511 {
512 let handler = Arc::new(handler);
513 self.add_handler(move |envelope, session| {
514 let receipt = envelope.receipt();
515 let handler = handler.clone();
516 async move {
517 let peer = session.peer.clone();
518 let responded = Arc::new(AtomicBool::default());
519 let response = Response {
520 peer: peer.clone(),
521 responded: responded.clone(),
522 receipt,
523 };
524 match (handler)(envelope.payload, response, session).await {
525 Ok(()) => {
526 if responded.load(std::sync::atomic::Ordering::SeqCst) {
527 Ok(())
528 } else {
529 Err(anyhow!("handler did not send a response"))?
530 }
531 }
532 Err(error) => {
533 peer.respond_with_error(
534 receipt,
535 proto::Error {
536 message: error.to_string(),
537 },
538 )?;
539 Err(error)
540 }
541 }
542 }
543 })
544 }
545
546 pub fn handle_connection(
547 self: &Arc<Self>,
548 connection: Connection,
549 address: String,
550 user: User,
551 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
552 executor: Executor,
553 ) -> impl Future<Output = Result<()>> {
554 let this = self.clone();
555 let user_id = user.id;
556 let login = user.github_login;
557 let span = info_span!("handle connection", %user_id, %login, %address);
558 let mut teardown = self.teardown.subscribe();
559 async move {
560 let (connection_id, handle_io, mut incoming_rx) = this
561 .peer
562 .add_connection(connection, {
563 let executor = executor.clone();
564 move |duration| executor.sleep(duration)
565 });
566
567 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
568 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
569 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
570
571 if let Some(send_connection_id) = send_connection_id.take() {
572 let _ = send_connection_id.send(connection_id);
573 }
574
575 if !user.connected_once {
576 this.peer.send(connection_id, proto::ShowContacts {})?;
577 this.app_state.db.set_user_connected_once(user_id, true).await?;
578 }
579
580 let (contacts, channels_for_user, channel_invites) = future::try_join3(
581 this.app_state.db.get_contacts(user_id),
582 this.app_state.db.get_channels_for_user(user_id),
583 this.app_state.db.get_channel_invites_for_user(user_id)
584 ).await?;
585
586 {
587 let mut pool = this.connection_pool.lock();
588 pool.add_connection(connection_id, user_id, user.admin);
589 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
590 this.peer.send(connection_id, build_initial_channels_update(
591 channels_for_user,
592 channel_invites
593 ))?;
594 }
595
596 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
597 this.peer.send(connection_id, incoming_call)?;
598 }
599
600 let session = Session {
601 user_id,
602 connection_id,
603 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
604 peer: this.peer.clone(),
605 connection_pool: this.connection_pool.clone(),
606 live_kit_client: this.app_state.live_kit_client.clone(),
607 executor: executor.clone(),
608 };
609 update_user_contacts(user_id, &session).await?;
610
611 let handle_io = handle_io.fuse();
612 futures::pin_mut!(handle_io);
613
614 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
615 // This prevents deadlocks when e.g., client A performs a request to client B and
616 // client B performs a request to client A. If both clients stop processing further
617 // messages until their respective request completes, they won't have a chance to
618 // respond to the other client's request and cause a deadlock.
619 //
620 // This arrangement ensures we will attempt to process earlier messages first, but fall
621 // back to processing messages arrived later in the spirit of making progress.
622 let mut foreground_message_handlers = FuturesUnordered::new();
623 let concurrent_handlers = Arc::new(Semaphore::new(256));
624 loop {
625 let next_message = async {
626 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
627 let message = incoming_rx.next().await;
628 (permit, message)
629 }.fuse();
630 futures::pin_mut!(next_message);
631 futures::select_biased! {
632 _ = teardown.changed().fuse() => return Ok(()),
633 result = handle_io => {
634 if let Err(error) = result {
635 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
636 }
637 break;
638 }
639 _ = foreground_message_handlers.next() => {}
640 next_message = next_message => {
641 let (permit, message) = next_message;
642 if let Some(message) = message {
643 let type_name = message.payload_type_name();
644 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
645 let span_enter = span.enter();
646 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
647 let is_background = message.is_background();
648 let handle_message = (handler)(message, session.clone());
649 drop(span_enter);
650
651 let handle_message = async move {
652 handle_message.await;
653 drop(permit);
654 }.instrument(span);
655 if is_background {
656 executor.spawn_detached(handle_message);
657 } else {
658 foreground_message_handlers.push(handle_message);
659 }
660 } else {
661 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
662 }
663 } else {
664 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
665 break;
666 }
667 }
668 }
669 }
670
671 drop(foreground_message_handlers);
672 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
673 if let Err(error) = connection_lost(session, teardown, executor).await {
674 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
675 }
676
677 Ok(())
678 }.instrument(span)
679 }
680
681 pub async fn invite_code_redeemed(
682 self: &Arc<Self>,
683 inviter_id: UserId,
684 invitee_id: UserId,
685 ) -> Result<()> {
686 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
687 if let Some(code) = &user.invite_code {
688 let pool = self.connection_pool.lock();
689 let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
690 for connection_id in pool.user_connection_ids(inviter_id) {
691 self.peer.send(
692 connection_id,
693 proto::UpdateContacts {
694 contacts: vec![invitee_contact.clone()],
695 ..Default::default()
696 },
697 )?;
698 self.peer.send(
699 connection_id,
700 proto::UpdateInviteInfo {
701 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
702 count: user.invite_count as u32,
703 },
704 )?;
705 }
706 }
707 }
708 Ok(())
709 }
710
711 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
712 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
713 if let Some(invite_code) = &user.invite_code {
714 let pool = self.connection_pool.lock();
715 for connection_id in pool.user_connection_ids(user_id) {
716 self.peer.send(
717 connection_id,
718 proto::UpdateInviteInfo {
719 url: format!(
720 "{}{}",
721 self.app_state.config.invite_link_prefix, invite_code
722 ),
723 count: user.invite_count as u32,
724 },
725 )?;
726 }
727 }
728 }
729 Ok(())
730 }
731
732 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
733 ServerSnapshot {
734 connection_pool: ConnectionPoolGuard {
735 guard: self.connection_pool.lock(),
736 _not_send: PhantomData,
737 },
738 peer: &self.peer,
739 }
740 }
741}
742
743impl<'a> Deref for ConnectionPoolGuard<'a> {
744 type Target = ConnectionPool;
745
746 fn deref(&self) -> &Self::Target {
747 &*self.guard
748 }
749}
750
751impl<'a> DerefMut for ConnectionPoolGuard<'a> {
752 fn deref_mut(&mut self) -> &mut Self::Target {
753 &mut *self.guard
754 }
755}
756
757impl<'a> Drop for ConnectionPoolGuard<'a> {
758 fn drop(&mut self) {
759 #[cfg(test)]
760 self.check_invariants();
761 }
762}
763
764fn broadcast<F>(
765 sender_id: Option<ConnectionId>,
766 receiver_ids: impl IntoIterator<Item = ConnectionId>,
767 mut f: F,
768) where
769 F: FnMut(ConnectionId) -> anyhow::Result<()>,
770{
771 for receiver_id in receiver_ids {
772 if Some(receiver_id) != sender_id {
773 if let Err(error) = f(receiver_id) {
774 tracing::error!("failed to send to {:?} {}", receiver_id, error);
775 }
776 }
777 }
778}
779
780lazy_static! {
781 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
782}
783
784pub struct ProtocolVersion(u32);
785
786impl Header for ProtocolVersion {
787 fn name() -> &'static HeaderName {
788 &ZED_PROTOCOL_VERSION
789 }
790
791 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
792 where
793 Self: Sized,
794 I: Iterator<Item = &'i axum::http::HeaderValue>,
795 {
796 let version = values
797 .next()
798 .ok_or_else(axum::headers::Error::invalid)?
799 .to_str()
800 .map_err(|_| axum::headers::Error::invalid())?
801 .parse()
802 .map_err(|_| axum::headers::Error::invalid())?;
803 Ok(Self(version))
804 }
805
806 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
807 values.extend([self.0.to_string().parse().unwrap()]);
808 }
809}
810
811pub fn routes(server: Arc<Server>) -> Router<Body> {
812 Router::new()
813 .route("/rpc", get(handle_websocket_request))
814 .layer(
815 ServiceBuilder::new()
816 .layer(Extension(server.app_state.clone()))
817 .layer(middleware::from_fn(auth::validate_header)),
818 )
819 .route("/metrics", get(handle_metrics))
820 .layer(Extension(server))
821}
822
823pub async fn handle_websocket_request(
824 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
825 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
826 Extension(server): Extension<Arc<Server>>,
827 Extension(user): Extension<User>,
828 ws: WebSocketUpgrade,
829) -> axum::response::Response {
830 if protocol_version != rpc::PROTOCOL_VERSION {
831 return (
832 StatusCode::UPGRADE_REQUIRED,
833 "client must be upgraded".to_string(),
834 )
835 .into_response();
836 }
837 let socket_address = socket_address.to_string();
838 ws.on_upgrade(move |socket| {
839 use util::ResultExt;
840 let socket = socket
841 .map_ok(to_tungstenite_message)
842 .err_into()
843 .with(|message| async move { Ok(to_axum_message(message)) });
844 let connection = Connection::new(Box::pin(socket));
845 async move {
846 server
847 .handle_connection(connection, socket_address, user, None, Executor::Production)
848 .await
849 .log_err();
850 }
851 })
852}
853
854pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
855 let connections = server
856 .connection_pool
857 .lock()
858 .connections()
859 .filter(|connection| !connection.admin)
860 .count();
861
862 METRIC_CONNECTIONS.set(connections as _);
863
864 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
865 METRIC_SHARED_PROJECTS.set(shared_projects as _);
866
867 let encoder = prometheus::TextEncoder::new();
868 let metric_families = prometheus::gather();
869 let encoded_metrics = encoder
870 .encode_to_string(&metric_families)
871 .map_err(|err| anyhow!("{}", err))?;
872 Ok(encoded_metrics)
873}
874
875#[instrument(err, skip(executor))]
876async fn connection_lost(
877 session: Session,
878 mut teardown: watch::Receiver<()>,
879 executor: Executor,
880) -> Result<()> {
881 session.peer.disconnect(session.connection_id);
882 session
883 .connection_pool()
884 .await
885 .remove_connection(session.connection_id)?;
886
887 session
888 .db()
889 .await
890 .connection_lost(session.connection_id)
891 .await
892 .trace_err();
893
894 futures::select_biased! {
895 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
896 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
897 leave_room_for_session(&session).await.trace_err();
898 leave_channel_buffers_for_session(&session)
899 .await
900 .trace_err();
901
902 if !session
903 .connection_pool()
904 .await
905 .is_user_online(session.user_id)
906 {
907 let db = session.db().await;
908 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
909 room_updated(&room, &session.peer);
910 }
911 }
912
913 update_user_contacts(session.user_id, &session).await?;
914 }
915 _ = teardown.changed().fuse() => {}
916 }
917
918 Ok(())
919}
920
921async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
922 response.send(proto::Ack {})?;
923 Ok(())
924}
925
926async fn create_room(
927 _request: proto::CreateRoom,
928 response: Response<proto::CreateRoom>,
929 session: Session,
930) -> Result<()> {
931 let live_kit_room = nanoid::nanoid!(30);
932
933 let live_kit_connection_info = {
934 let live_kit_room = live_kit_room.clone();
935 let live_kit = session.live_kit_client.as_ref();
936
937 util::async_iife!({
938 let live_kit = live_kit?;
939
940 live_kit
941 .create_room(live_kit_room.clone())
942 .await
943 .trace_err()?;
944
945 let token = live_kit
946 .room_token(&live_kit_room, &session.user_id.to_string())
947 .trace_err()?;
948
949 Some(proto::LiveKitConnectionInfo {
950 server_url: live_kit.url().into(),
951 token,
952 })
953 })
954 }
955 .await;
956
957 let room = session
958 .db()
959 .await
960 .create_room(session.user_id, session.connection_id, &live_kit_room)
961 .await?;
962
963 response.send(proto::CreateRoomResponse {
964 room: Some(room.clone()),
965 live_kit_connection_info,
966 })?;
967
968 update_user_contacts(session.user_id, &session).await?;
969 Ok(())
970}
971
972async fn join_room(
973 request: proto::JoinRoom,
974 response: Response<proto::JoinRoom>,
975 session: Session,
976) -> Result<()> {
977 let room_id = RoomId::from_proto(request.id);
978 let joined_room = {
979 let room = session
980 .db()
981 .await
982 .join_room(room_id, session.user_id, session.connection_id)
983 .await?;
984 room_updated(&room.room, &session.peer);
985 room.into_inner()
986 };
987
988 if let Some(channel_id) = joined_room.channel_id {
989 channel_updated(
990 channel_id,
991 &joined_room.room,
992 &joined_room.channel_members,
993 &session.peer,
994 &*session.connection_pool().await,
995 )
996 }
997
998 for connection_id in session
999 .connection_pool()
1000 .await
1001 .user_connection_ids(session.user_id)
1002 {
1003 session
1004 .peer
1005 .send(
1006 connection_id,
1007 proto::CallCanceled {
1008 room_id: room_id.to_proto(),
1009 },
1010 )
1011 .trace_err();
1012 }
1013
1014 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1015 if let Some(token) = live_kit
1016 .room_token(
1017 &joined_room.room.live_kit_room,
1018 &session.user_id.to_string(),
1019 )
1020 .trace_err()
1021 {
1022 Some(proto::LiveKitConnectionInfo {
1023 server_url: live_kit.url().into(),
1024 token,
1025 })
1026 } else {
1027 None
1028 }
1029 } else {
1030 None
1031 };
1032
1033 response.send(proto::JoinRoomResponse {
1034 room: Some(joined_room.room),
1035 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
1036 live_kit_connection_info,
1037 })?;
1038
1039 update_user_contacts(session.user_id, &session).await?;
1040 Ok(())
1041}
1042
1043async fn rejoin_room(
1044 request: proto::RejoinRoom,
1045 response: Response<proto::RejoinRoom>,
1046 session: Session,
1047) -> Result<()> {
1048 let room;
1049 let channel_id;
1050 let channel_members;
1051 {
1052 let mut rejoined_room = session
1053 .db()
1054 .await
1055 .rejoin_room(request, session.user_id, session.connection_id)
1056 .await?;
1057
1058 response.send(proto::RejoinRoomResponse {
1059 room: Some(rejoined_room.room.clone()),
1060 reshared_projects: rejoined_room
1061 .reshared_projects
1062 .iter()
1063 .map(|project| proto::ResharedProject {
1064 id: project.id.to_proto(),
1065 collaborators: project
1066 .collaborators
1067 .iter()
1068 .map(|collaborator| collaborator.to_proto())
1069 .collect(),
1070 })
1071 .collect(),
1072 rejoined_projects: rejoined_room
1073 .rejoined_projects
1074 .iter()
1075 .map(|rejoined_project| proto::RejoinedProject {
1076 id: rejoined_project.id.to_proto(),
1077 worktrees: rejoined_project
1078 .worktrees
1079 .iter()
1080 .map(|worktree| proto::WorktreeMetadata {
1081 id: worktree.id,
1082 root_name: worktree.root_name.clone(),
1083 visible: worktree.visible,
1084 abs_path: worktree.abs_path.clone(),
1085 })
1086 .collect(),
1087 collaborators: rejoined_project
1088 .collaborators
1089 .iter()
1090 .map(|collaborator| collaborator.to_proto())
1091 .collect(),
1092 language_servers: rejoined_project.language_servers.clone(),
1093 })
1094 .collect(),
1095 })?;
1096 room_updated(&rejoined_room.room, &session.peer);
1097
1098 for project in &rejoined_room.reshared_projects {
1099 for collaborator in &project.collaborators {
1100 session
1101 .peer
1102 .send(
1103 collaborator.connection_id,
1104 proto::UpdateProjectCollaborator {
1105 project_id: project.id.to_proto(),
1106 old_peer_id: Some(project.old_connection_id.into()),
1107 new_peer_id: Some(session.connection_id.into()),
1108 },
1109 )
1110 .trace_err();
1111 }
1112
1113 broadcast(
1114 Some(session.connection_id),
1115 project
1116 .collaborators
1117 .iter()
1118 .map(|collaborator| collaborator.connection_id),
1119 |connection_id| {
1120 session.peer.forward_send(
1121 session.connection_id,
1122 connection_id,
1123 proto::UpdateProject {
1124 project_id: project.id.to_proto(),
1125 worktrees: project.worktrees.clone(),
1126 },
1127 )
1128 },
1129 );
1130 }
1131
1132 for project in &rejoined_room.rejoined_projects {
1133 for collaborator in &project.collaborators {
1134 session
1135 .peer
1136 .send(
1137 collaborator.connection_id,
1138 proto::UpdateProjectCollaborator {
1139 project_id: project.id.to_proto(),
1140 old_peer_id: Some(project.old_connection_id.into()),
1141 new_peer_id: Some(session.connection_id.into()),
1142 },
1143 )
1144 .trace_err();
1145 }
1146 }
1147
1148 for project in &mut rejoined_room.rejoined_projects {
1149 for worktree in mem::take(&mut project.worktrees) {
1150 #[cfg(any(test, feature = "test-support"))]
1151 const MAX_CHUNK_SIZE: usize = 2;
1152 #[cfg(not(any(test, feature = "test-support")))]
1153 const MAX_CHUNK_SIZE: usize = 256;
1154
1155 // Stream this worktree's entries.
1156 let message = proto::UpdateWorktree {
1157 project_id: project.id.to_proto(),
1158 worktree_id: worktree.id,
1159 abs_path: worktree.abs_path.clone(),
1160 root_name: worktree.root_name,
1161 updated_entries: worktree.updated_entries,
1162 removed_entries: worktree.removed_entries,
1163 scan_id: worktree.scan_id,
1164 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1165 updated_repositories: worktree.updated_repositories,
1166 removed_repositories: worktree.removed_repositories,
1167 };
1168 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1169 session.peer.send(session.connection_id, update.clone())?;
1170 }
1171
1172 // Stream this worktree's diagnostics.
1173 for summary in worktree.diagnostic_summaries {
1174 session.peer.send(
1175 session.connection_id,
1176 proto::UpdateDiagnosticSummary {
1177 project_id: project.id.to_proto(),
1178 worktree_id: worktree.id,
1179 summary: Some(summary),
1180 },
1181 )?;
1182 }
1183
1184 for settings_file in worktree.settings_files {
1185 session.peer.send(
1186 session.connection_id,
1187 proto::UpdateWorktreeSettings {
1188 project_id: project.id.to_proto(),
1189 worktree_id: worktree.id,
1190 path: settings_file.path,
1191 content: Some(settings_file.content),
1192 },
1193 )?;
1194 }
1195 }
1196
1197 for language_server in &project.language_servers {
1198 session.peer.send(
1199 session.connection_id,
1200 proto::UpdateLanguageServer {
1201 project_id: project.id.to_proto(),
1202 language_server_id: language_server.id,
1203 variant: Some(
1204 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1205 proto::LspDiskBasedDiagnosticsUpdated {},
1206 ),
1207 ),
1208 },
1209 )?;
1210 }
1211 }
1212
1213 let rejoined_room = rejoined_room.into_inner();
1214
1215 room = rejoined_room.room;
1216 channel_id = rejoined_room.channel_id;
1217 channel_members = rejoined_room.channel_members;
1218 }
1219
1220 if let Some(channel_id) = channel_id {
1221 channel_updated(
1222 channel_id,
1223 &room,
1224 &channel_members,
1225 &session.peer,
1226 &*session.connection_pool().await,
1227 );
1228 }
1229
1230 update_user_contacts(session.user_id, &session).await?;
1231 Ok(())
1232}
1233
1234async fn leave_room(
1235 _: proto::LeaveRoom,
1236 response: Response<proto::LeaveRoom>,
1237 session: Session,
1238) -> Result<()> {
1239 leave_room_for_session(&session).await?;
1240 response.send(proto::Ack {})?;
1241 Ok(())
1242}
1243
1244async fn call(
1245 request: proto::Call,
1246 response: Response<proto::Call>,
1247 session: Session,
1248) -> Result<()> {
1249 let room_id = RoomId::from_proto(request.room_id);
1250 let calling_user_id = session.user_id;
1251 let calling_connection_id = session.connection_id;
1252 let called_user_id = UserId::from_proto(request.called_user_id);
1253 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1254 if !session
1255 .db()
1256 .await
1257 .has_contact(calling_user_id, called_user_id)
1258 .await?
1259 {
1260 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1261 }
1262
1263 let incoming_call = {
1264 let (room, incoming_call) = &mut *session
1265 .db()
1266 .await
1267 .call(
1268 room_id,
1269 calling_user_id,
1270 calling_connection_id,
1271 called_user_id,
1272 initial_project_id,
1273 )
1274 .await?;
1275 room_updated(&room, &session.peer);
1276 mem::take(incoming_call)
1277 };
1278 update_user_contacts(called_user_id, &session).await?;
1279
1280 let mut calls = session
1281 .connection_pool()
1282 .await
1283 .user_connection_ids(called_user_id)
1284 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1285 .collect::<FuturesUnordered<_>>();
1286
1287 while let Some(call_response) = calls.next().await {
1288 match call_response.as_ref() {
1289 Ok(_) => {
1290 response.send(proto::Ack {})?;
1291 return Ok(());
1292 }
1293 Err(_) => {
1294 call_response.trace_err();
1295 }
1296 }
1297 }
1298
1299 {
1300 let room = session
1301 .db()
1302 .await
1303 .call_failed(room_id, called_user_id)
1304 .await?;
1305 room_updated(&room, &session.peer);
1306 }
1307 update_user_contacts(called_user_id, &session).await?;
1308
1309 Err(anyhow!("failed to ring user"))?
1310}
1311
1312async fn cancel_call(
1313 request: proto::CancelCall,
1314 response: Response<proto::CancelCall>,
1315 session: Session,
1316) -> Result<()> {
1317 let called_user_id = UserId::from_proto(request.called_user_id);
1318 let room_id = RoomId::from_proto(request.room_id);
1319 {
1320 let room = session
1321 .db()
1322 .await
1323 .cancel_call(room_id, session.connection_id, called_user_id)
1324 .await?;
1325 room_updated(&room, &session.peer);
1326 }
1327
1328 for connection_id in session
1329 .connection_pool()
1330 .await
1331 .user_connection_ids(called_user_id)
1332 {
1333 session
1334 .peer
1335 .send(
1336 connection_id,
1337 proto::CallCanceled {
1338 room_id: room_id.to_proto(),
1339 },
1340 )
1341 .trace_err();
1342 }
1343 response.send(proto::Ack {})?;
1344
1345 update_user_contacts(called_user_id, &session).await?;
1346 Ok(())
1347}
1348
1349async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1350 let room_id = RoomId::from_proto(message.room_id);
1351 {
1352 let room = session
1353 .db()
1354 .await
1355 .decline_call(Some(room_id), session.user_id)
1356 .await?
1357 .ok_or_else(|| anyhow!("failed to decline call"))?;
1358 room_updated(&room, &session.peer);
1359 }
1360
1361 for connection_id in session
1362 .connection_pool()
1363 .await
1364 .user_connection_ids(session.user_id)
1365 {
1366 session
1367 .peer
1368 .send(
1369 connection_id,
1370 proto::CallCanceled {
1371 room_id: room_id.to_proto(),
1372 },
1373 )
1374 .trace_err();
1375 }
1376 update_user_contacts(session.user_id, &session).await?;
1377 Ok(())
1378}
1379
1380async fn update_participant_location(
1381 request: proto::UpdateParticipantLocation,
1382 response: Response<proto::UpdateParticipantLocation>,
1383 session: Session,
1384) -> Result<()> {
1385 let room_id = RoomId::from_proto(request.room_id);
1386 let location = request
1387 .location
1388 .ok_or_else(|| anyhow!("invalid location"))?;
1389
1390 let db = session.db().await;
1391 let room = db
1392 .update_room_participant_location(room_id, session.connection_id, location)
1393 .await?;
1394
1395 room_updated(&room, &session.peer);
1396 response.send(proto::Ack {})?;
1397 Ok(())
1398}
1399
1400async fn share_project(
1401 request: proto::ShareProject,
1402 response: Response<proto::ShareProject>,
1403 session: Session,
1404) -> Result<()> {
1405 let (project_id, room) = &*session
1406 .db()
1407 .await
1408 .share_project(
1409 RoomId::from_proto(request.room_id),
1410 session.connection_id,
1411 &request.worktrees,
1412 )
1413 .await?;
1414 response.send(proto::ShareProjectResponse {
1415 project_id: project_id.to_proto(),
1416 })?;
1417 room_updated(&room, &session.peer);
1418
1419 Ok(())
1420}
1421
1422async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1423 let project_id = ProjectId::from_proto(message.project_id);
1424
1425 let (room, guest_connection_ids) = &*session
1426 .db()
1427 .await
1428 .unshare_project(project_id, session.connection_id)
1429 .await?;
1430
1431 broadcast(
1432 Some(session.connection_id),
1433 guest_connection_ids.iter().copied(),
1434 |conn_id| session.peer.send(conn_id, message.clone()),
1435 );
1436 room_updated(&room, &session.peer);
1437
1438 Ok(())
1439}
1440
1441async fn join_project(
1442 request: proto::JoinProject,
1443 response: Response<proto::JoinProject>,
1444 session: Session,
1445) -> Result<()> {
1446 let project_id = ProjectId::from_proto(request.project_id);
1447 let guest_user_id = session.user_id;
1448
1449 tracing::info!(%project_id, "join project");
1450
1451 let (project, replica_id) = &mut *session
1452 .db()
1453 .await
1454 .join_project(project_id, session.connection_id)
1455 .await?;
1456
1457 let collaborators = project
1458 .collaborators
1459 .iter()
1460 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1461 .map(|collaborator| collaborator.to_proto())
1462 .collect::<Vec<_>>();
1463
1464 let worktrees = project
1465 .worktrees
1466 .iter()
1467 .map(|(id, worktree)| proto::WorktreeMetadata {
1468 id: *id,
1469 root_name: worktree.root_name.clone(),
1470 visible: worktree.visible,
1471 abs_path: worktree.abs_path.clone(),
1472 })
1473 .collect::<Vec<_>>();
1474
1475 for collaborator in &collaborators {
1476 session
1477 .peer
1478 .send(
1479 collaborator.peer_id.unwrap().into(),
1480 proto::AddProjectCollaborator {
1481 project_id: project_id.to_proto(),
1482 collaborator: Some(proto::Collaborator {
1483 peer_id: Some(session.connection_id.into()),
1484 replica_id: replica_id.0 as u32,
1485 user_id: guest_user_id.to_proto(),
1486 }),
1487 },
1488 )
1489 .trace_err();
1490 }
1491
1492 // First, we send the metadata associated with each worktree.
1493 response.send(proto::JoinProjectResponse {
1494 worktrees: worktrees.clone(),
1495 replica_id: replica_id.0 as u32,
1496 collaborators: collaborators.clone(),
1497 language_servers: project.language_servers.clone(),
1498 })?;
1499
1500 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1501 #[cfg(any(test, feature = "test-support"))]
1502 const MAX_CHUNK_SIZE: usize = 2;
1503 #[cfg(not(any(test, feature = "test-support")))]
1504 const MAX_CHUNK_SIZE: usize = 256;
1505
1506 // Stream this worktree's entries.
1507 let message = proto::UpdateWorktree {
1508 project_id: project_id.to_proto(),
1509 worktree_id,
1510 abs_path: worktree.abs_path.clone(),
1511 root_name: worktree.root_name,
1512 updated_entries: worktree.entries,
1513 removed_entries: Default::default(),
1514 scan_id: worktree.scan_id,
1515 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1516 updated_repositories: worktree.repository_entries.into_values().collect(),
1517 removed_repositories: Default::default(),
1518 };
1519 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1520 session.peer.send(session.connection_id, update.clone())?;
1521 }
1522
1523 // Stream this worktree's diagnostics.
1524 for summary in worktree.diagnostic_summaries {
1525 session.peer.send(
1526 session.connection_id,
1527 proto::UpdateDiagnosticSummary {
1528 project_id: project_id.to_proto(),
1529 worktree_id: worktree.id,
1530 summary: Some(summary),
1531 },
1532 )?;
1533 }
1534
1535 for settings_file in worktree.settings_files {
1536 session.peer.send(
1537 session.connection_id,
1538 proto::UpdateWorktreeSettings {
1539 project_id: project_id.to_proto(),
1540 worktree_id: worktree.id,
1541 path: settings_file.path,
1542 content: Some(settings_file.content),
1543 },
1544 )?;
1545 }
1546 }
1547
1548 for language_server in &project.language_servers {
1549 session.peer.send(
1550 session.connection_id,
1551 proto::UpdateLanguageServer {
1552 project_id: project_id.to_proto(),
1553 language_server_id: language_server.id,
1554 variant: Some(
1555 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1556 proto::LspDiskBasedDiagnosticsUpdated {},
1557 ),
1558 ),
1559 },
1560 )?;
1561 }
1562
1563 Ok(())
1564}
1565
1566async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1567 let sender_id = session.connection_id;
1568 let project_id = ProjectId::from_proto(request.project_id);
1569
1570 let (room, project) = &*session
1571 .db()
1572 .await
1573 .leave_project(project_id, sender_id)
1574 .await?;
1575 tracing::info!(
1576 %project_id,
1577 host_user_id = %project.host_user_id,
1578 host_connection_id = %project.host_connection_id,
1579 "leave project"
1580 );
1581
1582 project_left(&project, &session);
1583 room_updated(&room, &session.peer);
1584
1585 Ok(())
1586}
1587
1588async fn update_project(
1589 request: proto::UpdateProject,
1590 response: Response<proto::UpdateProject>,
1591 session: Session,
1592) -> Result<()> {
1593 let project_id = ProjectId::from_proto(request.project_id);
1594 let (room, guest_connection_ids) = &*session
1595 .db()
1596 .await
1597 .update_project(project_id, session.connection_id, &request.worktrees)
1598 .await?;
1599 broadcast(
1600 Some(session.connection_id),
1601 guest_connection_ids.iter().copied(),
1602 |connection_id| {
1603 session
1604 .peer
1605 .forward_send(session.connection_id, connection_id, request.clone())
1606 },
1607 );
1608 room_updated(&room, &session.peer);
1609 response.send(proto::Ack {})?;
1610
1611 Ok(())
1612}
1613
1614async fn update_worktree(
1615 request: proto::UpdateWorktree,
1616 response: Response<proto::UpdateWorktree>,
1617 session: Session,
1618) -> Result<()> {
1619 let guest_connection_ids = session
1620 .db()
1621 .await
1622 .update_worktree(&request, session.connection_id)
1623 .await?;
1624
1625 broadcast(
1626 Some(session.connection_id),
1627 guest_connection_ids.iter().copied(),
1628 |connection_id| {
1629 session
1630 .peer
1631 .forward_send(session.connection_id, connection_id, request.clone())
1632 },
1633 );
1634 response.send(proto::Ack {})?;
1635 Ok(())
1636}
1637
1638async fn update_diagnostic_summary(
1639 message: proto::UpdateDiagnosticSummary,
1640 session: Session,
1641) -> Result<()> {
1642 let guest_connection_ids = session
1643 .db()
1644 .await
1645 .update_diagnostic_summary(&message, session.connection_id)
1646 .await?;
1647
1648 broadcast(
1649 Some(session.connection_id),
1650 guest_connection_ids.iter().copied(),
1651 |connection_id| {
1652 session
1653 .peer
1654 .forward_send(session.connection_id, connection_id, message.clone())
1655 },
1656 );
1657
1658 Ok(())
1659}
1660
1661async fn update_worktree_settings(
1662 message: proto::UpdateWorktreeSettings,
1663 session: Session,
1664) -> Result<()> {
1665 let guest_connection_ids = session
1666 .db()
1667 .await
1668 .update_worktree_settings(&message, session.connection_id)
1669 .await?;
1670
1671 broadcast(
1672 Some(session.connection_id),
1673 guest_connection_ids.iter().copied(),
1674 |connection_id| {
1675 session
1676 .peer
1677 .forward_send(session.connection_id, connection_id, message.clone())
1678 },
1679 );
1680
1681 Ok(())
1682}
1683
1684async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1685 broadcast_project_message(request.project_id, request, session).await
1686}
1687
1688async fn start_language_server(
1689 request: proto::StartLanguageServer,
1690 session: Session,
1691) -> Result<()> {
1692 let guest_connection_ids = session
1693 .db()
1694 .await
1695 .start_language_server(&request, session.connection_id)
1696 .await?;
1697
1698 broadcast(
1699 Some(session.connection_id),
1700 guest_connection_ids.iter().copied(),
1701 |connection_id| {
1702 session
1703 .peer
1704 .forward_send(session.connection_id, connection_id, request.clone())
1705 },
1706 );
1707 Ok(())
1708}
1709
1710async fn update_language_server(
1711 request: proto::UpdateLanguageServer,
1712 session: Session,
1713) -> Result<()> {
1714 session.executor.record_backtrace();
1715 let project_id = ProjectId::from_proto(request.project_id);
1716 let project_connection_ids = session
1717 .db()
1718 .await
1719 .project_connection_ids(project_id, session.connection_id)
1720 .await?;
1721 broadcast(
1722 Some(session.connection_id),
1723 project_connection_ids.iter().copied(),
1724 |connection_id| {
1725 session
1726 .peer
1727 .forward_send(session.connection_id, connection_id, request.clone())
1728 },
1729 );
1730 Ok(())
1731}
1732
1733async fn forward_project_request<T>(
1734 request: T,
1735 response: Response<T>,
1736 session: Session,
1737) -> Result<()>
1738where
1739 T: EntityMessage + RequestMessage,
1740{
1741 session.executor.record_backtrace();
1742 let project_id = ProjectId::from_proto(request.remote_entity_id());
1743 let host_connection_id = {
1744 let collaborators = session
1745 .db()
1746 .await
1747 .project_collaborators(project_id, session.connection_id)
1748 .await?;
1749 collaborators
1750 .iter()
1751 .find(|collaborator| collaborator.is_host)
1752 .ok_or_else(|| anyhow!("host not found"))?
1753 .connection_id
1754 };
1755
1756 let payload = session
1757 .peer
1758 .forward_request(session.connection_id, host_connection_id, request)
1759 .await?;
1760
1761 response.send(payload)?;
1762 Ok(())
1763}
1764
1765async fn create_buffer_for_peer(
1766 request: proto::CreateBufferForPeer,
1767 session: Session,
1768) -> Result<()> {
1769 session.executor.record_backtrace();
1770 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1771 session
1772 .peer
1773 .forward_send(session.connection_id, peer_id.into(), request)?;
1774 Ok(())
1775}
1776
1777async fn update_buffer(
1778 request: proto::UpdateBuffer,
1779 response: Response<proto::UpdateBuffer>,
1780 session: Session,
1781) -> Result<()> {
1782 session.executor.record_backtrace();
1783 let project_id = ProjectId::from_proto(request.project_id);
1784 let mut guest_connection_ids;
1785 let mut host_connection_id = None;
1786 {
1787 let collaborators = session
1788 .db()
1789 .await
1790 .project_collaborators(project_id, session.connection_id)
1791 .await?;
1792 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1793 for collaborator in collaborators.iter() {
1794 if collaborator.is_host {
1795 host_connection_id = Some(collaborator.connection_id);
1796 } else {
1797 guest_connection_ids.push(collaborator.connection_id);
1798 }
1799 }
1800 }
1801 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1802
1803 session.executor.record_backtrace();
1804 broadcast(
1805 Some(session.connection_id),
1806 guest_connection_ids,
1807 |connection_id| {
1808 session
1809 .peer
1810 .forward_send(session.connection_id, connection_id, request.clone())
1811 },
1812 );
1813 if host_connection_id != session.connection_id {
1814 session
1815 .peer
1816 .forward_request(session.connection_id, host_connection_id, request.clone())
1817 .await?;
1818 }
1819
1820 response.send(proto::Ack {})?;
1821 Ok(())
1822}
1823
1824async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1825 let project_id = ProjectId::from_proto(request.project_id);
1826 let project_connection_ids = session
1827 .db()
1828 .await
1829 .project_connection_ids(project_id, session.connection_id)
1830 .await?;
1831
1832 broadcast(
1833 Some(session.connection_id),
1834 project_connection_ids.iter().copied(),
1835 |connection_id| {
1836 session
1837 .peer
1838 .forward_send(session.connection_id, connection_id, request.clone())
1839 },
1840 );
1841 Ok(())
1842}
1843
1844async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1845 let project_id = ProjectId::from_proto(request.project_id);
1846 let project_connection_ids = session
1847 .db()
1848 .await
1849 .project_connection_ids(project_id, session.connection_id)
1850 .await?;
1851 broadcast(
1852 Some(session.connection_id),
1853 project_connection_ids.iter().copied(),
1854 |connection_id| {
1855 session
1856 .peer
1857 .forward_send(session.connection_id, connection_id, request.clone())
1858 },
1859 );
1860 Ok(())
1861}
1862
1863async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1864 broadcast_project_message(request.project_id, request, session).await
1865}
1866
1867async fn broadcast_project_message<T: EnvelopedMessage>(
1868 project_id: u64,
1869 request: T,
1870 session: Session,
1871) -> Result<()> {
1872 let project_id = ProjectId::from_proto(project_id);
1873 let project_connection_ids = session
1874 .db()
1875 .await
1876 .project_connection_ids(project_id, session.connection_id)
1877 .await?;
1878 broadcast(
1879 Some(session.connection_id),
1880 project_connection_ids.iter().copied(),
1881 |connection_id| {
1882 session
1883 .peer
1884 .forward_send(session.connection_id, connection_id, request.clone())
1885 },
1886 );
1887 Ok(())
1888}
1889
1890async fn follow(
1891 request: proto::Follow,
1892 response: Response<proto::Follow>,
1893 session: Session,
1894) -> Result<()> {
1895 let room_id = RoomId::from_proto(request.room_id);
1896 let project_id = request.project_id.map(ProjectId::from_proto);
1897 let leader_id = request
1898 .leader_id
1899 .ok_or_else(|| anyhow!("invalid leader id"))?
1900 .into();
1901 let follower_id = session.connection_id;
1902
1903 session
1904 .db()
1905 .await
1906 .check_room_participants(room_id, leader_id, session.connection_id)
1907 .await?;
1908
1909 let mut response_payload = session
1910 .peer
1911 .forward_request(session.connection_id, leader_id, request)
1912 .await?;
1913 response_payload
1914 .views
1915 .retain(|view| view.leader_id != Some(follower_id.into()));
1916 response.send(response_payload)?;
1917
1918 if let Some(project_id) = project_id {
1919 let room = session
1920 .db()
1921 .await
1922 .follow(room_id, project_id, leader_id, follower_id)
1923 .await?;
1924 room_updated(&room, &session.peer);
1925 }
1926
1927 Ok(())
1928}
1929
1930async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1931 let room_id = RoomId::from_proto(request.room_id);
1932 let project_id = request.project_id.map(ProjectId::from_proto);
1933 let leader_id = request
1934 .leader_id
1935 .ok_or_else(|| anyhow!("invalid leader id"))?
1936 .into();
1937 let follower_id = session.connection_id;
1938
1939 session
1940 .db()
1941 .await
1942 .check_room_participants(room_id, leader_id, session.connection_id)
1943 .await?;
1944
1945 session
1946 .peer
1947 .forward_send(session.connection_id, leader_id, request)?;
1948
1949 if let Some(project_id) = project_id {
1950 let room = session
1951 .db()
1952 .await
1953 .unfollow(room_id, project_id, leader_id, follower_id)
1954 .await?;
1955 room_updated(&room, &session.peer);
1956 }
1957
1958 Ok(())
1959}
1960
1961async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1962 let room_id = RoomId::from_proto(request.room_id);
1963 let database = session.db.lock().await;
1964
1965 let connection_ids = if let Some(project_id) = request.project_id {
1966 let project_id = ProjectId::from_proto(project_id);
1967 database
1968 .project_connection_ids(project_id, session.connection_id)
1969 .await?
1970 } else {
1971 database
1972 .room_connection_ids(room_id, session.connection_id)
1973 .await?
1974 };
1975
1976 let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1977 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1978 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1979 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1980 });
1981 for follower_peer_id in request.follower_ids.iter().copied() {
1982 let follower_connection_id = follower_peer_id.into();
1983 if Some(follower_peer_id) != leader_id && connection_ids.contains(&follower_connection_id) {
1984 session.peer.forward_send(
1985 session.connection_id,
1986 follower_connection_id,
1987 request.clone(),
1988 )?;
1989 }
1990 }
1991 Ok(())
1992}
1993
1994async fn get_users(
1995 request: proto::GetUsers,
1996 response: Response<proto::GetUsers>,
1997 session: Session,
1998) -> Result<()> {
1999 let user_ids = request
2000 .user_ids
2001 .into_iter()
2002 .map(UserId::from_proto)
2003 .collect();
2004 let users = session
2005 .db()
2006 .await
2007 .get_users_by_ids(user_ids)
2008 .await?
2009 .into_iter()
2010 .map(|user| proto::User {
2011 id: user.id.to_proto(),
2012 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2013 github_login: user.github_login,
2014 })
2015 .collect();
2016 response.send(proto::UsersResponse { users })?;
2017 Ok(())
2018}
2019
2020async fn fuzzy_search_users(
2021 request: proto::FuzzySearchUsers,
2022 response: Response<proto::FuzzySearchUsers>,
2023 session: Session,
2024) -> Result<()> {
2025 let query = request.query;
2026 let users = match query.len() {
2027 0 => vec![],
2028 1 | 2 => session
2029 .db()
2030 .await
2031 .get_user_by_github_login(&query)
2032 .await?
2033 .into_iter()
2034 .collect(),
2035 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2036 };
2037 let users = users
2038 .into_iter()
2039 .filter(|user| user.id != session.user_id)
2040 .map(|user| proto::User {
2041 id: user.id.to_proto(),
2042 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2043 github_login: user.github_login,
2044 })
2045 .collect();
2046 response.send(proto::UsersResponse { users })?;
2047 Ok(())
2048}
2049
2050async fn request_contact(
2051 request: proto::RequestContact,
2052 response: Response<proto::RequestContact>,
2053 session: Session,
2054) -> Result<()> {
2055 let requester_id = session.user_id;
2056 let responder_id = UserId::from_proto(request.responder_id);
2057 if requester_id == responder_id {
2058 return Err(anyhow!("cannot add yourself as a contact"))?;
2059 }
2060
2061 session
2062 .db()
2063 .await
2064 .send_contact_request(requester_id, responder_id)
2065 .await?;
2066
2067 // Update outgoing contact requests of requester
2068 let mut update = proto::UpdateContacts::default();
2069 update.outgoing_requests.push(responder_id.to_proto());
2070 for connection_id in session
2071 .connection_pool()
2072 .await
2073 .user_connection_ids(requester_id)
2074 {
2075 session.peer.send(connection_id, update.clone())?;
2076 }
2077
2078 // Update incoming contact requests of responder
2079 let mut update = proto::UpdateContacts::default();
2080 update
2081 .incoming_requests
2082 .push(proto::IncomingContactRequest {
2083 requester_id: requester_id.to_proto(),
2084 should_notify: true,
2085 });
2086 for connection_id in session
2087 .connection_pool()
2088 .await
2089 .user_connection_ids(responder_id)
2090 {
2091 session.peer.send(connection_id, update.clone())?;
2092 }
2093
2094 response.send(proto::Ack {})?;
2095 Ok(())
2096}
2097
2098async fn respond_to_contact_request(
2099 request: proto::RespondToContactRequest,
2100 response: Response<proto::RespondToContactRequest>,
2101 session: Session,
2102) -> Result<()> {
2103 let responder_id = session.user_id;
2104 let requester_id = UserId::from_proto(request.requester_id);
2105 let db = session.db().await;
2106 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2107 db.dismiss_contact_notification(responder_id, requester_id)
2108 .await?;
2109 } else {
2110 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2111
2112 db.respond_to_contact_request(responder_id, requester_id, accept)
2113 .await?;
2114 let requester_busy = db.is_user_busy(requester_id).await?;
2115 let responder_busy = db.is_user_busy(responder_id).await?;
2116
2117 let pool = session.connection_pool().await;
2118 // Update responder with new contact
2119 let mut update = proto::UpdateContacts::default();
2120 if accept {
2121 update
2122 .contacts
2123 .push(contact_for_user(requester_id, false, requester_busy, &pool));
2124 }
2125 update
2126 .remove_incoming_requests
2127 .push(requester_id.to_proto());
2128 for connection_id in pool.user_connection_ids(responder_id) {
2129 session.peer.send(connection_id, update.clone())?;
2130 }
2131
2132 // Update requester with new contact
2133 let mut update = proto::UpdateContacts::default();
2134 if accept {
2135 update
2136 .contacts
2137 .push(contact_for_user(responder_id, true, responder_busy, &pool));
2138 }
2139 update
2140 .remove_outgoing_requests
2141 .push(responder_id.to_proto());
2142 for connection_id in pool.user_connection_ids(requester_id) {
2143 session.peer.send(connection_id, update.clone())?;
2144 }
2145 }
2146
2147 response.send(proto::Ack {})?;
2148 Ok(())
2149}
2150
2151async fn remove_contact(
2152 request: proto::RemoveContact,
2153 response: Response<proto::RemoveContact>,
2154 session: Session,
2155) -> Result<()> {
2156 let requester_id = session.user_id;
2157 let responder_id = UserId::from_proto(request.user_id);
2158 let db = session.db().await;
2159 let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
2160
2161 let pool = session.connection_pool().await;
2162 // Update outgoing contact requests of requester
2163 let mut update = proto::UpdateContacts::default();
2164 if contact_accepted {
2165 update.remove_contacts.push(responder_id.to_proto());
2166 } else {
2167 update
2168 .remove_outgoing_requests
2169 .push(responder_id.to_proto());
2170 }
2171 for connection_id in pool.user_connection_ids(requester_id) {
2172 session.peer.send(connection_id, update.clone())?;
2173 }
2174
2175 // Update incoming contact requests of responder
2176 let mut update = proto::UpdateContacts::default();
2177 if contact_accepted {
2178 update.remove_contacts.push(requester_id.to_proto());
2179 } else {
2180 update
2181 .remove_incoming_requests
2182 .push(requester_id.to_proto());
2183 }
2184 for connection_id in pool.user_connection_ids(responder_id) {
2185 session.peer.send(connection_id, update.clone())?;
2186 }
2187
2188 response.send(proto::Ack {})?;
2189 Ok(())
2190}
2191
2192async fn create_channel(
2193 request: proto::CreateChannel,
2194 response: Response<proto::CreateChannel>,
2195 session: Session,
2196) -> Result<()> {
2197 let db = session.db().await;
2198 let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
2199
2200 if let Some(live_kit) = session.live_kit_client.as_ref() {
2201 live_kit.create_room(live_kit_room.clone()).await?;
2202 }
2203
2204 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2205 let id = db
2206 .create_channel(&request.name, parent_id, &live_kit_room, session.user_id)
2207 .await?;
2208
2209 let channel = proto::Channel {
2210 id: id.to_proto(),
2211 name: request.name,
2212 };
2213
2214 response.send(proto::CreateChannelResponse {
2215 channel: Some(channel.clone()),
2216 parent_id: request.parent_id,
2217 })?;
2218
2219 let Some(parent_id) = parent_id else {
2220 return Ok(());
2221 };
2222
2223 let update = proto::UpdateChannels {
2224 channels: vec![channel],
2225 insert_edge: vec![ChannelEdge {
2226 parent_id: parent_id.to_proto(),
2227 channel_id: id.to_proto(),
2228 }],
2229 ..Default::default()
2230 };
2231
2232 let user_ids_to_notify = db.get_channel_members(parent_id).await?;
2233
2234 let connection_pool = session.connection_pool().await;
2235 for user_id in user_ids_to_notify {
2236 for connection_id in connection_pool.user_connection_ids(user_id) {
2237 if user_id == session.user_id {
2238 continue;
2239 }
2240 session.peer.send(connection_id, update.clone())?;
2241 }
2242 }
2243
2244 Ok(())
2245}
2246
2247async fn delete_channel(
2248 request: proto::DeleteChannel,
2249 response: Response<proto::DeleteChannel>,
2250 session: Session,
2251) -> Result<()> {
2252 let db = session.db().await;
2253
2254 let channel_id = request.channel_id;
2255 let (removed_channels, member_ids) = db
2256 .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2257 .await?;
2258 response.send(proto::Ack {})?;
2259
2260 // Notify members of removed channels
2261 let mut update = proto::UpdateChannels::default();
2262 update
2263 .delete_channels
2264 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2265
2266 let connection_pool = session.connection_pool().await;
2267 for member_id in member_ids {
2268 for connection_id in connection_pool.user_connection_ids(member_id) {
2269 session.peer.send(connection_id, update.clone())?;
2270 }
2271 }
2272
2273 Ok(())
2274}
2275
2276async fn invite_channel_member(
2277 request: proto::InviteChannelMember,
2278 response: Response<proto::InviteChannelMember>,
2279 session: Session,
2280) -> Result<()> {
2281 let db = session.db().await;
2282 let channel_id = ChannelId::from_proto(request.channel_id);
2283 let invitee_id = UserId::from_proto(request.user_id);
2284 db.invite_channel_member(channel_id, invitee_id, session.user_id, request.admin)
2285 .await?;
2286
2287 let (channel, _) = db
2288 .get_channel(channel_id, session.user_id)
2289 .await?
2290 .ok_or_else(|| anyhow!("channel not found"))?;
2291
2292 let mut update = proto::UpdateChannels::default();
2293 update.channel_invitations.push(proto::Channel {
2294 id: channel.id.to_proto(),
2295 name: channel.name,
2296 });
2297 for connection_id in session
2298 .connection_pool()
2299 .await
2300 .user_connection_ids(invitee_id)
2301 {
2302 session.peer.send(connection_id, update.clone())?;
2303 }
2304
2305 response.send(proto::Ack {})?;
2306 Ok(())
2307}
2308
2309async fn remove_channel_member(
2310 request: proto::RemoveChannelMember,
2311 response: Response<proto::RemoveChannelMember>,
2312 session: Session,
2313) -> Result<()> {
2314 let db = session.db().await;
2315 let channel_id = ChannelId::from_proto(request.channel_id);
2316 let member_id = UserId::from_proto(request.user_id);
2317
2318 db.remove_channel_member(channel_id, member_id, session.user_id)
2319 .await?;
2320
2321 let mut update = proto::UpdateChannels::default();
2322 update.delete_channels.push(channel_id.to_proto());
2323
2324 for connection_id in session
2325 .connection_pool()
2326 .await
2327 .user_connection_ids(member_id)
2328 {
2329 session.peer.send(connection_id, update.clone())?;
2330 }
2331
2332 response.send(proto::Ack {})?;
2333 Ok(())
2334}
2335
2336async fn set_channel_member_admin(
2337 request: proto::SetChannelMemberAdmin,
2338 response: Response<proto::SetChannelMemberAdmin>,
2339 session: Session,
2340) -> Result<()> {
2341 let db = session.db().await;
2342 let channel_id = ChannelId::from_proto(request.channel_id);
2343 let member_id = UserId::from_proto(request.user_id);
2344 db.set_channel_member_admin(channel_id, session.user_id, member_id, request.admin)
2345 .await?;
2346
2347 let (channel, has_accepted) = db
2348 .get_channel(channel_id, member_id)
2349 .await?
2350 .ok_or_else(|| anyhow!("channel not found"))?;
2351
2352 let mut update = proto::UpdateChannels::default();
2353 if has_accepted {
2354 update.channel_permissions.push(proto::ChannelPermission {
2355 channel_id: channel.id.to_proto(),
2356 is_admin: request.admin,
2357 });
2358 }
2359
2360 for connection_id in session
2361 .connection_pool()
2362 .await
2363 .user_connection_ids(member_id)
2364 {
2365 session.peer.send(connection_id, update.clone())?;
2366 }
2367
2368 response.send(proto::Ack {})?;
2369 Ok(())
2370}
2371
2372async fn rename_channel(
2373 request: proto::RenameChannel,
2374 response: Response<proto::RenameChannel>,
2375 session: Session,
2376) -> Result<()> {
2377 let db = session.db().await;
2378 let channel_id = ChannelId::from_proto(request.channel_id);
2379 let new_name = db
2380 .rename_channel(channel_id, session.user_id, &request.name)
2381 .await?;
2382
2383 let channel = proto::Channel {
2384 id: request.channel_id,
2385 name: new_name,
2386 };
2387 response.send(proto::RenameChannelResponse {
2388 channel: Some(channel.clone()),
2389 })?;
2390 let mut update = proto::UpdateChannels::default();
2391 update.channels.push(channel);
2392
2393 let member_ids = db.get_channel_members(channel_id).await?;
2394
2395 let connection_pool = session.connection_pool().await;
2396 for member_id in member_ids {
2397 for connection_id in connection_pool.user_connection_ids(member_id) {
2398 session.peer.send(connection_id, update.clone())?;
2399 }
2400 }
2401
2402 Ok(())
2403}
2404
2405async fn link_channel(
2406 request: proto::LinkChannel,
2407 response: Response<proto::LinkChannel>,
2408 session: Session,
2409) -> Result<()> {
2410 let db = session.db().await;
2411 let channel_id = ChannelId::from_proto(request.channel_id);
2412 let to = ChannelId::from_proto(request.to);
2413 let channels_to_send = db.link_channel(session.user_id, channel_id, to).await?;
2414
2415 let members = db.get_channel_members(to).await?;
2416 let connection_pool = session.connection_pool().await;
2417 let update = proto::UpdateChannels {
2418 channels: channels_to_send
2419 .channels
2420 .into_iter()
2421 .map(|channel| proto::Channel {
2422 id: channel.id.to_proto(),
2423 name: channel.name,
2424 })
2425 .collect(),
2426 insert_edge: channels_to_send.edges,
2427 ..Default::default()
2428 };
2429 for member_id in members {
2430 for connection_id in connection_pool.user_connection_ids(member_id) {
2431 session.peer.send(connection_id, update.clone())?;
2432 }
2433 }
2434
2435 response.send(Ack {})?;
2436
2437 Ok(())
2438}
2439
2440async fn unlink_channel(
2441 request: proto::UnlinkChannel,
2442 response: Response<proto::UnlinkChannel>,
2443 session: Session,
2444) -> Result<()> {
2445 let db = session.db().await;
2446 let channel_id = ChannelId::from_proto(request.channel_id);
2447 let from = ChannelId::from_proto(request.from);
2448
2449 db.unlink_channel(session.user_id, channel_id, from).await?;
2450
2451 let members = db.get_channel_members(from).await?;
2452
2453 let update = proto::UpdateChannels {
2454 delete_edge: vec![proto::ChannelEdge {
2455 channel_id: channel_id.to_proto(),
2456 parent_id: from.to_proto(),
2457 }],
2458 ..Default::default()
2459 };
2460 let connection_pool = session.connection_pool().await;
2461 for member_id in members {
2462 for connection_id in connection_pool.user_connection_ids(member_id) {
2463 session.peer.send(connection_id, update.clone())?;
2464 }
2465 }
2466
2467 response.send(Ack {})?;
2468
2469 Ok(())
2470}
2471
2472async fn move_channel(
2473 request: proto::MoveChannel,
2474 response: Response<proto::MoveChannel>,
2475 session: Session,
2476) -> Result<()> {
2477 let db = session.db().await;
2478 let channel_id = ChannelId::from_proto(request.channel_id);
2479 let from_parent = ChannelId::from_proto(request.from);
2480 let to = ChannelId::from_proto(request.to);
2481
2482 let channels_to_send = db
2483 .move_channel(session.user_id, channel_id, from_parent, to)
2484 .await?;
2485
2486 if channels_to_send.is_empty() {
2487 response.send(Ack {})?;
2488 return Ok(());
2489 }
2490
2491 let members_from = db.get_channel_members(from_parent).await?;
2492 let members_to = db.get_channel_members(to).await?;
2493
2494 let update = proto::UpdateChannels {
2495 delete_edge: vec![proto::ChannelEdge {
2496 channel_id: channel_id.to_proto(),
2497 parent_id: from_parent.to_proto(),
2498 }],
2499 ..Default::default()
2500 };
2501 let connection_pool = session.connection_pool().await;
2502 for member_id in members_from {
2503 for connection_id in connection_pool.user_connection_ids(member_id) {
2504 session.peer.send(connection_id, update.clone())?;
2505 }
2506 }
2507
2508 let update = proto::UpdateChannels {
2509 channels: channels_to_send
2510 .channels
2511 .into_iter()
2512 .map(|channel| proto::Channel {
2513 id: channel.id.to_proto(),
2514 name: channel.name,
2515 })
2516 .collect(),
2517 insert_edge: channels_to_send.edges,
2518 ..Default::default()
2519 };
2520 for member_id in members_to {
2521 for connection_id in connection_pool.user_connection_ids(member_id) {
2522 session.peer.send(connection_id, update.clone())?;
2523 }
2524 }
2525
2526 response.send(Ack {})?;
2527
2528 Ok(())
2529}
2530
2531async fn get_channel_members(
2532 request: proto::GetChannelMembers,
2533 response: Response<proto::GetChannelMembers>,
2534 session: Session,
2535) -> Result<()> {
2536 let db = session.db().await;
2537 let channel_id = ChannelId::from_proto(request.channel_id);
2538 let members = db
2539 .get_channel_member_details(channel_id, session.user_id)
2540 .await?;
2541 response.send(proto::GetChannelMembersResponse { members })?;
2542 Ok(())
2543}
2544
2545async fn respond_to_channel_invite(
2546 request: proto::RespondToChannelInvite,
2547 response: Response<proto::RespondToChannelInvite>,
2548 session: Session,
2549) -> Result<()> {
2550 let db = session.db().await;
2551 let channel_id = ChannelId::from_proto(request.channel_id);
2552 db.respond_to_channel_invite(channel_id, session.user_id, request.accept)
2553 .await?;
2554
2555 let mut update = proto::UpdateChannels::default();
2556 update
2557 .remove_channel_invitations
2558 .push(channel_id.to_proto());
2559 if request.accept {
2560 let result = db.get_channel_for_user(channel_id, session.user_id).await?;
2561 update
2562 .channels
2563 .extend(
2564 result
2565 .channels
2566 .channels
2567 .into_iter()
2568 .map(|channel| proto::Channel {
2569 id: channel.id.to_proto(),
2570 name: channel.name,
2571 }),
2572 );
2573 update.unseen_channel_messages = result.channel_messages;
2574 update.unseen_channel_buffer_changes = result.unseen_buffer_changes;
2575 update.insert_edge = result.channels.edges;
2576 update
2577 .channel_participants
2578 .extend(
2579 result
2580 .channel_participants
2581 .into_iter()
2582 .map(|(channel_id, user_ids)| proto::ChannelParticipants {
2583 channel_id: channel_id.to_proto(),
2584 participant_user_ids: user_ids.into_iter().map(UserId::to_proto).collect(),
2585 }),
2586 );
2587 update
2588 .channel_permissions
2589 .extend(
2590 result
2591 .channels_with_admin_privileges
2592 .into_iter()
2593 .map(|channel_id| proto::ChannelPermission {
2594 channel_id: channel_id.to_proto(),
2595 is_admin: true,
2596 }),
2597 );
2598 }
2599 session.peer.send(session.connection_id, update)?;
2600 response.send(proto::Ack {})?;
2601
2602 Ok(())
2603}
2604
2605async fn join_channel(
2606 request: proto::JoinChannel,
2607 response: Response<proto::JoinChannel>,
2608 session: Session,
2609) -> Result<()> {
2610 let channel_id = ChannelId::from_proto(request.channel_id);
2611
2612 let joined_room = {
2613 leave_room_for_session(&session).await?;
2614 let db = session.db().await;
2615
2616 let room_id = db.room_id_for_channel(channel_id).await?;
2617
2618 let joined_room = db
2619 .join_room(room_id, session.user_id, session.connection_id)
2620 .await?;
2621
2622 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2623 let token = live_kit
2624 .room_token(
2625 &joined_room.room.live_kit_room,
2626 &session.user_id.to_string(),
2627 )
2628 .trace_err()?;
2629
2630 Some(LiveKitConnectionInfo {
2631 server_url: live_kit.url().into(),
2632 token,
2633 })
2634 });
2635
2636 response.send(proto::JoinRoomResponse {
2637 room: Some(joined_room.room.clone()),
2638 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2639 live_kit_connection_info,
2640 })?;
2641
2642 room_updated(&joined_room.room, &session.peer);
2643
2644 joined_room.into_inner()
2645 };
2646
2647 channel_updated(
2648 channel_id,
2649 &joined_room.room,
2650 &joined_room.channel_members,
2651 &session.peer,
2652 &*session.connection_pool().await,
2653 );
2654
2655 update_user_contacts(session.user_id, &session).await?;
2656
2657 Ok(())
2658}
2659
2660async fn join_channel_buffer(
2661 request: proto::JoinChannelBuffer,
2662 response: Response<proto::JoinChannelBuffer>,
2663 session: Session,
2664) -> Result<()> {
2665 let db = session.db().await;
2666 let channel_id = ChannelId::from_proto(request.channel_id);
2667
2668 let open_response = db
2669 .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2670 .await?;
2671
2672 let collaborators = open_response.collaborators.clone();
2673 response.send(open_response)?;
2674
2675 let update = UpdateChannelBufferCollaborators {
2676 channel_id: channel_id.to_proto(),
2677 collaborators: collaborators.clone(),
2678 };
2679 channel_buffer_updated(
2680 session.connection_id,
2681 collaborators
2682 .iter()
2683 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2684 &update,
2685 &session.peer,
2686 );
2687
2688 Ok(())
2689}
2690
2691async fn update_channel_buffer(
2692 request: proto::UpdateChannelBuffer,
2693 session: Session,
2694) -> Result<()> {
2695 let db = session.db().await;
2696 let channel_id = ChannelId::from_proto(request.channel_id);
2697
2698 let (collaborators, non_collaborators, epoch, version) = db
2699 .update_channel_buffer(channel_id, session.user_id, &request.operations)
2700 .await?;
2701
2702 channel_buffer_updated(
2703 session.connection_id,
2704 collaborators,
2705 &proto::UpdateChannelBuffer {
2706 channel_id: channel_id.to_proto(),
2707 operations: request.operations,
2708 },
2709 &session.peer,
2710 );
2711
2712 let pool = &*session.connection_pool().await;
2713
2714 broadcast(
2715 None,
2716 non_collaborators
2717 .iter()
2718 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2719 |peer_id| {
2720 session.peer.send(
2721 peer_id.into(),
2722 proto::UpdateChannels {
2723 unseen_channel_buffer_changes: vec![proto::UnseenChannelBufferChange {
2724 channel_id: channel_id.to_proto(),
2725 epoch: epoch as u64,
2726 version: version.clone(),
2727 }],
2728 ..Default::default()
2729 },
2730 )
2731 },
2732 );
2733
2734 Ok(())
2735}
2736
2737async fn rejoin_channel_buffers(
2738 request: proto::RejoinChannelBuffers,
2739 response: Response<proto::RejoinChannelBuffers>,
2740 session: Session,
2741) -> Result<()> {
2742 let db = session.db().await;
2743 let buffers = db
2744 .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2745 .await?;
2746
2747 for rejoined_buffer in &buffers {
2748 let collaborators_to_notify = rejoined_buffer
2749 .buffer
2750 .collaborators
2751 .iter()
2752 .filter_map(|c| Some(c.peer_id?.into()));
2753 channel_buffer_updated(
2754 session.connection_id,
2755 collaborators_to_notify,
2756 &proto::UpdateChannelBufferCollaborators {
2757 channel_id: rejoined_buffer.buffer.channel_id,
2758 collaborators: rejoined_buffer.buffer.collaborators.clone(),
2759 },
2760 &session.peer,
2761 );
2762 }
2763
2764 response.send(proto::RejoinChannelBuffersResponse {
2765 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2766 })?;
2767
2768 Ok(())
2769}
2770
2771async fn leave_channel_buffer(
2772 request: proto::LeaveChannelBuffer,
2773 response: Response<proto::LeaveChannelBuffer>,
2774 session: Session,
2775) -> Result<()> {
2776 let db = session.db().await;
2777 let channel_id = ChannelId::from_proto(request.channel_id);
2778
2779 let left_buffer = db
2780 .leave_channel_buffer(channel_id, session.connection_id)
2781 .await?;
2782
2783 response.send(Ack {})?;
2784
2785 channel_buffer_updated(
2786 session.connection_id,
2787 left_buffer.connections,
2788 &proto::UpdateChannelBufferCollaborators {
2789 channel_id: channel_id.to_proto(),
2790 collaborators: left_buffer.collaborators,
2791 },
2792 &session.peer,
2793 );
2794
2795 Ok(())
2796}
2797
2798fn channel_buffer_updated<T: EnvelopedMessage>(
2799 sender_id: ConnectionId,
2800 collaborators: impl IntoIterator<Item = ConnectionId>,
2801 message: &T,
2802 peer: &Peer,
2803) {
2804 broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2805 peer.send(peer_id.into(), message.clone())
2806 });
2807}
2808
2809async fn send_channel_message(
2810 request: proto::SendChannelMessage,
2811 response: Response<proto::SendChannelMessage>,
2812 session: Session,
2813) -> Result<()> {
2814 // Validate the message body.
2815 let body = request.body.trim().to_string();
2816 if body.len() > MAX_MESSAGE_LEN {
2817 return Err(anyhow!("message is too long"))?;
2818 }
2819 if body.is_empty() {
2820 return Err(anyhow!("message can't be blank"))?;
2821 }
2822
2823 let timestamp = OffsetDateTime::now_utc();
2824 let nonce = request
2825 .nonce
2826 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
2827
2828 let channel_id = ChannelId::from_proto(request.channel_id);
2829 let (message_id, connection_ids, non_participants) = session
2830 .db()
2831 .await
2832 .create_channel_message(
2833 channel_id,
2834 session.user_id,
2835 &body,
2836 timestamp,
2837 nonce.clone().into(),
2838 )
2839 .await?;
2840 let message = proto::ChannelMessage {
2841 sender_id: session.user_id.to_proto(),
2842 id: message_id.to_proto(),
2843 body,
2844 timestamp: timestamp.unix_timestamp() as u64,
2845 nonce: Some(nonce),
2846 };
2847 broadcast(Some(session.connection_id), connection_ids, |connection| {
2848 session.peer.send(
2849 connection,
2850 proto::ChannelMessageSent {
2851 channel_id: channel_id.to_proto(),
2852 message: Some(message.clone()),
2853 },
2854 )
2855 });
2856 response.send(proto::SendChannelMessageResponse {
2857 message: Some(message),
2858 })?;
2859
2860 let pool = &*session.connection_pool().await;
2861 broadcast(
2862 None,
2863 non_participants
2864 .iter()
2865 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2866 |peer_id| {
2867 session.peer.send(
2868 peer_id.into(),
2869 proto::UpdateChannels {
2870 unseen_channel_messages: vec![proto::UnseenChannelMessage {
2871 channel_id: channel_id.to_proto(),
2872 message_id: message_id.to_proto(),
2873 }],
2874 ..Default::default()
2875 },
2876 )
2877 },
2878 );
2879
2880 Ok(())
2881}
2882
2883async fn remove_channel_message(
2884 request: proto::RemoveChannelMessage,
2885 response: Response<proto::RemoveChannelMessage>,
2886 session: Session,
2887) -> Result<()> {
2888 let channel_id = ChannelId::from_proto(request.channel_id);
2889 let message_id = MessageId::from_proto(request.message_id);
2890 let connection_ids = session
2891 .db()
2892 .await
2893 .remove_channel_message(channel_id, message_id, session.user_id)
2894 .await?;
2895 broadcast(Some(session.connection_id), connection_ids, |connection| {
2896 session.peer.send(connection, request.clone())
2897 });
2898 response.send(proto::Ack {})?;
2899 Ok(())
2900}
2901
2902async fn acknowledge_channel_message(
2903 request: proto::AckChannelMessage,
2904 session: Session,
2905) -> Result<()> {
2906 let channel_id = ChannelId::from_proto(request.channel_id);
2907 let message_id = MessageId::from_proto(request.message_id);
2908 session
2909 .db()
2910 .await
2911 .observe_channel_message(channel_id, session.user_id, message_id)
2912 .await?;
2913 Ok(())
2914}
2915
2916async fn acknowledge_buffer_version(
2917 request: proto::AckBufferOperation,
2918 session: Session,
2919) -> Result<()> {
2920 let buffer_id = BufferId::from_proto(request.buffer_id);
2921 session
2922 .db()
2923 .await
2924 .observe_buffer_version(
2925 buffer_id,
2926 session.user_id,
2927 request.epoch as i32,
2928 &request.version,
2929 )
2930 .await?;
2931 Ok(())
2932}
2933
2934async fn join_channel_chat(
2935 request: proto::JoinChannelChat,
2936 response: Response<proto::JoinChannelChat>,
2937 session: Session,
2938) -> Result<()> {
2939 let channel_id = ChannelId::from_proto(request.channel_id);
2940
2941 let db = session.db().await;
2942 db.join_channel_chat(channel_id, session.connection_id, session.user_id)
2943 .await?;
2944 let messages = db
2945 .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
2946 .await?;
2947 response.send(proto::JoinChannelChatResponse {
2948 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
2949 messages,
2950 })?;
2951 Ok(())
2952}
2953
2954async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
2955 let channel_id = ChannelId::from_proto(request.channel_id);
2956 session
2957 .db()
2958 .await
2959 .leave_channel_chat(channel_id, session.connection_id, session.user_id)
2960 .await?;
2961 Ok(())
2962}
2963
2964async fn get_channel_messages(
2965 request: proto::GetChannelMessages,
2966 response: Response<proto::GetChannelMessages>,
2967 session: Session,
2968) -> Result<()> {
2969 let channel_id = ChannelId::from_proto(request.channel_id);
2970 let messages = session
2971 .db()
2972 .await
2973 .get_channel_messages(
2974 channel_id,
2975 session.user_id,
2976 MESSAGE_COUNT_PER_PAGE,
2977 Some(MessageId::from_proto(request.before_message_id)),
2978 )
2979 .await?;
2980 response.send(proto::GetChannelMessagesResponse {
2981 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
2982 messages,
2983 })?;
2984 Ok(())
2985}
2986
2987async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
2988 let project_id = ProjectId::from_proto(request.project_id);
2989 let project_connection_ids = session
2990 .db()
2991 .await
2992 .project_connection_ids(project_id, session.connection_id)
2993 .await?;
2994 broadcast(
2995 Some(session.connection_id),
2996 project_connection_ids.iter().copied(),
2997 |connection_id| {
2998 session
2999 .peer
3000 .forward_send(session.connection_id, connection_id, request.clone())
3001 },
3002 );
3003 Ok(())
3004}
3005
3006async fn get_private_user_info(
3007 _request: proto::GetPrivateUserInfo,
3008 response: Response<proto::GetPrivateUserInfo>,
3009 session: Session,
3010) -> Result<()> {
3011 let db = session.db().await;
3012
3013 let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3014 let user = db
3015 .get_user_by_id(session.user_id)
3016 .await?
3017 .ok_or_else(|| anyhow!("user not found"))?;
3018 let flags = db.get_user_flags(session.user_id).await?;
3019
3020 response.send(proto::GetPrivateUserInfoResponse {
3021 metrics_id,
3022 staff: user.admin,
3023 flags,
3024 })?;
3025 Ok(())
3026}
3027
3028fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3029 match message {
3030 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3031 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3032 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3033 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3034 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3035 code: frame.code.into(),
3036 reason: frame.reason,
3037 })),
3038 }
3039}
3040
3041fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3042 match message {
3043 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3044 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3045 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3046 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3047 AxumMessage::Close(frame) => {
3048 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3049 code: frame.code.into(),
3050 reason: frame.reason,
3051 }))
3052 }
3053 }
3054}
3055
3056fn build_initial_channels_update(
3057 channels: ChannelsForUser,
3058 channel_invites: Vec<db::Channel>,
3059) -> proto::UpdateChannels {
3060 let mut update = proto::UpdateChannels::default();
3061
3062 for channel in channels.channels.channels {
3063 update.channels.push(proto::Channel {
3064 id: channel.id.to_proto(),
3065 name: channel.name,
3066 });
3067 }
3068
3069 update.unseen_channel_buffer_changes = channels.unseen_buffer_changes;
3070 update.unseen_channel_messages = channels.channel_messages;
3071 update.insert_edge = channels.channels.edges;
3072
3073 for (channel_id, participants) in channels.channel_participants {
3074 update
3075 .channel_participants
3076 .push(proto::ChannelParticipants {
3077 channel_id: channel_id.to_proto(),
3078 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3079 });
3080 }
3081
3082 update
3083 .channel_permissions
3084 .extend(
3085 channels
3086 .channels_with_admin_privileges
3087 .into_iter()
3088 .map(|id| proto::ChannelPermission {
3089 channel_id: id.to_proto(),
3090 is_admin: true,
3091 }),
3092 );
3093
3094 for channel in channel_invites {
3095 update.channel_invitations.push(proto::Channel {
3096 id: channel.id.to_proto(),
3097 name: channel.name,
3098 });
3099 }
3100
3101 update
3102}
3103
3104fn build_initial_contacts_update(
3105 contacts: Vec<db::Contact>,
3106 pool: &ConnectionPool,
3107) -> proto::UpdateContacts {
3108 let mut update = proto::UpdateContacts::default();
3109
3110 for contact in contacts {
3111 match contact {
3112 db::Contact::Accepted {
3113 user_id,
3114 should_notify,
3115 busy,
3116 } => {
3117 update
3118 .contacts
3119 .push(contact_for_user(user_id, should_notify, busy, &pool));
3120 }
3121 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3122 db::Contact::Incoming {
3123 user_id,
3124 should_notify,
3125 } => update
3126 .incoming_requests
3127 .push(proto::IncomingContactRequest {
3128 requester_id: user_id.to_proto(),
3129 should_notify,
3130 }),
3131 }
3132 }
3133
3134 update
3135}
3136
3137fn contact_for_user(
3138 user_id: UserId,
3139 should_notify: bool,
3140 busy: bool,
3141 pool: &ConnectionPool,
3142) -> proto::Contact {
3143 proto::Contact {
3144 user_id: user_id.to_proto(),
3145 online: pool.is_user_online(user_id),
3146 busy,
3147 should_notify,
3148 }
3149}
3150
3151fn room_updated(room: &proto::Room, peer: &Peer) {
3152 broadcast(
3153 None,
3154 room.participants
3155 .iter()
3156 .filter_map(|participant| Some(participant.peer_id?.into())),
3157 |peer_id| {
3158 peer.send(
3159 peer_id.into(),
3160 proto::RoomUpdated {
3161 room: Some(room.clone()),
3162 },
3163 )
3164 },
3165 );
3166}
3167
3168fn channel_updated(
3169 channel_id: ChannelId,
3170 room: &proto::Room,
3171 channel_members: &[UserId],
3172 peer: &Peer,
3173 pool: &ConnectionPool,
3174) {
3175 let participants = room
3176 .participants
3177 .iter()
3178 .map(|p| p.user_id)
3179 .collect::<Vec<_>>();
3180
3181 broadcast(
3182 None,
3183 channel_members
3184 .iter()
3185 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3186 |peer_id| {
3187 peer.send(
3188 peer_id.into(),
3189 proto::UpdateChannels {
3190 channel_participants: vec![proto::ChannelParticipants {
3191 channel_id: channel_id.to_proto(),
3192 participant_user_ids: participants.clone(),
3193 }],
3194 ..Default::default()
3195 },
3196 )
3197 },
3198 );
3199}
3200
3201async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3202 let db = session.db().await;
3203
3204 let contacts = db.get_contacts(user_id).await?;
3205 let busy = db.is_user_busy(user_id).await?;
3206
3207 let pool = session.connection_pool().await;
3208 let updated_contact = contact_for_user(user_id, false, busy, &pool);
3209 for contact in contacts {
3210 if let db::Contact::Accepted {
3211 user_id: contact_user_id,
3212 ..
3213 } = contact
3214 {
3215 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3216 session
3217 .peer
3218 .send(
3219 contact_conn_id,
3220 proto::UpdateContacts {
3221 contacts: vec![updated_contact.clone()],
3222 remove_contacts: Default::default(),
3223 incoming_requests: Default::default(),
3224 remove_incoming_requests: Default::default(),
3225 outgoing_requests: Default::default(),
3226 remove_outgoing_requests: Default::default(),
3227 },
3228 )
3229 .trace_err();
3230 }
3231 }
3232 }
3233 Ok(())
3234}
3235
3236async fn leave_room_for_session(session: &Session) -> Result<()> {
3237 let mut contacts_to_update = HashSet::default();
3238
3239 let room_id;
3240 let canceled_calls_to_user_ids;
3241 let live_kit_room;
3242 let delete_live_kit_room;
3243 let room;
3244 let channel_members;
3245 let channel_id;
3246
3247 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3248 contacts_to_update.insert(session.user_id);
3249
3250 for project in left_room.left_projects.values() {
3251 project_left(project, session);
3252 }
3253
3254 room_id = RoomId::from_proto(left_room.room.id);
3255 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3256 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3257 delete_live_kit_room = left_room.deleted;
3258 room = mem::take(&mut left_room.room);
3259 channel_members = mem::take(&mut left_room.channel_members);
3260 channel_id = left_room.channel_id;
3261
3262 room_updated(&room, &session.peer);
3263 } else {
3264 return Ok(());
3265 }
3266
3267 if let Some(channel_id) = channel_id {
3268 channel_updated(
3269 channel_id,
3270 &room,
3271 &channel_members,
3272 &session.peer,
3273 &*session.connection_pool().await,
3274 );
3275 }
3276
3277 {
3278 let pool = session.connection_pool().await;
3279 for canceled_user_id in canceled_calls_to_user_ids {
3280 for connection_id in pool.user_connection_ids(canceled_user_id) {
3281 session
3282 .peer
3283 .send(
3284 connection_id,
3285 proto::CallCanceled {
3286 room_id: room_id.to_proto(),
3287 },
3288 )
3289 .trace_err();
3290 }
3291 contacts_to_update.insert(canceled_user_id);
3292 }
3293 }
3294
3295 for contact_user_id in contacts_to_update {
3296 update_user_contacts(contact_user_id, &session).await?;
3297 }
3298
3299 if let Some(live_kit) = session.live_kit_client.as_ref() {
3300 live_kit
3301 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3302 .await
3303 .trace_err();
3304
3305 if delete_live_kit_room {
3306 live_kit.delete_room(live_kit_room).await.trace_err();
3307 }
3308 }
3309
3310 Ok(())
3311}
3312
3313async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3314 let left_channel_buffers = session
3315 .db()
3316 .await
3317 .leave_channel_buffers(session.connection_id)
3318 .await?;
3319
3320 for left_buffer in left_channel_buffers {
3321 channel_buffer_updated(
3322 session.connection_id,
3323 left_buffer.connections,
3324 &proto::UpdateChannelBufferCollaborators {
3325 channel_id: left_buffer.channel_id.to_proto(),
3326 collaborators: left_buffer.collaborators,
3327 },
3328 &session.peer,
3329 );
3330 }
3331
3332 Ok(())
3333}
3334
3335fn project_left(project: &db::LeftProject, session: &Session) {
3336 for connection_id in &project.connection_ids {
3337 if project.host_user_id == session.user_id {
3338 session
3339 .peer
3340 .send(
3341 *connection_id,
3342 proto::UnshareProject {
3343 project_id: project.id.to_proto(),
3344 },
3345 )
3346 .trace_err();
3347 } else {
3348 session
3349 .peer
3350 .send(
3351 *connection_id,
3352 proto::RemoveProjectCollaborator {
3353 project_id: project.id.to_proto(),
3354 peer_id: Some(session.connection_id.into()),
3355 },
3356 )
3357 .trace_err();
3358 }
3359 }
3360}
3361
3362pub trait ResultExt {
3363 type Ok;
3364
3365 fn trace_err(self) -> Option<Self::Ok>;
3366}
3367
3368impl<T, E> ResultExt for Result<T, E>
3369where
3370 E: std::fmt::Debug,
3371{
3372 type Ok = T;
3373
3374 fn trace_err(self) -> Option<T> {
3375 match self {
3376 Ok(value) => Some(value),
3377 Err(error) => {
3378 tracing::error!("{:?}", error);
3379 None
3380 }
3381 }
3382 }
3383}