1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{
6 self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreateChannelResult,
7 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
8 MoveChannelResult, NotificationId, ProjectId, RemoveChannelMemberResult,
9 RenameChannelResult, RespondToChannelInvite, RoomId, ServerId, SetChannelVisibilityResult,
10 User, UserId,
11 },
12 executor::Executor,
13 AppState, Result,
14};
15use anyhow::anyhow;
16use async_tungstenite::tungstenite::{
17 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
18};
19use axum::{
20 body::Body,
21 extract::{
22 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
23 ConnectInfo, WebSocketUpgrade,
24 },
25 headers::{Header, HeaderName},
26 http::StatusCode,
27 middleware,
28 response::IntoResponse,
29 routing::get,
30 Extension, Router, TypedHeader,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::ConnectionPool;
34use futures::{
35 channel::oneshot,
36 future::{self, BoxFuture},
37 stream::FuturesUnordered,
38 FutureExt, SinkExt, StreamExt, TryStreamExt,
39};
40use lazy_static::lazy_static;
41use prometheus::{register_int_gauge, IntGauge};
42use rpc::{
43 proto::{
44 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
45 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
46 },
47 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
48};
49use serde::{Serialize, Serializer};
50use std::{
51 any::TypeId,
52 fmt,
53 future::Future,
54 marker::PhantomData,
55 mem,
56 net::SocketAddr,
57 ops::{Deref, DerefMut},
58 rc::Rc,
59 sync::{
60 atomic::{AtomicBool, Ordering::SeqCst},
61 Arc,
62 },
63 time::{Duration, Instant},
64};
65use time::OffsetDateTime;
66use tokio::sync::{watch, Semaphore};
67use tower::ServiceBuilder;
68use tracing::{info_span, instrument, Instrument};
69
70pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
71pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
72
73const MESSAGE_COUNT_PER_PAGE: usize = 100;
74const MAX_MESSAGE_LEN: usize = 1024;
75const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
76
77lazy_static! {
78 static ref METRIC_CONNECTIONS: IntGauge =
79 register_int_gauge!("connections", "number of connections").unwrap();
80 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
81 "shared_projects",
82 "number of open projects with one or more guests"
83 )
84 .unwrap();
85}
86
87type MessageHandler =
88 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
89
90struct Response<R> {
91 peer: Arc<Peer>,
92 receipt: Receipt<R>,
93 responded: Arc<AtomicBool>,
94}
95
96impl<R: RequestMessage> Response<R> {
97 fn send(self, payload: R::Response) -> Result<()> {
98 self.responded.store(true, SeqCst);
99 self.peer.respond(self.receipt, payload)?;
100 Ok(())
101 }
102}
103
104#[derive(Clone)]
105struct Session {
106 zed_environment: Arc<str>,
107 user_id: UserId,
108 connection_id: ConnectionId,
109 db: Arc<tokio::sync::Mutex<DbHandle>>,
110 peer: Arc<Peer>,
111 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
112 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
113 _executor: Executor,
114}
115
116impl Session {
117 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
118 #[cfg(test)]
119 tokio::task::yield_now().await;
120 let guard = self.db.lock().await;
121 #[cfg(test)]
122 tokio::task::yield_now().await;
123 guard
124 }
125
126 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
127 #[cfg(test)]
128 tokio::task::yield_now().await;
129 let guard = self.connection_pool.lock();
130 ConnectionPoolGuard {
131 guard,
132 _not_send: PhantomData,
133 }
134 }
135}
136
137impl fmt::Debug for Session {
138 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
139 f.debug_struct("Session")
140 .field("user_id", &self.user_id)
141 .field("connection_id", &self.connection_id)
142 .finish()
143 }
144}
145
146struct DbHandle(Arc<Database>);
147
148impl Deref for DbHandle {
149 type Target = Database;
150
151 fn deref(&self) -> &Self::Target {
152 self.0.as_ref()
153 }
154}
155
156pub struct Server {
157 id: parking_lot::Mutex<ServerId>,
158 peer: Arc<Peer>,
159 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
160 app_state: Arc<AppState>,
161 executor: Executor,
162 handlers: HashMap<TypeId, MessageHandler>,
163 teardown: watch::Sender<()>,
164}
165
166pub(crate) struct ConnectionPoolGuard<'a> {
167 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
168 _not_send: PhantomData<Rc<()>>,
169}
170
171#[derive(Serialize)]
172pub struct ServerSnapshot<'a> {
173 peer: &'a Peer,
174 #[serde(serialize_with = "serialize_deref")]
175 connection_pool: ConnectionPoolGuard<'a>,
176}
177
178pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
179where
180 S: Serializer,
181 T: Deref<Target = U>,
182 U: Serialize,
183{
184 Serialize::serialize(value.deref(), serializer)
185}
186
187impl Server {
188 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
189 let mut server = Self {
190 id: parking_lot::Mutex::new(id),
191 peer: Peer::new(id.0 as u32),
192 app_state,
193 executor,
194 connection_pool: Default::default(),
195 handlers: Default::default(),
196 teardown: watch::channel(()).0,
197 };
198
199 server
200 .add_request_handler(ping)
201 .add_request_handler(create_room)
202 .add_request_handler(join_room)
203 .add_request_handler(rejoin_room)
204 .add_request_handler(leave_room)
205 .add_request_handler(call)
206 .add_request_handler(cancel_call)
207 .add_message_handler(decline_call)
208 .add_request_handler(update_participant_location)
209 .add_request_handler(share_project)
210 .add_message_handler(unshare_project)
211 .add_request_handler(join_project)
212 .add_message_handler(leave_project)
213 .add_request_handler(update_project)
214 .add_request_handler(update_worktree)
215 .add_message_handler(start_language_server)
216 .add_message_handler(update_language_server)
217 .add_message_handler(update_diagnostic_summary)
218 .add_message_handler(update_worktree_settings)
219 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
220 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
221 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
222 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
223 .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
224 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
225 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
226 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
227 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
228 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
229 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
230 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
231 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
232 .add_request_handler(
233 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
234 )
235 .add_request_handler(
236 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
237 )
238 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
239 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
240 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
241 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
242 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
243 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
244 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
245 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
246 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
247 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
248 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
249 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
250 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
251 .add_message_handler(create_buffer_for_peer)
252 .add_request_handler(update_buffer)
253 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
254 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
255 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
256 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
257 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
258 .add_request_handler(get_users)
259 .add_request_handler(fuzzy_search_users)
260 .add_request_handler(request_contact)
261 .add_request_handler(remove_contact)
262 .add_request_handler(respond_to_contact_request)
263 .add_request_handler(create_channel)
264 .add_request_handler(delete_channel)
265 .add_request_handler(invite_channel_member)
266 .add_request_handler(remove_channel_member)
267 .add_request_handler(set_channel_member_role)
268 .add_request_handler(set_channel_visibility)
269 .add_request_handler(rename_channel)
270 .add_request_handler(join_channel_buffer)
271 .add_request_handler(leave_channel_buffer)
272 .add_message_handler(update_channel_buffer)
273 .add_request_handler(rejoin_channel_buffers)
274 .add_request_handler(get_channel_members)
275 .add_request_handler(respond_to_channel_invite)
276 .add_request_handler(join_channel)
277 .add_request_handler(join_channel_chat)
278 .add_message_handler(leave_channel_chat)
279 .add_request_handler(send_channel_message)
280 .add_request_handler(remove_channel_message)
281 .add_request_handler(get_channel_messages)
282 .add_request_handler(get_channel_messages_by_id)
283 .add_request_handler(get_notifications)
284 .add_request_handler(mark_notification_as_read)
285 .add_request_handler(move_channel)
286 .add_request_handler(follow)
287 .add_message_handler(unfollow)
288 .add_message_handler(update_followers)
289 .add_request_handler(get_private_user_info)
290 .add_message_handler(acknowledge_channel_message)
291 .add_message_handler(acknowledge_buffer_version);
292
293 Arc::new(server)
294 }
295
296 pub async fn start(&self) -> Result<()> {
297 let server_id = *self.id.lock();
298 let app_state = self.app_state.clone();
299 let peer = self.peer.clone();
300 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
301 let pool = self.connection_pool.clone();
302 let live_kit_client = self.app_state.live_kit_client.clone();
303
304 let span = info_span!("start server");
305 self.executor.spawn_detached(
306 async move {
307 tracing::info!("waiting for cleanup timeout");
308 timeout.await;
309 tracing::info!("cleanup timeout expired, retrieving stale rooms");
310 if let Some((room_ids, channel_ids)) = app_state
311 .db
312 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
313 .await
314 .trace_err()
315 {
316 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
317 tracing::info!(
318 stale_channel_buffer_count = channel_ids.len(),
319 "retrieved stale channel buffers"
320 );
321
322 for channel_id in channel_ids {
323 if let Some(refreshed_channel_buffer) = app_state
324 .db
325 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
326 .await
327 .trace_err()
328 {
329 for connection_id in refreshed_channel_buffer.connection_ids {
330 peer.send(
331 connection_id,
332 proto::UpdateChannelBufferCollaborators {
333 channel_id: channel_id.to_proto(),
334 collaborators: refreshed_channel_buffer
335 .collaborators
336 .clone(),
337 },
338 )
339 .trace_err();
340 }
341 }
342 }
343
344 for room_id in room_ids {
345 let mut contacts_to_update = HashSet::default();
346 let mut canceled_calls_to_user_ids = Vec::new();
347 let mut live_kit_room = String::new();
348 let mut delete_live_kit_room = false;
349
350 if let Some(mut refreshed_room) = app_state
351 .db
352 .clear_stale_room_participants(room_id, server_id)
353 .await
354 .trace_err()
355 {
356 tracing::info!(
357 room_id = room_id.0,
358 new_participant_count = refreshed_room.room.participants.len(),
359 "refreshed room"
360 );
361 room_updated(&refreshed_room.room, &peer);
362 if let Some(channel_id) = refreshed_room.channel_id {
363 channel_updated(
364 channel_id,
365 &refreshed_room.room,
366 &refreshed_room.channel_members,
367 &peer,
368 &*pool.lock(),
369 );
370 }
371 contacts_to_update
372 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
373 contacts_to_update
374 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
375 canceled_calls_to_user_ids =
376 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
377 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
378 delete_live_kit_room = refreshed_room.room.participants.is_empty();
379 }
380
381 {
382 let pool = pool.lock();
383 for canceled_user_id in canceled_calls_to_user_ids {
384 for connection_id in pool.user_connection_ids(canceled_user_id) {
385 peer.send(
386 connection_id,
387 proto::CallCanceled {
388 room_id: room_id.to_proto(),
389 },
390 )
391 .trace_err();
392 }
393 }
394 }
395
396 for user_id in contacts_to_update {
397 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
398 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
399 if let Some((busy, contacts)) = busy.zip(contacts) {
400 let pool = pool.lock();
401 let updated_contact = contact_for_user(user_id, busy, &pool);
402 for contact in contacts {
403 if let db::Contact::Accepted {
404 user_id: contact_user_id,
405 ..
406 } = contact
407 {
408 for contact_conn_id in
409 pool.user_connection_ids(contact_user_id)
410 {
411 peer.send(
412 contact_conn_id,
413 proto::UpdateContacts {
414 contacts: vec![updated_contact.clone()],
415 remove_contacts: Default::default(),
416 incoming_requests: Default::default(),
417 remove_incoming_requests: Default::default(),
418 outgoing_requests: Default::default(),
419 remove_outgoing_requests: Default::default(),
420 },
421 )
422 .trace_err();
423 }
424 }
425 }
426 }
427 }
428
429 if let Some(live_kit) = live_kit_client.as_ref() {
430 if delete_live_kit_room {
431 live_kit.delete_room(live_kit_room).await.trace_err();
432 }
433 }
434 }
435 }
436
437 app_state
438 .db
439 .delete_stale_servers(&app_state.config.zed_environment, server_id)
440 .await
441 .trace_err();
442 }
443 .instrument(span),
444 );
445 Ok(())
446 }
447
448 pub fn teardown(&self) {
449 self.peer.teardown();
450 self.connection_pool.lock().reset();
451 let _ = self.teardown.send(());
452 }
453
454 #[cfg(test)]
455 pub fn reset(&self, id: ServerId) {
456 self.teardown();
457 *self.id.lock() = id;
458 self.peer.reset(id.0 as u32);
459 }
460
461 #[cfg(test)]
462 pub fn id(&self) -> ServerId {
463 *self.id.lock()
464 }
465
466 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
467 where
468 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
469 Fut: 'static + Send + Future<Output = Result<()>>,
470 M: EnvelopedMessage,
471 {
472 let prev_handler = self.handlers.insert(
473 TypeId::of::<M>(),
474 Box::new(move |envelope, session| {
475 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
476 let span = info_span!(
477 "handle message",
478 payload_type = envelope.payload_type_name()
479 );
480 span.in_scope(|| {
481 tracing::info!(
482 payload_type = envelope.payload_type_name(),
483 "message received"
484 );
485 });
486 let start_time = Instant::now();
487 let future = (handler)(*envelope, session);
488 async move {
489 let result = future.await;
490 let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
491 match result {
492 Err(error) => {
493 tracing::error!(%error, ?duration_ms, "error handling message")
494 }
495 Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
496 }
497 }
498 .instrument(span)
499 .boxed()
500 }),
501 );
502 if prev_handler.is_some() {
503 panic!("registered a handler for the same message twice");
504 }
505 self
506 }
507
508 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
509 where
510 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
511 Fut: 'static + Send + Future<Output = Result<()>>,
512 M: EnvelopedMessage,
513 {
514 self.add_handler(move |envelope, session| handler(envelope.payload, session));
515 self
516 }
517
518 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
519 where
520 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
521 Fut: Send + Future<Output = Result<()>>,
522 M: RequestMessage,
523 {
524 let handler = Arc::new(handler);
525 self.add_handler(move |envelope, session| {
526 let receipt = envelope.receipt();
527 let handler = handler.clone();
528 async move {
529 let peer = session.peer.clone();
530 let responded = Arc::new(AtomicBool::default());
531 let response = Response {
532 peer: peer.clone(),
533 responded: responded.clone(),
534 receipt,
535 };
536 match (handler)(envelope.payload, response, session).await {
537 Ok(()) => {
538 if responded.load(std::sync::atomic::Ordering::SeqCst) {
539 Ok(())
540 } else {
541 Err(anyhow!("handler did not send a response"))?
542 }
543 }
544 Err(error) => {
545 peer.respond_with_error(
546 receipt,
547 proto::Error {
548 message: error.to_string(),
549 },
550 )?;
551 Err(error)
552 }
553 }
554 }
555 })
556 }
557
558 pub fn handle_connection(
559 self: &Arc<Self>,
560 connection: Connection,
561 address: String,
562 user: User,
563 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
564 executor: Executor,
565 ) -> impl Future<Output = Result<()>> {
566 let this = self.clone();
567 let user_id = user.id;
568 let login = user.github_login;
569 let span = info_span!("handle connection", %user_id, %login, %address);
570 let mut teardown = self.teardown.subscribe();
571 async move {
572 let (connection_id, handle_io, mut incoming_rx) = this
573 .peer
574 .add_connection(connection, {
575 let executor = executor.clone();
576 move |duration| executor.sleep(duration)
577 });
578
579 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
580 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
581 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
582
583 if let Some(send_connection_id) = send_connection_id.take() {
584 let _ = send_connection_id.send(connection_id);
585 }
586
587 if !user.connected_once {
588 this.peer.send(connection_id, proto::ShowContacts {})?;
589 this.app_state.db.set_user_connected_once(user_id, true).await?;
590 }
591
592 let (contacts, channels_for_user, channel_invites) = future::try_join3(
593 this.app_state.db.get_contacts(user_id),
594 this.app_state.db.get_channels_for_user(user_id),
595 this.app_state.db.get_channel_invites_for_user(user_id),
596 ).await?;
597
598 {
599 let mut pool = this.connection_pool.lock();
600 pool.add_connection(connection_id, user_id, user.admin);
601 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
602 this.peer.send(connection_id, build_channels_update(
603 channels_for_user,
604 channel_invites
605 ))?;
606 }
607
608 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
609 this.peer.send(connection_id, incoming_call)?;
610 }
611
612 let session = Session {
613 user_id,
614 connection_id,
615 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
616 zed_environment: this.app_state.config.zed_environment.clone(),
617 peer: this.peer.clone(),
618 connection_pool: this.connection_pool.clone(),
619 live_kit_client: this.app_state.live_kit_client.clone(),
620 _executor: executor.clone()
621 };
622 update_user_contacts(user_id, &session).await?;
623
624 let handle_io = handle_io.fuse();
625 futures::pin_mut!(handle_io);
626
627 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
628 // This prevents deadlocks when e.g., client A performs a request to client B and
629 // client B performs a request to client A. If both clients stop processing further
630 // messages until their respective request completes, they won't have a chance to
631 // respond to the other client's request and cause a deadlock.
632 //
633 // This arrangement ensures we will attempt to process earlier messages first, but fall
634 // back to processing messages arrived later in the spirit of making progress.
635 let mut foreground_message_handlers = FuturesUnordered::new();
636 let concurrent_handlers = Arc::new(Semaphore::new(256));
637 loop {
638 let next_message = async {
639 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
640 let message = incoming_rx.next().await;
641 (permit, message)
642 }.fuse();
643 futures::pin_mut!(next_message);
644 futures::select_biased! {
645 _ = teardown.changed().fuse() => return Ok(()),
646 result = handle_io => {
647 if let Err(error) = result {
648 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
649 }
650 break;
651 }
652 _ = foreground_message_handlers.next() => {}
653 next_message = next_message => {
654 let (permit, message) = next_message;
655 if let Some(message) = message {
656 let type_name = message.payload_type_name();
657 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
658 let span_enter = span.enter();
659 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
660 let is_background = message.is_background();
661 let handle_message = (handler)(message, session.clone());
662 drop(span_enter);
663
664 let handle_message = async move {
665 handle_message.await;
666 drop(permit);
667 }.instrument(span);
668 if is_background {
669 executor.spawn_detached(handle_message);
670 } else {
671 foreground_message_handlers.push(handle_message);
672 }
673 } else {
674 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
675 }
676 } else {
677 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
678 break;
679 }
680 }
681 }
682 }
683
684 drop(foreground_message_handlers);
685 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
686 if let Err(error) = connection_lost(session, teardown, executor).await {
687 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
688 }
689
690 Ok(())
691 }.instrument(span)
692 }
693
694 pub async fn invite_code_redeemed(
695 self: &Arc<Self>,
696 inviter_id: UserId,
697 invitee_id: UserId,
698 ) -> Result<()> {
699 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
700 if let Some(code) = &user.invite_code {
701 let pool = self.connection_pool.lock();
702 let invitee_contact = contact_for_user(invitee_id, false, &pool);
703 for connection_id in pool.user_connection_ids(inviter_id) {
704 self.peer.send(
705 connection_id,
706 proto::UpdateContacts {
707 contacts: vec![invitee_contact.clone()],
708 ..Default::default()
709 },
710 )?;
711 self.peer.send(
712 connection_id,
713 proto::UpdateInviteInfo {
714 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
715 count: user.invite_count as u32,
716 },
717 )?;
718 }
719 }
720 }
721 Ok(())
722 }
723
724 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
725 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
726 if let Some(invite_code) = &user.invite_code {
727 let pool = self.connection_pool.lock();
728 for connection_id in pool.user_connection_ids(user_id) {
729 self.peer.send(
730 connection_id,
731 proto::UpdateInviteInfo {
732 url: format!(
733 "{}{}",
734 self.app_state.config.invite_link_prefix, invite_code
735 ),
736 count: user.invite_count as u32,
737 },
738 )?;
739 }
740 }
741 }
742 Ok(())
743 }
744
745 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
746 ServerSnapshot {
747 connection_pool: ConnectionPoolGuard {
748 guard: self.connection_pool.lock(),
749 _not_send: PhantomData,
750 },
751 peer: &self.peer,
752 }
753 }
754}
755
756impl<'a> Deref for ConnectionPoolGuard<'a> {
757 type Target = ConnectionPool;
758
759 fn deref(&self) -> &Self::Target {
760 &*self.guard
761 }
762}
763
764impl<'a> DerefMut for ConnectionPoolGuard<'a> {
765 fn deref_mut(&mut self) -> &mut Self::Target {
766 &mut *self.guard
767 }
768}
769
770impl<'a> Drop for ConnectionPoolGuard<'a> {
771 fn drop(&mut self) {
772 #[cfg(test)]
773 self.check_invariants();
774 }
775}
776
777fn broadcast<F>(
778 sender_id: Option<ConnectionId>,
779 receiver_ids: impl IntoIterator<Item = ConnectionId>,
780 mut f: F,
781) where
782 F: FnMut(ConnectionId) -> anyhow::Result<()>,
783{
784 for receiver_id in receiver_ids {
785 if Some(receiver_id) != sender_id {
786 if let Err(error) = f(receiver_id) {
787 tracing::error!("failed to send to {:?} {}", receiver_id, error);
788 }
789 }
790 }
791}
792
793lazy_static! {
794 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
795}
796
797pub struct ProtocolVersion(u32);
798
799impl Header for ProtocolVersion {
800 fn name() -> &'static HeaderName {
801 &ZED_PROTOCOL_VERSION
802 }
803
804 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
805 where
806 Self: Sized,
807 I: Iterator<Item = &'i axum::http::HeaderValue>,
808 {
809 let version = values
810 .next()
811 .ok_or_else(axum::headers::Error::invalid)?
812 .to_str()
813 .map_err(|_| axum::headers::Error::invalid())?
814 .parse()
815 .map_err(|_| axum::headers::Error::invalid())?;
816 Ok(Self(version))
817 }
818
819 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
820 values.extend([self.0.to_string().parse().unwrap()]);
821 }
822}
823
824pub fn routes(server: Arc<Server>) -> Router<Body> {
825 Router::new()
826 .route("/rpc", get(handle_websocket_request))
827 .layer(
828 ServiceBuilder::new()
829 .layer(Extension(server.app_state.clone()))
830 .layer(middleware::from_fn(auth::validate_header)),
831 )
832 .route("/metrics", get(handle_metrics))
833 .layer(Extension(server))
834}
835
836pub async fn handle_websocket_request(
837 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
838 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
839 Extension(server): Extension<Arc<Server>>,
840 Extension(user): Extension<User>,
841 ws: WebSocketUpgrade,
842) -> axum::response::Response {
843 if protocol_version != rpc::PROTOCOL_VERSION {
844 return (
845 StatusCode::UPGRADE_REQUIRED,
846 "client must be upgraded".to_string(),
847 )
848 .into_response();
849 }
850 let socket_address = socket_address.to_string();
851 ws.on_upgrade(move |socket| {
852 use util::ResultExt;
853 let socket = socket
854 .map_ok(to_tungstenite_message)
855 .err_into()
856 .with(|message| async move { Ok(to_axum_message(message)) });
857 let connection = Connection::new(Box::pin(socket));
858 async move {
859 server
860 .handle_connection(connection, socket_address, user, None, Executor::Production)
861 .await
862 .log_err();
863 }
864 })
865}
866
867pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
868 let connections = server
869 .connection_pool
870 .lock()
871 .connections()
872 .filter(|connection| !connection.admin)
873 .count();
874
875 METRIC_CONNECTIONS.set(connections as _);
876
877 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
878 METRIC_SHARED_PROJECTS.set(shared_projects as _);
879
880 let encoder = prometheus::TextEncoder::new();
881 let metric_families = prometheus::gather();
882 let encoded_metrics = encoder
883 .encode_to_string(&metric_families)
884 .map_err(|err| anyhow!("{}", err))?;
885 Ok(encoded_metrics)
886}
887
888#[instrument(err, skip(executor))]
889async fn connection_lost(
890 session: Session,
891 mut teardown: watch::Receiver<()>,
892 executor: Executor,
893) -> Result<()> {
894 session.peer.disconnect(session.connection_id);
895 session
896 .connection_pool()
897 .await
898 .remove_connection(session.connection_id)?;
899
900 session
901 .db()
902 .await
903 .connection_lost(session.connection_id)
904 .await
905 .trace_err();
906
907 futures::select_biased! {
908 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
909 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
910 leave_room_for_session(&session).await.trace_err();
911 leave_channel_buffers_for_session(&session)
912 .await
913 .trace_err();
914
915 if !session
916 .connection_pool()
917 .await
918 .is_user_online(session.user_id)
919 {
920 let db = session.db().await;
921 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
922 room_updated(&room, &session.peer);
923 }
924 }
925
926 update_user_contacts(session.user_id, &session).await?;
927 }
928 _ = teardown.changed().fuse() => {}
929 }
930
931 Ok(())
932}
933
934async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
935 response.send(proto::Ack {})?;
936 Ok(())
937}
938
939async fn create_room(
940 _request: proto::CreateRoom,
941 response: Response<proto::CreateRoom>,
942 session: Session,
943) -> Result<()> {
944 let live_kit_room = nanoid::nanoid!(30);
945
946 let live_kit_connection_info = {
947 let live_kit_room = live_kit_room.clone();
948 let live_kit = session.live_kit_client.as_ref();
949
950 util::async_maybe!({
951 let live_kit = live_kit?;
952
953 let token = live_kit
954 .room_token(&live_kit_room, &session.user_id.to_string())
955 .trace_err()?;
956
957 Some(proto::LiveKitConnectionInfo {
958 server_url: live_kit.url().into(),
959 token,
960 can_publish: true,
961 })
962 })
963 }
964 .await;
965
966 let room = session
967 .db()
968 .await
969 .create_room(
970 session.user_id,
971 session.connection_id,
972 &live_kit_room,
973 &session.zed_environment,
974 )
975 .await?;
976
977 response.send(proto::CreateRoomResponse {
978 room: Some(room.clone()),
979 live_kit_connection_info,
980 })?;
981
982 update_user_contacts(session.user_id, &session).await?;
983 Ok(())
984}
985
986async fn join_room(
987 request: proto::JoinRoom,
988 response: Response<proto::JoinRoom>,
989 session: Session,
990) -> Result<()> {
991 let room_id = RoomId::from_proto(request.id);
992
993 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
994
995 if let Some(channel_id) = channel_id {
996 return join_channel_internal(channel_id, Box::new(response), session).await;
997 }
998
999 let joined_room = {
1000 let room = session
1001 .db()
1002 .await
1003 .join_room(
1004 room_id,
1005 session.user_id,
1006 session.connection_id,
1007 session.zed_environment.as_ref(),
1008 )
1009 .await?;
1010 room_updated(&room.room, &session.peer);
1011 room.into_inner()
1012 };
1013
1014 for connection_id in session
1015 .connection_pool()
1016 .await
1017 .user_connection_ids(session.user_id)
1018 {
1019 session
1020 .peer
1021 .send(
1022 connection_id,
1023 proto::CallCanceled {
1024 room_id: room_id.to_proto(),
1025 },
1026 )
1027 .trace_err();
1028 }
1029
1030 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1031 if let Some(token) = live_kit
1032 .room_token(
1033 &joined_room.room.live_kit_room,
1034 &session.user_id.to_string(),
1035 )
1036 .trace_err()
1037 {
1038 Some(proto::LiveKitConnectionInfo {
1039 server_url: live_kit.url().into(),
1040 token,
1041 can_publish: true,
1042 })
1043 } else {
1044 None
1045 }
1046 } else {
1047 None
1048 };
1049
1050 response.send(proto::JoinRoomResponse {
1051 room: Some(joined_room.room),
1052 channel_id: None,
1053 live_kit_connection_info,
1054 })?;
1055
1056 update_user_contacts(session.user_id, &session).await?;
1057 Ok(())
1058}
1059
1060async fn rejoin_room(
1061 request: proto::RejoinRoom,
1062 response: Response<proto::RejoinRoom>,
1063 session: Session,
1064) -> Result<()> {
1065 let room;
1066 let channel_id;
1067 let channel_members;
1068 {
1069 let mut rejoined_room = session
1070 .db()
1071 .await
1072 .rejoin_room(request, session.user_id, session.connection_id)
1073 .await?;
1074
1075 response.send(proto::RejoinRoomResponse {
1076 room: Some(rejoined_room.room.clone()),
1077 reshared_projects: rejoined_room
1078 .reshared_projects
1079 .iter()
1080 .map(|project| proto::ResharedProject {
1081 id: project.id.to_proto(),
1082 collaborators: project
1083 .collaborators
1084 .iter()
1085 .map(|collaborator| collaborator.to_proto())
1086 .collect(),
1087 })
1088 .collect(),
1089 rejoined_projects: rejoined_room
1090 .rejoined_projects
1091 .iter()
1092 .map(|rejoined_project| proto::RejoinedProject {
1093 id: rejoined_project.id.to_proto(),
1094 worktrees: rejoined_project
1095 .worktrees
1096 .iter()
1097 .map(|worktree| proto::WorktreeMetadata {
1098 id: worktree.id,
1099 root_name: worktree.root_name.clone(),
1100 visible: worktree.visible,
1101 abs_path: worktree.abs_path.clone(),
1102 })
1103 .collect(),
1104 collaborators: rejoined_project
1105 .collaborators
1106 .iter()
1107 .map(|collaborator| collaborator.to_proto())
1108 .collect(),
1109 language_servers: rejoined_project.language_servers.clone(),
1110 })
1111 .collect(),
1112 })?;
1113 room_updated(&rejoined_room.room, &session.peer);
1114
1115 for project in &rejoined_room.reshared_projects {
1116 for collaborator in &project.collaborators {
1117 session
1118 .peer
1119 .send(
1120 collaborator.connection_id,
1121 proto::UpdateProjectCollaborator {
1122 project_id: project.id.to_proto(),
1123 old_peer_id: Some(project.old_connection_id.into()),
1124 new_peer_id: Some(session.connection_id.into()),
1125 },
1126 )
1127 .trace_err();
1128 }
1129
1130 broadcast(
1131 Some(session.connection_id),
1132 project
1133 .collaborators
1134 .iter()
1135 .map(|collaborator| collaborator.connection_id),
1136 |connection_id| {
1137 session.peer.forward_send(
1138 session.connection_id,
1139 connection_id,
1140 proto::UpdateProject {
1141 project_id: project.id.to_proto(),
1142 worktrees: project.worktrees.clone(),
1143 },
1144 )
1145 },
1146 );
1147 }
1148
1149 for project in &rejoined_room.rejoined_projects {
1150 for collaborator in &project.collaborators {
1151 session
1152 .peer
1153 .send(
1154 collaborator.connection_id,
1155 proto::UpdateProjectCollaborator {
1156 project_id: project.id.to_proto(),
1157 old_peer_id: Some(project.old_connection_id.into()),
1158 new_peer_id: Some(session.connection_id.into()),
1159 },
1160 )
1161 .trace_err();
1162 }
1163 }
1164
1165 for project in &mut rejoined_room.rejoined_projects {
1166 for worktree in mem::take(&mut project.worktrees) {
1167 #[cfg(any(test, feature = "test-support"))]
1168 const MAX_CHUNK_SIZE: usize = 2;
1169 #[cfg(not(any(test, feature = "test-support")))]
1170 const MAX_CHUNK_SIZE: usize = 256;
1171
1172 // Stream this worktree's entries.
1173 let message = proto::UpdateWorktree {
1174 project_id: project.id.to_proto(),
1175 worktree_id: worktree.id,
1176 abs_path: worktree.abs_path.clone(),
1177 root_name: worktree.root_name,
1178 updated_entries: worktree.updated_entries,
1179 removed_entries: worktree.removed_entries,
1180 scan_id: worktree.scan_id,
1181 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1182 updated_repositories: worktree.updated_repositories,
1183 removed_repositories: worktree.removed_repositories,
1184 };
1185 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1186 session.peer.send(session.connection_id, update.clone())?;
1187 }
1188
1189 // Stream this worktree's diagnostics.
1190 for summary in worktree.diagnostic_summaries {
1191 session.peer.send(
1192 session.connection_id,
1193 proto::UpdateDiagnosticSummary {
1194 project_id: project.id.to_proto(),
1195 worktree_id: worktree.id,
1196 summary: Some(summary),
1197 },
1198 )?;
1199 }
1200
1201 for settings_file in worktree.settings_files {
1202 session.peer.send(
1203 session.connection_id,
1204 proto::UpdateWorktreeSettings {
1205 project_id: project.id.to_proto(),
1206 worktree_id: worktree.id,
1207 path: settings_file.path,
1208 content: Some(settings_file.content),
1209 },
1210 )?;
1211 }
1212 }
1213
1214 for language_server in &project.language_servers {
1215 session.peer.send(
1216 session.connection_id,
1217 proto::UpdateLanguageServer {
1218 project_id: project.id.to_proto(),
1219 language_server_id: language_server.id,
1220 variant: Some(
1221 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1222 proto::LspDiskBasedDiagnosticsUpdated {},
1223 ),
1224 ),
1225 },
1226 )?;
1227 }
1228 }
1229
1230 let rejoined_room = rejoined_room.into_inner();
1231
1232 room = rejoined_room.room;
1233 channel_id = rejoined_room.channel_id;
1234 channel_members = rejoined_room.channel_members;
1235 }
1236
1237 if let Some(channel_id) = channel_id {
1238 channel_updated(
1239 channel_id,
1240 &room,
1241 &channel_members,
1242 &session.peer,
1243 &*session.connection_pool().await,
1244 );
1245 }
1246
1247 update_user_contacts(session.user_id, &session).await?;
1248 Ok(())
1249}
1250
1251async fn leave_room(
1252 _: proto::LeaveRoom,
1253 response: Response<proto::LeaveRoom>,
1254 session: Session,
1255) -> Result<()> {
1256 leave_room_for_session(&session).await?;
1257 response.send(proto::Ack {})?;
1258 Ok(())
1259}
1260
1261async fn call(
1262 request: proto::Call,
1263 response: Response<proto::Call>,
1264 session: Session,
1265) -> Result<()> {
1266 let room_id = RoomId::from_proto(request.room_id);
1267 let calling_user_id = session.user_id;
1268 let calling_connection_id = session.connection_id;
1269 let called_user_id = UserId::from_proto(request.called_user_id);
1270 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1271 if !session
1272 .db()
1273 .await
1274 .has_contact(calling_user_id, called_user_id)
1275 .await?
1276 {
1277 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1278 }
1279
1280 let incoming_call = {
1281 let (room, incoming_call) = &mut *session
1282 .db()
1283 .await
1284 .call(
1285 room_id,
1286 calling_user_id,
1287 calling_connection_id,
1288 called_user_id,
1289 initial_project_id,
1290 )
1291 .await?;
1292 room_updated(&room, &session.peer);
1293 mem::take(incoming_call)
1294 };
1295 update_user_contacts(called_user_id, &session).await?;
1296
1297 let mut calls = session
1298 .connection_pool()
1299 .await
1300 .user_connection_ids(called_user_id)
1301 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1302 .collect::<FuturesUnordered<_>>();
1303
1304 while let Some(call_response) = calls.next().await {
1305 match call_response.as_ref() {
1306 Ok(_) => {
1307 response.send(proto::Ack {})?;
1308 return Ok(());
1309 }
1310 Err(_) => {
1311 call_response.trace_err();
1312 }
1313 }
1314 }
1315
1316 {
1317 let room = session
1318 .db()
1319 .await
1320 .call_failed(room_id, called_user_id)
1321 .await?;
1322 room_updated(&room, &session.peer);
1323 }
1324 update_user_contacts(called_user_id, &session).await?;
1325
1326 Err(anyhow!("failed to ring user"))?
1327}
1328
1329async fn cancel_call(
1330 request: proto::CancelCall,
1331 response: Response<proto::CancelCall>,
1332 session: Session,
1333) -> Result<()> {
1334 let called_user_id = UserId::from_proto(request.called_user_id);
1335 let room_id = RoomId::from_proto(request.room_id);
1336 {
1337 let room = session
1338 .db()
1339 .await
1340 .cancel_call(room_id, session.connection_id, called_user_id)
1341 .await?;
1342 room_updated(&room, &session.peer);
1343 }
1344
1345 for connection_id in session
1346 .connection_pool()
1347 .await
1348 .user_connection_ids(called_user_id)
1349 {
1350 session
1351 .peer
1352 .send(
1353 connection_id,
1354 proto::CallCanceled {
1355 room_id: room_id.to_proto(),
1356 },
1357 )
1358 .trace_err();
1359 }
1360 response.send(proto::Ack {})?;
1361
1362 update_user_contacts(called_user_id, &session).await?;
1363 Ok(())
1364}
1365
1366async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1367 let room_id = RoomId::from_proto(message.room_id);
1368 {
1369 let room = session
1370 .db()
1371 .await
1372 .decline_call(Some(room_id), session.user_id)
1373 .await?
1374 .ok_or_else(|| anyhow!("failed to decline call"))?;
1375 room_updated(&room, &session.peer);
1376 }
1377
1378 for connection_id in session
1379 .connection_pool()
1380 .await
1381 .user_connection_ids(session.user_id)
1382 {
1383 session
1384 .peer
1385 .send(
1386 connection_id,
1387 proto::CallCanceled {
1388 room_id: room_id.to_proto(),
1389 },
1390 )
1391 .trace_err();
1392 }
1393 update_user_contacts(session.user_id, &session).await?;
1394 Ok(())
1395}
1396
1397async fn update_participant_location(
1398 request: proto::UpdateParticipantLocation,
1399 response: Response<proto::UpdateParticipantLocation>,
1400 session: Session,
1401) -> Result<()> {
1402 let room_id = RoomId::from_proto(request.room_id);
1403 let location = request
1404 .location
1405 .ok_or_else(|| anyhow!("invalid location"))?;
1406
1407 let db = session.db().await;
1408 let room = db
1409 .update_room_participant_location(room_id, session.connection_id, location)
1410 .await?;
1411
1412 room_updated(&room, &session.peer);
1413 response.send(proto::Ack {})?;
1414 Ok(())
1415}
1416
1417async fn share_project(
1418 request: proto::ShareProject,
1419 response: Response<proto::ShareProject>,
1420 session: Session,
1421) -> Result<()> {
1422 let (project_id, room) = &*session
1423 .db()
1424 .await
1425 .share_project(
1426 RoomId::from_proto(request.room_id),
1427 session.connection_id,
1428 &request.worktrees,
1429 )
1430 .await?;
1431 response.send(proto::ShareProjectResponse {
1432 project_id: project_id.to_proto(),
1433 })?;
1434 room_updated(&room, &session.peer);
1435
1436 Ok(())
1437}
1438
1439async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1440 let project_id = ProjectId::from_proto(message.project_id);
1441
1442 let (room, guest_connection_ids) = &*session
1443 .db()
1444 .await
1445 .unshare_project(project_id, session.connection_id)
1446 .await?;
1447
1448 broadcast(
1449 Some(session.connection_id),
1450 guest_connection_ids.iter().copied(),
1451 |conn_id| session.peer.send(conn_id, message.clone()),
1452 );
1453 room_updated(&room, &session.peer);
1454
1455 Ok(())
1456}
1457
1458async fn join_project(
1459 request: proto::JoinProject,
1460 response: Response<proto::JoinProject>,
1461 session: Session,
1462) -> Result<()> {
1463 let project_id = ProjectId::from_proto(request.project_id);
1464 let guest_user_id = session.user_id;
1465
1466 tracing::info!(%project_id, "join project");
1467
1468 let (project, replica_id) = &mut *session
1469 .db()
1470 .await
1471 .join_project(project_id, session.connection_id)
1472 .await?;
1473
1474 let collaborators = project
1475 .collaborators
1476 .iter()
1477 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1478 .map(|collaborator| collaborator.to_proto())
1479 .collect::<Vec<_>>();
1480
1481 let worktrees = project
1482 .worktrees
1483 .iter()
1484 .map(|(id, worktree)| proto::WorktreeMetadata {
1485 id: *id,
1486 root_name: worktree.root_name.clone(),
1487 visible: worktree.visible,
1488 abs_path: worktree.abs_path.clone(),
1489 })
1490 .collect::<Vec<_>>();
1491
1492 for collaborator in &collaborators {
1493 session
1494 .peer
1495 .send(
1496 collaborator.peer_id.unwrap().into(),
1497 proto::AddProjectCollaborator {
1498 project_id: project_id.to_proto(),
1499 collaborator: Some(proto::Collaborator {
1500 peer_id: Some(session.connection_id.into()),
1501 replica_id: replica_id.0 as u32,
1502 user_id: guest_user_id.to_proto(),
1503 }),
1504 },
1505 )
1506 .trace_err();
1507 }
1508
1509 // First, we send the metadata associated with each worktree.
1510 response.send(proto::JoinProjectResponse {
1511 worktrees: worktrees.clone(),
1512 replica_id: replica_id.0 as u32,
1513 collaborators: collaborators.clone(),
1514 language_servers: project.language_servers.clone(),
1515 })?;
1516
1517 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1518 #[cfg(any(test, feature = "test-support"))]
1519 const MAX_CHUNK_SIZE: usize = 2;
1520 #[cfg(not(any(test, feature = "test-support")))]
1521 const MAX_CHUNK_SIZE: usize = 256;
1522
1523 // Stream this worktree's entries.
1524 let message = proto::UpdateWorktree {
1525 project_id: project_id.to_proto(),
1526 worktree_id,
1527 abs_path: worktree.abs_path.clone(),
1528 root_name: worktree.root_name,
1529 updated_entries: worktree.entries,
1530 removed_entries: Default::default(),
1531 scan_id: worktree.scan_id,
1532 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1533 updated_repositories: worktree.repository_entries.into_values().collect(),
1534 removed_repositories: Default::default(),
1535 };
1536 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1537 session.peer.send(session.connection_id, update.clone())?;
1538 }
1539
1540 // Stream this worktree's diagnostics.
1541 for summary in worktree.diagnostic_summaries {
1542 session.peer.send(
1543 session.connection_id,
1544 proto::UpdateDiagnosticSummary {
1545 project_id: project_id.to_proto(),
1546 worktree_id: worktree.id,
1547 summary: Some(summary),
1548 },
1549 )?;
1550 }
1551
1552 for settings_file in worktree.settings_files {
1553 session.peer.send(
1554 session.connection_id,
1555 proto::UpdateWorktreeSettings {
1556 project_id: project_id.to_proto(),
1557 worktree_id: worktree.id,
1558 path: settings_file.path,
1559 content: Some(settings_file.content),
1560 },
1561 )?;
1562 }
1563 }
1564
1565 for language_server in &project.language_servers {
1566 session.peer.send(
1567 session.connection_id,
1568 proto::UpdateLanguageServer {
1569 project_id: project_id.to_proto(),
1570 language_server_id: language_server.id,
1571 variant: Some(
1572 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1573 proto::LspDiskBasedDiagnosticsUpdated {},
1574 ),
1575 ),
1576 },
1577 )?;
1578 }
1579
1580 Ok(())
1581}
1582
1583async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1584 let sender_id = session.connection_id;
1585 let project_id = ProjectId::from_proto(request.project_id);
1586
1587 let (room, project) = &*session
1588 .db()
1589 .await
1590 .leave_project(project_id, sender_id)
1591 .await?;
1592 tracing::info!(
1593 %project_id,
1594 host_user_id = %project.host_user_id,
1595 host_connection_id = %project.host_connection_id,
1596 "leave project"
1597 );
1598
1599 project_left(&project, &session);
1600 room_updated(&room, &session.peer);
1601
1602 Ok(())
1603}
1604
1605async fn update_project(
1606 request: proto::UpdateProject,
1607 response: Response<proto::UpdateProject>,
1608 session: Session,
1609) -> Result<()> {
1610 let project_id = ProjectId::from_proto(request.project_id);
1611 let (room, guest_connection_ids) = &*session
1612 .db()
1613 .await
1614 .update_project(project_id, session.connection_id, &request.worktrees)
1615 .await?;
1616 broadcast(
1617 Some(session.connection_id),
1618 guest_connection_ids.iter().copied(),
1619 |connection_id| {
1620 session
1621 .peer
1622 .forward_send(session.connection_id, connection_id, request.clone())
1623 },
1624 );
1625 room_updated(&room, &session.peer);
1626 response.send(proto::Ack {})?;
1627
1628 Ok(())
1629}
1630
1631async fn update_worktree(
1632 request: proto::UpdateWorktree,
1633 response: Response<proto::UpdateWorktree>,
1634 session: Session,
1635) -> Result<()> {
1636 let guest_connection_ids = session
1637 .db()
1638 .await
1639 .update_worktree(&request, session.connection_id)
1640 .await?;
1641
1642 broadcast(
1643 Some(session.connection_id),
1644 guest_connection_ids.iter().copied(),
1645 |connection_id| {
1646 session
1647 .peer
1648 .forward_send(session.connection_id, connection_id, request.clone())
1649 },
1650 );
1651 response.send(proto::Ack {})?;
1652 Ok(())
1653}
1654
1655async fn update_diagnostic_summary(
1656 message: proto::UpdateDiagnosticSummary,
1657 session: Session,
1658) -> Result<()> {
1659 let guest_connection_ids = session
1660 .db()
1661 .await
1662 .update_diagnostic_summary(&message, session.connection_id)
1663 .await?;
1664
1665 broadcast(
1666 Some(session.connection_id),
1667 guest_connection_ids.iter().copied(),
1668 |connection_id| {
1669 session
1670 .peer
1671 .forward_send(session.connection_id, connection_id, message.clone())
1672 },
1673 );
1674
1675 Ok(())
1676}
1677
1678async fn update_worktree_settings(
1679 message: proto::UpdateWorktreeSettings,
1680 session: Session,
1681) -> Result<()> {
1682 let guest_connection_ids = session
1683 .db()
1684 .await
1685 .update_worktree_settings(&message, session.connection_id)
1686 .await?;
1687
1688 broadcast(
1689 Some(session.connection_id),
1690 guest_connection_ids.iter().copied(),
1691 |connection_id| {
1692 session
1693 .peer
1694 .forward_send(session.connection_id, connection_id, message.clone())
1695 },
1696 );
1697
1698 Ok(())
1699}
1700
1701async fn start_language_server(
1702 request: proto::StartLanguageServer,
1703 session: Session,
1704) -> Result<()> {
1705 let guest_connection_ids = session
1706 .db()
1707 .await
1708 .start_language_server(&request, session.connection_id)
1709 .await?;
1710
1711 broadcast(
1712 Some(session.connection_id),
1713 guest_connection_ids.iter().copied(),
1714 |connection_id| {
1715 session
1716 .peer
1717 .forward_send(session.connection_id, connection_id, request.clone())
1718 },
1719 );
1720 Ok(())
1721}
1722
1723async fn update_language_server(
1724 request: proto::UpdateLanguageServer,
1725 session: Session,
1726) -> Result<()> {
1727 let project_id = ProjectId::from_proto(request.project_id);
1728 let project_connection_ids = session
1729 .db()
1730 .await
1731 .project_connection_ids(project_id, session.connection_id)
1732 .await?;
1733 broadcast(
1734 Some(session.connection_id),
1735 project_connection_ids.iter().copied(),
1736 |connection_id| {
1737 session
1738 .peer
1739 .forward_send(session.connection_id, connection_id, request.clone())
1740 },
1741 );
1742 Ok(())
1743}
1744
1745async fn forward_read_only_project_request<T>(
1746 request: T,
1747 response: Response<T>,
1748 session: Session,
1749) -> Result<()>
1750where
1751 T: EntityMessage + RequestMessage,
1752{
1753 let project_id = ProjectId::from_proto(request.remote_entity_id());
1754 let host_connection_id = session
1755 .db()
1756 .await
1757 .host_for_read_only_project_request(project_id, session.connection_id)
1758 .await?;
1759 let payload = session
1760 .peer
1761 .forward_request(session.connection_id, host_connection_id, request)
1762 .await?;
1763 response.send(payload)?;
1764 Ok(())
1765}
1766
1767async fn forward_mutating_project_request<T>(
1768 request: T,
1769 response: Response<T>,
1770 session: Session,
1771) -> Result<()>
1772where
1773 T: EntityMessage + RequestMessage,
1774{
1775 let project_id = ProjectId::from_proto(request.remote_entity_id());
1776 let host_connection_id = session
1777 .db()
1778 .await
1779 .host_for_mutating_project_request(project_id, session.connection_id)
1780 .await?;
1781 let payload = session
1782 .peer
1783 .forward_request(session.connection_id, host_connection_id, request)
1784 .await?;
1785 response.send(payload)?;
1786 Ok(())
1787}
1788
1789async fn create_buffer_for_peer(
1790 request: proto::CreateBufferForPeer,
1791 session: Session,
1792) -> Result<()> {
1793 session
1794 .db()
1795 .await
1796 .check_user_is_project_host(
1797 ProjectId::from_proto(request.project_id),
1798 session.connection_id,
1799 )
1800 .await?;
1801 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1802 session
1803 .peer
1804 .forward_send(session.connection_id, peer_id.into(), request)?;
1805 Ok(())
1806}
1807
1808async fn update_buffer(
1809 request: proto::UpdateBuffer,
1810 response: Response<proto::UpdateBuffer>,
1811 session: Session,
1812) -> Result<()> {
1813 let project_id = ProjectId::from_proto(request.project_id);
1814 let mut guest_connection_ids;
1815 let mut host_connection_id = None;
1816
1817 let mut requires_write_permission = false;
1818
1819 for op in request.operations.iter() {
1820 match op.variant {
1821 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
1822 Some(_) => requires_write_permission = true,
1823 }
1824 }
1825
1826 {
1827 let collaborators = session
1828 .db()
1829 .await
1830 .project_collaborators_for_buffer_update(
1831 project_id,
1832 session.connection_id,
1833 requires_write_permission,
1834 )
1835 .await?;
1836 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1837 for collaborator in collaborators.iter() {
1838 if collaborator.is_host {
1839 host_connection_id = Some(collaborator.connection_id);
1840 } else {
1841 guest_connection_ids.push(collaborator.connection_id);
1842 }
1843 }
1844 }
1845 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1846
1847 broadcast(
1848 Some(session.connection_id),
1849 guest_connection_ids,
1850 |connection_id| {
1851 session
1852 .peer
1853 .forward_send(session.connection_id, connection_id, request.clone())
1854 },
1855 );
1856 if host_connection_id != session.connection_id {
1857 session
1858 .peer
1859 .forward_request(session.connection_id, host_connection_id, request.clone())
1860 .await?;
1861 }
1862
1863 response.send(proto::Ack {})?;
1864 Ok(())
1865}
1866
1867async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
1868 request: T,
1869 session: Session,
1870) -> Result<()> {
1871 let project_id = ProjectId::from_proto(request.remote_entity_id());
1872 let project_connection_ids = session
1873 .db()
1874 .await
1875 .project_connection_ids(project_id, session.connection_id)
1876 .await?;
1877
1878 broadcast(
1879 Some(session.connection_id),
1880 project_connection_ids.iter().copied(),
1881 |connection_id| {
1882 session
1883 .peer
1884 .forward_send(session.connection_id, connection_id, request.clone())
1885 },
1886 );
1887 Ok(())
1888}
1889
1890async fn follow(
1891 request: proto::Follow,
1892 response: Response<proto::Follow>,
1893 session: Session,
1894) -> Result<()> {
1895 let room_id = RoomId::from_proto(request.room_id);
1896 let project_id = request.project_id.map(ProjectId::from_proto);
1897 let leader_id = request
1898 .leader_id
1899 .ok_or_else(|| anyhow!("invalid leader id"))?
1900 .into();
1901 let follower_id = session.connection_id;
1902
1903 session
1904 .db()
1905 .await
1906 .check_room_participants(room_id, leader_id, session.connection_id)
1907 .await?;
1908
1909 let response_payload = session
1910 .peer
1911 .forward_request(session.connection_id, leader_id, request)
1912 .await?;
1913 response.send(response_payload)?;
1914
1915 if let Some(project_id) = project_id {
1916 let room = session
1917 .db()
1918 .await
1919 .follow(room_id, project_id, leader_id, follower_id)
1920 .await?;
1921 room_updated(&room, &session.peer);
1922 }
1923
1924 Ok(())
1925}
1926
1927async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1928 let room_id = RoomId::from_proto(request.room_id);
1929 let project_id = request.project_id.map(ProjectId::from_proto);
1930 let leader_id = request
1931 .leader_id
1932 .ok_or_else(|| anyhow!("invalid leader id"))?
1933 .into();
1934 let follower_id = session.connection_id;
1935
1936 session
1937 .db()
1938 .await
1939 .check_room_participants(room_id, leader_id, session.connection_id)
1940 .await?;
1941
1942 session
1943 .peer
1944 .forward_send(session.connection_id, leader_id, request)?;
1945
1946 if let Some(project_id) = project_id {
1947 let room = session
1948 .db()
1949 .await
1950 .unfollow(room_id, project_id, leader_id, follower_id)
1951 .await?;
1952 room_updated(&room, &session.peer);
1953 }
1954
1955 Ok(())
1956}
1957
1958async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1959 let room_id = RoomId::from_proto(request.room_id);
1960 let database = session.db.lock().await;
1961
1962 let connection_ids = if let Some(project_id) = request.project_id {
1963 let project_id = ProjectId::from_proto(project_id);
1964 database
1965 .project_connection_ids(project_id, session.connection_id)
1966 .await?
1967 } else {
1968 database
1969 .room_connection_ids(room_id, session.connection_id)
1970 .await?
1971 };
1972
1973 // For now, don't send view update messages back to that view's current leader.
1974 let connection_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
1975 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1976 _ => None,
1977 });
1978
1979 for follower_peer_id in request.follower_ids.iter().copied() {
1980 let follower_connection_id = follower_peer_id.into();
1981 if Some(follower_peer_id) != connection_id_to_omit
1982 && connection_ids.contains(&follower_connection_id)
1983 {
1984 session.peer.forward_send(
1985 session.connection_id,
1986 follower_connection_id,
1987 request.clone(),
1988 )?;
1989 }
1990 }
1991 Ok(())
1992}
1993
1994async fn get_users(
1995 request: proto::GetUsers,
1996 response: Response<proto::GetUsers>,
1997 session: Session,
1998) -> Result<()> {
1999 let user_ids = request
2000 .user_ids
2001 .into_iter()
2002 .map(UserId::from_proto)
2003 .collect();
2004 let users = session
2005 .db()
2006 .await
2007 .get_users_by_ids(user_ids)
2008 .await?
2009 .into_iter()
2010 .map(|user| proto::User {
2011 id: user.id.to_proto(),
2012 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2013 github_login: user.github_login,
2014 })
2015 .collect();
2016 response.send(proto::UsersResponse { users })?;
2017 Ok(())
2018}
2019
2020async fn fuzzy_search_users(
2021 request: proto::FuzzySearchUsers,
2022 response: Response<proto::FuzzySearchUsers>,
2023 session: Session,
2024) -> Result<()> {
2025 let query = request.query;
2026 let users = match query.len() {
2027 0 => vec![],
2028 1 | 2 => session
2029 .db()
2030 .await
2031 .get_user_by_github_login(&query)
2032 .await?
2033 .into_iter()
2034 .collect(),
2035 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2036 };
2037 let users = users
2038 .into_iter()
2039 .filter(|user| user.id != session.user_id)
2040 .map(|user| proto::User {
2041 id: user.id.to_proto(),
2042 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2043 github_login: user.github_login,
2044 })
2045 .collect();
2046 response.send(proto::UsersResponse { users })?;
2047 Ok(())
2048}
2049
2050async fn request_contact(
2051 request: proto::RequestContact,
2052 response: Response<proto::RequestContact>,
2053 session: Session,
2054) -> Result<()> {
2055 let requester_id = session.user_id;
2056 let responder_id = UserId::from_proto(request.responder_id);
2057 if requester_id == responder_id {
2058 return Err(anyhow!("cannot add yourself as a contact"))?;
2059 }
2060
2061 let notifications = session
2062 .db()
2063 .await
2064 .send_contact_request(requester_id, responder_id)
2065 .await?;
2066
2067 // Update outgoing contact requests of requester
2068 let mut update = proto::UpdateContacts::default();
2069 update.outgoing_requests.push(responder_id.to_proto());
2070 for connection_id in session
2071 .connection_pool()
2072 .await
2073 .user_connection_ids(requester_id)
2074 {
2075 session.peer.send(connection_id, update.clone())?;
2076 }
2077
2078 // Update incoming contact requests of responder
2079 let mut update = proto::UpdateContacts::default();
2080 update
2081 .incoming_requests
2082 .push(proto::IncomingContactRequest {
2083 requester_id: requester_id.to_proto(),
2084 });
2085 let connection_pool = session.connection_pool().await;
2086 for connection_id in connection_pool.user_connection_ids(responder_id) {
2087 session.peer.send(connection_id, update.clone())?;
2088 }
2089
2090 send_notifications(&*connection_pool, &session.peer, notifications);
2091
2092 response.send(proto::Ack {})?;
2093 Ok(())
2094}
2095
2096async fn respond_to_contact_request(
2097 request: proto::RespondToContactRequest,
2098 response: Response<proto::RespondToContactRequest>,
2099 session: Session,
2100) -> Result<()> {
2101 let responder_id = session.user_id;
2102 let requester_id = UserId::from_proto(request.requester_id);
2103 let db = session.db().await;
2104 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2105 db.dismiss_contact_notification(responder_id, requester_id)
2106 .await?;
2107 } else {
2108 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2109
2110 let notifications = db
2111 .respond_to_contact_request(responder_id, requester_id, accept)
2112 .await?;
2113 let requester_busy = db.is_user_busy(requester_id).await?;
2114 let responder_busy = db.is_user_busy(responder_id).await?;
2115
2116 let pool = session.connection_pool().await;
2117 // Update responder with new contact
2118 let mut update = proto::UpdateContacts::default();
2119 if accept {
2120 update
2121 .contacts
2122 .push(contact_for_user(requester_id, requester_busy, &pool));
2123 }
2124 update
2125 .remove_incoming_requests
2126 .push(requester_id.to_proto());
2127 for connection_id in pool.user_connection_ids(responder_id) {
2128 session.peer.send(connection_id, update.clone())?;
2129 }
2130
2131 // Update requester with new contact
2132 let mut update = proto::UpdateContacts::default();
2133 if accept {
2134 update
2135 .contacts
2136 .push(contact_for_user(responder_id, responder_busy, &pool));
2137 }
2138 update
2139 .remove_outgoing_requests
2140 .push(responder_id.to_proto());
2141
2142 for connection_id in pool.user_connection_ids(requester_id) {
2143 session.peer.send(connection_id, update.clone())?;
2144 }
2145
2146 send_notifications(&*pool, &session.peer, notifications);
2147 }
2148
2149 response.send(proto::Ack {})?;
2150 Ok(())
2151}
2152
2153async fn remove_contact(
2154 request: proto::RemoveContact,
2155 response: Response<proto::RemoveContact>,
2156 session: Session,
2157) -> Result<()> {
2158 let requester_id = session.user_id;
2159 let responder_id = UserId::from_proto(request.user_id);
2160 let db = session.db().await;
2161 let (contact_accepted, deleted_notification_id) =
2162 db.remove_contact(requester_id, responder_id).await?;
2163
2164 let pool = session.connection_pool().await;
2165 // Update outgoing contact requests of requester
2166 let mut update = proto::UpdateContacts::default();
2167 if contact_accepted {
2168 update.remove_contacts.push(responder_id.to_proto());
2169 } else {
2170 update
2171 .remove_outgoing_requests
2172 .push(responder_id.to_proto());
2173 }
2174 for connection_id in pool.user_connection_ids(requester_id) {
2175 session.peer.send(connection_id, update.clone())?;
2176 }
2177
2178 // Update incoming contact requests of responder
2179 let mut update = proto::UpdateContacts::default();
2180 if contact_accepted {
2181 update.remove_contacts.push(requester_id.to_proto());
2182 } else {
2183 update
2184 .remove_incoming_requests
2185 .push(requester_id.to_proto());
2186 }
2187 for connection_id in pool.user_connection_ids(responder_id) {
2188 session.peer.send(connection_id, update.clone())?;
2189 if let Some(notification_id) = deleted_notification_id {
2190 session.peer.send(
2191 connection_id,
2192 proto::DeleteNotification {
2193 notification_id: notification_id.to_proto(),
2194 },
2195 )?;
2196 }
2197 }
2198
2199 response.send(proto::Ack {})?;
2200 Ok(())
2201}
2202
2203async fn create_channel(
2204 request: proto::CreateChannel,
2205 response: Response<proto::CreateChannel>,
2206 session: Session,
2207) -> Result<()> {
2208 let db = session.db().await;
2209
2210 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2211 let CreateChannelResult {
2212 channel,
2213 participants_to_update,
2214 } = db
2215 .create_channel(&request.name, parent_id, session.user_id)
2216 .await?;
2217
2218 response.send(proto::CreateChannelResponse {
2219 channel: Some(channel.to_proto()),
2220 parent_id: request.parent_id,
2221 })?;
2222
2223 let connection_pool = session.connection_pool().await;
2224 for (user_id, channels) in participants_to_update {
2225 let update = build_channels_update(channels, vec![]);
2226 for connection_id in connection_pool.user_connection_ids(user_id) {
2227 if user_id == session.user_id {
2228 continue;
2229 }
2230 session.peer.send(connection_id, update.clone())?;
2231 }
2232 }
2233
2234 Ok(())
2235}
2236
2237async fn delete_channel(
2238 request: proto::DeleteChannel,
2239 response: Response<proto::DeleteChannel>,
2240 session: Session,
2241) -> Result<()> {
2242 let db = session.db().await;
2243
2244 let channel_id = request.channel_id;
2245 let (removed_channels, member_ids) = db
2246 .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2247 .await?;
2248 response.send(proto::Ack {})?;
2249
2250 // Notify members of removed channels
2251 let mut update = proto::UpdateChannels::default();
2252 update
2253 .delete_channels
2254 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2255
2256 let connection_pool = session.connection_pool().await;
2257 for member_id in member_ids {
2258 for connection_id in connection_pool.user_connection_ids(member_id) {
2259 session.peer.send(connection_id, update.clone())?;
2260 }
2261 }
2262
2263 Ok(())
2264}
2265
2266async fn invite_channel_member(
2267 request: proto::InviteChannelMember,
2268 response: Response<proto::InviteChannelMember>,
2269 session: Session,
2270) -> Result<()> {
2271 let db = session.db().await;
2272 let channel_id = ChannelId::from_proto(request.channel_id);
2273 let invitee_id = UserId::from_proto(request.user_id);
2274 let InviteMemberResult {
2275 channel,
2276 notifications,
2277 } = db
2278 .invite_channel_member(
2279 channel_id,
2280 invitee_id,
2281 session.user_id,
2282 request.role().into(),
2283 )
2284 .await?;
2285
2286 let update = proto::UpdateChannels {
2287 channel_invitations: vec![channel.to_proto()],
2288 ..Default::default()
2289 };
2290
2291 let connection_pool = session.connection_pool().await;
2292 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2293 session.peer.send(connection_id, update.clone())?;
2294 }
2295
2296 send_notifications(&*connection_pool, &session.peer, notifications);
2297
2298 response.send(proto::Ack {})?;
2299 Ok(())
2300}
2301
2302async fn remove_channel_member(
2303 request: proto::RemoveChannelMember,
2304 response: Response<proto::RemoveChannelMember>,
2305 session: Session,
2306) -> Result<()> {
2307 let db = session.db().await;
2308 let channel_id = ChannelId::from_proto(request.channel_id);
2309 let member_id = UserId::from_proto(request.user_id);
2310
2311 let RemoveChannelMemberResult {
2312 membership_update,
2313 notification_id,
2314 } = db
2315 .remove_channel_member(channel_id, member_id, session.user_id)
2316 .await?;
2317
2318 let connection_pool = &session.connection_pool().await;
2319 notify_membership_updated(
2320 &connection_pool,
2321 membership_update,
2322 member_id,
2323 &session.peer,
2324 );
2325 for connection_id in connection_pool.user_connection_ids(member_id) {
2326 if let Some(notification_id) = notification_id {
2327 session
2328 .peer
2329 .send(
2330 connection_id,
2331 proto::DeleteNotification {
2332 notification_id: notification_id.to_proto(),
2333 },
2334 )
2335 .trace_err();
2336 }
2337 }
2338
2339 response.send(proto::Ack {})?;
2340 Ok(())
2341}
2342
2343async fn set_channel_visibility(
2344 request: proto::SetChannelVisibility,
2345 response: Response<proto::SetChannelVisibility>,
2346 session: Session,
2347) -> Result<()> {
2348 let db = session.db().await;
2349 let channel_id = ChannelId::from_proto(request.channel_id);
2350 let visibility = request.visibility().into();
2351
2352 let SetChannelVisibilityResult {
2353 participants_to_update,
2354 participants_to_remove,
2355 channels_to_remove,
2356 } = db
2357 .set_channel_visibility(channel_id, visibility, session.user_id)
2358 .await?;
2359
2360 let connection_pool = session.connection_pool().await;
2361 for (user_id, channels) in participants_to_update {
2362 let update = build_channels_update(channels, vec![]);
2363 for connection_id in connection_pool.user_connection_ids(user_id) {
2364 session.peer.send(connection_id, update.clone())?;
2365 }
2366 }
2367 for user_id in participants_to_remove {
2368 let update = proto::UpdateChannels {
2369 delete_channels: channels_to_remove.iter().map(|id| id.to_proto()).collect(),
2370 ..Default::default()
2371 };
2372 for connection_id in connection_pool.user_connection_ids(user_id) {
2373 session.peer.send(connection_id, update.clone())?;
2374 }
2375 }
2376
2377 response.send(proto::Ack {})?;
2378 Ok(())
2379}
2380
2381async fn set_channel_member_role(
2382 request: proto::SetChannelMemberRole,
2383 response: Response<proto::SetChannelMemberRole>,
2384 session: Session,
2385) -> Result<()> {
2386 let db = session.db().await;
2387 let channel_id = ChannelId::from_proto(request.channel_id);
2388 let member_id = UserId::from_proto(request.user_id);
2389 let result = db
2390 .set_channel_member_role(
2391 channel_id,
2392 session.user_id,
2393 member_id,
2394 request.role().into(),
2395 )
2396 .await?;
2397
2398 match result {
2399 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2400 let connection_pool = session.connection_pool().await;
2401 notify_membership_updated(
2402 &connection_pool,
2403 membership_update,
2404 member_id,
2405 &session.peer,
2406 )
2407 }
2408 db::SetMemberRoleResult::InviteUpdated(channel) => {
2409 let update = proto::UpdateChannels {
2410 channel_invitations: vec![channel.to_proto()],
2411 ..Default::default()
2412 };
2413
2414 for connection_id in session
2415 .connection_pool()
2416 .await
2417 .user_connection_ids(member_id)
2418 {
2419 session.peer.send(connection_id, update.clone())?;
2420 }
2421 }
2422 }
2423
2424 response.send(proto::Ack {})?;
2425 Ok(())
2426}
2427
2428async fn rename_channel(
2429 request: proto::RenameChannel,
2430 response: Response<proto::RenameChannel>,
2431 session: Session,
2432) -> Result<()> {
2433 let db = session.db().await;
2434 let channel_id = ChannelId::from_proto(request.channel_id);
2435 let RenameChannelResult {
2436 channel,
2437 participants_to_update,
2438 } = db
2439 .rename_channel(channel_id, session.user_id, &request.name)
2440 .await?;
2441
2442 response.send(proto::RenameChannelResponse {
2443 channel: Some(channel.to_proto()),
2444 })?;
2445
2446 let connection_pool = session.connection_pool().await;
2447 for (user_id, channel) in participants_to_update {
2448 for connection_id in connection_pool.user_connection_ids(user_id) {
2449 let update = proto::UpdateChannels {
2450 channels: vec![channel.to_proto()],
2451 ..Default::default()
2452 };
2453
2454 session.peer.send(connection_id, update.clone())?;
2455 }
2456 }
2457
2458 Ok(())
2459}
2460
2461async fn move_channel(
2462 request: proto::MoveChannel,
2463 response: Response<proto::MoveChannel>,
2464 session: Session,
2465) -> Result<()> {
2466 let channel_id = ChannelId::from_proto(request.channel_id);
2467 let to = request.to.map(ChannelId::from_proto);
2468
2469 let result = session
2470 .db()
2471 .await
2472 .move_channel(channel_id, to, session.user_id)
2473 .await?;
2474
2475 notify_channel_moved(result, session).await?;
2476
2477 response.send(Ack {})?;
2478 Ok(())
2479}
2480
2481async fn notify_channel_moved(result: Option<MoveChannelResult>, session: Session) -> Result<()> {
2482 let Some(MoveChannelResult {
2483 participants_to_remove,
2484 participants_to_update,
2485 moved_channels,
2486 }) = result
2487 else {
2488 return Ok(());
2489 };
2490 let moved_channels: Vec<u64> = moved_channels.iter().map(|id| id.to_proto()).collect();
2491
2492 let connection_pool = session.connection_pool().await;
2493 for (user_id, channels) in participants_to_update {
2494 let mut update = build_channels_update(channels, vec![]);
2495 update.delete_channels = moved_channels.clone();
2496 for connection_id in connection_pool.user_connection_ids(user_id) {
2497 session.peer.send(connection_id, update.clone())?;
2498 }
2499 }
2500
2501 for user_id in participants_to_remove {
2502 let update = proto::UpdateChannels {
2503 delete_channels: moved_channels.clone(),
2504 ..Default::default()
2505 };
2506 for connection_id in connection_pool.user_connection_ids(user_id) {
2507 session.peer.send(connection_id, update.clone())?;
2508 }
2509 }
2510 Ok(())
2511}
2512
2513async fn get_channel_members(
2514 request: proto::GetChannelMembers,
2515 response: Response<proto::GetChannelMembers>,
2516 session: Session,
2517) -> Result<()> {
2518 let db = session.db().await;
2519 let channel_id = ChannelId::from_proto(request.channel_id);
2520 let members = db
2521 .get_channel_participant_details(channel_id, session.user_id)
2522 .await?;
2523 response.send(proto::GetChannelMembersResponse { members })?;
2524 Ok(())
2525}
2526
2527async fn respond_to_channel_invite(
2528 request: proto::RespondToChannelInvite,
2529 response: Response<proto::RespondToChannelInvite>,
2530 session: Session,
2531) -> Result<()> {
2532 let db = session.db().await;
2533 let channel_id = ChannelId::from_proto(request.channel_id);
2534 let RespondToChannelInvite {
2535 membership_update,
2536 notifications,
2537 } = db
2538 .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2539 .await?;
2540
2541 let connection_pool = session.connection_pool().await;
2542 if let Some(membership_update) = membership_update {
2543 notify_membership_updated(
2544 &connection_pool,
2545 membership_update,
2546 session.user_id,
2547 &session.peer,
2548 );
2549 } else {
2550 let update = proto::UpdateChannels {
2551 remove_channel_invitations: vec![channel_id.to_proto()],
2552 ..Default::default()
2553 };
2554
2555 for connection_id in connection_pool.user_connection_ids(session.user_id) {
2556 session.peer.send(connection_id, update.clone())?;
2557 }
2558 };
2559
2560 send_notifications(&*connection_pool, &session.peer, notifications);
2561
2562 response.send(proto::Ack {})?;
2563
2564 Ok(())
2565}
2566
2567async fn join_channel(
2568 request: proto::JoinChannel,
2569 response: Response<proto::JoinChannel>,
2570 session: Session,
2571) -> Result<()> {
2572 let channel_id = ChannelId::from_proto(request.channel_id);
2573 join_channel_internal(channel_id, Box::new(response), session).await
2574}
2575
2576trait JoinChannelInternalResponse {
2577 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2578}
2579impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2580 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2581 Response::<proto::JoinChannel>::send(self, result)
2582 }
2583}
2584impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2585 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2586 Response::<proto::JoinRoom>::send(self, result)
2587 }
2588}
2589
2590async fn join_channel_internal(
2591 channel_id: ChannelId,
2592 response: Box<impl JoinChannelInternalResponse>,
2593 session: Session,
2594) -> Result<()> {
2595 let joined_room = {
2596 leave_room_for_session(&session).await?;
2597 let db = session.db().await;
2598
2599 let (joined_room, membership_updated, role) = db
2600 .join_channel(
2601 channel_id,
2602 session.user_id,
2603 session.connection_id,
2604 session.zed_environment.as_ref(),
2605 )
2606 .await?;
2607
2608 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2609 let (can_publish, token) = if role == ChannelRole::Guest {
2610 (
2611 false,
2612 live_kit
2613 .guest_token(
2614 &joined_room.room.live_kit_room,
2615 &session.user_id.to_string(),
2616 )
2617 .trace_err()?,
2618 )
2619 } else {
2620 (
2621 true,
2622 live_kit
2623 .room_token(
2624 &joined_room.room.live_kit_room,
2625 &session.user_id.to_string(),
2626 )
2627 .trace_err()?,
2628 )
2629 };
2630
2631 Some(LiveKitConnectionInfo {
2632 server_url: live_kit.url().into(),
2633 token,
2634 can_publish,
2635 })
2636 });
2637
2638 response.send(proto::JoinRoomResponse {
2639 room: Some(joined_room.room.clone()),
2640 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2641 live_kit_connection_info,
2642 })?;
2643
2644 let connection_pool = session.connection_pool().await;
2645 if let Some(membership_updated) = membership_updated {
2646 notify_membership_updated(
2647 &connection_pool,
2648 membership_updated,
2649 session.user_id,
2650 &session.peer,
2651 );
2652 }
2653
2654 room_updated(&joined_room.room, &session.peer);
2655
2656 joined_room
2657 };
2658
2659 channel_updated(
2660 channel_id,
2661 &joined_room.room,
2662 &joined_room.channel_members,
2663 &session.peer,
2664 &*session.connection_pool().await,
2665 );
2666
2667 update_user_contacts(session.user_id, &session).await?;
2668 Ok(())
2669}
2670
2671async fn join_channel_buffer(
2672 request: proto::JoinChannelBuffer,
2673 response: Response<proto::JoinChannelBuffer>,
2674 session: Session,
2675) -> Result<()> {
2676 let db = session.db().await;
2677 let channel_id = ChannelId::from_proto(request.channel_id);
2678
2679 let open_response = db
2680 .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2681 .await?;
2682
2683 let collaborators = open_response.collaborators.clone();
2684 response.send(open_response)?;
2685
2686 let update = UpdateChannelBufferCollaborators {
2687 channel_id: channel_id.to_proto(),
2688 collaborators: collaborators.clone(),
2689 };
2690 channel_buffer_updated(
2691 session.connection_id,
2692 collaborators
2693 .iter()
2694 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2695 &update,
2696 &session.peer,
2697 );
2698
2699 Ok(())
2700}
2701
2702async fn update_channel_buffer(
2703 request: proto::UpdateChannelBuffer,
2704 session: Session,
2705) -> Result<()> {
2706 let db = session.db().await;
2707 let channel_id = ChannelId::from_proto(request.channel_id);
2708
2709 let (collaborators, non_collaborators, epoch, version) = db
2710 .update_channel_buffer(channel_id, session.user_id, &request.operations)
2711 .await?;
2712
2713 channel_buffer_updated(
2714 session.connection_id,
2715 collaborators,
2716 &proto::UpdateChannelBuffer {
2717 channel_id: channel_id.to_proto(),
2718 operations: request.operations,
2719 },
2720 &session.peer,
2721 );
2722
2723 let pool = &*session.connection_pool().await;
2724
2725 broadcast(
2726 None,
2727 non_collaborators
2728 .iter()
2729 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2730 |peer_id| {
2731 session.peer.send(
2732 peer_id.into(),
2733 proto::UpdateChannels {
2734 unseen_channel_buffer_changes: vec![proto::UnseenChannelBufferChange {
2735 channel_id: channel_id.to_proto(),
2736 epoch: epoch as u64,
2737 version: version.clone(),
2738 }],
2739 ..Default::default()
2740 },
2741 )
2742 },
2743 );
2744
2745 Ok(())
2746}
2747
2748async fn rejoin_channel_buffers(
2749 request: proto::RejoinChannelBuffers,
2750 response: Response<proto::RejoinChannelBuffers>,
2751 session: Session,
2752) -> Result<()> {
2753 let db = session.db().await;
2754 let buffers = db
2755 .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2756 .await?;
2757
2758 for rejoined_buffer in &buffers {
2759 let collaborators_to_notify = rejoined_buffer
2760 .buffer
2761 .collaborators
2762 .iter()
2763 .filter_map(|c| Some(c.peer_id?.into()));
2764 channel_buffer_updated(
2765 session.connection_id,
2766 collaborators_to_notify,
2767 &proto::UpdateChannelBufferCollaborators {
2768 channel_id: rejoined_buffer.buffer.channel_id,
2769 collaborators: rejoined_buffer.buffer.collaborators.clone(),
2770 },
2771 &session.peer,
2772 );
2773 }
2774
2775 response.send(proto::RejoinChannelBuffersResponse {
2776 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2777 })?;
2778
2779 Ok(())
2780}
2781
2782async fn leave_channel_buffer(
2783 request: proto::LeaveChannelBuffer,
2784 response: Response<proto::LeaveChannelBuffer>,
2785 session: Session,
2786) -> Result<()> {
2787 let db = session.db().await;
2788 let channel_id = ChannelId::from_proto(request.channel_id);
2789
2790 let left_buffer = db
2791 .leave_channel_buffer(channel_id, session.connection_id)
2792 .await?;
2793
2794 response.send(Ack {})?;
2795
2796 channel_buffer_updated(
2797 session.connection_id,
2798 left_buffer.connections,
2799 &proto::UpdateChannelBufferCollaborators {
2800 channel_id: channel_id.to_proto(),
2801 collaborators: left_buffer.collaborators,
2802 },
2803 &session.peer,
2804 );
2805
2806 Ok(())
2807}
2808
2809fn channel_buffer_updated<T: EnvelopedMessage>(
2810 sender_id: ConnectionId,
2811 collaborators: impl IntoIterator<Item = ConnectionId>,
2812 message: &T,
2813 peer: &Peer,
2814) {
2815 broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2816 peer.send(peer_id.into(), message.clone())
2817 });
2818}
2819
2820fn send_notifications(
2821 connection_pool: &ConnectionPool,
2822 peer: &Peer,
2823 notifications: db::NotificationBatch,
2824) {
2825 for (user_id, notification) in notifications {
2826 for connection_id in connection_pool.user_connection_ids(user_id) {
2827 if let Err(error) = peer.send(
2828 connection_id,
2829 proto::AddNotification {
2830 notification: Some(notification.clone()),
2831 },
2832 ) {
2833 tracing::error!(
2834 "failed to send notification to {:?} {}",
2835 connection_id,
2836 error
2837 );
2838 }
2839 }
2840 }
2841}
2842
2843async fn send_channel_message(
2844 request: proto::SendChannelMessage,
2845 response: Response<proto::SendChannelMessage>,
2846 session: Session,
2847) -> Result<()> {
2848 // Validate the message body.
2849 let body = request.body.trim().to_string();
2850 if body.len() > MAX_MESSAGE_LEN {
2851 return Err(anyhow!("message is too long"))?;
2852 }
2853 if body.is_empty() {
2854 return Err(anyhow!("message can't be blank"))?;
2855 }
2856
2857 // TODO: adjust mentions if body is trimmed
2858
2859 let timestamp = OffsetDateTime::now_utc();
2860 let nonce = request
2861 .nonce
2862 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
2863
2864 let channel_id = ChannelId::from_proto(request.channel_id);
2865 let CreatedChannelMessage {
2866 message_id,
2867 participant_connection_ids,
2868 channel_members,
2869 notifications,
2870 } = session
2871 .db()
2872 .await
2873 .create_channel_message(
2874 channel_id,
2875 session.user_id,
2876 &body,
2877 &request.mentions,
2878 timestamp,
2879 nonce.clone().into(),
2880 )
2881 .await?;
2882 let message = proto::ChannelMessage {
2883 sender_id: session.user_id.to_proto(),
2884 id: message_id.to_proto(),
2885 body,
2886 mentions: request.mentions,
2887 timestamp: timestamp.unix_timestamp() as u64,
2888 nonce: Some(nonce),
2889 };
2890 broadcast(
2891 Some(session.connection_id),
2892 participant_connection_ids,
2893 |connection| {
2894 session.peer.send(
2895 connection,
2896 proto::ChannelMessageSent {
2897 channel_id: channel_id.to_proto(),
2898 message: Some(message.clone()),
2899 },
2900 )
2901 },
2902 );
2903 response.send(proto::SendChannelMessageResponse {
2904 message: Some(message),
2905 })?;
2906
2907 let pool = &*session.connection_pool().await;
2908 broadcast(
2909 None,
2910 channel_members
2911 .iter()
2912 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2913 |peer_id| {
2914 session.peer.send(
2915 peer_id.into(),
2916 proto::UpdateChannels {
2917 unseen_channel_messages: vec![proto::UnseenChannelMessage {
2918 channel_id: channel_id.to_proto(),
2919 message_id: message_id.to_proto(),
2920 }],
2921 ..Default::default()
2922 },
2923 )
2924 },
2925 );
2926 send_notifications(pool, &session.peer, notifications);
2927
2928 Ok(())
2929}
2930
2931async fn remove_channel_message(
2932 request: proto::RemoveChannelMessage,
2933 response: Response<proto::RemoveChannelMessage>,
2934 session: Session,
2935) -> Result<()> {
2936 let channel_id = ChannelId::from_proto(request.channel_id);
2937 let message_id = MessageId::from_proto(request.message_id);
2938 let connection_ids = session
2939 .db()
2940 .await
2941 .remove_channel_message(channel_id, message_id, session.user_id)
2942 .await?;
2943 broadcast(Some(session.connection_id), connection_ids, |connection| {
2944 session.peer.send(connection, request.clone())
2945 });
2946 response.send(proto::Ack {})?;
2947 Ok(())
2948}
2949
2950async fn acknowledge_channel_message(
2951 request: proto::AckChannelMessage,
2952 session: Session,
2953) -> Result<()> {
2954 let channel_id = ChannelId::from_proto(request.channel_id);
2955 let message_id = MessageId::from_proto(request.message_id);
2956 let notifications = session
2957 .db()
2958 .await
2959 .observe_channel_message(channel_id, session.user_id, message_id)
2960 .await?;
2961 send_notifications(
2962 &*session.connection_pool().await,
2963 &session.peer,
2964 notifications,
2965 );
2966 Ok(())
2967}
2968
2969async fn acknowledge_buffer_version(
2970 request: proto::AckBufferOperation,
2971 session: Session,
2972) -> Result<()> {
2973 let buffer_id = BufferId::from_proto(request.buffer_id);
2974 session
2975 .db()
2976 .await
2977 .observe_buffer_version(
2978 buffer_id,
2979 session.user_id,
2980 request.epoch as i32,
2981 &request.version,
2982 )
2983 .await?;
2984 Ok(())
2985}
2986
2987async fn join_channel_chat(
2988 request: proto::JoinChannelChat,
2989 response: Response<proto::JoinChannelChat>,
2990 session: Session,
2991) -> Result<()> {
2992 let channel_id = ChannelId::from_proto(request.channel_id);
2993
2994 let db = session.db().await;
2995 db.join_channel_chat(channel_id, session.connection_id, session.user_id)
2996 .await?;
2997 let messages = db
2998 .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
2999 .await?;
3000 response.send(proto::JoinChannelChatResponse {
3001 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3002 messages,
3003 })?;
3004 Ok(())
3005}
3006
3007async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3008 let channel_id = ChannelId::from_proto(request.channel_id);
3009 session
3010 .db()
3011 .await
3012 .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3013 .await?;
3014 Ok(())
3015}
3016
3017async fn get_channel_messages(
3018 request: proto::GetChannelMessages,
3019 response: Response<proto::GetChannelMessages>,
3020 session: Session,
3021) -> Result<()> {
3022 let channel_id = ChannelId::from_proto(request.channel_id);
3023 let messages = session
3024 .db()
3025 .await
3026 .get_channel_messages(
3027 channel_id,
3028 session.user_id,
3029 MESSAGE_COUNT_PER_PAGE,
3030 Some(MessageId::from_proto(request.before_message_id)),
3031 )
3032 .await?;
3033 response.send(proto::GetChannelMessagesResponse {
3034 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3035 messages,
3036 })?;
3037 Ok(())
3038}
3039
3040async fn get_channel_messages_by_id(
3041 request: proto::GetChannelMessagesById,
3042 response: Response<proto::GetChannelMessagesById>,
3043 session: Session,
3044) -> Result<()> {
3045 let message_ids = request
3046 .message_ids
3047 .iter()
3048 .map(|id| MessageId::from_proto(*id))
3049 .collect::<Vec<_>>();
3050 let messages = session
3051 .db()
3052 .await
3053 .get_channel_messages_by_id(session.user_id, &message_ids)
3054 .await?;
3055 response.send(proto::GetChannelMessagesResponse {
3056 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3057 messages,
3058 })?;
3059 Ok(())
3060}
3061
3062async fn get_notifications(
3063 request: proto::GetNotifications,
3064 response: Response<proto::GetNotifications>,
3065 session: Session,
3066) -> Result<()> {
3067 let notifications = session
3068 .db()
3069 .await
3070 .get_notifications(
3071 session.user_id,
3072 NOTIFICATION_COUNT_PER_PAGE,
3073 request
3074 .before_id
3075 .map(|id| db::NotificationId::from_proto(id)),
3076 )
3077 .await?;
3078 response.send(proto::GetNotificationsResponse {
3079 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3080 notifications,
3081 })?;
3082 Ok(())
3083}
3084
3085async fn mark_notification_as_read(
3086 request: proto::MarkNotificationRead,
3087 response: Response<proto::MarkNotificationRead>,
3088 session: Session,
3089) -> Result<()> {
3090 let database = &session.db().await;
3091 let notifications = database
3092 .mark_notification_as_read_by_id(
3093 session.user_id,
3094 NotificationId::from_proto(request.notification_id),
3095 )
3096 .await?;
3097 send_notifications(
3098 &*session.connection_pool().await,
3099 &session.peer,
3100 notifications,
3101 );
3102 response.send(proto::Ack {})?;
3103 Ok(())
3104}
3105
3106async fn get_private_user_info(
3107 _request: proto::GetPrivateUserInfo,
3108 response: Response<proto::GetPrivateUserInfo>,
3109 session: Session,
3110) -> Result<()> {
3111 let db = session.db().await;
3112
3113 let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3114 let user = db
3115 .get_user_by_id(session.user_id)
3116 .await?
3117 .ok_or_else(|| anyhow!("user not found"))?;
3118 let flags = db.get_user_flags(session.user_id).await?;
3119
3120 response.send(proto::GetPrivateUserInfoResponse {
3121 metrics_id,
3122 staff: user.admin,
3123 flags,
3124 })?;
3125 Ok(())
3126}
3127
3128fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3129 match message {
3130 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3131 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3132 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3133 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3134 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3135 code: frame.code.into(),
3136 reason: frame.reason,
3137 })),
3138 }
3139}
3140
3141fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3142 match message {
3143 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3144 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3145 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3146 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3147 AxumMessage::Close(frame) => {
3148 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3149 code: frame.code.into(),
3150 reason: frame.reason,
3151 }))
3152 }
3153 }
3154}
3155
3156fn notify_membership_updated(
3157 connection_pool: &ConnectionPool,
3158 result: MembershipUpdated,
3159 user_id: UserId,
3160 peer: &Peer,
3161) {
3162 let mut update = build_channels_update(result.new_channels, vec![]);
3163 update.delete_channels = result
3164 .removed_channels
3165 .into_iter()
3166 .map(|id| id.to_proto())
3167 .collect();
3168 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3169
3170 for connection_id in connection_pool.user_connection_ids(user_id) {
3171 peer.send(connection_id, update.clone()).trace_err();
3172 }
3173}
3174
3175fn build_channels_update(
3176 channels: ChannelsForUser,
3177 channel_invites: Vec<db::Channel>,
3178) -> proto::UpdateChannels {
3179 let mut update = proto::UpdateChannels::default();
3180
3181 for channel in channels.channels {
3182 update.channels.push(channel.to_proto());
3183 }
3184
3185 update.unseen_channel_buffer_changes = channels.unseen_buffer_changes;
3186 update.unseen_channel_messages = channels.channel_messages;
3187
3188 for (channel_id, participants) in channels.channel_participants {
3189 update
3190 .channel_participants
3191 .push(proto::ChannelParticipants {
3192 channel_id: channel_id.to_proto(),
3193 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3194 });
3195 }
3196
3197 for channel in channel_invites {
3198 update.channel_invitations.push(channel.to_proto());
3199 }
3200
3201 update
3202}
3203
3204fn build_initial_contacts_update(
3205 contacts: Vec<db::Contact>,
3206 pool: &ConnectionPool,
3207) -> proto::UpdateContacts {
3208 let mut update = proto::UpdateContacts::default();
3209
3210 for contact in contacts {
3211 match contact {
3212 db::Contact::Accepted { user_id, busy } => {
3213 update.contacts.push(contact_for_user(user_id, busy, &pool));
3214 }
3215 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3216 db::Contact::Incoming { user_id } => {
3217 update
3218 .incoming_requests
3219 .push(proto::IncomingContactRequest {
3220 requester_id: user_id.to_proto(),
3221 })
3222 }
3223 }
3224 }
3225
3226 update
3227}
3228
3229fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3230 proto::Contact {
3231 user_id: user_id.to_proto(),
3232 online: pool.is_user_online(user_id),
3233 busy,
3234 }
3235}
3236
3237fn room_updated(room: &proto::Room, peer: &Peer) {
3238 broadcast(
3239 None,
3240 room.participants
3241 .iter()
3242 .filter_map(|participant| Some(participant.peer_id?.into())),
3243 |peer_id| {
3244 peer.send(
3245 peer_id.into(),
3246 proto::RoomUpdated {
3247 room: Some(room.clone()),
3248 },
3249 )
3250 },
3251 );
3252}
3253
3254fn channel_updated(
3255 channel_id: ChannelId,
3256 room: &proto::Room,
3257 channel_members: &[UserId],
3258 peer: &Peer,
3259 pool: &ConnectionPool,
3260) {
3261 let participants = room
3262 .participants
3263 .iter()
3264 .map(|p| p.user_id)
3265 .collect::<Vec<_>>();
3266
3267 broadcast(
3268 None,
3269 channel_members
3270 .iter()
3271 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3272 |peer_id| {
3273 peer.send(
3274 peer_id.into(),
3275 proto::UpdateChannels {
3276 channel_participants: vec![proto::ChannelParticipants {
3277 channel_id: channel_id.to_proto(),
3278 participant_user_ids: participants.clone(),
3279 }],
3280 ..Default::default()
3281 },
3282 )
3283 },
3284 );
3285}
3286
3287async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3288 let db = session.db().await;
3289
3290 let contacts = db.get_contacts(user_id).await?;
3291 let busy = db.is_user_busy(user_id).await?;
3292
3293 let pool = session.connection_pool().await;
3294 let updated_contact = contact_for_user(user_id, busy, &pool);
3295 for contact in contacts {
3296 if let db::Contact::Accepted {
3297 user_id: contact_user_id,
3298 ..
3299 } = contact
3300 {
3301 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3302 session
3303 .peer
3304 .send(
3305 contact_conn_id,
3306 proto::UpdateContacts {
3307 contacts: vec![updated_contact.clone()],
3308 remove_contacts: Default::default(),
3309 incoming_requests: Default::default(),
3310 remove_incoming_requests: Default::default(),
3311 outgoing_requests: Default::default(),
3312 remove_outgoing_requests: Default::default(),
3313 },
3314 )
3315 .trace_err();
3316 }
3317 }
3318 }
3319 Ok(())
3320}
3321
3322async fn leave_room_for_session(session: &Session) -> Result<()> {
3323 let mut contacts_to_update = HashSet::default();
3324
3325 let room_id;
3326 let canceled_calls_to_user_ids;
3327 let live_kit_room;
3328 let delete_live_kit_room;
3329 let room;
3330 let channel_members;
3331 let channel_id;
3332
3333 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3334 contacts_to_update.insert(session.user_id);
3335
3336 for project in left_room.left_projects.values() {
3337 project_left(project, session);
3338 }
3339
3340 room_id = RoomId::from_proto(left_room.room.id);
3341 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3342 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3343 delete_live_kit_room = left_room.deleted;
3344 room = mem::take(&mut left_room.room);
3345 channel_members = mem::take(&mut left_room.channel_members);
3346 channel_id = left_room.channel_id;
3347
3348 room_updated(&room, &session.peer);
3349 } else {
3350 return Ok(());
3351 }
3352
3353 if let Some(channel_id) = channel_id {
3354 channel_updated(
3355 channel_id,
3356 &room,
3357 &channel_members,
3358 &session.peer,
3359 &*session.connection_pool().await,
3360 );
3361 }
3362
3363 {
3364 let pool = session.connection_pool().await;
3365 for canceled_user_id in canceled_calls_to_user_ids {
3366 for connection_id in pool.user_connection_ids(canceled_user_id) {
3367 session
3368 .peer
3369 .send(
3370 connection_id,
3371 proto::CallCanceled {
3372 room_id: room_id.to_proto(),
3373 },
3374 )
3375 .trace_err();
3376 }
3377 contacts_to_update.insert(canceled_user_id);
3378 }
3379 }
3380
3381 for contact_user_id in contacts_to_update {
3382 update_user_contacts(contact_user_id, &session).await?;
3383 }
3384
3385 if let Some(live_kit) = session.live_kit_client.as_ref() {
3386 live_kit
3387 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3388 .await
3389 .trace_err();
3390
3391 if delete_live_kit_room {
3392 live_kit.delete_room(live_kit_room).await.trace_err();
3393 }
3394 }
3395
3396 Ok(())
3397}
3398
3399async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3400 let left_channel_buffers = session
3401 .db()
3402 .await
3403 .leave_channel_buffers(session.connection_id)
3404 .await?;
3405
3406 for left_buffer in left_channel_buffers {
3407 channel_buffer_updated(
3408 session.connection_id,
3409 left_buffer.connections,
3410 &proto::UpdateChannelBufferCollaborators {
3411 channel_id: left_buffer.channel_id.to_proto(),
3412 collaborators: left_buffer.collaborators,
3413 },
3414 &session.peer,
3415 );
3416 }
3417
3418 Ok(())
3419}
3420
3421fn project_left(project: &db::LeftProject, session: &Session) {
3422 for connection_id in &project.connection_ids {
3423 if project.host_user_id == session.user_id {
3424 session
3425 .peer
3426 .send(
3427 *connection_id,
3428 proto::UnshareProject {
3429 project_id: project.id.to_proto(),
3430 },
3431 )
3432 .trace_err();
3433 } else {
3434 session
3435 .peer
3436 .send(
3437 *connection_id,
3438 proto::RemoveProjectCollaborator {
3439 project_id: project.id.to_proto(),
3440 peer_id: Some(session.connection_id.into()),
3441 },
3442 )
3443 .trace_err();
3444 }
3445 }
3446}
3447
3448pub trait ResultExt {
3449 type Ok;
3450
3451 fn trace_err(self) -> Option<Self::Ok>;
3452}
3453
3454impl<T, E> ResultExt for Result<T, E>
3455where
3456 E: std::fmt::Debug,
3457{
3458 type Ok = T;
3459
3460 fn trace_err(self) -> Option<T> {
3461 match self {
3462 Ok(value) => Some(value),
3463 Err(error) => {
3464 tracing::error!("{:?}", error);
3465 None
3466 }
3467 }
3468 }
3469}