1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{
6 self, BufferId, ChannelId, ChannelVisibility, ChannelsForUser, CreatedChannelMessage,
7 Database, MessageId, ProjectId, RoomId, ServerId, User, UserId,
8 },
9 executor::Executor,
10 AppState, Result,
11};
12use anyhow::anyhow;
13use async_tungstenite::tungstenite::{
14 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
15};
16use axum::{
17 body::Body,
18 extract::{
19 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
20 ConnectInfo, WebSocketUpgrade,
21 },
22 headers::{Header, HeaderName},
23 http::StatusCode,
24 middleware,
25 response::IntoResponse,
26 routing::get,
27 Extension, Router, TypedHeader,
28};
29use collections::{HashMap, HashSet};
30pub use connection_pool::ConnectionPool;
31use futures::{
32 channel::oneshot,
33 future::{self, BoxFuture},
34 stream::FuturesUnordered,
35 FutureExt, SinkExt, StreamExt, TryStreamExt,
36};
37use lazy_static::lazy_static;
38use prometheus::{register_int_gauge, IntGauge};
39use rpc::{
40 proto::{
41 self, Ack, AnyTypedEnvelope, ChannelEdge, EntityMessage, EnvelopedMessage,
42 LiveKitConnectionInfo, RequestMessage, UpdateChannelBufferCollaborators,
43 },
44 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
45};
46use serde::{Serialize, Serializer};
47use std::{
48 any::TypeId,
49 fmt,
50 future::Future,
51 marker::PhantomData,
52 mem,
53 net::SocketAddr,
54 ops::{Deref, DerefMut},
55 rc::Rc,
56 sync::{
57 atomic::{AtomicBool, Ordering::SeqCst},
58 Arc,
59 },
60 time::{Duration, Instant},
61};
62use time::OffsetDateTime;
63use tokio::sync::{watch, Semaphore};
64use tower::ServiceBuilder;
65use tracing::{info_span, instrument, Instrument};
66use util::channel::RELEASE_CHANNEL_NAME;
67
68pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
69pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
70
71const MESSAGE_COUNT_PER_PAGE: usize = 100;
72const MAX_MESSAGE_LEN: usize = 1024;
73const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
74
75lazy_static! {
76 static ref METRIC_CONNECTIONS: IntGauge =
77 register_int_gauge!("connections", "number of connections").unwrap();
78 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
79 "shared_projects",
80 "number of open projects with one or more guests"
81 )
82 .unwrap();
83}
84
85type MessageHandler =
86 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
87
88struct Response<R> {
89 peer: Arc<Peer>,
90 receipt: Receipt<R>,
91 responded: Arc<AtomicBool>,
92}
93
94impl<R: RequestMessage> Response<R> {
95 fn send(self, payload: R::Response) -> Result<()> {
96 self.responded.store(true, SeqCst);
97 self.peer.respond(self.receipt, payload)?;
98 Ok(())
99 }
100}
101
102#[derive(Clone)]
103struct Session {
104 user_id: UserId,
105 connection_id: ConnectionId,
106 db: Arc<tokio::sync::Mutex<DbHandle>>,
107 peer: Arc<Peer>,
108 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
109 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
110 executor: Executor,
111}
112
113impl Session {
114 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
115 #[cfg(test)]
116 tokio::task::yield_now().await;
117 let guard = self.db.lock().await;
118 #[cfg(test)]
119 tokio::task::yield_now().await;
120 guard
121 }
122
123 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
124 #[cfg(test)]
125 tokio::task::yield_now().await;
126 let guard = self.connection_pool.lock();
127 ConnectionPoolGuard {
128 guard,
129 _not_send: PhantomData,
130 }
131 }
132}
133
134impl fmt::Debug for Session {
135 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
136 f.debug_struct("Session")
137 .field("user_id", &self.user_id)
138 .field("connection_id", &self.connection_id)
139 .finish()
140 }
141}
142
143struct DbHandle(Arc<Database>);
144
145impl Deref for DbHandle {
146 type Target = Database;
147
148 fn deref(&self) -> &Self::Target {
149 self.0.as_ref()
150 }
151}
152
153pub struct Server {
154 id: parking_lot::Mutex<ServerId>,
155 peer: Arc<Peer>,
156 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
157 app_state: Arc<AppState>,
158 executor: Executor,
159 handlers: HashMap<TypeId, MessageHandler>,
160 teardown: watch::Sender<()>,
161}
162
163pub(crate) struct ConnectionPoolGuard<'a> {
164 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
165 _not_send: PhantomData<Rc<()>>,
166}
167
168#[derive(Serialize)]
169pub struct ServerSnapshot<'a> {
170 peer: &'a Peer,
171 #[serde(serialize_with = "serialize_deref")]
172 connection_pool: ConnectionPoolGuard<'a>,
173}
174
175pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
176where
177 S: Serializer,
178 T: Deref<Target = U>,
179 U: Serialize,
180{
181 Serialize::serialize(value.deref(), serializer)
182}
183
184impl Server {
185 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
186 let mut server = Self {
187 id: parking_lot::Mutex::new(id),
188 peer: Peer::new(id.0 as u32),
189 app_state,
190 executor,
191 connection_pool: Default::default(),
192 handlers: Default::default(),
193 teardown: watch::channel(()).0,
194 };
195
196 server
197 .add_request_handler(ping)
198 .add_request_handler(create_room)
199 .add_request_handler(join_room)
200 .add_request_handler(rejoin_room)
201 .add_request_handler(leave_room)
202 .add_request_handler(call)
203 .add_request_handler(cancel_call)
204 .add_message_handler(decline_call)
205 .add_request_handler(update_participant_location)
206 .add_request_handler(share_project)
207 .add_message_handler(unshare_project)
208 .add_request_handler(join_project)
209 .add_message_handler(leave_project)
210 .add_request_handler(update_project)
211 .add_request_handler(update_worktree)
212 .add_message_handler(start_language_server)
213 .add_message_handler(update_language_server)
214 .add_message_handler(update_diagnostic_summary)
215 .add_message_handler(update_worktree_settings)
216 .add_message_handler(refresh_inlay_hints)
217 .add_request_handler(forward_project_request::<proto::GetHover>)
218 .add_request_handler(forward_project_request::<proto::GetDefinition>)
219 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
220 .add_request_handler(forward_project_request::<proto::GetReferences>)
221 .add_request_handler(forward_project_request::<proto::SearchProject>)
222 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
223 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
224 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
225 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
226 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
227 .add_request_handler(forward_project_request::<proto::GetCompletions>)
228 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
229 .add_request_handler(forward_project_request::<proto::ResolveCompletionDocumentation>)
230 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
231 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
232 .add_request_handler(forward_project_request::<proto::PrepareRename>)
233 .add_request_handler(forward_project_request::<proto::PerformRename>)
234 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
235 .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
236 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
237 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
238 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
239 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
240 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
241 .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
242 .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
243 .add_request_handler(forward_project_request::<proto::InlayHints>)
244 .add_message_handler(create_buffer_for_peer)
245 .add_request_handler(update_buffer)
246 .add_message_handler(update_buffer_file)
247 .add_message_handler(buffer_reloaded)
248 .add_message_handler(buffer_saved)
249 .add_request_handler(forward_project_request::<proto::SaveBuffer>)
250 .add_request_handler(get_users)
251 .add_request_handler(fuzzy_search_users)
252 .add_request_handler(request_contact)
253 .add_request_handler(remove_contact)
254 .add_request_handler(respond_to_contact_request)
255 .add_request_handler(create_channel)
256 .add_request_handler(delete_channel)
257 .add_request_handler(invite_channel_member)
258 .add_request_handler(remove_channel_member)
259 .add_request_handler(set_channel_member_role)
260 .add_request_handler(set_channel_visibility)
261 .add_request_handler(rename_channel)
262 .add_request_handler(join_channel_buffer)
263 .add_request_handler(leave_channel_buffer)
264 .add_message_handler(update_channel_buffer)
265 .add_request_handler(rejoin_channel_buffers)
266 .add_request_handler(get_channel_members)
267 .add_request_handler(respond_to_channel_invite)
268 .add_request_handler(join_channel)
269 .add_request_handler(join_channel_chat)
270 .add_message_handler(leave_channel_chat)
271 .add_request_handler(send_channel_message)
272 .add_request_handler(remove_channel_message)
273 .add_request_handler(get_channel_messages)
274 .add_request_handler(get_channel_messages_by_id)
275 .add_request_handler(get_notifications)
276 .add_request_handler(link_channel)
277 .add_request_handler(unlink_channel)
278 .add_request_handler(move_channel)
279 .add_request_handler(follow)
280 .add_message_handler(unfollow)
281 .add_message_handler(update_followers)
282 .add_message_handler(update_diff_base)
283 .add_request_handler(get_private_user_info)
284 .add_message_handler(acknowledge_channel_message)
285 .add_message_handler(acknowledge_buffer_version);
286
287 Arc::new(server)
288 }
289
290 pub async fn start(&self) -> Result<()> {
291 let server_id = *self.id.lock();
292 let app_state = self.app_state.clone();
293 let peer = self.peer.clone();
294 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
295 let pool = self.connection_pool.clone();
296 let live_kit_client = self.app_state.live_kit_client.clone();
297
298 let span = info_span!("start server");
299 self.executor.spawn_detached(
300 async move {
301 tracing::info!("waiting for cleanup timeout");
302 timeout.await;
303 tracing::info!("cleanup timeout expired, retrieving stale rooms");
304 if let Some((room_ids, channel_ids)) = app_state
305 .db
306 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
307 .await
308 .trace_err()
309 {
310 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
311 tracing::info!(
312 stale_channel_buffer_count = channel_ids.len(),
313 "retrieved stale channel buffers"
314 );
315
316 for channel_id in channel_ids {
317 if let Some(refreshed_channel_buffer) = app_state
318 .db
319 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
320 .await
321 .trace_err()
322 {
323 for connection_id in refreshed_channel_buffer.connection_ids {
324 peer.send(
325 connection_id,
326 proto::UpdateChannelBufferCollaborators {
327 channel_id: channel_id.to_proto(),
328 collaborators: refreshed_channel_buffer
329 .collaborators
330 .clone(),
331 },
332 )
333 .trace_err();
334 }
335 }
336 }
337
338 for room_id in room_ids {
339 let mut contacts_to_update = HashSet::default();
340 let mut canceled_calls_to_user_ids = Vec::new();
341 let mut live_kit_room = String::new();
342 let mut delete_live_kit_room = false;
343
344 if let Some(mut refreshed_room) = app_state
345 .db
346 .clear_stale_room_participants(room_id, server_id)
347 .await
348 .trace_err()
349 {
350 tracing::info!(
351 room_id = room_id.0,
352 new_participant_count = refreshed_room.room.participants.len(),
353 "refreshed room"
354 );
355 room_updated(&refreshed_room.room, &peer);
356 if let Some(channel_id) = refreshed_room.channel_id {
357 channel_updated(
358 channel_id,
359 &refreshed_room.room,
360 &refreshed_room.channel_members,
361 &peer,
362 &*pool.lock(),
363 );
364 }
365 contacts_to_update
366 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
367 contacts_to_update
368 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
369 canceled_calls_to_user_ids =
370 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
371 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
372 delete_live_kit_room = refreshed_room.room.participants.is_empty();
373 }
374
375 {
376 let pool = pool.lock();
377 for canceled_user_id in canceled_calls_to_user_ids {
378 for connection_id in pool.user_connection_ids(canceled_user_id) {
379 peer.send(
380 connection_id,
381 proto::CallCanceled {
382 room_id: room_id.to_proto(),
383 },
384 )
385 .trace_err();
386 }
387 }
388 }
389
390 for user_id in contacts_to_update {
391 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
392 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
393 if let Some((busy, contacts)) = busy.zip(contacts) {
394 let pool = pool.lock();
395 let updated_contact = contact_for_user(user_id, busy, &pool);
396 for contact in contacts {
397 if let db::Contact::Accepted {
398 user_id: contact_user_id,
399 ..
400 } = contact
401 {
402 for contact_conn_id in
403 pool.user_connection_ids(contact_user_id)
404 {
405 peer.send(
406 contact_conn_id,
407 proto::UpdateContacts {
408 contacts: vec![updated_contact.clone()],
409 remove_contacts: Default::default(),
410 incoming_requests: Default::default(),
411 remove_incoming_requests: Default::default(),
412 outgoing_requests: Default::default(),
413 remove_outgoing_requests: Default::default(),
414 },
415 )
416 .trace_err();
417 }
418 }
419 }
420 }
421 }
422
423 if let Some(live_kit) = live_kit_client.as_ref() {
424 if delete_live_kit_room {
425 live_kit.delete_room(live_kit_room).await.trace_err();
426 }
427 }
428 }
429 }
430
431 app_state
432 .db
433 .delete_stale_servers(&app_state.config.zed_environment, server_id)
434 .await
435 .trace_err();
436 }
437 .instrument(span),
438 );
439 Ok(())
440 }
441
442 pub fn teardown(&self) {
443 self.peer.teardown();
444 self.connection_pool.lock().reset();
445 let _ = self.teardown.send(());
446 }
447
448 #[cfg(test)]
449 pub fn reset(&self, id: ServerId) {
450 self.teardown();
451 *self.id.lock() = id;
452 self.peer.reset(id.0 as u32);
453 }
454
455 #[cfg(test)]
456 pub fn id(&self) -> ServerId {
457 *self.id.lock()
458 }
459
460 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
461 where
462 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
463 Fut: 'static + Send + Future<Output = Result<()>>,
464 M: EnvelopedMessage,
465 {
466 let prev_handler = self.handlers.insert(
467 TypeId::of::<M>(),
468 Box::new(move |envelope, session| {
469 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
470 let span = info_span!(
471 "handle message",
472 payload_type = envelope.payload_type_name()
473 );
474 span.in_scope(|| {
475 tracing::info!(
476 payload_type = envelope.payload_type_name(),
477 "message received"
478 );
479 });
480 let start_time = Instant::now();
481 let future = (handler)(*envelope, session);
482 async move {
483 let result = future.await;
484 let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
485 match result {
486 Err(error) => {
487 tracing::error!(%error, ?duration_ms, "error handling message")
488 }
489 Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
490 }
491 }
492 .instrument(span)
493 .boxed()
494 }),
495 );
496 if prev_handler.is_some() {
497 panic!("registered a handler for the same message twice");
498 }
499 self
500 }
501
502 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
503 where
504 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
505 Fut: 'static + Send + Future<Output = Result<()>>,
506 M: EnvelopedMessage,
507 {
508 self.add_handler(move |envelope, session| handler(envelope.payload, session));
509 self
510 }
511
512 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
513 where
514 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
515 Fut: Send + Future<Output = Result<()>>,
516 M: RequestMessage,
517 {
518 let handler = Arc::new(handler);
519 self.add_handler(move |envelope, session| {
520 let receipt = envelope.receipt();
521 let handler = handler.clone();
522 async move {
523 let peer = session.peer.clone();
524 let responded = Arc::new(AtomicBool::default());
525 let response = Response {
526 peer: peer.clone(),
527 responded: responded.clone(),
528 receipt,
529 };
530 match (handler)(envelope.payload, response, session).await {
531 Ok(()) => {
532 if responded.load(std::sync::atomic::Ordering::SeqCst) {
533 Ok(())
534 } else {
535 Err(anyhow!("handler did not send a response"))?
536 }
537 }
538 Err(error) => {
539 peer.respond_with_error(
540 receipt,
541 proto::Error {
542 message: error.to_string(),
543 },
544 )?;
545 Err(error)
546 }
547 }
548 }
549 })
550 }
551
552 pub fn handle_connection(
553 self: &Arc<Self>,
554 connection: Connection,
555 address: String,
556 user: User,
557 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
558 executor: Executor,
559 ) -> impl Future<Output = Result<()>> {
560 let this = self.clone();
561 let user_id = user.id;
562 let login = user.github_login;
563 let span = info_span!("handle connection", %user_id, %login, %address);
564 let mut teardown = self.teardown.subscribe();
565 async move {
566 let (connection_id, handle_io, mut incoming_rx) = this
567 .peer
568 .add_connection(connection, {
569 let executor = executor.clone();
570 move |duration| executor.sleep(duration)
571 });
572
573 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
574 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
575 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
576
577 if let Some(send_connection_id) = send_connection_id.take() {
578 let _ = send_connection_id.send(connection_id);
579 }
580
581 if !user.connected_once {
582 this.peer.send(connection_id, proto::ShowContacts {})?;
583 this.app_state.db.set_user_connected_once(user_id, true).await?;
584 }
585
586 let (contacts, channels_for_user, channel_invites) = future::try_join3(
587 this.app_state.db.get_contacts(user_id),
588 this.app_state.db.get_channels_for_user(user_id),
589 this.app_state.db.get_channel_invites_for_user(user_id),
590 ).await?;
591
592 {
593 let mut pool = this.connection_pool.lock();
594 pool.add_connection(connection_id, user_id, user.admin);
595 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
596 this.peer.send(connection_id, build_initial_channels_update(
597 channels_for_user,
598 channel_invites
599 ))?;
600 }
601
602 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
603 this.peer.send(connection_id, incoming_call)?;
604 }
605
606 let session = Session {
607 user_id,
608 connection_id,
609 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
610 peer: this.peer.clone(),
611 connection_pool: this.connection_pool.clone(),
612 live_kit_client: this.app_state.live_kit_client.clone(),
613 executor: executor.clone(),
614 };
615 update_user_contacts(user_id, &session).await?;
616
617 let handle_io = handle_io.fuse();
618 futures::pin_mut!(handle_io);
619
620 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
621 // This prevents deadlocks when e.g., client A performs a request to client B and
622 // client B performs a request to client A. If both clients stop processing further
623 // messages until their respective request completes, they won't have a chance to
624 // respond to the other client's request and cause a deadlock.
625 //
626 // This arrangement ensures we will attempt to process earlier messages first, but fall
627 // back to processing messages arrived later in the spirit of making progress.
628 let mut foreground_message_handlers = FuturesUnordered::new();
629 let concurrent_handlers = Arc::new(Semaphore::new(256));
630 loop {
631 let next_message = async {
632 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
633 let message = incoming_rx.next().await;
634 (permit, message)
635 }.fuse();
636 futures::pin_mut!(next_message);
637 futures::select_biased! {
638 _ = teardown.changed().fuse() => return Ok(()),
639 result = handle_io => {
640 if let Err(error) = result {
641 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
642 }
643 break;
644 }
645 _ = foreground_message_handlers.next() => {}
646 next_message = next_message => {
647 let (permit, message) = next_message;
648 if let Some(message) = message {
649 let type_name = message.payload_type_name();
650 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
651 let span_enter = span.enter();
652 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
653 let is_background = message.is_background();
654 let handle_message = (handler)(message, session.clone());
655 drop(span_enter);
656
657 let handle_message = async move {
658 handle_message.await;
659 drop(permit);
660 }.instrument(span);
661 if is_background {
662 executor.spawn_detached(handle_message);
663 } else {
664 foreground_message_handlers.push(handle_message);
665 }
666 } else {
667 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
668 }
669 } else {
670 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
671 break;
672 }
673 }
674 }
675 }
676
677 drop(foreground_message_handlers);
678 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
679 if let Err(error) = connection_lost(session, teardown, executor).await {
680 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
681 }
682
683 Ok(())
684 }.instrument(span)
685 }
686
687 pub async fn invite_code_redeemed(
688 self: &Arc<Self>,
689 inviter_id: UserId,
690 invitee_id: UserId,
691 ) -> Result<()> {
692 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
693 if let Some(code) = &user.invite_code {
694 let pool = self.connection_pool.lock();
695 let invitee_contact = contact_for_user(invitee_id, false, &pool);
696 for connection_id in pool.user_connection_ids(inviter_id) {
697 self.peer.send(
698 connection_id,
699 proto::UpdateContacts {
700 contacts: vec![invitee_contact.clone()],
701 ..Default::default()
702 },
703 )?;
704 self.peer.send(
705 connection_id,
706 proto::UpdateInviteInfo {
707 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
708 count: user.invite_count as u32,
709 },
710 )?;
711 }
712 }
713 }
714 Ok(())
715 }
716
717 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
718 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
719 if let Some(invite_code) = &user.invite_code {
720 let pool = self.connection_pool.lock();
721 for connection_id in pool.user_connection_ids(user_id) {
722 self.peer.send(
723 connection_id,
724 proto::UpdateInviteInfo {
725 url: format!(
726 "{}{}",
727 self.app_state.config.invite_link_prefix, invite_code
728 ),
729 count: user.invite_count as u32,
730 },
731 )?;
732 }
733 }
734 }
735 Ok(())
736 }
737
738 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
739 ServerSnapshot {
740 connection_pool: ConnectionPoolGuard {
741 guard: self.connection_pool.lock(),
742 _not_send: PhantomData,
743 },
744 peer: &self.peer,
745 }
746 }
747}
748
749impl<'a> Deref for ConnectionPoolGuard<'a> {
750 type Target = ConnectionPool;
751
752 fn deref(&self) -> &Self::Target {
753 &*self.guard
754 }
755}
756
757impl<'a> DerefMut for ConnectionPoolGuard<'a> {
758 fn deref_mut(&mut self) -> &mut Self::Target {
759 &mut *self.guard
760 }
761}
762
763impl<'a> Drop for ConnectionPoolGuard<'a> {
764 fn drop(&mut self) {
765 #[cfg(test)]
766 self.check_invariants();
767 }
768}
769
770fn broadcast<F>(
771 sender_id: Option<ConnectionId>,
772 receiver_ids: impl IntoIterator<Item = ConnectionId>,
773 mut f: F,
774) where
775 F: FnMut(ConnectionId) -> anyhow::Result<()>,
776{
777 for receiver_id in receiver_ids {
778 if Some(receiver_id) != sender_id {
779 if let Err(error) = f(receiver_id) {
780 tracing::error!("failed to send to {:?} {}", receiver_id, error);
781 }
782 }
783 }
784}
785
786lazy_static! {
787 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
788}
789
790pub struct ProtocolVersion(u32);
791
792impl Header for ProtocolVersion {
793 fn name() -> &'static HeaderName {
794 &ZED_PROTOCOL_VERSION
795 }
796
797 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
798 where
799 Self: Sized,
800 I: Iterator<Item = &'i axum::http::HeaderValue>,
801 {
802 let version = values
803 .next()
804 .ok_or_else(axum::headers::Error::invalid)?
805 .to_str()
806 .map_err(|_| axum::headers::Error::invalid())?
807 .parse()
808 .map_err(|_| axum::headers::Error::invalid())?;
809 Ok(Self(version))
810 }
811
812 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
813 values.extend([self.0.to_string().parse().unwrap()]);
814 }
815}
816
817pub fn routes(server: Arc<Server>) -> Router<Body> {
818 Router::new()
819 .route("/rpc", get(handle_websocket_request))
820 .layer(
821 ServiceBuilder::new()
822 .layer(Extension(server.app_state.clone()))
823 .layer(middleware::from_fn(auth::validate_header)),
824 )
825 .route("/metrics", get(handle_metrics))
826 .layer(Extension(server))
827}
828
829pub async fn handle_websocket_request(
830 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
831 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
832 Extension(server): Extension<Arc<Server>>,
833 Extension(user): Extension<User>,
834 ws: WebSocketUpgrade,
835) -> axum::response::Response {
836 if protocol_version != rpc::PROTOCOL_VERSION {
837 return (
838 StatusCode::UPGRADE_REQUIRED,
839 "client must be upgraded".to_string(),
840 )
841 .into_response();
842 }
843 let socket_address = socket_address.to_string();
844 ws.on_upgrade(move |socket| {
845 use util::ResultExt;
846 let socket = socket
847 .map_ok(to_tungstenite_message)
848 .err_into()
849 .with(|message| async move { Ok(to_axum_message(message)) });
850 let connection = Connection::new(Box::pin(socket));
851 async move {
852 server
853 .handle_connection(connection, socket_address, user, None, Executor::Production)
854 .await
855 .log_err();
856 }
857 })
858}
859
860pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
861 let connections = server
862 .connection_pool
863 .lock()
864 .connections()
865 .filter(|connection| !connection.admin)
866 .count();
867
868 METRIC_CONNECTIONS.set(connections as _);
869
870 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
871 METRIC_SHARED_PROJECTS.set(shared_projects as _);
872
873 let encoder = prometheus::TextEncoder::new();
874 let metric_families = prometheus::gather();
875 let encoded_metrics = encoder
876 .encode_to_string(&metric_families)
877 .map_err(|err| anyhow!("{}", err))?;
878 Ok(encoded_metrics)
879}
880
881#[instrument(err, skip(executor))]
882async fn connection_lost(
883 session: Session,
884 mut teardown: watch::Receiver<()>,
885 executor: Executor,
886) -> Result<()> {
887 session.peer.disconnect(session.connection_id);
888 session
889 .connection_pool()
890 .await
891 .remove_connection(session.connection_id)?;
892
893 session
894 .db()
895 .await
896 .connection_lost(session.connection_id)
897 .await
898 .trace_err();
899
900 futures::select_biased! {
901 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
902 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
903 leave_room_for_session(&session).await.trace_err();
904 leave_channel_buffers_for_session(&session)
905 .await
906 .trace_err();
907
908 if !session
909 .connection_pool()
910 .await
911 .is_user_online(session.user_id)
912 {
913 let db = session.db().await;
914 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
915 room_updated(&room, &session.peer);
916 }
917 }
918
919 update_user_contacts(session.user_id, &session).await?;
920 }
921 _ = teardown.changed().fuse() => {}
922 }
923
924 Ok(())
925}
926
927async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
928 response.send(proto::Ack {})?;
929 Ok(())
930}
931
932async fn create_room(
933 _request: proto::CreateRoom,
934 response: Response<proto::CreateRoom>,
935 session: Session,
936) -> Result<()> {
937 let live_kit_room = nanoid::nanoid!(30);
938
939 let live_kit_connection_info = {
940 let live_kit_room = live_kit_room.clone();
941 let live_kit = session.live_kit_client.as_ref();
942
943 util::async_iife!({
944 let live_kit = live_kit?;
945
946 let token = live_kit
947 .room_token(&live_kit_room, &session.user_id.to_string())
948 .trace_err()?;
949
950 Some(proto::LiveKitConnectionInfo {
951 server_url: live_kit.url().into(),
952 token,
953 })
954 })
955 }
956 .await;
957
958 let room = session
959 .db()
960 .await
961 .create_room(
962 session.user_id,
963 session.connection_id,
964 &live_kit_room,
965 RELEASE_CHANNEL_NAME.as_str(),
966 )
967 .await?;
968
969 response.send(proto::CreateRoomResponse {
970 room: Some(room.clone()),
971 live_kit_connection_info,
972 })?;
973
974 update_user_contacts(session.user_id, &session).await?;
975 Ok(())
976}
977
978async fn join_room(
979 request: proto::JoinRoom,
980 response: Response<proto::JoinRoom>,
981 session: Session,
982) -> Result<()> {
983 let room_id = RoomId::from_proto(request.id);
984
985 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
986
987 if let Some(channel_id) = channel_id {
988 return join_channel_internal(channel_id, Box::new(response), session).await;
989 }
990
991 let joined_room = {
992 let room = session
993 .db()
994 .await
995 .join_room(
996 room_id,
997 session.user_id,
998 session.connection_id,
999 RELEASE_CHANNEL_NAME.as_str(),
1000 )
1001 .await?;
1002 room_updated(&room.room, &session.peer);
1003 room.into_inner()
1004 };
1005
1006 for connection_id in session
1007 .connection_pool()
1008 .await
1009 .user_connection_ids(session.user_id)
1010 {
1011 session
1012 .peer
1013 .send(
1014 connection_id,
1015 proto::CallCanceled {
1016 room_id: room_id.to_proto(),
1017 },
1018 )
1019 .trace_err();
1020 }
1021
1022 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1023 if let Some(token) = live_kit
1024 .room_token(
1025 &joined_room.room.live_kit_room,
1026 &session.user_id.to_string(),
1027 )
1028 .trace_err()
1029 {
1030 Some(proto::LiveKitConnectionInfo {
1031 server_url: live_kit.url().into(),
1032 token,
1033 })
1034 } else {
1035 None
1036 }
1037 } else {
1038 None
1039 };
1040
1041 response.send(proto::JoinRoomResponse {
1042 room: Some(joined_room.room),
1043 channel_id: None,
1044 live_kit_connection_info,
1045 })?;
1046
1047 update_user_contacts(session.user_id, &session).await?;
1048 Ok(())
1049}
1050
1051async fn rejoin_room(
1052 request: proto::RejoinRoom,
1053 response: Response<proto::RejoinRoom>,
1054 session: Session,
1055) -> Result<()> {
1056 let room;
1057 let channel_id;
1058 let channel_members;
1059 {
1060 let mut rejoined_room = session
1061 .db()
1062 .await
1063 .rejoin_room(request, session.user_id, session.connection_id)
1064 .await?;
1065
1066 response.send(proto::RejoinRoomResponse {
1067 room: Some(rejoined_room.room.clone()),
1068 reshared_projects: rejoined_room
1069 .reshared_projects
1070 .iter()
1071 .map(|project| proto::ResharedProject {
1072 id: project.id.to_proto(),
1073 collaborators: project
1074 .collaborators
1075 .iter()
1076 .map(|collaborator| collaborator.to_proto())
1077 .collect(),
1078 })
1079 .collect(),
1080 rejoined_projects: rejoined_room
1081 .rejoined_projects
1082 .iter()
1083 .map(|rejoined_project| proto::RejoinedProject {
1084 id: rejoined_project.id.to_proto(),
1085 worktrees: rejoined_project
1086 .worktrees
1087 .iter()
1088 .map(|worktree| proto::WorktreeMetadata {
1089 id: worktree.id,
1090 root_name: worktree.root_name.clone(),
1091 visible: worktree.visible,
1092 abs_path: worktree.abs_path.clone(),
1093 })
1094 .collect(),
1095 collaborators: rejoined_project
1096 .collaborators
1097 .iter()
1098 .map(|collaborator| collaborator.to_proto())
1099 .collect(),
1100 language_servers: rejoined_project.language_servers.clone(),
1101 })
1102 .collect(),
1103 })?;
1104 room_updated(&rejoined_room.room, &session.peer);
1105
1106 for project in &rejoined_room.reshared_projects {
1107 for collaborator in &project.collaborators {
1108 session
1109 .peer
1110 .send(
1111 collaborator.connection_id,
1112 proto::UpdateProjectCollaborator {
1113 project_id: project.id.to_proto(),
1114 old_peer_id: Some(project.old_connection_id.into()),
1115 new_peer_id: Some(session.connection_id.into()),
1116 },
1117 )
1118 .trace_err();
1119 }
1120
1121 broadcast(
1122 Some(session.connection_id),
1123 project
1124 .collaborators
1125 .iter()
1126 .map(|collaborator| collaborator.connection_id),
1127 |connection_id| {
1128 session.peer.forward_send(
1129 session.connection_id,
1130 connection_id,
1131 proto::UpdateProject {
1132 project_id: project.id.to_proto(),
1133 worktrees: project.worktrees.clone(),
1134 },
1135 )
1136 },
1137 );
1138 }
1139
1140 for project in &rejoined_room.rejoined_projects {
1141 for collaborator in &project.collaborators {
1142 session
1143 .peer
1144 .send(
1145 collaborator.connection_id,
1146 proto::UpdateProjectCollaborator {
1147 project_id: project.id.to_proto(),
1148 old_peer_id: Some(project.old_connection_id.into()),
1149 new_peer_id: Some(session.connection_id.into()),
1150 },
1151 )
1152 .trace_err();
1153 }
1154 }
1155
1156 for project in &mut rejoined_room.rejoined_projects {
1157 for worktree in mem::take(&mut project.worktrees) {
1158 #[cfg(any(test, feature = "test-support"))]
1159 const MAX_CHUNK_SIZE: usize = 2;
1160 #[cfg(not(any(test, feature = "test-support")))]
1161 const MAX_CHUNK_SIZE: usize = 256;
1162
1163 // Stream this worktree's entries.
1164 let message = proto::UpdateWorktree {
1165 project_id: project.id.to_proto(),
1166 worktree_id: worktree.id,
1167 abs_path: worktree.abs_path.clone(),
1168 root_name: worktree.root_name,
1169 updated_entries: worktree.updated_entries,
1170 removed_entries: worktree.removed_entries,
1171 scan_id: worktree.scan_id,
1172 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1173 updated_repositories: worktree.updated_repositories,
1174 removed_repositories: worktree.removed_repositories,
1175 };
1176 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1177 session.peer.send(session.connection_id, update.clone())?;
1178 }
1179
1180 // Stream this worktree's diagnostics.
1181 for summary in worktree.diagnostic_summaries {
1182 session.peer.send(
1183 session.connection_id,
1184 proto::UpdateDiagnosticSummary {
1185 project_id: project.id.to_proto(),
1186 worktree_id: worktree.id,
1187 summary: Some(summary),
1188 },
1189 )?;
1190 }
1191
1192 for settings_file in worktree.settings_files {
1193 session.peer.send(
1194 session.connection_id,
1195 proto::UpdateWorktreeSettings {
1196 project_id: project.id.to_proto(),
1197 worktree_id: worktree.id,
1198 path: settings_file.path,
1199 content: Some(settings_file.content),
1200 },
1201 )?;
1202 }
1203 }
1204
1205 for language_server in &project.language_servers {
1206 session.peer.send(
1207 session.connection_id,
1208 proto::UpdateLanguageServer {
1209 project_id: project.id.to_proto(),
1210 language_server_id: language_server.id,
1211 variant: Some(
1212 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1213 proto::LspDiskBasedDiagnosticsUpdated {},
1214 ),
1215 ),
1216 },
1217 )?;
1218 }
1219 }
1220
1221 let rejoined_room = rejoined_room.into_inner();
1222
1223 room = rejoined_room.room;
1224 channel_id = rejoined_room.channel_id;
1225 channel_members = rejoined_room.channel_members;
1226 }
1227
1228 if let Some(channel_id) = channel_id {
1229 channel_updated(
1230 channel_id,
1231 &room,
1232 &channel_members,
1233 &session.peer,
1234 &*session.connection_pool().await,
1235 );
1236 }
1237
1238 update_user_contacts(session.user_id, &session).await?;
1239 Ok(())
1240}
1241
1242async fn leave_room(
1243 _: proto::LeaveRoom,
1244 response: Response<proto::LeaveRoom>,
1245 session: Session,
1246) -> Result<()> {
1247 leave_room_for_session(&session).await?;
1248 response.send(proto::Ack {})?;
1249 Ok(())
1250}
1251
1252async fn call(
1253 request: proto::Call,
1254 response: Response<proto::Call>,
1255 session: Session,
1256) -> Result<()> {
1257 let room_id = RoomId::from_proto(request.room_id);
1258 let calling_user_id = session.user_id;
1259 let calling_connection_id = session.connection_id;
1260 let called_user_id = UserId::from_proto(request.called_user_id);
1261 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1262 if !session
1263 .db()
1264 .await
1265 .has_contact(calling_user_id, called_user_id)
1266 .await?
1267 {
1268 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1269 }
1270
1271 let incoming_call = {
1272 let (room, incoming_call) = &mut *session
1273 .db()
1274 .await
1275 .call(
1276 room_id,
1277 calling_user_id,
1278 calling_connection_id,
1279 called_user_id,
1280 initial_project_id,
1281 )
1282 .await?;
1283 room_updated(&room, &session.peer);
1284 mem::take(incoming_call)
1285 };
1286 update_user_contacts(called_user_id, &session).await?;
1287
1288 let mut calls = session
1289 .connection_pool()
1290 .await
1291 .user_connection_ids(called_user_id)
1292 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1293 .collect::<FuturesUnordered<_>>();
1294
1295 while let Some(call_response) = calls.next().await {
1296 match call_response.as_ref() {
1297 Ok(_) => {
1298 response.send(proto::Ack {})?;
1299 return Ok(());
1300 }
1301 Err(_) => {
1302 call_response.trace_err();
1303 }
1304 }
1305 }
1306
1307 {
1308 let room = session
1309 .db()
1310 .await
1311 .call_failed(room_id, called_user_id)
1312 .await?;
1313 room_updated(&room, &session.peer);
1314 }
1315 update_user_contacts(called_user_id, &session).await?;
1316
1317 Err(anyhow!("failed to ring user"))?
1318}
1319
1320async fn cancel_call(
1321 request: proto::CancelCall,
1322 response: Response<proto::CancelCall>,
1323 session: Session,
1324) -> Result<()> {
1325 let called_user_id = UserId::from_proto(request.called_user_id);
1326 let room_id = RoomId::from_proto(request.room_id);
1327 {
1328 let room = session
1329 .db()
1330 .await
1331 .cancel_call(room_id, session.connection_id, called_user_id)
1332 .await?;
1333 room_updated(&room, &session.peer);
1334 }
1335
1336 for connection_id in session
1337 .connection_pool()
1338 .await
1339 .user_connection_ids(called_user_id)
1340 {
1341 session
1342 .peer
1343 .send(
1344 connection_id,
1345 proto::CallCanceled {
1346 room_id: room_id.to_proto(),
1347 },
1348 )
1349 .trace_err();
1350 }
1351 response.send(proto::Ack {})?;
1352
1353 update_user_contacts(called_user_id, &session).await?;
1354 Ok(())
1355}
1356
1357async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1358 let room_id = RoomId::from_proto(message.room_id);
1359 {
1360 let room = session
1361 .db()
1362 .await
1363 .decline_call(Some(room_id), session.user_id)
1364 .await?
1365 .ok_or_else(|| anyhow!("failed to decline call"))?;
1366 room_updated(&room, &session.peer);
1367 }
1368
1369 for connection_id in session
1370 .connection_pool()
1371 .await
1372 .user_connection_ids(session.user_id)
1373 {
1374 session
1375 .peer
1376 .send(
1377 connection_id,
1378 proto::CallCanceled {
1379 room_id: room_id.to_proto(),
1380 },
1381 )
1382 .trace_err();
1383 }
1384 update_user_contacts(session.user_id, &session).await?;
1385 Ok(())
1386}
1387
1388async fn update_participant_location(
1389 request: proto::UpdateParticipantLocation,
1390 response: Response<proto::UpdateParticipantLocation>,
1391 session: Session,
1392) -> Result<()> {
1393 let room_id = RoomId::from_proto(request.room_id);
1394 let location = request
1395 .location
1396 .ok_or_else(|| anyhow!("invalid location"))?;
1397
1398 let db = session.db().await;
1399 let room = db
1400 .update_room_participant_location(room_id, session.connection_id, location)
1401 .await?;
1402
1403 room_updated(&room, &session.peer);
1404 response.send(proto::Ack {})?;
1405 Ok(())
1406}
1407
1408async fn share_project(
1409 request: proto::ShareProject,
1410 response: Response<proto::ShareProject>,
1411 session: Session,
1412) -> Result<()> {
1413 let (project_id, room) = &*session
1414 .db()
1415 .await
1416 .share_project(
1417 RoomId::from_proto(request.room_id),
1418 session.connection_id,
1419 &request.worktrees,
1420 )
1421 .await?;
1422 response.send(proto::ShareProjectResponse {
1423 project_id: project_id.to_proto(),
1424 })?;
1425 room_updated(&room, &session.peer);
1426
1427 Ok(())
1428}
1429
1430async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1431 let project_id = ProjectId::from_proto(message.project_id);
1432
1433 let (room, guest_connection_ids) = &*session
1434 .db()
1435 .await
1436 .unshare_project(project_id, session.connection_id)
1437 .await?;
1438
1439 broadcast(
1440 Some(session.connection_id),
1441 guest_connection_ids.iter().copied(),
1442 |conn_id| session.peer.send(conn_id, message.clone()),
1443 );
1444 room_updated(&room, &session.peer);
1445
1446 Ok(())
1447}
1448
1449async fn join_project(
1450 request: proto::JoinProject,
1451 response: Response<proto::JoinProject>,
1452 session: Session,
1453) -> Result<()> {
1454 let project_id = ProjectId::from_proto(request.project_id);
1455 let guest_user_id = session.user_id;
1456
1457 tracing::info!(%project_id, "join project");
1458
1459 let (project, replica_id) = &mut *session
1460 .db()
1461 .await
1462 .join_project(project_id, session.connection_id)
1463 .await?;
1464
1465 let collaborators = project
1466 .collaborators
1467 .iter()
1468 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1469 .map(|collaborator| collaborator.to_proto())
1470 .collect::<Vec<_>>();
1471
1472 let worktrees = project
1473 .worktrees
1474 .iter()
1475 .map(|(id, worktree)| proto::WorktreeMetadata {
1476 id: *id,
1477 root_name: worktree.root_name.clone(),
1478 visible: worktree.visible,
1479 abs_path: worktree.abs_path.clone(),
1480 })
1481 .collect::<Vec<_>>();
1482
1483 for collaborator in &collaborators {
1484 session
1485 .peer
1486 .send(
1487 collaborator.peer_id.unwrap().into(),
1488 proto::AddProjectCollaborator {
1489 project_id: project_id.to_proto(),
1490 collaborator: Some(proto::Collaborator {
1491 peer_id: Some(session.connection_id.into()),
1492 replica_id: replica_id.0 as u32,
1493 user_id: guest_user_id.to_proto(),
1494 }),
1495 },
1496 )
1497 .trace_err();
1498 }
1499
1500 // First, we send the metadata associated with each worktree.
1501 response.send(proto::JoinProjectResponse {
1502 worktrees: worktrees.clone(),
1503 replica_id: replica_id.0 as u32,
1504 collaborators: collaborators.clone(),
1505 language_servers: project.language_servers.clone(),
1506 })?;
1507
1508 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1509 #[cfg(any(test, feature = "test-support"))]
1510 const MAX_CHUNK_SIZE: usize = 2;
1511 #[cfg(not(any(test, feature = "test-support")))]
1512 const MAX_CHUNK_SIZE: usize = 256;
1513
1514 // Stream this worktree's entries.
1515 let message = proto::UpdateWorktree {
1516 project_id: project_id.to_proto(),
1517 worktree_id,
1518 abs_path: worktree.abs_path.clone(),
1519 root_name: worktree.root_name,
1520 updated_entries: worktree.entries,
1521 removed_entries: Default::default(),
1522 scan_id: worktree.scan_id,
1523 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1524 updated_repositories: worktree.repository_entries.into_values().collect(),
1525 removed_repositories: Default::default(),
1526 };
1527 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1528 session.peer.send(session.connection_id, update.clone())?;
1529 }
1530
1531 // Stream this worktree's diagnostics.
1532 for summary in worktree.diagnostic_summaries {
1533 session.peer.send(
1534 session.connection_id,
1535 proto::UpdateDiagnosticSummary {
1536 project_id: project_id.to_proto(),
1537 worktree_id: worktree.id,
1538 summary: Some(summary),
1539 },
1540 )?;
1541 }
1542
1543 for settings_file in worktree.settings_files {
1544 session.peer.send(
1545 session.connection_id,
1546 proto::UpdateWorktreeSettings {
1547 project_id: project_id.to_proto(),
1548 worktree_id: worktree.id,
1549 path: settings_file.path,
1550 content: Some(settings_file.content),
1551 },
1552 )?;
1553 }
1554 }
1555
1556 for language_server in &project.language_servers {
1557 session.peer.send(
1558 session.connection_id,
1559 proto::UpdateLanguageServer {
1560 project_id: project_id.to_proto(),
1561 language_server_id: language_server.id,
1562 variant: Some(
1563 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1564 proto::LspDiskBasedDiagnosticsUpdated {},
1565 ),
1566 ),
1567 },
1568 )?;
1569 }
1570
1571 Ok(())
1572}
1573
1574async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1575 let sender_id = session.connection_id;
1576 let project_id = ProjectId::from_proto(request.project_id);
1577
1578 let (room, project) = &*session
1579 .db()
1580 .await
1581 .leave_project(project_id, sender_id)
1582 .await?;
1583 tracing::info!(
1584 %project_id,
1585 host_user_id = %project.host_user_id,
1586 host_connection_id = %project.host_connection_id,
1587 "leave project"
1588 );
1589
1590 project_left(&project, &session);
1591 room_updated(&room, &session.peer);
1592
1593 Ok(())
1594}
1595
1596async fn update_project(
1597 request: proto::UpdateProject,
1598 response: Response<proto::UpdateProject>,
1599 session: Session,
1600) -> Result<()> {
1601 let project_id = ProjectId::from_proto(request.project_id);
1602 let (room, guest_connection_ids) = &*session
1603 .db()
1604 .await
1605 .update_project(project_id, session.connection_id, &request.worktrees)
1606 .await?;
1607 broadcast(
1608 Some(session.connection_id),
1609 guest_connection_ids.iter().copied(),
1610 |connection_id| {
1611 session
1612 .peer
1613 .forward_send(session.connection_id, connection_id, request.clone())
1614 },
1615 );
1616 room_updated(&room, &session.peer);
1617 response.send(proto::Ack {})?;
1618
1619 Ok(())
1620}
1621
1622async fn update_worktree(
1623 request: proto::UpdateWorktree,
1624 response: Response<proto::UpdateWorktree>,
1625 session: Session,
1626) -> Result<()> {
1627 let guest_connection_ids = session
1628 .db()
1629 .await
1630 .update_worktree(&request, session.connection_id)
1631 .await?;
1632
1633 broadcast(
1634 Some(session.connection_id),
1635 guest_connection_ids.iter().copied(),
1636 |connection_id| {
1637 session
1638 .peer
1639 .forward_send(session.connection_id, connection_id, request.clone())
1640 },
1641 );
1642 response.send(proto::Ack {})?;
1643 Ok(())
1644}
1645
1646async fn update_diagnostic_summary(
1647 message: proto::UpdateDiagnosticSummary,
1648 session: Session,
1649) -> Result<()> {
1650 let guest_connection_ids = session
1651 .db()
1652 .await
1653 .update_diagnostic_summary(&message, session.connection_id)
1654 .await?;
1655
1656 broadcast(
1657 Some(session.connection_id),
1658 guest_connection_ids.iter().copied(),
1659 |connection_id| {
1660 session
1661 .peer
1662 .forward_send(session.connection_id, connection_id, message.clone())
1663 },
1664 );
1665
1666 Ok(())
1667}
1668
1669async fn update_worktree_settings(
1670 message: proto::UpdateWorktreeSettings,
1671 session: Session,
1672) -> Result<()> {
1673 let guest_connection_ids = session
1674 .db()
1675 .await
1676 .update_worktree_settings(&message, session.connection_id)
1677 .await?;
1678
1679 broadcast(
1680 Some(session.connection_id),
1681 guest_connection_ids.iter().copied(),
1682 |connection_id| {
1683 session
1684 .peer
1685 .forward_send(session.connection_id, connection_id, message.clone())
1686 },
1687 );
1688
1689 Ok(())
1690}
1691
1692async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1693 broadcast_project_message(request.project_id, request, session).await
1694}
1695
1696async fn start_language_server(
1697 request: proto::StartLanguageServer,
1698 session: Session,
1699) -> Result<()> {
1700 let guest_connection_ids = session
1701 .db()
1702 .await
1703 .start_language_server(&request, session.connection_id)
1704 .await?;
1705
1706 broadcast(
1707 Some(session.connection_id),
1708 guest_connection_ids.iter().copied(),
1709 |connection_id| {
1710 session
1711 .peer
1712 .forward_send(session.connection_id, connection_id, request.clone())
1713 },
1714 );
1715 Ok(())
1716}
1717
1718async fn update_language_server(
1719 request: proto::UpdateLanguageServer,
1720 session: Session,
1721) -> Result<()> {
1722 session.executor.record_backtrace();
1723 let project_id = ProjectId::from_proto(request.project_id);
1724 let project_connection_ids = session
1725 .db()
1726 .await
1727 .project_connection_ids(project_id, session.connection_id)
1728 .await?;
1729 broadcast(
1730 Some(session.connection_id),
1731 project_connection_ids.iter().copied(),
1732 |connection_id| {
1733 session
1734 .peer
1735 .forward_send(session.connection_id, connection_id, request.clone())
1736 },
1737 );
1738 Ok(())
1739}
1740
1741async fn forward_project_request<T>(
1742 request: T,
1743 response: Response<T>,
1744 session: Session,
1745) -> Result<()>
1746where
1747 T: EntityMessage + RequestMessage,
1748{
1749 session.executor.record_backtrace();
1750 let project_id = ProjectId::from_proto(request.remote_entity_id());
1751 let host_connection_id = {
1752 let collaborators = session
1753 .db()
1754 .await
1755 .project_collaborators(project_id, session.connection_id)
1756 .await?;
1757 collaborators
1758 .iter()
1759 .find(|collaborator| collaborator.is_host)
1760 .ok_or_else(|| anyhow!("host not found"))?
1761 .connection_id
1762 };
1763
1764 let payload = session
1765 .peer
1766 .forward_request(session.connection_id, host_connection_id, request)
1767 .await?;
1768
1769 response.send(payload)?;
1770 Ok(())
1771}
1772
1773async fn create_buffer_for_peer(
1774 request: proto::CreateBufferForPeer,
1775 session: Session,
1776) -> Result<()> {
1777 session.executor.record_backtrace();
1778 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1779 session
1780 .peer
1781 .forward_send(session.connection_id, peer_id.into(), request)?;
1782 Ok(())
1783}
1784
1785async fn update_buffer(
1786 request: proto::UpdateBuffer,
1787 response: Response<proto::UpdateBuffer>,
1788 session: Session,
1789) -> Result<()> {
1790 session.executor.record_backtrace();
1791 let project_id = ProjectId::from_proto(request.project_id);
1792 let mut guest_connection_ids;
1793 let mut host_connection_id = None;
1794 {
1795 let collaborators = session
1796 .db()
1797 .await
1798 .project_collaborators(project_id, session.connection_id)
1799 .await?;
1800 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1801 for collaborator in collaborators.iter() {
1802 if collaborator.is_host {
1803 host_connection_id = Some(collaborator.connection_id);
1804 } else {
1805 guest_connection_ids.push(collaborator.connection_id);
1806 }
1807 }
1808 }
1809 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1810
1811 session.executor.record_backtrace();
1812 broadcast(
1813 Some(session.connection_id),
1814 guest_connection_ids,
1815 |connection_id| {
1816 session
1817 .peer
1818 .forward_send(session.connection_id, connection_id, request.clone())
1819 },
1820 );
1821 if host_connection_id != session.connection_id {
1822 session
1823 .peer
1824 .forward_request(session.connection_id, host_connection_id, request.clone())
1825 .await?;
1826 }
1827
1828 response.send(proto::Ack {})?;
1829 Ok(())
1830}
1831
1832async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1833 let project_id = ProjectId::from_proto(request.project_id);
1834 let project_connection_ids = session
1835 .db()
1836 .await
1837 .project_connection_ids(project_id, session.connection_id)
1838 .await?;
1839
1840 broadcast(
1841 Some(session.connection_id),
1842 project_connection_ids.iter().copied(),
1843 |connection_id| {
1844 session
1845 .peer
1846 .forward_send(session.connection_id, connection_id, request.clone())
1847 },
1848 );
1849 Ok(())
1850}
1851
1852async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1853 let project_id = ProjectId::from_proto(request.project_id);
1854 let project_connection_ids = session
1855 .db()
1856 .await
1857 .project_connection_ids(project_id, session.connection_id)
1858 .await?;
1859 broadcast(
1860 Some(session.connection_id),
1861 project_connection_ids.iter().copied(),
1862 |connection_id| {
1863 session
1864 .peer
1865 .forward_send(session.connection_id, connection_id, request.clone())
1866 },
1867 );
1868 Ok(())
1869}
1870
1871async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1872 broadcast_project_message(request.project_id, request, session).await
1873}
1874
1875async fn broadcast_project_message<T: EnvelopedMessage>(
1876 project_id: u64,
1877 request: T,
1878 session: Session,
1879) -> Result<()> {
1880 let project_id = ProjectId::from_proto(project_id);
1881 let project_connection_ids = session
1882 .db()
1883 .await
1884 .project_connection_ids(project_id, session.connection_id)
1885 .await?;
1886 broadcast(
1887 Some(session.connection_id),
1888 project_connection_ids.iter().copied(),
1889 |connection_id| {
1890 session
1891 .peer
1892 .forward_send(session.connection_id, connection_id, request.clone())
1893 },
1894 );
1895 Ok(())
1896}
1897
1898async fn follow(
1899 request: proto::Follow,
1900 response: Response<proto::Follow>,
1901 session: Session,
1902) -> Result<()> {
1903 let room_id = RoomId::from_proto(request.room_id);
1904 let project_id = request.project_id.map(ProjectId::from_proto);
1905 let leader_id = request
1906 .leader_id
1907 .ok_or_else(|| anyhow!("invalid leader id"))?
1908 .into();
1909 let follower_id = session.connection_id;
1910
1911 session
1912 .db()
1913 .await
1914 .check_room_participants(room_id, leader_id, session.connection_id)
1915 .await?;
1916
1917 let response_payload = session
1918 .peer
1919 .forward_request(session.connection_id, leader_id, request)
1920 .await?;
1921 response.send(response_payload)?;
1922
1923 if let Some(project_id) = project_id {
1924 let room = session
1925 .db()
1926 .await
1927 .follow(room_id, project_id, leader_id, follower_id)
1928 .await?;
1929 room_updated(&room, &session.peer);
1930 }
1931
1932 Ok(())
1933}
1934
1935async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1936 let room_id = RoomId::from_proto(request.room_id);
1937 let project_id = request.project_id.map(ProjectId::from_proto);
1938 let leader_id = request
1939 .leader_id
1940 .ok_or_else(|| anyhow!("invalid leader id"))?
1941 .into();
1942 let follower_id = session.connection_id;
1943
1944 session
1945 .db()
1946 .await
1947 .check_room_participants(room_id, leader_id, session.connection_id)
1948 .await?;
1949
1950 session
1951 .peer
1952 .forward_send(session.connection_id, leader_id, request)?;
1953
1954 if let Some(project_id) = project_id {
1955 let room = session
1956 .db()
1957 .await
1958 .unfollow(room_id, project_id, leader_id, follower_id)
1959 .await?;
1960 room_updated(&room, &session.peer);
1961 }
1962
1963 Ok(())
1964}
1965
1966async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1967 let room_id = RoomId::from_proto(request.room_id);
1968 let database = session.db.lock().await;
1969
1970 let connection_ids = if let Some(project_id) = request.project_id {
1971 let project_id = ProjectId::from_proto(project_id);
1972 database
1973 .project_connection_ids(project_id, session.connection_id)
1974 .await?
1975 } else {
1976 database
1977 .room_connection_ids(room_id, session.connection_id)
1978 .await?
1979 };
1980
1981 // For now, don't send view update messages back to that view's current leader.
1982 let connection_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
1983 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1984 _ => None,
1985 });
1986
1987 for follower_peer_id in request.follower_ids.iter().copied() {
1988 let follower_connection_id = follower_peer_id.into();
1989 if Some(follower_peer_id) != connection_id_to_omit
1990 && connection_ids.contains(&follower_connection_id)
1991 {
1992 session.peer.forward_send(
1993 session.connection_id,
1994 follower_connection_id,
1995 request.clone(),
1996 )?;
1997 }
1998 }
1999 Ok(())
2000}
2001
2002async fn get_users(
2003 request: proto::GetUsers,
2004 response: Response<proto::GetUsers>,
2005 session: Session,
2006) -> Result<()> {
2007 let user_ids = request
2008 .user_ids
2009 .into_iter()
2010 .map(UserId::from_proto)
2011 .collect();
2012 let users = session
2013 .db()
2014 .await
2015 .get_users_by_ids(user_ids)
2016 .await?
2017 .into_iter()
2018 .map(|user| proto::User {
2019 id: user.id.to_proto(),
2020 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2021 github_login: user.github_login,
2022 })
2023 .collect();
2024 response.send(proto::UsersResponse { users })?;
2025 Ok(())
2026}
2027
2028async fn fuzzy_search_users(
2029 request: proto::FuzzySearchUsers,
2030 response: Response<proto::FuzzySearchUsers>,
2031 session: Session,
2032) -> Result<()> {
2033 let query = request.query;
2034 let users = match query.len() {
2035 0 => vec![],
2036 1 | 2 => session
2037 .db()
2038 .await
2039 .get_user_by_github_login(&query)
2040 .await?
2041 .into_iter()
2042 .collect(),
2043 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2044 };
2045 let users = users
2046 .into_iter()
2047 .filter(|user| user.id != session.user_id)
2048 .map(|user| proto::User {
2049 id: user.id.to_proto(),
2050 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2051 github_login: user.github_login,
2052 })
2053 .collect();
2054 response.send(proto::UsersResponse { users })?;
2055 Ok(())
2056}
2057
2058async fn request_contact(
2059 request: proto::RequestContact,
2060 response: Response<proto::RequestContact>,
2061 session: Session,
2062) -> Result<()> {
2063 let requester_id = session.user_id;
2064 let responder_id = UserId::from_proto(request.responder_id);
2065 if requester_id == responder_id {
2066 return Err(anyhow!("cannot add yourself as a contact"))?;
2067 }
2068
2069 let notifications = session
2070 .db()
2071 .await
2072 .send_contact_request(requester_id, responder_id)
2073 .await?;
2074
2075 // Update outgoing contact requests of requester
2076 let mut update = proto::UpdateContacts::default();
2077 update.outgoing_requests.push(responder_id.to_proto());
2078 for connection_id in session
2079 .connection_pool()
2080 .await
2081 .user_connection_ids(requester_id)
2082 {
2083 session.peer.send(connection_id, update.clone())?;
2084 }
2085
2086 // Update incoming contact requests of responder
2087 let mut update = proto::UpdateContacts::default();
2088 update
2089 .incoming_requests
2090 .push(proto::IncomingContactRequest {
2091 requester_id: requester_id.to_proto(),
2092 });
2093 let connection_pool = session.connection_pool().await;
2094 for connection_id in connection_pool.user_connection_ids(responder_id) {
2095 session.peer.send(connection_id, update.clone())?;
2096 }
2097
2098 send_notifications(&*connection_pool, &session.peer, notifications);
2099
2100 response.send(proto::Ack {})?;
2101 Ok(())
2102}
2103
2104async fn respond_to_contact_request(
2105 request: proto::RespondToContactRequest,
2106 response: Response<proto::RespondToContactRequest>,
2107 session: Session,
2108) -> Result<()> {
2109 let responder_id = session.user_id;
2110 let requester_id = UserId::from_proto(request.requester_id);
2111 let db = session.db().await;
2112 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2113 db.dismiss_contact_notification(responder_id, requester_id)
2114 .await?;
2115 } else {
2116 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2117
2118 let notifications = db
2119 .respond_to_contact_request(responder_id, requester_id, accept)
2120 .await?;
2121 let requester_busy = db.is_user_busy(requester_id).await?;
2122 let responder_busy = db.is_user_busy(responder_id).await?;
2123
2124 let pool = session.connection_pool().await;
2125 // Update responder with new contact
2126 let mut update = proto::UpdateContacts::default();
2127 if accept {
2128 update
2129 .contacts
2130 .push(contact_for_user(requester_id, requester_busy, &pool));
2131 }
2132 update
2133 .remove_incoming_requests
2134 .push(requester_id.to_proto());
2135 for connection_id in pool.user_connection_ids(responder_id) {
2136 session.peer.send(connection_id, update.clone())?;
2137 }
2138
2139 // Update requester with new contact
2140 let mut update = proto::UpdateContacts::default();
2141 if accept {
2142 update
2143 .contacts
2144 .push(contact_for_user(responder_id, responder_busy, &pool));
2145 }
2146 update
2147 .remove_outgoing_requests
2148 .push(responder_id.to_proto());
2149
2150 for connection_id in pool.user_connection_ids(requester_id) {
2151 session.peer.send(connection_id, update.clone())?;
2152 }
2153
2154 send_notifications(&*pool, &session.peer, notifications);
2155 }
2156
2157 response.send(proto::Ack {})?;
2158 Ok(())
2159}
2160
2161async fn remove_contact(
2162 request: proto::RemoveContact,
2163 response: Response<proto::RemoveContact>,
2164 session: Session,
2165) -> Result<()> {
2166 let requester_id = session.user_id;
2167 let responder_id = UserId::from_proto(request.user_id);
2168 let db = session.db().await;
2169 let (contact_accepted, deleted_notification_id) =
2170 db.remove_contact(requester_id, responder_id).await?;
2171
2172 let pool = session.connection_pool().await;
2173 // Update outgoing contact requests of requester
2174 let mut update = proto::UpdateContacts::default();
2175 if contact_accepted {
2176 update.remove_contacts.push(responder_id.to_proto());
2177 } else {
2178 update
2179 .remove_outgoing_requests
2180 .push(responder_id.to_proto());
2181 }
2182 for connection_id in pool.user_connection_ids(requester_id) {
2183 session.peer.send(connection_id, update.clone())?;
2184 }
2185
2186 // Update incoming contact requests of responder
2187 let mut update = proto::UpdateContacts::default();
2188 if contact_accepted {
2189 update.remove_contacts.push(requester_id.to_proto());
2190 } else {
2191 update
2192 .remove_incoming_requests
2193 .push(requester_id.to_proto());
2194 }
2195 for connection_id in pool.user_connection_ids(responder_id) {
2196 session.peer.send(connection_id, update.clone())?;
2197 if let Some(notification_id) = deleted_notification_id {
2198 session.peer.send(
2199 connection_id,
2200 proto::DeleteNotification {
2201 notification_id: notification_id.to_proto(),
2202 },
2203 )?;
2204 }
2205 }
2206
2207 response.send(proto::Ack {})?;
2208 Ok(())
2209}
2210
2211async fn create_channel(
2212 request: proto::CreateChannel,
2213 response: Response<proto::CreateChannel>,
2214 session: Session,
2215) -> Result<()> {
2216 let db = session.db().await;
2217
2218 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2219 let id = db
2220 .create_channel(&request.name, parent_id, session.user_id)
2221 .await?;
2222
2223 let channel = proto::Channel {
2224 id: id.to_proto(),
2225 name: request.name,
2226 visibility: proto::ChannelVisibility::Members as i32,
2227 };
2228
2229 response.send(proto::CreateChannelResponse {
2230 channel: Some(channel.clone()),
2231 parent_id: request.parent_id,
2232 })?;
2233
2234 let Some(parent_id) = parent_id else {
2235 return Ok(());
2236 };
2237
2238 let update = proto::UpdateChannels {
2239 channels: vec![channel],
2240 insert_edge: vec![ChannelEdge {
2241 parent_id: parent_id.to_proto(),
2242 channel_id: id.to_proto(),
2243 }],
2244 ..Default::default()
2245 };
2246
2247 let user_ids_to_notify = db.get_channel_members(parent_id).await?;
2248
2249 let connection_pool = session.connection_pool().await;
2250 for user_id in user_ids_to_notify {
2251 for connection_id in connection_pool.user_connection_ids(user_id) {
2252 if user_id == session.user_id {
2253 continue;
2254 }
2255 session.peer.send(connection_id, update.clone())?;
2256 }
2257 }
2258
2259 Ok(())
2260}
2261
2262async fn delete_channel(
2263 request: proto::DeleteChannel,
2264 response: Response<proto::DeleteChannel>,
2265 session: Session,
2266) -> Result<()> {
2267 let db = session.db().await;
2268
2269 let channel_id = request.channel_id;
2270 let (removed_channels, member_ids) = db
2271 .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2272 .await?;
2273 response.send(proto::Ack {})?;
2274
2275 // Notify members of removed channels
2276 let mut update = proto::UpdateChannels::default();
2277 update
2278 .delete_channels
2279 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2280
2281 let connection_pool = session.connection_pool().await;
2282 for member_id in member_ids {
2283 for connection_id in connection_pool.user_connection_ids(member_id) {
2284 session.peer.send(connection_id, update.clone())?;
2285 }
2286 }
2287
2288 Ok(())
2289}
2290
2291async fn invite_channel_member(
2292 request: proto::InviteChannelMember,
2293 response: Response<proto::InviteChannelMember>,
2294 session: Session,
2295) -> Result<()> {
2296 let db = session.db().await;
2297 let channel_id = ChannelId::from_proto(request.channel_id);
2298 let invitee_id = UserId::from_proto(request.user_id);
2299 let notifications = db
2300 .invite_channel_member(
2301 channel_id,
2302 invitee_id,
2303 session.user_id,
2304 request.role().into(),
2305 )
2306 .await?;
2307
2308 let channel = db.get_channel(channel_id, session.user_id).await?;
2309
2310 let mut update = proto::UpdateChannels::default();
2311 update.channel_invitations.push(proto::Channel {
2312 id: channel.id.to_proto(),
2313 visibility: channel.visibility.into(),
2314 name: channel.name,
2315 });
2316
2317 let pool = session.connection_pool().await;
2318 for connection_id in pool.user_connection_ids(invitee_id) {
2319 session.peer.send(connection_id, update.clone())?;
2320 }
2321
2322 send_notifications(&*pool, &session.peer, notifications);
2323
2324 response.send(proto::Ack {})?;
2325 Ok(())
2326}
2327
2328async fn remove_channel_member(
2329 request: proto::RemoveChannelMember,
2330 response: Response<proto::RemoveChannelMember>,
2331 session: Session,
2332) -> Result<()> {
2333 let db = session.db().await;
2334 let channel_id = ChannelId::from_proto(request.channel_id);
2335 let member_id = UserId::from_proto(request.user_id);
2336
2337 let removed_notification_id = db
2338 .remove_channel_member(channel_id, member_id, session.user_id)
2339 .await?;
2340
2341 let mut update = proto::UpdateChannels::default();
2342 update.delete_channels.push(channel_id.to_proto());
2343
2344 for connection_id in session
2345 .connection_pool()
2346 .await
2347 .user_connection_ids(member_id)
2348 {
2349 session.peer.send(connection_id, update.clone()).trace_err();
2350 if let Some(notification_id) = removed_notification_id {
2351 session
2352 .peer
2353 .send(
2354 connection_id,
2355 proto::DeleteNotification {
2356 notification_id: notification_id.to_proto(),
2357 },
2358 )
2359 .trace_err();
2360 }
2361 }
2362
2363 response.send(proto::Ack {})?;
2364 Ok(())
2365}
2366
2367async fn set_channel_visibility(
2368 request: proto::SetChannelVisibility,
2369 response: Response<proto::SetChannelVisibility>,
2370 session: Session,
2371) -> Result<()> {
2372 let db = session.db().await;
2373 let channel_id = ChannelId::from_proto(request.channel_id);
2374 let visibility = request.visibility().into();
2375
2376 let channel = db
2377 .set_channel_visibility(channel_id, visibility, session.user_id)
2378 .await?;
2379
2380 let mut update = proto::UpdateChannels::default();
2381 update.channels.push(proto::Channel {
2382 id: channel.id.to_proto(),
2383 name: channel.name,
2384 visibility: channel.visibility.into(),
2385 });
2386
2387 let member_ids = db.get_channel_members(channel_id).await?;
2388
2389 let connection_pool = session.connection_pool().await;
2390 for member_id in member_ids {
2391 for connection_id in connection_pool.user_connection_ids(member_id) {
2392 session.peer.send(connection_id, update.clone())?;
2393 }
2394 }
2395
2396 response.send(proto::Ack {})?;
2397 Ok(())
2398}
2399
2400async fn set_channel_member_role(
2401 request: proto::SetChannelMemberRole,
2402 response: Response<proto::SetChannelMemberRole>,
2403 session: Session,
2404) -> Result<()> {
2405 let db = session.db().await;
2406 let channel_id = ChannelId::from_proto(request.channel_id);
2407 let member_id = UserId::from_proto(request.user_id);
2408 let channel_member = db
2409 .set_channel_member_role(
2410 channel_id,
2411 session.user_id,
2412 member_id,
2413 request.role().into(),
2414 )
2415 .await?;
2416
2417 let channel = db.get_channel(channel_id, session.user_id).await?;
2418
2419 let mut update = proto::UpdateChannels::default();
2420 if channel_member.accepted {
2421 update.channel_permissions.push(proto::ChannelPermission {
2422 channel_id: channel.id.to_proto(),
2423 role: request.role,
2424 });
2425 }
2426
2427 for connection_id in session
2428 .connection_pool()
2429 .await
2430 .user_connection_ids(member_id)
2431 {
2432 session.peer.send(connection_id, update.clone())?;
2433 }
2434
2435 response.send(proto::Ack {})?;
2436 Ok(())
2437}
2438
2439async fn rename_channel(
2440 request: proto::RenameChannel,
2441 response: Response<proto::RenameChannel>,
2442 session: Session,
2443) -> Result<()> {
2444 let db = session.db().await;
2445 let channel_id = ChannelId::from_proto(request.channel_id);
2446 let channel = db
2447 .rename_channel(channel_id, session.user_id, &request.name)
2448 .await?;
2449
2450 let channel = proto::Channel {
2451 id: channel.id.to_proto(),
2452 name: channel.name,
2453 visibility: channel.visibility.into(),
2454 };
2455 response.send(proto::RenameChannelResponse {
2456 channel: Some(channel.clone()),
2457 })?;
2458 let mut update = proto::UpdateChannels::default();
2459 update.channels.push(channel);
2460
2461 let member_ids = db.get_channel_members(channel_id).await?;
2462
2463 let connection_pool = session.connection_pool().await;
2464 for member_id in member_ids {
2465 for connection_id in connection_pool.user_connection_ids(member_id) {
2466 session.peer.send(connection_id, update.clone())?;
2467 }
2468 }
2469
2470 Ok(())
2471}
2472
2473async fn link_channel(
2474 request: proto::LinkChannel,
2475 response: Response<proto::LinkChannel>,
2476 session: Session,
2477) -> Result<()> {
2478 let db = session.db().await;
2479 let channel_id = ChannelId::from_proto(request.channel_id);
2480 let to = ChannelId::from_proto(request.to);
2481 let channels_to_send = db.link_channel(session.user_id, channel_id, to).await?;
2482
2483 let members = db.get_channel_members(to).await?;
2484 let connection_pool = session.connection_pool().await;
2485 let update = proto::UpdateChannels {
2486 channels: channels_to_send
2487 .channels
2488 .into_iter()
2489 .map(|channel| proto::Channel {
2490 id: channel.id.to_proto(),
2491 visibility: channel.visibility.into(),
2492 name: channel.name,
2493 })
2494 .collect(),
2495 insert_edge: channels_to_send.edges,
2496 ..Default::default()
2497 };
2498 for member_id in members {
2499 for connection_id in connection_pool.user_connection_ids(member_id) {
2500 session.peer.send(connection_id, update.clone())?;
2501 }
2502 }
2503
2504 response.send(Ack {})?;
2505
2506 Ok(())
2507}
2508
2509async fn unlink_channel(
2510 request: proto::UnlinkChannel,
2511 response: Response<proto::UnlinkChannel>,
2512 session: Session,
2513) -> Result<()> {
2514 let db = session.db().await;
2515 let channel_id = ChannelId::from_proto(request.channel_id);
2516 let from = ChannelId::from_proto(request.from);
2517
2518 db.unlink_channel(session.user_id, channel_id, from).await?;
2519
2520 let members = db.get_channel_members(from).await?;
2521
2522 let update = proto::UpdateChannels {
2523 delete_edge: vec![proto::ChannelEdge {
2524 channel_id: channel_id.to_proto(),
2525 parent_id: from.to_proto(),
2526 }],
2527 ..Default::default()
2528 };
2529 let connection_pool = session.connection_pool().await;
2530 for member_id in members {
2531 for connection_id in connection_pool.user_connection_ids(member_id) {
2532 session.peer.send(connection_id, update.clone())?;
2533 }
2534 }
2535
2536 response.send(Ack {})?;
2537
2538 Ok(())
2539}
2540
2541async fn move_channel(
2542 request: proto::MoveChannel,
2543 response: Response<proto::MoveChannel>,
2544 session: Session,
2545) -> Result<()> {
2546 let db = session.db().await;
2547 let channel_id = ChannelId::from_proto(request.channel_id);
2548 let from_parent = ChannelId::from_proto(request.from);
2549 let to = ChannelId::from_proto(request.to);
2550
2551 let channels_to_send = db
2552 .move_channel(session.user_id, channel_id, from_parent, to)
2553 .await?;
2554
2555 if channels_to_send.is_empty() {
2556 response.send(Ack {})?;
2557 return Ok(());
2558 }
2559
2560 let members_from = db.get_channel_members(from_parent).await?;
2561 let members_to = db.get_channel_members(to).await?;
2562
2563 let update = proto::UpdateChannels {
2564 delete_edge: vec![proto::ChannelEdge {
2565 channel_id: channel_id.to_proto(),
2566 parent_id: from_parent.to_proto(),
2567 }],
2568 ..Default::default()
2569 };
2570 let connection_pool = session.connection_pool().await;
2571 for member_id in members_from {
2572 for connection_id in connection_pool.user_connection_ids(member_id) {
2573 session.peer.send(connection_id, update.clone())?;
2574 }
2575 }
2576
2577 let update = proto::UpdateChannels {
2578 channels: channels_to_send
2579 .channels
2580 .into_iter()
2581 .map(|channel| proto::Channel {
2582 id: channel.id.to_proto(),
2583 visibility: channel.visibility.into(),
2584 name: channel.name,
2585 })
2586 .collect(),
2587 insert_edge: channels_to_send.edges,
2588 ..Default::default()
2589 };
2590 for member_id in members_to {
2591 for connection_id in connection_pool.user_connection_ids(member_id) {
2592 session.peer.send(connection_id, update.clone())?;
2593 }
2594 }
2595
2596 response.send(Ack {})?;
2597
2598 Ok(())
2599}
2600
2601async fn get_channel_members(
2602 request: proto::GetChannelMembers,
2603 response: Response<proto::GetChannelMembers>,
2604 session: Session,
2605) -> Result<()> {
2606 let db = session.db().await;
2607 let channel_id = ChannelId::from_proto(request.channel_id);
2608 let members = db
2609 .get_channel_participant_details(channel_id, session.user_id)
2610 .await?;
2611 response.send(proto::GetChannelMembersResponse { members })?;
2612 Ok(())
2613}
2614
2615async fn respond_to_channel_invite(
2616 request: proto::RespondToChannelInvite,
2617 response: Response<proto::RespondToChannelInvite>,
2618 session: Session,
2619) -> Result<()> {
2620 let db = session.db().await;
2621 let channel_id = ChannelId::from_proto(request.channel_id);
2622 let notifications = db
2623 .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2624 .await?;
2625
2626 if request.accept {
2627 channel_membership_updated(db, channel_id, &session).await?;
2628 } else {
2629 let mut update = proto::UpdateChannels::default();
2630 update
2631 .remove_channel_invitations
2632 .push(channel_id.to_proto());
2633 session.peer.send(session.connection_id, update)?;
2634 }
2635
2636 send_notifications(
2637 &*session.connection_pool().await,
2638 &session.peer,
2639 notifications,
2640 );
2641 response.send(proto::Ack {})?;
2642
2643 Ok(())
2644}
2645
2646async fn channel_membership_updated(
2647 db: tokio::sync::MutexGuard<'_, DbHandle>,
2648 channel_id: ChannelId,
2649 session: &Session,
2650) -> Result<(), crate::Error> {
2651 let mut update = proto::UpdateChannels::default();
2652 update
2653 .remove_channel_invitations
2654 .push(channel_id.to_proto());
2655
2656 let result = db.get_channel_for_user(channel_id, session.user_id).await?;
2657 update.channels.extend(
2658 result
2659 .channels
2660 .channels
2661 .into_iter()
2662 .map(|channel| proto::Channel {
2663 id: channel.id.to_proto(),
2664 visibility: channel.visibility.into(),
2665 name: channel.name,
2666 }),
2667 );
2668 update.unseen_channel_messages = result.channel_messages;
2669 update.unseen_channel_buffer_changes = result.unseen_buffer_changes;
2670 update.insert_edge = result.channels.edges;
2671 update
2672 .channel_participants
2673 .extend(
2674 result
2675 .channel_participants
2676 .into_iter()
2677 .map(|(channel_id, user_ids)| proto::ChannelParticipants {
2678 channel_id: channel_id.to_proto(),
2679 participant_user_ids: user_ids.into_iter().map(UserId::to_proto).collect(),
2680 }),
2681 );
2682 update
2683 .channel_permissions
2684 .extend(
2685 result
2686 .channels_with_admin_privileges
2687 .into_iter()
2688 .map(|channel_id| proto::ChannelPermission {
2689 channel_id: channel_id.to_proto(),
2690 role: proto::ChannelRole::Admin.into(),
2691 }),
2692 );
2693 session.peer.send(session.connection_id, update)?;
2694 Ok(())
2695}
2696
2697async fn join_channel(
2698 request: proto::JoinChannel,
2699 response: Response<proto::JoinChannel>,
2700 session: Session,
2701) -> Result<()> {
2702 let channel_id = ChannelId::from_proto(request.channel_id);
2703 join_channel_internal(channel_id, Box::new(response), session).await
2704}
2705
2706trait JoinChannelInternalResponse {
2707 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2708}
2709impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2710 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2711 Response::<proto::JoinChannel>::send(self, result)
2712 }
2713}
2714impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2715 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2716 Response::<proto::JoinRoom>::send(self, result)
2717 }
2718}
2719
2720async fn join_channel_internal(
2721 channel_id: ChannelId,
2722 response: Box<impl JoinChannelInternalResponse>,
2723 session: Session,
2724) -> Result<()> {
2725 let joined_room = {
2726 leave_room_for_session(&session).await?;
2727 let db = session.db().await;
2728
2729 let (joined_room, joined_channel) = db
2730 .join_channel(
2731 channel_id,
2732 session.user_id,
2733 session.connection_id,
2734 RELEASE_CHANNEL_NAME.as_str(),
2735 )
2736 .await?;
2737
2738 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2739 let token = live_kit
2740 .room_token(
2741 &joined_room.room.live_kit_room,
2742 &session.user_id.to_string(),
2743 )
2744 .trace_err()?;
2745
2746 Some(LiveKitConnectionInfo {
2747 server_url: live_kit.url().into(),
2748 token,
2749 })
2750 });
2751
2752 response.send(proto::JoinRoomResponse {
2753 room: Some(joined_room.room.clone()),
2754 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2755 live_kit_connection_info,
2756 })?;
2757
2758 if let Some(joined_channel) = joined_channel {
2759 channel_membership_updated(db, joined_channel, &session).await?
2760 }
2761
2762 room_updated(&joined_room.room, &session.peer);
2763
2764 joined_room
2765 };
2766
2767 channel_updated(
2768 channel_id,
2769 &joined_room.room,
2770 &joined_room.channel_members,
2771 &session.peer,
2772 &*session.connection_pool().await,
2773 );
2774
2775 update_user_contacts(session.user_id, &session).await?;
2776 Ok(())
2777}
2778
2779async fn join_channel_buffer(
2780 request: proto::JoinChannelBuffer,
2781 response: Response<proto::JoinChannelBuffer>,
2782 session: Session,
2783) -> Result<()> {
2784 let db = session.db().await;
2785 let channel_id = ChannelId::from_proto(request.channel_id);
2786
2787 let open_response = db
2788 .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2789 .await?;
2790
2791 let collaborators = open_response.collaborators.clone();
2792 response.send(open_response)?;
2793
2794 let update = UpdateChannelBufferCollaborators {
2795 channel_id: channel_id.to_proto(),
2796 collaborators: collaborators.clone(),
2797 };
2798 channel_buffer_updated(
2799 session.connection_id,
2800 collaborators
2801 .iter()
2802 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2803 &update,
2804 &session.peer,
2805 );
2806
2807 Ok(())
2808}
2809
2810async fn update_channel_buffer(
2811 request: proto::UpdateChannelBuffer,
2812 session: Session,
2813) -> Result<()> {
2814 let db = session.db().await;
2815 let channel_id = ChannelId::from_proto(request.channel_id);
2816
2817 let (collaborators, non_collaborators, epoch, version) = db
2818 .update_channel_buffer(channel_id, session.user_id, &request.operations)
2819 .await?;
2820
2821 channel_buffer_updated(
2822 session.connection_id,
2823 collaborators,
2824 &proto::UpdateChannelBuffer {
2825 channel_id: channel_id.to_proto(),
2826 operations: request.operations,
2827 },
2828 &session.peer,
2829 );
2830
2831 let pool = &*session.connection_pool().await;
2832
2833 broadcast(
2834 None,
2835 non_collaborators
2836 .iter()
2837 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2838 |peer_id| {
2839 session.peer.send(
2840 peer_id.into(),
2841 proto::UpdateChannels {
2842 unseen_channel_buffer_changes: vec![proto::UnseenChannelBufferChange {
2843 channel_id: channel_id.to_proto(),
2844 epoch: epoch as u64,
2845 version: version.clone(),
2846 }],
2847 ..Default::default()
2848 },
2849 )
2850 },
2851 );
2852
2853 Ok(())
2854}
2855
2856async fn rejoin_channel_buffers(
2857 request: proto::RejoinChannelBuffers,
2858 response: Response<proto::RejoinChannelBuffers>,
2859 session: Session,
2860) -> Result<()> {
2861 let db = session.db().await;
2862 let buffers = db
2863 .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2864 .await?;
2865
2866 for rejoined_buffer in &buffers {
2867 let collaborators_to_notify = rejoined_buffer
2868 .buffer
2869 .collaborators
2870 .iter()
2871 .filter_map(|c| Some(c.peer_id?.into()));
2872 channel_buffer_updated(
2873 session.connection_id,
2874 collaborators_to_notify,
2875 &proto::UpdateChannelBufferCollaborators {
2876 channel_id: rejoined_buffer.buffer.channel_id,
2877 collaborators: rejoined_buffer.buffer.collaborators.clone(),
2878 },
2879 &session.peer,
2880 );
2881 }
2882
2883 response.send(proto::RejoinChannelBuffersResponse {
2884 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2885 })?;
2886
2887 Ok(())
2888}
2889
2890async fn leave_channel_buffer(
2891 request: proto::LeaveChannelBuffer,
2892 response: Response<proto::LeaveChannelBuffer>,
2893 session: Session,
2894) -> Result<()> {
2895 let db = session.db().await;
2896 let channel_id = ChannelId::from_proto(request.channel_id);
2897
2898 let left_buffer = db
2899 .leave_channel_buffer(channel_id, session.connection_id)
2900 .await?;
2901
2902 response.send(Ack {})?;
2903
2904 channel_buffer_updated(
2905 session.connection_id,
2906 left_buffer.connections,
2907 &proto::UpdateChannelBufferCollaborators {
2908 channel_id: channel_id.to_proto(),
2909 collaborators: left_buffer.collaborators,
2910 },
2911 &session.peer,
2912 );
2913
2914 Ok(())
2915}
2916
2917fn channel_buffer_updated<T: EnvelopedMessage>(
2918 sender_id: ConnectionId,
2919 collaborators: impl IntoIterator<Item = ConnectionId>,
2920 message: &T,
2921 peer: &Peer,
2922) {
2923 broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2924 peer.send(peer_id.into(), message.clone())
2925 });
2926}
2927
2928fn send_notifications(
2929 connection_pool: &ConnectionPool,
2930 peer: &Peer,
2931 notifications: db::NotificationBatch,
2932) {
2933 for (user_id, notification) in notifications {
2934 for connection_id in connection_pool.user_connection_ids(user_id) {
2935 if let Err(error) = peer.send(
2936 connection_id,
2937 proto::NewNotification {
2938 notification: Some(notification.clone()),
2939 },
2940 ) {
2941 tracing::error!(
2942 "failed to send notification to {:?} {}",
2943 connection_id,
2944 error
2945 );
2946 }
2947 }
2948 }
2949}
2950
2951async fn send_channel_message(
2952 request: proto::SendChannelMessage,
2953 response: Response<proto::SendChannelMessage>,
2954 session: Session,
2955) -> Result<()> {
2956 // Validate the message body.
2957 let body = request.body.trim().to_string();
2958 if body.len() > MAX_MESSAGE_LEN {
2959 return Err(anyhow!("message is too long"))?;
2960 }
2961 if body.is_empty() {
2962 return Err(anyhow!("message can't be blank"))?;
2963 }
2964
2965 // TODO: adjust mentions if body is trimmed
2966
2967 let timestamp = OffsetDateTime::now_utc();
2968 let nonce = request
2969 .nonce
2970 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
2971
2972 let channel_id = ChannelId::from_proto(request.channel_id);
2973 let CreatedChannelMessage {
2974 message_id,
2975 participant_connection_ids,
2976 channel_members,
2977 notifications,
2978 } = session
2979 .db()
2980 .await
2981 .create_channel_message(
2982 channel_id,
2983 session.user_id,
2984 &body,
2985 &request.mentions,
2986 timestamp,
2987 nonce.clone().into(),
2988 )
2989 .await?;
2990 let message = proto::ChannelMessage {
2991 sender_id: session.user_id.to_proto(),
2992 id: message_id.to_proto(),
2993 body,
2994 mentions: request.mentions,
2995 timestamp: timestamp.unix_timestamp() as u64,
2996 nonce: Some(nonce),
2997 };
2998 broadcast(
2999 Some(session.connection_id),
3000 participant_connection_ids,
3001 |connection| {
3002 session.peer.send(
3003 connection,
3004 proto::ChannelMessageSent {
3005 channel_id: channel_id.to_proto(),
3006 message: Some(message.clone()),
3007 },
3008 )
3009 },
3010 );
3011 response.send(proto::SendChannelMessageResponse {
3012 message: Some(message),
3013 })?;
3014
3015 let pool = &*session.connection_pool().await;
3016 broadcast(
3017 None,
3018 channel_members
3019 .iter()
3020 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3021 |peer_id| {
3022 session.peer.send(
3023 peer_id.into(),
3024 proto::UpdateChannels {
3025 unseen_channel_messages: vec![proto::UnseenChannelMessage {
3026 channel_id: channel_id.to_proto(),
3027 message_id: message_id.to_proto(),
3028 }],
3029 ..Default::default()
3030 },
3031 )
3032 },
3033 );
3034 send_notifications(pool, &session.peer, notifications);
3035
3036 Ok(())
3037}
3038
3039async fn remove_channel_message(
3040 request: proto::RemoveChannelMessage,
3041 response: Response<proto::RemoveChannelMessage>,
3042 session: Session,
3043) -> Result<()> {
3044 let channel_id = ChannelId::from_proto(request.channel_id);
3045 let message_id = MessageId::from_proto(request.message_id);
3046 let connection_ids = session
3047 .db()
3048 .await
3049 .remove_channel_message(channel_id, message_id, session.user_id)
3050 .await?;
3051 broadcast(Some(session.connection_id), connection_ids, |connection| {
3052 session.peer.send(connection, request.clone())
3053 });
3054 response.send(proto::Ack {})?;
3055 Ok(())
3056}
3057
3058async fn acknowledge_channel_message(
3059 request: proto::AckChannelMessage,
3060 session: Session,
3061) -> Result<()> {
3062 let channel_id = ChannelId::from_proto(request.channel_id);
3063 let message_id = MessageId::from_proto(request.message_id);
3064 session
3065 .db()
3066 .await
3067 .observe_channel_message(channel_id, session.user_id, message_id)
3068 .await?;
3069 Ok(())
3070}
3071
3072async fn acknowledge_buffer_version(
3073 request: proto::AckBufferOperation,
3074 session: Session,
3075) -> Result<()> {
3076 let buffer_id = BufferId::from_proto(request.buffer_id);
3077 session
3078 .db()
3079 .await
3080 .observe_buffer_version(
3081 buffer_id,
3082 session.user_id,
3083 request.epoch as i32,
3084 &request.version,
3085 )
3086 .await?;
3087 Ok(())
3088}
3089
3090async fn join_channel_chat(
3091 request: proto::JoinChannelChat,
3092 response: Response<proto::JoinChannelChat>,
3093 session: Session,
3094) -> Result<()> {
3095 let channel_id = ChannelId::from_proto(request.channel_id);
3096
3097 let db = session.db().await;
3098 db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3099 .await?;
3100 let messages = db
3101 .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3102 .await?;
3103 response.send(proto::JoinChannelChatResponse {
3104 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3105 messages,
3106 })?;
3107 Ok(())
3108}
3109
3110async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3111 let channel_id = ChannelId::from_proto(request.channel_id);
3112 session
3113 .db()
3114 .await
3115 .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3116 .await?;
3117 Ok(())
3118}
3119
3120async fn get_channel_messages(
3121 request: proto::GetChannelMessages,
3122 response: Response<proto::GetChannelMessages>,
3123 session: Session,
3124) -> Result<()> {
3125 let channel_id = ChannelId::from_proto(request.channel_id);
3126 let messages = session
3127 .db()
3128 .await
3129 .get_channel_messages(
3130 channel_id,
3131 session.user_id,
3132 MESSAGE_COUNT_PER_PAGE,
3133 Some(MessageId::from_proto(request.before_message_id)),
3134 )
3135 .await?;
3136 response.send(proto::GetChannelMessagesResponse {
3137 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3138 messages,
3139 })?;
3140 Ok(())
3141}
3142
3143async fn get_channel_messages_by_id(
3144 request: proto::GetChannelMessagesById,
3145 response: Response<proto::GetChannelMessagesById>,
3146 session: Session,
3147) -> Result<()> {
3148 let message_ids = request
3149 .message_ids
3150 .iter()
3151 .map(|id| MessageId::from_proto(*id))
3152 .collect::<Vec<_>>();
3153 let messages = session
3154 .db()
3155 .await
3156 .get_channel_messages_by_id(session.user_id, &message_ids)
3157 .await?;
3158 response.send(proto::GetChannelMessagesResponse {
3159 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3160 messages,
3161 })?;
3162 Ok(())
3163}
3164
3165async fn get_notifications(
3166 request: proto::GetNotifications,
3167 response: Response<proto::GetNotifications>,
3168 session: Session,
3169) -> Result<()> {
3170 let notifications = session
3171 .db()
3172 .await
3173 .get_notifications(
3174 session.user_id,
3175 NOTIFICATION_COUNT_PER_PAGE,
3176 request
3177 .before_id
3178 .map(|id| db::NotificationId::from_proto(id)),
3179 )
3180 .await?;
3181 response.send(proto::GetNotificationsResponse { notifications })?;
3182 Ok(())
3183}
3184
3185async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
3186 let project_id = ProjectId::from_proto(request.project_id);
3187 let project_connection_ids = session
3188 .db()
3189 .await
3190 .project_connection_ids(project_id, session.connection_id)
3191 .await?;
3192 broadcast(
3193 Some(session.connection_id),
3194 project_connection_ids.iter().copied(),
3195 |connection_id| {
3196 session
3197 .peer
3198 .forward_send(session.connection_id, connection_id, request.clone())
3199 },
3200 );
3201 Ok(())
3202}
3203
3204async fn get_private_user_info(
3205 _request: proto::GetPrivateUserInfo,
3206 response: Response<proto::GetPrivateUserInfo>,
3207 session: Session,
3208) -> Result<()> {
3209 let db = session.db().await;
3210
3211 let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3212 let user = db
3213 .get_user_by_id(session.user_id)
3214 .await?
3215 .ok_or_else(|| anyhow!("user not found"))?;
3216 let flags = db.get_user_flags(session.user_id).await?;
3217
3218 response.send(proto::GetPrivateUserInfoResponse {
3219 metrics_id,
3220 staff: user.admin,
3221 flags,
3222 })?;
3223 Ok(())
3224}
3225
3226fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3227 match message {
3228 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3229 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3230 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3231 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3232 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3233 code: frame.code.into(),
3234 reason: frame.reason,
3235 })),
3236 }
3237}
3238
3239fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3240 match message {
3241 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3242 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3243 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3244 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3245 AxumMessage::Close(frame) => {
3246 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3247 code: frame.code.into(),
3248 reason: frame.reason,
3249 }))
3250 }
3251 }
3252}
3253
3254fn build_initial_channels_update(
3255 channels: ChannelsForUser,
3256 channel_invites: Vec<db::Channel>,
3257) -> proto::UpdateChannels {
3258 let mut update = proto::UpdateChannels::default();
3259
3260 for channel in channels.channels.channels {
3261 update.channels.push(proto::Channel {
3262 id: channel.id.to_proto(),
3263 name: channel.name,
3264 visibility: channel.visibility.into(),
3265 });
3266 }
3267
3268 update.unseen_channel_buffer_changes = channels.unseen_buffer_changes;
3269 update.unseen_channel_messages = channels.channel_messages;
3270 update.insert_edge = channels.channels.edges;
3271
3272 for (channel_id, participants) in channels.channel_participants {
3273 update
3274 .channel_participants
3275 .push(proto::ChannelParticipants {
3276 channel_id: channel_id.to_proto(),
3277 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3278 });
3279 }
3280
3281 update
3282 .channel_permissions
3283 .extend(
3284 channels
3285 .channels_with_admin_privileges
3286 .into_iter()
3287 .map(|id| proto::ChannelPermission {
3288 channel_id: id.to_proto(),
3289 role: proto::ChannelRole::Admin.into(),
3290 }),
3291 );
3292
3293 for channel in channel_invites {
3294 update.channel_invitations.push(proto::Channel {
3295 id: channel.id.to_proto(),
3296 name: channel.name,
3297 // TODO: Visibility
3298 visibility: ChannelVisibility::Public.into(),
3299 });
3300 }
3301
3302 update
3303}
3304
3305fn build_initial_contacts_update(
3306 contacts: Vec<db::Contact>,
3307 pool: &ConnectionPool,
3308) -> proto::UpdateContacts {
3309 let mut update = proto::UpdateContacts::default();
3310
3311 for contact in contacts {
3312 match contact {
3313 db::Contact::Accepted { user_id, busy } => {
3314 update.contacts.push(contact_for_user(user_id, busy, &pool));
3315 }
3316 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3317 db::Contact::Incoming { user_id } => {
3318 update
3319 .incoming_requests
3320 .push(proto::IncomingContactRequest {
3321 requester_id: user_id.to_proto(),
3322 })
3323 }
3324 }
3325 }
3326
3327 update
3328}
3329
3330fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3331 proto::Contact {
3332 user_id: user_id.to_proto(),
3333 online: pool.is_user_online(user_id),
3334 busy,
3335 }
3336}
3337
3338fn room_updated(room: &proto::Room, peer: &Peer) {
3339 broadcast(
3340 None,
3341 room.participants
3342 .iter()
3343 .filter_map(|participant| Some(participant.peer_id?.into())),
3344 |peer_id| {
3345 peer.send(
3346 peer_id.into(),
3347 proto::RoomUpdated {
3348 room: Some(room.clone()),
3349 },
3350 )
3351 },
3352 );
3353}
3354
3355fn channel_updated(
3356 channel_id: ChannelId,
3357 room: &proto::Room,
3358 channel_members: &[UserId],
3359 peer: &Peer,
3360 pool: &ConnectionPool,
3361) {
3362 let participants = room
3363 .participants
3364 .iter()
3365 .map(|p| p.user_id)
3366 .collect::<Vec<_>>();
3367
3368 broadcast(
3369 None,
3370 channel_members
3371 .iter()
3372 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3373 |peer_id| {
3374 peer.send(
3375 peer_id.into(),
3376 proto::UpdateChannels {
3377 channel_participants: vec![proto::ChannelParticipants {
3378 channel_id: channel_id.to_proto(),
3379 participant_user_ids: participants.clone(),
3380 }],
3381 ..Default::default()
3382 },
3383 )
3384 },
3385 );
3386}
3387
3388async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3389 let db = session.db().await;
3390
3391 let contacts = db.get_contacts(user_id).await?;
3392 let busy = db.is_user_busy(user_id).await?;
3393
3394 let pool = session.connection_pool().await;
3395 let updated_contact = contact_for_user(user_id, busy, &pool);
3396 for contact in contacts {
3397 if let db::Contact::Accepted {
3398 user_id: contact_user_id,
3399 ..
3400 } = contact
3401 {
3402 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3403 session
3404 .peer
3405 .send(
3406 contact_conn_id,
3407 proto::UpdateContacts {
3408 contacts: vec![updated_contact.clone()],
3409 remove_contacts: Default::default(),
3410 incoming_requests: Default::default(),
3411 remove_incoming_requests: Default::default(),
3412 outgoing_requests: Default::default(),
3413 remove_outgoing_requests: Default::default(),
3414 },
3415 )
3416 .trace_err();
3417 }
3418 }
3419 }
3420 Ok(())
3421}
3422
3423async fn leave_room_for_session(session: &Session) -> Result<()> {
3424 let mut contacts_to_update = HashSet::default();
3425
3426 let room_id;
3427 let canceled_calls_to_user_ids;
3428 let live_kit_room;
3429 let delete_live_kit_room;
3430 let room;
3431 let channel_members;
3432 let channel_id;
3433
3434 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3435 contacts_to_update.insert(session.user_id);
3436
3437 for project in left_room.left_projects.values() {
3438 project_left(project, session);
3439 }
3440
3441 room_id = RoomId::from_proto(left_room.room.id);
3442 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3443 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3444 delete_live_kit_room = left_room.deleted;
3445 room = mem::take(&mut left_room.room);
3446 channel_members = mem::take(&mut left_room.channel_members);
3447 channel_id = left_room.channel_id;
3448
3449 room_updated(&room, &session.peer);
3450 } else {
3451 return Ok(());
3452 }
3453
3454 if let Some(channel_id) = channel_id {
3455 channel_updated(
3456 channel_id,
3457 &room,
3458 &channel_members,
3459 &session.peer,
3460 &*session.connection_pool().await,
3461 );
3462 }
3463
3464 {
3465 let pool = session.connection_pool().await;
3466 for canceled_user_id in canceled_calls_to_user_ids {
3467 for connection_id in pool.user_connection_ids(canceled_user_id) {
3468 session
3469 .peer
3470 .send(
3471 connection_id,
3472 proto::CallCanceled {
3473 room_id: room_id.to_proto(),
3474 },
3475 )
3476 .trace_err();
3477 }
3478 contacts_to_update.insert(canceled_user_id);
3479 }
3480 }
3481
3482 for contact_user_id in contacts_to_update {
3483 update_user_contacts(contact_user_id, &session).await?;
3484 }
3485
3486 if let Some(live_kit) = session.live_kit_client.as_ref() {
3487 live_kit
3488 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3489 .await
3490 .trace_err();
3491
3492 if delete_live_kit_room {
3493 live_kit.delete_room(live_kit_room).await.trace_err();
3494 }
3495 }
3496
3497 Ok(())
3498}
3499
3500async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3501 let left_channel_buffers = session
3502 .db()
3503 .await
3504 .leave_channel_buffers(session.connection_id)
3505 .await?;
3506
3507 for left_buffer in left_channel_buffers {
3508 channel_buffer_updated(
3509 session.connection_id,
3510 left_buffer.connections,
3511 &proto::UpdateChannelBufferCollaborators {
3512 channel_id: left_buffer.channel_id.to_proto(),
3513 collaborators: left_buffer.collaborators,
3514 },
3515 &session.peer,
3516 );
3517 }
3518
3519 Ok(())
3520}
3521
3522fn project_left(project: &db::LeftProject, session: &Session) {
3523 for connection_id in &project.connection_ids {
3524 if project.host_user_id == session.user_id {
3525 session
3526 .peer
3527 .send(
3528 *connection_id,
3529 proto::UnshareProject {
3530 project_id: project.id.to_proto(),
3531 },
3532 )
3533 .trace_err();
3534 } else {
3535 session
3536 .peer
3537 .send(
3538 *connection_id,
3539 proto::RemoveProjectCollaborator {
3540 project_id: project.id.to_proto(),
3541 peer_id: Some(session.connection_id.into()),
3542 },
3543 )
3544 .trace_err();
3545 }
3546 }
3547}
3548
3549pub trait ResultExt {
3550 type Ok;
3551
3552 fn trace_err(self) -> Option<Self::Ok>;
3553}
3554
3555impl<T, E> ResultExt for Result<T, E>
3556where
3557 E: std::fmt::Debug,
3558{
3559 type Ok = T;
3560
3561 fn trace_err(self) -> Option<T> {
3562 match self {
3563 Ok(value) => Some(value),
3564 Err(error) => {
3565 tracing::error!("{:?}", error);
3566 None
3567 }
3568 }
3569 }
3570}