1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{
6 self, BufferId, ChannelId, ChannelsForUser, Database, MessageId, ProjectId, RoomId,
7 ServerId, User, UserId,
8 },
9 executor::Executor,
10 AppState, Result,
11};
12use anyhow::anyhow;
13use async_tungstenite::tungstenite::{
14 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
15};
16use axum::{
17 body::Body,
18 extract::{
19 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
20 ConnectInfo, WebSocketUpgrade,
21 },
22 headers::{Header, HeaderName},
23 http::StatusCode,
24 middleware,
25 response::IntoResponse,
26 routing::get,
27 Extension, Router, TypedHeader,
28};
29use collections::{HashMap, HashSet};
30pub use connection_pool::ConnectionPool;
31use futures::{
32 channel::oneshot,
33 future::{self, BoxFuture},
34 stream::FuturesUnordered,
35 FutureExt, SinkExt, StreamExt, TryStreamExt,
36};
37use lazy_static::lazy_static;
38use prometheus::{register_int_gauge, IntGauge};
39use rpc::{
40 proto::{
41 self, Ack, AnyTypedEnvelope, ChannelEdge, EntityMessage, EnvelopedMessage,
42 LiveKitConnectionInfo, RequestMessage, UpdateChannelBufferCollaborators,
43 },
44 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
45};
46use serde::{Serialize, Serializer};
47use std::{
48 any::TypeId,
49 fmt,
50 future::Future,
51 marker::PhantomData,
52 mem,
53 net::SocketAddr,
54 ops::{Deref, DerefMut},
55 rc::Rc,
56 sync::{
57 atomic::{AtomicBool, Ordering::SeqCst},
58 Arc,
59 },
60 time::{Duration, Instant},
61};
62use time::OffsetDateTime;
63use tokio::sync::{watch, Semaphore};
64use tower::ServiceBuilder;
65use tracing::{info_span, instrument, Instrument};
66use util::channel::RELEASE_CHANNEL_NAME;
67
68pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
69pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
70
71const MESSAGE_COUNT_PER_PAGE: usize = 100;
72const MAX_MESSAGE_LEN: usize = 1024;
73
74lazy_static! {
75 static ref METRIC_CONNECTIONS: IntGauge =
76 register_int_gauge!("connections", "number of connections").unwrap();
77 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
78 "shared_projects",
79 "number of open projects with one or more guests"
80 )
81 .unwrap();
82}
83
84type MessageHandler =
85 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
86
87struct Response<R> {
88 peer: Arc<Peer>,
89 receipt: Receipt<R>,
90 responded: Arc<AtomicBool>,
91}
92
93impl<R: RequestMessage> Response<R> {
94 fn send(self, payload: R::Response) -> Result<()> {
95 self.responded.store(true, SeqCst);
96 self.peer.respond(self.receipt, payload)?;
97 Ok(())
98 }
99}
100
101#[derive(Clone)]
102struct Session {
103 user_id: UserId,
104 connection_id: ConnectionId,
105 db: Arc<tokio::sync::Mutex<DbHandle>>,
106 peer: Arc<Peer>,
107 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
108 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
109 executor: Executor,
110}
111
112impl Session {
113 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
114 #[cfg(test)]
115 tokio::task::yield_now().await;
116 let guard = self.db.lock().await;
117 #[cfg(test)]
118 tokio::task::yield_now().await;
119 guard
120 }
121
122 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
123 #[cfg(test)]
124 tokio::task::yield_now().await;
125 let guard = self.connection_pool.lock();
126 ConnectionPoolGuard {
127 guard,
128 _not_send: PhantomData,
129 }
130 }
131}
132
133impl fmt::Debug for Session {
134 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
135 f.debug_struct("Session")
136 .field("user_id", &self.user_id)
137 .field("connection_id", &self.connection_id)
138 .finish()
139 }
140}
141
142struct DbHandle(Arc<Database>);
143
144impl Deref for DbHandle {
145 type Target = Database;
146
147 fn deref(&self) -> &Self::Target {
148 self.0.as_ref()
149 }
150}
151
152pub struct Server {
153 id: parking_lot::Mutex<ServerId>,
154 peer: Arc<Peer>,
155 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
156 app_state: Arc<AppState>,
157 executor: Executor,
158 handlers: HashMap<TypeId, MessageHandler>,
159 teardown: watch::Sender<()>,
160}
161
162pub(crate) struct ConnectionPoolGuard<'a> {
163 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
164 _not_send: PhantomData<Rc<()>>,
165}
166
167#[derive(Serialize)]
168pub struct ServerSnapshot<'a> {
169 peer: &'a Peer,
170 #[serde(serialize_with = "serialize_deref")]
171 connection_pool: ConnectionPoolGuard<'a>,
172}
173
174pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
175where
176 S: Serializer,
177 T: Deref<Target = U>,
178 U: Serialize,
179{
180 Serialize::serialize(value.deref(), serializer)
181}
182
183impl Server {
184 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
185 let mut server = Self {
186 id: parking_lot::Mutex::new(id),
187 peer: Peer::new(id.0 as u32),
188 app_state,
189 executor,
190 connection_pool: Default::default(),
191 handlers: Default::default(),
192 teardown: watch::channel(()).0,
193 };
194
195 server
196 .add_request_handler(ping)
197 .add_request_handler(create_room)
198 .add_request_handler(join_room)
199 .add_request_handler(rejoin_room)
200 .add_request_handler(leave_room)
201 .add_request_handler(call)
202 .add_request_handler(cancel_call)
203 .add_message_handler(decline_call)
204 .add_request_handler(update_participant_location)
205 .add_request_handler(share_project)
206 .add_message_handler(unshare_project)
207 .add_request_handler(join_project)
208 .add_message_handler(leave_project)
209 .add_request_handler(update_project)
210 .add_request_handler(update_worktree)
211 .add_message_handler(start_language_server)
212 .add_message_handler(update_language_server)
213 .add_message_handler(update_diagnostic_summary)
214 .add_message_handler(update_worktree_settings)
215 .add_message_handler(refresh_inlay_hints)
216 .add_request_handler(forward_project_request::<proto::GetHover>)
217 .add_request_handler(forward_project_request::<proto::GetDefinition>)
218 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
219 .add_request_handler(forward_project_request::<proto::GetReferences>)
220 .add_request_handler(forward_project_request::<proto::SearchProject>)
221 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
222 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
223 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
224 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
225 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
226 .add_request_handler(forward_project_request::<proto::GetCompletions>)
227 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
228 .add_request_handler(forward_project_request::<proto::ResolveCompletionDocumentation>)
229 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
230 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
231 .add_request_handler(forward_project_request::<proto::PrepareRename>)
232 .add_request_handler(forward_project_request::<proto::PerformRename>)
233 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
234 .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
235 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
236 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
237 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
238 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
239 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
240 .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
241 .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
242 .add_request_handler(forward_project_request::<proto::InlayHints>)
243 .add_message_handler(create_buffer_for_peer)
244 .add_request_handler(update_buffer)
245 .add_message_handler(update_buffer_file)
246 .add_message_handler(buffer_reloaded)
247 .add_message_handler(buffer_saved)
248 .add_request_handler(forward_project_request::<proto::SaveBuffer>)
249 .add_request_handler(get_users)
250 .add_request_handler(fuzzy_search_users)
251 .add_request_handler(request_contact)
252 .add_request_handler(remove_contact)
253 .add_request_handler(respond_to_contact_request)
254 .add_request_handler(create_channel)
255 .add_request_handler(delete_channel)
256 .add_request_handler(invite_channel_member)
257 .add_request_handler(remove_channel_member)
258 .add_request_handler(set_channel_member_admin)
259 .add_request_handler(rename_channel)
260 .add_request_handler(join_channel_buffer)
261 .add_request_handler(leave_channel_buffer)
262 .add_message_handler(update_channel_buffer)
263 .add_request_handler(rejoin_channel_buffers)
264 .add_request_handler(get_channel_members)
265 .add_request_handler(respond_to_channel_invite)
266 .add_request_handler(join_channel)
267 .add_request_handler(join_channel_chat)
268 .add_message_handler(leave_channel_chat)
269 .add_request_handler(send_channel_message)
270 .add_request_handler(remove_channel_message)
271 .add_request_handler(get_channel_messages)
272 .add_request_handler(link_channel)
273 .add_request_handler(unlink_channel)
274 .add_request_handler(move_channel)
275 .add_request_handler(follow)
276 .add_message_handler(unfollow)
277 .add_message_handler(update_followers)
278 .add_message_handler(update_diff_base)
279 .add_request_handler(get_private_user_info)
280 .add_message_handler(acknowledge_channel_message)
281 .add_message_handler(acknowledge_buffer_version);
282
283 Arc::new(server)
284 }
285
286 pub async fn start(&self) -> Result<()> {
287 let server_id = *self.id.lock();
288 let app_state = self.app_state.clone();
289 let peer = self.peer.clone();
290 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
291 let pool = self.connection_pool.clone();
292 let live_kit_client = self.app_state.live_kit_client.clone();
293
294 let span = info_span!("start server");
295 self.executor.spawn_detached(
296 async move {
297 tracing::info!("waiting for cleanup timeout");
298 timeout.await;
299 tracing::info!("cleanup timeout expired, retrieving stale rooms");
300 if let Some((room_ids, channel_ids)) = app_state
301 .db
302 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
303 .await
304 .trace_err()
305 {
306 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
307 tracing::info!(
308 stale_channel_buffer_count = channel_ids.len(),
309 "retrieved stale channel buffers"
310 );
311
312 for channel_id in channel_ids {
313 if let Some(refreshed_channel_buffer) = app_state
314 .db
315 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
316 .await
317 .trace_err()
318 {
319 for connection_id in refreshed_channel_buffer.connection_ids {
320 peer.send(
321 connection_id,
322 proto::UpdateChannelBufferCollaborators {
323 channel_id: channel_id.to_proto(),
324 collaborators: refreshed_channel_buffer
325 .collaborators
326 .clone(),
327 },
328 )
329 .trace_err();
330 }
331 }
332 }
333
334 for room_id in room_ids {
335 let mut contacts_to_update = HashSet::default();
336 let mut canceled_calls_to_user_ids = Vec::new();
337 let mut live_kit_room = String::new();
338 let mut delete_live_kit_room = false;
339
340 if let Some(mut refreshed_room) = app_state
341 .db
342 .clear_stale_room_participants(room_id, server_id)
343 .await
344 .trace_err()
345 {
346 tracing::info!(
347 room_id = room_id.0,
348 new_participant_count = refreshed_room.room.participants.len(),
349 "refreshed room"
350 );
351 room_updated(&refreshed_room.room, &peer);
352 if let Some(channel_id) = refreshed_room.channel_id {
353 channel_updated(
354 channel_id,
355 &refreshed_room.room,
356 &refreshed_room.channel_members,
357 &peer,
358 &*pool.lock(),
359 );
360 }
361 contacts_to_update
362 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
363 contacts_to_update
364 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
365 canceled_calls_to_user_ids =
366 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
367 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
368 delete_live_kit_room = refreshed_room.room.participants.is_empty();
369 }
370
371 {
372 let pool = pool.lock();
373 for canceled_user_id in canceled_calls_to_user_ids {
374 for connection_id in pool.user_connection_ids(canceled_user_id) {
375 peer.send(
376 connection_id,
377 proto::CallCanceled {
378 room_id: room_id.to_proto(),
379 },
380 )
381 .trace_err();
382 }
383 }
384 }
385
386 for user_id in contacts_to_update {
387 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
388 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
389 if let Some((busy, contacts)) = busy.zip(contacts) {
390 let pool = pool.lock();
391 let updated_contact = contact_for_user(user_id, false, busy, &pool);
392 for contact in contacts {
393 if let db::Contact::Accepted {
394 user_id: contact_user_id,
395 ..
396 } = contact
397 {
398 for contact_conn_id in
399 pool.user_connection_ids(contact_user_id)
400 {
401 peer.send(
402 contact_conn_id,
403 proto::UpdateContacts {
404 contacts: vec![updated_contact.clone()],
405 remove_contacts: Default::default(),
406 incoming_requests: Default::default(),
407 remove_incoming_requests: Default::default(),
408 outgoing_requests: Default::default(),
409 remove_outgoing_requests: Default::default(),
410 },
411 )
412 .trace_err();
413 }
414 }
415 }
416 }
417 }
418
419 if let Some(live_kit) = live_kit_client.as_ref() {
420 if delete_live_kit_room {
421 live_kit.delete_room(live_kit_room).await.trace_err();
422 }
423 }
424 }
425 }
426
427 app_state
428 .db
429 .delete_stale_servers(&app_state.config.zed_environment, server_id)
430 .await
431 .trace_err();
432 }
433 .instrument(span),
434 );
435 Ok(())
436 }
437
438 pub fn teardown(&self) {
439 self.peer.teardown();
440 self.connection_pool.lock().reset();
441 let _ = self.teardown.send(());
442 }
443
444 #[cfg(test)]
445 pub fn reset(&self, id: ServerId) {
446 self.teardown();
447 *self.id.lock() = id;
448 self.peer.reset(id.0 as u32);
449 }
450
451 #[cfg(test)]
452 pub fn id(&self) -> ServerId {
453 *self.id.lock()
454 }
455
456 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
457 where
458 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
459 Fut: 'static + Send + Future<Output = Result<()>>,
460 M: EnvelopedMessage,
461 {
462 let prev_handler = self.handlers.insert(
463 TypeId::of::<M>(),
464 Box::new(move |envelope, session| {
465 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
466 let span = info_span!(
467 "handle message",
468 payload_type = envelope.payload_type_name()
469 );
470 span.in_scope(|| {
471 tracing::info!(
472 payload_type = envelope.payload_type_name(),
473 "message received"
474 );
475 });
476 let start_time = Instant::now();
477 let future = (handler)(*envelope, session);
478 async move {
479 let result = future.await;
480 let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
481 match result {
482 Err(error) => {
483 tracing::error!(%error, ?duration_ms, "error handling message")
484 }
485 Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
486 }
487 }
488 .instrument(span)
489 .boxed()
490 }),
491 );
492 if prev_handler.is_some() {
493 panic!("registered a handler for the same message twice");
494 }
495 self
496 }
497
498 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
499 where
500 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
501 Fut: 'static + Send + Future<Output = Result<()>>,
502 M: EnvelopedMessage,
503 {
504 self.add_handler(move |envelope, session| handler(envelope.payload, session));
505 self
506 }
507
508 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
509 where
510 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
511 Fut: Send + Future<Output = Result<()>>,
512 M: RequestMessage,
513 {
514 let handler = Arc::new(handler);
515 self.add_handler(move |envelope, session| {
516 let receipt = envelope.receipt();
517 let handler = handler.clone();
518 async move {
519 let peer = session.peer.clone();
520 let responded = Arc::new(AtomicBool::default());
521 let response = Response {
522 peer: peer.clone(),
523 responded: responded.clone(),
524 receipt,
525 };
526 match (handler)(envelope.payload, response, session).await {
527 Ok(()) => {
528 if responded.load(std::sync::atomic::Ordering::SeqCst) {
529 Ok(())
530 } else {
531 Err(anyhow!("handler did not send a response"))?
532 }
533 }
534 Err(error) => {
535 peer.respond_with_error(
536 receipt,
537 proto::Error {
538 message: error.to_string(),
539 },
540 )?;
541 Err(error)
542 }
543 }
544 }
545 })
546 }
547
548 pub fn handle_connection(
549 self: &Arc<Self>,
550 connection: Connection,
551 address: String,
552 user: User,
553 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
554 executor: Executor,
555 ) -> impl Future<Output = Result<()>> {
556 let this = self.clone();
557 let user_id = user.id;
558 let login = user.github_login;
559 let span = info_span!("handle connection", %user_id, %login, %address);
560 let mut teardown = self.teardown.subscribe();
561 async move {
562 let (connection_id, handle_io, mut incoming_rx) = this
563 .peer
564 .add_connection(connection, {
565 let executor = executor.clone();
566 move |duration| executor.sleep(duration)
567 });
568
569 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
570 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
571 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
572
573 if let Some(send_connection_id) = send_connection_id.take() {
574 let _ = send_connection_id.send(connection_id);
575 }
576
577 if !user.connected_once {
578 this.peer.send(connection_id, proto::ShowContacts {})?;
579 this.app_state.db.set_user_connected_once(user_id, true).await?;
580 }
581
582 let (contacts, channels_for_user, channel_invites) = future::try_join3(
583 this.app_state.db.get_contacts(user_id),
584 this.app_state.db.get_channels_for_user(user_id),
585 this.app_state.db.get_channel_invites_for_user(user_id)
586 ).await?;
587
588 {
589 let mut pool = this.connection_pool.lock();
590 pool.add_connection(connection_id, user_id, user.admin);
591 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
592 this.peer.send(connection_id, build_initial_channels_update(
593 channels_for_user,
594 channel_invites
595 ))?;
596 }
597
598 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
599 this.peer.send(connection_id, incoming_call)?;
600 }
601
602 let session = Session {
603 user_id,
604 connection_id,
605 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
606 peer: this.peer.clone(),
607 connection_pool: this.connection_pool.clone(),
608 live_kit_client: this.app_state.live_kit_client.clone(),
609 executor: executor.clone(),
610 };
611 update_user_contacts(user_id, &session).await?;
612
613 let handle_io = handle_io.fuse();
614 futures::pin_mut!(handle_io);
615
616 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
617 // This prevents deadlocks when e.g., client A performs a request to client B and
618 // client B performs a request to client A. If both clients stop processing further
619 // messages until their respective request completes, they won't have a chance to
620 // respond to the other client's request and cause a deadlock.
621 //
622 // This arrangement ensures we will attempt to process earlier messages first, but fall
623 // back to processing messages arrived later in the spirit of making progress.
624 let mut foreground_message_handlers = FuturesUnordered::new();
625 let concurrent_handlers = Arc::new(Semaphore::new(256));
626 loop {
627 let next_message = async {
628 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
629 let message = incoming_rx.next().await;
630 (permit, message)
631 }.fuse();
632 futures::pin_mut!(next_message);
633 futures::select_biased! {
634 _ = teardown.changed().fuse() => return Ok(()),
635 result = handle_io => {
636 if let Err(error) = result {
637 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
638 }
639 break;
640 }
641 _ = foreground_message_handlers.next() => {}
642 next_message = next_message => {
643 let (permit, message) = next_message;
644 if let Some(message) = message {
645 let type_name = message.payload_type_name();
646 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
647 let span_enter = span.enter();
648 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
649 let is_background = message.is_background();
650 let handle_message = (handler)(message, session.clone());
651 drop(span_enter);
652
653 let handle_message = async move {
654 handle_message.await;
655 drop(permit);
656 }.instrument(span);
657 if is_background {
658 executor.spawn_detached(handle_message);
659 } else {
660 foreground_message_handlers.push(handle_message);
661 }
662 } else {
663 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
664 }
665 } else {
666 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
667 break;
668 }
669 }
670 }
671 }
672
673 drop(foreground_message_handlers);
674 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
675 if let Err(error) = connection_lost(session, teardown, executor).await {
676 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
677 }
678
679 Ok(())
680 }.instrument(span)
681 }
682
683 pub async fn invite_code_redeemed(
684 self: &Arc<Self>,
685 inviter_id: UserId,
686 invitee_id: UserId,
687 ) -> Result<()> {
688 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
689 if let Some(code) = &user.invite_code {
690 let pool = self.connection_pool.lock();
691 let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
692 for connection_id in pool.user_connection_ids(inviter_id) {
693 self.peer.send(
694 connection_id,
695 proto::UpdateContacts {
696 contacts: vec![invitee_contact.clone()],
697 ..Default::default()
698 },
699 )?;
700 self.peer.send(
701 connection_id,
702 proto::UpdateInviteInfo {
703 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
704 count: user.invite_count as u32,
705 },
706 )?;
707 }
708 }
709 }
710 Ok(())
711 }
712
713 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
714 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
715 if let Some(invite_code) = &user.invite_code {
716 let pool = self.connection_pool.lock();
717 for connection_id in pool.user_connection_ids(user_id) {
718 self.peer.send(
719 connection_id,
720 proto::UpdateInviteInfo {
721 url: format!(
722 "{}{}",
723 self.app_state.config.invite_link_prefix, invite_code
724 ),
725 count: user.invite_count as u32,
726 },
727 )?;
728 }
729 }
730 }
731 Ok(())
732 }
733
734 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
735 ServerSnapshot {
736 connection_pool: ConnectionPoolGuard {
737 guard: self.connection_pool.lock(),
738 _not_send: PhantomData,
739 },
740 peer: &self.peer,
741 }
742 }
743}
744
745impl<'a> Deref for ConnectionPoolGuard<'a> {
746 type Target = ConnectionPool;
747
748 fn deref(&self) -> &Self::Target {
749 &*self.guard
750 }
751}
752
753impl<'a> DerefMut for ConnectionPoolGuard<'a> {
754 fn deref_mut(&mut self) -> &mut Self::Target {
755 &mut *self.guard
756 }
757}
758
759impl<'a> Drop for ConnectionPoolGuard<'a> {
760 fn drop(&mut self) {
761 #[cfg(test)]
762 self.check_invariants();
763 }
764}
765
766fn broadcast<F>(
767 sender_id: Option<ConnectionId>,
768 receiver_ids: impl IntoIterator<Item = ConnectionId>,
769 mut f: F,
770) where
771 F: FnMut(ConnectionId) -> anyhow::Result<()>,
772{
773 for receiver_id in receiver_ids {
774 if Some(receiver_id) != sender_id {
775 if let Err(error) = f(receiver_id) {
776 tracing::error!("failed to send to {:?} {}", receiver_id, error);
777 }
778 }
779 }
780}
781
782lazy_static! {
783 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
784}
785
786pub struct ProtocolVersion(u32);
787
788impl Header for ProtocolVersion {
789 fn name() -> &'static HeaderName {
790 &ZED_PROTOCOL_VERSION
791 }
792
793 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
794 where
795 Self: Sized,
796 I: Iterator<Item = &'i axum::http::HeaderValue>,
797 {
798 let version = values
799 .next()
800 .ok_or_else(axum::headers::Error::invalid)?
801 .to_str()
802 .map_err(|_| axum::headers::Error::invalid())?
803 .parse()
804 .map_err(|_| axum::headers::Error::invalid())?;
805 Ok(Self(version))
806 }
807
808 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
809 values.extend([self.0.to_string().parse().unwrap()]);
810 }
811}
812
813pub fn routes(server: Arc<Server>) -> Router<Body> {
814 Router::new()
815 .route("/rpc", get(handle_websocket_request))
816 .layer(
817 ServiceBuilder::new()
818 .layer(Extension(server.app_state.clone()))
819 .layer(middleware::from_fn(auth::validate_header)),
820 )
821 .route("/metrics", get(handle_metrics))
822 .layer(Extension(server))
823}
824
825pub async fn handle_websocket_request(
826 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
827 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
828 Extension(server): Extension<Arc<Server>>,
829 Extension(user): Extension<User>,
830 ws: WebSocketUpgrade,
831) -> axum::response::Response {
832 if protocol_version != rpc::PROTOCOL_VERSION {
833 return (
834 StatusCode::UPGRADE_REQUIRED,
835 "client must be upgraded".to_string(),
836 )
837 .into_response();
838 }
839 let socket_address = socket_address.to_string();
840 ws.on_upgrade(move |socket| {
841 use util::ResultExt;
842 let socket = socket
843 .map_ok(to_tungstenite_message)
844 .err_into()
845 .with(|message| async move { Ok(to_axum_message(message)) });
846 let connection = Connection::new(Box::pin(socket));
847 async move {
848 server
849 .handle_connection(connection, socket_address, user, None, Executor::Production)
850 .await
851 .log_err();
852 }
853 })
854}
855
856pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
857 let connections = server
858 .connection_pool
859 .lock()
860 .connections()
861 .filter(|connection| !connection.admin)
862 .count();
863
864 METRIC_CONNECTIONS.set(connections as _);
865
866 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
867 METRIC_SHARED_PROJECTS.set(shared_projects as _);
868
869 let encoder = prometheus::TextEncoder::new();
870 let metric_families = prometheus::gather();
871 let encoded_metrics = encoder
872 .encode_to_string(&metric_families)
873 .map_err(|err| anyhow!("{}", err))?;
874 Ok(encoded_metrics)
875}
876
877#[instrument(err, skip(executor))]
878async fn connection_lost(
879 session: Session,
880 mut teardown: watch::Receiver<()>,
881 executor: Executor,
882) -> Result<()> {
883 session.peer.disconnect(session.connection_id);
884 session
885 .connection_pool()
886 .await
887 .remove_connection(session.connection_id)?;
888
889 session
890 .db()
891 .await
892 .connection_lost(session.connection_id)
893 .await
894 .trace_err();
895
896 futures::select_biased! {
897 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
898 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
899 leave_room_for_session(&session).await.trace_err();
900 leave_channel_buffers_for_session(&session)
901 .await
902 .trace_err();
903
904 if !session
905 .connection_pool()
906 .await
907 .is_user_online(session.user_id)
908 {
909 let db = session.db().await;
910 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
911 room_updated(&room, &session.peer);
912 }
913 }
914
915 update_user_contacts(session.user_id, &session).await?;
916 }
917 _ = teardown.changed().fuse() => {}
918 }
919
920 Ok(())
921}
922
923async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
924 response.send(proto::Ack {})?;
925 Ok(())
926}
927
928async fn create_room(
929 _request: proto::CreateRoom,
930 response: Response<proto::CreateRoom>,
931 session: Session,
932) -> Result<()> {
933 let live_kit_room = nanoid::nanoid!(30);
934
935 let live_kit_connection_info = {
936 let live_kit_room = live_kit_room.clone();
937 let live_kit = session.live_kit_client.as_ref();
938
939 util::async_iife!({
940 let live_kit = live_kit?;
941
942 let token = live_kit
943 .room_token(&live_kit_room, &session.user_id.to_string())
944 .trace_err()?;
945
946 Some(proto::LiveKitConnectionInfo {
947 server_url: live_kit.url().into(),
948 token,
949 })
950 })
951 }
952 .await;
953
954 let room = session
955 .db()
956 .await
957 .create_room(
958 session.user_id,
959 session.connection_id,
960 &live_kit_room,
961 RELEASE_CHANNEL_NAME.as_str(),
962 )
963 .await?;
964
965 response.send(proto::CreateRoomResponse {
966 room: Some(room.clone()),
967 live_kit_connection_info,
968 })?;
969
970 update_user_contacts(session.user_id, &session).await?;
971 Ok(())
972}
973
974async fn join_room(
975 request: proto::JoinRoom,
976 response: Response<proto::JoinRoom>,
977 session: Session,
978) -> Result<()> {
979 let room_id = RoomId::from_proto(request.id);
980 let joined_room = {
981 let room = session
982 .db()
983 .await
984 .join_room(
985 room_id,
986 session.user_id,
987 session.connection_id,
988 RELEASE_CHANNEL_NAME.as_str(),
989 )
990 .await?;
991 room_updated(&room.room, &session.peer);
992 room.into_inner()
993 };
994
995 if let Some(channel_id) = joined_room.channel_id {
996 channel_updated(
997 channel_id,
998 &joined_room.room,
999 &joined_room.channel_members,
1000 &session.peer,
1001 &*session.connection_pool().await,
1002 )
1003 }
1004
1005 for connection_id in session
1006 .connection_pool()
1007 .await
1008 .user_connection_ids(session.user_id)
1009 {
1010 session
1011 .peer
1012 .send(
1013 connection_id,
1014 proto::CallCanceled {
1015 room_id: room_id.to_proto(),
1016 },
1017 )
1018 .trace_err();
1019 }
1020
1021 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1022 if let Some(token) = live_kit
1023 .room_token(
1024 &joined_room.room.live_kit_room,
1025 &session.user_id.to_string(),
1026 )
1027 .trace_err()
1028 {
1029 Some(proto::LiveKitConnectionInfo {
1030 server_url: live_kit.url().into(),
1031 token,
1032 })
1033 } else {
1034 None
1035 }
1036 } else {
1037 None
1038 };
1039
1040 response.send(proto::JoinRoomResponse {
1041 room: Some(joined_room.room),
1042 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
1043 live_kit_connection_info,
1044 })?;
1045
1046 update_user_contacts(session.user_id, &session).await?;
1047 Ok(())
1048}
1049
1050async fn rejoin_room(
1051 request: proto::RejoinRoom,
1052 response: Response<proto::RejoinRoom>,
1053 session: Session,
1054) -> Result<()> {
1055 let room;
1056 let channel_id;
1057 let channel_members;
1058 {
1059 let mut rejoined_room = session
1060 .db()
1061 .await
1062 .rejoin_room(request, session.user_id, session.connection_id)
1063 .await?;
1064
1065 response.send(proto::RejoinRoomResponse {
1066 room: Some(rejoined_room.room.clone()),
1067 reshared_projects: rejoined_room
1068 .reshared_projects
1069 .iter()
1070 .map(|project| proto::ResharedProject {
1071 id: project.id.to_proto(),
1072 collaborators: project
1073 .collaborators
1074 .iter()
1075 .map(|collaborator| collaborator.to_proto())
1076 .collect(),
1077 })
1078 .collect(),
1079 rejoined_projects: rejoined_room
1080 .rejoined_projects
1081 .iter()
1082 .map(|rejoined_project| proto::RejoinedProject {
1083 id: rejoined_project.id.to_proto(),
1084 worktrees: rejoined_project
1085 .worktrees
1086 .iter()
1087 .map(|worktree| proto::WorktreeMetadata {
1088 id: worktree.id,
1089 root_name: worktree.root_name.clone(),
1090 visible: worktree.visible,
1091 abs_path: worktree.abs_path.clone(),
1092 })
1093 .collect(),
1094 collaborators: rejoined_project
1095 .collaborators
1096 .iter()
1097 .map(|collaborator| collaborator.to_proto())
1098 .collect(),
1099 language_servers: rejoined_project.language_servers.clone(),
1100 })
1101 .collect(),
1102 })?;
1103 room_updated(&rejoined_room.room, &session.peer);
1104
1105 for project in &rejoined_room.reshared_projects {
1106 for collaborator in &project.collaborators {
1107 session
1108 .peer
1109 .send(
1110 collaborator.connection_id,
1111 proto::UpdateProjectCollaborator {
1112 project_id: project.id.to_proto(),
1113 old_peer_id: Some(project.old_connection_id.into()),
1114 new_peer_id: Some(session.connection_id.into()),
1115 },
1116 )
1117 .trace_err();
1118 }
1119
1120 broadcast(
1121 Some(session.connection_id),
1122 project
1123 .collaborators
1124 .iter()
1125 .map(|collaborator| collaborator.connection_id),
1126 |connection_id| {
1127 session.peer.forward_send(
1128 session.connection_id,
1129 connection_id,
1130 proto::UpdateProject {
1131 project_id: project.id.to_proto(),
1132 worktrees: project.worktrees.clone(),
1133 },
1134 )
1135 },
1136 );
1137 }
1138
1139 for project in &rejoined_room.rejoined_projects {
1140 for collaborator in &project.collaborators {
1141 session
1142 .peer
1143 .send(
1144 collaborator.connection_id,
1145 proto::UpdateProjectCollaborator {
1146 project_id: project.id.to_proto(),
1147 old_peer_id: Some(project.old_connection_id.into()),
1148 new_peer_id: Some(session.connection_id.into()),
1149 },
1150 )
1151 .trace_err();
1152 }
1153 }
1154
1155 for project in &mut rejoined_room.rejoined_projects {
1156 for worktree in mem::take(&mut project.worktrees) {
1157 #[cfg(any(test, feature = "test-support"))]
1158 const MAX_CHUNK_SIZE: usize = 2;
1159 #[cfg(not(any(test, feature = "test-support")))]
1160 const MAX_CHUNK_SIZE: usize = 256;
1161
1162 // Stream this worktree's entries.
1163 let message = proto::UpdateWorktree {
1164 project_id: project.id.to_proto(),
1165 worktree_id: worktree.id,
1166 abs_path: worktree.abs_path.clone(),
1167 root_name: worktree.root_name,
1168 updated_entries: worktree.updated_entries,
1169 removed_entries: worktree.removed_entries,
1170 scan_id: worktree.scan_id,
1171 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1172 updated_repositories: worktree.updated_repositories,
1173 removed_repositories: worktree.removed_repositories,
1174 };
1175 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1176 session.peer.send(session.connection_id, update.clone())?;
1177 }
1178
1179 // Stream this worktree's diagnostics.
1180 for summary in worktree.diagnostic_summaries {
1181 session.peer.send(
1182 session.connection_id,
1183 proto::UpdateDiagnosticSummary {
1184 project_id: project.id.to_proto(),
1185 worktree_id: worktree.id,
1186 summary: Some(summary),
1187 },
1188 )?;
1189 }
1190
1191 for settings_file in worktree.settings_files {
1192 session.peer.send(
1193 session.connection_id,
1194 proto::UpdateWorktreeSettings {
1195 project_id: project.id.to_proto(),
1196 worktree_id: worktree.id,
1197 path: settings_file.path,
1198 content: Some(settings_file.content),
1199 },
1200 )?;
1201 }
1202 }
1203
1204 for language_server in &project.language_servers {
1205 session.peer.send(
1206 session.connection_id,
1207 proto::UpdateLanguageServer {
1208 project_id: project.id.to_proto(),
1209 language_server_id: language_server.id,
1210 variant: Some(
1211 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1212 proto::LspDiskBasedDiagnosticsUpdated {},
1213 ),
1214 ),
1215 },
1216 )?;
1217 }
1218 }
1219
1220 let rejoined_room = rejoined_room.into_inner();
1221
1222 room = rejoined_room.room;
1223 channel_id = rejoined_room.channel_id;
1224 channel_members = rejoined_room.channel_members;
1225 }
1226
1227 if let Some(channel_id) = channel_id {
1228 channel_updated(
1229 channel_id,
1230 &room,
1231 &channel_members,
1232 &session.peer,
1233 &*session.connection_pool().await,
1234 );
1235 }
1236
1237 update_user_contacts(session.user_id, &session).await?;
1238 Ok(())
1239}
1240
1241async fn leave_room(
1242 _: proto::LeaveRoom,
1243 response: Response<proto::LeaveRoom>,
1244 session: Session,
1245) -> Result<()> {
1246 leave_room_for_session(&session).await?;
1247 response.send(proto::Ack {})?;
1248 Ok(())
1249}
1250
1251async fn call(
1252 request: proto::Call,
1253 response: Response<proto::Call>,
1254 session: Session,
1255) -> Result<()> {
1256 let room_id = RoomId::from_proto(request.room_id);
1257 let calling_user_id = session.user_id;
1258 let calling_connection_id = session.connection_id;
1259 let called_user_id = UserId::from_proto(request.called_user_id);
1260 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1261 if !session
1262 .db()
1263 .await
1264 .has_contact(calling_user_id, called_user_id)
1265 .await?
1266 {
1267 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1268 }
1269
1270 let incoming_call = {
1271 let (room, incoming_call) = &mut *session
1272 .db()
1273 .await
1274 .call(
1275 room_id,
1276 calling_user_id,
1277 calling_connection_id,
1278 called_user_id,
1279 initial_project_id,
1280 )
1281 .await?;
1282 room_updated(&room, &session.peer);
1283 mem::take(incoming_call)
1284 };
1285 update_user_contacts(called_user_id, &session).await?;
1286
1287 let mut calls = session
1288 .connection_pool()
1289 .await
1290 .user_connection_ids(called_user_id)
1291 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1292 .collect::<FuturesUnordered<_>>();
1293
1294 while let Some(call_response) = calls.next().await {
1295 match call_response.as_ref() {
1296 Ok(_) => {
1297 response.send(proto::Ack {})?;
1298 return Ok(());
1299 }
1300 Err(_) => {
1301 call_response.trace_err();
1302 }
1303 }
1304 }
1305
1306 {
1307 let room = session
1308 .db()
1309 .await
1310 .call_failed(room_id, called_user_id)
1311 .await?;
1312 room_updated(&room, &session.peer);
1313 }
1314 update_user_contacts(called_user_id, &session).await?;
1315
1316 Err(anyhow!("failed to ring user"))?
1317}
1318
1319async fn cancel_call(
1320 request: proto::CancelCall,
1321 response: Response<proto::CancelCall>,
1322 session: Session,
1323) -> Result<()> {
1324 let called_user_id = UserId::from_proto(request.called_user_id);
1325 let room_id = RoomId::from_proto(request.room_id);
1326 {
1327 let room = session
1328 .db()
1329 .await
1330 .cancel_call(room_id, session.connection_id, called_user_id)
1331 .await?;
1332 room_updated(&room, &session.peer);
1333 }
1334
1335 for connection_id in session
1336 .connection_pool()
1337 .await
1338 .user_connection_ids(called_user_id)
1339 {
1340 session
1341 .peer
1342 .send(
1343 connection_id,
1344 proto::CallCanceled {
1345 room_id: room_id.to_proto(),
1346 },
1347 )
1348 .trace_err();
1349 }
1350 response.send(proto::Ack {})?;
1351
1352 update_user_contacts(called_user_id, &session).await?;
1353 Ok(())
1354}
1355
1356async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1357 let room_id = RoomId::from_proto(message.room_id);
1358 {
1359 let room = session
1360 .db()
1361 .await
1362 .decline_call(Some(room_id), session.user_id)
1363 .await?
1364 .ok_or_else(|| anyhow!("failed to decline call"))?;
1365 room_updated(&room, &session.peer);
1366 }
1367
1368 for connection_id in session
1369 .connection_pool()
1370 .await
1371 .user_connection_ids(session.user_id)
1372 {
1373 session
1374 .peer
1375 .send(
1376 connection_id,
1377 proto::CallCanceled {
1378 room_id: room_id.to_proto(),
1379 },
1380 )
1381 .trace_err();
1382 }
1383 update_user_contacts(session.user_id, &session).await?;
1384 Ok(())
1385}
1386
1387async fn update_participant_location(
1388 request: proto::UpdateParticipantLocation,
1389 response: Response<proto::UpdateParticipantLocation>,
1390 session: Session,
1391) -> Result<()> {
1392 let room_id = RoomId::from_proto(request.room_id);
1393 let location = request
1394 .location
1395 .ok_or_else(|| anyhow!("invalid location"))?;
1396
1397 let db = session.db().await;
1398 let room = db
1399 .update_room_participant_location(room_id, session.connection_id, location)
1400 .await?;
1401
1402 room_updated(&room, &session.peer);
1403 response.send(proto::Ack {})?;
1404 Ok(())
1405}
1406
1407async fn share_project(
1408 request: proto::ShareProject,
1409 response: Response<proto::ShareProject>,
1410 session: Session,
1411) -> Result<()> {
1412 let (project_id, room) = &*session
1413 .db()
1414 .await
1415 .share_project(
1416 RoomId::from_proto(request.room_id),
1417 session.connection_id,
1418 &request.worktrees,
1419 )
1420 .await?;
1421 response.send(proto::ShareProjectResponse {
1422 project_id: project_id.to_proto(),
1423 })?;
1424 room_updated(&room, &session.peer);
1425
1426 Ok(())
1427}
1428
1429async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1430 let project_id = ProjectId::from_proto(message.project_id);
1431
1432 let (room, guest_connection_ids) = &*session
1433 .db()
1434 .await
1435 .unshare_project(project_id, session.connection_id)
1436 .await?;
1437
1438 broadcast(
1439 Some(session.connection_id),
1440 guest_connection_ids.iter().copied(),
1441 |conn_id| session.peer.send(conn_id, message.clone()),
1442 );
1443 room_updated(&room, &session.peer);
1444
1445 Ok(())
1446}
1447
1448async fn join_project(
1449 request: proto::JoinProject,
1450 response: Response<proto::JoinProject>,
1451 session: Session,
1452) -> Result<()> {
1453 let project_id = ProjectId::from_proto(request.project_id);
1454 let guest_user_id = session.user_id;
1455
1456 tracing::info!(%project_id, "join project");
1457
1458 let (project, replica_id) = &mut *session
1459 .db()
1460 .await
1461 .join_project(project_id, session.connection_id)
1462 .await?;
1463
1464 let collaborators = project
1465 .collaborators
1466 .iter()
1467 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1468 .map(|collaborator| collaborator.to_proto())
1469 .collect::<Vec<_>>();
1470
1471 let worktrees = project
1472 .worktrees
1473 .iter()
1474 .map(|(id, worktree)| proto::WorktreeMetadata {
1475 id: *id,
1476 root_name: worktree.root_name.clone(),
1477 visible: worktree.visible,
1478 abs_path: worktree.abs_path.clone(),
1479 })
1480 .collect::<Vec<_>>();
1481
1482 for collaborator in &collaborators {
1483 session
1484 .peer
1485 .send(
1486 collaborator.peer_id.unwrap().into(),
1487 proto::AddProjectCollaborator {
1488 project_id: project_id.to_proto(),
1489 collaborator: Some(proto::Collaborator {
1490 peer_id: Some(session.connection_id.into()),
1491 replica_id: replica_id.0 as u32,
1492 user_id: guest_user_id.to_proto(),
1493 }),
1494 },
1495 )
1496 .trace_err();
1497 }
1498
1499 // First, we send the metadata associated with each worktree.
1500 response.send(proto::JoinProjectResponse {
1501 worktrees: worktrees.clone(),
1502 replica_id: replica_id.0 as u32,
1503 collaborators: collaborators.clone(),
1504 language_servers: project.language_servers.clone(),
1505 })?;
1506
1507 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1508 #[cfg(any(test, feature = "test-support"))]
1509 const MAX_CHUNK_SIZE: usize = 2;
1510 #[cfg(not(any(test, feature = "test-support")))]
1511 const MAX_CHUNK_SIZE: usize = 256;
1512
1513 // Stream this worktree's entries.
1514 let message = proto::UpdateWorktree {
1515 project_id: project_id.to_proto(),
1516 worktree_id,
1517 abs_path: worktree.abs_path.clone(),
1518 root_name: worktree.root_name,
1519 updated_entries: worktree.entries,
1520 removed_entries: Default::default(),
1521 scan_id: worktree.scan_id,
1522 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1523 updated_repositories: worktree.repository_entries.into_values().collect(),
1524 removed_repositories: Default::default(),
1525 };
1526 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1527 session.peer.send(session.connection_id, update.clone())?;
1528 }
1529
1530 // Stream this worktree's diagnostics.
1531 for summary in worktree.diagnostic_summaries {
1532 session.peer.send(
1533 session.connection_id,
1534 proto::UpdateDiagnosticSummary {
1535 project_id: project_id.to_proto(),
1536 worktree_id: worktree.id,
1537 summary: Some(summary),
1538 },
1539 )?;
1540 }
1541
1542 for settings_file in worktree.settings_files {
1543 session.peer.send(
1544 session.connection_id,
1545 proto::UpdateWorktreeSettings {
1546 project_id: project_id.to_proto(),
1547 worktree_id: worktree.id,
1548 path: settings_file.path,
1549 content: Some(settings_file.content),
1550 },
1551 )?;
1552 }
1553 }
1554
1555 for language_server in &project.language_servers {
1556 session.peer.send(
1557 session.connection_id,
1558 proto::UpdateLanguageServer {
1559 project_id: project_id.to_proto(),
1560 language_server_id: language_server.id,
1561 variant: Some(
1562 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1563 proto::LspDiskBasedDiagnosticsUpdated {},
1564 ),
1565 ),
1566 },
1567 )?;
1568 }
1569
1570 Ok(())
1571}
1572
1573async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1574 let sender_id = session.connection_id;
1575 let project_id = ProjectId::from_proto(request.project_id);
1576
1577 let (room, project) = &*session
1578 .db()
1579 .await
1580 .leave_project(project_id, sender_id)
1581 .await?;
1582 tracing::info!(
1583 %project_id,
1584 host_user_id = %project.host_user_id,
1585 host_connection_id = %project.host_connection_id,
1586 "leave project"
1587 );
1588
1589 project_left(&project, &session);
1590 room_updated(&room, &session.peer);
1591
1592 Ok(())
1593}
1594
1595async fn update_project(
1596 request: proto::UpdateProject,
1597 response: Response<proto::UpdateProject>,
1598 session: Session,
1599) -> Result<()> {
1600 let project_id = ProjectId::from_proto(request.project_id);
1601 let (room, guest_connection_ids) = &*session
1602 .db()
1603 .await
1604 .update_project(project_id, session.connection_id, &request.worktrees)
1605 .await?;
1606 broadcast(
1607 Some(session.connection_id),
1608 guest_connection_ids.iter().copied(),
1609 |connection_id| {
1610 session
1611 .peer
1612 .forward_send(session.connection_id, connection_id, request.clone())
1613 },
1614 );
1615 room_updated(&room, &session.peer);
1616 response.send(proto::Ack {})?;
1617
1618 Ok(())
1619}
1620
1621async fn update_worktree(
1622 request: proto::UpdateWorktree,
1623 response: Response<proto::UpdateWorktree>,
1624 session: Session,
1625) -> Result<()> {
1626 let guest_connection_ids = session
1627 .db()
1628 .await
1629 .update_worktree(&request, session.connection_id)
1630 .await?;
1631
1632 broadcast(
1633 Some(session.connection_id),
1634 guest_connection_ids.iter().copied(),
1635 |connection_id| {
1636 session
1637 .peer
1638 .forward_send(session.connection_id, connection_id, request.clone())
1639 },
1640 );
1641 response.send(proto::Ack {})?;
1642 Ok(())
1643}
1644
1645async fn update_diagnostic_summary(
1646 message: proto::UpdateDiagnosticSummary,
1647 session: Session,
1648) -> Result<()> {
1649 let guest_connection_ids = session
1650 .db()
1651 .await
1652 .update_diagnostic_summary(&message, session.connection_id)
1653 .await?;
1654
1655 broadcast(
1656 Some(session.connection_id),
1657 guest_connection_ids.iter().copied(),
1658 |connection_id| {
1659 session
1660 .peer
1661 .forward_send(session.connection_id, connection_id, message.clone())
1662 },
1663 );
1664
1665 Ok(())
1666}
1667
1668async fn update_worktree_settings(
1669 message: proto::UpdateWorktreeSettings,
1670 session: Session,
1671) -> Result<()> {
1672 let guest_connection_ids = session
1673 .db()
1674 .await
1675 .update_worktree_settings(&message, session.connection_id)
1676 .await?;
1677
1678 broadcast(
1679 Some(session.connection_id),
1680 guest_connection_ids.iter().copied(),
1681 |connection_id| {
1682 session
1683 .peer
1684 .forward_send(session.connection_id, connection_id, message.clone())
1685 },
1686 );
1687
1688 Ok(())
1689}
1690
1691async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1692 broadcast_project_message(request.project_id, request, session).await
1693}
1694
1695async fn start_language_server(
1696 request: proto::StartLanguageServer,
1697 session: Session,
1698) -> Result<()> {
1699 let guest_connection_ids = session
1700 .db()
1701 .await
1702 .start_language_server(&request, session.connection_id)
1703 .await?;
1704
1705 broadcast(
1706 Some(session.connection_id),
1707 guest_connection_ids.iter().copied(),
1708 |connection_id| {
1709 session
1710 .peer
1711 .forward_send(session.connection_id, connection_id, request.clone())
1712 },
1713 );
1714 Ok(())
1715}
1716
1717async fn update_language_server(
1718 request: proto::UpdateLanguageServer,
1719 session: Session,
1720) -> Result<()> {
1721 session.executor.record_backtrace();
1722 let project_id = ProjectId::from_proto(request.project_id);
1723 let project_connection_ids = session
1724 .db()
1725 .await
1726 .project_connection_ids(project_id, session.connection_id)
1727 .await?;
1728 broadcast(
1729 Some(session.connection_id),
1730 project_connection_ids.iter().copied(),
1731 |connection_id| {
1732 session
1733 .peer
1734 .forward_send(session.connection_id, connection_id, request.clone())
1735 },
1736 );
1737 Ok(())
1738}
1739
1740async fn forward_project_request<T>(
1741 request: T,
1742 response: Response<T>,
1743 session: Session,
1744) -> Result<()>
1745where
1746 T: EntityMessage + RequestMessage,
1747{
1748 session.executor.record_backtrace();
1749 let project_id = ProjectId::from_proto(request.remote_entity_id());
1750 let host_connection_id = {
1751 let collaborators = session
1752 .db()
1753 .await
1754 .project_collaborators(project_id, session.connection_id)
1755 .await?;
1756 collaborators
1757 .iter()
1758 .find(|collaborator| collaborator.is_host)
1759 .ok_or_else(|| anyhow!("host not found"))?
1760 .connection_id
1761 };
1762
1763 let payload = session
1764 .peer
1765 .forward_request(session.connection_id, host_connection_id, request)
1766 .await?;
1767
1768 response.send(payload)?;
1769 Ok(())
1770}
1771
1772async fn create_buffer_for_peer(
1773 request: proto::CreateBufferForPeer,
1774 session: Session,
1775) -> Result<()> {
1776 session.executor.record_backtrace();
1777 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1778 session
1779 .peer
1780 .forward_send(session.connection_id, peer_id.into(), request)?;
1781 Ok(())
1782}
1783
1784async fn update_buffer(
1785 request: proto::UpdateBuffer,
1786 response: Response<proto::UpdateBuffer>,
1787 session: Session,
1788) -> Result<()> {
1789 session.executor.record_backtrace();
1790 let project_id = ProjectId::from_proto(request.project_id);
1791 let mut guest_connection_ids;
1792 let mut host_connection_id = None;
1793 {
1794 let collaborators = session
1795 .db()
1796 .await
1797 .project_collaborators(project_id, session.connection_id)
1798 .await?;
1799 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1800 for collaborator in collaborators.iter() {
1801 if collaborator.is_host {
1802 host_connection_id = Some(collaborator.connection_id);
1803 } else {
1804 guest_connection_ids.push(collaborator.connection_id);
1805 }
1806 }
1807 }
1808 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1809
1810 session.executor.record_backtrace();
1811 broadcast(
1812 Some(session.connection_id),
1813 guest_connection_ids,
1814 |connection_id| {
1815 session
1816 .peer
1817 .forward_send(session.connection_id, connection_id, request.clone())
1818 },
1819 );
1820 if host_connection_id != session.connection_id {
1821 session
1822 .peer
1823 .forward_request(session.connection_id, host_connection_id, request.clone())
1824 .await?;
1825 }
1826
1827 response.send(proto::Ack {})?;
1828 Ok(())
1829}
1830
1831async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1832 let project_id = ProjectId::from_proto(request.project_id);
1833 let project_connection_ids = session
1834 .db()
1835 .await
1836 .project_connection_ids(project_id, session.connection_id)
1837 .await?;
1838
1839 broadcast(
1840 Some(session.connection_id),
1841 project_connection_ids.iter().copied(),
1842 |connection_id| {
1843 session
1844 .peer
1845 .forward_send(session.connection_id, connection_id, request.clone())
1846 },
1847 );
1848 Ok(())
1849}
1850
1851async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1852 let project_id = ProjectId::from_proto(request.project_id);
1853 let project_connection_ids = session
1854 .db()
1855 .await
1856 .project_connection_ids(project_id, session.connection_id)
1857 .await?;
1858 broadcast(
1859 Some(session.connection_id),
1860 project_connection_ids.iter().copied(),
1861 |connection_id| {
1862 session
1863 .peer
1864 .forward_send(session.connection_id, connection_id, request.clone())
1865 },
1866 );
1867 Ok(())
1868}
1869
1870async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1871 broadcast_project_message(request.project_id, request, session).await
1872}
1873
1874async fn broadcast_project_message<T: EnvelopedMessage>(
1875 project_id: u64,
1876 request: T,
1877 session: Session,
1878) -> Result<()> {
1879 let project_id = ProjectId::from_proto(project_id);
1880 let project_connection_ids = session
1881 .db()
1882 .await
1883 .project_connection_ids(project_id, session.connection_id)
1884 .await?;
1885 broadcast(
1886 Some(session.connection_id),
1887 project_connection_ids.iter().copied(),
1888 |connection_id| {
1889 session
1890 .peer
1891 .forward_send(session.connection_id, connection_id, request.clone())
1892 },
1893 );
1894 Ok(())
1895}
1896
1897async fn follow(
1898 request: proto::Follow,
1899 response: Response<proto::Follow>,
1900 session: Session,
1901) -> Result<()> {
1902 let room_id = RoomId::from_proto(request.room_id);
1903 let project_id = request.project_id.map(ProjectId::from_proto);
1904 let leader_id = request
1905 .leader_id
1906 .ok_or_else(|| anyhow!("invalid leader id"))?
1907 .into();
1908 let follower_id = session.connection_id;
1909
1910 session
1911 .db()
1912 .await
1913 .check_room_participants(room_id, leader_id, session.connection_id)
1914 .await?;
1915
1916 let response_payload = session
1917 .peer
1918 .forward_request(session.connection_id, leader_id, request)
1919 .await?;
1920 response.send(response_payload)?;
1921
1922 if let Some(project_id) = project_id {
1923 let room = session
1924 .db()
1925 .await
1926 .follow(room_id, project_id, leader_id, follower_id)
1927 .await?;
1928 room_updated(&room, &session.peer);
1929 }
1930
1931 Ok(())
1932}
1933
1934async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1935 let room_id = RoomId::from_proto(request.room_id);
1936 let project_id = request.project_id.map(ProjectId::from_proto);
1937 let leader_id = request
1938 .leader_id
1939 .ok_or_else(|| anyhow!("invalid leader id"))?
1940 .into();
1941 let follower_id = session.connection_id;
1942
1943 session
1944 .db()
1945 .await
1946 .check_room_participants(room_id, leader_id, session.connection_id)
1947 .await?;
1948
1949 session
1950 .peer
1951 .forward_send(session.connection_id, leader_id, request)?;
1952
1953 if let Some(project_id) = project_id {
1954 let room = session
1955 .db()
1956 .await
1957 .unfollow(room_id, project_id, leader_id, follower_id)
1958 .await?;
1959 room_updated(&room, &session.peer);
1960 }
1961
1962 Ok(())
1963}
1964
1965async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1966 let room_id = RoomId::from_proto(request.room_id);
1967 let database = session.db.lock().await;
1968
1969 let connection_ids = if let Some(project_id) = request.project_id {
1970 let project_id = ProjectId::from_proto(project_id);
1971 database
1972 .project_connection_ids(project_id, session.connection_id)
1973 .await?
1974 } else {
1975 database
1976 .room_connection_ids(room_id, session.connection_id)
1977 .await?
1978 };
1979
1980 // For now, don't send view update messages back to that view's current leader.
1981 let connection_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
1982 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1983 _ => None,
1984 });
1985
1986 for follower_peer_id in request.follower_ids.iter().copied() {
1987 let follower_connection_id = follower_peer_id.into();
1988 if Some(follower_peer_id) != connection_id_to_omit
1989 && connection_ids.contains(&follower_connection_id)
1990 {
1991 session.peer.forward_send(
1992 session.connection_id,
1993 follower_connection_id,
1994 request.clone(),
1995 )?;
1996 }
1997 }
1998 Ok(())
1999}
2000
2001async fn get_users(
2002 request: proto::GetUsers,
2003 response: Response<proto::GetUsers>,
2004 session: Session,
2005) -> Result<()> {
2006 let user_ids = request
2007 .user_ids
2008 .into_iter()
2009 .map(UserId::from_proto)
2010 .collect();
2011 let users = session
2012 .db()
2013 .await
2014 .get_users_by_ids(user_ids)
2015 .await?
2016 .into_iter()
2017 .map(|user| proto::User {
2018 id: user.id.to_proto(),
2019 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2020 github_login: user.github_login,
2021 })
2022 .collect();
2023 response.send(proto::UsersResponse { users })?;
2024 Ok(())
2025}
2026
2027async fn fuzzy_search_users(
2028 request: proto::FuzzySearchUsers,
2029 response: Response<proto::FuzzySearchUsers>,
2030 session: Session,
2031) -> Result<()> {
2032 let query = request.query;
2033 let users = match query.len() {
2034 0 => vec![],
2035 1 | 2 => session
2036 .db()
2037 .await
2038 .get_user_by_github_login(&query)
2039 .await?
2040 .into_iter()
2041 .collect(),
2042 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2043 };
2044 let users = users
2045 .into_iter()
2046 .filter(|user| user.id != session.user_id)
2047 .map(|user| proto::User {
2048 id: user.id.to_proto(),
2049 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2050 github_login: user.github_login,
2051 })
2052 .collect();
2053 response.send(proto::UsersResponse { users })?;
2054 Ok(())
2055}
2056
2057async fn request_contact(
2058 request: proto::RequestContact,
2059 response: Response<proto::RequestContact>,
2060 session: Session,
2061) -> Result<()> {
2062 let requester_id = session.user_id;
2063 let responder_id = UserId::from_proto(request.responder_id);
2064 if requester_id == responder_id {
2065 return Err(anyhow!("cannot add yourself as a contact"))?;
2066 }
2067
2068 session
2069 .db()
2070 .await
2071 .send_contact_request(requester_id, responder_id)
2072 .await?;
2073
2074 // Update outgoing contact requests of requester
2075 let mut update = proto::UpdateContacts::default();
2076 update.outgoing_requests.push(responder_id.to_proto());
2077 for connection_id in session
2078 .connection_pool()
2079 .await
2080 .user_connection_ids(requester_id)
2081 {
2082 session.peer.send(connection_id, update.clone())?;
2083 }
2084
2085 // Update incoming contact requests of responder
2086 let mut update = proto::UpdateContacts::default();
2087 update
2088 .incoming_requests
2089 .push(proto::IncomingContactRequest {
2090 requester_id: requester_id.to_proto(),
2091 should_notify: true,
2092 });
2093 for connection_id in session
2094 .connection_pool()
2095 .await
2096 .user_connection_ids(responder_id)
2097 {
2098 session.peer.send(connection_id, update.clone())?;
2099 }
2100
2101 response.send(proto::Ack {})?;
2102 Ok(())
2103}
2104
2105async fn respond_to_contact_request(
2106 request: proto::RespondToContactRequest,
2107 response: Response<proto::RespondToContactRequest>,
2108 session: Session,
2109) -> Result<()> {
2110 let responder_id = session.user_id;
2111 let requester_id = UserId::from_proto(request.requester_id);
2112 let db = session.db().await;
2113 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2114 db.dismiss_contact_notification(responder_id, requester_id)
2115 .await?;
2116 } else {
2117 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2118
2119 db.respond_to_contact_request(responder_id, requester_id, accept)
2120 .await?;
2121 let requester_busy = db.is_user_busy(requester_id).await?;
2122 let responder_busy = db.is_user_busy(responder_id).await?;
2123
2124 let pool = session.connection_pool().await;
2125 // Update responder with new contact
2126 let mut update = proto::UpdateContacts::default();
2127 if accept {
2128 update
2129 .contacts
2130 .push(contact_for_user(requester_id, false, requester_busy, &pool));
2131 }
2132 update
2133 .remove_incoming_requests
2134 .push(requester_id.to_proto());
2135 for connection_id in pool.user_connection_ids(responder_id) {
2136 session.peer.send(connection_id, update.clone())?;
2137 }
2138
2139 // Update requester with new contact
2140 let mut update = proto::UpdateContacts::default();
2141 if accept {
2142 update
2143 .contacts
2144 .push(contact_for_user(responder_id, true, responder_busy, &pool));
2145 }
2146 update
2147 .remove_outgoing_requests
2148 .push(responder_id.to_proto());
2149 for connection_id in pool.user_connection_ids(requester_id) {
2150 session.peer.send(connection_id, update.clone())?;
2151 }
2152 }
2153
2154 response.send(proto::Ack {})?;
2155 Ok(())
2156}
2157
2158async fn remove_contact(
2159 request: proto::RemoveContact,
2160 response: Response<proto::RemoveContact>,
2161 session: Session,
2162) -> Result<()> {
2163 let requester_id = session.user_id;
2164 let responder_id = UserId::from_proto(request.user_id);
2165 let db = session.db().await;
2166 let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
2167
2168 let pool = session.connection_pool().await;
2169 // Update outgoing contact requests of requester
2170 let mut update = proto::UpdateContacts::default();
2171 if contact_accepted {
2172 update.remove_contacts.push(responder_id.to_proto());
2173 } else {
2174 update
2175 .remove_outgoing_requests
2176 .push(responder_id.to_proto());
2177 }
2178 for connection_id in pool.user_connection_ids(requester_id) {
2179 session.peer.send(connection_id, update.clone())?;
2180 }
2181
2182 // Update incoming contact requests of responder
2183 let mut update = proto::UpdateContacts::default();
2184 if contact_accepted {
2185 update.remove_contacts.push(requester_id.to_proto());
2186 } else {
2187 update
2188 .remove_incoming_requests
2189 .push(requester_id.to_proto());
2190 }
2191 for connection_id in pool.user_connection_ids(responder_id) {
2192 session.peer.send(connection_id, update.clone())?;
2193 }
2194
2195 response.send(proto::Ack {})?;
2196 Ok(())
2197}
2198
2199async fn create_channel(
2200 request: proto::CreateChannel,
2201 response: Response<proto::CreateChannel>,
2202 session: Session,
2203) -> Result<()> {
2204 let db = session.db().await;
2205
2206 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2207 let id = db
2208 .create_channel(&request.name, parent_id, session.user_id)
2209 .await?;
2210
2211 let channel = proto::Channel {
2212 id: id.to_proto(),
2213 name: request.name,
2214 };
2215
2216 response.send(proto::CreateChannelResponse {
2217 channel: Some(channel.clone()),
2218 parent_id: request.parent_id,
2219 })?;
2220
2221 let Some(parent_id) = parent_id else {
2222 return Ok(());
2223 };
2224
2225 let update = proto::UpdateChannels {
2226 channels: vec![channel],
2227 insert_edge: vec![ChannelEdge {
2228 parent_id: parent_id.to_proto(),
2229 channel_id: id.to_proto(),
2230 }],
2231 ..Default::default()
2232 };
2233
2234 let user_ids_to_notify = db.get_channel_members(parent_id).await?;
2235
2236 let connection_pool = session.connection_pool().await;
2237 for user_id in user_ids_to_notify {
2238 for connection_id in connection_pool.user_connection_ids(user_id) {
2239 if user_id == session.user_id {
2240 continue;
2241 }
2242 session.peer.send(connection_id, update.clone())?;
2243 }
2244 }
2245
2246 Ok(())
2247}
2248
2249async fn delete_channel(
2250 request: proto::DeleteChannel,
2251 response: Response<proto::DeleteChannel>,
2252 session: Session,
2253) -> Result<()> {
2254 let db = session.db().await;
2255
2256 let channel_id = request.channel_id;
2257 let (removed_channels, member_ids) = db
2258 .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2259 .await?;
2260 response.send(proto::Ack {})?;
2261
2262 // Notify members of removed channels
2263 let mut update = proto::UpdateChannels::default();
2264 update
2265 .delete_channels
2266 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2267
2268 let connection_pool = session.connection_pool().await;
2269 for member_id in member_ids {
2270 for connection_id in connection_pool.user_connection_ids(member_id) {
2271 session.peer.send(connection_id, update.clone())?;
2272 }
2273 }
2274
2275 Ok(())
2276}
2277
2278async fn invite_channel_member(
2279 request: proto::InviteChannelMember,
2280 response: Response<proto::InviteChannelMember>,
2281 session: Session,
2282) -> Result<()> {
2283 let db = session.db().await;
2284 let channel_id = ChannelId::from_proto(request.channel_id);
2285 let invitee_id = UserId::from_proto(request.user_id);
2286 db.invite_channel_member(channel_id, invitee_id, session.user_id, request.admin)
2287 .await?;
2288
2289 let (channel, _) = db
2290 .get_channel(channel_id, session.user_id)
2291 .await?
2292 .ok_or_else(|| anyhow!("channel not found"))?;
2293
2294 let mut update = proto::UpdateChannels::default();
2295 update.channel_invitations.push(proto::Channel {
2296 id: channel.id.to_proto(),
2297 name: channel.name,
2298 });
2299 for connection_id in session
2300 .connection_pool()
2301 .await
2302 .user_connection_ids(invitee_id)
2303 {
2304 session.peer.send(connection_id, update.clone())?;
2305 }
2306
2307 response.send(proto::Ack {})?;
2308 Ok(())
2309}
2310
2311async fn remove_channel_member(
2312 request: proto::RemoveChannelMember,
2313 response: Response<proto::RemoveChannelMember>,
2314 session: Session,
2315) -> Result<()> {
2316 let db = session.db().await;
2317 let channel_id = ChannelId::from_proto(request.channel_id);
2318 let member_id = UserId::from_proto(request.user_id);
2319
2320 db.remove_channel_member(channel_id, member_id, session.user_id)
2321 .await?;
2322
2323 let mut update = proto::UpdateChannels::default();
2324 update.delete_channels.push(channel_id.to_proto());
2325
2326 for connection_id in session
2327 .connection_pool()
2328 .await
2329 .user_connection_ids(member_id)
2330 {
2331 session.peer.send(connection_id, update.clone())?;
2332 }
2333
2334 response.send(proto::Ack {})?;
2335 Ok(())
2336}
2337
2338async fn set_channel_member_admin(
2339 request: proto::SetChannelMemberAdmin,
2340 response: Response<proto::SetChannelMemberAdmin>,
2341 session: Session,
2342) -> Result<()> {
2343 let db = session.db().await;
2344 let channel_id = ChannelId::from_proto(request.channel_id);
2345 let member_id = UserId::from_proto(request.user_id);
2346 db.set_channel_member_admin(channel_id, session.user_id, member_id, request.admin)
2347 .await?;
2348
2349 let (channel, has_accepted) = db
2350 .get_channel(channel_id, member_id)
2351 .await?
2352 .ok_or_else(|| anyhow!("channel not found"))?;
2353
2354 let mut update = proto::UpdateChannels::default();
2355 if has_accepted {
2356 update.channel_permissions.push(proto::ChannelPermission {
2357 channel_id: channel.id.to_proto(),
2358 is_admin: request.admin,
2359 });
2360 }
2361
2362 for connection_id in session
2363 .connection_pool()
2364 .await
2365 .user_connection_ids(member_id)
2366 {
2367 session.peer.send(connection_id, update.clone())?;
2368 }
2369
2370 response.send(proto::Ack {})?;
2371 Ok(())
2372}
2373
2374async fn rename_channel(
2375 request: proto::RenameChannel,
2376 response: Response<proto::RenameChannel>,
2377 session: Session,
2378) -> Result<()> {
2379 let db = session.db().await;
2380 let channel_id = ChannelId::from_proto(request.channel_id);
2381 let new_name = db
2382 .rename_channel(channel_id, session.user_id, &request.name)
2383 .await?;
2384
2385 let channel = proto::Channel {
2386 id: request.channel_id,
2387 name: new_name,
2388 };
2389 response.send(proto::RenameChannelResponse {
2390 channel: Some(channel.clone()),
2391 })?;
2392 let mut update = proto::UpdateChannels::default();
2393 update.channels.push(channel);
2394
2395 let member_ids = db.get_channel_members(channel_id).await?;
2396
2397 let connection_pool = session.connection_pool().await;
2398 for member_id in member_ids {
2399 for connection_id in connection_pool.user_connection_ids(member_id) {
2400 session.peer.send(connection_id, update.clone())?;
2401 }
2402 }
2403
2404 Ok(())
2405}
2406
2407async fn link_channel(
2408 request: proto::LinkChannel,
2409 response: Response<proto::LinkChannel>,
2410 session: Session,
2411) -> Result<()> {
2412 let db = session.db().await;
2413 let channel_id = ChannelId::from_proto(request.channel_id);
2414 let to = ChannelId::from_proto(request.to);
2415 let channels_to_send = db.link_channel(session.user_id, channel_id, to).await?;
2416
2417 let members = db.get_channel_members(to).await?;
2418 let connection_pool = session.connection_pool().await;
2419 let update = proto::UpdateChannels {
2420 channels: channels_to_send
2421 .channels
2422 .into_iter()
2423 .map(|channel| proto::Channel {
2424 id: channel.id.to_proto(),
2425 name: channel.name,
2426 })
2427 .collect(),
2428 insert_edge: channels_to_send.edges,
2429 ..Default::default()
2430 };
2431 for member_id in members {
2432 for connection_id in connection_pool.user_connection_ids(member_id) {
2433 session.peer.send(connection_id, update.clone())?;
2434 }
2435 }
2436
2437 response.send(Ack {})?;
2438
2439 Ok(())
2440}
2441
2442async fn unlink_channel(
2443 request: proto::UnlinkChannel,
2444 response: Response<proto::UnlinkChannel>,
2445 session: Session,
2446) -> Result<()> {
2447 let db = session.db().await;
2448 let channel_id = ChannelId::from_proto(request.channel_id);
2449 let from = ChannelId::from_proto(request.from);
2450
2451 db.unlink_channel(session.user_id, channel_id, from).await?;
2452
2453 let members = db.get_channel_members(from).await?;
2454
2455 let update = proto::UpdateChannels {
2456 delete_edge: vec![proto::ChannelEdge {
2457 channel_id: channel_id.to_proto(),
2458 parent_id: from.to_proto(),
2459 }],
2460 ..Default::default()
2461 };
2462 let connection_pool = session.connection_pool().await;
2463 for member_id in members {
2464 for connection_id in connection_pool.user_connection_ids(member_id) {
2465 session.peer.send(connection_id, update.clone())?;
2466 }
2467 }
2468
2469 response.send(Ack {})?;
2470
2471 Ok(())
2472}
2473
2474async fn move_channel(
2475 request: proto::MoveChannel,
2476 response: Response<proto::MoveChannel>,
2477 session: Session,
2478) -> Result<()> {
2479 let db = session.db().await;
2480 let channel_id = ChannelId::from_proto(request.channel_id);
2481 let from_parent = ChannelId::from_proto(request.from);
2482 let to = ChannelId::from_proto(request.to);
2483
2484 let channels_to_send = db
2485 .move_channel(session.user_id, channel_id, from_parent, to)
2486 .await?;
2487
2488 if channels_to_send.is_empty() {
2489 response.send(Ack {})?;
2490 return Ok(());
2491 }
2492
2493 let members_from = db.get_channel_members(from_parent).await?;
2494 let members_to = db.get_channel_members(to).await?;
2495
2496 let update = proto::UpdateChannels {
2497 delete_edge: vec![proto::ChannelEdge {
2498 channel_id: channel_id.to_proto(),
2499 parent_id: from_parent.to_proto(),
2500 }],
2501 ..Default::default()
2502 };
2503 let connection_pool = session.connection_pool().await;
2504 for member_id in members_from {
2505 for connection_id in connection_pool.user_connection_ids(member_id) {
2506 session.peer.send(connection_id, update.clone())?;
2507 }
2508 }
2509
2510 let update = proto::UpdateChannels {
2511 channels: channels_to_send
2512 .channels
2513 .into_iter()
2514 .map(|channel| proto::Channel {
2515 id: channel.id.to_proto(),
2516 name: channel.name,
2517 })
2518 .collect(),
2519 insert_edge: channels_to_send.edges,
2520 ..Default::default()
2521 };
2522 for member_id in members_to {
2523 for connection_id in connection_pool.user_connection_ids(member_id) {
2524 session.peer.send(connection_id, update.clone())?;
2525 }
2526 }
2527
2528 response.send(Ack {})?;
2529
2530 Ok(())
2531}
2532
2533async fn get_channel_members(
2534 request: proto::GetChannelMembers,
2535 response: Response<proto::GetChannelMembers>,
2536 session: Session,
2537) -> Result<()> {
2538 let db = session.db().await;
2539 let channel_id = ChannelId::from_proto(request.channel_id);
2540 let members = db
2541 .get_channel_member_details(channel_id, session.user_id)
2542 .await?;
2543 response.send(proto::GetChannelMembersResponse { members })?;
2544 Ok(())
2545}
2546
2547async fn respond_to_channel_invite(
2548 request: proto::RespondToChannelInvite,
2549 response: Response<proto::RespondToChannelInvite>,
2550 session: Session,
2551) -> Result<()> {
2552 let db = session.db().await;
2553 let channel_id = ChannelId::from_proto(request.channel_id);
2554 db.respond_to_channel_invite(channel_id, session.user_id, request.accept)
2555 .await?;
2556
2557 let mut update = proto::UpdateChannels::default();
2558 update
2559 .remove_channel_invitations
2560 .push(channel_id.to_proto());
2561 if request.accept {
2562 let result = db.get_channel_for_user(channel_id, session.user_id).await?;
2563 update
2564 .channels
2565 .extend(
2566 result
2567 .channels
2568 .channels
2569 .into_iter()
2570 .map(|channel| proto::Channel {
2571 id: channel.id.to_proto(),
2572 name: channel.name,
2573 }),
2574 );
2575 update.unseen_channel_messages = result.channel_messages;
2576 update.unseen_channel_buffer_changes = result.unseen_buffer_changes;
2577 update.insert_edge = result.channels.edges;
2578 update
2579 .channel_participants
2580 .extend(
2581 result
2582 .channel_participants
2583 .into_iter()
2584 .map(|(channel_id, user_ids)| proto::ChannelParticipants {
2585 channel_id: channel_id.to_proto(),
2586 participant_user_ids: user_ids.into_iter().map(UserId::to_proto).collect(),
2587 }),
2588 );
2589 update
2590 .channel_permissions
2591 .extend(
2592 result
2593 .channels_with_admin_privileges
2594 .into_iter()
2595 .map(|channel_id| proto::ChannelPermission {
2596 channel_id: channel_id.to_proto(),
2597 is_admin: true,
2598 }),
2599 );
2600 }
2601 session.peer.send(session.connection_id, update)?;
2602 response.send(proto::Ack {})?;
2603
2604 Ok(())
2605}
2606
2607async fn join_channel(
2608 request: proto::JoinChannel,
2609 response: Response<proto::JoinChannel>,
2610 session: Session,
2611) -> Result<()> {
2612 let channel_id = ChannelId::from_proto(request.channel_id);
2613 let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
2614
2615 let joined_room = {
2616 leave_room_for_session(&session).await?;
2617 let db = session.db().await;
2618
2619 let room_id = db
2620 .get_or_create_channel_room(channel_id, &live_kit_room, &*RELEASE_CHANNEL_NAME)
2621 .await?;
2622
2623 let joined_room = db
2624 .join_room(
2625 room_id,
2626 session.user_id,
2627 session.connection_id,
2628 RELEASE_CHANNEL_NAME.as_str(),
2629 )
2630 .await?;
2631
2632 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2633 let token = live_kit
2634 .room_token(
2635 &joined_room.room.live_kit_room,
2636 &session.user_id.to_string(),
2637 )
2638 .trace_err()?;
2639
2640 Some(LiveKitConnectionInfo {
2641 server_url: live_kit.url().into(),
2642 token,
2643 })
2644 });
2645
2646 response.send(proto::JoinRoomResponse {
2647 room: Some(joined_room.room.clone()),
2648 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2649 live_kit_connection_info,
2650 })?;
2651
2652 room_updated(&joined_room.room, &session.peer);
2653
2654 joined_room.into_inner()
2655 };
2656
2657 channel_updated(
2658 channel_id,
2659 &joined_room.room,
2660 &joined_room.channel_members,
2661 &session.peer,
2662 &*session.connection_pool().await,
2663 );
2664
2665 update_user_contacts(session.user_id, &session).await?;
2666
2667 Ok(())
2668}
2669
2670async fn join_channel_buffer(
2671 request: proto::JoinChannelBuffer,
2672 response: Response<proto::JoinChannelBuffer>,
2673 session: Session,
2674) -> Result<()> {
2675 let db = session.db().await;
2676 let channel_id = ChannelId::from_proto(request.channel_id);
2677
2678 let open_response = db
2679 .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2680 .await?;
2681
2682 let collaborators = open_response.collaborators.clone();
2683 response.send(open_response)?;
2684
2685 let update = UpdateChannelBufferCollaborators {
2686 channel_id: channel_id.to_proto(),
2687 collaborators: collaborators.clone(),
2688 };
2689 channel_buffer_updated(
2690 session.connection_id,
2691 collaborators
2692 .iter()
2693 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2694 &update,
2695 &session.peer,
2696 );
2697
2698 Ok(())
2699}
2700
2701async fn update_channel_buffer(
2702 request: proto::UpdateChannelBuffer,
2703 session: Session,
2704) -> Result<()> {
2705 let db = session.db().await;
2706 let channel_id = ChannelId::from_proto(request.channel_id);
2707
2708 let (collaborators, non_collaborators, epoch, version) = db
2709 .update_channel_buffer(channel_id, session.user_id, &request.operations)
2710 .await?;
2711
2712 channel_buffer_updated(
2713 session.connection_id,
2714 collaborators,
2715 &proto::UpdateChannelBuffer {
2716 channel_id: channel_id.to_proto(),
2717 operations: request.operations,
2718 },
2719 &session.peer,
2720 );
2721
2722 let pool = &*session.connection_pool().await;
2723
2724 broadcast(
2725 None,
2726 non_collaborators
2727 .iter()
2728 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2729 |peer_id| {
2730 session.peer.send(
2731 peer_id.into(),
2732 proto::UpdateChannels {
2733 unseen_channel_buffer_changes: vec![proto::UnseenChannelBufferChange {
2734 channel_id: channel_id.to_proto(),
2735 epoch: epoch as u64,
2736 version: version.clone(),
2737 }],
2738 ..Default::default()
2739 },
2740 )
2741 },
2742 );
2743
2744 Ok(())
2745}
2746
2747async fn rejoin_channel_buffers(
2748 request: proto::RejoinChannelBuffers,
2749 response: Response<proto::RejoinChannelBuffers>,
2750 session: Session,
2751) -> Result<()> {
2752 let db = session.db().await;
2753 let buffers = db
2754 .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2755 .await?;
2756
2757 for rejoined_buffer in &buffers {
2758 let collaborators_to_notify = rejoined_buffer
2759 .buffer
2760 .collaborators
2761 .iter()
2762 .filter_map(|c| Some(c.peer_id?.into()));
2763 channel_buffer_updated(
2764 session.connection_id,
2765 collaborators_to_notify,
2766 &proto::UpdateChannelBufferCollaborators {
2767 channel_id: rejoined_buffer.buffer.channel_id,
2768 collaborators: rejoined_buffer.buffer.collaborators.clone(),
2769 },
2770 &session.peer,
2771 );
2772 }
2773
2774 response.send(proto::RejoinChannelBuffersResponse {
2775 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2776 })?;
2777
2778 Ok(())
2779}
2780
2781async fn leave_channel_buffer(
2782 request: proto::LeaveChannelBuffer,
2783 response: Response<proto::LeaveChannelBuffer>,
2784 session: Session,
2785) -> Result<()> {
2786 let db = session.db().await;
2787 let channel_id = ChannelId::from_proto(request.channel_id);
2788
2789 let left_buffer = db
2790 .leave_channel_buffer(channel_id, session.connection_id)
2791 .await?;
2792
2793 response.send(Ack {})?;
2794
2795 channel_buffer_updated(
2796 session.connection_id,
2797 left_buffer.connections,
2798 &proto::UpdateChannelBufferCollaborators {
2799 channel_id: channel_id.to_proto(),
2800 collaborators: left_buffer.collaborators,
2801 },
2802 &session.peer,
2803 );
2804
2805 Ok(())
2806}
2807
2808fn channel_buffer_updated<T: EnvelopedMessage>(
2809 sender_id: ConnectionId,
2810 collaborators: impl IntoIterator<Item = ConnectionId>,
2811 message: &T,
2812 peer: &Peer,
2813) {
2814 broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2815 peer.send(peer_id.into(), message.clone())
2816 });
2817}
2818
2819async fn send_channel_message(
2820 request: proto::SendChannelMessage,
2821 response: Response<proto::SendChannelMessage>,
2822 session: Session,
2823) -> Result<()> {
2824 // Validate the message body.
2825 let body = request.body.trim().to_string();
2826 if body.len() > MAX_MESSAGE_LEN {
2827 return Err(anyhow!("message is too long"))?;
2828 }
2829 if body.is_empty() {
2830 return Err(anyhow!("message can't be blank"))?;
2831 }
2832
2833 let timestamp = OffsetDateTime::now_utc();
2834 let nonce = request
2835 .nonce
2836 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
2837
2838 let channel_id = ChannelId::from_proto(request.channel_id);
2839 let (message_id, connection_ids, non_participants) = session
2840 .db()
2841 .await
2842 .create_channel_message(
2843 channel_id,
2844 session.user_id,
2845 &body,
2846 timestamp,
2847 nonce.clone().into(),
2848 )
2849 .await?;
2850 let message = proto::ChannelMessage {
2851 sender_id: session.user_id.to_proto(),
2852 id: message_id.to_proto(),
2853 body,
2854 timestamp: timestamp.unix_timestamp() as u64,
2855 nonce: Some(nonce),
2856 };
2857 broadcast(Some(session.connection_id), connection_ids, |connection| {
2858 session.peer.send(
2859 connection,
2860 proto::ChannelMessageSent {
2861 channel_id: channel_id.to_proto(),
2862 message: Some(message.clone()),
2863 },
2864 )
2865 });
2866 response.send(proto::SendChannelMessageResponse {
2867 message: Some(message),
2868 })?;
2869
2870 let pool = &*session.connection_pool().await;
2871 broadcast(
2872 None,
2873 non_participants
2874 .iter()
2875 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2876 |peer_id| {
2877 session.peer.send(
2878 peer_id.into(),
2879 proto::UpdateChannels {
2880 unseen_channel_messages: vec![proto::UnseenChannelMessage {
2881 channel_id: channel_id.to_proto(),
2882 message_id: message_id.to_proto(),
2883 }],
2884 ..Default::default()
2885 },
2886 )
2887 },
2888 );
2889
2890 Ok(())
2891}
2892
2893async fn remove_channel_message(
2894 request: proto::RemoveChannelMessage,
2895 response: Response<proto::RemoveChannelMessage>,
2896 session: Session,
2897) -> Result<()> {
2898 let channel_id = ChannelId::from_proto(request.channel_id);
2899 let message_id = MessageId::from_proto(request.message_id);
2900 let connection_ids = session
2901 .db()
2902 .await
2903 .remove_channel_message(channel_id, message_id, session.user_id)
2904 .await?;
2905 broadcast(Some(session.connection_id), connection_ids, |connection| {
2906 session.peer.send(connection, request.clone())
2907 });
2908 response.send(proto::Ack {})?;
2909 Ok(())
2910}
2911
2912async fn acknowledge_channel_message(
2913 request: proto::AckChannelMessage,
2914 session: Session,
2915) -> Result<()> {
2916 let channel_id = ChannelId::from_proto(request.channel_id);
2917 let message_id = MessageId::from_proto(request.message_id);
2918 session
2919 .db()
2920 .await
2921 .observe_channel_message(channel_id, session.user_id, message_id)
2922 .await?;
2923 Ok(())
2924}
2925
2926async fn acknowledge_buffer_version(
2927 request: proto::AckBufferOperation,
2928 session: Session,
2929) -> Result<()> {
2930 let buffer_id = BufferId::from_proto(request.buffer_id);
2931 session
2932 .db()
2933 .await
2934 .observe_buffer_version(
2935 buffer_id,
2936 session.user_id,
2937 request.epoch as i32,
2938 &request.version,
2939 )
2940 .await?;
2941 Ok(())
2942}
2943
2944async fn join_channel_chat(
2945 request: proto::JoinChannelChat,
2946 response: Response<proto::JoinChannelChat>,
2947 session: Session,
2948) -> Result<()> {
2949 let channel_id = ChannelId::from_proto(request.channel_id);
2950
2951 let db = session.db().await;
2952 db.join_channel_chat(channel_id, session.connection_id, session.user_id)
2953 .await?;
2954 let messages = db
2955 .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
2956 .await?;
2957 response.send(proto::JoinChannelChatResponse {
2958 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
2959 messages,
2960 })?;
2961 Ok(())
2962}
2963
2964async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
2965 let channel_id = ChannelId::from_proto(request.channel_id);
2966 session
2967 .db()
2968 .await
2969 .leave_channel_chat(channel_id, session.connection_id, session.user_id)
2970 .await?;
2971 Ok(())
2972}
2973
2974async fn get_channel_messages(
2975 request: proto::GetChannelMessages,
2976 response: Response<proto::GetChannelMessages>,
2977 session: Session,
2978) -> Result<()> {
2979 let channel_id = ChannelId::from_proto(request.channel_id);
2980 let messages = session
2981 .db()
2982 .await
2983 .get_channel_messages(
2984 channel_id,
2985 session.user_id,
2986 MESSAGE_COUNT_PER_PAGE,
2987 Some(MessageId::from_proto(request.before_message_id)),
2988 )
2989 .await?;
2990 response.send(proto::GetChannelMessagesResponse {
2991 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
2992 messages,
2993 })?;
2994 Ok(())
2995}
2996
2997async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
2998 let project_id = ProjectId::from_proto(request.project_id);
2999 let project_connection_ids = session
3000 .db()
3001 .await
3002 .project_connection_ids(project_id, session.connection_id)
3003 .await?;
3004 broadcast(
3005 Some(session.connection_id),
3006 project_connection_ids.iter().copied(),
3007 |connection_id| {
3008 session
3009 .peer
3010 .forward_send(session.connection_id, connection_id, request.clone())
3011 },
3012 );
3013 Ok(())
3014}
3015
3016async fn get_private_user_info(
3017 _request: proto::GetPrivateUserInfo,
3018 response: Response<proto::GetPrivateUserInfo>,
3019 session: Session,
3020) -> Result<()> {
3021 let db = session.db().await;
3022
3023 let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3024 let user = db
3025 .get_user_by_id(session.user_id)
3026 .await?
3027 .ok_or_else(|| anyhow!("user not found"))?;
3028 let flags = db.get_user_flags(session.user_id).await?;
3029
3030 response.send(proto::GetPrivateUserInfoResponse {
3031 metrics_id,
3032 staff: user.admin,
3033 flags,
3034 })?;
3035 Ok(())
3036}
3037
3038fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3039 match message {
3040 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3041 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3042 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3043 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3044 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3045 code: frame.code.into(),
3046 reason: frame.reason,
3047 })),
3048 }
3049}
3050
3051fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3052 match message {
3053 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3054 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3055 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3056 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3057 AxumMessage::Close(frame) => {
3058 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3059 code: frame.code.into(),
3060 reason: frame.reason,
3061 }))
3062 }
3063 }
3064}
3065
3066fn build_initial_channels_update(
3067 channels: ChannelsForUser,
3068 channel_invites: Vec<db::Channel>,
3069) -> proto::UpdateChannels {
3070 let mut update = proto::UpdateChannels::default();
3071
3072 for channel in channels.channels.channels {
3073 update.channels.push(proto::Channel {
3074 id: channel.id.to_proto(),
3075 name: channel.name,
3076 });
3077 }
3078
3079 update.unseen_channel_buffer_changes = channels.unseen_buffer_changes;
3080 update.unseen_channel_messages = channels.channel_messages;
3081 update.insert_edge = channels.channels.edges;
3082
3083 for (channel_id, participants) in channels.channel_participants {
3084 update
3085 .channel_participants
3086 .push(proto::ChannelParticipants {
3087 channel_id: channel_id.to_proto(),
3088 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3089 });
3090 }
3091
3092 update
3093 .channel_permissions
3094 .extend(
3095 channels
3096 .channels_with_admin_privileges
3097 .into_iter()
3098 .map(|id| proto::ChannelPermission {
3099 channel_id: id.to_proto(),
3100 is_admin: true,
3101 }),
3102 );
3103
3104 for channel in channel_invites {
3105 update.channel_invitations.push(proto::Channel {
3106 id: channel.id.to_proto(),
3107 name: channel.name,
3108 });
3109 }
3110
3111 update
3112}
3113
3114fn build_initial_contacts_update(
3115 contacts: Vec<db::Contact>,
3116 pool: &ConnectionPool,
3117) -> proto::UpdateContacts {
3118 let mut update = proto::UpdateContacts::default();
3119
3120 for contact in contacts {
3121 match contact {
3122 db::Contact::Accepted {
3123 user_id,
3124 should_notify,
3125 busy,
3126 } => {
3127 update
3128 .contacts
3129 .push(contact_for_user(user_id, should_notify, busy, &pool));
3130 }
3131 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3132 db::Contact::Incoming {
3133 user_id,
3134 should_notify,
3135 } => update
3136 .incoming_requests
3137 .push(proto::IncomingContactRequest {
3138 requester_id: user_id.to_proto(),
3139 should_notify,
3140 }),
3141 }
3142 }
3143
3144 update
3145}
3146
3147fn contact_for_user(
3148 user_id: UserId,
3149 should_notify: bool,
3150 busy: bool,
3151 pool: &ConnectionPool,
3152) -> proto::Contact {
3153 proto::Contact {
3154 user_id: user_id.to_proto(),
3155 online: pool.is_user_online(user_id),
3156 busy,
3157 should_notify,
3158 }
3159}
3160
3161fn room_updated(room: &proto::Room, peer: &Peer) {
3162 broadcast(
3163 None,
3164 room.participants
3165 .iter()
3166 .filter_map(|participant| Some(participant.peer_id?.into())),
3167 |peer_id| {
3168 peer.send(
3169 peer_id.into(),
3170 proto::RoomUpdated {
3171 room: Some(room.clone()),
3172 },
3173 )
3174 },
3175 );
3176}
3177
3178fn channel_updated(
3179 channel_id: ChannelId,
3180 room: &proto::Room,
3181 channel_members: &[UserId],
3182 peer: &Peer,
3183 pool: &ConnectionPool,
3184) {
3185 let participants = room
3186 .participants
3187 .iter()
3188 .map(|p| p.user_id)
3189 .collect::<Vec<_>>();
3190
3191 broadcast(
3192 None,
3193 channel_members
3194 .iter()
3195 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3196 |peer_id| {
3197 peer.send(
3198 peer_id.into(),
3199 proto::UpdateChannels {
3200 channel_participants: vec![proto::ChannelParticipants {
3201 channel_id: channel_id.to_proto(),
3202 participant_user_ids: participants.clone(),
3203 }],
3204 ..Default::default()
3205 },
3206 )
3207 },
3208 );
3209}
3210
3211async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3212 let db = session.db().await;
3213
3214 let contacts = db.get_contacts(user_id).await?;
3215 let busy = db.is_user_busy(user_id).await?;
3216
3217 let pool = session.connection_pool().await;
3218 let updated_contact = contact_for_user(user_id, false, busy, &pool);
3219 for contact in contacts {
3220 if let db::Contact::Accepted {
3221 user_id: contact_user_id,
3222 ..
3223 } = contact
3224 {
3225 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3226 session
3227 .peer
3228 .send(
3229 contact_conn_id,
3230 proto::UpdateContacts {
3231 contacts: vec![updated_contact.clone()],
3232 remove_contacts: Default::default(),
3233 incoming_requests: Default::default(),
3234 remove_incoming_requests: Default::default(),
3235 outgoing_requests: Default::default(),
3236 remove_outgoing_requests: Default::default(),
3237 },
3238 )
3239 .trace_err();
3240 }
3241 }
3242 }
3243 Ok(())
3244}
3245
3246async fn leave_room_for_session(session: &Session) -> Result<()> {
3247 let mut contacts_to_update = HashSet::default();
3248
3249 let room_id;
3250 let canceled_calls_to_user_ids;
3251 let live_kit_room;
3252 let delete_live_kit_room;
3253 let room;
3254 let channel_members;
3255 let channel_id;
3256
3257 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3258 contacts_to_update.insert(session.user_id);
3259
3260 for project in left_room.left_projects.values() {
3261 project_left(project, session);
3262 }
3263
3264 room_id = RoomId::from_proto(left_room.room.id);
3265 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3266 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3267 delete_live_kit_room = left_room.deleted;
3268 room = mem::take(&mut left_room.room);
3269 channel_members = mem::take(&mut left_room.channel_members);
3270 channel_id = left_room.channel_id;
3271
3272 room_updated(&room, &session.peer);
3273 } else {
3274 return Ok(());
3275 }
3276
3277 if let Some(channel_id) = channel_id {
3278 channel_updated(
3279 channel_id,
3280 &room,
3281 &channel_members,
3282 &session.peer,
3283 &*session.connection_pool().await,
3284 );
3285 }
3286
3287 {
3288 let pool = session.connection_pool().await;
3289 for canceled_user_id in canceled_calls_to_user_ids {
3290 for connection_id in pool.user_connection_ids(canceled_user_id) {
3291 session
3292 .peer
3293 .send(
3294 connection_id,
3295 proto::CallCanceled {
3296 room_id: room_id.to_proto(),
3297 },
3298 )
3299 .trace_err();
3300 }
3301 contacts_to_update.insert(canceled_user_id);
3302 }
3303 }
3304
3305 for contact_user_id in contacts_to_update {
3306 update_user_contacts(contact_user_id, &session).await?;
3307 }
3308
3309 if let Some(live_kit) = session.live_kit_client.as_ref() {
3310 live_kit
3311 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3312 .await
3313 .trace_err();
3314
3315 if delete_live_kit_room {
3316 live_kit.delete_room(live_kit_room).await.trace_err();
3317 }
3318 }
3319
3320 Ok(())
3321}
3322
3323async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3324 let left_channel_buffers = session
3325 .db()
3326 .await
3327 .leave_channel_buffers(session.connection_id)
3328 .await?;
3329
3330 for left_buffer in left_channel_buffers {
3331 channel_buffer_updated(
3332 session.connection_id,
3333 left_buffer.connections,
3334 &proto::UpdateChannelBufferCollaborators {
3335 channel_id: left_buffer.channel_id.to_proto(),
3336 collaborators: left_buffer.collaborators,
3337 },
3338 &session.peer,
3339 );
3340 }
3341
3342 Ok(())
3343}
3344
3345fn project_left(project: &db::LeftProject, session: &Session) {
3346 for connection_id in &project.connection_ids {
3347 if project.host_user_id == session.user_id {
3348 session
3349 .peer
3350 .send(
3351 *connection_id,
3352 proto::UnshareProject {
3353 project_id: project.id.to_proto(),
3354 },
3355 )
3356 .trace_err();
3357 } else {
3358 session
3359 .peer
3360 .send(
3361 *connection_id,
3362 proto::RemoveProjectCollaborator {
3363 project_id: project.id.to_proto(),
3364 peer_id: Some(session.connection_id.into()),
3365 },
3366 )
3367 .trace_err();
3368 }
3369 }
3370}
3371
3372pub trait ResultExt {
3373 type Ok;
3374
3375 fn trace_err(self) -> Option<Self::Ok>;
3376}
3377
3378impl<T, E> ResultExt for Result<T, E>
3379where
3380 E: std::fmt::Debug,
3381{
3382 type Ok = T;
3383
3384 fn trace_err(self) -> Option<T> {
3385 match self {
3386 Ok(value) => Some(value),
3387 Err(error) => {
3388 tracing::error!("{:?}", error);
3389 None
3390 }
3391 }
3392 }
3393}