1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{
6 self, BufferId, ChannelId, ChannelVisibility, ChannelsForUser, CreatedChannelMessage,
7 Database, MessageId, NotificationId, ProjectId, RoomId, ServerId, User, UserId,
8 },
9 executor::Executor,
10 AppState, Result,
11};
12use anyhow::anyhow;
13use async_tungstenite::tungstenite::{
14 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
15};
16use axum::{
17 body::Body,
18 extract::{
19 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
20 ConnectInfo, WebSocketUpgrade,
21 },
22 headers::{Header, HeaderName},
23 http::StatusCode,
24 middleware,
25 response::IntoResponse,
26 routing::get,
27 Extension, Router, TypedHeader,
28};
29use collections::{HashMap, HashSet};
30pub use connection_pool::ConnectionPool;
31use futures::{
32 channel::oneshot,
33 future::{self, BoxFuture},
34 stream::FuturesUnordered,
35 FutureExt, SinkExt, StreamExt, TryStreamExt,
36};
37use lazy_static::lazy_static;
38use prometheus::{register_int_gauge, IntGauge};
39use rpc::{
40 proto::{
41 self, Ack, AnyTypedEnvelope, ChannelEdge, EntityMessage, EnvelopedMessage,
42 LiveKitConnectionInfo, RequestMessage, UpdateChannelBufferCollaborators,
43 },
44 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
45};
46use serde::{Serialize, Serializer};
47use std::{
48 any::TypeId,
49 fmt,
50 future::Future,
51 marker::PhantomData,
52 mem,
53 net::SocketAddr,
54 ops::{Deref, DerefMut},
55 rc::Rc,
56 sync::{
57 atomic::{AtomicBool, Ordering::SeqCst},
58 Arc,
59 },
60 time::{Duration, Instant},
61};
62use time::OffsetDateTime;
63use tokio::sync::{watch, Semaphore};
64use tower::ServiceBuilder;
65use tracing::{info_span, instrument, Instrument};
66use util::channel::RELEASE_CHANNEL_NAME;
67
68pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
69pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
70
71const MESSAGE_COUNT_PER_PAGE: usize = 100;
72const MAX_MESSAGE_LEN: usize = 1024;
73const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
74
75lazy_static! {
76 static ref METRIC_CONNECTIONS: IntGauge =
77 register_int_gauge!("connections", "number of connections").unwrap();
78 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
79 "shared_projects",
80 "number of open projects with one or more guests"
81 )
82 .unwrap();
83}
84
85type MessageHandler =
86 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
87
88struct Response<R> {
89 peer: Arc<Peer>,
90 receipt: Receipt<R>,
91 responded: Arc<AtomicBool>,
92}
93
94impl<R: RequestMessage> Response<R> {
95 fn send(self, payload: R::Response) -> Result<()> {
96 self.responded.store(true, SeqCst);
97 self.peer.respond(self.receipt, payload)?;
98 Ok(())
99 }
100}
101
102#[derive(Clone)]
103struct Session {
104 user_id: UserId,
105 connection_id: ConnectionId,
106 db: Arc<tokio::sync::Mutex<DbHandle>>,
107 peer: Arc<Peer>,
108 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
109 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
110 executor: Executor,
111}
112
113impl Session {
114 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
115 #[cfg(test)]
116 tokio::task::yield_now().await;
117 let guard = self.db.lock().await;
118 #[cfg(test)]
119 tokio::task::yield_now().await;
120 guard
121 }
122
123 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
124 #[cfg(test)]
125 tokio::task::yield_now().await;
126 let guard = self.connection_pool.lock();
127 ConnectionPoolGuard {
128 guard,
129 _not_send: PhantomData,
130 }
131 }
132}
133
134impl fmt::Debug for Session {
135 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
136 f.debug_struct("Session")
137 .field("user_id", &self.user_id)
138 .field("connection_id", &self.connection_id)
139 .finish()
140 }
141}
142
143struct DbHandle(Arc<Database>);
144
145impl Deref for DbHandle {
146 type Target = Database;
147
148 fn deref(&self) -> &Self::Target {
149 self.0.as_ref()
150 }
151}
152
153pub struct Server {
154 id: parking_lot::Mutex<ServerId>,
155 peer: Arc<Peer>,
156 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
157 app_state: Arc<AppState>,
158 executor: Executor,
159 handlers: HashMap<TypeId, MessageHandler>,
160 teardown: watch::Sender<()>,
161}
162
163pub(crate) struct ConnectionPoolGuard<'a> {
164 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
165 _not_send: PhantomData<Rc<()>>,
166}
167
168#[derive(Serialize)]
169pub struct ServerSnapshot<'a> {
170 peer: &'a Peer,
171 #[serde(serialize_with = "serialize_deref")]
172 connection_pool: ConnectionPoolGuard<'a>,
173}
174
175pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
176where
177 S: Serializer,
178 T: Deref<Target = U>,
179 U: Serialize,
180{
181 Serialize::serialize(value.deref(), serializer)
182}
183
184impl Server {
185 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
186 let mut server = Self {
187 id: parking_lot::Mutex::new(id),
188 peer: Peer::new(id.0 as u32),
189 app_state,
190 executor,
191 connection_pool: Default::default(),
192 handlers: Default::default(),
193 teardown: watch::channel(()).0,
194 };
195
196 server
197 .add_request_handler(ping)
198 .add_request_handler(create_room)
199 .add_request_handler(join_room)
200 .add_request_handler(rejoin_room)
201 .add_request_handler(leave_room)
202 .add_request_handler(call)
203 .add_request_handler(cancel_call)
204 .add_message_handler(decline_call)
205 .add_request_handler(update_participant_location)
206 .add_request_handler(share_project)
207 .add_message_handler(unshare_project)
208 .add_request_handler(join_project)
209 .add_message_handler(leave_project)
210 .add_request_handler(update_project)
211 .add_request_handler(update_worktree)
212 .add_message_handler(start_language_server)
213 .add_message_handler(update_language_server)
214 .add_message_handler(update_diagnostic_summary)
215 .add_message_handler(update_worktree_settings)
216 .add_message_handler(refresh_inlay_hints)
217 .add_request_handler(forward_project_request::<proto::GetHover>)
218 .add_request_handler(forward_project_request::<proto::GetDefinition>)
219 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
220 .add_request_handler(forward_project_request::<proto::GetReferences>)
221 .add_request_handler(forward_project_request::<proto::SearchProject>)
222 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
223 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
224 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
225 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
226 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
227 .add_request_handler(forward_project_request::<proto::GetCompletions>)
228 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
229 .add_request_handler(forward_project_request::<proto::ResolveCompletionDocumentation>)
230 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
231 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
232 .add_request_handler(forward_project_request::<proto::PrepareRename>)
233 .add_request_handler(forward_project_request::<proto::PerformRename>)
234 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
235 .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
236 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
237 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
238 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
239 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
240 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
241 .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
242 .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
243 .add_request_handler(forward_project_request::<proto::InlayHints>)
244 .add_message_handler(create_buffer_for_peer)
245 .add_request_handler(update_buffer)
246 .add_message_handler(update_buffer_file)
247 .add_message_handler(buffer_reloaded)
248 .add_message_handler(buffer_saved)
249 .add_request_handler(forward_project_request::<proto::SaveBuffer>)
250 .add_request_handler(get_users)
251 .add_request_handler(fuzzy_search_users)
252 .add_request_handler(request_contact)
253 .add_request_handler(remove_contact)
254 .add_request_handler(respond_to_contact_request)
255 .add_request_handler(create_channel)
256 .add_request_handler(delete_channel)
257 .add_request_handler(invite_channel_member)
258 .add_request_handler(remove_channel_member)
259 .add_request_handler(set_channel_member_role)
260 .add_request_handler(set_channel_visibility)
261 .add_request_handler(rename_channel)
262 .add_request_handler(join_channel_buffer)
263 .add_request_handler(leave_channel_buffer)
264 .add_message_handler(update_channel_buffer)
265 .add_request_handler(rejoin_channel_buffers)
266 .add_request_handler(get_channel_members)
267 .add_request_handler(respond_to_channel_invite)
268 .add_request_handler(join_channel)
269 .add_request_handler(join_channel_chat)
270 .add_message_handler(leave_channel_chat)
271 .add_request_handler(send_channel_message)
272 .add_request_handler(remove_channel_message)
273 .add_request_handler(get_channel_messages)
274 .add_request_handler(get_channel_messages_by_id)
275 .add_request_handler(get_notifications)
276 .add_request_handler(mark_notification_as_read)
277 .add_request_handler(link_channel)
278 .add_request_handler(unlink_channel)
279 .add_request_handler(move_channel)
280 .add_request_handler(follow)
281 .add_message_handler(unfollow)
282 .add_message_handler(update_followers)
283 .add_message_handler(update_diff_base)
284 .add_request_handler(get_private_user_info)
285 .add_message_handler(acknowledge_channel_message)
286 .add_message_handler(acknowledge_buffer_version);
287
288 Arc::new(server)
289 }
290
291 pub async fn start(&self) -> Result<()> {
292 let server_id = *self.id.lock();
293 let app_state = self.app_state.clone();
294 let peer = self.peer.clone();
295 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
296 let pool = self.connection_pool.clone();
297 let live_kit_client = self.app_state.live_kit_client.clone();
298
299 let span = info_span!("start server");
300 self.executor.spawn_detached(
301 async move {
302 tracing::info!("waiting for cleanup timeout");
303 timeout.await;
304 tracing::info!("cleanup timeout expired, retrieving stale rooms");
305 if let Some((room_ids, channel_ids)) = app_state
306 .db
307 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
308 .await
309 .trace_err()
310 {
311 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
312 tracing::info!(
313 stale_channel_buffer_count = channel_ids.len(),
314 "retrieved stale channel buffers"
315 );
316
317 for channel_id in channel_ids {
318 if let Some(refreshed_channel_buffer) = app_state
319 .db
320 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
321 .await
322 .trace_err()
323 {
324 for connection_id in refreshed_channel_buffer.connection_ids {
325 peer.send(
326 connection_id,
327 proto::UpdateChannelBufferCollaborators {
328 channel_id: channel_id.to_proto(),
329 collaborators: refreshed_channel_buffer
330 .collaborators
331 .clone(),
332 },
333 )
334 .trace_err();
335 }
336 }
337 }
338
339 for room_id in room_ids {
340 let mut contacts_to_update = HashSet::default();
341 let mut canceled_calls_to_user_ids = Vec::new();
342 let mut live_kit_room = String::new();
343 let mut delete_live_kit_room = false;
344
345 if let Some(mut refreshed_room) = app_state
346 .db
347 .clear_stale_room_participants(room_id, server_id)
348 .await
349 .trace_err()
350 {
351 tracing::info!(
352 room_id = room_id.0,
353 new_participant_count = refreshed_room.room.participants.len(),
354 "refreshed room"
355 );
356 room_updated(&refreshed_room.room, &peer);
357 if let Some(channel_id) = refreshed_room.channel_id {
358 channel_updated(
359 channel_id,
360 &refreshed_room.room,
361 &refreshed_room.channel_members,
362 &peer,
363 &*pool.lock(),
364 );
365 }
366 contacts_to_update
367 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
368 contacts_to_update
369 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
370 canceled_calls_to_user_ids =
371 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
372 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
373 delete_live_kit_room = refreshed_room.room.participants.is_empty();
374 }
375
376 {
377 let pool = pool.lock();
378 for canceled_user_id in canceled_calls_to_user_ids {
379 for connection_id in pool.user_connection_ids(canceled_user_id) {
380 peer.send(
381 connection_id,
382 proto::CallCanceled {
383 room_id: room_id.to_proto(),
384 },
385 )
386 .trace_err();
387 }
388 }
389 }
390
391 for user_id in contacts_to_update {
392 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
393 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
394 if let Some((busy, contacts)) = busy.zip(contacts) {
395 let pool = pool.lock();
396 let updated_contact = contact_for_user(user_id, busy, &pool);
397 for contact in contacts {
398 if let db::Contact::Accepted {
399 user_id: contact_user_id,
400 ..
401 } = contact
402 {
403 for contact_conn_id in
404 pool.user_connection_ids(contact_user_id)
405 {
406 peer.send(
407 contact_conn_id,
408 proto::UpdateContacts {
409 contacts: vec![updated_contact.clone()],
410 remove_contacts: Default::default(),
411 incoming_requests: Default::default(),
412 remove_incoming_requests: Default::default(),
413 outgoing_requests: Default::default(),
414 remove_outgoing_requests: Default::default(),
415 },
416 )
417 .trace_err();
418 }
419 }
420 }
421 }
422 }
423
424 if let Some(live_kit) = live_kit_client.as_ref() {
425 if delete_live_kit_room {
426 live_kit.delete_room(live_kit_room).await.trace_err();
427 }
428 }
429 }
430 }
431
432 app_state
433 .db
434 .delete_stale_servers(&app_state.config.zed_environment, server_id)
435 .await
436 .trace_err();
437 }
438 .instrument(span),
439 );
440 Ok(())
441 }
442
443 pub fn teardown(&self) {
444 self.peer.teardown();
445 self.connection_pool.lock().reset();
446 let _ = self.teardown.send(());
447 }
448
449 #[cfg(test)]
450 pub fn reset(&self, id: ServerId) {
451 self.teardown();
452 *self.id.lock() = id;
453 self.peer.reset(id.0 as u32);
454 }
455
456 #[cfg(test)]
457 pub fn id(&self) -> ServerId {
458 *self.id.lock()
459 }
460
461 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
462 where
463 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
464 Fut: 'static + Send + Future<Output = Result<()>>,
465 M: EnvelopedMessage,
466 {
467 let prev_handler = self.handlers.insert(
468 TypeId::of::<M>(),
469 Box::new(move |envelope, session| {
470 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
471 let span = info_span!(
472 "handle message",
473 payload_type = envelope.payload_type_name()
474 );
475 span.in_scope(|| {
476 tracing::info!(
477 payload_type = envelope.payload_type_name(),
478 "message received"
479 );
480 });
481 let start_time = Instant::now();
482 let future = (handler)(*envelope, session);
483 async move {
484 let result = future.await;
485 let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
486 match result {
487 Err(error) => {
488 tracing::error!(%error, ?duration_ms, "error handling message")
489 }
490 Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
491 }
492 }
493 .instrument(span)
494 .boxed()
495 }),
496 );
497 if prev_handler.is_some() {
498 panic!("registered a handler for the same message twice");
499 }
500 self
501 }
502
503 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
504 where
505 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
506 Fut: 'static + Send + Future<Output = Result<()>>,
507 M: EnvelopedMessage,
508 {
509 self.add_handler(move |envelope, session| handler(envelope.payload, session));
510 self
511 }
512
513 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
514 where
515 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
516 Fut: Send + Future<Output = Result<()>>,
517 M: RequestMessage,
518 {
519 let handler = Arc::new(handler);
520 self.add_handler(move |envelope, session| {
521 let receipt = envelope.receipt();
522 let handler = handler.clone();
523 async move {
524 let peer = session.peer.clone();
525 let responded = Arc::new(AtomicBool::default());
526 let response = Response {
527 peer: peer.clone(),
528 responded: responded.clone(),
529 receipt,
530 };
531 match (handler)(envelope.payload, response, session).await {
532 Ok(()) => {
533 if responded.load(std::sync::atomic::Ordering::SeqCst) {
534 Ok(())
535 } else {
536 Err(anyhow!("handler did not send a response"))?
537 }
538 }
539 Err(error) => {
540 peer.respond_with_error(
541 receipt,
542 proto::Error {
543 message: error.to_string(),
544 },
545 )?;
546 Err(error)
547 }
548 }
549 }
550 })
551 }
552
553 pub fn handle_connection(
554 self: &Arc<Self>,
555 connection: Connection,
556 address: String,
557 user: User,
558 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
559 executor: Executor,
560 ) -> impl Future<Output = Result<()>> {
561 let this = self.clone();
562 let user_id = user.id;
563 let login = user.github_login;
564 let span = info_span!("handle connection", %user_id, %login, %address);
565 let mut teardown = self.teardown.subscribe();
566 async move {
567 let (connection_id, handle_io, mut incoming_rx) = this
568 .peer
569 .add_connection(connection, {
570 let executor = executor.clone();
571 move |duration| executor.sleep(duration)
572 });
573
574 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
575 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
576 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
577
578 if let Some(send_connection_id) = send_connection_id.take() {
579 let _ = send_connection_id.send(connection_id);
580 }
581
582 if !user.connected_once {
583 this.peer.send(connection_id, proto::ShowContacts {})?;
584 this.app_state.db.set_user_connected_once(user_id, true).await?;
585 }
586
587 let (contacts, channels_for_user, channel_invites) = future::try_join3(
588 this.app_state.db.get_contacts(user_id),
589 this.app_state.db.get_channels_for_user(user_id),
590 this.app_state.db.get_channel_invites_for_user(user_id),
591 ).await?;
592
593 {
594 let mut pool = this.connection_pool.lock();
595 pool.add_connection(connection_id, user_id, user.admin);
596 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
597 this.peer.send(connection_id, build_initial_channels_update(
598 channels_for_user,
599 channel_invites
600 ))?;
601 }
602
603 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
604 this.peer.send(connection_id, incoming_call)?;
605 }
606
607 let session = Session {
608 user_id,
609 connection_id,
610 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
611 peer: this.peer.clone(),
612 connection_pool: this.connection_pool.clone(),
613 live_kit_client: this.app_state.live_kit_client.clone(),
614 executor: executor.clone(),
615 };
616 update_user_contacts(user_id, &session).await?;
617
618 let handle_io = handle_io.fuse();
619 futures::pin_mut!(handle_io);
620
621 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
622 // This prevents deadlocks when e.g., client A performs a request to client B and
623 // client B performs a request to client A. If both clients stop processing further
624 // messages until their respective request completes, they won't have a chance to
625 // respond to the other client's request and cause a deadlock.
626 //
627 // This arrangement ensures we will attempt to process earlier messages first, but fall
628 // back to processing messages arrived later in the spirit of making progress.
629 let mut foreground_message_handlers = FuturesUnordered::new();
630 let concurrent_handlers = Arc::new(Semaphore::new(256));
631 loop {
632 let next_message = async {
633 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
634 let message = incoming_rx.next().await;
635 (permit, message)
636 }.fuse();
637 futures::pin_mut!(next_message);
638 futures::select_biased! {
639 _ = teardown.changed().fuse() => return Ok(()),
640 result = handle_io => {
641 if let Err(error) = result {
642 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
643 }
644 break;
645 }
646 _ = foreground_message_handlers.next() => {}
647 next_message = next_message => {
648 let (permit, message) = next_message;
649 if let Some(message) = message {
650 let type_name = message.payload_type_name();
651 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
652 let span_enter = span.enter();
653 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
654 let is_background = message.is_background();
655 let handle_message = (handler)(message, session.clone());
656 drop(span_enter);
657
658 let handle_message = async move {
659 handle_message.await;
660 drop(permit);
661 }.instrument(span);
662 if is_background {
663 executor.spawn_detached(handle_message);
664 } else {
665 foreground_message_handlers.push(handle_message);
666 }
667 } else {
668 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
669 }
670 } else {
671 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
672 break;
673 }
674 }
675 }
676 }
677
678 drop(foreground_message_handlers);
679 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
680 if let Err(error) = connection_lost(session, teardown, executor).await {
681 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
682 }
683
684 Ok(())
685 }.instrument(span)
686 }
687
688 pub async fn invite_code_redeemed(
689 self: &Arc<Self>,
690 inviter_id: UserId,
691 invitee_id: UserId,
692 ) -> Result<()> {
693 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
694 if let Some(code) = &user.invite_code {
695 let pool = self.connection_pool.lock();
696 let invitee_contact = contact_for_user(invitee_id, false, &pool);
697 for connection_id in pool.user_connection_ids(inviter_id) {
698 self.peer.send(
699 connection_id,
700 proto::UpdateContacts {
701 contacts: vec![invitee_contact.clone()],
702 ..Default::default()
703 },
704 )?;
705 self.peer.send(
706 connection_id,
707 proto::UpdateInviteInfo {
708 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
709 count: user.invite_count as u32,
710 },
711 )?;
712 }
713 }
714 }
715 Ok(())
716 }
717
718 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
719 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
720 if let Some(invite_code) = &user.invite_code {
721 let pool = self.connection_pool.lock();
722 for connection_id in pool.user_connection_ids(user_id) {
723 self.peer.send(
724 connection_id,
725 proto::UpdateInviteInfo {
726 url: format!(
727 "{}{}",
728 self.app_state.config.invite_link_prefix, invite_code
729 ),
730 count: user.invite_count as u32,
731 },
732 )?;
733 }
734 }
735 }
736 Ok(())
737 }
738
739 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
740 ServerSnapshot {
741 connection_pool: ConnectionPoolGuard {
742 guard: self.connection_pool.lock(),
743 _not_send: PhantomData,
744 },
745 peer: &self.peer,
746 }
747 }
748}
749
750impl<'a> Deref for ConnectionPoolGuard<'a> {
751 type Target = ConnectionPool;
752
753 fn deref(&self) -> &Self::Target {
754 &*self.guard
755 }
756}
757
758impl<'a> DerefMut for ConnectionPoolGuard<'a> {
759 fn deref_mut(&mut self) -> &mut Self::Target {
760 &mut *self.guard
761 }
762}
763
764impl<'a> Drop for ConnectionPoolGuard<'a> {
765 fn drop(&mut self) {
766 #[cfg(test)]
767 self.check_invariants();
768 }
769}
770
771fn broadcast<F>(
772 sender_id: Option<ConnectionId>,
773 receiver_ids: impl IntoIterator<Item = ConnectionId>,
774 mut f: F,
775) where
776 F: FnMut(ConnectionId) -> anyhow::Result<()>,
777{
778 for receiver_id in receiver_ids {
779 if Some(receiver_id) != sender_id {
780 if let Err(error) = f(receiver_id) {
781 tracing::error!("failed to send to {:?} {}", receiver_id, error);
782 }
783 }
784 }
785}
786
787lazy_static! {
788 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
789}
790
791pub struct ProtocolVersion(u32);
792
793impl Header for ProtocolVersion {
794 fn name() -> &'static HeaderName {
795 &ZED_PROTOCOL_VERSION
796 }
797
798 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
799 where
800 Self: Sized,
801 I: Iterator<Item = &'i axum::http::HeaderValue>,
802 {
803 let version = values
804 .next()
805 .ok_or_else(axum::headers::Error::invalid)?
806 .to_str()
807 .map_err(|_| axum::headers::Error::invalid())?
808 .parse()
809 .map_err(|_| axum::headers::Error::invalid())?;
810 Ok(Self(version))
811 }
812
813 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
814 values.extend([self.0.to_string().parse().unwrap()]);
815 }
816}
817
818pub fn routes(server: Arc<Server>) -> Router<Body> {
819 Router::new()
820 .route("/rpc", get(handle_websocket_request))
821 .layer(
822 ServiceBuilder::new()
823 .layer(Extension(server.app_state.clone()))
824 .layer(middleware::from_fn(auth::validate_header)),
825 )
826 .route("/metrics", get(handle_metrics))
827 .layer(Extension(server))
828}
829
830pub async fn handle_websocket_request(
831 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
832 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
833 Extension(server): Extension<Arc<Server>>,
834 Extension(user): Extension<User>,
835 ws: WebSocketUpgrade,
836) -> axum::response::Response {
837 if protocol_version != rpc::PROTOCOL_VERSION {
838 return (
839 StatusCode::UPGRADE_REQUIRED,
840 "client must be upgraded".to_string(),
841 )
842 .into_response();
843 }
844 let socket_address = socket_address.to_string();
845 ws.on_upgrade(move |socket| {
846 use util::ResultExt;
847 let socket = socket
848 .map_ok(to_tungstenite_message)
849 .err_into()
850 .with(|message| async move { Ok(to_axum_message(message)) });
851 let connection = Connection::new(Box::pin(socket));
852 async move {
853 server
854 .handle_connection(connection, socket_address, user, None, Executor::Production)
855 .await
856 .log_err();
857 }
858 })
859}
860
861pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
862 let connections = server
863 .connection_pool
864 .lock()
865 .connections()
866 .filter(|connection| !connection.admin)
867 .count();
868
869 METRIC_CONNECTIONS.set(connections as _);
870
871 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
872 METRIC_SHARED_PROJECTS.set(shared_projects as _);
873
874 let encoder = prometheus::TextEncoder::new();
875 let metric_families = prometheus::gather();
876 let encoded_metrics = encoder
877 .encode_to_string(&metric_families)
878 .map_err(|err| anyhow!("{}", err))?;
879 Ok(encoded_metrics)
880}
881
882#[instrument(err, skip(executor))]
883async fn connection_lost(
884 session: Session,
885 mut teardown: watch::Receiver<()>,
886 executor: Executor,
887) -> Result<()> {
888 session.peer.disconnect(session.connection_id);
889 session
890 .connection_pool()
891 .await
892 .remove_connection(session.connection_id)?;
893
894 session
895 .db()
896 .await
897 .connection_lost(session.connection_id)
898 .await
899 .trace_err();
900
901 futures::select_biased! {
902 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
903 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
904 leave_room_for_session(&session).await.trace_err();
905 leave_channel_buffers_for_session(&session)
906 .await
907 .trace_err();
908
909 if !session
910 .connection_pool()
911 .await
912 .is_user_online(session.user_id)
913 {
914 let db = session.db().await;
915 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
916 room_updated(&room, &session.peer);
917 }
918 }
919
920 update_user_contacts(session.user_id, &session).await?;
921 }
922 _ = teardown.changed().fuse() => {}
923 }
924
925 Ok(())
926}
927
928async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
929 response.send(proto::Ack {})?;
930 Ok(())
931}
932
933async fn create_room(
934 _request: proto::CreateRoom,
935 response: Response<proto::CreateRoom>,
936 session: Session,
937) -> Result<()> {
938 let live_kit_room = nanoid::nanoid!(30);
939
940 let live_kit_connection_info = {
941 let live_kit_room = live_kit_room.clone();
942 let live_kit = session.live_kit_client.as_ref();
943
944 util::async_iife!({
945 let live_kit = live_kit?;
946
947 let token = live_kit
948 .room_token(&live_kit_room, &session.user_id.to_string())
949 .trace_err()?;
950
951 Some(proto::LiveKitConnectionInfo {
952 server_url: live_kit.url().into(),
953 token,
954 })
955 })
956 }
957 .await;
958
959 let room = session
960 .db()
961 .await
962 .create_room(
963 session.user_id,
964 session.connection_id,
965 &live_kit_room,
966 RELEASE_CHANNEL_NAME.as_str(),
967 )
968 .await?;
969
970 response.send(proto::CreateRoomResponse {
971 room: Some(room.clone()),
972 live_kit_connection_info,
973 })?;
974
975 update_user_contacts(session.user_id, &session).await?;
976 Ok(())
977}
978
979async fn join_room(
980 request: proto::JoinRoom,
981 response: Response<proto::JoinRoom>,
982 session: Session,
983) -> Result<()> {
984 let room_id = RoomId::from_proto(request.id);
985
986 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
987
988 if let Some(channel_id) = channel_id {
989 return join_channel_internal(channel_id, Box::new(response), session).await;
990 }
991
992 let joined_room = {
993 let room = session
994 .db()
995 .await
996 .join_room(
997 room_id,
998 session.user_id,
999 session.connection_id,
1000 RELEASE_CHANNEL_NAME.as_str(),
1001 )
1002 .await?;
1003 room_updated(&room.room, &session.peer);
1004 room.into_inner()
1005 };
1006
1007 for connection_id in session
1008 .connection_pool()
1009 .await
1010 .user_connection_ids(session.user_id)
1011 {
1012 session
1013 .peer
1014 .send(
1015 connection_id,
1016 proto::CallCanceled {
1017 room_id: room_id.to_proto(),
1018 },
1019 )
1020 .trace_err();
1021 }
1022
1023 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1024 if let Some(token) = live_kit
1025 .room_token(
1026 &joined_room.room.live_kit_room,
1027 &session.user_id.to_string(),
1028 )
1029 .trace_err()
1030 {
1031 Some(proto::LiveKitConnectionInfo {
1032 server_url: live_kit.url().into(),
1033 token,
1034 })
1035 } else {
1036 None
1037 }
1038 } else {
1039 None
1040 };
1041
1042 response.send(proto::JoinRoomResponse {
1043 room: Some(joined_room.room),
1044 channel_id: None,
1045 live_kit_connection_info,
1046 })?;
1047
1048 update_user_contacts(session.user_id, &session).await?;
1049 Ok(())
1050}
1051
1052async fn rejoin_room(
1053 request: proto::RejoinRoom,
1054 response: Response<proto::RejoinRoom>,
1055 session: Session,
1056) -> Result<()> {
1057 let room;
1058 let channel_id;
1059 let channel_members;
1060 {
1061 let mut rejoined_room = session
1062 .db()
1063 .await
1064 .rejoin_room(request, session.user_id, session.connection_id)
1065 .await?;
1066
1067 response.send(proto::RejoinRoomResponse {
1068 room: Some(rejoined_room.room.clone()),
1069 reshared_projects: rejoined_room
1070 .reshared_projects
1071 .iter()
1072 .map(|project| proto::ResharedProject {
1073 id: project.id.to_proto(),
1074 collaborators: project
1075 .collaborators
1076 .iter()
1077 .map(|collaborator| collaborator.to_proto())
1078 .collect(),
1079 })
1080 .collect(),
1081 rejoined_projects: rejoined_room
1082 .rejoined_projects
1083 .iter()
1084 .map(|rejoined_project| proto::RejoinedProject {
1085 id: rejoined_project.id.to_proto(),
1086 worktrees: rejoined_project
1087 .worktrees
1088 .iter()
1089 .map(|worktree| proto::WorktreeMetadata {
1090 id: worktree.id,
1091 root_name: worktree.root_name.clone(),
1092 visible: worktree.visible,
1093 abs_path: worktree.abs_path.clone(),
1094 })
1095 .collect(),
1096 collaborators: rejoined_project
1097 .collaborators
1098 .iter()
1099 .map(|collaborator| collaborator.to_proto())
1100 .collect(),
1101 language_servers: rejoined_project.language_servers.clone(),
1102 })
1103 .collect(),
1104 })?;
1105 room_updated(&rejoined_room.room, &session.peer);
1106
1107 for project in &rejoined_room.reshared_projects {
1108 for collaborator in &project.collaborators {
1109 session
1110 .peer
1111 .send(
1112 collaborator.connection_id,
1113 proto::UpdateProjectCollaborator {
1114 project_id: project.id.to_proto(),
1115 old_peer_id: Some(project.old_connection_id.into()),
1116 new_peer_id: Some(session.connection_id.into()),
1117 },
1118 )
1119 .trace_err();
1120 }
1121
1122 broadcast(
1123 Some(session.connection_id),
1124 project
1125 .collaborators
1126 .iter()
1127 .map(|collaborator| collaborator.connection_id),
1128 |connection_id| {
1129 session.peer.forward_send(
1130 session.connection_id,
1131 connection_id,
1132 proto::UpdateProject {
1133 project_id: project.id.to_proto(),
1134 worktrees: project.worktrees.clone(),
1135 },
1136 )
1137 },
1138 );
1139 }
1140
1141 for project in &rejoined_room.rejoined_projects {
1142 for collaborator in &project.collaborators {
1143 session
1144 .peer
1145 .send(
1146 collaborator.connection_id,
1147 proto::UpdateProjectCollaborator {
1148 project_id: project.id.to_proto(),
1149 old_peer_id: Some(project.old_connection_id.into()),
1150 new_peer_id: Some(session.connection_id.into()),
1151 },
1152 )
1153 .trace_err();
1154 }
1155 }
1156
1157 for project in &mut rejoined_room.rejoined_projects {
1158 for worktree in mem::take(&mut project.worktrees) {
1159 #[cfg(any(test, feature = "test-support"))]
1160 const MAX_CHUNK_SIZE: usize = 2;
1161 #[cfg(not(any(test, feature = "test-support")))]
1162 const MAX_CHUNK_SIZE: usize = 256;
1163
1164 // Stream this worktree's entries.
1165 let message = proto::UpdateWorktree {
1166 project_id: project.id.to_proto(),
1167 worktree_id: worktree.id,
1168 abs_path: worktree.abs_path.clone(),
1169 root_name: worktree.root_name,
1170 updated_entries: worktree.updated_entries,
1171 removed_entries: worktree.removed_entries,
1172 scan_id: worktree.scan_id,
1173 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1174 updated_repositories: worktree.updated_repositories,
1175 removed_repositories: worktree.removed_repositories,
1176 };
1177 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1178 session.peer.send(session.connection_id, update.clone())?;
1179 }
1180
1181 // Stream this worktree's diagnostics.
1182 for summary in worktree.diagnostic_summaries {
1183 session.peer.send(
1184 session.connection_id,
1185 proto::UpdateDiagnosticSummary {
1186 project_id: project.id.to_proto(),
1187 worktree_id: worktree.id,
1188 summary: Some(summary),
1189 },
1190 )?;
1191 }
1192
1193 for settings_file in worktree.settings_files {
1194 session.peer.send(
1195 session.connection_id,
1196 proto::UpdateWorktreeSettings {
1197 project_id: project.id.to_proto(),
1198 worktree_id: worktree.id,
1199 path: settings_file.path,
1200 content: Some(settings_file.content),
1201 },
1202 )?;
1203 }
1204 }
1205
1206 for language_server in &project.language_servers {
1207 session.peer.send(
1208 session.connection_id,
1209 proto::UpdateLanguageServer {
1210 project_id: project.id.to_proto(),
1211 language_server_id: language_server.id,
1212 variant: Some(
1213 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1214 proto::LspDiskBasedDiagnosticsUpdated {},
1215 ),
1216 ),
1217 },
1218 )?;
1219 }
1220 }
1221
1222 let rejoined_room = rejoined_room.into_inner();
1223
1224 room = rejoined_room.room;
1225 channel_id = rejoined_room.channel_id;
1226 channel_members = rejoined_room.channel_members;
1227 }
1228
1229 if let Some(channel_id) = channel_id {
1230 channel_updated(
1231 channel_id,
1232 &room,
1233 &channel_members,
1234 &session.peer,
1235 &*session.connection_pool().await,
1236 );
1237 }
1238
1239 update_user_contacts(session.user_id, &session).await?;
1240 Ok(())
1241}
1242
1243async fn leave_room(
1244 _: proto::LeaveRoom,
1245 response: Response<proto::LeaveRoom>,
1246 session: Session,
1247) -> Result<()> {
1248 leave_room_for_session(&session).await?;
1249 response.send(proto::Ack {})?;
1250 Ok(())
1251}
1252
1253async fn call(
1254 request: proto::Call,
1255 response: Response<proto::Call>,
1256 session: Session,
1257) -> Result<()> {
1258 let room_id = RoomId::from_proto(request.room_id);
1259 let calling_user_id = session.user_id;
1260 let calling_connection_id = session.connection_id;
1261 let called_user_id = UserId::from_proto(request.called_user_id);
1262 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1263 if !session
1264 .db()
1265 .await
1266 .has_contact(calling_user_id, called_user_id)
1267 .await?
1268 {
1269 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1270 }
1271
1272 let incoming_call = {
1273 let (room, incoming_call) = &mut *session
1274 .db()
1275 .await
1276 .call(
1277 room_id,
1278 calling_user_id,
1279 calling_connection_id,
1280 called_user_id,
1281 initial_project_id,
1282 )
1283 .await?;
1284 room_updated(&room, &session.peer);
1285 mem::take(incoming_call)
1286 };
1287 update_user_contacts(called_user_id, &session).await?;
1288
1289 let mut calls = session
1290 .connection_pool()
1291 .await
1292 .user_connection_ids(called_user_id)
1293 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1294 .collect::<FuturesUnordered<_>>();
1295
1296 while let Some(call_response) = calls.next().await {
1297 match call_response.as_ref() {
1298 Ok(_) => {
1299 response.send(proto::Ack {})?;
1300 return Ok(());
1301 }
1302 Err(_) => {
1303 call_response.trace_err();
1304 }
1305 }
1306 }
1307
1308 {
1309 let room = session
1310 .db()
1311 .await
1312 .call_failed(room_id, called_user_id)
1313 .await?;
1314 room_updated(&room, &session.peer);
1315 }
1316 update_user_contacts(called_user_id, &session).await?;
1317
1318 Err(anyhow!("failed to ring user"))?
1319}
1320
1321async fn cancel_call(
1322 request: proto::CancelCall,
1323 response: Response<proto::CancelCall>,
1324 session: Session,
1325) -> Result<()> {
1326 let called_user_id = UserId::from_proto(request.called_user_id);
1327 let room_id = RoomId::from_proto(request.room_id);
1328 {
1329 let room = session
1330 .db()
1331 .await
1332 .cancel_call(room_id, session.connection_id, called_user_id)
1333 .await?;
1334 room_updated(&room, &session.peer);
1335 }
1336
1337 for connection_id in session
1338 .connection_pool()
1339 .await
1340 .user_connection_ids(called_user_id)
1341 {
1342 session
1343 .peer
1344 .send(
1345 connection_id,
1346 proto::CallCanceled {
1347 room_id: room_id.to_proto(),
1348 },
1349 )
1350 .trace_err();
1351 }
1352 response.send(proto::Ack {})?;
1353
1354 update_user_contacts(called_user_id, &session).await?;
1355 Ok(())
1356}
1357
1358async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1359 let room_id = RoomId::from_proto(message.room_id);
1360 {
1361 let room = session
1362 .db()
1363 .await
1364 .decline_call(Some(room_id), session.user_id)
1365 .await?
1366 .ok_or_else(|| anyhow!("failed to decline call"))?;
1367 room_updated(&room, &session.peer);
1368 }
1369
1370 for connection_id in session
1371 .connection_pool()
1372 .await
1373 .user_connection_ids(session.user_id)
1374 {
1375 session
1376 .peer
1377 .send(
1378 connection_id,
1379 proto::CallCanceled {
1380 room_id: room_id.to_proto(),
1381 },
1382 )
1383 .trace_err();
1384 }
1385 update_user_contacts(session.user_id, &session).await?;
1386 Ok(())
1387}
1388
1389async fn update_participant_location(
1390 request: proto::UpdateParticipantLocation,
1391 response: Response<proto::UpdateParticipantLocation>,
1392 session: Session,
1393) -> Result<()> {
1394 let room_id = RoomId::from_proto(request.room_id);
1395 let location = request
1396 .location
1397 .ok_or_else(|| anyhow!("invalid location"))?;
1398
1399 let db = session.db().await;
1400 let room = db
1401 .update_room_participant_location(room_id, session.connection_id, location)
1402 .await?;
1403
1404 room_updated(&room, &session.peer);
1405 response.send(proto::Ack {})?;
1406 Ok(())
1407}
1408
1409async fn share_project(
1410 request: proto::ShareProject,
1411 response: Response<proto::ShareProject>,
1412 session: Session,
1413) -> Result<()> {
1414 let (project_id, room) = &*session
1415 .db()
1416 .await
1417 .share_project(
1418 RoomId::from_proto(request.room_id),
1419 session.connection_id,
1420 &request.worktrees,
1421 )
1422 .await?;
1423 response.send(proto::ShareProjectResponse {
1424 project_id: project_id.to_proto(),
1425 })?;
1426 room_updated(&room, &session.peer);
1427
1428 Ok(())
1429}
1430
1431async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1432 let project_id = ProjectId::from_proto(message.project_id);
1433
1434 let (room, guest_connection_ids) = &*session
1435 .db()
1436 .await
1437 .unshare_project(project_id, session.connection_id)
1438 .await?;
1439
1440 broadcast(
1441 Some(session.connection_id),
1442 guest_connection_ids.iter().copied(),
1443 |conn_id| session.peer.send(conn_id, message.clone()),
1444 );
1445 room_updated(&room, &session.peer);
1446
1447 Ok(())
1448}
1449
1450async fn join_project(
1451 request: proto::JoinProject,
1452 response: Response<proto::JoinProject>,
1453 session: Session,
1454) -> Result<()> {
1455 let project_id = ProjectId::from_proto(request.project_id);
1456 let guest_user_id = session.user_id;
1457
1458 tracing::info!(%project_id, "join project");
1459
1460 let (project, replica_id) = &mut *session
1461 .db()
1462 .await
1463 .join_project(project_id, session.connection_id)
1464 .await?;
1465
1466 let collaborators = project
1467 .collaborators
1468 .iter()
1469 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1470 .map(|collaborator| collaborator.to_proto())
1471 .collect::<Vec<_>>();
1472
1473 let worktrees = project
1474 .worktrees
1475 .iter()
1476 .map(|(id, worktree)| proto::WorktreeMetadata {
1477 id: *id,
1478 root_name: worktree.root_name.clone(),
1479 visible: worktree.visible,
1480 abs_path: worktree.abs_path.clone(),
1481 })
1482 .collect::<Vec<_>>();
1483
1484 for collaborator in &collaborators {
1485 session
1486 .peer
1487 .send(
1488 collaborator.peer_id.unwrap().into(),
1489 proto::AddProjectCollaborator {
1490 project_id: project_id.to_proto(),
1491 collaborator: Some(proto::Collaborator {
1492 peer_id: Some(session.connection_id.into()),
1493 replica_id: replica_id.0 as u32,
1494 user_id: guest_user_id.to_proto(),
1495 }),
1496 },
1497 )
1498 .trace_err();
1499 }
1500
1501 // First, we send the metadata associated with each worktree.
1502 response.send(proto::JoinProjectResponse {
1503 worktrees: worktrees.clone(),
1504 replica_id: replica_id.0 as u32,
1505 collaborators: collaborators.clone(),
1506 language_servers: project.language_servers.clone(),
1507 })?;
1508
1509 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1510 #[cfg(any(test, feature = "test-support"))]
1511 const MAX_CHUNK_SIZE: usize = 2;
1512 #[cfg(not(any(test, feature = "test-support")))]
1513 const MAX_CHUNK_SIZE: usize = 256;
1514
1515 // Stream this worktree's entries.
1516 let message = proto::UpdateWorktree {
1517 project_id: project_id.to_proto(),
1518 worktree_id,
1519 abs_path: worktree.abs_path.clone(),
1520 root_name: worktree.root_name,
1521 updated_entries: worktree.entries,
1522 removed_entries: Default::default(),
1523 scan_id: worktree.scan_id,
1524 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1525 updated_repositories: worktree.repository_entries.into_values().collect(),
1526 removed_repositories: Default::default(),
1527 };
1528 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1529 session.peer.send(session.connection_id, update.clone())?;
1530 }
1531
1532 // Stream this worktree's diagnostics.
1533 for summary in worktree.diagnostic_summaries {
1534 session.peer.send(
1535 session.connection_id,
1536 proto::UpdateDiagnosticSummary {
1537 project_id: project_id.to_proto(),
1538 worktree_id: worktree.id,
1539 summary: Some(summary),
1540 },
1541 )?;
1542 }
1543
1544 for settings_file in worktree.settings_files {
1545 session.peer.send(
1546 session.connection_id,
1547 proto::UpdateWorktreeSettings {
1548 project_id: project_id.to_proto(),
1549 worktree_id: worktree.id,
1550 path: settings_file.path,
1551 content: Some(settings_file.content),
1552 },
1553 )?;
1554 }
1555 }
1556
1557 for language_server in &project.language_servers {
1558 session.peer.send(
1559 session.connection_id,
1560 proto::UpdateLanguageServer {
1561 project_id: project_id.to_proto(),
1562 language_server_id: language_server.id,
1563 variant: Some(
1564 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1565 proto::LspDiskBasedDiagnosticsUpdated {},
1566 ),
1567 ),
1568 },
1569 )?;
1570 }
1571
1572 Ok(())
1573}
1574
1575async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1576 let sender_id = session.connection_id;
1577 let project_id = ProjectId::from_proto(request.project_id);
1578
1579 let (room, project) = &*session
1580 .db()
1581 .await
1582 .leave_project(project_id, sender_id)
1583 .await?;
1584 tracing::info!(
1585 %project_id,
1586 host_user_id = %project.host_user_id,
1587 host_connection_id = %project.host_connection_id,
1588 "leave project"
1589 );
1590
1591 project_left(&project, &session);
1592 room_updated(&room, &session.peer);
1593
1594 Ok(())
1595}
1596
1597async fn update_project(
1598 request: proto::UpdateProject,
1599 response: Response<proto::UpdateProject>,
1600 session: Session,
1601) -> Result<()> {
1602 let project_id = ProjectId::from_proto(request.project_id);
1603 let (room, guest_connection_ids) = &*session
1604 .db()
1605 .await
1606 .update_project(project_id, session.connection_id, &request.worktrees)
1607 .await?;
1608 broadcast(
1609 Some(session.connection_id),
1610 guest_connection_ids.iter().copied(),
1611 |connection_id| {
1612 session
1613 .peer
1614 .forward_send(session.connection_id, connection_id, request.clone())
1615 },
1616 );
1617 room_updated(&room, &session.peer);
1618 response.send(proto::Ack {})?;
1619
1620 Ok(())
1621}
1622
1623async fn update_worktree(
1624 request: proto::UpdateWorktree,
1625 response: Response<proto::UpdateWorktree>,
1626 session: Session,
1627) -> Result<()> {
1628 let guest_connection_ids = session
1629 .db()
1630 .await
1631 .update_worktree(&request, session.connection_id)
1632 .await?;
1633
1634 broadcast(
1635 Some(session.connection_id),
1636 guest_connection_ids.iter().copied(),
1637 |connection_id| {
1638 session
1639 .peer
1640 .forward_send(session.connection_id, connection_id, request.clone())
1641 },
1642 );
1643 response.send(proto::Ack {})?;
1644 Ok(())
1645}
1646
1647async fn update_diagnostic_summary(
1648 message: proto::UpdateDiagnosticSummary,
1649 session: Session,
1650) -> Result<()> {
1651 let guest_connection_ids = session
1652 .db()
1653 .await
1654 .update_diagnostic_summary(&message, session.connection_id)
1655 .await?;
1656
1657 broadcast(
1658 Some(session.connection_id),
1659 guest_connection_ids.iter().copied(),
1660 |connection_id| {
1661 session
1662 .peer
1663 .forward_send(session.connection_id, connection_id, message.clone())
1664 },
1665 );
1666
1667 Ok(())
1668}
1669
1670async fn update_worktree_settings(
1671 message: proto::UpdateWorktreeSettings,
1672 session: Session,
1673) -> Result<()> {
1674 let guest_connection_ids = session
1675 .db()
1676 .await
1677 .update_worktree_settings(&message, session.connection_id)
1678 .await?;
1679
1680 broadcast(
1681 Some(session.connection_id),
1682 guest_connection_ids.iter().copied(),
1683 |connection_id| {
1684 session
1685 .peer
1686 .forward_send(session.connection_id, connection_id, message.clone())
1687 },
1688 );
1689
1690 Ok(())
1691}
1692
1693async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1694 broadcast_project_message(request.project_id, request, session).await
1695}
1696
1697async fn start_language_server(
1698 request: proto::StartLanguageServer,
1699 session: Session,
1700) -> Result<()> {
1701 let guest_connection_ids = session
1702 .db()
1703 .await
1704 .start_language_server(&request, session.connection_id)
1705 .await?;
1706
1707 broadcast(
1708 Some(session.connection_id),
1709 guest_connection_ids.iter().copied(),
1710 |connection_id| {
1711 session
1712 .peer
1713 .forward_send(session.connection_id, connection_id, request.clone())
1714 },
1715 );
1716 Ok(())
1717}
1718
1719async fn update_language_server(
1720 request: proto::UpdateLanguageServer,
1721 session: Session,
1722) -> Result<()> {
1723 session.executor.record_backtrace();
1724 let project_id = ProjectId::from_proto(request.project_id);
1725 let project_connection_ids = session
1726 .db()
1727 .await
1728 .project_connection_ids(project_id, session.connection_id)
1729 .await?;
1730 broadcast(
1731 Some(session.connection_id),
1732 project_connection_ids.iter().copied(),
1733 |connection_id| {
1734 session
1735 .peer
1736 .forward_send(session.connection_id, connection_id, request.clone())
1737 },
1738 );
1739 Ok(())
1740}
1741
1742async fn forward_project_request<T>(
1743 request: T,
1744 response: Response<T>,
1745 session: Session,
1746) -> Result<()>
1747where
1748 T: EntityMessage + RequestMessage,
1749{
1750 session.executor.record_backtrace();
1751 let project_id = ProjectId::from_proto(request.remote_entity_id());
1752 let host_connection_id = {
1753 let collaborators = session
1754 .db()
1755 .await
1756 .project_collaborators(project_id, session.connection_id)
1757 .await?;
1758 collaborators
1759 .iter()
1760 .find(|collaborator| collaborator.is_host)
1761 .ok_or_else(|| anyhow!("host not found"))?
1762 .connection_id
1763 };
1764
1765 let payload = session
1766 .peer
1767 .forward_request(session.connection_id, host_connection_id, request)
1768 .await?;
1769
1770 response.send(payload)?;
1771 Ok(())
1772}
1773
1774async fn create_buffer_for_peer(
1775 request: proto::CreateBufferForPeer,
1776 session: Session,
1777) -> Result<()> {
1778 session.executor.record_backtrace();
1779 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1780 session
1781 .peer
1782 .forward_send(session.connection_id, peer_id.into(), request)?;
1783 Ok(())
1784}
1785
1786async fn update_buffer(
1787 request: proto::UpdateBuffer,
1788 response: Response<proto::UpdateBuffer>,
1789 session: Session,
1790) -> Result<()> {
1791 session.executor.record_backtrace();
1792 let project_id = ProjectId::from_proto(request.project_id);
1793 let mut guest_connection_ids;
1794 let mut host_connection_id = None;
1795 {
1796 let collaborators = session
1797 .db()
1798 .await
1799 .project_collaborators(project_id, session.connection_id)
1800 .await?;
1801 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1802 for collaborator in collaborators.iter() {
1803 if collaborator.is_host {
1804 host_connection_id = Some(collaborator.connection_id);
1805 } else {
1806 guest_connection_ids.push(collaborator.connection_id);
1807 }
1808 }
1809 }
1810 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1811
1812 session.executor.record_backtrace();
1813 broadcast(
1814 Some(session.connection_id),
1815 guest_connection_ids,
1816 |connection_id| {
1817 session
1818 .peer
1819 .forward_send(session.connection_id, connection_id, request.clone())
1820 },
1821 );
1822 if host_connection_id != session.connection_id {
1823 session
1824 .peer
1825 .forward_request(session.connection_id, host_connection_id, request.clone())
1826 .await?;
1827 }
1828
1829 response.send(proto::Ack {})?;
1830 Ok(())
1831}
1832
1833async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1834 let project_id = ProjectId::from_proto(request.project_id);
1835 let project_connection_ids = session
1836 .db()
1837 .await
1838 .project_connection_ids(project_id, session.connection_id)
1839 .await?;
1840
1841 broadcast(
1842 Some(session.connection_id),
1843 project_connection_ids.iter().copied(),
1844 |connection_id| {
1845 session
1846 .peer
1847 .forward_send(session.connection_id, connection_id, request.clone())
1848 },
1849 );
1850 Ok(())
1851}
1852
1853async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1854 let project_id = ProjectId::from_proto(request.project_id);
1855 let project_connection_ids = session
1856 .db()
1857 .await
1858 .project_connection_ids(project_id, session.connection_id)
1859 .await?;
1860 broadcast(
1861 Some(session.connection_id),
1862 project_connection_ids.iter().copied(),
1863 |connection_id| {
1864 session
1865 .peer
1866 .forward_send(session.connection_id, connection_id, request.clone())
1867 },
1868 );
1869 Ok(())
1870}
1871
1872async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1873 broadcast_project_message(request.project_id, request, session).await
1874}
1875
1876async fn broadcast_project_message<T: EnvelopedMessage>(
1877 project_id: u64,
1878 request: T,
1879 session: Session,
1880) -> Result<()> {
1881 let project_id = ProjectId::from_proto(project_id);
1882 let project_connection_ids = session
1883 .db()
1884 .await
1885 .project_connection_ids(project_id, session.connection_id)
1886 .await?;
1887 broadcast(
1888 Some(session.connection_id),
1889 project_connection_ids.iter().copied(),
1890 |connection_id| {
1891 session
1892 .peer
1893 .forward_send(session.connection_id, connection_id, request.clone())
1894 },
1895 );
1896 Ok(())
1897}
1898
1899async fn follow(
1900 request: proto::Follow,
1901 response: Response<proto::Follow>,
1902 session: Session,
1903) -> Result<()> {
1904 let room_id = RoomId::from_proto(request.room_id);
1905 let project_id = request.project_id.map(ProjectId::from_proto);
1906 let leader_id = request
1907 .leader_id
1908 .ok_or_else(|| anyhow!("invalid leader id"))?
1909 .into();
1910 let follower_id = session.connection_id;
1911
1912 session
1913 .db()
1914 .await
1915 .check_room_participants(room_id, leader_id, session.connection_id)
1916 .await?;
1917
1918 let response_payload = session
1919 .peer
1920 .forward_request(session.connection_id, leader_id, request)
1921 .await?;
1922 response.send(response_payload)?;
1923
1924 if let Some(project_id) = project_id {
1925 let room = session
1926 .db()
1927 .await
1928 .follow(room_id, project_id, leader_id, follower_id)
1929 .await?;
1930 room_updated(&room, &session.peer);
1931 }
1932
1933 Ok(())
1934}
1935
1936async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1937 let room_id = RoomId::from_proto(request.room_id);
1938 let project_id = request.project_id.map(ProjectId::from_proto);
1939 let leader_id = request
1940 .leader_id
1941 .ok_or_else(|| anyhow!("invalid leader id"))?
1942 .into();
1943 let follower_id = session.connection_id;
1944
1945 session
1946 .db()
1947 .await
1948 .check_room_participants(room_id, leader_id, session.connection_id)
1949 .await?;
1950
1951 session
1952 .peer
1953 .forward_send(session.connection_id, leader_id, request)?;
1954
1955 if let Some(project_id) = project_id {
1956 let room = session
1957 .db()
1958 .await
1959 .unfollow(room_id, project_id, leader_id, follower_id)
1960 .await?;
1961 room_updated(&room, &session.peer);
1962 }
1963
1964 Ok(())
1965}
1966
1967async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1968 let room_id = RoomId::from_proto(request.room_id);
1969 let database = session.db.lock().await;
1970
1971 let connection_ids = if let Some(project_id) = request.project_id {
1972 let project_id = ProjectId::from_proto(project_id);
1973 database
1974 .project_connection_ids(project_id, session.connection_id)
1975 .await?
1976 } else {
1977 database
1978 .room_connection_ids(room_id, session.connection_id)
1979 .await?
1980 };
1981
1982 // For now, don't send view update messages back to that view's current leader.
1983 let connection_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
1984 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1985 _ => None,
1986 });
1987
1988 for follower_peer_id in request.follower_ids.iter().copied() {
1989 let follower_connection_id = follower_peer_id.into();
1990 if Some(follower_peer_id) != connection_id_to_omit
1991 && connection_ids.contains(&follower_connection_id)
1992 {
1993 session.peer.forward_send(
1994 session.connection_id,
1995 follower_connection_id,
1996 request.clone(),
1997 )?;
1998 }
1999 }
2000 Ok(())
2001}
2002
2003async fn get_users(
2004 request: proto::GetUsers,
2005 response: Response<proto::GetUsers>,
2006 session: Session,
2007) -> Result<()> {
2008 let user_ids = request
2009 .user_ids
2010 .into_iter()
2011 .map(UserId::from_proto)
2012 .collect();
2013 let users = session
2014 .db()
2015 .await
2016 .get_users_by_ids(user_ids)
2017 .await?
2018 .into_iter()
2019 .map(|user| proto::User {
2020 id: user.id.to_proto(),
2021 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2022 github_login: user.github_login,
2023 })
2024 .collect();
2025 response.send(proto::UsersResponse { users })?;
2026 Ok(())
2027}
2028
2029async fn fuzzy_search_users(
2030 request: proto::FuzzySearchUsers,
2031 response: Response<proto::FuzzySearchUsers>,
2032 session: Session,
2033) -> Result<()> {
2034 let query = request.query;
2035 let users = match query.len() {
2036 0 => vec![],
2037 1 | 2 => session
2038 .db()
2039 .await
2040 .get_user_by_github_login(&query)
2041 .await?
2042 .into_iter()
2043 .collect(),
2044 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2045 };
2046 let users = users
2047 .into_iter()
2048 .filter(|user| user.id != session.user_id)
2049 .map(|user| proto::User {
2050 id: user.id.to_proto(),
2051 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2052 github_login: user.github_login,
2053 })
2054 .collect();
2055 response.send(proto::UsersResponse { users })?;
2056 Ok(())
2057}
2058
2059async fn request_contact(
2060 request: proto::RequestContact,
2061 response: Response<proto::RequestContact>,
2062 session: Session,
2063) -> Result<()> {
2064 let requester_id = session.user_id;
2065 let responder_id = UserId::from_proto(request.responder_id);
2066 if requester_id == responder_id {
2067 return Err(anyhow!("cannot add yourself as a contact"))?;
2068 }
2069
2070 let notifications = session
2071 .db()
2072 .await
2073 .send_contact_request(requester_id, responder_id)
2074 .await?;
2075
2076 // Update outgoing contact requests of requester
2077 let mut update = proto::UpdateContacts::default();
2078 update.outgoing_requests.push(responder_id.to_proto());
2079 for connection_id in session
2080 .connection_pool()
2081 .await
2082 .user_connection_ids(requester_id)
2083 {
2084 session.peer.send(connection_id, update.clone())?;
2085 }
2086
2087 // Update incoming contact requests of responder
2088 let mut update = proto::UpdateContacts::default();
2089 update
2090 .incoming_requests
2091 .push(proto::IncomingContactRequest {
2092 requester_id: requester_id.to_proto(),
2093 });
2094 let connection_pool = session.connection_pool().await;
2095 for connection_id in connection_pool.user_connection_ids(responder_id) {
2096 session.peer.send(connection_id, update.clone())?;
2097 }
2098
2099 send_notifications(&*connection_pool, &session.peer, notifications);
2100
2101 response.send(proto::Ack {})?;
2102 Ok(())
2103}
2104
2105async fn respond_to_contact_request(
2106 request: proto::RespondToContactRequest,
2107 response: Response<proto::RespondToContactRequest>,
2108 session: Session,
2109) -> Result<()> {
2110 let responder_id = session.user_id;
2111 let requester_id = UserId::from_proto(request.requester_id);
2112 let db = session.db().await;
2113 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2114 db.dismiss_contact_notification(responder_id, requester_id)
2115 .await?;
2116 } else {
2117 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2118
2119 let notifications = db
2120 .respond_to_contact_request(responder_id, requester_id, accept)
2121 .await?;
2122 let requester_busy = db.is_user_busy(requester_id).await?;
2123 let responder_busy = db.is_user_busy(responder_id).await?;
2124
2125 let pool = session.connection_pool().await;
2126 // Update responder with new contact
2127 let mut update = proto::UpdateContacts::default();
2128 if accept {
2129 update
2130 .contacts
2131 .push(contact_for_user(requester_id, requester_busy, &pool));
2132 }
2133 update
2134 .remove_incoming_requests
2135 .push(requester_id.to_proto());
2136 for connection_id in pool.user_connection_ids(responder_id) {
2137 session.peer.send(connection_id, update.clone())?;
2138 }
2139
2140 // Update requester with new contact
2141 let mut update = proto::UpdateContacts::default();
2142 if accept {
2143 update
2144 .contacts
2145 .push(contact_for_user(responder_id, responder_busy, &pool));
2146 }
2147 update
2148 .remove_outgoing_requests
2149 .push(responder_id.to_proto());
2150
2151 for connection_id in pool.user_connection_ids(requester_id) {
2152 session.peer.send(connection_id, update.clone())?;
2153 }
2154
2155 send_notifications(&*pool, &session.peer, notifications);
2156 }
2157
2158 response.send(proto::Ack {})?;
2159 Ok(())
2160}
2161
2162async fn remove_contact(
2163 request: proto::RemoveContact,
2164 response: Response<proto::RemoveContact>,
2165 session: Session,
2166) -> Result<()> {
2167 let requester_id = session.user_id;
2168 let responder_id = UserId::from_proto(request.user_id);
2169 let db = session.db().await;
2170 let (contact_accepted, deleted_notification_id) =
2171 db.remove_contact(requester_id, responder_id).await?;
2172
2173 let pool = session.connection_pool().await;
2174 // Update outgoing contact requests of requester
2175 let mut update = proto::UpdateContacts::default();
2176 if contact_accepted {
2177 update.remove_contacts.push(responder_id.to_proto());
2178 } else {
2179 update
2180 .remove_outgoing_requests
2181 .push(responder_id.to_proto());
2182 }
2183 for connection_id in pool.user_connection_ids(requester_id) {
2184 session.peer.send(connection_id, update.clone())?;
2185 }
2186
2187 // Update incoming contact requests of responder
2188 let mut update = proto::UpdateContacts::default();
2189 if contact_accepted {
2190 update.remove_contacts.push(requester_id.to_proto());
2191 } else {
2192 update
2193 .remove_incoming_requests
2194 .push(requester_id.to_proto());
2195 }
2196 for connection_id in pool.user_connection_ids(responder_id) {
2197 session.peer.send(connection_id, update.clone())?;
2198 if let Some(notification_id) = deleted_notification_id {
2199 session.peer.send(
2200 connection_id,
2201 proto::DeleteNotification {
2202 notification_id: notification_id.to_proto(),
2203 },
2204 )?;
2205 }
2206 }
2207
2208 response.send(proto::Ack {})?;
2209 Ok(())
2210}
2211
2212async fn create_channel(
2213 request: proto::CreateChannel,
2214 response: Response<proto::CreateChannel>,
2215 session: Session,
2216) -> Result<()> {
2217 let db = session.db().await;
2218
2219 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2220 let id = db
2221 .create_channel(&request.name, parent_id, session.user_id)
2222 .await?;
2223
2224 let channel = proto::Channel {
2225 id: id.to_proto(),
2226 name: request.name,
2227 visibility: proto::ChannelVisibility::Members as i32,
2228 };
2229
2230 response.send(proto::CreateChannelResponse {
2231 channel: Some(channel.clone()),
2232 parent_id: request.parent_id,
2233 })?;
2234
2235 let Some(parent_id) = parent_id else {
2236 return Ok(());
2237 };
2238
2239 let update = proto::UpdateChannels {
2240 channels: vec![channel],
2241 insert_edge: vec![ChannelEdge {
2242 parent_id: parent_id.to_proto(),
2243 channel_id: id.to_proto(),
2244 }],
2245 ..Default::default()
2246 };
2247
2248 let user_ids_to_notify = db.get_channel_members(parent_id).await?;
2249
2250 let connection_pool = session.connection_pool().await;
2251 for user_id in user_ids_to_notify {
2252 for connection_id in connection_pool.user_connection_ids(user_id) {
2253 if user_id == session.user_id {
2254 continue;
2255 }
2256 session.peer.send(connection_id, update.clone())?;
2257 }
2258 }
2259
2260 Ok(())
2261}
2262
2263async fn delete_channel(
2264 request: proto::DeleteChannel,
2265 response: Response<proto::DeleteChannel>,
2266 session: Session,
2267) -> Result<()> {
2268 let db = session.db().await;
2269
2270 let channel_id = request.channel_id;
2271 let (removed_channels, member_ids) = db
2272 .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2273 .await?;
2274 response.send(proto::Ack {})?;
2275
2276 // Notify members of removed channels
2277 let mut update = proto::UpdateChannels::default();
2278 update
2279 .delete_channels
2280 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2281
2282 let connection_pool = session.connection_pool().await;
2283 for member_id in member_ids {
2284 for connection_id in connection_pool.user_connection_ids(member_id) {
2285 session.peer.send(connection_id, update.clone())?;
2286 }
2287 }
2288
2289 Ok(())
2290}
2291
2292async fn invite_channel_member(
2293 request: proto::InviteChannelMember,
2294 response: Response<proto::InviteChannelMember>,
2295 session: Session,
2296) -> Result<()> {
2297 let db = session.db().await;
2298 let channel_id = ChannelId::from_proto(request.channel_id);
2299 let invitee_id = UserId::from_proto(request.user_id);
2300 let notifications = db
2301 .invite_channel_member(
2302 channel_id,
2303 invitee_id,
2304 session.user_id,
2305 request.role().into(),
2306 )
2307 .await?;
2308
2309 let channel = db.get_channel(channel_id, session.user_id).await?;
2310
2311 let mut update = proto::UpdateChannels::default();
2312 update.channel_invitations.push(proto::Channel {
2313 id: channel.id.to_proto(),
2314 visibility: channel.visibility.into(),
2315 name: channel.name,
2316 });
2317
2318 let pool = session.connection_pool().await;
2319 for connection_id in pool.user_connection_ids(invitee_id) {
2320 session.peer.send(connection_id, update.clone())?;
2321 }
2322
2323 send_notifications(&*pool, &session.peer, notifications);
2324
2325 response.send(proto::Ack {})?;
2326 Ok(())
2327}
2328
2329async fn remove_channel_member(
2330 request: proto::RemoveChannelMember,
2331 response: Response<proto::RemoveChannelMember>,
2332 session: Session,
2333) -> Result<()> {
2334 let db = session.db().await;
2335 let channel_id = ChannelId::from_proto(request.channel_id);
2336 let member_id = UserId::from_proto(request.user_id);
2337
2338 let removed_notification_id = db
2339 .remove_channel_member(channel_id, member_id, session.user_id)
2340 .await?;
2341
2342 let mut update = proto::UpdateChannels::default();
2343 update.delete_channels.push(channel_id.to_proto());
2344
2345 for connection_id in session
2346 .connection_pool()
2347 .await
2348 .user_connection_ids(member_id)
2349 {
2350 session.peer.send(connection_id, update.clone()).trace_err();
2351 if let Some(notification_id) = removed_notification_id {
2352 session
2353 .peer
2354 .send(
2355 connection_id,
2356 proto::DeleteNotification {
2357 notification_id: notification_id.to_proto(),
2358 },
2359 )
2360 .trace_err();
2361 }
2362 }
2363
2364 response.send(proto::Ack {})?;
2365 Ok(())
2366}
2367
2368async fn set_channel_visibility(
2369 request: proto::SetChannelVisibility,
2370 response: Response<proto::SetChannelVisibility>,
2371 session: Session,
2372) -> Result<()> {
2373 let db = session.db().await;
2374 let channel_id = ChannelId::from_proto(request.channel_id);
2375 let visibility = request.visibility().into();
2376
2377 let channel = db
2378 .set_channel_visibility(channel_id, visibility, session.user_id)
2379 .await?;
2380
2381 let mut update = proto::UpdateChannels::default();
2382 update.channels.push(proto::Channel {
2383 id: channel.id.to_proto(),
2384 name: channel.name,
2385 visibility: channel.visibility.into(),
2386 });
2387
2388 let member_ids = db.get_channel_members(channel_id).await?;
2389
2390 let connection_pool = session.connection_pool().await;
2391 for member_id in member_ids {
2392 for connection_id in connection_pool.user_connection_ids(member_id) {
2393 session.peer.send(connection_id, update.clone())?;
2394 }
2395 }
2396
2397 response.send(proto::Ack {})?;
2398 Ok(())
2399}
2400
2401async fn set_channel_member_role(
2402 request: proto::SetChannelMemberRole,
2403 response: Response<proto::SetChannelMemberRole>,
2404 session: Session,
2405) -> Result<()> {
2406 let db = session.db().await;
2407 let channel_id = ChannelId::from_proto(request.channel_id);
2408 let member_id = UserId::from_proto(request.user_id);
2409 let channel_member = db
2410 .set_channel_member_role(
2411 channel_id,
2412 session.user_id,
2413 member_id,
2414 request.role().into(),
2415 )
2416 .await?;
2417
2418 let channel = db.get_channel(channel_id, session.user_id).await?;
2419
2420 let mut update = proto::UpdateChannels::default();
2421 if channel_member.accepted {
2422 update.channel_permissions.push(proto::ChannelPermission {
2423 channel_id: channel.id.to_proto(),
2424 role: request.role,
2425 });
2426 }
2427
2428 for connection_id in session
2429 .connection_pool()
2430 .await
2431 .user_connection_ids(member_id)
2432 {
2433 session.peer.send(connection_id, update.clone())?;
2434 }
2435
2436 response.send(proto::Ack {})?;
2437 Ok(())
2438}
2439
2440async fn rename_channel(
2441 request: proto::RenameChannel,
2442 response: Response<proto::RenameChannel>,
2443 session: Session,
2444) -> Result<()> {
2445 let db = session.db().await;
2446 let channel_id = ChannelId::from_proto(request.channel_id);
2447 let channel = db
2448 .rename_channel(channel_id, session.user_id, &request.name)
2449 .await?;
2450
2451 let channel = proto::Channel {
2452 id: channel.id.to_proto(),
2453 name: channel.name,
2454 visibility: channel.visibility.into(),
2455 };
2456 response.send(proto::RenameChannelResponse {
2457 channel: Some(channel.clone()),
2458 })?;
2459 let mut update = proto::UpdateChannels::default();
2460 update.channels.push(channel);
2461
2462 let member_ids = db.get_channel_members(channel_id).await?;
2463
2464 let connection_pool = session.connection_pool().await;
2465 for member_id in member_ids {
2466 for connection_id in connection_pool.user_connection_ids(member_id) {
2467 session.peer.send(connection_id, update.clone())?;
2468 }
2469 }
2470
2471 Ok(())
2472}
2473
2474async fn link_channel(
2475 request: proto::LinkChannel,
2476 response: Response<proto::LinkChannel>,
2477 session: Session,
2478) -> Result<()> {
2479 let db = session.db().await;
2480 let channel_id = ChannelId::from_proto(request.channel_id);
2481 let to = ChannelId::from_proto(request.to);
2482 let channels_to_send = db.link_channel(session.user_id, channel_id, to).await?;
2483
2484 let members = db.get_channel_members(to).await?;
2485 let connection_pool = session.connection_pool().await;
2486 let update = proto::UpdateChannels {
2487 channels: channels_to_send
2488 .channels
2489 .into_iter()
2490 .map(|channel| proto::Channel {
2491 id: channel.id.to_proto(),
2492 visibility: channel.visibility.into(),
2493 name: channel.name,
2494 })
2495 .collect(),
2496 insert_edge: channels_to_send.edges,
2497 ..Default::default()
2498 };
2499 for member_id in members {
2500 for connection_id in connection_pool.user_connection_ids(member_id) {
2501 session.peer.send(connection_id, update.clone())?;
2502 }
2503 }
2504
2505 response.send(Ack {})?;
2506
2507 Ok(())
2508}
2509
2510async fn unlink_channel(
2511 request: proto::UnlinkChannel,
2512 response: Response<proto::UnlinkChannel>,
2513 session: Session,
2514) -> Result<()> {
2515 let db = session.db().await;
2516 let channel_id = ChannelId::from_proto(request.channel_id);
2517 let from = ChannelId::from_proto(request.from);
2518
2519 db.unlink_channel(session.user_id, channel_id, from).await?;
2520
2521 let members = db.get_channel_members(from).await?;
2522
2523 let update = proto::UpdateChannels {
2524 delete_edge: vec![proto::ChannelEdge {
2525 channel_id: channel_id.to_proto(),
2526 parent_id: from.to_proto(),
2527 }],
2528 ..Default::default()
2529 };
2530 let connection_pool = session.connection_pool().await;
2531 for member_id in members {
2532 for connection_id in connection_pool.user_connection_ids(member_id) {
2533 session.peer.send(connection_id, update.clone())?;
2534 }
2535 }
2536
2537 response.send(Ack {})?;
2538
2539 Ok(())
2540}
2541
2542async fn move_channel(
2543 request: proto::MoveChannel,
2544 response: Response<proto::MoveChannel>,
2545 session: Session,
2546) -> Result<()> {
2547 let db = session.db().await;
2548 let channel_id = ChannelId::from_proto(request.channel_id);
2549 let from_parent = ChannelId::from_proto(request.from);
2550 let to = ChannelId::from_proto(request.to);
2551
2552 let channels_to_send = db
2553 .move_channel(session.user_id, channel_id, from_parent, to)
2554 .await?;
2555
2556 if channels_to_send.is_empty() {
2557 response.send(Ack {})?;
2558 return Ok(());
2559 }
2560
2561 let members_from = db.get_channel_members(from_parent).await?;
2562 let members_to = db.get_channel_members(to).await?;
2563
2564 let update = proto::UpdateChannels {
2565 delete_edge: vec![proto::ChannelEdge {
2566 channel_id: channel_id.to_proto(),
2567 parent_id: from_parent.to_proto(),
2568 }],
2569 ..Default::default()
2570 };
2571 let connection_pool = session.connection_pool().await;
2572 for member_id in members_from {
2573 for connection_id in connection_pool.user_connection_ids(member_id) {
2574 session.peer.send(connection_id, update.clone())?;
2575 }
2576 }
2577
2578 let update = proto::UpdateChannels {
2579 channels: channels_to_send
2580 .channels
2581 .into_iter()
2582 .map(|channel| proto::Channel {
2583 id: channel.id.to_proto(),
2584 visibility: channel.visibility.into(),
2585 name: channel.name,
2586 })
2587 .collect(),
2588 insert_edge: channels_to_send.edges,
2589 ..Default::default()
2590 };
2591 for member_id in members_to {
2592 for connection_id in connection_pool.user_connection_ids(member_id) {
2593 session.peer.send(connection_id, update.clone())?;
2594 }
2595 }
2596
2597 response.send(Ack {})?;
2598
2599 Ok(())
2600}
2601
2602async fn get_channel_members(
2603 request: proto::GetChannelMembers,
2604 response: Response<proto::GetChannelMembers>,
2605 session: Session,
2606) -> Result<()> {
2607 let db = session.db().await;
2608 let channel_id = ChannelId::from_proto(request.channel_id);
2609 let members = db
2610 .get_channel_participant_details(channel_id, session.user_id)
2611 .await?;
2612 response.send(proto::GetChannelMembersResponse { members })?;
2613 Ok(())
2614}
2615
2616async fn respond_to_channel_invite(
2617 request: proto::RespondToChannelInvite,
2618 response: Response<proto::RespondToChannelInvite>,
2619 session: Session,
2620) -> Result<()> {
2621 let db = session.db().await;
2622 let channel_id = ChannelId::from_proto(request.channel_id);
2623 let notifications = db
2624 .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2625 .await?;
2626
2627 if request.accept {
2628 channel_membership_updated(db, channel_id, &session).await?;
2629 } else {
2630 let mut update = proto::UpdateChannels::default();
2631 update
2632 .remove_channel_invitations
2633 .push(channel_id.to_proto());
2634 session.peer.send(session.connection_id, update)?;
2635 }
2636
2637 send_notifications(
2638 &*session.connection_pool().await,
2639 &session.peer,
2640 notifications,
2641 );
2642 response.send(proto::Ack {})?;
2643
2644 Ok(())
2645}
2646
2647async fn channel_membership_updated(
2648 db: tokio::sync::MutexGuard<'_, DbHandle>,
2649 channel_id: ChannelId,
2650 session: &Session,
2651) -> Result<(), crate::Error> {
2652 let mut update = proto::UpdateChannels::default();
2653 update
2654 .remove_channel_invitations
2655 .push(channel_id.to_proto());
2656
2657 let result = db.get_channel_for_user(channel_id, session.user_id).await?;
2658 update.channels.extend(
2659 result
2660 .channels
2661 .channels
2662 .into_iter()
2663 .map(|channel| proto::Channel {
2664 id: channel.id.to_proto(),
2665 visibility: channel.visibility.into(),
2666 name: channel.name,
2667 }),
2668 );
2669 update.unseen_channel_messages = result.channel_messages;
2670 update.unseen_channel_buffer_changes = result.unseen_buffer_changes;
2671 update.insert_edge = result.channels.edges;
2672 update
2673 .channel_participants
2674 .extend(
2675 result
2676 .channel_participants
2677 .into_iter()
2678 .map(|(channel_id, user_ids)| proto::ChannelParticipants {
2679 channel_id: channel_id.to_proto(),
2680 participant_user_ids: user_ids.into_iter().map(UserId::to_proto).collect(),
2681 }),
2682 );
2683 update
2684 .channel_permissions
2685 .extend(
2686 result
2687 .channels_with_admin_privileges
2688 .into_iter()
2689 .map(|channel_id| proto::ChannelPermission {
2690 channel_id: channel_id.to_proto(),
2691 role: proto::ChannelRole::Admin.into(),
2692 }),
2693 );
2694 session.peer.send(session.connection_id, update)?;
2695 Ok(())
2696}
2697
2698async fn join_channel(
2699 request: proto::JoinChannel,
2700 response: Response<proto::JoinChannel>,
2701 session: Session,
2702) -> Result<()> {
2703 let channel_id = ChannelId::from_proto(request.channel_id);
2704 join_channel_internal(channel_id, Box::new(response), session).await
2705}
2706
2707trait JoinChannelInternalResponse {
2708 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2709}
2710impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2711 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2712 Response::<proto::JoinChannel>::send(self, result)
2713 }
2714}
2715impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2716 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2717 Response::<proto::JoinRoom>::send(self, result)
2718 }
2719}
2720
2721async fn join_channel_internal(
2722 channel_id: ChannelId,
2723 response: Box<impl JoinChannelInternalResponse>,
2724 session: Session,
2725) -> Result<()> {
2726 let joined_room = {
2727 leave_room_for_session(&session).await?;
2728 let db = session.db().await;
2729
2730 let (joined_room, joined_channel) = db
2731 .join_channel(
2732 channel_id,
2733 session.user_id,
2734 session.connection_id,
2735 RELEASE_CHANNEL_NAME.as_str(),
2736 )
2737 .await?;
2738
2739 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2740 let token = live_kit
2741 .room_token(
2742 &joined_room.room.live_kit_room,
2743 &session.user_id.to_string(),
2744 )
2745 .trace_err()?;
2746
2747 Some(LiveKitConnectionInfo {
2748 server_url: live_kit.url().into(),
2749 token,
2750 })
2751 });
2752
2753 response.send(proto::JoinRoomResponse {
2754 room: Some(joined_room.room.clone()),
2755 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2756 live_kit_connection_info,
2757 })?;
2758
2759 if let Some(joined_channel) = joined_channel {
2760 channel_membership_updated(db, joined_channel, &session).await?
2761 }
2762
2763 room_updated(&joined_room.room, &session.peer);
2764
2765 joined_room
2766 };
2767
2768 channel_updated(
2769 channel_id,
2770 &joined_room.room,
2771 &joined_room.channel_members,
2772 &session.peer,
2773 &*session.connection_pool().await,
2774 );
2775
2776 update_user_contacts(session.user_id, &session).await?;
2777 Ok(())
2778}
2779
2780async fn join_channel_buffer(
2781 request: proto::JoinChannelBuffer,
2782 response: Response<proto::JoinChannelBuffer>,
2783 session: Session,
2784) -> Result<()> {
2785 let db = session.db().await;
2786 let channel_id = ChannelId::from_proto(request.channel_id);
2787
2788 let open_response = db
2789 .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2790 .await?;
2791
2792 let collaborators = open_response.collaborators.clone();
2793 response.send(open_response)?;
2794
2795 let update = UpdateChannelBufferCollaborators {
2796 channel_id: channel_id.to_proto(),
2797 collaborators: collaborators.clone(),
2798 };
2799 channel_buffer_updated(
2800 session.connection_id,
2801 collaborators
2802 .iter()
2803 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2804 &update,
2805 &session.peer,
2806 );
2807
2808 Ok(())
2809}
2810
2811async fn update_channel_buffer(
2812 request: proto::UpdateChannelBuffer,
2813 session: Session,
2814) -> Result<()> {
2815 let db = session.db().await;
2816 let channel_id = ChannelId::from_proto(request.channel_id);
2817
2818 let (collaborators, non_collaborators, epoch, version) = db
2819 .update_channel_buffer(channel_id, session.user_id, &request.operations)
2820 .await?;
2821
2822 channel_buffer_updated(
2823 session.connection_id,
2824 collaborators,
2825 &proto::UpdateChannelBuffer {
2826 channel_id: channel_id.to_proto(),
2827 operations: request.operations,
2828 },
2829 &session.peer,
2830 );
2831
2832 let pool = &*session.connection_pool().await;
2833
2834 broadcast(
2835 None,
2836 non_collaborators
2837 .iter()
2838 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2839 |peer_id| {
2840 session.peer.send(
2841 peer_id.into(),
2842 proto::UpdateChannels {
2843 unseen_channel_buffer_changes: vec![proto::UnseenChannelBufferChange {
2844 channel_id: channel_id.to_proto(),
2845 epoch: epoch as u64,
2846 version: version.clone(),
2847 }],
2848 ..Default::default()
2849 },
2850 )
2851 },
2852 );
2853
2854 Ok(())
2855}
2856
2857async fn rejoin_channel_buffers(
2858 request: proto::RejoinChannelBuffers,
2859 response: Response<proto::RejoinChannelBuffers>,
2860 session: Session,
2861) -> Result<()> {
2862 let db = session.db().await;
2863 let buffers = db
2864 .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2865 .await?;
2866
2867 for rejoined_buffer in &buffers {
2868 let collaborators_to_notify = rejoined_buffer
2869 .buffer
2870 .collaborators
2871 .iter()
2872 .filter_map(|c| Some(c.peer_id?.into()));
2873 channel_buffer_updated(
2874 session.connection_id,
2875 collaborators_to_notify,
2876 &proto::UpdateChannelBufferCollaborators {
2877 channel_id: rejoined_buffer.buffer.channel_id,
2878 collaborators: rejoined_buffer.buffer.collaborators.clone(),
2879 },
2880 &session.peer,
2881 );
2882 }
2883
2884 response.send(proto::RejoinChannelBuffersResponse {
2885 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2886 })?;
2887
2888 Ok(())
2889}
2890
2891async fn leave_channel_buffer(
2892 request: proto::LeaveChannelBuffer,
2893 response: Response<proto::LeaveChannelBuffer>,
2894 session: Session,
2895) -> Result<()> {
2896 let db = session.db().await;
2897 let channel_id = ChannelId::from_proto(request.channel_id);
2898
2899 let left_buffer = db
2900 .leave_channel_buffer(channel_id, session.connection_id)
2901 .await?;
2902
2903 response.send(Ack {})?;
2904
2905 channel_buffer_updated(
2906 session.connection_id,
2907 left_buffer.connections,
2908 &proto::UpdateChannelBufferCollaborators {
2909 channel_id: channel_id.to_proto(),
2910 collaborators: left_buffer.collaborators,
2911 },
2912 &session.peer,
2913 );
2914
2915 Ok(())
2916}
2917
2918fn channel_buffer_updated<T: EnvelopedMessage>(
2919 sender_id: ConnectionId,
2920 collaborators: impl IntoIterator<Item = ConnectionId>,
2921 message: &T,
2922 peer: &Peer,
2923) {
2924 broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2925 peer.send(peer_id.into(), message.clone())
2926 });
2927}
2928
2929fn send_notifications(
2930 connection_pool: &ConnectionPool,
2931 peer: &Peer,
2932 notifications: db::NotificationBatch,
2933) {
2934 for (user_id, notification) in notifications {
2935 for connection_id in connection_pool.user_connection_ids(user_id) {
2936 if let Err(error) = peer.send(
2937 connection_id,
2938 proto::AddNotification {
2939 notification: Some(notification.clone()),
2940 },
2941 ) {
2942 tracing::error!(
2943 "failed to send notification to {:?} {}",
2944 connection_id,
2945 error
2946 );
2947 }
2948 }
2949 }
2950}
2951
2952async fn send_channel_message(
2953 request: proto::SendChannelMessage,
2954 response: Response<proto::SendChannelMessage>,
2955 session: Session,
2956) -> Result<()> {
2957 // Validate the message body.
2958 let body = request.body.trim().to_string();
2959 if body.len() > MAX_MESSAGE_LEN {
2960 return Err(anyhow!("message is too long"))?;
2961 }
2962 if body.is_empty() {
2963 return Err(anyhow!("message can't be blank"))?;
2964 }
2965
2966 // TODO: adjust mentions if body is trimmed
2967
2968 let timestamp = OffsetDateTime::now_utc();
2969 let nonce = request
2970 .nonce
2971 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
2972
2973 let channel_id = ChannelId::from_proto(request.channel_id);
2974 let CreatedChannelMessage {
2975 message_id,
2976 participant_connection_ids,
2977 channel_members,
2978 notifications,
2979 } = session
2980 .db()
2981 .await
2982 .create_channel_message(
2983 channel_id,
2984 session.user_id,
2985 &body,
2986 &request.mentions,
2987 timestamp,
2988 nonce.clone().into(),
2989 )
2990 .await?;
2991 let message = proto::ChannelMessage {
2992 sender_id: session.user_id.to_proto(),
2993 id: message_id.to_proto(),
2994 body,
2995 mentions: request.mentions,
2996 timestamp: timestamp.unix_timestamp() as u64,
2997 nonce: Some(nonce),
2998 };
2999 broadcast(
3000 Some(session.connection_id),
3001 participant_connection_ids,
3002 |connection| {
3003 session.peer.send(
3004 connection,
3005 proto::ChannelMessageSent {
3006 channel_id: channel_id.to_proto(),
3007 message: Some(message.clone()),
3008 },
3009 )
3010 },
3011 );
3012 response.send(proto::SendChannelMessageResponse {
3013 message: Some(message),
3014 })?;
3015
3016 let pool = &*session.connection_pool().await;
3017 broadcast(
3018 None,
3019 channel_members
3020 .iter()
3021 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3022 |peer_id| {
3023 session.peer.send(
3024 peer_id.into(),
3025 proto::UpdateChannels {
3026 unseen_channel_messages: vec![proto::UnseenChannelMessage {
3027 channel_id: channel_id.to_proto(),
3028 message_id: message_id.to_proto(),
3029 }],
3030 ..Default::default()
3031 },
3032 )
3033 },
3034 );
3035 send_notifications(pool, &session.peer, notifications);
3036
3037 Ok(())
3038}
3039
3040async fn remove_channel_message(
3041 request: proto::RemoveChannelMessage,
3042 response: Response<proto::RemoveChannelMessage>,
3043 session: Session,
3044) -> Result<()> {
3045 let channel_id = ChannelId::from_proto(request.channel_id);
3046 let message_id = MessageId::from_proto(request.message_id);
3047 let connection_ids = session
3048 .db()
3049 .await
3050 .remove_channel_message(channel_id, message_id, session.user_id)
3051 .await?;
3052 broadcast(Some(session.connection_id), connection_ids, |connection| {
3053 session.peer.send(connection, request.clone())
3054 });
3055 response.send(proto::Ack {})?;
3056 Ok(())
3057}
3058
3059async fn acknowledge_channel_message(
3060 request: proto::AckChannelMessage,
3061 session: Session,
3062) -> Result<()> {
3063 let channel_id = ChannelId::from_proto(request.channel_id);
3064 let message_id = MessageId::from_proto(request.message_id);
3065 let notifications = session
3066 .db()
3067 .await
3068 .observe_channel_message(channel_id, session.user_id, message_id)
3069 .await?;
3070 send_notifications(
3071 &*session.connection_pool().await,
3072 &session.peer,
3073 notifications,
3074 );
3075 Ok(())
3076}
3077
3078async fn acknowledge_buffer_version(
3079 request: proto::AckBufferOperation,
3080 session: Session,
3081) -> Result<()> {
3082 let buffer_id = BufferId::from_proto(request.buffer_id);
3083 session
3084 .db()
3085 .await
3086 .observe_buffer_version(
3087 buffer_id,
3088 session.user_id,
3089 request.epoch as i32,
3090 &request.version,
3091 )
3092 .await?;
3093 Ok(())
3094}
3095
3096async fn join_channel_chat(
3097 request: proto::JoinChannelChat,
3098 response: Response<proto::JoinChannelChat>,
3099 session: Session,
3100) -> Result<()> {
3101 let channel_id = ChannelId::from_proto(request.channel_id);
3102
3103 let db = session.db().await;
3104 db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3105 .await?;
3106 let messages = db
3107 .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3108 .await?;
3109 response.send(proto::JoinChannelChatResponse {
3110 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3111 messages,
3112 })?;
3113 Ok(())
3114}
3115
3116async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3117 let channel_id = ChannelId::from_proto(request.channel_id);
3118 session
3119 .db()
3120 .await
3121 .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3122 .await?;
3123 Ok(())
3124}
3125
3126async fn get_channel_messages(
3127 request: proto::GetChannelMessages,
3128 response: Response<proto::GetChannelMessages>,
3129 session: Session,
3130) -> Result<()> {
3131 let channel_id = ChannelId::from_proto(request.channel_id);
3132 let messages = session
3133 .db()
3134 .await
3135 .get_channel_messages(
3136 channel_id,
3137 session.user_id,
3138 MESSAGE_COUNT_PER_PAGE,
3139 Some(MessageId::from_proto(request.before_message_id)),
3140 )
3141 .await?;
3142 response.send(proto::GetChannelMessagesResponse {
3143 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3144 messages,
3145 })?;
3146 Ok(())
3147}
3148
3149async fn get_channel_messages_by_id(
3150 request: proto::GetChannelMessagesById,
3151 response: Response<proto::GetChannelMessagesById>,
3152 session: Session,
3153) -> Result<()> {
3154 let message_ids = request
3155 .message_ids
3156 .iter()
3157 .map(|id| MessageId::from_proto(*id))
3158 .collect::<Vec<_>>();
3159 let messages = session
3160 .db()
3161 .await
3162 .get_channel_messages_by_id(session.user_id, &message_ids)
3163 .await?;
3164 response.send(proto::GetChannelMessagesResponse {
3165 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3166 messages,
3167 })?;
3168 Ok(())
3169}
3170
3171async fn get_notifications(
3172 request: proto::GetNotifications,
3173 response: Response<proto::GetNotifications>,
3174 session: Session,
3175) -> Result<()> {
3176 let notifications = session
3177 .db()
3178 .await
3179 .get_notifications(
3180 session.user_id,
3181 NOTIFICATION_COUNT_PER_PAGE,
3182 request
3183 .before_id
3184 .map(|id| db::NotificationId::from_proto(id)),
3185 )
3186 .await?;
3187 response.send(proto::GetNotificationsResponse {
3188 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3189 notifications,
3190 })?;
3191 Ok(())
3192}
3193
3194async fn mark_notification_as_read(
3195 request: proto::MarkNotificationRead,
3196 response: Response<proto::MarkNotificationRead>,
3197 session: Session,
3198) -> Result<()> {
3199 let database = &session.db().await;
3200 let notifications = database
3201 .mark_notification_as_read_by_id(
3202 session.user_id,
3203 NotificationId::from_proto(request.notification_id),
3204 )
3205 .await?;
3206 send_notifications(
3207 &*session.connection_pool().await,
3208 &session.peer,
3209 notifications,
3210 );
3211 response.send(proto::Ack {})?;
3212 Ok(())
3213}
3214
3215async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
3216 let project_id = ProjectId::from_proto(request.project_id);
3217 let project_connection_ids = session
3218 .db()
3219 .await
3220 .project_connection_ids(project_id, session.connection_id)
3221 .await?;
3222 broadcast(
3223 Some(session.connection_id),
3224 project_connection_ids.iter().copied(),
3225 |connection_id| {
3226 session
3227 .peer
3228 .forward_send(session.connection_id, connection_id, request.clone())
3229 },
3230 );
3231 Ok(())
3232}
3233
3234async fn get_private_user_info(
3235 _request: proto::GetPrivateUserInfo,
3236 response: Response<proto::GetPrivateUserInfo>,
3237 session: Session,
3238) -> Result<()> {
3239 let db = session.db().await;
3240
3241 let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3242 let user = db
3243 .get_user_by_id(session.user_id)
3244 .await?
3245 .ok_or_else(|| anyhow!("user not found"))?;
3246 let flags = db.get_user_flags(session.user_id).await?;
3247
3248 response.send(proto::GetPrivateUserInfoResponse {
3249 metrics_id,
3250 staff: user.admin,
3251 flags,
3252 })?;
3253 Ok(())
3254}
3255
3256fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3257 match message {
3258 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3259 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3260 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3261 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3262 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3263 code: frame.code.into(),
3264 reason: frame.reason,
3265 })),
3266 }
3267}
3268
3269fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3270 match message {
3271 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3272 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3273 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3274 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3275 AxumMessage::Close(frame) => {
3276 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3277 code: frame.code.into(),
3278 reason: frame.reason,
3279 }))
3280 }
3281 }
3282}
3283
3284fn build_initial_channels_update(
3285 channels: ChannelsForUser,
3286 channel_invites: Vec<db::Channel>,
3287) -> proto::UpdateChannels {
3288 let mut update = proto::UpdateChannels::default();
3289
3290 for channel in channels.channels.channels {
3291 update.channels.push(proto::Channel {
3292 id: channel.id.to_proto(),
3293 name: channel.name,
3294 visibility: channel.visibility.into(),
3295 });
3296 }
3297
3298 update.unseen_channel_buffer_changes = channels.unseen_buffer_changes;
3299 update.unseen_channel_messages = channels.channel_messages;
3300 update.insert_edge = channels.channels.edges;
3301
3302 for (channel_id, participants) in channels.channel_participants {
3303 update
3304 .channel_participants
3305 .push(proto::ChannelParticipants {
3306 channel_id: channel_id.to_proto(),
3307 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3308 });
3309 }
3310
3311 update
3312 .channel_permissions
3313 .extend(
3314 channels
3315 .channels_with_admin_privileges
3316 .into_iter()
3317 .map(|id| proto::ChannelPermission {
3318 channel_id: id.to_proto(),
3319 role: proto::ChannelRole::Admin.into(),
3320 }),
3321 );
3322
3323 for channel in channel_invites {
3324 update.channel_invitations.push(proto::Channel {
3325 id: channel.id.to_proto(),
3326 name: channel.name,
3327 // TODO: Visibility
3328 visibility: ChannelVisibility::Public.into(),
3329 });
3330 }
3331
3332 update
3333}
3334
3335fn build_initial_contacts_update(
3336 contacts: Vec<db::Contact>,
3337 pool: &ConnectionPool,
3338) -> proto::UpdateContacts {
3339 let mut update = proto::UpdateContacts::default();
3340
3341 for contact in contacts {
3342 match contact {
3343 db::Contact::Accepted { user_id, busy } => {
3344 update.contacts.push(contact_for_user(user_id, busy, &pool));
3345 }
3346 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3347 db::Contact::Incoming { user_id } => {
3348 update
3349 .incoming_requests
3350 .push(proto::IncomingContactRequest {
3351 requester_id: user_id.to_proto(),
3352 })
3353 }
3354 }
3355 }
3356
3357 update
3358}
3359
3360fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3361 proto::Contact {
3362 user_id: user_id.to_proto(),
3363 online: pool.is_user_online(user_id),
3364 busy,
3365 }
3366}
3367
3368fn room_updated(room: &proto::Room, peer: &Peer) {
3369 broadcast(
3370 None,
3371 room.participants
3372 .iter()
3373 .filter_map(|participant| Some(participant.peer_id?.into())),
3374 |peer_id| {
3375 peer.send(
3376 peer_id.into(),
3377 proto::RoomUpdated {
3378 room: Some(room.clone()),
3379 },
3380 )
3381 },
3382 );
3383}
3384
3385fn channel_updated(
3386 channel_id: ChannelId,
3387 room: &proto::Room,
3388 channel_members: &[UserId],
3389 peer: &Peer,
3390 pool: &ConnectionPool,
3391) {
3392 let participants = room
3393 .participants
3394 .iter()
3395 .map(|p| p.user_id)
3396 .collect::<Vec<_>>();
3397
3398 broadcast(
3399 None,
3400 channel_members
3401 .iter()
3402 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3403 |peer_id| {
3404 peer.send(
3405 peer_id.into(),
3406 proto::UpdateChannels {
3407 channel_participants: vec![proto::ChannelParticipants {
3408 channel_id: channel_id.to_proto(),
3409 participant_user_ids: participants.clone(),
3410 }],
3411 ..Default::default()
3412 },
3413 )
3414 },
3415 );
3416}
3417
3418async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3419 let db = session.db().await;
3420
3421 let contacts = db.get_contacts(user_id).await?;
3422 let busy = db.is_user_busy(user_id).await?;
3423
3424 let pool = session.connection_pool().await;
3425 let updated_contact = contact_for_user(user_id, busy, &pool);
3426 for contact in contacts {
3427 if let db::Contact::Accepted {
3428 user_id: contact_user_id,
3429 ..
3430 } = contact
3431 {
3432 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3433 session
3434 .peer
3435 .send(
3436 contact_conn_id,
3437 proto::UpdateContacts {
3438 contacts: vec![updated_contact.clone()],
3439 remove_contacts: Default::default(),
3440 incoming_requests: Default::default(),
3441 remove_incoming_requests: Default::default(),
3442 outgoing_requests: Default::default(),
3443 remove_outgoing_requests: Default::default(),
3444 },
3445 )
3446 .trace_err();
3447 }
3448 }
3449 }
3450 Ok(())
3451}
3452
3453async fn leave_room_for_session(session: &Session) -> Result<()> {
3454 let mut contacts_to_update = HashSet::default();
3455
3456 let room_id;
3457 let canceled_calls_to_user_ids;
3458 let live_kit_room;
3459 let delete_live_kit_room;
3460 let room;
3461 let channel_members;
3462 let channel_id;
3463
3464 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3465 contacts_to_update.insert(session.user_id);
3466
3467 for project in left_room.left_projects.values() {
3468 project_left(project, session);
3469 }
3470
3471 room_id = RoomId::from_proto(left_room.room.id);
3472 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3473 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3474 delete_live_kit_room = left_room.deleted;
3475 room = mem::take(&mut left_room.room);
3476 channel_members = mem::take(&mut left_room.channel_members);
3477 channel_id = left_room.channel_id;
3478
3479 room_updated(&room, &session.peer);
3480 } else {
3481 return Ok(());
3482 }
3483
3484 if let Some(channel_id) = channel_id {
3485 channel_updated(
3486 channel_id,
3487 &room,
3488 &channel_members,
3489 &session.peer,
3490 &*session.connection_pool().await,
3491 );
3492 }
3493
3494 {
3495 let pool = session.connection_pool().await;
3496 for canceled_user_id in canceled_calls_to_user_ids {
3497 for connection_id in pool.user_connection_ids(canceled_user_id) {
3498 session
3499 .peer
3500 .send(
3501 connection_id,
3502 proto::CallCanceled {
3503 room_id: room_id.to_proto(),
3504 },
3505 )
3506 .trace_err();
3507 }
3508 contacts_to_update.insert(canceled_user_id);
3509 }
3510 }
3511
3512 for contact_user_id in contacts_to_update {
3513 update_user_contacts(contact_user_id, &session).await?;
3514 }
3515
3516 if let Some(live_kit) = session.live_kit_client.as_ref() {
3517 live_kit
3518 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3519 .await
3520 .trace_err();
3521
3522 if delete_live_kit_room {
3523 live_kit.delete_room(live_kit_room).await.trace_err();
3524 }
3525 }
3526
3527 Ok(())
3528}
3529
3530async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3531 let left_channel_buffers = session
3532 .db()
3533 .await
3534 .leave_channel_buffers(session.connection_id)
3535 .await?;
3536
3537 for left_buffer in left_channel_buffers {
3538 channel_buffer_updated(
3539 session.connection_id,
3540 left_buffer.connections,
3541 &proto::UpdateChannelBufferCollaborators {
3542 channel_id: left_buffer.channel_id.to_proto(),
3543 collaborators: left_buffer.collaborators,
3544 },
3545 &session.peer,
3546 );
3547 }
3548
3549 Ok(())
3550}
3551
3552fn project_left(project: &db::LeftProject, session: &Session) {
3553 for connection_id in &project.connection_ids {
3554 if project.host_user_id == session.user_id {
3555 session
3556 .peer
3557 .send(
3558 *connection_id,
3559 proto::UnshareProject {
3560 project_id: project.id.to_proto(),
3561 },
3562 )
3563 .trace_err();
3564 } else {
3565 session
3566 .peer
3567 .send(
3568 *connection_id,
3569 proto::RemoveProjectCollaborator {
3570 project_id: project.id.to_proto(),
3571 peer_id: Some(session.connection_id.into()),
3572 },
3573 )
3574 .trace_err();
3575 }
3576 }
3577}
3578
3579pub trait ResultExt {
3580 type Ok;
3581
3582 fn trace_err(self) -> Option<Self::Ok>;
3583}
3584
3585impl<T, E> ResultExt for Result<T, E>
3586where
3587 E: std::fmt::Debug,
3588{
3589 type Ok = T;
3590
3591 fn trace_err(self) -> Option<T> {
3592 match self {
3593 Ok(value) => Some(value),
3594 Err(error) => {
3595 tracing::error!("{:?}", error);
3596 None
3597 }
3598 }
3599 }
3600}