1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{
6 self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreateChannelResult,
7 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
8 MoveChannelResult, NotificationId, ProjectId, RemoveChannelMemberResult,
9 RenameChannelResult, RespondToChannelInvite, RoomId, ServerId, SetChannelVisibilityResult,
10 User, UserId,
11 },
12 executor::Executor,
13 AppState, Result,
14};
15use anyhow::anyhow;
16use async_tungstenite::tungstenite::{
17 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
18};
19use axum::{
20 body::Body,
21 extract::{
22 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
23 ConnectInfo, WebSocketUpgrade,
24 },
25 headers::{Header, HeaderName},
26 http::StatusCode,
27 middleware,
28 response::IntoResponse,
29 routing::get,
30 Extension, Router, TypedHeader,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::ConnectionPool;
34use futures::{
35 channel::oneshot,
36 future::{self, BoxFuture},
37 stream::FuturesUnordered,
38 FutureExt, SinkExt, StreamExt, TryStreamExt,
39};
40use lazy_static::lazy_static;
41use prometheus::{register_int_gauge, IntGauge};
42use rpc::{
43 proto::{
44 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
45 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
46 },
47 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
48};
49use serde::{Serialize, Serializer};
50use std::{
51 any::TypeId,
52 fmt,
53 future::Future,
54 marker::PhantomData,
55 mem,
56 net::SocketAddr,
57 ops::{Deref, DerefMut},
58 rc::Rc,
59 sync::{
60 atomic::{AtomicBool, Ordering::SeqCst},
61 Arc,
62 },
63 time::{Duration, Instant},
64};
65use time::OffsetDateTime;
66use tokio::sync::{watch, Semaphore};
67use tower::ServiceBuilder;
68use tracing::{info_span, instrument, Instrument};
69
70pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
71pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
72
73const MESSAGE_COUNT_PER_PAGE: usize = 100;
74const MAX_MESSAGE_LEN: usize = 1024;
75const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
76
77lazy_static! {
78 static ref METRIC_CONNECTIONS: IntGauge =
79 register_int_gauge!("connections", "number of connections").unwrap();
80 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
81 "shared_projects",
82 "number of open projects with one or more guests"
83 )
84 .unwrap();
85}
86
87type MessageHandler =
88 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
89
90struct Response<R> {
91 peer: Arc<Peer>,
92 receipt: Receipt<R>,
93 responded: Arc<AtomicBool>,
94}
95
96impl<R: RequestMessage> Response<R> {
97 fn send(self, payload: R::Response) -> Result<()> {
98 self.responded.store(true, SeqCst);
99 self.peer.respond(self.receipt, payload)?;
100 Ok(())
101 }
102}
103
104#[derive(Clone)]
105struct Session {
106 zed_environment: Arc<str>,
107 user_id: UserId,
108 connection_id: ConnectionId,
109 db: Arc<tokio::sync::Mutex<DbHandle>>,
110 peer: Arc<Peer>,
111 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
112 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
113 _executor: Executor,
114}
115
116impl Session {
117 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
118 #[cfg(test)]
119 tokio::task::yield_now().await;
120 let guard = self.db.lock().await;
121 #[cfg(test)]
122 tokio::task::yield_now().await;
123 guard
124 }
125
126 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
127 #[cfg(test)]
128 tokio::task::yield_now().await;
129 let guard = self.connection_pool.lock();
130 ConnectionPoolGuard {
131 guard,
132 _not_send: PhantomData,
133 }
134 }
135}
136
137impl fmt::Debug for Session {
138 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
139 f.debug_struct("Session")
140 .field("user_id", &self.user_id)
141 .field("connection_id", &self.connection_id)
142 .finish()
143 }
144}
145
146struct DbHandle(Arc<Database>);
147
148impl Deref for DbHandle {
149 type Target = Database;
150
151 fn deref(&self) -> &Self::Target {
152 self.0.as_ref()
153 }
154}
155
156pub struct Server {
157 id: parking_lot::Mutex<ServerId>,
158 peer: Arc<Peer>,
159 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
160 app_state: Arc<AppState>,
161 executor: Executor,
162 handlers: HashMap<TypeId, MessageHandler>,
163 teardown: watch::Sender<()>,
164}
165
166pub(crate) struct ConnectionPoolGuard<'a> {
167 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
168 _not_send: PhantomData<Rc<()>>,
169}
170
171#[derive(Serialize)]
172pub struct ServerSnapshot<'a> {
173 peer: &'a Peer,
174 #[serde(serialize_with = "serialize_deref")]
175 connection_pool: ConnectionPoolGuard<'a>,
176}
177
178pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
179where
180 S: Serializer,
181 T: Deref<Target = U>,
182 U: Serialize,
183{
184 Serialize::serialize(value.deref(), serializer)
185}
186
187impl Server {
188 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
189 let mut server = Self {
190 id: parking_lot::Mutex::new(id),
191 peer: Peer::new(id.0 as u32),
192 app_state,
193 executor,
194 connection_pool: Default::default(),
195 handlers: Default::default(),
196 teardown: watch::channel(()).0,
197 };
198
199 server
200 .add_request_handler(ping)
201 .add_request_handler(create_room)
202 .add_request_handler(join_room)
203 .add_request_handler(rejoin_room)
204 .add_request_handler(leave_room)
205 .add_request_handler(set_room_participant_role)
206 .add_request_handler(call)
207 .add_request_handler(cancel_call)
208 .add_message_handler(decline_call)
209 .add_request_handler(update_participant_location)
210 .add_request_handler(share_project)
211 .add_message_handler(unshare_project)
212 .add_request_handler(join_project)
213 .add_message_handler(leave_project)
214 .add_request_handler(update_project)
215 .add_request_handler(update_worktree)
216 .add_message_handler(start_language_server)
217 .add_message_handler(update_language_server)
218 .add_message_handler(update_diagnostic_summary)
219 .add_message_handler(update_worktree_settings)
220 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
221 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
222 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
223 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
224 .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
225 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
226 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
227 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
228 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
229 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
230 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
231 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
232 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
233 .add_request_handler(
234 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
235 )
236 .add_request_handler(
237 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
238 )
239 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
240 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
241 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
242 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
243 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
244 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
245 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
246 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
247 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
248 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
249 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
250 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
251 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
252 .add_message_handler(create_buffer_for_peer)
253 .add_request_handler(update_buffer)
254 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
255 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
256 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
257 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
258 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
259 .add_request_handler(get_users)
260 .add_request_handler(fuzzy_search_users)
261 .add_request_handler(request_contact)
262 .add_request_handler(remove_contact)
263 .add_request_handler(respond_to_contact_request)
264 .add_request_handler(create_channel)
265 .add_request_handler(delete_channel)
266 .add_request_handler(invite_channel_member)
267 .add_request_handler(remove_channel_member)
268 .add_request_handler(set_channel_member_role)
269 .add_request_handler(set_channel_visibility)
270 .add_request_handler(rename_channel)
271 .add_request_handler(join_channel_buffer)
272 .add_request_handler(leave_channel_buffer)
273 .add_message_handler(update_channel_buffer)
274 .add_request_handler(rejoin_channel_buffers)
275 .add_request_handler(get_channel_members)
276 .add_request_handler(respond_to_channel_invite)
277 .add_request_handler(join_channel)
278 .add_request_handler(join_channel_chat)
279 .add_message_handler(leave_channel_chat)
280 .add_request_handler(send_channel_message)
281 .add_request_handler(remove_channel_message)
282 .add_request_handler(get_channel_messages)
283 .add_request_handler(get_channel_messages_by_id)
284 .add_request_handler(get_notifications)
285 .add_request_handler(mark_notification_as_read)
286 .add_request_handler(move_channel)
287 .add_request_handler(follow)
288 .add_message_handler(unfollow)
289 .add_message_handler(update_followers)
290 .add_request_handler(get_private_user_info)
291 .add_message_handler(acknowledge_channel_message)
292 .add_message_handler(acknowledge_buffer_version);
293
294 Arc::new(server)
295 }
296
297 pub async fn start(&self) -> Result<()> {
298 let server_id = *self.id.lock();
299 let app_state = self.app_state.clone();
300 let peer = self.peer.clone();
301 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
302 let pool = self.connection_pool.clone();
303 let live_kit_client = self.app_state.live_kit_client.clone();
304
305 let span = info_span!("start server");
306 self.executor.spawn_detached(
307 async move {
308 tracing::info!("waiting for cleanup timeout");
309 timeout.await;
310 tracing::info!("cleanup timeout expired, retrieving stale rooms");
311 if let Some((room_ids, channel_ids)) = app_state
312 .db
313 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
314 .await
315 .trace_err()
316 {
317 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
318 tracing::info!(
319 stale_channel_buffer_count = channel_ids.len(),
320 "retrieved stale channel buffers"
321 );
322
323 for channel_id in channel_ids {
324 if let Some(refreshed_channel_buffer) = app_state
325 .db
326 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
327 .await
328 .trace_err()
329 {
330 for connection_id in refreshed_channel_buffer.connection_ids {
331 peer.send(
332 connection_id,
333 proto::UpdateChannelBufferCollaborators {
334 channel_id: channel_id.to_proto(),
335 collaborators: refreshed_channel_buffer
336 .collaborators
337 .clone(),
338 },
339 )
340 .trace_err();
341 }
342 }
343 }
344
345 for room_id in room_ids {
346 let mut contacts_to_update = HashSet::default();
347 let mut canceled_calls_to_user_ids = Vec::new();
348 let mut live_kit_room = String::new();
349 let mut delete_live_kit_room = false;
350
351 if let Some(mut refreshed_room) = app_state
352 .db
353 .clear_stale_room_participants(room_id, server_id)
354 .await
355 .trace_err()
356 {
357 tracing::info!(
358 room_id = room_id.0,
359 new_participant_count = refreshed_room.room.participants.len(),
360 "refreshed room"
361 );
362 room_updated(&refreshed_room.room, &peer);
363 if let Some(channel_id) = refreshed_room.channel_id {
364 channel_updated(
365 channel_id,
366 &refreshed_room.room,
367 &refreshed_room.channel_members,
368 &peer,
369 &*pool.lock(),
370 );
371 }
372 contacts_to_update
373 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
374 contacts_to_update
375 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
376 canceled_calls_to_user_ids =
377 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
378 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
379 delete_live_kit_room = refreshed_room.room.participants.is_empty();
380 }
381
382 {
383 let pool = pool.lock();
384 for canceled_user_id in canceled_calls_to_user_ids {
385 for connection_id in pool.user_connection_ids(canceled_user_id) {
386 peer.send(
387 connection_id,
388 proto::CallCanceled {
389 room_id: room_id.to_proto(),
390 },
391 )
392 .trace_err();
393 }
394 }
395 }
396
397 for user_id in contacts_to_update {
398 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
399 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
400 if let Some((busy, contacts)) = busy.zip(contacts) {
401 let pool = pool.lock();
402 let updated_contact = contact_for_user(user_id, busy, &pool);
403 for contact in contacts {
404 if let db::Contact::Accepted {
405 user_id: contact_user_id,
406 ..
407 } = contact
408 {
409 for contact_conn_id in
410 pool.user_connection_ids(contact_user_id)
411 {
412 peer.send(
413 contact_conn_id,
414 proto::UpdateContacts {
415 contacts: vec![updated_contact.clone()],
416 remove_contacts: Default::default(),
417 incoming_requests: Default::default(),
418 remove_incoming_requests: Default::default(),
419 outgoing_requests: Default::default(),
420 remove_outgoing_requests: Default::default(),
421 },
422 )
423 .trace_err();
424 }
425 }
426 }
427 }
428 }
429
430 if let Some(live_kit) = live_kit_client.as_ref() {
431 if delete_live_kit_room {
432 live_kit.delete_room(live_kit_room).await.trace_err();
433 }
434 }
435 }
436 }
437
438 app_state
439 .db
440 .delete_stale_servers(&app_state.config.zed_environment, server_id)
441 .await
442 .trace_err();
443 }
444 .instrument(span),
445 );
446 Ok(())
447 }
448
449 pub fn teardown(&self) {
450 self.peer.teardown();
451 self.connection_pool.lock().reset();
452 let _ = self.teardown.send(());
453 }
454
455 #[cfg(test)]
456 pub fn reset(&self, id: ServerId) {
457 self.teardown();
458 *self.id.lock() = id;
459 self.peer.reset(id.0 as u32);
460 }
461
462 #[cfg(test)]
463 pub fn id(&self) -> ServerId {
464 *self.id.lock()
465 }
466
467 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
468 where
469 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
470 Fut: 'static + Send + Future<Output = Result<()>>,
471 M: EnvelopedMessage,
472 {
473 let prev_handler = self.handlers.insert(
474 TypeId::of::<M>(),
475 Box::new(move |envelope, session| {
476 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
477 let span = info_span!(
478 "handle message",
479 payload_type = envelope.payload_type_name()
480 );
481 span.in_scope(|| {
482 tracing::info!(
483 payload_type = envelope.payload_type_name(),
484 "message received"
485 );
486 });
487 let start_time = Instant::now();
488 let future = (handler)(*envelope, session);
489 async move {
490 let result = future.await;
491 let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
492 match result {
493 Err(error) => {
494 tracing::error!(%error, ?duration_ms, "error handling message")
495 }
496 Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
497 }
498 }
499 .instrument(span)
500 .boxed()
501 }),
502 );
503 if prev_handler.is_some() {
504 panic!("registered a handler for the same message twice");
505 }
506 self
507 }
508
509 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
510 where
511 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
512 Fut: 'static + Send + Future<Output = Result<()>>,
513 M: EnvelopedMessage,
514 {
515 self.add_handler(move |envelope, session| handler(envelope.payload, session));
516 self
517 }
518
519 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
520 where
521 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
522 Fut: Send + Future<Output = Result<()>>,
523 M: RequestMessage,
524 {
525 let handler = Arc::new(handler);
526 self.add_handler(move |envelope, session| {
527 let receipt = envelope.receipt();
528 let handler = handler.clone();
529 async move {
530 let peer = session.peer.clone();
531 let responded = Arc::new(AtomicBool::default());
532 let response = Response {
533 peer: peer.clone(),
534 responded: responded.clone(),
535 receipt,
536 };
537 match (handler)(envelope.payload, response, session).await {
538 Ok(()) => {
539 if responded.load(std::sync::atomic::Ordering::SeqCst) {
540 Ok(())
541 } else {
542 Err(anyhow!("handler did not send a response"))?
543 }
544 }
545 Err(error) => {
546 peer.respond_with_error(
547 receipt,
548 proto::Error {
549 message: error.to_string(),
550 },
551 )?;
552 Err(error)
553 }
554 }
555 }
556 })
557 }
558
559 pub fn handle_connection(
560 self: &Arc<Self>,
561 connection: Connection,
562 address: String,
563 user: User,
564 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
565 executor: Executor,
566 ) -> impl Future<Output = Result<()>> {
567 let this = self.clone();
568 let user_id = user.id;
569 let login = user.github_login;
570 let span = info_span!("handle connection", %user_id, %login, %address);
571 let mut teardown = self.teardown.subscribe();
572 async move {
573 let (connection_id, handle_io, mut incoming_rx) = this
574 .peer
575 .add_connection(connection, {
576 let executor = executor.clone();
577 move |duration| executor.sleep(duration)
578 });
579
580 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
581 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
582 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
583
584 if let Some(send_connection_id) = send_connection_id.take() {
585 let _ = send_connection_id.send(connection_id);
586 }
587
588 if !user.connected_once {
589 this.peer.send(connection_id, proto::ShowContacts {})?;
590 this.app_state.db.set_user_connected_once(user_id, true).await?;
591 }
592
593 let (contacts, channels_for_user, channel_invites) = future::try_join3(
594 this.app_state.db.get_contacts(user_id),
595 this.app_state.db.get_channels_for_user(user_id),
596 this.app_state.db.get_channel_invites_for_user(user_id),
597 ).await?;
598
599 {
600 let mut pool = this.connection_pool.lock();
601 pool.add_connection(connection_id, user_id, user.admin);
602 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
603 this.peer.send(connection_id, build_channels_update(
604 channels_for_user,
605 channel_invites
606 ))?;
607 }
608
609 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
610 this.peer.send(connection_id, incoming_call)?;
611 }
612
613 let session = Session {
614 user_id,
615 connection_id,
616 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
617 zed_environment: this.app_state.config.zed_environment.clone(),
618 peer: this.peer.clone(),
619 connection_pool: this.connection_pool.clone(),
620 live_kit_client: this.app_state.live_kit_client.clone(),
621 _executor: executor.clone()
622 };
623 update_user_contacts(user_id, &session).await?;
624
625 let handle_io = handle_io.fuse();
626 futures::pin_mut!(handle_io);
627
628 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
629 // This prevents deadlocks when e.g., client A performs a request to client B and
630 // client B performs a request to client A. If both clients stop processing further
631 // messages until their respective request completes, they won't have a chance to
632 // respond to the other client's request and cause a deadlock.
633 //
634 // This arrangement ensures we will attempt to process earlier messages first, but fall
635 // back to processing messages arrived later in the spirit of making progress.
636 let mut foreground_message_handlers = FuturesUnordered::new();
637 let concurrent_handlers = Arc::new(Semaphore::new(256));
638 loop {
639 let next_message = async {
640 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
641 let message = incoming_rx.next().await;
642 (permit, message)
643 }.fuse();
644 futures::pin_mut!(next_message);
645 futures::select_biased! {
646 _ = teardown.changed().fuse() => return Ok(()),
647 result = handle_io => {
648 if let Err(error) = result {
649 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
650 }
651 break;
652 }
653 _ = foreground_message_handlers.next() => {}
654 next_message = next_message => {
655 let (permit, message) = next_message;
656 if let Some(message) = message {
657 let type_name = message.payload_type_name();
658 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
659 let span_enter = span.enter();
660 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
661 let is_background = message.is_background();
662 let handle_message = (handler)(message, session.clone());
663 drop(span_enter);
664
665 let handle_message = async move {
666 handle_message.await;
667 drop(permit);
668 }.instrument(span);
669 if is_background {
670 executor.spawn_detached(handle_message);
671 } else {
672 foreground_message_handlers.push(handle_message);
673 }
674 } else {
675 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
676 }
677 } else {
678 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
679 break;
680 }
681 }
682 }
683 }
684
685 drop(foreground_message_handlers);
686 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
687 if let Err(error) = connection_lost(session, teardown, executor).await {
688 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
689 }
690
691 Ok(())
692 }.instrument(span)
693 }
694
695 pub async fn invite_code_redeemed(
696 self: &Arc<Self>,
697 inviter_id: UserId,
698 invitee_id: UserId,
699 ) -> Result<()> {
700 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
701 if let Some(code) = &user.invite_code {
702 let pool = self.connection_pool.lock();
703 let invitee_contact = contact_for_user(invitee_id, false, &pool);
704 for connection_id in pool.user_connection_ids(inviter_id) {
705 self.peer.send(
706 connection_id,
707 proto::UpdateContacts {
708 contacts: vec![invitee_contact.clone()],
709 ..Default::default()
710 },
711 )?;
712 self.peer.send(
713 connection_id,
714 proto::UpdateInviteInfo {
715 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
716 count: user.invite_count as u32,
717 },
718 )?;
719 }
720 }
721 }
722 Ok(())
723 }
724
725 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
726 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
727 if let Some(invite_code) = &user.invite_code {
728 let pool = self.connection_pool.lock();
729 for connection_id in pool.user_connection_ids(user_id) {
730 self.peer.send(
731 connection_id,
732 proto::UpdateInviteInfo {
733 url: format!(
734 "{}{}",
735 self.app_state.config.invite_link_prefix, invite_code
736 ),
737 count: user.invite_count as u32,
738 },
739 )?;
740 }
741 }
742 }
743 Ok(())
744 }
745
746 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
747 ServerSnapshot {
748 connection_pool: ConnectionPoolGuard {
749 guard: self.connection_pool.lock(),
750 _not_send: PhantomData,
751 },
752 peer: &self.peer,
753 }
754 }
755}
756
757impl<'a> Deref for ConnectionPoolGuard<'a> {
758 type Target = ConnectionPool;
759
760 fn deref(&self) -> &Self::Target {
761 &*self.guard
762 }
763}
764
765impl<'a> DerefMut for ConnectionPoolGuard<'a> {
766 fn deref_mut(&mut self) -> &mut Self::Target {
767 &mut *self.guard
768 }
769}
770
771impl<'a> Drop for ConnectionPoolGuard<'a> {
772 fn drop(&mut self) {
773 #[cfg(test)]
774 self.check_invariants();
775 }
776}
777
778fn broadcast<F>(
779 sender_id: Option<ConnectionId>,
780 receiver_ids: impl IntoIterator<Item = ConnectionId>,
781 mut f: F,
782) where
783 F: FnMut(ConnectionId) -> anyhow::Result<()>,
784{
785 for receiver_id in receiver_ids {
786 if Some(receiver_id) != sender_id {
787 if let Err(error) = f(receiver_id) {
788 tracing::error!("failed to send to {:?} {}", receiver_id, error);
789 }
790 }
791 }
792}
793
794lazy_static! {
795 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
796}
797
798pub struct ProtocolVersion(u32);
799
800impl Header for ProtocolVersion {
801 fn name() -> &'static HeaderName {
802 &ZED_PROTOCOL_VERSION
803 }
804
805 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
806 where
807 Self: Sized,
808 I: Iterator<Item = &'i axum::http::HeaderValue>,
809 {
810 let version = values
811 .next()
812 .ok_or_else(axum::headers::Error::invalid)?
813 .to_str()
814 .map_err(|_| axum::headers::Error::invalid())?
815 .parse()
816 .map_err(|_| axum::headers::Error::invalid())?;
817 Ok(Self(version))
818 }
819
820 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
821 values.extend([self.0.to_string().parse().unwrap()]);
822 }
823}
824
825pub fn routes(server: Arc<Server>) -> Router<Body> {
826 Router::new()
827 .route("/rpc", get(handle_websocket_request))
828 .layer(
829 ServiceBuilder::new()
830 .layer(Extension(server.app_state.clone()))
831 .layer(middleware::from_fn(auth::validate_header)),
832 )
833 .route("/metrics", get(handle_metrics))
834 .layer(Extension(server))
835}
836
837pub async fn handle_websocket_request(
838 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
839 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
840 Extension(server): Extension<Arc<Server>>,
841 Extension(user): Extension<User>,
842 ws: WebSocketUpgrade,
843) -> axum::response::Response {
844 if protocol_version != rpc::PROTOCOL_VERSION {
845 return (
846 StatusCode::UPGRADE_REQUIRED,
847 "client must be upgraded".to_string(),
848 )
849 .into_response();
850 }
851 let socket_address = socket_address.to_string();
852 ws.on_upgrade(move |socket| {
853 use util::ResultExt;
854 let socket = socket
855 .map_ok(to_tungstenite_message)
856 .err_into()
857 .with(|message| async move { Ok(to_axum_message(message)) });
858 let connection = Connection::new(Box::pin(socket));
859 async move {
860 server
861 .handle_connection(connection, socket_address, user, None, Executor::Production)
862 .await
863 .log_err();
864 }
865 })
866}
867
868pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
869 let connections = server
870 .connection_pool
871 .lock()
872 .connections()
873 .filter(|connection| !connection.admin)
874 .count();
875
876 METRIC_CONNECTIONS.set(connections as _);
877
878 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
879 METRIC_SHARED_PROJECTS.set(shared_projects as _);
880
881 let encoder = prometheus::TextEncoder::new();
882 let metric_families = prometheus::gather();
883 let encoded_metrics = encoder
884 .encode_to_string(&metric_families)
885 .map_err(|err| anyhow!("{}", err))?;
886 Ok(encoded_metrics)
887}
888
889#[instrument(err, skip(executor))]
890async fn connection_lost(
891 session: Session,
892 mut teardown: watch::Receiver<()>,
893 executor: Executor,
894) -> Result<()> {
895 session.peer.disconnect(session.connection_id);
896 session
897 .connection_pool()
898 .await
899 .remove_connection(session.connection_id)?;
900
901 session
902 .db()
903 .await
904 .connection_lost(session.connection_id)
905 .await
906 .trace_err();
907
908 futures::select_biased! {
909 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
910 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
911 leave_room_for_session(&session).await.trace_err();
912 leave_channel_buffers_for_session(&session)
913 .await
914 .trace_err();
915
916 if !session
917 .connection_pool()
918 .await
919 .is_user_online(session.user_id)
920 {
921 let db = session.db().await;
922 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
923 room_updated(&room, &session.peer);
924 }
925 }
926
927 update_user_contacts(session.user_id, &session).await?;
928 }
929 _ = teardown.changed().fuse() => {}
930 }
931
932 Ok(())
933}
934
935async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
936 response.send(proto::Ack {})?;
937 Ok(())
938}
939
940async fn create_room(
941 _request: proto::CreateRoom,
942 response: Response<proto::CreateRoom>,
943 session: Session,
944) -> Result<()> {
945 let live_kit_room = nanoid::nanoid!(30);
946
947 let live_kit_connection_info = {
948 let live_kit_room = live_kit_room.clone();
949 let live_kit = session.live_kit_client.as_ref();
950
951 util::async_maybe!({
952 let live_kit = live_kit?;
953
954 let token = live_kit
955 .room_token(&live_kit_room, &session.user_id.to_string())
956 .trace_err()?;
957
958 Some(proto::LiveKitConnectionInfo {
959 server_url: live_kit.url().into(),
960 token,
961 can_publish: true,
962 })
963 })
964 }
965 .await;
966
967 let room = session
968 .db()
969 .await
970 .create_room(
971 session.user_id,
972 session.connection_id,
973 &live_kit_room,
974 &session.zed_environment,
975 )
976 .await?;
977
978 response.send(proto::CreateRoomResponse {
979 room: Some(room.clone()),
980 live_kit_connection_info,
981 })?;
982
983 update_user_contacts(session.user_id, &session).await?;
984 Ok(())
985}
986
987async fn join_room(
988 request: proto::JoinRoom,
989 response: Response<proto::JoinRoom>,
990 session: Session,
991) -> Result<()> {
992 let room_id = RoomId::from_proto(request.id);
993
994 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
995
996 if let Some(channel_id) = channel_id {
997 return join_channel_internal(channel_id, Box::new(response), session).await;
998 }
999
1000 let joined_room = {
1001 let room = session
1002 .db()
1003 .await
1004 .join_room(
1005 room_id,
1006 session.user_id,
1007 session.connection_id,
1008 session.zed_environment.as_ref(),
1009 )
1010 .await?;
1011 room_updated(&room.room, &session.peer);
1012 room.into_inner()
1013 };
1014
1015 for connection_id in session
1016 .connection_pool()
1017 .await
1018 .user_connection_ids(session.user_id)
1019 {
1020 session
1021 .peer
1022 .send(
1023 connection_id,
1024 proto::CallCanceled {
1025 room_id: room_id.to_proto(),
1026 },
1027 )
1028 .trace_err();
1029 }
1030
1031 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1032 if let Some(token) = live_kit
1033 .room_token(
1034 &joined_room.room.live_kit_room,
1035 &session.user_id.to_string(),
1036 )
1037 .trace_err()
1038 {
1039 Some(proto::LiveKitConnectionInfo {
1040 server_url: live_kit.url().into(),
1041 token,
1042 can_publish: true,
1043 })
1044 } else {
1045 None
1046 }
1047 } else {
1048 None
1049 };
1050
1051 response.send(proto::JoinRoomResponse {
1052 room: Some(joined_room.room),
1053 channel_id: None,
1054 live_kit_connection_info,
1055 })?;
1056
1057 update_user_contacts(session.user_id, &session).await?;
1058 Ok(())
1059}
1060
1061async fn rejoin_room(
1062 request: proto::RejoinRoom,
1063 response: Response<proto::RejoinRoom>,
1064 session: Session,
1065) -> Result<()> {
1066 let room;
1067 let channel_id;
1068 let channel_members;
1069 {
1070 let mut rejoined_room = session
1071 .db()
1072 .await
1073 .rejoin_room(request, session.user_id, session.connection_id)
1074 .await?;
1075
1076 response.send(proto::RejoinRoomResponse {
1077 room: Some(rejoined_room.room.clone()),
1078 reshared_projects: rejoined_room
1079 .reshared_projects
1080 .iter()
1081 .map(|project| proto::ResharedProject {
1082 id: project.id.to_proto(),
1083 collaborators: project
1084 .collaborators
1085 .iter()
1086 .map(|collaborator| collaborator.to_proto())
1087 .collect(),
1088 })
1089 .collect(),
1090 rejoined_projects: rejoined_room
1091 .rejoined_projects
1092 .iter()
1093 .map(|rejoined_project| proto::RejoinedProject {
1094 id: rejoined_project.id.to_proto(),
1095 worktrees: rejoined_project
1096 .worktrees
1097 .iter()
1098 .map(|worktree| proto::WorktreeMetadata {
1099 id: worktree.id,
1100 root_name: worktree.root_name.clone(),
1101 visible: worktree.visible,
1102 abs_path: worktree.abs_path.clone(),
1103 })
1104 .collect(),
1105 collaborators: rejoined_project
1106 .collaborators
1107 .iter()
1108 .map(|collaborator| collaborator.to_proto())
1109 .collect(),
1110 language_servers: rejoined_project.language_servers.clone(),
1111 })
1112 .collect(),
1113 })?;
1114 room_updated(&rejoined_room.room, &session.peer);
1115
1116 for project in &rejoined_room.reshared_projects {
1117 for collaborator in &project.collaborators {
1118 session
1119 .peer
1120 .send(
1121 collaborator.connection_id,
1122 proto::UpdateProjectCollaborator {
1123 project_id: project.id.to_proto(),
1124 old_peer_id: Some(project.old_connection_id.into()),
1125 new_peer_id: Some(session.connection_id.into()),
1126 },
1127 )
1128 .trace_err();
1129 }
1130
1131 broadcast(
1132 Some(session.connection_id),
1133 project
1134 .collaborators
1135 .iter()
1136 .map(|collaborator| collaborator.connection_id),
1137 |connection_id| {
1138 session.peer.forward_send(
1139 session.connection_id,
1140 connection_id,
1141 proto::UpdateProject {
1142 project_id: project.id.to_proto(),
1143 worktrees: project.worktrees.clone(),
1144 },
1145 )
1146 },
1147 );
1148 }
1149
1150 for project in &rejoined_room.rejoined_projects {
1151 for collaborator in &project.collaborators {
1152 session
1153 .peer
1154 .send(
1155 collaborator.connection_id,
1156 proto::UpdateProjectCollaborator {
1157 project_id: project.id.to_proto(),
1158 old_peer_id: Some(project.old_connection_id.into()),
1159 new_peer_id: Some(session.connection_id.into()),
1160 },
1161 )
1162 .trace_err();
1163 }
1164 }
1165
1166 for project in &mut rejoined_room.rejoined_projects {
1167 for worktree in mem::take(&mut project.worktrees) {
1168 #[cfg(any(test, feature = "test-support"))]
1169 const MAX_CHUNK_SIZE: usize = 2;
1170 #[cfg(not(any(test, feature = "test-support")))]
1171 const MAX_CHUNK_SIZE: usize = 256;
1172
1173 // Stream this worktree's entries.
1174 let message = proto::UpdateWorktree {
1175 project_id: project.id.to_proto(),
1176 worktree_id: worktree.id,
1177 abs_path: worktree.abs_path.clone(),
1178 root_name: worktree.root_name,
1179 updated_entries: worktree.updated_entries,
1180 removed_entries: worktree.removed_entries,
1181 scan_id: worktree.scan_id,
1182 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1183 updated_repositories: worktree.updated_repositories,
1184 removed_repositories: worktree.removed_repositories,
1185 };
1186 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1187 session.peer.send(session.connection_id, update.clone())?;
1188 }
1189
1190 // Stream this worktree's diagnostics.
1191 for summary in worktree.diagnostic_summaries {
1192 session.peer.send(
1193 session.connection_id,
1194 proto::UpdateDiagnosticSummary {
1195 project_id: project.id.to_proto(),
1196 worktree_id: worktree.id,
1197 summary: Some(summary),
1198 },
1199 )?;
1200 }
1201
1202 for settings_file in worktree.settings_files {
1203 session.peer.send(
1204 session.connection_id,
1205 proto::UpdateWorktreeSettings {
1206 project_id: project.id.to_proto(),
1207 worktree_id: worktree.id,
1208 path: settings_file.path,
1209 content: Some(settings_file.content),
1210 },
1211 )?;
1212 }
1213 }
1214
1215 for language_server in &project.language_servers {
1216 session.peer.send(
1217 session.connection_id,
1218 proto::UpdateLanguageServer {
1219 project_id: project.id.to_proto(),
1220 language_server_id: language_server.id,
1221 variant: Some(
1222 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1223 proto::LspDiskBasedDiagnosticsUpdated {},
1224 ),
1225 ),
1226 },
1227 )?;
1228 }
1229 }
1230
1231 let rejoined_room = rejoined_room.into_inner();
1232
1233 room = rejoined_room.room;
1234 channel_id = rejoined_room.channel_id;
1235 channel_members = rejoined_room.channel_members;
1236 }
1237
1238 if let Some(channel_id) = channel_id {
1239 channel_updated(
1240 channel_id,
1241 &room,
1242 &channel_members,
1243 &session.peer,
1244 &*session.connection_pool().await,
1245 );
1246 }
1247
1248 update_user_contacts(session.user_id, &session).await?;
1249 Ok(())
1250}
1251
1252async fn leave_room(
1253 _: proto::LeaveRoom,
1254 response: Response<proto::LeaveRoom>,
1255 session: Session,
1256) -> Result<()> {
1257 leave_room_for_session(&session).await?;
1258 response.send(proto::Ack {})?;
1259 Ok(())
1260}
1261
1262async fn set_room_participant_role(
1263 request: proto::SetRoomParticipantRole,
1264 response: Response<proto::SetRoomParticipantRole>,
1265 session: Session,
1266) -> Result<()> {
1267 let (live_kit_room, can_publish) = {
1268 let room = session
1269 .db()
1270 .await
1271 .set_room_participant_role(
1272 session.user_id,
1273 RoomId::from_proto(request.room_id),
1274 UserId::from_proto(request.user_id),
1275 ChannelRole::from(request.role()),
1276 )
1277 .await?;
1278
1279 let live_kit_room = room.live_kit_room.clone();
1280 let can_publish = ChannelRole::from(request.role()).can_publish_to_rooms();
1281 room_updated(&room, &session.peer);
1282 (live_kit_room, can_publish)
1283 };
1284
1285 if let Some(live_kit) = session.live_kit_client.as_ref() {
1286 live_kit
1287 .update_participant(
1288 live_kit_room.clone(),
1289 request.user_id.to_string(),
1290 live_kit_server::proto::ParticipantPermission {
1291 can_subscribe: true,
1292 can_publish,
1293 can_publish_data: can_publish,
1294 hidden: false,
1295 recorder: false,
1296 },
1297 )
1298 .await
1299 .trace_err();
1300 }
1301
1302 response.send(proto::Ack {})?;
1303 Ok(())
1304}
1305
1306async fn call(
1307 request: proto::Call,
1308 response: Response<proto::Call>,
1309 session: Session,
1310) -> Result<()> {
1311 let room_id = RoomId::from_proto(request.room_id);
1312 let calling_user_id = session.user_id;
1313 let calling_connection_id = session.connection_id;
1314 let called_user_id = UserId::from_proto(request.called_user_id);
1315 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1316 if !session
1317 .db()
1318 .await
1319 .has_contact(calling_user_id, called_user_id)
1320 .await?
1321 {
1322 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1323 }
1324
1325 let incoming_call = {
1326 let (room, incoming_call) = &mut *session
1327 .db()
1328 .await
1329 .call(
1330 room_id,
1331 calling_user_id,
1332 calling_connection_id,
1333 called_user_id,
1334 initial_project_id,
1335 )
1336 .await?;
1337 room_updated(&room, &session.peer);
1338 mem::take(incoming_call)
1339 };
1340 update_user_contacts(called_user_id, &session).await?;
1341
1342 let mut calls = session
1343 .connection_pool()
1344 .await
1345 .user_connection_ids(called_user_id)
1346 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1347 .collect::<FuturesUnordered<_>>();
1348
1349 while let Some(call_response) = calls.next().await {
1350 match call_response.as_ref() {
1351 Ok(_) => {
1352 response.send(proto::Ack {})?;
1353 return Ok(());
1354 }
1355 Err(_) => {
1356 call_response.trace_err();
1357 }
1358 }
1359 }
1360
1361 {
1362 let room = session
1363 .db()
1364 .await
1365 .call_failed(room_id, called_user_id)
1366 .await?;
1367 room_updated(&room, &session.peer);
1368 }
1369 update_user_contacts(called_user_id, &session).await?;
1370
1371 Err(anyhow!("failed to ring user"))?
1372}
1373
1374async fn cancel_call(
1375 request: proto::CancelCall,
1376 response: Response<proto::CancelCall>,
1377 session: Session,
1378) -> Result<()> {
1379 let called_user_id = UserId::from_proto(request.called_user_id);
1380 let room_id = RoomId::from_proto(request.room_id);
1381 {
1382 let room = session
1383 .db()
1384 .await
1385 .cancel_call(room_id, session.connection_id, called_user_id)
1386 .await?;
1387 room_updated(&room, &session.peer);
1388 }
1389
1390 for connection_id in session
1391 .connection_pool()
1392 .await
1393 .user_connection_ids(called_user_id)
1394 {
1395 session
1396 .peer
1397 .send(
1398 connection_id,
1399 proto::CallCanceled {
1400 room_id: room_id.to_proto(),
1401 },
1402 )
1403 .trace_err();
1404 }
1405 response.send(proto::Ack {})?;
1406
1407 update_user_contacts(called_user_id, &session).await?;
1408 Ok(())
1409}
1410
1411async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1412 let room_id = RoomId::from_proto(message.room_id);
1413 {
1414 let room = session
1415 .db()
1416 .await
1417 .decline_call(Some(room_id), session.user_id)
1418 .await?
1419 .ok_or_else(|| anyhow!("failed to decline call"))?;
1420 room_updated(&room, &session.peer);
1421 }
1422
1423 for connection_id in session
1424 .connection_pool()
1425 .await
1426 .user_connection_ids(session.user_id)
1427 {
1428 session
1429 .peer
1430 .send(
1431 connection_id,
1432 proto::CallCanceled {
1433 room_id: room_id.to_proto(),
1434 },
1435 )
1436 .trace_err();
1437 }
1438 update_user_contacts(session.user_id, &session).await?;
1439 Ok(())
1440}
1441
1442async fn update_participant_location(
1443 request: proto::UpdateParticipantLocation,
1444 response: Response<proto::UpdateParticipantLocation>,
1445 session: Session,
1446) -> Result<()> {
1447 let room_id = RoomId::from_proto(request.room_id);
1448 let location = request
1449 .location
1450 .ok_or_else(|| anyhow!("invalid location"))?;
1451
1452 let db = session.db().await;
1453 let room = db
1454 .update_room_participant_location(room_id, session.connection_id, location)
1455 .await?;
1456
1457 room_updated(&room, &session.peer);
1458 response.send(proto::Ack {})?;
1459 Ok(())
1460}
1461
1462async fn share_project(
1463 request: proto::ShareProject,
1464 response: Response<proto::ShareProject>,
1465 session: Session,
1466) -> Result<()> {
1467 let (project_id, room) = &*session
1468 .db()
1469 .await
1470 .share_project(
1471 RoomId::from_proto(request.room_id),
1472 session.connection_id,
1473 &request.worktrees,
1474 )
1475 .await?;
1476 response.send(proto::ShareProjectResponse {
1477 project_id: project_id.to_proto(),
1478 })?;
1479 room_updated(&room, &session.peer);
1480
1481 Ok(())
1482}
1483
1484async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1485 let project_id = ProjectId::from_proto(message.project_id);
1486
1487 let (room, guest_connection_ids) = &*session
1488 .db()
1489 .await
1490 .unshare_project(project_id, session.connection_id)
1491 .await?;
1492
1493 broadcast(
1494 Some(session.connection_id),
1495 guest_connection_ids.iter().copied(),
1496 |conn_id| session.peer.send(conn_id, message.clone()),
1497 );
1498 room_updated(&room, &session.peer);
1499
1500 Ok(())
1501}
1502
1503async fn join_project(
1504 request: proto::JoinProject,
1505 response: Response<proto::JoinProject>,
1506 session: Session,
1507) -> Result<()> {
1508 let project_id = ProjectId::from_proto(request.project_id);
1509 let guest_user_id = session.user_id;
1510
1511 tracing::info!(%project_id, "join project");
1512
1513 let (project, replica_id) = &mut *session
1514 .db()
1515 .await
1516 .join_project(project_id, session.connection_id)
1517 .await?;
1518
1519 let collaborators = project
1520 .collaborators
1521 .iter()
1522 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1523 .map(|collaborator| collaborator.to_proto())
1524 .collect::<Vec<_>>();
1525
1526 let worktrees = project
1527 .worktrees
1528 .iter()
1529 .map(|(id, worktree)| proto::WorktreeMetadata {
1530 id: *id,
1531 root_name: worktree.root_name.clone(),
1532 visible: worktree.visible,
1533 abs_path: worktree.abs_path.clone(),
1534 })
1535 .collect::<Vec<_>>();
1536
1537 for collaborator in &collaborators {
1538 session
1539 .peer
1540 .send(
1541 collaborator.peer_id.unwrap().into(),
1542 proto::AddProjectCollaborator {
1543 project_id: project_id.to_proto(),
1544 collaborator: Some(proto::Collaborator {
1545 peer_id: Some(session.connection_id.into()),
1546 replica_id: replica_id.0 as u32,
1547 user_id: guest_user_id.to_proto(),
1548 }),
1549 },
1550 )
1551 .trace_err();
1552 }
1553
1554 // First, we send the metadata associated with each worktree.
1555 response.send(proto::JoinProjectResponse {
1556 worktrees: worktrees.clone(),
1557 replica_id: replica_id.0 as u32,
1558 collaborators: collaborators.clone(),
1559 language_servers: project.language_servers.clone(),
1560 })?;
1561
1562 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1563 #[cfg(any(test, feature = "test-support"))]
1564 const MAX_CHUNK_SIZE: usize = 2;
1565 #[cfg(not(any(test, feature = "test-support")))]
1566 const MAX_CHUNK_SIZE: usize = 256;
1567
1568 // Stream this worktree's entries.
1569 let message = proto::UpdateWorktree {
1570 project_id: project_id.to_proto(),
1571 worktree_id,
1572 abs_path: worktree.abs_path.clone(),
1573 root_name: worktree.root_name,
1574 updated_entries: worktree.entries,
1575 removed_entries: Default::default(),
1576 scan_id: worktree.scan_id,
1577 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1578 updated_repositories: worktree.repository_entries.into_values().collect(),
1579 removed_repositories: Default::default(),
1580 };
1581 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1582 session.peer.send(session.connection_id, update.clone())?;
1583 }
1584
1585 // Stream this worktree's diagnostics.
1586 for summary in worktree.diagnostic_summaries {
1587 session.peer.send(
1588 session.connection_id,
1589 proto::UpdateDiagnosticSummary {
1590 project_id: project_id.to_proto(),
1591 worktree_id: worktree.id,
1592 summary: Some(summary),
1593 },
1594 )?;
1595 }
1596
1597 for settings_file in worktree.settings_files {
1598 session.peer.send(
1599 session.connection_id,
1600 proto::UpdateWorktreeSettings {
1601 project_id: project_id.to_proto(),
1602 worktree_id: worktree.id,
1603 path: settings_file.path,
1604 content: Some(settings_file.content),
1605 },
1606 )?;
1607 }
1608 }
1609
1610 for language_server in &project.language_servers {
1611 session.peer.send(
1612 session.connection_id,
1613 proto::UpdateLanguageServer {
1614 project_id: project_id.to_proto(),
1615 language_server_id: language_server.id,
1616 variant: Some(
1617 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1618 proto::LspDiskBasedDiagnosticsUpdated {},
1619 ),
1620 ),
1621 },
1622 )?;
1623 }
1624
1625 Ok(())
1626}
1627
1628async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1629 let sender_id = session.connection_id;
1630 let project_id = ProjectId::from_proto(request.project_id);
1631
1632 let (room, project) = &*session
1633 .db()
1634 .await
1635 .leave_project(project_id, sender_id)
1636 .await?;
1637 tracing::info!(
1638 %project_id,
1639 host_user_id = %project.host_user_id,
1640 host_connection_id = %project.host_connection_id,
1641 "leave project"
1642 );
1643
1644 project_left(&project, &session);
1645 room_updated(&room, &session.peer);
1646
1647 Ok(())
1648}
1649
1650async fn update_project(
1651 request: proto::UpdateProject,
1652 response: Response<proto::UpdateProject>,
1653 session: Session,
1654) -> Result<()> {
1655 let project_id = ProjectId::from_proto(request.project_id);
1656 let (room, guest_connection_ids) = &*session
1657 .db()
1658 .await
1659 .update_project(project_id, session.connection_id, &request.worktrees)
1660 .await?;
1661 broadcast(
1662 Some(session.connection_id),
1663 guest_connection_ids.iter().copied(),
1664 |connection_id| {
1665 session
1666 .peer
1667 .forward_send(session.connection_id, connection_id, request.clone())
1668 },
1669 );
1670 room_updated(&room, &session.peer);
1671 response.send(proto::Ack {})?;
1672
1673 Ok(())
1674}
1675
1676async fn update_worktree(
1677 request: proto::UpdateWorktree,
1678 response: Response<proto::UpdateWorktree>,
1679 session: Session,
1680) -> Result<()> {
1681 let guest_connection_ids = session
1682 .db()
1683 .await
1684 .update_worktree(&request, session.connection_id)
1685 .await?;
1686
1687 broadcast(
1688 Some(session.connection_id),
1689 guest_connection_ids.iter().copied(),
1690 |connection_id| {
1691 session
1692 .peer
1693 .forward_send(session.connection_id, connection_id, request.clone())
1694 },
1695 );
1696 response.send(proto::Ack {})?;
1697 Ok(())
1698}
1699
1700async fn update_diagnostic_summary(
1701 message: proto::UpdateDiagnosticSummary,
1702 session: Session,
1703) -> Result<()> {
1704 let guest_connection_ids = session
1705 .db()
1706 .await
1707 .update_diagnostic_summary(&message, session.connection_id)
1708 .await?;
1709
1710 broadcast(
1711 Some(session.connection_id),
1712 guest_connection_ids.iter().copied(),
1713 |connection_id| {
1714 session
1715 .peer
1716 .forward_send(session.connection_id, connection_id, message.clone())
1717 },
1718 );
1719
1720 Ok(())
1721}
1722
1723async fn update_worktree_settings(
1724 message: proto::UpdateWorktreeSettings,
1725 session: Session,
1726) -> Result<()> {
1727 let guest_connection_ids = session
1728 .db()
1729 .await
1730 .update_worktree_settings(&message, session.connection_id)
1731 .await?;
1732
1733 broadcast(
1734 Some(session.connection_id),
1735 guest_connection_ids.iter().copied(),
1736 |connection_id| {
1737 session
1738 .peer
1739 .forward_send(session.connection_id, connection_id, message.clone())
1740 },
1741 );
1742
1743 Ok(())
1744}
1745
1746async fn start_language_server(
1747 request: proto::StartLanguageServer,
1748 session: Session,
1749) -> Result<()> {
1750 let guest_connection_ids = session
1751 .db()
1752 .await
1753 .start_language_server(&request, session.connection_id)
1754 .await?;
1755
1756 broadcast(
1757 Some(session.connection_id),
1758 guest_connection_ids.iter().copied(),
1759 |connection_id| {
1760 session
1761 .peer
1762 .forward_send(session.connection_id, connection_id, request.clone())
1763 },
1764 );
1765 Ok(())
1766}
1767
1768async fn update_language_server(
1769 request: proto::UpdateLanguageServer,
1770 session: Session,
1771) -> Result<()> {
1772 let project_id = ProjectId::from_proto(request.project_id);
1773 let project_connection_ids = session
1774 .db()
1775 .await
1776 .project_connection_ids(project_id, session.connection_id)
1777 .await?;
1778 broadcast(
1779 Some(session.connection_id),
1780 project_connection_ids.iter().copied(),
1781 |connection_id| {
1782 session
1783 .peer
1784 .forward_send(session.connection_id, connection_id, request.clone())
1785 },
1786 );
1787 Ok(())
1788}
1789
1790async fn forward_read_only_project_request<T>(
1791 request: T,
1792 response: Response<T>,
1793 session: Session,
1794) -> Result<()>
1795where
1796 T: EntityMessage + RequestMessage,
1797{
1798 let project_id = ProjectId::from_proto(request.remote_entity_id());
1799 let host_connection_id = session
1800 .db()
1801 .await
1802 .host_for_read_only_project_request(project_id, session.connection_id)
1803 .await?;
1804 let payload = session
1805 .peer
1806 .forward_request(session.connection_id, host_connection_id, request)
1807 .await?;
1808 response.send(payload)?;
1809 Ok(())
1810}
1811
1812async fn forward_mutating_project_request<T>(
1813 request: T,
1814 response: Response<T>,
1815 session: Session,
1816) -> Result<()>
1817where
1818 T: EntityMessage + RequestMessage,
1819{
1820 let project_id = ProjectId::from_proto(request.remote_entity_id());
1821 let host_connection_id = session
1822 .db()
1823 .await
1824 .host_for_mutating_project_request(project_id, session.connection_id)
1825 .await?;
1826 let payload = session
1827 .peer
1828 .forward_request(session.connection_id, host_connection_id, request)
1829 .await?;
1830 response.send(payload)?;
1831 Ok(())
1832}
1833
1834async fn create_buffer_for_peer(
1835 request: proto::CreateBufferForPeer,
1836 session: Session,
1837) -> Result<()> {
1838 session
1839 .db()
1840 .await
1841 .check_user_is_project_host(
1842 ProjectId::from_proto(request.project_id),
1843 session.connection_id,
1844 )
1845 .await?;
1846 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1847 session
1848 .peer
1849 .forward_send(session.connection_id, peer_id.into(), request)?;
1850 Ok(())
1851}
1852
1853async fn update_buffer(
1854 request: proto::UpdateBuffer,
1855 response: Response<proto::UpdateBuffer>,
1856 session: Session,
1857) -> Result<()> {
1858 let project_id = ProjectId::from_proto(request.project_id);
1859 let mut guest_connection_ids;
1860 let mut host_connection_id = None;
1861
1862 let mut requires_write_permission = false;
1863
1864 for op in request.operations.iter() {
1865 match op.variant {
1866 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
1867 Some(_) => requires_write_permission = true,
1868 }
1869 }
1870
1871 {
1872 let collaborators = session
1873 .db()
1874 .await
1875 .project_collaborators_for_buffer_update(
1876 project_id,
1877 session.connection_id,
1878 requires_write_permission,
1879 )
1880 .await?;
1881 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1882 for collaborator in collaborators.iter() {
1883 if collaborator.is_host {
1884 host_connection_id = Some(collaborator.connection_id);
1885 } else {
1886 guest_connection_ids.push(collaborator.connection_id);
1887 }
1888 }
1889 }
1890 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1891
1892 broadcast(
1893 Some(session.connection_id),
1894 guest_connection_ids,
1895 |connection_id| {
1896 session
1897 .peer
1898 .forward_send(session.connection_id, connection_id, request.clone())
1899 },
1900 );
1901 if host_connection_id != session.connection_id {
1902 session
1903 .peer
1904 .forward_request(session.connection_id, host_connection_id, request.clone())
1905 .await?;
1906 }
1907
1908 response.send(proto::Ack {})?;
1909 Ok(())
1910}
1911
1912async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
1913 request: T,
1914 session: Session,
1915) -> Result<()> {
1916 let project_id = ProjectId::from_proto(request.remote_entity_id());
1917 let project_connection_ids = session
1918 .db()
1919 .await
1920 .project_connection_ids(project_id, session.connection_id)
1921 .await?;
1922
1923 broadcast(
1924 Some(session.connection_id),
1925 project_connection_ids.iter().copied(),
1926 |connection_id| {
1927 session
1928 .peer
1929 .forward_send(session.connection_id, connection_id, request.clone())
1930 },
1931 );
1932 Ok(())
1933}
1934
1935async fn follow(
1936 request: proto::Follow,
1937 response: Response<proto::Follow>,
1938 session: Session,
1939) -> Result<()> {
1940 let room_id = RoomId::from_proto(request.room_id);
1941 let project_id = request.project_id.map(ProjectId::from_proto);
1942 let leader_id = request
1943 .leader_id
1944 .ok_or_else(|| anyhow!("invalid leader id"))?
1945 .into();
1946 let follower_id = session.connection_id;
1947
1948 session
1949 .db()
1950 .await
1951 .check_room_participants(room_id, leader_id, session.connection_id)
1952 .await?;
1953
1954 let response_payload = session
1955 .peer
1956 .forward_request(session.connection_id, leader_id, request)
1957 .await?;
1958 response.send(response_payload)?;
1959
1960 if let Some(project_id) = project_id {
1961 let room = session
1962 .db()
1963 .await
1964 .follow(room_id, project_id, leader_id, follower_id)
1965 .await?;
1966 room_updated(&room, &session.peer);
1967 }
1968
1969 Ok(())
1970}
1971
1972async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1973 let room_id = RoomId::from_proto(request.room_id);
1974 let project_id = request.project_id.map(ProjectId::from_proto);
1975 let leader_id = request
1976 .leader_id
1977 .ok_or_else(|| anyhow!("invalid leader id"))?
1978 .into();
1979 let follower_id = session.connection_id;
1980
1981 session
1982 .db()
1983 .await
1984 .check_room_participants(room_id, leader_id, session.connection_id)
1985 .await?;
1986
1987 session
1988 .peer
1989 .forward_send(session.connection_id, leader_id, request)?;
1990
1991 if let Some(project_id) = project_id {
1992 let room = session
1993 .db()
1994 .await
1995 .unfollow(room_id, project_id, leader_id, follower_id)
1996 .await?;
1997 room_updated(&room, &session.peer);
1998 }
1999
2000 Ok(())
2001}
2002
2003async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2004 let room_id = RoomId::from_proto(request.room_id);
2005 let database = session.db.lock().await;
2006
2007 let connection_ids = if let Some(project_id) = request.project_id {
2008 let project_id = ProjectId::from_proto(project_id);
2009 database
2010 .project_connection_ids(project_id, session.connection_id)
2011 .await?
2012 } else {
2013 database
2014 .room_connection_ids(room_id, session.connection_id)
2015 .await?
2016 };
2017
2018 // For now, don't send view update messages back to that view's current leader.
2019 let connection_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2020 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2021 _ => None,
2022 });
2023
2024 for follower_peer_id in request.follower_ids.iter().copied() {
2025 let follower_connection_id = follower_peer_id.into();
2026 if Some(follower_peer_id) != connection_id_to_omit
2027 && connection_ids.contains(&follower_connection_id)
2028 {
2029 session.peer.forward_send(
2030 session.connection_id,
2031 follower_connection_id,
2032 request.clone(),
2033 )?;
2034 }
2035 }
2036 Ok(())
2037}
2038
2039async fn get_users(
2040 request: proto::GetUsers,
2041 response: Response<proto::GetUsers>,
2042 session: Session,
2043) -> Result<()> {
2044 let user_ids = request
2045 .user_ids
2046 .into_iter()
2047 .map(UserId::from_proto)
2048 .collect();
2049 let users = session
2050 .db()
2051 .await
2052 .get_users_by_ids(user_ids)
2053 .await?
2054 .into_iter()
2055 .map(|user| proto::User {
2056 id: user.id.to_proto(),
2057 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2058 github_login: user.github_login,
2059 })
2060 .collect();
2061 response.send(proto::UsersResponse { users })?;
2062 Ok(())
2063}
2064
2065async fn fuzzy_search_users(
2066 request: proto::FuzzySearchUsers,
2067 response: Response<proto::FuzzySearchUsers>,
2068 session: Session,
2069) -> Result<()> {
2070 let query = request.query;
2071 let users = match query.len() {
2072 0 => vec![],
2073 1 | 2 => session
2074 .db()
2075 .await
2076 .get_user_by_github_login(&query)
2077 .await?
2078 .into_iter()
2079 .collect(),
2080 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2081 };
2082 let users = users
2083 .into_iter()
2084 .filter(|user| user.id != session.user_id)
2085 .map(|user| proto::User {
2086 id: user.id.to_proto(),
2087 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2088 github_login: user.github_login,
2089 })
2090 .collect();
2091 response.send(proto::UsersResponse { users })?;
2092 Ok(())
2093}
2094
2095async fn request_contact(
2096 request: proto::RequestContact,
2097 response: Response<proto::RequestContact>,
2098 session: Session,
2099) -> Result<()> {
2100 let requester_id = session.user_id;
2101 let responder_id = UserId::from_proto(request.responder_id);
2102 if requester_id == responder_id {
2103 return Err(anyhow!("cannot add yourself as a contact"))?;
2104 }
2105
2106 let notifications = session
2107 .db()
2108 .await
2109 .send_contact_request(requester_id, responder_id)
2110 .await?;
2111
2112 // Update outgoing contact requests of requester
2113 let mut update = proto::UpdateContacts::default();
2114 update.outgoing_requests.push(responder_id.to_proto());
2115 for connection_id in session
2116 .connection_pool()
2117 .await
2118 .user_connection_ids(requester_id)
2119 {
2120 session.peer.send(connection_id, update.clone())?;
2121 }
2122
2123 // Update incoming contact requests of responder
2124 let mut update = proto::UpdateContacts::default();
2125 update
2126 .incoming_requests
2127 .push(proto::IncomingContactRequest {
2128 requester_id: requester_id.to_proto(),
2129 });
2130 let connection_pool = session.connection_pool().await;
2131 for connection_id in connection_pool.user_connection_ids(responder_id) {
2132 session.peer.send(connection_id, update.clone())?;
2133 }
2134
2135 send_notifications(&*connection_pool, &session.peer, notifications);
2136
2137 response.send(proto::Ack {})?;
2138 Ok(())
2139}
2140
2141async fn respond_to_contact_request(
2142 request: proto::RespondToContactRequest,
2143 response: Response<proto::RespondToContactRequest>,
2144 session: Session,
2145) -> Result<()> {
2146 let responder_id = session.user_id;
2147 let requester_id = UserId::from_proto(request.requester_id);
2148 let db = session.db().await;
2149 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2150 db.dismiss_contact_notification(responder_id, requester_id)
2151 .await?;
2152 } else {
2153 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2154
2155 let notifications = db
2156 .respond_to_contact_request(responder_id, requester_id, accept)
2157 .await?;
2158 let requester_busy = db.is_user_busy(requester_id).await?;
2159 let responder_busy = db.is_user_busy(responder_id).await?;
2160
2161 let pool = session.connection_pool().await;
2162 // Update responder with new contact
2163 let mut update = proto::UpdateContacts::default();
2164 if accept {
2165 update
2166 .contacts
2167 .push(contact_for_user(requester_id, requester_busy, &pool));
2168 }
2169 update
2170 .remove_incoming_requests
2171 .push(requester_id.to_proto());
2172 for connection_id in pool.user_connection_ids(responder_id) {
2173 session.peer.send(connection_id, update.clone())?;
2174 }
2175
2176 // Update requester with new contact
2177 let mut update = proto::UpdateContacts::default();
2178 if accept {
2179 update
2180 .contacts
2181 .push(contact_for_user(responder_id, responder_busy, &pool));
2182 }
2183 update
2184 .remove_outgoing_requests
2185 .push(responder_id.to_proto());
2186
2187 for connection_id in pool.user_connection_ids(requester_id) {
2188 session.peer.send(connection_id, update.clone())?;
2189 }
2190
2191 send_notifications(&*pool, &session.peer, notifications);
2192 }
2193
2194 response.send(proto::Ack {})?;
2195 Ok(())
2196}
2197
2198async fn remove_contact(
2199 request: proto::RemoveContact,
2200 response: Response<proto::RemoveContact>,
2201 session: Session,
2202) -> Result<()> {
2203 let requester_id = session.user_id;
2204 let responder_id = UserId::from_proto(request.user_id);
2205 let db = session.db().await;
2206 let (contact_accepted, deleted_notification_id) =
2207 db.remove_contact(requester_id, responder_id).await?;
2208
2209 let pool = session.connection_pool().await;
2210 // Update outgoing contact requests of requester
2211 let mut update = proto::UpdateContacts::default();
2212 if contact_accepted {
2213 update.remove_contacts.push(responder_id.to_proto());
2214 } else {
2215 update
2216 .remove_outgoing_requests
2217 .push(responder_id.to_proto());
2218 }
2219 for connection_id in pool.user_connection_ids(requester_id) {
2220 session.peer.send(connection_id, update.clone())?;
2221 }
2222
2223 // Update incoming contact requests of responder
2224 let mut update = proto::UpdateContacts::default();
2225 if contact_accepted {
2226 update.remove_contacts.push(requester_id.to_proto());
2227 } else {
2228 update
2229 .remove_incoming_requests
2230 .push(requester_id.to_proto());
2231 }
2232 for connection_id in pool.user_connection_ids(responder_id) {
2233 session.peer.send(connection_id, update.clone())?;
2234 if let Some(notification_id) = deleted_notification_id {
2235 session.peer.send(
2236 connection_id,
2237 proto::DeleteNotification {
2238 notification_id: notification_id.to_proto(),
2239 },
2240 )?;
2241 }
2242 }
2243
2244 response.send(proto::Ack {})?;
2245 Ok(())
2246}
2247
2248async fn create_channel(
2249 request: proto::CreateChannel,
2250 response: Response<proto::CreateChannel>,
2251 session: Session,
2252) -> Result<()> {
2253 let db = session.db().await;
2254
2255 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2256 let CreateChannelResult {
2257 channel,
2258 participants_to_update,
2259 } = db
2260 .create_channel(&request.name, parent_id, session.user_id)
2261 .await?;
2262
2263 response.send(proto::CreateChannelResponse {
2264 channel: Some(channel.to_proto()),
2265 parent_id: request.parent_id,
2266 })?;
2267
2268 let connection_pool = session.connection_pool().await;
2269 for (user_id, channels) in participants_to_update {
2270 let update = build_channels_update(channels, vec![]);
2271 for connection_id in connection_pool.user_connection_ids(user_id) {
2272 if user_id == session.user_id {
2273 continue;
2274 }
2275 session.peer.send(connection_id, update.clone())?;
2276 }
2277 }
2278
2279 Ok(())
2280}
2281
2282async fn delete_channel(
2283 request: proto::DeleteChannel,
2284 response: Response<proto::DeleteChannel>,
2285 session: Session,
2286) -> Result<()> {
2287 let db = session.db().await;
2288
2289 let channel_id = request.channel_id;
2290 let (removed_channels, member_ids) = db
2291 .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2292 .await?;
2293 response.send(proto::Ack {})?;
2294
2295 // Notify members of removed channels
2296 let mut update = proto::UpdateChannels::default();
2297 update
2298 .delete_channels
2299 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2300
2301 let connection_pool = session.connection_pool().await;
2302 for member_id in member_ids {
2303 for connection_id in connection_pool.user_connection_ids(member_id) {
2304 session.peer.send(connection_id, update.clone())?;
2305 }
2306 }
2307
2308 Ok(())
2309}
2310
2311async fn invite_channel_member(
2312 request: proto::InviteChannelMember,
2313 response: Response<proto::InviteChannelMember>,
2314 session: Session,
2315) -> Result<()> {
2316 let db = session.db().await;
2317 let channel_id = ChannelId::from_proto(request.channel_id);
2318 let invitee_id = UserId::from_proto(request.user_id);
2319 let InviteMemberResult {
2320 channel,
2321 notifications,
2322 } = db
2323 .invite_channel_member(
2324 channel_id,
2325 invitee_id,
2326 session.user_id,
2327 request.role().into(),
2328 )
2329 .await?;
2330
2331 let update = proto::UpdateChannels {
2332 channel_invitations: vec![channel.to_proto()],
2333 ..Default::default()
2334 };
2335
2336 let connection_pool = session.connection_pool().await;
2337 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2338 session.peer.send(connection_id, update.clone())?;
2339 }
2340
2341 send_notifications(&*connection_pool, &session.peer, notifications);
2342
2343 response.send(proto::Ack {})?;
2344 Ok(())
2345}
2346
2347async fn remove_channel_member(
2348 request: proto::RemoveChannelMember,
2349 response: Response<proto::RemoveChannelMember>,
2350 session: Session,
2351) -> Result<()> {
2352 let db = session.db().await;
2353 let channel_id = ChannelId::from_proto(request.channel_id);
2354 let member_id = UserId::from_proto(request.user_id);
2355
2356 let RemoveChannelMemberResult {
2357 membership_update,
2358 notification_id,
2359 } = db
2360 .remove_channel_member(channel_id, member_id, session.user_id)
2361 .await?;
2362
2363 let connection_pool = &session.connection_pool().await;
2364 notify_membership_updated(
2365 &connection_pool,
2366 membership_update,
2367 member_id,
2368 &session.peer,
2369 );
2370 for connection_id in connection_pool.user_connection_ids(member_id) {
2371 if let Some(notification_id) = notification_id {
2372 session
2373 .peer
2374 .send(
2375 connection_id,
2376 proto::DeleteNotification {
2377 notification_id: notification_id.to_proto(),
2378 },
2379 )
2380 .trace_err();
2381 }
2382 }
2383
2384 response.send(proto::Ack {})?;
2385 Ok(())
2386}
2387
2388async fn set_channel_visibility(
2389 request: proto::SetChannelVisibility,
2390 response: Response<proto::SetChannelVisibility>,
2391 session: Session,
2392) -> Result<()> {
2393 let db = session.db().await;
2394 let channel_id = ChannelId::from_proto(request.channel_id);
2395 let visibility = request.visibility().into();
2396
2397 let SetChannelVisibilityResult {
2398 participants_to_update,
2399 participants_to_remove,
2400 channels_to_remove,
2401 } = db
2402 .set_channel_visibility(channel_id, visibility, session.user_id)
2403 .await?;
2404
2405 let connection_pool = session.connection_pool().await;
2406 for (user_id, channels) in participants_to_update {
2407 let update = build_channels_update(channels, vec![]);
2408 for connection_id in connection_pool.user_connection_ids(user_id) {
2409 session.peer.send(connection_id, update.clone())?;
2410 }
2411 }
2412 for user_id in participants_to_remove {
2413 let update = proto::UpdateChannels {
2414 delete_channels: channels_to_remove.iter().map(|id| id.to_proto()).collect(),
2415 ..Default::default()
2416 };
2417 for connection_id in connection_pool.user_connection_ids(user_id) {
2418 session.peer.send(connection_id, update.clone())?;
2419 }
2420 }
2421
2422 response.send(proto::Ack {})?;
2423 Ok(())
2424}
2425
2426async fn set_channel_member_role(
2427 request: proto::SetChannelMemberRole,
2428 response: Response<proto::SetChannelMemberRole>,
2429 session: Session,
2430) -> Result<()> {
2431 let db = session.db().await;
2432 let channel_id = ChannelId::from_proto(request.channel_id);
2433 let member_id = UserId::from_proto(request.user_id);
2434 let result = db
2435 .set_channel_member_role(
2436 channel_id,
2437 session.user_id,
2438 member_id,
2439 request.role().into(),
2440 )
2441 .await?;
2442
2443 match result {
2444 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2445 let connection_pool = session.connection_pool().await;
2446 notify_membership_updated(
2447 &connection_pool,
2448 membership_update,
2449 member_id,
2450 &session.peer,
2451 )
2452 }
2453 db::SetMemberRoleResult::InviteUpdated(channel) => {
2454 let update = proto::UpdateChannels {
2455 channel_invitations: vec![channel.to_proto()],
2456 ..Default::default()
2457 };
2458
2459 for connection_id in session
2460 .connection_pool()
2461 .await
2462 .user_connection_ids(member_id)
2463 {
2464 session.peer.send(connection_id, update.clone())?;
2465 }
2466 }
2467 }
2468
2469 response.send(proto::Ack {})?;
2470 Ok(())
2471}
2472
2473async fn rename_channel(
2474 request: proto::RenameChannel,
2475 response: Response<proto::RenameChannel>,
2476 session: Session,
2477) -> Result<()> {
2478 let db = session.db().await;
2479 let channel_id = ChannelId::from_proto(request.channel_id);
2480 let RenameChannelResult {
2481 channel,
2482 participants_to_update,
2483 } = db
2484 .rename_channel(channel_id, session.user_id, &request.name)
2485 .await?;
2486
2487 response.send(proto::RenameChannelResponse {
2488 channel: Some(channel.to_proto()),
2489 })?;
2490
2491 let connection_pool = session.connection_pool().await;
2492 for (user_id, channel) in participants_to_update {
2493 for connection_id in connection_pool.user_connection_ids(user_id) {
2494 let update = proto::UpdateChannels {
2495 channels: vec![channel.to_proto()],
2496 ..Default::default()
2497 };
2498
2499 session.peer.send(connection_id, update.clone())?;
2500 }
2501 }
2502
2503 Ok(())
2504}
2505
2506async fn move_channel(
2507 request: proto::MoveChannel,
2508 response: Response<proto::MoveChannel>,
2509 session: Session,
2510) -> Result<()> {
2511 let channel_id = ChannelId::from_proto(request.channel_id);
2512 let to = request.to.map(ChannelId::from_proto);
2513
2514 let result = session
2515 .db()
2516 .await
2517 .move_channel(channel_id, to, session.user_id)
2518 .await?;
2519
2520 notify_channel_moved(result, session).await?;
2521
2522 response.send(Ack {})?;
2523 Ok(())
2524}
2525
2526async fn notify_channel_moved(result: Option<MoveChannelResult>, session: Session) -> Result<()> {
2527 let Some(MoveChannelResult {
2528 participants_to_remove,
2529 participants_to_update,
2530 moved_channels,
2531 }) = result
2532 else {
2533 return Ok(());
2534 };
2535 let moved_channels: Vec<u64> = moved_channels.iter().map(|id| id.to_proto()).collect();
2536
2537 let connection_pool = session.connection_pool().await;
2538 for (user_id, channels) in participants_to_update {
2539 let mut update = build_channels_update(channels, vec![]);
2540 update.delete_channels = moved_channels.clone();
2541 for connection_id in connection_pool.user_connection_ids(user_id) {
2542 session.peer.send(connection_id, update.clone())?;
2543 }
2544 }
2545
2546 for user_id in participants_to_remove {
2547 let update = proto::UpdateChannels {
2548 delete_channels: moved_channels.clone(),
2549 ..Default::default()
2550 };
2551 for connection_id in connection_pool.user_connection_ids(user_id) {
2552 session.peer.send(connection_id, update.clone())?;
2553 }
2554 }
2555 Ok(())
2556}
2557
2558async fn get_channel_members(
2559 request: proto::GetChannelMembers,
2560 response: Response<proto::GetChannelMembers>,
2561 session: Session,
2562) -> Result<()> {
2563 let db = session.db().await;
2564 let channel_id = ChannelId::from_proto(request.channel_id);
2565 let members = db
2566 .get_channel_participant_details(channel_id, session.user_id)
2567 .await?;
2568 response.send(proto::GetChannelMembersResponse { members })?;
2569 Ok(())
2570}
2571
2572async fn respond_to_channel_invite(
2573 request: proto::RespondToChannelInvite,
2574 response: Response<proto::RespondToChannelInvite>,
2575 session: Session,
2576) -> Result<()> {
2577 let db = session.db().await;
2578 let channel_id = ChannelId::from_proto(request.channel_id);
2579 let RespondToChannelInvite {
2580 membership_update,
2581 notifications,
2582 } = db
2583 .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2584 .await?;
2585
2586 let connection_pool = session.connection_pool().await;
2587 if let Some(membership_update) = membership_update {
2588 notify_membership_updated(
2589 &connection_pool,
2590 membership_update,
2591 session.user_id,
2592 &session.peer,
2593 );
2594 } else {
2595 let update = proto::UpdateChannels {
2596 remove_channel_invitations: vec![channel_id.to_proto()],
2597 ..Default::default()
2598 };
2599
2600 for connection_id in connection_pool.user_connection_ids(session.user_id) {
2601 session.peer.send(connection_id, update.clone())?;
2602 }
2603 };
2604
2605 send_notifications(&*connection_pool, &session.peer, notifications);
2606
2607 response.send(proto::Ack {})?;
2608
2609 Ok(())
2610}
2611
2612async fn join_channel(
2613 request: proto::JoinChannel,
2614 response: Response<proto::JoinChannel>,
2615 session: Session,
2616) -> Result<()> {
2617 let channel_id = ChannelId::from_proto(request.channel_id);
2618 join_channel_internal(channel_id, Box::new(response), session).await
2619}
2620
2621trait JoinChannelInternalResponse {
2622 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2623}
2624impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2625 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2626 Response::<proto::JoinChannel>::send(self, result)
2627 }
2628}
2629impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2630 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2631 Response::<proto::JoinRoom>::send(self, result)
2632 }
2633}
2634
2635async fn join_channel_internal(
2636 channel_id: ChannelId,
2637 response: Box<impl JoinChannelInternalResponse>,
2638 session: Session,
2639) -> Result<()> {
2640 let joined_room = {
2641 leave_room_for_session(&session).await?;
2642 let db = session.db().await;
2643
2644 let (joined_room, membership_updated, role) = db
2645 .join_channel(
2646 channel_id,
2647 session.user_id,
2648 session.connection_id,
2649 session.zed_environment.as_ref(),
2650 )
2651 .await?;
2652
2653 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2654 let (can_publish, token) = if role == ChannelRole::Guest {
2655 (
2656 false,
2657 live_kit
2658 .guest_token(
2659 &joined_room.room.live_kit_room,
2660 &session.user_id.to_string(),
2661 )
2662 .trace_err()?,
2663 )
2664 } else {
2665 (
2666 true,
2667 live_kit
2668 .room_token(
2669 &joined_room.room.live_kit_room,
2670 &session.user_id.to_string(),
2671 )
2672 .trace_err()?,
2673 )
2674 };
2675
2676 Some(LiveKitConnectionInfo {
2677 server_url: live_kit.url().into(),
2678 token,
2679 can_publish,
2680 })
2681 });
2682
2683 response.send(proto::JoinRoomResponse {
2684 room: Some(joined_room.room.clone()),
2685 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2686 live_kit_connection_info,
2687 })?;
2688
2689 let connection_pool = session.connection_pool().await;
2690 if let Some(membership_updated) = membership_updated {
2691 notify_membership_updated(
2692 &connection_pool,
2693 membership_updated,
2694 session.user_id,
2695 &session.peer,
2696 );
2697 }
2698
2699 room_updated(&joined_room.room, &session.peer);
2700
2701 joined_room
2702 };
2703
2704 channel_updated(
2705 channel_id,
2706 &joined_room.room,
2707 &joined_room.channel_members,
2708 &session.peer,
2709 &*session.connection_pool().await,
2710 );
2711
2712 update_user_contacts(session.user_id, &session).await?;
2713 Ok(())
2714}
2715
2716async fn join_channel_buffer(
2717 request: proto::JoinChannelBuffer,
2718 response: Response<proto::JoinChannelBuffer>,
2719 session: Session,
2720) -> Result<()> {
2721 let db = session.db().await;
2722 let channel_id = ChannelId::from_proto(request.channel_id);
2723
2724 let open_response = db
2725 .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2726 .await?;
2727
2728 let collaborators = open_response.collaborators.clone();
2729 response.send(open_response)?;
2730
2731 let update = UpdateChannelBufferCollaborators {
2732 channel_id: channel_id.to_proto(),
2733 collaborators: collaborators.clone(),
2734 };
2735 channel_buffer_updated(
2736 session.connection_id,
2737 collaborators
2738 .iter()
2739 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2740 &update,
2741 &session.peer,
2742 );
2743
2744 Ok(())
2745}
2746
2747async fn update_channel_buffer(
2748 request: proto::UpdateChannelBuffer,
2749 session: Session,
2750) -> Result<()> {
2751 let db = session.db().await;
2752 let channel_id = ChannelId::from_proto(request.channel_id);
2753
2754 let (collaborators, non_collaborators, epoch, version) = db
2755 .update_channel_buffer(channel_id, session.user_id, &request.operations)
2756 .await?;
2757
2758 channel_buffer_updated(
2759 session.connection_id,
2760 collaborators,
2761 &proto::UpdateChannelBuffer {
2762 channel_id: channel_id.to_proto(),
2763 operations: request.operations,
2764 },
2765 &session.peer,
2766 );
2767
2768 let pool = &*session.connection_pool().await;
2769
2770 broadcast(
2771 None,
2772 non_collaborators
2773 .iter()
2774 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2775 |peer_id| {
2776 session.peer.send(
2777 peer_id.into(),
2778 proto::UpdateChannels {
2779 unseen_channel_buffer_changes: vec![proto::UnseenChannelBufferChange {
2780 channel_id: channel_id.to_proto(),
2781 epoch: epoch as u64,
2782 version: version.clone(),
2783 }],
2784 ..Default::default()
2785 },
2786 )
2787 },
2788 );
2789
2790 Ok(())
2791}
2792
2793async fn rejoin_channel_buffers(
2794 request: proto::RejoinChannelBuffers,
2795 response: Response<proto::RejoinChannelBuffers>,
2796 session: Session,
2797) -> Result<()> {
2798 let db = session.db().await;
2799 let buffers = db
2800 .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2801 .await?;
2802
2803 for rejoined_buffer in &buffers {
2804 let collaborators_to_notify = rejoined_buffer
2805 .buffer
2806 .collaborators
2807 .iter()
2808 .filter_map(|c| Some(c.peer_id?.into()));
2809 channel_buffer_updated(
2810 session.connection_id,
2811 collaborators_to_notify,
2812 &proto::UpdateChannelBufferCollaborators {
2813 channel_id: rejoined_buffer.buffer.channel_id,
2814 collaborators: rejoined_buffer.buffer.collaborators.clone(),
2815 },
2816 &session.peer,
2817 );
2818 }
2819
2820 response.send(proto::RejoinChannelBuffersResponse {
2821 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2822 })?;
2823
2824 Ok(())
2825}
2826
2827async fn leave_channel_buffer(
2828 request: proto::LeaveChannelBuffer,
2829 response: Response<proto::LeaveChannelBuffer>,
2830 session: Session,
2831) -> Result<()> {
2832 let db = session.db().await;
2833 let channel_id = ChannelId::from_proto(request.channel_id);
2834
2835 let left_buffer = db
2836 .leave_channel_buffer(channel_id, session.connection_id)
2837 .await?;
2838
2839 response.send(Ack {})?;
2840
2841 channel_buffer_updated(
2842 session.connection_id,
2843 left_buffer.connections,
2844 &proto::UpdateChannelBufferCollaborators {
2845 channel_id: channel_id.to_proto(),
2846 collaborators: left_buffer.collaborators,
2847 },
2848 &session.peer,
2849 );
2850
2851 Ok(())
2852}
2853
2854fn channel_buffer_updated<T: EnvelopedMessage>(
2855 sender_id: ConnectionId,
2856 collaborators: impl IntoIterator<Item = ConnectionId>,
2857 message: &T,
2858 peer: &Peer,
2859) {
2860 broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2861 peer.send(peer_id.into(), message.clone())
2862 });
2863}
2864
2865fn send_notifications(
2866 connection_pool: &ConnectionPool,
2867 peer: &Peer,
2868 notifications: db::NotificationBatch,
2869) {
2870 for (user_id, notification) in notifications {
2871 for connection_id in connection_pool.user_connection_ids(user_id) {
2872 if let Err(error) = peer.send(
2873 connection_id,
2874 proto::AddNotification {
2875 notification: Some(notification.clone()),
2876 },
2877 ) {
2878 tracing::error!(
2879 "failed to send notification to {:?} {}",
2880 connection_id,
2881 error
2882 );
2883 }
2884 }
2885 }
2886}
2887
2888async fn send_channel_message(
2889 request: proto::SendChannelMessage,
2890 response: Response<proto::SendChannelMessage>,
2891 session: Session,
2892) -> Result<()> {
2893 // Validate the message body.
2894 let body = request.body.trim().to_string();
2895 if body.len() > MAX_MESSAGE_LEN {
2896 return Err(anyhow!("message is too long"))?;
2897 }
2898 if body.is_empty() {
2899 return Err(anyhow!("message can't be blank"))?;
2900 }
2901
2902 // TODO: adjust mentions if body is trimmed
2903
2904 let timestamp = OffsetDateTime::now_utc();
2905 let nonce = request
2906 .nonce
2907 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
2908
2909 let channel_id = ChannelId::from_proto(request.channel_id);
2910 let CreatedChannelMessage {
2911 message_id,
2912 participant_connection_ids,
2913 channel_members,
2914 notifications,
2915 } = session
2916 .db()
2917 .await
2918 .create_channel_message(
2919 channel_id,
2920 session.user_id,
2921 &body,
2922 &request.mentions,
2923 timestamp,
2924 nonce.clone().into(),
2925 )
2926 .await?;
2927 let message = proto::ChannelMessage {
2928 sender_id: session.user_id.to_proto(),
2929 id: message_id.to_proto(),
2930 body,
2931 mentions: request.mentions,
2932 timestamp: timestamp.unix_timestamp() as u64,
2933 nonce: Some(nonce),
2934 };
2935 broadcast(
2936 Some(session.connection_id),
2937 participant_connection_ids,
2938 |connection| {
2939 session.peer.send(
2940 connection,
2941 proto::ChannelMessageSent {
2942 channel_id: channel_id.to_proto(),
2943 message: Some(message.clone()),
2944 },
2945 )
2946 },
2947 );
2948 response.send(proto::SendChannelMessageResponse {
2949 message: Some(message),
2950 })?;
2951
2952 let pool = &*session.connection_pool().await;
2953 broadcast(
2954 None,
2955 channel_members
2956 .iter()
2957 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2958 |peer_id| {
2959 session.peer.send(
2960 peer_id.into(),
2961 proto::UpdateChannels {
2962 unseen_channel_messages: vec![proto::UnseenChannelMessage {
2963 channel_id: channel_id.to_proto(),
2964 message_id: message_id.to_proto(),
2965 }],
2966 ..Default::default()
2967 },
2968 )
2969 },
2970 );
2971 send_notifications(pool, &session.peer, notifications);
2972
2973 Ok(())
2974}
2975
2976async fn remove_channel_message(
2977 request: proto::RemoveChannelMessage,
2978 response: Response<proto::RemoveChannelMessage>,
2979 session: Session,
2980) -> Result<()> {
2981 let channel_id = ChannelId::from_proto(request.channel_id);
2982 let message_id = MessageId::from_proto(request.message_id);
2983 let connection_ids = session
2984 .db()
2985 .await
2986 .remove_channel_message(channel_id, message_id, session.user_id)
2987 .await?;
2988 broadcast(Some(session.connection_id), connection_ids, |connection| {
2989 session.peer.send(connection, request.clone())
2990 });
2991 response.send(proto::Ack {})?;
2992 Ok(())
2993}
2994
2995async fn acknowledge_channel_message(
2996 request: proto::AckChannelMessage,
2997 session: Session,
2998) -> Result<()> {
2999 let channel_id = ChannelId::from_proto(request.channel_id);
3000 let message_id = MessageId::from_proto(request.message_id);
3001 let notifications = session
3002 .db()
3003 .await
3004 .observe_channel_message(channel_id, session.user_id, message_id)
3005 .await?;
3006 send_notifications(
3007 &*session.connection_pool().await,
3008 &session.peer,
3009 notifications,
3010 );
3011 Ok(())
3012}
3013
3014async fn acknowledge_buffer_version(
3015 request: proto::AckBufferOperation,
3016 session: Session,
3017) -> Result<()> {
3018 let buffer_id = BufferId::from_proto(request.buffer_id);
3019 session
3020 .db()
3021 .await
3022 .observe_buffer_version(
3023 buffer_id,
3024 session.user_id,
3025 request.epoch as i32,
3026 &request.version,
3027 )
3028 .await?;
3029 Ok(())
3030}
3031
3032async fn join_channel_chat(
3033 request: proto::JoinChannelChat,
3034 response: Response<proto::JoinChannelChat>,
3035 session: Session,
3036) -> Result<()> {
3037 let channel_id = ChannelId::from_proto(request.channel_id);
3038
3039 let db = session.db().await;
3040 db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3041 .await?;
3042 let messages = db
3043 .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3044 .await?;
3045 response.send(proto::JoinChannelChatResponse {
3046 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3047 messages,
3048 })?;
3049 Ok(())
3050}
3051
3052async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3053 let channel_id = ChannelId::from_proto(request.channel_id);
3054 session
3055 .db()
3056 .await
3057 .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3058 .await?;
3059 Ok(())
3060}
3061
3062async fn get_channel_messages(
3063 request: proto::GetChannelMessages,
3064 response: Response<proto::GetChannelMessages>,
3065 session: Session,
3066) -> Result<()> {
3067 let channel_id = ChannelId::from_proto(request.channel_id);
3068 let messages = session
3069 .db()
3070 .await
3071 .get_channel_messages(
3072 channel_id,
3073 session.user_id,
3074 MESSAGE_COUNT_PER_PAGE,
3075 Some(MessageId::from_proto(request.before_message_id)),
3076 )
3077 .await?;
3078 response.send(proto::GetChannelMessagesResponse {
3079 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3080 messages,
3081 })?;
3082 Ok(())
3083}
3084
3085async fn get_channel_messages_by_id(
3086 request: proto::GetChannelMessagesById,
3087 response: Response<proto::GetChannelMessagesById>,
3088 session: Session,
3089) -> Result<()> {
3090 let message_ids = request
3091 .message_ids
3092 .iter()
3093 .map(|id| MessageId::from_proto(*id))
3094 .collect::<Vec<_>>();
3095 let messages = session
3096 .db()
3097 .await
3098 .get_channel_messages_by_id(session.user_id, &message_ids)
3099 .await?;
3100 response.send(proto::GetChannelMessagesResponse {
3101 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3102 messages,
3103 })?;
3104 Ok(())
3105}
3106
3107async fn get_notifications(
3108 request: proto::GetNotifications,
3109 response: Response<proto::GetNotifications>,
3110 session: Session,
3111) -> Result<()> {
3112 let notifications = session
3113 .db()
3114 .await
3115 .get_notifications(
3116 session.user_id,
3117 NOTIFICATION_COUNT_PER_PAGE,
3118 request
3119 .before_id
3120 .map(|id| db::NotificationId::from_proto(id)),
3121 )
3122 .await?;
3123 response.send(proto::GetNotificationsResponse {
3124 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3125 notifications,
3126 })?;
3127 Ok(())
3128}
3129
3130async fn mark_notification_as_read(
3131 request: proto::MarkNotificationRead,
3132 response: Response<proto::MarkNotificationRead>,
3133 session: Session,
3134) -> Result<()> {
3135 let database = &session.db().await;
3136 let notifications = database
3137 .mark_notification_as_read_by_id(
3138 session.user_id,
3139 NotificationId::from_proto(request.notification_id),
3140 )
3141 .await?;
3142 send_notifications(
3143 &*session.connection_pool().await,
3144 &session.peer,
3145 notifications,
3146 );
3147 response.send(proto::Ack {})?;
3148 Ok(())
3149}
3150
3151async fn get_private_user_info(
3152 _request: proto::GetPrivateUserInfo,
3153 response: Response<proto::GetPrivateUserInfo>,
3154 session: Session,
3155) -> Result<()> {
3156 let db = session.db().await;
3157
3158 let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3159 let user = db
3160 .get_user_by_id(session.user_id)
3161 .await?
3162 .ok_or_else(|| anyhow!("user not found"))?;
3163 let flags = db.get_user_flags(session.user_id).await?;
3164
3165 response.send(proto::GetPrivateUserInfoResponse {
3166 metrics_id,
3167 staff: user.admin,
3168 flags,
3169 })?;
3170 Ok(())
3171}
3172
3173fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3174 match message {
3175 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3176 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3177 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3178 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3179 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3180 code: frame.code.into(),
3181 reason: frame.reason,
3182 })),
3183 }
3184}
3185
3186fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3187 match message {
3188 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3189 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3190 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3191 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3192 AxumMessage::Close(frame) => {
3193 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3194 code: frame.code.into(),
3195 reason: frame.reason,
3196 }))
3197 }
3198 }
3199}
3200
3201fn notify_membership_updated(
3202 connection_pool: &ConnectionPool,
3203 result: MembershipUpdated,
3204 user_id: UserId,
3205 peer: &Peer,
3206) {
3207 let mut update = build_channels_update(result.new_channels, vec![]);
3208 update.delete_channels = result
3209 .removed_channels
3210 .into_iter()
3211 .map(|id| id.to_proto())
3212 .collect();
3213 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3214
3215 for connection_id in connection_pool.user_connection_ids(user_id) {
3216 peer.send(connection_id, update.clone()).trace_err();
3217 }
3218}
3219
3220fn build_channels_update(
3221 channels: ChannelsForUser,
3222 channel_invites: Vec<db::Channel>,
3223) -> proto::UpdateChannels {
3224 let mut update = proto::UpdateChannels::default();
3225
3226 for channel in channels.channels {
3227 update.channels.push(channel.to_proto());
3228 }
3229
3230 update.unseen_channel_buffer_changes = channels.unseen_buffer_changes;
3231 update.unseen_channel_messages = channels.channel_messages;
3232
3233 for (channel_id, participants) in channels.channel_participants {
3234 update
3235 .channel_participants
3236 .push(proto::ChannelParticipants {
3237 channel_id: channel_id.to_proto(),
3238 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3239 });
3240 }
3241
3242 for channel in channel_invites {
3243 update.channel_invitations.push(channel.to_proto());
3244 }
3245
3246 update
3247}
3248
3249fn build_initial_contacts_update(
3250 contacts: Vec<db::Contact>,
3251 pool: &ConnectionPool,
3252) -> proto::UpdateContacts {
3253 let mut update = proto::UpdateContacts::default();
3254
3255 for contact in contacts {
3256 match contact {
3257 db::Contact::Accepted { user_id, busy } => {
3258 update.contacts.push(contact_for_user(user_id, busy, &pool));
3259 }
3260 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3261 db::Contact::Incoming { user_id } => {
3262 update
3263 .incoming_requests
3264 .push(proto::IncomingContactRequest {
3265 requester_id: user_id.to_proto(),
3266 })
3267 }
3268 }
3269 }
3270
3271 update
3272}
3273
3274fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3275 proto::Contact {
3276 user_id: user_id.to_proto(),
3277 online: pool.is_user_online(user_id),
3278 busy,
3279 }
3280}
3281
3282fn room_updated(room: &proto::Room, peer: &Peer) {
3283 broadcast(
3284 None,
3285 room.participants
3286 .iter()
3287 .filter_map(|participant| Some(participant.peer_id?.into())),
3288 |peer_id| {
3289 peer.send(
3290 peer_id.into(),
3291 proto::RoomUpdated {
3292 room: Some(room.clone()),
3293 },
3294 )
3295 },
3296 );
3297}
3298
3299fn channel_updated(
3300 channel_id: ChannelId,
3301 room: &proto::Room,
3302 channel_members: &[UserId],
3303 peer: &Peer,
3304 pool: &ConnectionPool,
3305) {
3306 let participants = room
3307 .participants
3308 .iter()
3309 .map(|p| p.user_id)
3310 .collect::<Vec<_>>();
3311
3312 broadcast(
3313 None,
3314 channel_members
3315 .iter()
3316 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3317 |peer_id| {
3318 peer.send(
3319 peer_id.into(),
3320 proto::UpdateChannels {
3321 channel_participants: vec![proto::ChannelParticipants {
3322 channel_id: channel_id.to_proto(),
3323 participant_user_ids: participants.clone(),
3324 }],
3325 ..Default::default()
3326 },
3327 )
3328 },
3329 );
3330}
3331
3332async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3333 let db = session.db().await;
3334
3335 let contacts = db.get_contacts(user_id).await?;
3336 let busy = db.is_user_busy(user_id).await?;
3337
3338 let pool = session.connection_pool().await;
3339 let updated_contact = contact_for_user(user_id, busy, &pool);
3340 for contact in contacts {
3341 if let db::Contact::Accepted {
3342 user_id: contact_user_id,
3343 ..
3344 } = contact
3345 {
3346 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3347 session
3348 .peer
3349 .send(
3350 contact_conn_id,
3351 proto::UpdateContacts {
3352 contacts: vec![updated_contact.clone()],
3353 remove_contacts: Default::default(),
3354 incoming_requests: Default::default(),
3355 remove_incoming_requests: Default::default(),
3356 outgoing_requests: Default::default(),
3357 remove_outgoing_requests: Default::default(),
3358 },
3359 )
3360 .trace_err();
3361 }
3362 }
3363 }
3364 Ok(())
3365}
3366
3367async fn leave_room_for_session(session: &Session) -> Result<()> {
3368 let mut contacts_to_update = HashSet::default();
3369
3370 let room_id;
3371 let canceled_calls_to_user_ids;
3372 let live_kit_room;
3373 let delete_live_kit_room;
3374 let room;
3375 let channel_members;
3376 let channel_id;
3377
3378 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3379 contacts_to_update.insert(session.user_id);
3380
3381 for project in left_room.left_projects.values() {
3382 project_left(project, session);
3383 }
3384
3385 room_id = RoomId::from_proto(left_room.room.id);
3386 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3387 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3388 delete_live_kit_room = left_room.deleted;
3389 room = mem::take(&mut left_room.room);
3390 channel_members = mem::take(&mut left_room.channel_members);
3391 channel_id = left_room.channel_id;
3392
3393 room_updated(&room, &session.peer);
3394 } else {
3395 return Ok(());
3396 }
3397
3398 if let Some(channel_id) = channel_id {
3399 channel_updated(
3400 channel_id,
3401 &room,
3402 &channel_members,
3403 &session.peer,
3404 &*session.connection_pool().await,
3405 );
3406 }
3407
3408 {
3409 let pool = session.connection_pool().await;
3410 for canceled_user_id in canceled_calls_to_user_ids {
3411 for connection_id in pool.user_connection_ids(canceled_user_id) {
3412 session
3413 .peer
3414 .send(
3415 connection_id,
3416 proto::CallCanceled {
3417 room_id: room_id.to_proto(),
3418 },
3419 )
3420 .trace_err();
3421 }
3422 contacts_to_update.insert(canceled_user_id);
3423 }
3424 }
3425
3426 for contact_user_id in contacts_to_update {
3427 update_user_contacts(contact_user_id, &session).await?;
3428 }
3429
3430 if let Some(live_kit) = session.live_kit_client.as_ref() {
3431 live_kit
3432 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3433 .await
3434 .trace_err();
3435
3436 if delete_live_kit_room {
3437 live_kit.delete_room(live_kit_room).await.trace_err();
3438 }
3439 }
3440
3441 Ok(())
3442}
3443
3444async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3445 let left_channel_buffers = session
3446 .db()
3447 .await
3448 .leave_channel_buffers(session.connection_id)
3449 .await?;
3450
3451 for left_buffer in left_channel_buffers {
3452 channel_buffer_updated(
3453 session.connection_id,
3454 left_buffer.connections,
3455 &proto::UpdateChannelBufferCollaborators {
3456 channel_id: left_buffer.channel_id.to_proto(),
3457 collaborators: left_buffer.collaborators,
3458 },
3459 &session.peer,
3460 );
3461 }
3462
3463 Ok(())
3464}
3465
3466fn project_left(project: &db::LeftProject, session: &Session) {
3467 for connection_id in &project.connection_ids {
3468 if project.host_user_id == session.user_id {
3469 session
3470 .peer
3471 .send(
3472 *connection_id,
3473 proto::UnshareProject {
3474 project_id: project.id.to_proto(),
3475 },
3476 )
3477 .trace_err();
3478 } else {
3479 session
3480 .peer
3481 .send(
3482 *connection_id,
3483 proto::RemoveProjectCollaborator {
3484 project_id: project.id.to_proto(),
3485 peer_id: Some(session.connection_id.into()),
3486 },
3487 )
3488 .trace_err();
3489 }
3490 }
3491}
3492
3493pub trait ResultExt {
3494 type Ok;
3495
3496 fn trace_err(self) -> Option<Self::Ok>;
3497}
3498
3499impl<T, E> ResultExt for Result<T, E>
3500where
3501 E: std::fmt::Debug,
3502{
3503 type Ok = T;
3504
3505 fn trace_err(self) -> Option<T> {
3506 match self {
3507 Ok(value) => Some(value),
3508 Err(error) => {
3509 tracing::error!("{:?}", error);
3510 None
3511 }
3512 }
3513 }
3514}