1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{
6 self, BufferId, ChannelId, ChannelRole, ChannelVisibility, ChannelsForUser, Database,
7 MessageId, ProjectId, RoomId, ServerId, User, UserId,
8 },
9 executor::Executor,
10 AppState, Result,
11};
12use anyhow::anyhow;
13use async_tungstenite::tungstenite::{
14 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
15};
16use axum::{
17 body::Body,
18 extract::{
19 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
20 ConnectInfo, WebSocketUpgrade,
21 },
22 headers::{Header, HeaderName},
23 http::StatusCode,
24 middleware,
25 response::IntoResponse,
26 routing::get,
27 Extension, Router, TypedHeader,
28};
29use collections::{HashMap, HashSet};
30pub use connection_pool::ConnectionPool;
31use futures::{
32 channel::oneshot,
33 future::{self, BoxFuture},
34 stream::FuturesUnordered,
35 FutureExt, SinkExt, StreamExt, TryStreamExt,
36};
37use lazy_static::lazy_static;
38use prometheus::{register_int_gauge, IntGauge};
39use rpc::{
40 proto::{
41 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
42 RequestMessage, UpdateChannelBufferCollaborators,
43 },
44 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
45};
46use serde::{Serialize, Serializer};
47use std::{
48 any::TypeId,
49 fmt,
50 future::Future,
51 marker::PhantomData,
52 mem,
53 net::SocketAddr,
54 ops::{Deref, DerefMut},
55 rc::Rc,
56 sync::{
57 atomic::{AtomicBool, Ordering::SeqCst},
58 Arc,
59 },
60 time::{Duration, Instant},
61};
62use time::OffsetDateTime;
63use tokio::sync::{watch, Semaphore};
64use tower::ServiceBuilder;
65use tracing::{info_span, instrument, Instrument};
66use util::channel::RELEASE_CHANNEL_NAME;
67
68pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
69pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
70
71const MESSAGE_COUNT_PER_PAGE: usize = 100;
72const MAX_MESSAGE_LEN: usize = 1024;
73
74lazy_static! {
75 static ref METRIC_CONNECTIONS: IntGauge =
76 register_int_gauge!("connections", "number of connections").unwrap();
77 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
78 "shared_projects",
79 "number of open projects with one or more guests"
80 )
81 .unwrap();
82}
83
84type MessageHandler =
85 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
86
87struct Response<R> {
88 peer: Arc<Peer>,
89 receipt: Receipt<R>,
90 responded: Arc<AtomicBool>,
91}
92
93impl<R: RequestMessage> Response<R> {
94 fn send(self, payload: R::Response) -> Result<()> {
95 self.responded.store(true, SeqCst);
96 self.peer.respond(self.receipt, payload)?;
97 Ok(())
98 }
99}
100
101#[derive(Clone)]
102struct Session {
103 user_id: UserId,
104 connection_id: ConnectionId,
105 db: Arc<tokio::sync::Mutex<DbHandle>>,
106 peer: Arc<Peer>,
107 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
108 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
109 executor: Executor,
110}
111
112impl Session {
113 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
114 #[cfg(test)]
115 tokio::task::yield_now().await;
116 let guard = self.db.lock().await;
117 #[cfg(test)]
118 tokio::task::yield_now().await;
119 guard
120 }
121
122 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
123 #[cfg(test)]
124 tokio::task::yield_now().await;
125 let guard = self.connection_pool.lock();
126 ConnectionPoolGuard {
127 guard,
128 _not_send: PhantomData,
129 }
130 }
131}
132
133impl fmt::Debug for Session {
134 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
135 f.debug_struct("Session")
136 .field("user_id", &self.user_id)
137 .field("connection_id", &self.connection_id)
138 .finish()
139 }
140}
141
142struct DbHandle(Arc<Database>);
143
144impl Deref for DbHandle {
145 type Target = Database;
146
147 fn deref(&self) -> &Self::Target {
148 self.0.as_ref()
149 }
150}
151
152pub struct Server {
153 id: parking_lot::Mutex<ServerId>,
154 peer: Arc<Peer>,
155 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
156 app_state: Arc<AppState>,
157 executor: Executor,
158 handlers: HashMap<TypeId, MessageHandler>,
159 teardown: watch::Sender<()>,
160}
161
162pub(crate) struct ConnectionPoolGuard<'a> {
163 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
164 _not_send: PhantomData<Rc<()>>,
165}
166
167#[derive(Serialize)]
168pub struct ServerSnapshot<'a> {
169 peer: &'a Peer,
170 #[serde(serialize_with = "serialize_deref")]
171 connection_pool: ConnectionPoolGuard<'a>,
172}
173
174pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
175where
176 S: Serializer,
177 T: Deref<Target = U>,
178 U: Serialize,
179{
180 Serialize::serialize(value.deref(), serializer)
181}
182
183impl Server {
184 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
185 let mut server = Self {
186 id: parking_lot::Mutex::new(id),
187 peer: Peer::new(id.0 as u32),
188 app_state,
189 executor,
190 connection_pool: Default::default(),
191 handlers: Default::default(),
192 teardown: watch::channel(()).0,
193 };
194
195 server
196 .add_request_handler(ping)
197 .add_request_handler(create_room)
198 .add_request_handler(join_room)
199 .add_request_handler(rejoin_room)
200 .add_request_handler(leave_room)
201 .add_request_handler(call)
202 .add_request_handler(cancel_call)
203 .add_message_handler(decline_call)
204 .add_request_handler(update_participant_location)
205 .add_request_handler(share_project)
206 .add_message_handler(unshare_project)
207 .add_request_handler(join_project)
208 .add_message_handler(leave_project)
209 .add_request_handler(update_project)
210 .add_request_handler(update_worktree)
211 .add_message_handler(start_language_server)
212 .add_message_handler(update_language_server)
213 .add_message_handler(update_diagnostic_summary)
214 .add_message_handler(update_worktree_settings)
215 .add_message_handler(refresh_inlay_hints)
216 .add_request_handler(forward_project_request::<proto::GetHover>)
217 .add_request_handler(forward_project_request::<proto::GetDefinition>)
218 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
219 .add_request_handler(forward_project_request::<proto::GetReferences>)
220 .add_request_handler(forward_project_request::<proto::SearchProject>)
221 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
222 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
223 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
224 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
225 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
226 .add_request_handler(forward_project_request::<proto::GetCompletions>)
227 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
228 .add_request_handler(forward_project_request::<proto::ResolveCompletionDocumentation>)
229 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
230 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
231 .add_request_handler(forward_project_request::<proto::PrepareRename>)
232 .add_request_handler(forward_project_request::<proto::PerformRename>)
233 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
234 .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
235 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
236 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
237 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
238 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
239 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
240 .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
241 .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
242 .add_request_handler(forward_project_request::<proto::InlayHints>)
243 .add_message_handler(create_buffer_for_peer)
244 .add_request_handler(update_buffer)
245 .add_message_handler(update_buffer_file)
246 .add_message_handler(buffer_reloaded)
247 .add_message_handler(buffer_saved)
248 .add_request_handler(forward_project_request::<proto::SaveBuffer>)
249 .add_request_handler(get_users)
250 .add_request_handler(fuzzy_search_users)
251 .add_request_handler(request_contact)
252 .add_request_handler(remove_contact)
253 .add_request_handler(respond_to_contact_request)
254 .add_request_handler(create_channel)
255 .add_request_handler(delete_channel)
256 .add_request_handler(invite_channel_member)
257 .add_request_handler(remove_channel_member)
258 .add_request_handler(set_channel_member_role)
259 .add_request_handler(set_channel_visibility)
260 .add_request_handler(rename_channel)
261 .add_request_handler(join_channel_buffer)
262 .add_request_handler(leave_channel_buffer)
263 .add_message_handler(update_channel_buffer)
264 .add_request_handler(rejoin_channel_buffers)
265 .add_request_handler(get_channel_members)
266 .add_request_handler(respond_to_channel_invite)
267 .add_request_handler(join_channel)
268 .add_request_handler(join_channel_chat)
269 .add_message_handler(leave_channel_chat)
270 .add_request_handler(send_channel_message)
271 .add_request_handler(remove_channel_message)
272 .add_request_handler(get_channel_messages)
273 .add_request_handler(link_channel)
274 .add_request_handler(unlink_channel)
275 .add_request_handler(move_channel)
276 .add_request_handler(follow)
277 .add_message_handler(unfollow)
278 .add_message_handler(update_followers)
279 .add_message_handler(update_diff_base)
280 .add_request_handler(get_private_user_info)
281 .add_message_handler(acknowledge_channel_message)
282 .add_message_handler(acknowledge_buffer_version);
283
284 Arc::new(server)
285 }
286
287 pub async fn start(&self) -> Result<()> {
288 let server_id = *self.id.lock();
289 let app_state = self.app_state.clone();
290 let peer = self.peer.clone();
291 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
292 let pool = self.connection_pool.clone();
293 let live_kit_client = self.app_state.live_kit_client.clone();
294
295 let span = info_span!("start server");
296 self.executor.spawn_detached(
297 async move {
298 tracing::info!("waiting for cleanup timeout");
299 timeout.await;
300 tracing::info!("cleanup timeout expired, retrieving stale rooms");
301 if let Some((room_ids, channel_ids)) = app_state
302 .db
303 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
304 .await
305 .trace_err()
306 {
307 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
308 tracing::info!(
309 stale_channel_buffer_count = channel_ids.len(),
310 "retrieved stale channel buffers"
311 );
312
313 for channel_id in channel_ids {
314 if let Some(refreshed_channel_buffer) = app_state
315 .db
316 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
317 .await
318 .trace_err()
319 {
320 for connection_id in refreshed_channel_buffer.connection_ids {
321 peer.send(
322 connection_id,
323 proto::UpdateChannelBufferCollaborators {
324 channel_id: channel_id.to_proto(),
325 collaborators: refreshed_channel_buffer
326 .collaborators
327 .clone(),
328 },
329 )
330 .trace_err();
331 }
332 }
333 }
334
335 for room_id in room_ids {
336 let mut contacts_to_update = HashSet::default();
337 let mut canceled_calls_to_user_ids = Vec::new();
338 let mut live_kit_room = String::new();
339 let mut delete_live_kit_room = false;
340
341 if let Some(mut refreshed_room) = app_state
342 .db
343 .clear_stale_room_participants(room_id, server_id)
344 .await
345 .trace_err()
346 {
347 tracing::info!(
348 room_id = room_id.0,
349 new_participant_count = refreshed_room.room.participants.len(),
350 "refreshed room"
351 );
352 room_updated(&refreshed_room.room, &peer);
353 if let Some(channel_id) = refreshed_room.channel_id {
354 channel_updated(
355 channel_id,
356 &refreshed_room.room,
357 &refreshed_room.channel_members,
358 &peer,
359 &*pool.lock(),
360 );
361 }
362 contacts_to_update
363 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
364 contacts_to_update
365 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
366 canceled_calls_to_user_ids =
367 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
368 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
369 delete_live_kit_room = refreshed_room.room.participants.is_empty();
370 }
371
372 {
373 let pool = pool.lock();
374 for canceled_user_id in canceled_calls_to_user_ids {
375 for connection_id in pool.user_connection_ids(canceled_user_id) {
376 peer.send(
377 connection_id,
378 proto::CallCanceled {
379 room_id: room_id.to_proto(),
380 },
381 )
382 .trace_err();
383 }
384 }
385 }
386
387 for user_id in contacts_to_update {
388 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
389 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
390 if let Some((busy, contacts)) = busy.zip(contacts) {
391 let pool = pool.lock();
392 let updated_contact = contact_for_user(user_id, false, busy, &pool);
393 for contact in contacts {
394 if let db::Contact::Accepted {
395 user_id: contact_user_id,
396 ..
397 } = contact
398 {
399 for contact_conn_id in
400 pool.user_connection_ids(contact_user_id)
401 {
402 peer.send(
403 contact_conn_id,
404 proto::UpdateContacts {
405 contacts: vec![updated_contact.clone()],
406 remove_contacts: Default::default(),
407 incoming_requests: Default::default(),
408 remove_incoming_requests: Default::default(),
409 outgoing_requests: Default::default(),
410 remove_outgoing_requests: Default::default(),
411 },
412 )
413 .trace_err();
414 }
415 }
416 }
417 }
418 }
419
420 if let Some(live_kit) = live_kit_client.as_ref() {
421 if delete_live_kit_room {
422 live_kit.delete_room(live_kit_room).await.trace_err();
423 }
424 }
425 }
426 }
427
428 app_state
429 .db
430 .delete_stale_servers(&app_state.config.zed_environment, server_id)
431 .await
432 .trace_err();
433 }
434 .instrument(span),
435 );
436 Ok(())
437 }
438
439 pub fn teardown(&self) {
440 self.peer.teardown();
441 self.connection_pool.lock().reset();
442 let _ = self.teardown.send(());
443 }
444
445 #[cfg(test)]
446 pub fn reset(&self, id: ServerId) {
447 self.teardown();
448 *self.id.lock() = id;
449 self.peer.reset(id.0 as u32);
450 }
451
452 #[cfg(test)]
453 pub fn id(&self) -> ServerId {
454 *self.id.lock()
455 }
456
457 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
458 where
459 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
460 Fut: 'static + Send + Future<Output = Result<()>>,
461 M: EnvelopedMessage,
462 {
463 let prev_handler = self.handlers.insert(
464 TypeId::of::<M>(),
465 Box::new(move |envelope, session| {
466 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
467 let span = info_span!(
468 "handle message",
469 payload_type = envelope.payload_type_name()
470 );
471 span.in_scope(|| {
472 tracing::info!(
473 payload_type = envelope.payload_type_name(),
474 "message received"
475 );
476 });
477 let start_time = Instant::now();
478 let future = (handler)(*envelope, session);
479 async move {
480 let result = future.await;
481 let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
482 match result {
483 Err(error) => {
484 tracing::error!(%error, ?duration_ms, "error handling message")
485 }
486 Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
487 }
488 }
489 .instrument(span)
490 .boxed()
491 }),
492 );
493 if prev_handler.is_some() {
494 panic!("registered a handler for the same message twice");
495 }
496 self
497 }
498
499 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
500 where
501 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
502 Fut: 'static + Send + Future<Output = Result<()>>,
503 M: EnvelopedMessage,
504 {
505 self.add_handler(move |envelope, session| handler(envelope.payload, session));
506 self
507 }
508
509 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
510 where
511 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
512 Fut: Send + Future<Output = Result<()>>,
513 M: RequestMessage,
514 {
515 let handler = Arc::new(handler);
516 self.add_handler(move |envelope, session| {
517 let receipt = envelope.receipt();
518 let handler = handler.clone();
519 async move {
520 let peer = session.peer.clone();
521 let responded = Arc::new(AtomicBool::default());
522 let response = Response {
523 peer: peer.clone(),
524 responded: responded.clone(),
525 receipt,
526 };
527 match (handler)(envelope.payload, response, session).await {
528 Ok(()) => {
529 if responded.load(std::sync::atomic::Ordering::SeqCst) {
530 Ok(())
531 } else {
532 Err(anyhow!("handler did not send a response"))?
533 }
534 }
535 Err(error) => {
536 peer.respond_with_error(
537 receipt,
538 proto::Error {
539 message: error.to_string(),
540 },
541 )?;
542 Err(error)
543 }
544 }
545 }
546 })
547 }
548
549 pub fn handle_connection(
550 self: &Arc<Self>,
551 connection: Connection,
552 address: String,
553 user: User,
554 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
555 executor: Executor,
556 ) -> impl Future<Output = Result<()>> {
557 let this = self.clone();
558 let user_id = user.id;
559 let login = user.github_login;
560 let span = info_span!("handle connection", %user_id, %login, %address);
561 let mut teardown = self.teardown.subscribe();
562 async move {
563 let (connection_id, handle_io, mut incoming_rx) = this
564 .peer
565 .add_connection(connection, {
566 let executor = executor.clone();
567 move |duration| executor.sleep(duration)
568 });
569
570 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
571 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
572 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
573
574 if let Some(send_connection_id) = send_connection_id.take() {
575 let _ = send_connection_id.send(connection_id);
576 }
577
578 if !user.connected_once {
579 this.peer.send(connection_id, proto::ShowContacts {})?;
580 this.app_state.db.set_user_connected_once(user_id, true).await?;
581 }
582
583 let (contacts, channels_for_user, channel_invites) = future::try_join3(
584 this.app_state.db.get_contacts(user_id),
585 this.app_state.db.get_channels_for_user(user_id),
586 this.app_state.db.get_channel_invites_for_user(user_id)
587 ).await?;
588
589 {
590 let mut pool = this.connection_pool.lock();
591 pool.add_connection(connection_id, user_id, user.admin);
592 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
593 this.peer.send(connection_id, build_initial_channels_update(
594 channels_for_user,
595 channel_invites
596 ))?;
597 }
598
599 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
600 this.peer.send(connection_id, incoming_call)?;
601 }
602
603 let session = Session {
604 user_id,
605 connection_id,
606 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
607 peer: this.peer.clone(),
608 connection_pool: this.connection_pool.clone(),
609 live_kit_client: this.app_state.live_kit_client.clone(),
610 executor: executor.clone(),
611 };
612 update_user_contacts(user_id, &session).await?;
613
614 let handle_io = handle_io.fuse();
615 futures::pin_mut!(handle_io);
616
617 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
618 // This prevents deadlocks when e.g., client A performs a request to client B and
619 // client B performs a request to client A. If both clients stop processing further
620 // messages until their respective request completes, they won't have a chance to
621 // respond to the other client's request and cause a deadlock.
622 //
623 // This arrangement ensures we will attempt to process earlier messages first, but fall
624 // back to processing messages arrived later in the spirit of making progress.
625 let mut foreground_message_handlers = FuturesUnordered::new();
626 let concurrent_handlers = Arc::new(Semaphore::new(256));
627 loop {
628 let next_message = async {
629 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
630 let message = incoming_rx.next().await;
631 (permit, message)
632 }.fuse();
633 futures::pin_mut!(next_message);
634 futures::select_biased! {
635 _ = teardown.changed().fuse() => return Ok(()),
636 result = handle_io => {
637 if let Err(error) = result {
638 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
639 }
640 break;
641 }
642 _ = foreground_message_handlers.next() => {}
643 next_message = next_message => {
644 let (permit, message) = next_message;
645 if let Some(message) = message {
646 let type_name = message.payload_type_name();
647 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
648 let span_enter = span.enter();
649 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
650 let is_background = message.is_background();
651 let handle_message = (handler)(message, session.clone());
652 drop(span_enter);
653
654 let handle_message = async move {
655 handle_message.await;
656 drop(permit);
657 }.instrument(span);
658 if is_background {
659 executor.spawn_detached(handle_message);
660 } else {
661 foreground_message_handlers.push(handle_message);
662 }
663 } else {
664 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
665 }
666 } else {
667 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
668 break;
669 }
670 }
671 }
672 }
673
674 drop(foreground_message_handlers);
675 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
676 if let Err(error) = connection_lost(session, teardown, executor).await {
677 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
678 }
679
680 Ok(())
681 }.instrument(span)
682 }
683
684 pub async fn invite_code_redeemed(
685 self: &Arc<Self>,
686 inviter_id: UserId,
687 invitee_id: UserId,
688 ) -> Result<()> {
689 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
690 if let Some(code) = &user.invite_code {
691 let pool = self.connection_pool.lock();
692 let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
693 for connection_id in pool.user_connection_ids(inviter_id) {
694 self.peer.send(
695 connection_id,
696 proto::UpdateContacts {
697 contacts: vec![invitee_contact.clone()],
698 ..Default::default()
699 },
700 )?;
701 self.peer.send(
702 connection_id,
703 proto::UpdateInviteInfo {
704 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
705 count: user.invite_count as u32,
706 },
707 )?;
708 }
709 }
710 }
711 Ok(())
712 }
713
714 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
715 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
716 if let Some(invite_code) = &user.invite_code {
717 let pool = self.connection_pool.lock();
718 for connection_id in pool.user_connection_ids(user_id) {
719 self.peer.send(
720 connection_id,
721 proto::UpdateInviteInfo {
722 url: format!(
723 "{}{}",
724 self.app_state.config.invite_link_prefix, invite_code
725 ),
726 count: user.invite_count as u32,
727 },
728 )?;
729 }
730 }
731 }
732 Ok(())
733 }
734
735 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
736 ServerSnapshot {
737 connection_pool: ConnectionPoolGuard {
738 guard: self.connection_pool.lock(),
739 _not_send: PhantomData,
740 },
741 peer: &self.peer,
742 }
743 }
744}
745
746impl<'a> Deref for ConnectionPoolGuard<'a> {
747 type Target = ConnectionPool;
748
749 fn deref(&self) -> &Self::Target {
750 &*self.guard
751 }
752}
753
754impl<'a> DerefMut for ConnectionPoolGuard<'a> {
755 fn deref_mut(&mut self) -> &mut Self::Target {
756 &mut *self.guard
757 }
758}
759
760impl<'a> Drop for ConnectionPoolGuard<'a> {
761 fn drop(&mut self) {
762 #[cfg(test)]
763 self.check_invariants();
764 }
765}
766
767fn broadcast<F>(
768 sender_id: Option<ConnectionId>,
769 receiver_ids: impl IntoIterator<Item = ConnectionId>,
770 mut f: F,
771) where
772 F: FnMut(ConnectionId) -> anyhow::Result<()>,
773{
774 for receiver_id in receiver_ids {
775 if Some(receiver_id) != sender_id {
776 if let Err(error) = f(receiver_id) {
777 tracing::error!("failed to send to {:?} {}", receiver_id, error);
778 }
779 }
780 }
781}
782
783lazy_static! {
784 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
785}
786
787pub struct ProtocolVersion(u32);
788
789impl Header for ProtocolVersion {
790 fn name() -> &'static HeaderName {
791 &ZED_PROTOCOL_VERSION
792 }
793
794 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
795 where
796 Self: Sized,
797 I: Iterator<Item = &'i axum::http::HeaderValue>,
798 {
799 let version = values
800 .next()
801 .ok_or_else(axum::headers::Error::invalid)?
802 .to_str()
803 .map_err(|_| axum::headers::Error::invalid())?
804 .parse()
805 .map_err(|_| axum::headers::Error::invalid())?;
806 Ok(Self(version))
807 }
808
809 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
810 values.extend([self.0.to_string().parse().unwrap()]);
811 }
812}
813
814pub fn routes(server: Arc<Server>) -> Router<Body> {
815 Router::new()
816 .route("/rpc", get(handle_websocket_request))
817 .layer(
818 ServiceBuilder::new()
819 .layer(Extension(server.app_state.clone()))
820 .layer(middleware::from_fn(auth::validate_header)),
821 )
822 .route("/metrics", get(handle_metrics))
823 .layer(Extension(server))
824}
825
826pub async fn handle_websocket_request(
827 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
828 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
829 Extension(server): Extension<Arc<Server>>,
830 Extension(user): Extension<User>,
831 ws: WebSocketUpgrade,
832) -> axum::response::Response {
833 if protocol_version != rpc::PROTOCOL_VERSION {
834 return (
835 StatusCode::UPGRADE_REQUIRED,
836 "client must be upgraded".to_string(),
837 )
838 .into_response();
839 }
840 let socket_address = socket_address.to_string();
841 ws.on_upgrade(move |socket| {
842 use util::ResultExt;
843 let socket = socket
844 .map_ok(to_tungstenite_message)
845 .err_into()
846 .with(|message| async move { Ok(to_axum_message(message)) });
847 let connection = Connection::new(Box::pin(socket));
848 async move {
849 server
850 .handle_connection(connection, socket_address, user, None, Executor::Production)
851 .await
852 .log_err();
853 }
854 })
855}
856
857pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
858 let connections = server
859 .connection_pool
860 .lock()
861 .connections()
862 .filter(|connection| !connection.admin)
863 .count();
864
865 METRIC_CONNECTIONS.set(connections as _);
866
867 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
868 METRIC_SHARED_PROJECTS.set(shared_projects as _);
869
870 let encoder = prometheus::TextEncoder::new();
871 let metric_families = prometheus::gather();
872 let encoded_metrics = encoder
873 .encode_to_string(&metric_families)
874 .map_err(|err| anyhow!("{}", err))?;
875 Ok(encoded_metrics)
876}
877
878#[instrument(err, skip(executor))]
879async fn connection_lost(
880 session: Session,
881 mut teardown: watch::Receiver<()>,
882 executor: Executor,
883) -> Result<()> {
884 session.peer.disconnect(session.connection_id);
885 session
886 .connection_pool()
887 .await
888 .remove_connection(session.connection_id)?;
889
890 session
891 .db()
892 .await
893 .connection_lost(session.connection_id)
894 .await
895 .trace_err();
896
897 futures::select_biased! {
898 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
899 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
900 leave_room_for_session(&session).await.trace_err();
901 leave_channel_buffers_for_session(&session)
902 .await
903 .trace_err();
904
905 if !session
906 .connection_pool()
907 .await
908 .is_user_online(session.user_id)
909 {
910 let db = session.db().await;
911 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
912 room_updated(&room, &session.peer);
913 }
914 }
915
916 update_user_contacts(session.user_id, &session).await?;
917 }
918 _ = teardown.changed().fuse() => {}
919 }
920
921 Ok(())
922}
923
924async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
925 response.send(proto::Ack {})?;
926 Ok(())
927}
928
929async fn create_room(
930 _request: proto::CreateRoom,
931 response: Response<proto::CreateRoom>,
932 session: Session,
933) -> Result<()> {
934 let live_kit_room = nanoid::nanoid!(30);
935
936 let live_kit_connection_info = {
937 let live_kit_room = live_kit_room.clone();
938 let live_kit = session.live_kit_client.as_ref();
939
940 util::async_iife!({
941 let live_kit = live_kit?;
942
943 let token = live_kit
944 .room_token(&live_kit_room, &session.user_id.to_string())
945 .trace_err()?;
946
947 Some(proto::LiveKitConnectionInfo {
948 server_url: live_kit.url().into(),
949 token,
950 })
951 })
952 }
953 .await;
954
955 let room = session
956 .db()
957 .await
958 .create_room(
959 session.user_id,
960 session.connection_id,
961 &live_kit_room,
962 RELEASE_CHANNEL_NAME.as_str(),
963 )
964 .await?;
965
966 response.send(proto::CreateRoomResponse {
967 room: Some(room.clone()),
968 live_kit_connection_info,
969 })?;
970
971 update_user_contacts(session.user_id, &session).await?;
972 Ok(())
973}
974
975async fn join_room(
976 request: proto::JoinRoom,
977 response: Response<proto::JoinRoom>,
978 session: Session,
979) -> Result<()> {
980 let room_id = RoomId::from_proto(request.id);
981
982 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
983
984 if let Some(channel_id) = channel_id {
985 return join_channel_internal(channel_id, Box::new(response), session).await;
986 }
987
988 let joined_room = {
989 let room = session
990 .db()
991 .await
992 .join_room(
993 room_id,
994 session.user_id,
995 session.connection_id,
996 RELEASE_CHANNEL_NAME.as_str(),
997 )
998 .await?;
999 room_updated(&room.room, &session.peer);
1000 room.into_inner()
1001 };
1002
1003 for connection_id in session
1004 .connection_pool()
1005 .await
1006 .user_connection_ids(session.user_id)
1007 {
1008 session
1009 .peer
1010 .send(
1011 connection_id,
1012 proto::CallCanceled {
1013 room_id: room_id.to_proto(),
1014 },
1015 )
1016 .trace_err();
1017 }
1018
1019 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1020 if let Some(token) = live_kit
1021 .room_token(
1022 &joined_room.room.live_kit_room,
1023 &session.user_id.to_string(),
1024 )
1025 .trace_err()
1026 {
1027 Some(proto::LiveKitConnectionInfo {
1028 server_url: live_kit.url().into(),
1029 token,
1030 })
1031 } else {
1032 None
1033 }
1034 } else {
1035 None
1036 };
1037
1038 response.send(proto::JoinRoomResponse {
1039 room: Some(joined_room.room),
1040 channel_id: None,
1041 live_kit_connection_info,
1042 })?;
1043
1044 update_user_contacts(session.user_id, &session).await?;
1045 Ok(())
1046}
1047
1048async fn rejoin_room(
1049 request: proto::RejoinRoom,
1050 response: Response<proto::RejoinRoom>,
1051 session: Session,
1052) -> Result<()> {
1053 let room;
1054 let channel_id;
1055 let channel_members;
1056 {
1057 let mut rejoined_room = session
1058 .db()
1059 .await
1060 .rejoin_room(request, session.user_id, session.connection_id)
1061 .await?;
1062
1063 response.send(proto::RejoinRoomResponse {
1064 room: Some(rejoined_room.room.clone()),
1065 reshared_projects: rejoined_room
1066 .reshared_projects
1067 .iter()
1068 .map(|project| proto::ResharedProject {
1069 id: project.id.to_proto(),
1070 collaborators: project
1071 .collaborators
1072 .iter()
1073 .map(|collaborator| collaborator.to_proto())
1074 .collect(),
1075 })
1076 .collect(),
1077 rejoined_projects: rejoined_room
1078 .rejoined_projects
1079 .iter()
1080 .map(|rejoined_project| proto::RejoinedProject {
1081 id: rejoined_project.id.to_proto(),
1082 worktrees: rejoined_project
1083 .worktrees
1084 .iter()
1085 .map(|worktree| proto::WorktreeMetadata {
1086 id: worktree.id,
1087 root_name: worktree.root_name.clone(),
1088 visible: worktree.visible,
1089 abs_path: worktree.abs_path.clone(),
1090 })
1091 .collect(),
1092 collaborators: rejoined_project
1093 .collaborators
1094 .iter()
1095 .map(|collaborator| collaborator.to_proto())
1096 .collect(),
1097 language_servers: rejoined_project.language_servers.clone(),
1098 })
1099 .collect(),
1100 })?;
1101 room_updated(&rejoined_room.room, &session.peer);
1102
1103 for project in &rejoined_room.reshared_projects {
1104 for collaborator in &project.collaborators {
1105 session
1106 .peer
1107 .send(
1108 collaborator.connection_id,
1109 proto::UpdateProjectCollaborator {
1110 project_id: project.id.to_proto(),
1111 old_peer_id: Some(project.old_connection_id.into()),
1112 new_peer_id: Some(session.connection_id.into()),
1113 },
1114 )
1115 .trace_err();
1116 }
1117
1118 broadcast(
1119 Some(session.connection_id),
1120 project
1121 .collaborators
1122 .iter()
1123 .map(|collaborator| collaborator.connection_id),
1124 |connection_id| {
1125 session.peer.forward_send(
1126 session.connection_id,
1127 connection_id,
1128 proto::UpdateProject {
1129 project_id: project.id.to_proto(),
1130 worktrees: project.worktrees.clone(),
1131 },
1132 )
1133 },
1134 );
1135 }
1136
1137 for project in &rejoined_room.rejoined_projects {
1138 for collaborator in &project.collaborators {
1139 session
1140 .peer
1141 .send(
1142 collaborator.connection_id,
1143 proto::UpdateProjectCollaborator {
1144 project_id: project.id.to_proto(),
1145 old_peer_id: Some(project.old_connection_id.into()),
1146 new_peer_id: Some(session.connection_id.into()),
1147 },
1148 )
1149 .trace_err();
1150 }
1151 }
1152
1153 for project in &mut rejoined_room.rejoined_projects {
1154 for worktree in mem::take(&mut project.worktrees) {
1155 #[cfg(any(test, feature = "test-support"))]
1156 const MAX_CHUNK_SIZE: usize = 2;
1157 #[cfg(not(any(test, feature = "test-support")))]
1158 const MAX_CHUNK_SIZE: usize = 256;
1159
1160 // Stream this worktree's entries.
1161 let message = proto::UpdateWorktree {
1162 project_id: project.id.to_proto(),
1163 worktree_id: worktree.id,
1164 abs_path: worktree.abs_path.clone(),
1165 root_name: worktree.root_name,
1166 updated_entries: worktree.updated_entries,
1167 removed_entries: worktree.removed_entries,
1168 scan_id: worktree.scan_id,
1169 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1170 updated_repositories: worktree.updated_repositories,
1171 removed_repositories: worktree.removed_repositories,
1172 };
1173 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1174 session.peer.send(session.connection_id, update.clone())?;
1175 }
1176
1177 // Stream this worktree's diagnostics.
1178 for summary in worktree.diagnostic_summaries {
1179 session.peer.send(
1180 session.connection_id,
1181 proto::UpdateDiagnosticSummary {
1182 project_id: project.id.to_proto(),
1183 worktree_id: worktree.id,
1184 summary: Some(summary),
1185 },
1186 )?;
1187 }
1188
1189 for settings_file in worktree.settings_files {
1190 session.peer.send(
1191 session.connection_id,
1192 proto::UpdateWorktreeSettings {
1193 project_id: project.id.to_proto(),
1194 worktree_id: worktree.id,
1195 path: settings_file.path,
1196 content: Some(settings_file.content),
1197 },
1198 )?;
1199 }
1200 }
1201
1202 for language_server in &project.language_servers {
1203 session.peer.send(
1204 session.connection_id,
1205 proto::UpdateLanguageServer {
1206 project_id: project.id.to_proto(),
1207 language_server_id: language_server.id,
1208 variant: Some(
1209 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1210 proto::LspDiskBasedDiagnosticsUpdated {},
1211 ),
1212 ),
1213 },
1214 )?;
1215 }
1216 }
1217
1218 let rejoined_room = rejoined_room.into_inner();
1219
1220 room = rejoined_room.room;
1221 channel_id = rejoined_room.channel_id;
1222 channel_members = rejoined_room.channel_members;
1223 }
1224
1225 if let Some(channel_id) = channel_id {
1226 channel_updated(
1227 channel_id,
1228 &room,
1229 &channel_members,
1230 &session.peer,
1231 &*session.connection_pool().await,
1232 );
1233 }
1234
1235 update_user_contacts(session.user_id, &session).await?;
1236 Ok(())
1237}
1238
1239async fn leave_room(
1240 _: proto::LeaveRoom,
1241 response: Response<proto::LeaveRoom>,
1242 session: Session,
1243) -> Result<()> {
1244 leave_room_for_session(&session).await?;
1245 response.send(proto::Ack {})?;
1246 Ok(())
1247}
1248
1249async fn call(
1250 request: proto::Call,
1251 response: Response<proto::Call>,
1252 session: Session,
1253) -> Result<()> {
1254 let room_id = RoomId::from_proto(request.room_id);
1255 let calling_user_id = session.user_id;
1256 let calling_connection_id = session.connection_id;
1257 let called_user_id = UserId::from_proto(request.called_user_id);
1258 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1259 if !session
1260 .db()
1261 .await
1262 .has_contact(calling_user_id, called_user_id)
1263 .await?
1264 {
1265 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1266 }
1267
1268 let incoming_call = {
1269 let (room, incoming_call) = &mut *session
1270 .db()
1271 .await
1272 .call(
1273 room_id,
1274 calling_user_id,
1275 calling_connection_id,
1276 called_user_id,
1277 initial_project_id,
1278 )
1279 .await?;
1280 room_updated(&room, &session.peer);
1281 mem::take(incoming_call)
1282 };
1283 update_user_contacts(called_user_id, &session).await?;
1284
1285 let mut calls = session
1286 .connection_pool()
1287 .await
1288 .user_connection_ids(called_user_id)
1289 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1290 .collect::<FuturesUnordered<_>>();
1291
1292 while let Some(call_response) = calls.next().await {
1293 match call_response.as_ref() {
1294 Ok(_) => {
1295 response.send(proto::Ack {})?;
1296 return Ok(());
1297 }
1298 Err(_) => {
1299 call_response.trace_err();
1300 }
1301 }
1302 }
1303
1304 {
1305 let room = session
1306 .db()
1307 .await
1308 .call_failed(room_id, called_user_id)
1309 .await?;
1310 room_updated(&room, &session.peer);
1311 }
1312 update_user_contacts(called_user_id, &session).await?;
1313
1314 Err(anyhow!("failed to ring user"))?
1315}
1316
1317async fn cancel_call(
1318 request: proto::CancelCall,
1319 response: Response<proto::CancelCall>,
1320 session: Session,
1321) -> Result<()> {
1322 let called_user_id = UserId::from_proto(request.called_user_id);
1323 let room_id = RoomId::from_proto(request.room_id);
1324 {
1325 let room = session
1326 .db()
1327 .await
1328 .cancel_call(room_id, session.connection_id, called_user_id)
1329 .await?;
1330 room_updated(&room, &session.peer);
1331 }
1332
1333 for connection_id in session
1334 .connection_pool()
1335 .await
1336 .user_connection_ids(called_user_id)
1337 {
1338 session
1339 .peer
1340 .send(
1341 connection_id,
1342 proto::CallCanceled {
1343 room_id: room_id.to_proto(),
1344 },
1345 )
1346 .trace_err();
1347 }
1348 response.send(proto::Ack {})?;
1349
1350 update_user_contacts(called_user_id, &session).await?;
1351 Ok(())
1352}
1353
1354async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1355 let room_id = RoomId::from_proto(message.room_id);
1356 {
1357 let room = session
1358 .db()
1359 .await
1360 .decline_call(Some(room_id), session.user_id)
1361 .await?
1362 .ok_or_else(|| anyhow!("failed to decline call"))?;
1363 room_updated(&room, &session.peer);
1364 }
1365
1366 for connection_id in session
1367 .connection_pool()
1368 .await
1369 .user_connection_ids(session.user_id)
1370 {
1371 session
1372 .peer
1373 .send(
1374 connection_id,
1375 proto::CallCanceled {
1376 room_id: room_id.to_proto(),
1377 },
1378 )
1379 .trace_err();
1380 }
1381 update_user_contacts(session.user_id, &session).await?;
1382 Ok(())
1383}
1384
1385async fn update_participant_location(
1386 request: proto::UpdateParticipantLocation,
1387 response: Response<proto::UpdateParticipantLocation>,
1388 session: Session,
1389) -> Result<()> {
1390 let room_id = RoomId::from_proto(request.room_id);
1391 let location = request
1392 .location
1393 .ok_or_else(|| anyhow!("invalid location"))?;
1394
1395 let db = session.db().await;
1396 let room = db
1397 .update_room_participant_location(room_id, session.connection_id, location)
1398 .await?;
1399
1400 room_updated(&room, &session.peer);
1401 response.send(proto::Ack {})?;
1402 Ok(())
1403}
1404
1405async fn share_project(
1406 request: proto::ShareProject,
1407 response: Response<proto::ShareProject>,
1408 session: Session,
1409) -> Result<()> {
1410 let (project_id, room) = &*session
1411 .db()
1412 .await
1413 .share_project(
1414 RoomId::from_proto(request.room_id),
1415 session.connection_id,
1416 &request.worktrees,
1417 )
1418 .await?;
1419 response.send(proto::ShareProjectResponse {
1420 project_id: project_id.to_proto(),
1421 })?;
1422 room_updated(&room, &session.peer);
1423
1424 Ok(())
1425}
1426
1427async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1428 let project_id = ProjectId::from_proto(message.project_id);
1429
1430 let (room, guest_connection_ids) = &*session
1431 .db()
1432 .await
1433 .unshare_project(project_id, session.connection_id)
1434 .await?;
1435
1436 broadcast(
1437 Some(session.connection_id),
1438 guest_connection_ids.iter().copied(),
1439 |conn_id| session.peer.send(conn_id, message.clone()),
1440 );
1441 room_updated(&room, &session.peer);
1442
1443 Ok(())
1444}
1445
1446async fn join_project(
1447 request: proto::JoinProject,
1448 response: Response<proto::JoinProject>,
1449 session: Session,
1450) -> Result<()> {
1451 let project_id = ProjectId::from_proto(request.project_id);
1452 let guest_user_id = session.user_id;
1453
1454 tracing::info!(%project_id, "join project");
1455
1456 let (project, replica_id) = &mut *session
1457 .db()
1458 .await
1459 .join_project(project_id, session.connection_id)
1460 .await?;
1461
1462 let collaborators = project
1463 .collaborators
1464 .iter()
1465 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1466 .map(|collaborator| collaborator.to_proto())
1467 .collect::<Vec<_>>();
1468
1469 let worktrees = project
1470 .worktrees
1471 .iter()
1472 .map(|(id, worktree)| proto::WorktreeMetadata {
1473 id: *id,
1474 root_name: worktree.root_name.clone(),
1475 visible: worktree.visible,
1476 abs_path: worktree.abs_path.clone(),
1477 })
1478 .collect::<Vec<_>>();
1479
1480 for collaborator in &collaborators {
1481 session
1482 .peer
1483 .send(
1484 collaborator.peer_id.unwrap().into(),
1485 proto::AddProjectCollaborator {
1486 project_id: project_id.to_proto(),
1487 collaborator: Some(proto::Collaborator {
1488 peer_id: Some(session.connection_id.into()),
1489 replica_id: replica_id.0 as u32,
1490 user_id: guest_user_id.to_proto(),
1491 }),
1492 },
1493 )
1494 .trace_err();
1495 }
1496
1497 // First, we send the metadata associated with each worktree.
1498 response.send(proto::JoinProjectResponse {
1499 worktrees: worktrees.clone(),
1500 replica_id: replica_id.0 as u32,
1501 collaborators: collaborators.clone(),
1502 language_servers: project.language_servers.clone(),
1503 })?;
1504
1505 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1506 #[cfg(any(test, feature = "test-support"))]
1507 const MAX_CHUNK_SIZE: usize = 2;
1508 #[cfg(not(any(test, feature = "test-support")))]
1509 const MAX_CHUNK_SIZE: usize = 256;
1510
1511 // Stream this worktree's entries.
1512 let message = proto::UpdateWorktree {
1513 project_id: project_id.to_proto(),
1514 worktree_id,
1515 abs_path: worktree.abs_path.clone(),
1516 root_name: worktree.root_name,
1517 updated_entries: worktree.entries,
1518 removed_entries: Default::default(),
1519 scan_id: worktree.scan_id,
1520 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1521 updated_repositories: worktree.repository_entries.into_values().collect(),
1522 removed_repositories: Default::default(),
1523 };
1524 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1525 session.peer.send(session.connection_id, update.clone())?;
1526 }
1527
1528 // Stream this worktree's diagnostics.
1529 for summary in worktree.diagnostic_summaries {
1530 session.peer.send(
1531 session.connection_id,
1532 proto::UpdateDiagnosticSummary {
1533 project_id: project_id.to_proto(),
1534 worktree_id: worktree.id,
1535 summary: Some(summary),
1536 },
1537 )?;
1538 }
1539
1540 for settings_file in worktree.settings_files {
1541 session.peer.send(
1542 session.connection_id,
1543 proto::UpdateWorktreeSettings {
1544 project_id: project_id.to_proto(),
1545 worktree_id: worktree.id,
1546 path: settings_file.path,
1547 content: Some(settings_file.content),
1548 },
1549 )?;
1550 }
1551 }
1552
1553 for language_server in &project.language_servers {
1554 session.peer.send(
1555 session.connection_id,
1556 proto::UpdateLanguageServer {
1557 project_id: project_id.to_proto(),
1558 language_server_id: language_server.id,
1559 variant: Some(
1560 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1561 proto::LspDiskBasedDiagnosticsUpdated {},
1562 ),
1563 ),
1564 },
1565 )?;
1566 }
1567
1568 Ok(())
1569}
1570
1571async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1572 let sender_id = session.connection_id;
1573 let project_id = ProjectId::from_proto(request.project_id);
1574
1575 let (room, project) = &*session
1576 .db()
1577 .await
1578 .leave_project(project_id, sender_id)
1579 .await?;
1580 tracing::info!(
1581 %project_id,
1582 host_user_id = %project.host_user_id,
1583 host_connection_id = %project.host_connection_id,
1584 "leave project"
1585 );
1586
1587 project_left(&project, &session);
1588 room_updated(&room, &session.peer);
1589
1590 Ok(())
1591}
1592
1593async fn update_project(
1594 request: proto::UpdateProject,
1595 response: Response<proto::UpdateProject>,
1596 session: Session,
1597) -> Result<()> {
1598 let project_id = ProjectId::from_proto(request.project_id);
1599 let (room, guest_connection_ids) = &*session
1600 .db()
1601 .await
1602 .update_project(project_id, session.connection_id, &request.worktrees)
1603 .await?;
1604 broadcast(
1605 Some(session.connection_id),
1606 guest_connection_ids.iter().copied(),
1607 |connection_id| {
1608 session
1609 .peer
1610 .forward_send(session.connection_id, connection_id, request.clone())
1611 },
1612 );
1613 room_updated(&room, &session.peer);
1614 response.send(proto::Ack {})?;
1615
1616 Ok(())
1617}
1618
1619async fn update_worktree(
1620 request: proto::UpdateWorktree,
1621 response: Response<proto::UpdateWorktree>,
1622 session: Session,
1623) -> Result<()> {
1624 let guest_connection_ids = session
1625 .db()
1626 .await
1627 .update_worktree(&request, session.connection_id)
1628 .await?;
1629
1630 broadcast(
1631 Some(session.connection_id),
1632 guest_connection_ids.iter().copied(),
1633 |connection_id| {
1634 session
1635 .peer
1636 .forward_send(session.connection_id, connection_id, request.clone())
1637 },
1638 );
1639 response.send(proto::Ack {})?;
1640 Ok(())
1641}
1642
1643async fn update_diagnostic_summary(
1644 message: proto::UpdateDiagnosticSummary,
1645 session: Session,
1646) -> Result<()> {
1647 let guest_connection_ids = session
1648 .db()
1649 .await
1650 .update_diagnostic_summary(&message, session.connection_id)
1651 .await?;
1652
1653 broadcast(
1654 Some(session.connection_id),
1655 guest_connection_ids.iter().copied(),
1656 |connection_id| {
1657 session
1658 .peer
1659 .forward_send(session.connection_id, connection_id, message.clone())
1660 },
1661 );
1662
1663 Ok(())
1664}
1665
1666async fn update_worktree_settings(
1667 message: proto::UpdateWorktreeSettings,
1668 session: Session,
1669) -> Result<()> {
1670 let guest_connection_ids = session
1671 .db()
1672 .await
1673 .update_worktree_settings(&message, session.connection_id)
1674 .await?;
1675
1676 broadcast(
1677 Some(session.connection_id),
1678 guest_connection_ids.iter().copied(),
1679 |connection_id| {
1680 session
1681 .peer
1682 .forward_send(session.connection_id, connection_id, message.clone())
1683 },
1684 );
1685
1686 Ok(())
1687}
1688
1689async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1690 broadcast_project_message(request.project_id, request, session).await
1691}
1692
1693async fn start_language_server(
1694 request: proto::StartLanguageServer,
1695 session: Session,
1696) -> Result<()> {
1697 let guest_connection_ids = session
1698 .db()
1699 .await
1700 .start_language_server(&request, session.connection_id)
1701 .await?;
1702
1703 broadcast(
1704 Some(session.connection_id),
1705 guest_connection_ids.iter().copied(),
1706 |connection_id| {
1707 session
1708 .peer
1709 .forward_send(session.connection_id, connection_id, request.clone())
1710 },
1711 );
1712 Ok(())
1713}
1714
1715async fn update_language_server(
1716 request: proto::UpdateLanguageServer,
1717 session: Session,
1718) -> Result<()> {
1719 session.executor.record_backtrace();
1720 let project_id = ProjectId::from_proto(request.project_id);
1721 let project_connection_ids = session
1722 .db()
1723 .await
1724 .project_connection_ids(project_id, session.connection_id)
1725 .await?;
1726 broadcast(
1727 Some(session.connection_id),
1728 project_connection_ids.iter().copied(),
1729 |connection_id| {
1730 session
1731 .peer
1732 .forward_send(session.connection_id, connection_id, request.clone())
1733 },
1734 );
1735 Ok(())
1736}
1737
1738async fn forward_project_request<T>(
1739 request: T,
1740 response: Response<T>,
1741 session: Session,
1742) -> Result<()>
1743where
1744 T: EntityMessage + RequestMessage,
1745{
1746 session.executor.record_backtrace();
1747 let project_id = ProjectId::from_proto(request.remote_entity_id());
1748 let host_connection_id = {
1749 let collaborators = session
1750 .db()
1751 .await
1752 .project_collaborators(project_id, session.connection_id)
1753 .await?;
1754 collaborators
1755 .iter()
1756 .find(|collaborator| collaborator.is_host)
1757 .ok_or_else(|| anyhow!("host not found"))?
1758 .connection_id
1759 };
1760
1761 let payload = session
1762 .peer
1763 .forward_request(session.connection_id, host_connection_id, request)
1764 .await?;
1765
1766 response.send(payload)?;
1767 Ok(())
1768}
1769
1770async fn create_buffer_for_peer(
1771 request: proto::CreateBufferForPeer,
1772 session: Session,
1773) -> Result<()> {
1774 session.executor.record_backtrace();
1775 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1776 session
1777 .peer
1778 .forward_send(session.connection_id, peer_id.into(), request)?;
1779 Ok(())
1780}
1781
1782async fn update_buffer(
1783 request: proto::UpdateBuffer,
1784 response: Response<proto::UpdateBuffer>,
1785 session: Session,
1786) -> Result<()> {
1787 session.executor.record_backtrace();
1788 let project_id = ProjectId::from_proto(request.project_id);
1789 let mut guest_connection_ids;
1790 let mut host_connection_id = None;
1791 {
1792 let collaborators = session
1793 .db()
1794 .await
1795 .project_collaborators(project_id, session.connection_id)
1796 .await?;
1797 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1798 for collaborator in collaborators.iter() {
1799 if collaborator.is_host {
1800 host_connection_id = Some(collaborator.connection_id);
1801 } else {
1802 guest_connection_ids.push(collaborator.connection_id);
1803 }
1804 }
1805 }
1806 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1807
1808 session.executor.record_backtrace();
1809 broadcast(
1810 Some(session.connection_id),
1811 guest_connection_ids,
1812 |connection_id| {
1813 session
1814 .peer
1815 .forward_send(session.connection_id, connection_id, request.clone())
1816 },
1817 );
1818 if host_connection_id != session.connection_id {
1819 session
1820 .peer
1821 .forward_request(session.connection_id, host_connection_id, request.clone())
1822 .await?;
1823 }
1824
1825 response.send(proto::Ack {})?;
1826 Ok(())
1827}
1828
1829async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1830 let project_id = ProjectId::from_proto(request.project_id);
1831 let project_connection_ids = session
1832 .db()
1833 .await
1834 .project_connection_ids(project_id, session.connection_id)
1835 .await?;
1836
1837 broadcast(
1838 Some(session.connection_id),
1839 project_connection_ids.iter().copied(),
1840 |connection_id| {
1841 session
1842 .peer
1843 .forward_send(session.connection_id, connection_id, request.clone())
1844 },
1845 );
1846 Ok(())
1847}
1848
1849async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1850 let project_id = ProjectId::from_proto(request.project_id);
1851 let project_connection_ids = session
1852 .db()
1853 .await
1854 .project_connection_ids(project_id, session.connection_id)
1855 .await?;
1856 broadcast(
1857 Some(session.connection_id),
1858 project_connection_ids.iter().copied(),
1859 |connection_id| {
1860 session
1861 .peer
1862 .forward_send(session.connection_id, connection_id, request.clone())
1863 },
1864 );
1865 Ok(())
1866}
1867
1868async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1869 broadcast_project_message(request.project_id, request, session).await
1870}
1871
1872async fn broadcast_project_message<T: EnvelopedMessage>(
1873 project_id: u64,
1874 request: T,
1875 session: Session,
1876) -> Result<()> {
1877 let project_id = ProjectId::from_proto(project_id);
1878 let project_connection_ids = session
1879 .db()
1880 .await
1881 .project_connection_ids(project_id, session.connection_id)
1882 .await?;
1883 broadcast(
1884 Some(session.connection_id),
1885 project_connection_ids.iter().copied(),
1886 |connection_id| {
1887 session
1888 .peer
1889 .forward_send(session.connection_id, connection_id, request.clone())
1890 },
1891 );
1892 Ok(())
1893}
1894
1895async fn follow(
1896 request: proto::Follow,
1897 response: Response<proto::Follow>,
1898 session: Session,
1899) -> Result<()> {
1900 let room_id = RoomId::from_proto(request.room_id);
1901 let project_id = request.project_id.map(ProjectId::from_proto);
1902 let leader_id = request
1903 .leader_id
1904 .ok_or_else(|| anyhow!("invalid leader id"))?
1905 .into();
1906 let follower_id = session.connection_id;
1907
1908 session
1909 .db()
1910 .await
1911 .check_room_participants(room_id, leader_id, session.connection_id)
1912 .await?;
1913
1914 let response_payload = session
1915 .peer
1916 .forward_request(session.connection_id, leader_id, request)
1917 .await?;
1918 response.send(response_payload)?;
1919
1920 if let Some(project_id) = project_id {
1921 let room = session
1922 .db()
1923 .await
1924 .follow(room_id, project_id, leader_id, follower_id)
1925 .await?;
1926 room_updated(&room, &session.peer);
1927 }
1928
1929 Ok(())
1930}
1931
1932async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1933 let room_id = RoomId::from_proto(request.room_id);
1934 let project_id = request.project_id.map(ProjectId::from_proto);
1935 let leader_id = request
1936 .leader_id
1937 .ok_or_else(|| anyhow!("invalid leader id"))?
1938 .into();
1939 let follower_id = session.connection_id;
1940
1941 session
1942 .db()
1943 .await
1944 .check_room_participants(room_id, leader_id, session.connection_id)
1945 .await?;
1946
1947 session
1948 .peer
1949 .forward_send(session.connection_id, leader_id, request)?;
1950
1951 if let Some(project_id) = project_id {
1952 let room = session
1953 .db()
1954 .await
1955 .unfollow(room_id, project_id, leader_id, follower_id)
1956 .await?;
1957 room_updated(&room, &session.peer);
1958 }
1959
1960 Ok(())
1961}
1962
1963async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1964 let room_id = RoomId::from_proto(request.room_id);
1965 let database = session.db.lock().await;
1966
1967 let connection_ids = if let Some(project_id) = request.project_id {
1968 let project_id = ProjectId::from_proto(project_id);
1969 database
1970 .project_connection_ids(project_id, session.connection_id)
1971 .await?
1972 } else {
1973 database
1974 .room_connection_ids(room_id, session.connection_id)
1975 .await?
1976 };
1977
1978 // For now, don't send view update messages back to that view's current leader.
1979 let connection_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
1980 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1981 _ => None,
1982 });
1983
1984 for follower_peer_id in request.follower_ids.iter().copied() {
1985 let follower_connection_id = follower_peer_id.into();
1986 if Some(follower_peer_id) != connection_id_to_omit
1987 && connection_ids.contains(&follower_connection_id)
1988 {
1989 session.peer.forward_send(
1990 session.connection_id,
1991 follower_connection_id,
1992 request.clone(),
1993 )?;
1994 }
1995 }
1996 Ok(())
1997}
1998
1999async fn get_users(
2000 request: proto::GetUsers,
2001 response: Response<proto::GetUsers>,
2002 session: Session,
2003) -> Result<()> {
2004 let user_ids = request
2005 .user_ids
2006 .into_iter()
2007 .map(UserId::from_proto)
2008 .collect();
2009 let users = session
2010 .db()
2011 .await
2012 .get_users_by_ids(user_ids)
2013 .await?
2014 .into_iter()
2015 .map(|user| proto::User {
2016 id: user.id.to_proto(),
2017 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2018 github_login: user.github_login,
2019 })
2020 .collect();
2021 response.send(proto::UsersResponse { users })?;
2022 Ok(())
2023}
2024
2025async fn fuzzy_search_users(
2026 request: proto::FuzzySearchUsers,
2027 response: Response<proto::FuzzySearchUsers>,
2028 session: Session,
2029) -> Result<()> {
2030 let query = request.query;
2031 let users = match query.len() {
2032 0 => vec![],
2033 1 | 2 => session
2034 .db()
2035 .await
2036 .get_user_by_github_login(&query)
2037 .await?
2038 .into_iter()
2039 .collect(),
2040 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2041 };
2042 let users = users
2043 .into_iter()
2044 .filter(|user| user.id != session.user_id)
2045 .map(|user| proto::User {
2046 id: user.id.to_proto(),
2047 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2048 github_login: user.github_login,
2049 })
2050 .collect();
2051 response.send(proto::UsersResponse { users })?;
2052 Ok(())
2053}
2054
2055async fn request_contact(
2056 request: proto::RequestContact,
2057 response: Response<proto::RequestContact>,
2058 session: Session,
2059) -> Result<()> {
2060 let requester_id = session.user_id;
2061 let responder_id = UserId::from_proto(request.responder_id);
2062 if requester_id == responder_id {
2063 return Err(anyhow!("cannot add yourself as a contact"))?;
2064 }
2065
2066 session
2067 .db()
2068 .await
2069 .send_contact_request(requester_id, responder_id)
2070 .await?;
2071
2072 // Update outgoing contact requests of requester
2073 let mut update = proto::UpdateContacts::default();
2074 update.outgoing_requests.push(responder_id.to_proto());
2075 for connection_id in session
2076 .connection_pool()
2077 .await
2078 .user_connection_ids(requester_id)
2079 {
2080 session.peer.send(connection_id, update.clone())?;
2081 }
2082
2083 // Update incoming contact requests of responder
2084 let mut update = proto::UpdateContacts::default();
2085 update
2086 .incoming_requests
2087 .push(proto::IncomingContactRequest {
2088 requester_id: requester_id.to_proto(),
2089 should_notify: true,
2090 });
2091 for connection_id in session
2092 .connection_pool()
2093 .await
2094 .user_connection_ids(responder_id)
2095 {
2096 session.peer.send(connection_id, update.clone())?;
2097 }
2098
2099 response.send(proto::Ack {})?;
2100 Ok(())
2101}
2102
2103async fn respond_to_contact_request(
2104 request: proto::RespondToContactRequest,
2105 response: Response<proto::RespondToContactRequest>,
2106 session: Session,
2107) -> Result<()> {
2108 let responder_id = session.user_id;
2109 let requester_id = UserId::from_proto(request.requester_id);
2110 let db = session.db().await;
2111 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2112 db.dismiss_contact_notification(responder_id, requester_id)
2113 .await?;
2114 } else {
2115 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2116
2117 db.respond_to_contact_request(responder_id, requester_id, accept)
2118 .await?;
2119 let requester_busy = db.is_user_busy(requester_id).await?;
2120 let responder_busy = db.is_user_busy(responder_id).await?;
2121
2122 let pool = session.connection_pool().await;
2123 // Update responder with new contact
2124 let mut update = proto::UpdateContacts::default();
2125 if accept {
2126 update
2127 .contacts
2128 .push(contact_for_user(requester_id, false, requester_busy, &pool));
2129 }
2130 update
2131 .remove_incoming_requests
2132 .push(requester_id.to_proto());
2133 for connection_id in pool.user_connection_ids(responder_id) {
2134 session.peer.send(connection_id, update.clone())?;
2135 }
2136
2137 // Update requester with new contact
2138 let mut update = proto::UpdateContacts::default();
2139 if accept {
2140 update
2141 .contacts
2142 .push(contact_for_user(responder_id, true, responder_busy, &pool));
2143 }
2144 update
2145 .remove_outgoing_requests
2146 .push(responder_id.to_proto());
2147 for connection_id in pool.user_connection_ids(requester_id) {
2148 session.peer.send(connection_id, update.clone())?;
2149 }
2150 }
2151
2152 response.send(proto::Ack {})?;
2153 Ok(())
2154}
2155
2156async fn remove_contact(
2157 request: proto::RemoveContact,
2158 response: Response<proto::RemoveContact>,
2159 session: Session,
2160) -> Result<()> {
2161 let requester_id = session.user_id;
2162 let responder_id = UserId::from_proto(request.user_id);
2163 let db = session.db().await;
2164 let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
2165
2166 let pool = session.connection_pool().await;
2167 // Update outgoing contact requests of requester
2168 let mut update = proto::UpdateContacts::default();
2169 if contact_accepted {
2170 update.remove_contacts.push(responder_id.to_proto());
2171 } else {
2172 update
2173 .remove_outgoing_requests
2174 .push(responder_id.to_proto());
2175 }
2176 for connection_id in pool.user_connection_ids(requester_id) {
2177 session.peer.send(connection_id, update.clone())?;
2178 }
2179
2180 // Update incoming contact requests of responder
2181 let mut update = proto::UpdateContacts::default();
2182 if contact_accepted {
2183 update.remove_contacts.push(requester_id.to_proto());
2184 } else {
2185 update
2186 .remove_incoming_requests
2187 .push(requester_id.to_proto());
2188 }
2189 for connection_id in pool.user_connection_ids(responder_id) {
2190 session.peer.send(connection_id, update.clone())?;
2191 }
2192
2193 response.send(proto::Ack {})?;
2194 Ok(())
2195}
2196
2197async fn create_channel(
2198 request: proto::CreateChannel,
2199 response: Response<proto::CreateChannel>,
2200 session: Session,
2201) -> Result<()> {
2202 let db = session.db().await;
2203
2204 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2205 let id = db
2206 .create_channel(&request.name, parent_id, session.user_id)
2207 .await?;
2208
2209 response.send(proto::CreateChannelResponse {
2210 channel: Some(proto::Channel {
2211 id: id.to_proto(),
2212 name: request.name,
2213 visibility: proto::ChannelVisibility::Members as i32,
2214 role: proto::ChannelRole::Admin.into(),
2215 }),
2216 parent_id: request.parent_id,
2217 })?;
2218
2219 let Some(parent_id) = parent_id else {
2220 return Ok(());
2221 };
2222
2223 let updates = db
2224 .participants_to_notify_for_channel_change(parent_id, session.user_id)
2225 .await?;
2226
2227 let connection_pool = session.connection_pool().await;
2228 for (user_id, channels) in updates {
2229 let update = build_initial_channels_update(channels, vec![]);
2230 for connection_id in connection_pool.user_connection_ids(user_id) {
2231 if user_id == session.user_id {
2232 continue;
2233 }
2234 session.peer.send(connection_id, update.clone())?;
2235 }
2236 }
2237
2238 Ok(())
2239}
2240
2241async fn delete_channel(
2242 request: proto::DeleteChannel,
2243 response: Response<proto::DeleteChannel>,
2244 session: Session,
2245) -> Result<()> {
2246 let db = session.db().await;
2247
2248 let channel_id = request.channel_id;
2249 let (removed_channels, member_ids) = db
2250 .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2251 .await?;
2252 response.send(proto::Ack {})?;
2253
2254 // Notify members of removed channels
2255 let mut update = proto::UpdateChannels::default();
2256 update
2257 .delete_channels
2258 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2259
2260 let connection_pool = session.connection_pool().await;
2261 for member_id in member_ids {
2262 for connection_id in connection_pool.user_connection_ids(member_id) {
2263 session.peer.send(connection_id, update.clone())?;
2264 }
2265 }
2266
2267 Ok(())
2268}
2269
2270async fn invite_channel_member(
2271 request: proto::InviteChannelMember,
2272 response: Response<proto::InviteChannelMember>,
2273 session: Session,
2274) -> Result<()> {
2275 let db = session.db().await;
2276 let channel_id = ChannelId::from_proto(request.channel_id);
2277 let invitee_id = UserId::from_proto(request.user_id);
2278 db.invite_channel_member(
2279 channel_id,
2280 invitee_id,
2281 session.user_id,
2282 request.role().into(),
2283 )
2284 .await?;
2285
2286 let channel = db.get_channel(channel_id, session.user_id).await?;
2287
2288 let mut update = proto::UpdateChannels::default();
2289 update.channel_invitations.push(proto::Channel {
2290 id: channel.id.to_proto(),
2291 visibility: channel.visibility.into(),
2292 name: channel.name,
2293 role: request.role().into(),
2294 });
2295 for connection_id in session
2296 .connection_pool()
2297 .await
2298 .user_connection_ids(invitee_id)
2299 {
2300 session.peer.send(connection_id, update.clone())?;
2301 }
2302
2303 response.send(proto::Ack {})?;
2304 Ok(())
2305}
2306
2307async fn remove_channel_member(
2308 request: proto::RemoveChannelMember,
2309 response: Response<proto::RemoveChannelMember>,
2310 session: Session,
2311) -> Result<()> {
2312 let db = session.db().await;
2313 let channel_id = ChannelId::from_proto(request.channel_id);
2314 let member_id = UserId::from_proto(request.user_id);
2315
2316 db.remove_channel_member(channel_id, member_id, session.user_id)
2317 .await?;
2318
2319 let mut update = proto::UpdateChannels::default();
2320 update.delete_channels.push(channel_id.to_proto());
2321
2322 for connection_id in session
2323 .connection_pool()
2324 .await
2325 .user_connection_ids(member_id)
2326 {
2327 session.peer.send(connection_id, update.clone())?;
2328 }
2329
2330 response.send(proto::Ack {})?;
2331 Ok(())
2332}
2333
2334async fn set_channel_visibility(
2335 request: proto::SetChannelVisibility,
2336 response: Response<proto::SetChannelVisibility>,
2337 session: Session,
2338) -> Result<()> {
2339 let db = session.db().await;
2340 let channel_id = ChannelId::from_proto(request.channel_id);
2341 let visibility = request.visibility().into();
2342
2343 let previous_members = db
2344 .get_channel_participant_details(channel_id, session.user_id)
2345 .await?;
2346
2347 db.set_channel_visibility(channel_id, visibility, session.user_id)
2348 .await?;
2349
2350 let mut updates: HashMap<UserId, ChannelsForUser> = db
2351 .participants_to_notify_for_channel_change(channel_id, session.user_id)
2352 .await?
2353 .into_iter()
2354 .collect();
2355
2356 let mut participants_who_lost_access: HashSet<UserId> = HashSet::default();
2357 match visibility {
2358 ChannelVisibility::Members => {
2359 for member in previous_members {
2360 if ChannelRole::from(member.role()).can_only_see_public_descendants() {
2361 participants_who_lost_access.insert(UserId::from_proto(member.user_id));
2362 }
2363 }
2364 }
2365 ChannelVisibility::Public => {
2366 if let Some(public_parent_id) = db.public_parent_channel_id(channel_id).await? {
2367 let parent_updates = db
2368 .participants_to_notify_for_channel_change(public_parent_id, session.user_id)
2369 .await?;
2370
2371 for (user_id, channels) in parent_updates {
2372 updates.insert(user_id, channels);
2373 }
2374 }
2375 }
2376 }
2377
2378 let connection_pool = session.connection_pool().await;
2379 for (user_id, channels) in updates {
2380 let update = build_initial_channels_update(channels, vec![]);
2381 for connection_id in connection_pool.user_connection_ids(user_id) {
2382 session.peer.send(connection_id, update.clone())?;
2383 }
2384 }
2385 for user_id in participants_who_lost_access {
2386 let update = proto::UpdateChannels {
2387 delete_channels: vec![channel_id.to_proto()],
2388 ..Default::default()
2389 };
2390 for connection_id in connection_pool.user_connection_ids(user_id) {
2391 session.peer.send(connection_id, update.clone())?;
2392 }
2393 }
2394
2395 response.send(proto::Ack {})?;
2396 Ok(())
2397}
2398
2399async fn set_channel_member_role(
2400 request: proto::SetChannelMemberRole,
2401 response: Response<proto::SetChannelMemberRole>,
2402 session: Session,
2403) -> Result<()> {
2404 let db = session.db().await;
2405 let channel_id = ChannelId::from_proto(request.channel_id);
2406 let member_id = UserId::from_proto(request.user_id);
2407 let channel_member = db
2408 .set_channel_member_role(
2409 channel_id,
2410 session.user_id,
2411 member_id,
2412 request.role().into(),
2413 )
2414 .await?;
2415
2416 let mut update = proto::UpdateChannels::default();
2417 if channel_member.accepted {
2418 let channels = db.get_channel_for_user(channel_id, member_id).await?;
2419 update = build_initial_channels_update(channels, vec![]);
2420 } else {
2421 let channel = db.get_channel(channel_id, session.user_id).await?;
2422 update.channel_invitations.push(proto::Channel {
2423 id: channel_id.to_proto(),
2424 visibility: channel.visibility.into(),
2425 name: channel.name,
2426 role: request.role().into(),
2427 });
2428 }
2429
2430 for connection_id in session
2431 .connection_pool()
2432 .await
2433 .user_connection_ids(member_id)
2434 {
2435 session.peer.send(connection_id, update.clone())?;
2436 }
2437
2438 response.send(proto::Ack {})?;
2439 Ok(())
2440}
2441
2442async fn rename_channel(
2443 request: proto::RenameChannel,
2444 response: Response<proto::RenameChannel>,
2445 session: Session,
2446) -> Result<()> {
2447 let db = session.db().await;
2448 let channel_id = ChannelId::from_proto(request.channel_id);
2449 let channel = db
2450 .rename_channel(channel_id, session.user_id, &request.name)
2451 .await?;
2452
2453 response.send(proto::RenameChannelResponse {
2454 channel: Some(proto::Channel {
2455 id: channel.id.to_proto(),
2456 name: channel.name.clone(),
2457 visibility: channel.visibility.into(),
2458 role: proto::ChannelRole::Admin.into(),
2459 }),
2460 })?;
2461
2462 let members = db
2463 .get_channel_participant_details(channel_id, session.user_id)
2464 .await?;
2465
2466 let connection_pool = session.connection_pool().await;
2467 for member in members {
2468 for connection_id in connection_pool.user_connection_ids(UserId::from_proto(member.user_id))
2469 {
2470 let mut update = proto::UpdateChannels::default();
2471 update.channels.push(proto::Channel {
2472 id: channel.id.to_proto(),
2473 name: channel.name.clone(),
2474 visibility: channel.visibility.into(),
2475 role: member.role.into(),
2476 });
2477
2478 session.peer.send(connection_id, update.clone())?;
2479 }
2480 }
2481
2482 Ok(())
2483}
2484
2485// TODO: Implement in terms of symlinks
2486// Current behavior of this is more like 'Move root channel'
2487async fn link_channel(
2488 request: proto::LinkChannel,
2489 response: Response<proto::LinkChannel>,
2490 session: Session,
2491) -> Result<()> {
2492 let db = session.db().await;
2493 let channel_id = ChannelId::from_proto(request.channel_id);
2494 let to = ChannelId::from_proto(request.to);
2495
2496 // TODO: Remove this restriction once we have symlinks
2497 db.assert_root_channel(channel_id).await?;
2498
2499 db.link_channel(session.user_id, channel_id, to).await?;
2500
2501 let member_updates = db
2502 .participants_to_notify_for_channel_change(to, session.user_id)
2503 .await?;
2504
2505 dbg!(&member_updates);
2506
2507 let connection_pool = session.connection_pool().await;
2508
2509 for (member_id, channels) in member_updates {
2510 let update = build_initial_channels_update(channels, vec![]);
2511 for connection_id in connection_pool.user_connection_ids(member_id) {
2512 session.peer.send(connection_id, update.clone())?;
2513 }
2514 }
2515
2516 response.send(Ack {})?;
2517
2518 Ok(())
2519}
2520
2521// TODO: Implement in terms of symlinks
2522async fn unlink_channel(
2523 _request: proto::UnlinkChannel,
2524 _response: Response<proto::UnlinkChannel>,
2525 _session: Session,
2526) -> Result<()> {
2527 Err(anyhow!("unimplemented").into())
2528}
2529
2530async fn move_channel(
2531 request: proto::MoveChannel,
2532 response: Response<proto::MoveChannel>,
2533 session: Session,
2534) -> Result<()> {
2535 let db = session.db().await;
2536 let channel_id = ChannelId::from_proto(request.channel_id);
2537 let from_parent = ChannelId::from_proto(request.from);
2538 let to = ChannelId::from_proto(request.to);
2539
2540 let previous_participants = db
2541 .get_channel_participant_details(channel_id, session.user_id)
2542 .await?;
2543
2544 debug_assert_eq!(db.parent_channel_id(channel_id).await?, Some(from_parent));
2545
2546 let channels_to_send = db
2547 .move_channel(session.user_id, channel_id, from_parent, to)
2548 .await?;
2549
2550 if channels_to_send.is_empty() {
2551 response.send(Ack {})?;
2552 return Ok(());
2553 }
2554
2555 let updates = db
2556 .participants_to_notify_for_channel_change(to, session.user_id)
2557 .await?;
2558
2559 let mut participants_who_lost_access: HashSet<UserId> = HashSet::default();
2560 let mut channels_to_delete = db.get_channel_descendant_ids(channel_id).await?;
2561 channels_to_delete.insert(channel_id);
2562
2563 for previous_participant in previous_participants.iter() {
2564 let user_id = UserId::from_proto(previous_participant.user_id);
2565 if previous_participant.kind() == proto::channel_member::Kind::AncestorMember {
2566 participants_who_lost_access.insert(user_id);
2567 }
2568 }
2569
2570 let connection_pool = session.connection_pool().await;
2571 for (user_id, channels) in updates {
2572 let mut update = build_initial_channels_update(channels, vec![]);
2573 update.delete_channels = channels_to_delete
2574 .iter()
2575 .map(|channel_id| channel_id.to_proto())
2576 .collect();
2577 participants_who_lost_access.remove(&user_id);
2578 for connection_id in connection_pool.user_connection_ids(user_id) {
2579 session.peer.send(connection_id, update.clone())?;
2580 }
2581 }
2582
2583 for user_id in participants_who_lost_access {
2584 let update = proto::UpdateChannels {
2585 delete_channels: channels_to_delete
2586 .iter()
2587 .map(|channel_id| channel_id.to_proto())
2588 .collect(),
2589 ..Default::default()
2590 };
2591 for connection_id in connection_pool.user_connection_ids(user_id) {
2592 session.peer.send(connection_id, update.clone())?;
2593 }
2594 }
2595
2596 response.send(Ack {})?;
2597
2598 Ok(())
2599}
2600
2601async fn get_channel_members(
2602 request: proto::GetChannelMembers,
2603 response: Response<proto::GetChannelMembers>,
2604 session: Session,
2605) -> Result<()> {
2606 let db = session.db().await;
2607 let channel_id = ChannelId::from_proto(request.channel_id);
2608 let members = db
2609 .get_channel_participant_details(channel_id, session.user_id)
2610 .await?;
2611 response.send(proto::GetChannelMembersResponse { members })?;
2612 Ok(())
2613}
2614
2615async fn respond_to_channel_invite(
2616 request: proto::RespondToChannelInvite,
2617 response: Response<proto::RespondToChannelInvite>,
2618 session: Session,
2619) -> Result<()> {
2620 let db = session.db().await;
2621 let channel_id = ChannelId::from_proto(request.channel_id);
2622 db.respond_to_channel_invite(channel_id, session.user_id, request.accept)
2623 .await?;
2624
2625 if request.accept {
2626 channel_membership_updated(db, channel_id, &session).await?;
2627 } else {
2628 let mut update = proto::UpdateChannels::default();
2629 update
2630 .remove_channel_invitations
2631 .push(channel_id.to_proto());
2632 session.peer.send(session.connection_id, update)?;
2633 }
2634 response.send(proto::Ack {})?;
2635
2636 Ok(())
2637}
2638
2639async fn channel_membership_updated(
2640 db: tokio::sync::MutexGuard<'_, DbHandle>,
2641 channel_id: ChannelId,
2642 session: &Session,
2643) -> Result<(), crate::Error> {
2644 let mut update = proto::UpdateChannels::default();
2645 update
2646 .remove_channel_invitations
2647 .push(channel_id.to_proto());
2648
2649 let result = db.get_channel_for_user(channel_id, session.user_id).await?;
2650 update.channels.extend(
2651 result
2652 .channels
2653 .channels
2654 .into_iter()
2655 .map(|channel| proto::Channel {
2656 id: channel.id.to_proto(),
2657 visibility: channel.visibility.into(),
2658 role: channel.role.into(),
2659 name: channel.name,
2660 }),
2661 );
2662 update.unseen_channel_messages = result.channel_messages;
2663 update.unseen_channel_buffer_changes = result.unseen_buffer_changes;
2664 update.insert_edge = result.channels.edges;
2665 update
2666 .channel_participants
2667 .extend(
2668 result
2669 .channel_participants
2670 .into_iter()
2671 .map(|(channel_id, user_ids)| proto::ChannelParticipants {
2672 channel_id: channel_id.to_proto(),
2673 participant_user_ids: user_ids.into_iter().map(UserId::to_proto).collect(),
2674 }),
2675 );
2676 session.peer.send(session.connection_id, update)?;
2677 Ok(())
2678}
2679
2680async fn join_channel(
2681 request: proto::JoinChannel,
2682 response: Response<proto::JoinChannel>,
2683 session: Session,
2684) -> Result<()> {
2685 let channel_id = ChannelId::from_proto(request.channel_id);
2686 join_channel_internal(channel_id, Box::new(response), session).await
2687}
2688
2689trait JoinChannelInternalResponse {
2690 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2691}
2692impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2693 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2694 Response::<proto::JoinChannel>::send(self, result)
2695 }
2696}
2697impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2698 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2699 Response::<proto::JoinRoom>::send(self, result)
2700 }
2701}
2702
2703async fn join_channel_internal(
2704 channel_id: ChannelId,
2705 response: Box<impl JoinChannelInternalResponse>,
2706 session: Session,
2707) -> Result<()> {
2708 let joined_room = {
2709 leave_room_for_session(&session).await?;
2710 let db = session.db().await;
2711
2712 let (joined_room, joined_channel) = db
2713 .join_channel(
2714 channel_id,
2715 session.user_id,
2716 session.connection_id,
2717 RELEASE_CHANNEL_NAME.as_str(),
2718 )
2719 .await?;
2720
2721 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2722 let token = live_kit
2723 .room_token(
2724 &joined_room.room.live_kit_room,
2725 &session.user_id.to_string(),
2726 )
2727 .trace_err()?;
2728
2729 Some(LiveKitConnectionInfo {
2730 server_url: live_kit.url().into(),
2731 token,
2732 })
2733 });
2734
2735 response.send(proto::JoinRoomResponse {
2736 room: Some(joined_room.room.clone()),
2737 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2738 live_kit_connection_info,
2739 })?;
2740
2741 if let Some(joined_channel) = joined_channel {
2742 channel_membership_updated(db, joined_channel, &session).await?
2743 }
2744
2745 room_updated(&joined_room.room, &session.peer);
2746
2747 joined_room
2748 };
2749
2750 channel_updated(
2751 channel_id,
2752 &joined_room.room,
2753 &joined_room.channel_members,
2754 &session.peer,
2755 &*session.connection_pool().await,
2756 );
2757
2758 update_user_contacts(session.user_id, &session).await?;
2759 Ok(())
2760}
2761
2762async fn join_channel_buffer(
2763 request: proto::JoinChannelBuffer,
2764 response: Response<proto::JoinChannelBuffer>,
2765 session: Session,
2766) -> Result<()> {
2767 let db = session.db().await;
2768 let channel_id = ChannelId::from_proto(request.channel_id);
2769
2770 let open_response = db
2771 .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2772 .await?;
2773
2774 let collaborators = open_response.collaborators.clone();
2775 response.send(open_response)?;
2776
2777 let update = UpdateChannelBufferCollaborators {
2778 channel_id: channel_id.to_proto(),
2779 collaborators: collaborators.clone(),
2780 };
2781 channel_buffer_updated(
2782 session.connection_id,
2783 collaborators
2784 .iter()
2785 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2786 &update,
2787 &session.peer,
2788 );
2789
2790 Ok(())
2791}
2792
2793async fn update_channel_buffer(
2794 request: proto::UpdateChannelBuffer,
2795 session: Session,
2796) -> Result<()> {
2797 let db = session.db().await;
2798 let channel_id = ChannelId::from_proto(request.channel_id);
2799
2800 let (collaborators, non_collaborators, epoch, version) = db
2801 .update_channel_buffer(channel_id, session.user_id, &request.operations)
2802 .await?;
2803
2804 channel_buffer_updated(
2805 session.connection_id,
2806 collaborators,
2807 &proto::UpdateChannelBuffer {
2808 channel_id: channel_id.to_proto(),
2809 operations: request.operations,
2810 },
2811 &session.peer,
2812 );
2813
2814 let pool = &*session.connection_pool().await;
2815
2816 broadcast(
2817 None,
2818 non_collaborators
2819 .iter()
2820 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2821 |peer_id| {
2822 session.peer.send(
2823 peer_id.into(),
2824 proto::UpdateChannels {
2825 unseen_channel_buffer_changes: vec![proto::UnseenChannelBufferChange {
2826 channel_id: channel_id.to_proto(),
2827 epoch: epoch as u64,
2828 version: version.clone(),
2829 }],
2830 ..Default::default()
2831 },
2832 )
2833 },
2834 );
2835
2836 Ok(())
2837}
2838
2839async fn rejoin_channel_buffers(
2840 request: proto::RejoinChannelBuffers,
2841 response: Response<proto::RejoinChannelBuffers>,
2842 session: Session,
2843) -> Result<()> {
2844 let db = session.db().await;
2845 let buffers = db
2846 .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2847 .await?;
2848
2849 for rejoined_buffer in &buffers {
2850 let collaborators_to_notify = rejoined_buffer
2851 .buffer
2852 .collaborators
2853 .iter()
2854 .filter_map(|c| Some(c.peer_id?.into()));
2855 channel_buffer_updated(
2856 session.connection_id,
2857 collaborators_to_notify,
2858 &proto::UpdateChannelBufferCollaborators {
2859 channel_id: rejoined_buffer.buffer.channel_id,
2860 collaborators: rejoined_buffer.buffer.collaborators.clone(),
2861 },
2862 &session.peer,
2863 );
2864 }
2865
2866 response.send(proto::RejoinChannelBuffersResponse {
2867 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2868 })?;
2869
2870 Ok(())
2871}
2872
2873async fn leave_channel_buffer(
2874 request: proto::LeaveChannelBuffer,
2875 response: Response<proto::LeaveChannelBuffer>,
2876 session: Session,
2877) -> Result<()> {
2878 let db = session.db().await;
2879 let channel_id = ChannelId::from_proto(request.channel_id);
2880
2881 let left_buffer = db
2882 .leave_channel_buffer(channel_id, session.connection_id)
2883 .await?;
2884
2885 response.send(Ack {})?;
2886
2887 channel_buffer_updated(
2888 session.connection_id,
2889 left_buffer.connections,
2890 &proto::UpdateChannelBufferCollaborators {
2891 channel_id: channel_id.to_proto(),
2892 collaborators: left_buffer.collaborators,
2893 },
2894 &session.peer,
2895 );
2896
2897 Ok(())
2898}
2899
2900fn channel_buffer_updated<T: EnvelopedMessage>(
2901 sender_id: ConnectionId,
2902 collaborators: impl IntoIterator<Item = ConnectionId>,
2903 message: &T,
2904 peer: &Peer,
2905) {
2906 broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2907 peer.send(peer_id.into(), message.clone())
2908 });
2909}
2910
2911async fn send_channel_message(
2912 request: proto::SendChannelMessage,
2913 response: Response<proto::SendChannelMessage>,
2914 session: Session,
2915) -> Result<()> {
2916 // Validate the message body.
2917 let body = request.body.trim().to_string();
2918 if body.len() > MAX_MESSAGE_LEN {
2919 return Err(anyhow!("message is too long"))?;
2920 }
2921 if body.is_empty() {
2922 return Err(anyhow!("message can't be blank"))?;
2923 }
2924
2925 let timestamp = OffsetDateTime::now_utc();
2926 let nonce = request
2927 .nonce
2928 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
2929
2930 let channel_id = ChannelId::from_proto(request.channel_id);
2931 let (message_id, connection_ids, non_participants) = session
2932 .db()
2933 .await
2934 .create_channel_message(
2935 channel_id,
2936 session.user_id,
2937 &body,
2938 timestamp,
2939 nonce.clone().into(),
2940 )
2941 .await?;
2942 let message = proto::ChannelMessage {
2943 sender_id: session.user_id.to_proto(),
2944 id: message_id.to_proto(),
2945 body,
2946 timestamp: timestamp.unix_timestamp() as u64,
2947 nonce: Some(nonce),
2948 };
2949 broadcast(Some(session.connection_id), connection_ids, |connection| {
2950 session.peer.send(
2951 connection,
2952 proto::ChannelMessageSent {
2953 channel_id: channel_id.to_proto(),
2954 message: Some(message.clone()),
2955 },
2956 )
2957 });
2958 response.send(proto::SendChannelMessageResponse {
2959 message: Some(message),
2960 })?;
2961
2962 let pool = &*session.connection_pool().await;
2963 broadcast(
2964 None,
2965 non_participants
2966 .iter()
2967 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2968 |peer_id| {
2969 session.peer.send(
2970 peer_id.into(),
2971 proto::UpdateChannels {
2972 unseen_channel_messages: vec![proto::UnseenChannelMessage {
2973 channel_id: channel_id.to_proto(),
2974 message_id: message_id.to_proto(),
2975 }],
2976 ..Default::default()
2977 },
2978 )
2979 },
2980 );
2981
2982 Ok(())
2983}
2984
2985async fn remove_channel_message(
2986 request: proto::RemoveChannelMessage,
2987 response: Response<proto::RemoveChannelMessage>,
2988 session: Session,
2989) -> Result<()> {
2990 let channel_id = ChannelId::from_proto(request.channel_id);
2991 let message_id = MessageId::from_proto(request.message_id);
2992 let connection_ids = session
2993 .db()
2994 .await
2995 .remove_channel_message(channel_id, message_id, session.user_id)
2996 .await?;
2997 broadcast(Some(session.connection_id), connection_ids, |connection| {
2998 session.peer.send(connection, request.clone())
2999 });
3000 response.send(proto::Ack {})?;
3001 Ok(())
3002}
3003
3004async fn acknowledge_channel_message(
3005 request: proto::AckChannelMessage,
3006 session: Session,
3007) -> Result<()> {
3008 let channel_id = ChannelId::from_proto(request.channel_id);
3009 let message_id = MessageId::from_proto(request.message_id);
3010 session
3011 .db()
3012 .await
3013 .observe_channel_message(channel_id, session.user_id, message_id)
3014 .await?;
3015 Ok(())
3016}
3017
3018async fn acknowledge_buffer_version(
3019 request: proto::AckBufferOperation,
3020 session: Session,
3021) -> Result<()> {
3022 let buffer_id = BufferId::from_proto(request.buffer_id);
3023 session
3024 .db()
3025 .await
3026 .observe_buffer_version(
3027 buffer_id,
3028 session.user_id,
3029 request.epoch as i32,
3030 &request.version,
3031 )
3032 .await?;
3033 Ok(())
3034}
3035
3036async fn join_channel_chat(
3037 request: proto::JoinChannelChat,
3038 response: Response<proto::JoinChannelChat>,
3039 session: Session,
3040) -> Result<()> {
3041 let channel_id = ChannelId::from_proto(request.channel_id);
3042
3043 let db = session.db().await;
3044 db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3045 .await?;
3046 let messages = db
3047 .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3048 .await?;
3049 response.send(proto::JoinChannelChatResponse {
3050 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3051 messages,
3052 })?;
3053 Ok(())
3054}
3055
3056async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3057 let channel_id = ChannelId::from_proto(request.channel_id);
3058 session
3059 .db()
3060 .await
3061 .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3062 .await?;
3063 Ok(())
3064}
3065
3066async fn get_channel_messages(
3067 request: proto::GetChannelMessages,
3068 response: Response<proto::GetChannelMessages>,
3069 session: Session,
3070) -> Result<()> {
3071 let channel_id = ChannelId::from_proto(request.channel_id);
3072 let messages = session
3073 .db()
3074 .await
3075 .get_channel_messages(
3076 channel_id,
3077 session.user_id,
3078 MESSAGE_COUNT_PER_PAGE,
3079 Some(MessageId::from_proto(request.before_message_id)),
3080 )
3081 .await?;
3082 response.send(proto::GetChannelMessagesResponse {
3083 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3084 messages,
3085 })?;
3086 Ok(())
3087}
3088
3089async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
3090 let project_id = ProjectId::from_proto(request.project_id);
3091 let project_connection_ids = session
3092 .db()
3093 .await
3094 .project_connection_ids(project_id, session.connection_id)
3095 .await?;
3096 broadcast(
3097 Some(session.connection_id),
3098 project_connection_ids.iter().copied(),
3099 |connection_id| {
3100 session
3101 .peer
3102 .forward_send(session.connection_id, connection_id, request.clone())
3103 },
3104 );
3105 Ok(())
3106}
3107
3108async fn get_private_user_info(
3109 _request: proto::GetPrivateUserInfo,
3110 response: Response<proto::GetPrivateUserInfo>,
3111 session: Session,
3112) -> Result<()> {
3113 let db = session.db().await;
3114
3115 let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3116 let user = db
3117 .get_user_by_id(session.user_id)
3118 .await?
3119 .ok_or_else(|| anyhow!("user not found"))?;
3120 let flags = db.get_user_flags(session.user_id).await?;
3121
3122 response.send(proto::GetPrivateUserInfoResponse {
3123 metrics_id,
3124 staff: user.admin,
3125 flags,
3126 })?;
3127 Ok(())
3128}
3129
3130fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3131 match message {
3132 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3133 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3134 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3135 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3136 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3137 code: frame.code.into(),
3138 reason: frame.reason,
3139 })),
3140 }
3141}
3142
3143fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3144 match message {
3145 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3146 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3147 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3148 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3149 AxumMessage::Close(frame) => {
3150 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3151 code: frame.code.into(),
3152 reason: frame.reason,
3153 }))
3154 }
3155 }
3156}
3157
3158fn build_initial_channels_update(
3159 channels: ChannelsForUser,
3160 channel_invites: Vec<db::Channel>,
3161) -> proto::UpdateChannels {
3162 let mut update = proto::UpdateChannels::default();
3163
3164 for channel in channels.channels.channels {
3165 update.channels.push(proto::Channel {
3166 id: channel.id.to_proto(),
3167 name: channel.name,
3168 visibility: channel.visibility.into(),
3169 role: channel.role.into(),
3170 });
3171 }
3172
3173 update.unseen_channel_buffer_changes = channels.unseen_buffer_changes;
3174 update.unseen_channel_messages = channels.channel_messages;
3175 update.insert_edge = channels.channels.edges;
3176
3177 for (channel_id, participants) in channels.channel_participants {
3178 update
3179 .channel_participants
3180 .push(proto::ChannelParticipants {
3181 channel_id: channel_id.to_proto(),
3182 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3183 });
3184 }
3185
3186 for channel in channel_invites {
3187 update.channel_invitations.push(proto::Channel {
3188 id: channel.id.to_proto(),
3189 name: channel.name,
3190 visibility: channel.visibility.into(),
3191 role: channel.role.into(),
3192 });
3193 }
3194
3195 update
3196}
3197
3198fn build_initial_contacts_update(
3199 contacts: Vec<db::Contact>,
3200 pool: &ConnectionPool,
3201) -> proto::UpdateContacts {
3202 let mut update = proto::UpdateContacts::default();
3203
3204 for contact in contacts {
3205 match contact {
3206 db::Contact::Accepted {
3207 user_id,
3208 should_notify,
3209 busy,
3210 } => {
3211 update
3212 .contacts
3213 .push(contact_for_user(user_id, should_notify, busy, &pool));
3214 }
3215 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3216 db::Contact::Incoming {
3217 user_id,
3218 should_notify,
3219 } => update
3220 .incoming_requests
3221 .push(proto::IncomingContactRequest {
3222 requester_id: user_id.to_proto(),
3223 should_notify,
3224 }),
3225 }
3226 }
3227
3228 update
3229}
3230
3231fn contact_for_user(
3232 user_id: UserId,
3233 should_notify: bool,
3234 busy: bool,
3235 pool: &ConnectionPool,
3236) -> proto::Contact {
3237 proto::Contact {
3238 user_id: user_id.to_proto(),
3239 online: pool.is_user_online(user_id),
3240 busy,
3241 should_notify,
3242 }
3243}
3244
3245fn room_updated(room: &proto::Room, peer: &Peer) {
3246 broadcast(
3247 None,
3248 room.participants
3249 .iter()
3250 .filter_map(|participant| Some(participant.peer_id?.into())),
3251 |peer_id| {
3252 peer.send(
3253 peer_id.into(),
3254 proto::RoomUpdated {
3255 room: Some(room.clone()),
3256 },
3257 )
3258 },
3259 );
3260}
3261
3262fn channel_updated(
3263 channel_id: ChannelId,
3264 room: &proto::Room,
3265 channel_members: &[UserId],
3266 peer: &Peer,
3267 pool: &ConnectionPool,
3268) {
3269 let participants = room
3270 .participants
3271 .iter()
3272 .map(|p| p.user_id)
3273 .collect::<Vec<_>>();
3274
3275 broadcast(
3276 None,
3277 channel_members
3278 .iter()
3279 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3280 |peer_id| {
3281 peer.send(
3282 peer_id.into(),
3283 proto::UpdateChannels {
3284 channel_participants: vec![proto::ChannelParticipants {
3285 channel_id: channel_id.to_proto(),
3286 participant_user_ids: participants.clone(),
3287 }],
3288 ..Default::default()
3289 },
3290 )
3291 },
3292 );
3293}
3294
3295async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3296 let db = session.db().await;
3297
3298 let contacts = db.get_contacts(user_id).await?;
3299 let busy = db.is_user_busy(user_id).await?;
3300
3301 let pool = session.connection_pool().await;
3302 let updated_contact = contact_for_user(user_id, false, busy, &pool);
3303 for contact in contacts {
3304 if let db::Contact::Accepted {
3305 user_id: contact_user_id,
3306 ..
3307 } = contact
3308 {
3309 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3310 session
3311 .peer
3312 .send(
3313 contact_conn_id,
3314 proto::UpdateContacts {
3315 contacts: vec![updated_contact.clone()],
3316 remove_contacts: Default::default(),
3317 incoming_requests: Default::default(),
3318 remove_incoming_requests: Default::default(),
3319 outgoing_requests: Default::default(),
3320 remove_outgoing_requests: Default::default(),
3321 },
3322 )
3323 .trace_err();
3324 }
3325 }
3326 }
3327 Ok(())
3328}
3329
3330async fn leave_room_for_session(session: &Session) -> Result<()> {
3331 let mut contacts_to_update = HashSet::default();
3332
3333 let room_id;
3334 let canceled_calls_to_user_ids;
3335 let live_kit_room;
3336 let delete_live_kit_room;
3337 let room;
3338 let channel_members;
3339 let channel_id;
3340
3341 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3342 contacts_to_update.insert(session.user_id);
3343
3344 for project in left_room.left_projects.values() {
3345 project_left(project, session);
3346 }
3347
3348 room_id = RoomId::from_proto(left_room.room.id);
3349 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3350 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3351 delete_live_kit_room = left_room.deleted;
3352 room = mem::take(&mut left_room.room);
3353 channel_members = mem::take(&mut left_room.channel_members);
3354 channel_id = left_room.channel_id;
3355
3356 room_updated(&room, &session.peer);
3357 } else {
3358 return Ok(());
3359 }
3360
3361 if let Some(channel_id) = channel_id {
3362 channel_updated(
3363 channel_id,
3364 &room,
3365 &channel_members,
3366 &session.peer,
3367 &*session.connection_pool().await,
3368 );
3369 }
3370
3371 {
3372 let pool = session.connection_pool().await;
3373 for canceled_user_id in canceled_calls_to_user_ids {
3374 for connection_id in pool.user_connection_ids(canceled_user_id) {
3375 session
3376 .peer
3377 .send(
3378 connection_id,
3379 proto::CallCanceled {
3380 room_id: room_id.to_proto(),
3381 },
3382 )
3383 .trace_err();
3384 }
3385 contacts_to_update.insert(canceled_user_id);
3386 }
3387 }
3388
3389 for contact_user_id in contacts_to_update {
3390 update_user_contacts(contact_user_id, &session).await?;
3391 }
3392
3393 if let Some(live_kit) = session.live_kit_client.as_ref() {
3394 live_kit
3395 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3396 .await
3397 .trace_err();
3398
3399 if delete_live_kit_room {
3400 live_kit.delete_room(live_kit_room).await.trace_err();
3401 }
3402 }
3403
3404 Ok(())
3405}
3406
3407async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3408 let left_channel_buffers = session
3409 .db()
3410 .await
3411 .leave_channel_buffers(session.connection_id)
3412 .await?;
3413
3414 for left_buffer in left_channel_buffers {
3415 channel_buffer_updated(
3416 session.connection_id,
3417 left_buffer.connections,
3418 &proto::UpdateChannelBufferCollaborators {
3419 channel_id: left_buffer.channel_id.to_proto(),
3420 collaborators: left_buffer.collaborators,
3421 },
3422 &session.peer,
3423 );
3424 }
3425
3426 Ok(())
3427}
3428
3429fn project_left(project: &db::LeftProject, session: &Session) {
3430 for connection_id in &project.connection_ids {
3431 if project.host_user_id == session.user_id {
3432 session
3433 .peer
3434 .send(
3435 *connection_id,
3436 proto::UnshareProject {
3437 project_id: project.id.to_proto(),
3438 },
3439 )
3440 .trace_err();
3441 } else {
3442 session
3443 .peer
3444 .send(
3445 *connection_id,
3446 proto::RemoveProjectCollaborator {
3447 project_id: project.id.to_proto(),
3448 peer_id: Some(session.connection_id.into()),
3449 },
3450 )
3451 .trace_err();
3452 }
3453 }
3454}
3455
3456pub trait ResultExt {
3457 type Ok;
3458
3459 fn trace_err(self) -> Option<Self::Ok>;
3460}
3461
3462impl<T, E> ResultExt for Result<T, E>
3463where
3464 E: std::fmt::Debug,
3465{
3466 type Ok = T;
3467
3468 fn trace_err(self) -> Option<T> {
3469 match self {
3470 Ok(value) => Some(value),
3471 Err(error) => {
3472 tracing::error!("{:?}", error);
3473 None
3474 }
3475 }
3476 }
3477}