1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{self, ChannelId, ChannelsForUser, Database, ProjectId, RoomId, ServerId, User, UserId},
6 executor::Executor,
7 AppState, Result,
8};
9use anyhow::anyhow;
10use async_tungstenite::tungstenite::{
11 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
12};
13use axum::{
14 body::Body,
15 extract::{
16 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
17 ConnectInfo, WebSocketUpgrade,
18 },
19 headers::{Header, HeaderName},
20 http::StatusCode,
21 middleware,
22 response::IntoResponse,
23 routing::get,
24 Extension, Router, TypedHeader,
25};
26use collections::{HashMap, HashSet};
27pub use connection_pool::ConnectionPool;
28use futures::{
29 channel::oneshot,
30 future::{self, BoxFuture},
31 stream::FuturesUnordered,
32 FutureExt, SinkExt, StreamExt, TryStreamExt,
33};
34use lazy_static::lazy_static;
35use prometheus::{register_int_gauge, IntGauge};
36use rpc::{
37 proto::{
38 self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
39 RequestMessage,
40 },
41 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
42};
43use serde::{Serialize, Serializer};
44use std::{
45 any::TypeId,
46 fmt,
47 future::Future,
48 marker::PhantomData,
49 mem,
50 net::SocketAddr,
51 ops::{Deref, DerefMut},
52 rc::Rc,
53 sync::{
54 atomic::{AtomicBool, Ordering::SeqCst},
55 Arc,
56 },
57 time::{Duration, Instant},
58};
59use tokio::sync::{watch, Semaphore};
60use tower::ServiceBuilder;
61use tracing::{info_span, instrument, Instrument};
62
63pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
64pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
65
66lazy_static! {
67 static ref METRIC_CONNECTIONS: IntGauge =
68 register_int_gauge!("connections", "number of connections").unwrap();
69 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
70 "shared_projects",
71 "number of open projects with one or more guests"
72 )
73 .unwrap();
74}
75
76type MessageHandler =
77 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
78
79struct Response<R> {
80 peer: Arc<Peer>,
81 receipt: Receipt<R>,
82 responded: Arc<AtomicBool>,
83}
84
85impl<R: RequestMessage> Response<R> {
86 fn send(self, payload: R::Response) -> Result<()> {
87 self.responded.store(true, SeqCst);
88 self.peer.respond(self.receipt, payload)?;
89 Ok(())
90 }
91}
92
93#[derive(Clone)]
94struct Session {
95 user_id: UserId,
96 connection_id: ConnectionId,
97 db: Arc<tokio::sync::Mutex<DbHandle>>,
98 peer: Arc<Peer>,
99 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
100 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
101 executor: Executor,
102}
103
104impl Session {
105 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
106 #[cfg(test)]
107 tokio::task::yield_now().await;
108 let guard = self.db.lock().await;
109 #[cfg(test)]
110 tokio::task::yield_now().await;
111 guard
112 }
113
114 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
115 #[cfg(test)]
116 tokio::task::yield_now().await;
117 let guard = self.connection_pool.lock();
118 ConnectionPoolGuard {
119 guard,
120 _not_send: PhantomData,
121 }
122 }
123}
124
125impl fmt::Debug for Session {
126 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127 f.debug_struct("Session")
128 .field("user_id", &self.user_id)
129 .field("connection_id", &self.connection_id)
130 .finish()
131 }
132}
133
134struct DbHandle(Arc<Database>);
135
136impl Deref for DbHandle {
137 type Target = Database;
138
139 fn deref(&self) -> &Self::Target {
140 self.0.as_ref()
141 }
142}
143
144pub struct Server {
145 id: parking_lot::Mutex<ServerId>,
146 peer: Arc<Peer>,
147 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
148 app_state: Arc<AppState>,
149 executor: Executor,
150 handlers: HashMap<TypeId, MessageHandler>,
151 teardown: watch::Sender<()>,
152}
153
154pub(crate) struct ConnectionPoolGuard<'a> {
155 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
156 _not_send: PhantomData<Rc<()>>,
157}
158
159#[derive(Serialize)]
160pub struct ServerSnapshot<'a> {
161 peer: &'a Peer,
162 #[serde(serialize_with = "serialize_deref")]
163 connection_pool: ConnectionPoolGuard<'a>,
164}
165
166pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
167where
168 S: Serializer,
169 T: Deref<Target = U>,
170 U: Serialize,
171{
172 Serialize::serialize(value.deref(), serializer)
173}
174
175impl Server {
176 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
177 let mut server = Self {
178 id: parking_lot::Mutex::new(id),
179 peer: Peer::new(id.0 as u32),
180 app_state,
181 executor,
182 connection_pool: Default::default(),
183 handlers: Default::default(),
184 teardown: watch::channel(()).0,
185 };
186
187 server
188 .add_request_handler(ping)
189 .add_request_handler(create_room)
190 .add_request_handler(join_room)
191 .add_request_handler(rejoin_room)
192 .add_request_handler(leave_room)
193 .add_request_handler(call)
194 .add_request_handler(cancel_call)
195 .add_message_handler(decline_call)
196 .add_request_handler(update_participant_location)
197 .add_request_handler(share_project)
198 .add_message_handler(unshare_project)
199 .add_request_handler(join_project)
200 .add_message_handler(leave_project)
201 .add_request_handler(update_project)
202 .add_request_handler(update_worktree)
203 .add_message_handler(start_language_server)
204 .add_message_handler(update_language_server)
205 .add_message_handler(update_diagnostic_summary)
206 .add_message_handler(update_worktree_settings)
207 .add_message_handler(refresh_inlay_hints)
208 .add_request_handler(forward_project_request::<proto::GetHover>)
209 .add_request_handler(forward_project_request::<proto::GetDefinition>)
210 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
211 .add_request_handler(forward_project_request::<proto::GetReferences>)
212 .add_request_handler(forward_project_request::<proto::SearchProject>)
213 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
214 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
215 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
216 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
217 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
218 .add_request_handler(forward_project_request::<proto::GetCompletions>)
219 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
220 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
221 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
222 .add_request_handler(forward_project_request::<proto::PrepareRename>)
223 .add_request_handler(forward_project_request::<proto::PerformRename>)
224 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
225 .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
226 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
227 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
228 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
229 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
230 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
231 .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
232 .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
233 .add_request_handler(forward_project_request::<proto::InlayHints>)
234 .add_message_handler(create_buffer_for_peer)
235 .add_request_handler(update_buffer)
236 .add_message_handler(update_buffer_file)
237 .add_message_handler(buffer_reloaded)
238 .add_message_handler(buffer_saved)
239 .add_request_handler(forward_project_request::<proto::SaveBuffer>)
240 .add_request_handler(get_users)
241 .add_request_handler(fuzzy_search_users)
242 .add_request_handler(request_contact)
243 .add_request_handler(remove_contact)
244 .add_request_handler(respond_to_contact_request)
245 .add_request_handler(create_channel)
246 .add_request_handler(remove_channel)
247 .add_request_handler(invite_channel_member)
248 .add_request_handler(remove_channel_member)
249 .add_request_handler(set_channel_member_admin)
250 .add_request_handler(rename_channel)
251 .add_request_handler(get_channel_members)
252 .add_request_handler(respond_to_channel_invite)
253 .add_request_handler(join_channel)
254 .add_request_handler(follow)
255 .add_message_handler(unfollow)
256 .add_message_handler(update_followers)
257 .add_message_handler(update_diff_base)
258 .add_request_handler(get_private_user_info);
259
260 Arc::new(server)
261 }
262
263 pub async fn start(&self) -> Result<()> {
264 let server_id = *self.id.lock();
265 let app_state = self.app_state.clone();
266 let peer = self.peer.clone();
267 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
268 let pool = self.connection_pool.clone();
269 let live_kit_client = self.app_state.live_kit_client.clone();
270
271 let span = info_span!("start server");
272 self.executor.spawn_detached(
273 async move {
274 tracing::info!("waiting for cleanup timeout");
275 timeout.await;
276 tracing::info!("cleanup timeout expired, retrieving stale rooms");
277 if let Some(room_ids) = app_state
278 .db
279 .stale_room_ids(&app_state.config.zed_environment, server_id)
280 .await
281 .trace_err()
282 {
283 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
284 for room_id in room_ids {
285 let mut contacts_to_update = HashSet::default();
286 let mut canceled_calls_to_user_ids = Vec::new();
287 let mut live_kit_room = String::new();
288 let mut delete_live_kit_room = false;
289
290 if let Some(mut refreshed_room) = app_state
291 .db
292 .refresh_room(room_id, server_id)
293 .await
294 .trace_err()
295 {
296 tracing::info!(
297 room_id = room_id.0,
298 new_participant_count = refreshed_room.room.participants.len(),
299 "refreshed room"
300 );
301 room_updated(&refreshed_room.room, &peer);
302 if let Some(channel_id) = refreshed_room.channel_id {
303 channel_updated(
304 channel_id,
305 &refreshed_room.room,
306 &refreshed_room.channel_members,
307 &peer,
308 &*pool.lock(),
309 );
310 }
311 contacts_to_update
312 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
313 contacts_to_update
314 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
315 canceled_calls_to_user_ids =
316 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
317 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
318 delete_live_kit_room = refreshed_room.room.participants.is_empty();
319 }
320
321 {
322 let pool = pool.lock();
323 for canceled_user_id in canceled_calls_to_user_ids {
324 for connection_id in pool.user_connection_ids(canceled_user_id) {
325 peer.send(
326 connection_id,
327 proto::CallCanceled {
328 room_id: room_id.to_proto(),
329 },
330 )
331 .trace_err();
332 }
333 }
334 }
335
336 for user_id in contacts_to_update {
337 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
338 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
339 if let Some((busy, contacts)) = busy.zip(contacts) {
340 let pool = pool.lock();
341 let updated_contact = contact_for_user(user_id, false, busy, &pool);
342 for contact in contacts {
343 if let db::Contact::Accepted {
344 user_id: contact_user_id,
345 ..
346 } = contact
347 {
348 for contact_conn_id in
349 pool.user_connection_ids(contact_user_id)
350 {
351 peer.send(
352 contact_conn_id,
353 proto::UpdateContacts {
354 contacts: vec![updated_contact.clone()],
355 remove_contacts: Default::default(),
356 incoming_requests: Default::default(),
357 remove_incoming_requests: Default::default(),
358 outgoing_requests: Default::default(),
359 remove_outgoing_requests: Default::default(),
360 },
361 )
362 .trace_err();
363 }
364 }
365 }
366 }
367 }
368
369 if let Some(live_kit) = live_kit_client.as_ref() {
370 if delete_live_kit_room {
371 live_kit.delete_room(live_kit_room).await.trace_err();
372 }
373 }
374 }
375 }
376
377 app_state
378 .db
379 .delete_stale_servers(&app_state.config.zed_environment, server_id)
380 .await
381 .trace_err();
382 }
383 .instrument(span),
384 );
385 Ok(())
386 }
387
388 pub fn teardown(&self) {
389 self.peer.teardown();
390 self.connection_pool.lock().reset();
391 let _ = self.teardown.send(());
392 }
393
394 #[cfg(test)]
395 pub fn reset(&self, id: ServerId) {
396 self.teardown();
397 *self.id.lock() = id;
398 self.peer.reset(id.0 as u32);
399 }
400
401 #[cfg(test)]
402 pub fn id(&self) -> ServerId {
403 *self.id.lock()
404 }
405
406 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
407 where
408 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
409 Fut: 'static + Send + Future<Output = Result<()>>,
410 M: EnvelopedMessage,
411 {
412 let prev_handler = self.handlers.insert(
413 TypeId::of::<M>(),
414 Box::new(move |envelope, session| {
415 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
416 let span = info_span!(
417 "handle message",
418 payload_type = envelope.payload_type_name()
419 );
420 span.in_scope(|| {
421 tracing::info!(
422 payload_type = envelope.payload_type_name(),
423 "message received"
424 );
425 });
426 let start_time = Instant::now();
427 let future = (handler)(*envelope, session);
428 async move {
429 let result = future.await;
430 let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
431 match result {
432 Err(error) => {
433 tracing::error!(%error, ?duration_ms, "error handling message")
434 }
435 Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
436 }
437 }
438 .instrument(span)
439 .boxed()
440 }),
441 );
442 if prev_handler.is_some() {
443 panic!("registered a handler for the same message twice");
444 }
445 self
446 }
447
448 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
449 where
450 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
451 Fut: 'static + Send + Future<Output = Result<()>>,
452 M: EnvelopedMessage,
453 {
454 self.add_handler(move |envelope, session| handler(envelope.payload, session));
455 self
456 }
457
458 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
459 where
460 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
461 Fut: Send + Future<Output = Result<()>>,
462 M: RequestMessage,
463 {
464 let handler = Arc::new(handler);
465 self.add_handler(move |envelope, session| {
466 let receipt = envelope.receipt();
467 let handler = handler.clone();
468 async move {
469 let peer = session.peer.clone();
470 let responded = Arc::new(AtomicBool::default());
471 let response = Response {
472 peer: peer.clone(),
473 responded: responded.clone(),
474 receipt,
475 };
476 match (handler)(envelope.payload, response, session).await {
477 Ok(()) => {
478 if responded.load(std::sync::atomic::Ordering::SeqCst) {
479 Ok(())
480 } else {
481 Err(anyhow!("handler did not send a response"))?
482 }
483 }
484 Err(error) => {
485 peer.respond_with_error(
486 receipt,
487 proto::Error {
488 message: error.to_string(),
489 },
490 )?;
491 Err(error)
492 }
493 }
494 }
495 })
496 }
497
498 pub fn handle_connection(
499 self: &Arc<Self>,
500 connection: Connection,
501 address: String,
502 user: User,
503 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
504 executor: Executor,
505 ) -> impl Future<Output = Result<()>> {
506 let this = self.clone();
507 let user_id = user.id;
508 let login = user.github_login;
509 let span = info_span!("handle connection", %user_id, %login, %address);
510 let mut teardown = self.teardown.subscribe();
511 async move {
512 let (connection_id, handle_io, mut incoming_rx) = this
513 .peer
514 .add_connection(connection, {
515 let executor = executor.clone();
516 move |duration| executor.sleep(duration)
517 });
518
519 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
520 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
521 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
522
523 if let Some(send_connection_id) = send_connection_id.take() {
524 let _ = send_connection_id.send(connection_id);
525 }
526
527 if !user.connected_once {
528 this.peer.send(connection_id, proto::ShowContacts {})?;
529 this.app_state.db.set_user_connected_once(user_id, true).await?;
530 }
531
532 let (contacts, invite_code, channels_for_user, channel_invites) = future::try_join4(
533 this.app_state.db.get_contacts(user_id),
534 this.app_state.db.get_invite_code_for_user(user_id),
535 this.app_state.db.get_channels_for_user(user_id),
536 this.app_state.db.get_channel_invites_for_user(user_id)
537 ).await?;
538
539 {
540 let mut pool = this.connection_pool.lock();
541 pool.add_connection(connection_id, user_id, user.admin);
542 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
543 this.peer.send(connection_id, build_initial_channels_update(
544 channels_for_user,
545 channel_invites
546 ))?;
547
548 if let Some((code, count)) = invite_code {
549 this.peer.send(connection_id, proto::UpdateInviteInfo {
550 url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
551 count: count as u32,
552 })?;
553 }
554 }
555
556 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
557 this.peer.send(connection_id, incoming_call)?;
558 }
559
560 let session = Session {
561 user_id,
562 connection_id,
563 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
564 peer: this.peer.clone(),
565 connection_pool: this.connection_pool.clone(),
566 live_kit_client: this.app_state.live_kit_client.clone(),
567 executor: executor.clone(),
568 };
569 update_user_contacts(user_id, &session).await?;
570
571 let handle_io = handle_io.fuse();
572 futures::pin_mut!(handle_io);
573
574 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
575 // This prevents deadlocks when e.g., client A performs a request to client B and
576 // client B performs a request to client A. If both clients stop processing further
577 // messages until their respective request completes, they won't have a chance to
578 // respond to the other client's request and cause a deadlock.
579 //
580 // This arrangement ensures we will attempt to process earlier messages first, but fall
581 // back to processing messages arrived later in the spirit of making progress.
582 let mut foreground_message_handlers = FuturesUnordered::new();
583 let concurrent_handlers = Arc::new(Semaphore::new(256));
584 loop {
585 let next_message = async {
586 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
587 let message = incoming_rx.next().await;
588 (permit, message)
589 }.fuse();
590 futures::pin_mut!(next_message);
591 futures::select_biased! {
592 _ = teardown.changed().fuse() => return Ok(()),
593 result = handle_io => {
594 if let Err(error) = result {
595 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
596 }
597 break;
598 }
599 _ = foreground_message_handlers.next() => {}
600 next_message = next_message => {
601 let (permit, message) = next_message;
602 if let Some(message) = message {
603 let type_name = message.payload_type_name();
604 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
605 let span_enter = span.enter();
606 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
607 let is_background = message.is_background();
608 let handle_message = (handler)(message, session.clone());
609 drop(span_enter);
610
611 let handle_message = async move {
612 handle_message.await;
613 drop(permit);
614 }.instrument(span);
615 if is_background {
616 executor.spawn_detached(handle_message);
617 } else {
618 foreground_message_handlers.push(handle_message);
619 }
620 } else {
621 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
622 }
623 } else {
624 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
625 break;
626 }
627 }
628 }
629 }
630
631 drop(foreground_message_handlers);
632 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
633 if let Err(error) = connection_lost(session, teardown, executor).await {
634 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
635 }
636
637 Ok(())
638 }.instrument(span)
639 }
640
641 pub async fn invite_code_redeemed(
642 self: &Arc<Self>,
643 inviter_id: UserId,
644 invitee_id: UserId,
645 ) -> Result<()> {
646 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
647 if let Some(code) = &user.invite_code {
648 let pool = self.connection_pool.lock();
649 let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
650 for connection_id in pool.user_connection_ids(inviter_id) {
651 self.peer.send(
652 connection_id,
653 proto::UpdateContacts {
654 contacts: vec![invitee_contact.clone()],
655 ..Default::default()
656 },
657 )?;
658 self.peer.send(
659 connection_id,
660 proto::UpdateInviteInfo {
661 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
662 count: user.invite_count as u32,
663 },
664 )?;
665 }
666 }
667 }
668 Ok(())
669 }
670
671 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
672 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
673 if let Some(invite_code) = &user.invite_code {
674 let pool = self.connection_pool.lock();
675 for connection_id in pool.user_connection_ids(user_id) {
676 self.peer.send(
677 connection_id,
678 proto::UpdateInviteInfo {
679 url: format!(
680 "{}{}",
681 self.app_state.config.invite_link_prefix, invite_code
682 ),
683 count: user.invite_count as u32,
684 },
685 )?;
686 }
687 }
688 }
689 Ok(())
690 }
691
692 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
693 ServerSnapshot {
694 connection_pool: ConnectionPoolGuard {
695 guard: self.connection_pool.lock(),
696 _not_send: PhantomData,
697 },
698 peer: &self.peer,
699 }
700 }
701}
702
703impl<'a> Deref for ConnectionPoolGuard<'a> {
704 type Target = ConnectionPool;
705
706 fn deref(&self) -> &Self::Target {
707 &*self.guard
708 }
709}
710
711impl<'a> DerefMut for ConnectionPoolGuard<'a> {
712 fn deref_mut(&mut self) -> &mut Self::Target {
713 &mut *self.guard
714 }
715}
716
717impl<'a> Drop for ConnectionPoolGuard<'a> {
718 fn drop(&mut self) {
719 #[cfg(test)]
720 self.check_invariants();
721 }
722}
723
724fn broadcast<F>(
725 sender_id: Option<ConnectionId>,
726 receiver_ids: impl IntoIterator<Item = ConnectionId>,
727 mut f: F,
728) where
729 F: FnMut(ConnectionId) -> anyhow::Result<()>,
730{
731 for receiver_id in receiver_ids {
732 if Some(receiver_id) != sender_id {
733 if let Err(error) = f(receiver_id) {
734 tracing::error!("failed to send to {:?} {}", receiver_id, error);
735 }
736 }
737 }
738}
739
740lazy_static! {
741 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
742}
743
744pub struct ProtocolVersion(u32);
745
746impl Header for ProtocolVersion {
747 fn name() -> &'static HeaderName {
748 &ZED_PROTOCOL_VERSION
749 }
750
751 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
752 where
753 Self: Sized,
754 I: Iterator<Item = &'i axum::http::HeaderValue>,
755 {
756 let version = values
757 .next()
758 .ok_or_else(axum::headers::Error::invalid)?
759 .to_str()
760 .map_err(|_| axum::headers::Error::invalid())?
761 .parse()
762 .map_err(|_| axum::headers::Error::invalid())?;
763 Ok(Self(version))
764 }
765
766 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
767 values.extend([self.0.to_string().parse().unwrap()]);
768 }
769}
770
771pub fn routes(server: Arc<Server>) -> Router<Body> {
772 Router::new()
773 .route("/rpc", get(handle_websocket_request))
774 .layer(
775 ServiceBuilder::new()
776 .layer(Extension(server.app_state.clone()))
777 .layer(middleware::from_fn(auth::validate_header)),
778 )
779 .route("/metrics", get(handle_metrics))
780 .layer(Extension(server))
781}
782
783pub async fn handle_websocket_request(
784 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
785 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
786 Extension(server): Extension<Arc<Server>>,
787 Extension(user): Extension<User>,
788 ws: WebSocketUpgrade,
789) -> axum::response::Response {
790 if protocol_version != rpc::PROTOCOL_VERSION {
791 return (
792 StatusCode::UPGRADE_REQUIRED,
793 "client must be upgraded".to_string(),
794 )
795 .into_response();
796 }
797 let socket_address = socket_address.to_string();
798 ws.on_upgrade(move |socket| {
799 use util::ResultExt;
800 let socket = socket
801 .map_ok(to_tungstenite_message)
802 .err_into()
803 .with(|message| async move { Ok(to_axum_message(message)) });
804 let connection = Connection::new(Box::pin(socket));
805 async move {
806 server
807 .handle_connection(connection, socket_address, user, None, Executor::Production)
808 .await
809 .log_err();
810 }
811 })
812}
813
814pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
815 let connections = server
816 .connection_pool
817 .lock()
818 .connections()
819 .filter(|connection| !connection.admin)
820 .count();
821
822 METRIC_CONNECTIONS.set(connections as _);
823
824 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
825 METRIC_SHARED_PROJECTS.set(shared_projects as _);
826
827 let encoder = prometheus::TextEncoder::new();
828 let metric_families = prometheus::gather();
829 let encoded_metrics = encoder
830 .encode_to_string(&metric_families)
831 .map_err(|err| anyhow!("{}", err))?;
832 Ok(encoded_metrics)
833}
834
835#[instrument(err, skip(executor))]
836async fn connection_lost(
837 session: Session,
838 mut teardown: watch::Receiver<()>,
839 executor: Executor,
840) -> Result<()> {
841 session.peer.disconnect(session.connection_id);
842 session
843 .connection_pool()
844 .await
845 .remove_connection(session.connection_id)?;
846
847 session
848 .db()
849 .await
850 .connection_lost(session.connection_id)
851 .await
852 .trace_err();
853
854 futures::select_biased! {
855 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
856 leave_room_for_session(&session).await.trace_err();
857
858 if !session
859 .connection_pool()
860 .await
861 .is_user_online(session.user_id)
862 {
863 let db = session.db().await;
864 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
865 room_updated(&room, &session.peer);
866 }
867 }
868 update_user_contacts(session.user_id, &session).await?;
869 }
870 _ = teardown.changed().fuse() => {}
871 }
872
873 Ok(())
874}
875
876async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
877 response.send(proto::Ack {})?;
878 Ok(())
879}
880
881async fn create_room(
882 _request: proto::CreateRoom,
883 response: Response<proto::CreateRoom>,
884 session: Session,
885) -> Result<()> {
886 let live_kit_room = nanoid::nanoid!(30);
887
888 let live_kit_connection_info = {
889 let live_kit_room = live_kit_room.clone();
890 let live_kit = session.live_kit_client.as_ref();
891
892 util::async_iife!({
893 let live_kit = live_kit?;
894
895 live_kit
896 .create_room(live_kit_room.clone())
897 .await
898 .trace_err()?;
899
900 let token = live_kit
901 .room_token(&live_kit_room, &session.user_id.to_string())
902 .trace_err()?;
903
904 Some(proto::LiveKitConnectionInfo {
905 server_url: live_kit.url().into(),
906 token,
907 })
908 })
909 }
910 .await;
911
912 let room = session
913 .db()
914 .await
915 .create_room(session.user_id, session.connection_id, &live_kit_room)
916 .await?;
917
918 response.send(proto::CreateRoomResponse {
919 room: Some(room.clone()),
920 live_kit_connection_info,
921 })?;
922
923 update_user_contacts(session.user_id, &session).await?;
924 Ok(())
925}
926
927async fn join_room(
928 request: proto::JoinRoom,
929 response: Response<proto::JoinRoom>,
930 session: Session,
931) -> Result<()> {
932 let room_id = RoomId::from_proto(request.id);
933 let room = {
934 let room = session
935 .db()
936 .await
937 .join_room(room_id, session.user_id, None, session.connection_id)
938 .await?;
939 room_updated(&room.room, &session.peer);
940 room.room.clone()
941 };
942
943 for connection_id in session
944 .connection_pool()
945 .await
946 .user_connection_ids(session.user_id)
947 {
948 session
949 .peer
950 .send(
951 connection_id,
952 proto::CallCanceled {
953 room_id: room_id.to_proto(),
954 },
955 )
956 .trace_err();
957 }
958
959 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
960 if let Some(token) = live_kit
961 .room_token(&room.live_kit_room, &session.user_id.to_string())
962 .trace_err()
963 {
964 Some(proto::LiveKitConnectionInfo {
965 server_url: live_kit.url().into(),
966 token,
967 })
968 } else {
969 None
970 }
971 } else {
972 None
973 };
974
975 response.send(proto::JoinRoomResponse {
976 room: Some(room),
977 live_kit_connection_info,
978 })?;
979
980 update_user_contacts(session.user_id, &session).await?;
981 Ok(())
982}
983
984async fn rejoin_room(
985 request: proto::RejoinRoom,
986 response: Response<proto::RejoinRoom>,
987 session: Session,
988) -> Result<()> {
989 let room;
990 let channel_id;
991 let channel_members;
992 {
993 let mut rejoined_room = session
994 .db()
995 .await
996 .rejoin_room(request, session.user_id, session.connection_id)
997 .await?;
998
999 response.send(proto::RejoinRoomResponse {
1000 room: Some(rejoined_room.room.clone()),
1001 reshared_projects: rejoined_room
1002 .reshared_projects
1003 .iter()
1004 .map(|project| proto::ResharedProject {
1005 id: project.id.to_proto(),
1006 collaborators: project
1007 .collaborators
1008 .iter()
1009 .map(|collaborator| collaborator.to_proto())
1010 .collect(),
1011 })
1012 .collect(),
1013 rejoined_projects: rejoined_room
1014 .rejoined_projects
1015 .iter()
1016 .map(|rejoined_project| proto::RejoinedProject {
1017 id: rejoined_project.id.to_proto(),
1018 worktrees: rejoined_project
1019 .worktrees
1020 .iter()
1021 .map(|worktree| proto::WorktreeMetadata {
1022 id: worktree.id,
1023 root_name: worktree.root_name.clone(),
1024 visible: worktree.visible,
1025 abs_path: worktree.abs_path.clone(),
1026 })
1027 .collect(),
1028 collaborators: rejoined_project
1029 .collaborators
1030 .iter()
1031 .map(|collaborator| collaborator.to_proto())
1032 .collect(),
1033 language_servers: rejoined_project.language_servers.clone(),
1034 })
1035 .collect(),
1036 })?;
1037 room_updated(&rejoined_room.room, &session.peer);
1038
1039 for project in &rejoined_room.reshared_projects {
1040 for collaborator in &project.collaborators {
1041 session
1042 .peer
1043 .send(
1044 collaborator.connection_id,
1045 proto::UpdateProjectCollaborator {
1046 project_id: project.id.to_proto(),
1047 old_peer_id: Some(project.old_connection_id.into()),
1048 new_peer_id: Some(session.connection_id.into()),
1049 },
1050 )
1051 .trace_err();
1052 }
1053
1054 broadcast(
1055 Some(session.connection_id),
1056 project
1057 .collaborators
1058 .iter()
1059 .map(|collaborator| collaborator.connection_id),
1060 |connection_id| {
1061 session.peer.forward_send(
1062 session.connection_id,
1063 connection_id,
1064 proto::UpdateProject {
1065 project_id: project.id.to_proto(),
1066 worktrees: project.worktrees.clone(),
1067 },
1068 )
1069 },
1070 );
1071 }
1072
1073 for project in &rejoined_room.rejoined_projects {
1074 for collaborator in &project.collaborators {
1075 session
1076 .peer
1077 .send(
1078 collaborator.connection_id,
1079 proto::UpdateProjectCollaborator {
1080 project_id: project.id.to_proto(),
1081 old_peer_id: Some(project.old_connection_id.into()),
1082 new_peer_id: Some(session.connection_id.into()),
1083 },
1084 )
1085 .trace_err();
1086 }
1087 }
1088
1089 for project in &mut rejoined_room.rejoined_projects {
1090 for worktree in mem::take(&mut project.worktrees) {
1091 #[cfg(any(test, feature = "test-support"))]
1092 const MAX_CHUNK_SIZE: usize = 2;
1093 #[cfg(not(any(test, feature = "test-support")))]
1094 const MAX_CHUNK_SIZE: usize = 256;
1095
1096 // Stream this worktree's entries.
1097 let message = proto::UpdateWorktree {
1098 project_id: project.id.to_proto(),
1099 worktree_id: worktree.id,
1100 abs_path: worktree.abs_path.clone(),
1101 root_name: worktree.root_name,
1102 updated_entries: worktree.updated_entries,
1103 removed_entries: worktree.removed_entries,
1104 scan_id: worktree.scan_id,
1105 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1106 updated_repositories: worktree.updated_repositories,
1107 removed_repositories: worktree.removed_repositories,
1108 };
1109 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1110 session.peer.send(session.connection_id, update.clone())?;
1111 }
1112
1113 // Stream this worktree's diagnostics.
1114 for summary in worktree.diagnostic_summaries {
1115 session.peer.send(
1116 session.connection_id,
1117 proto::UpdateDiagnosticSummary {
1118 project_id: project.id.to_proto(),
1119 worktree_id: worktree.id,
1120 summary: Some(summary),
1121 },
1122 )?;
1123 }
1124
1125 for settings_file in worktree.settings_files {
1126 session.peer.send(
1127 session.connection_id,
1128 proto::UpdateWorktreeSettings {
1129 project_id: project.id.to_proto(),
1130 worktree_id: worktree.id,
1131 path: settings_file.path,
1132 content: Some(settings_file.content),
1133 },
1134 )?;
1135 }
1136 }
1137
1138 for language_server in &project.language_servers {
1139 session.peer.send(
1140 session.connection_id,
1141 proto::UpdateLanguageServer {
1142 project_id: project.id.to_proto(),
1143 language_server_id: language_server.id,
1144 variant: Some(
1145 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1146 proto::LspDiskBasedDiagnosticsUpdated {},
1147 ),
1148 ),
1149 },
1150 )?;
1151 }
1152 }
1153
1154 room = mem::take(&mut rejoined_room.room);
1155 channel_id = rejoined_room.channel_id;
1156 channel_members = mem::take(&mut rejoined_room.channel_members);
1157 }
1158
1159 //TODO: move this into the room guard
1160 if let Some(channel_id) = channel_id {
1161 channel_updated(
1162 channel_id,
1163 &room,
1164 &channel_members,
1165 &session.peer,
1166 &*session.connection_pool().await,
1167 );
1168 }
1169
1170 update_user_contacts(session.user_id, &session).await?;
1171 Ok(())
1172}
1173
1174async fn leave_room(
1175 _: proto::LeaveRoom,
1176 response: Response<proto::LeaveRoom>,
1177 session: Session,
1178) -> Result<()> {
1179 leave_room_for_session(&session).await?;
1180 response.send(proto::Ack {})?;
1181 Ok(())
1182}
1183
1184async fn call(
1185 request: proto::Call,
1186 response: Response<proto::Call>,
1187 session: Session,
1188) -> Result<()> {
1189 let room_id = RoomId::from_proto(request.room_id);
1190 let calling_user_id = session.user_id;
1191 let calling_connection_id = session.connection_id;
1192 let called_user_id = UserId::from_proto(request.called_user_id);
1193 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1194 if !session
1195 .db()
1196 .await
1197 .has_contact(calling_user_id, called_user_id)
1198 .await?
1199 {
1200 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1201 }
1202
1203 let incoming_call = {
1204 let (room, incoming_call) = &mut *session
1205 .db()
1206 .await
1207 .call(
1208 room_id,
1209 calling_user_id,
1210 calling_connection_id,
1211 called_user_id,
1212 initial_project_id,
1213 )
1214 .await?;
1215 room_updated(&room, &session.peer);
1216 mem::take(incoming_call)
1217 };
1218 update_user_contacts(called_user_id, &session).await?;
1219
1220 let mut calls = session
1221 .connection_pool()
1222 .await
1223 .user_connection_ids(called_user_id)
1224 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1225 .collect::<FuturesUnordered<_>>();
1226
1227 while let Some(call_response) = calls.next().await {
1228 match call_response.as_ref() {
1229 Ok(_) => {
1230 response.send(proto::Ack {})?;
1231 return Ok(());
1232 }
1233 Err(_) => {
1234 call_response.trace_err();
1235 }
1236 }
1237 }
1238
1239 {
1240 let room = session
1241 .db()
1242 .await
1243 .call_failed(room_id, called_user_id)
1244 .await?;
1245 room_updated(&room, &session.peer);
1246 }
1247 update_user_contacts(called_user_id, &session).await?;
1248
1249 Err(anyhow!("failed to ring user"))?
1250}
1251
1252async fn cancel_call(
1253 request: proto::CancelCall,
1254 response: Response<proto::CancelCall>,
1255 session: Session,
1256) -> Result<()> {
1257 let called_user_id = UserId::from_proto(request.called_user_id);
1258 let room_id = RoomId::from_proto(request.room_id);
1259 {
1260 let room = session
1261 .db()
1262 .await
1263 .cancel_call(room_id, session.connection_id, called_user_id)
1264 .await?;
1265 room_updated(&room, &session.peer);
1266 }
1267
1268 for connection_id in session
1269 .connection_pool()
1270 .await
1271 .user_connection_ids(called_user_id)
1272 {
1273 session
1274 .peer
1275 .send(
1276 connection_id,
1277 proto::CallCanceled {
1278 room_id: room_id.to_proto(),
1279 },
1280 )
1281 .trace_err();
1282 }
1283 response.send(proto::Ack {})?;
1284
1285 update_user_contacts(called_user_id, &session).await?;
1286 Ok(())
1287}
1288
1289async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1290 let room_id = RoomId::from_proto(message.room_id);
1291 {
1292 let room = session
1293 .db()
1294 .await
1295 .decline_call(Some(room_id), session.user_id)
1296 .await?
1297 .ok_or_else(|| anyhow!("failed to decline call"))?;
1298 room_updated(&room, &session.peer);
1299 }
1300
1301 for connection_id in session
1302 .connection_pool()
1303 .await
1304 .user_connection_ids(session.user_id)
1305 {
1306 session
1307 .peer
1308 .send(
1309 connection_id,
1310 proto::CallCanceled {
1311 room_id: room_id.to_proto(),
1312 },
1313 )
1314 .trace_err();
1315 }
1316 update_user_contacts(session.user_id, &session).await?;
1317 Ok(())
1318}
1319
1320async fn update_participant_location(
1321 request: proto::UpdateParticipantLocation,
1322 response: Response<proto::UpdateParticipantLocation>,
1323 session: Session,
1324) -> Result<()> {
1325 let room_id = RoomId::from_proto(request.room_id);
1326 let location = request
1327 .location
1328 .ok_or_else(|| anyhow!("invalid location"))?;
1329
1330 let db = session.db().await;
1331 let room = db
1332 .update_room_participant_location(room_id, session.connection_id, location)
1333 .await?;
1334
1335 room_updated(&room, &session.peer);
1336 response.send(proto::Ack {})?;
1337 Ok(())
1338}
1339
1340async fn share_project(
1341 request: proto::ShareProject,
1342 response: Response<proto::ShareProject>,
1343 session: Session,
1344) -> Result<()> {
1345 let (project_id, room) = &*session
1346 .db()
1347 .await
1348 .share_project(
1349 RoomId::from_proto(request.room_id),
1350 session.connection_id,
1351 &request.worktrees,
1352 )
1353 .await?;
1354 response.send(proto::ShareProjectResponse {
1355 project_id: project_id.to_proto(),
1356 })?;
1357 room_updated(&room, &session.peer);
1358
1359 Ok(())
1360}
1361
1362async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1363 let project_id = ProjectId::from_proto(message.project_id);
1364
1365 let (room, guest_connection_ids) = &*session
1366 .db()
1367 .await
1368 .unshare_project(project_id, session.connection_id)
1369 .await?;
1370
1371 broadcast(
1372 Some(session.connection_id),
1373 guest_connection_ids.iter().copied(),
1374 |conn_id| session.peer.send(conn_id, message.clone()),
1375 );
1376 room_updated(&room, &session.peer);
1377
1378 Ok(())
1379}
1380
1381async fn join_project(
1382 request: proto::JoinProject,
1383 response: Response<proto::JoinProject>,
1384 session: Session,
1385) -> Result<()> {
1386 let project_id = ProjectId::from_proto(request.project_id);
1387 let guest_user_id = session.user_id;
1388
1389 tracing::info!(%project_id, "join project");
1390
1391 let (project, replica_id) = &mut *session
1392 .db()
1393 .await
1394 .join_project(project_id, session.connection_id)
1395 .await?;
1396
1397 let collaborators = project
1398 .collaborators
1399 .iter()
1400 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1401 .map(|collaborator| collaborator.to_proto())
1402 .collect::<Vec<_>>();
1403
1404 let worktrees = project
1405 .worktrees
1406 .iter()
1407 .map(|(id, worktree)| proto::WorktreeMetadata {
1408 id: *id,
1409 root_name: worktree.root_name.clone(),
1410 visible: worktree.visible,
1411 abs_path: worktree.abs_path.clone(),
1412 })
1413 .collect::<Vec<_>>();
1414
1415 for collaborator in &collaborators {
1416 session
1417 .peer
1418 .send(
1419 collaborator.peer_id.unwrap().into(),
1420 proto::AddProjectCollaborator {
1421 project_id: project_id.to_proto(),
1422 collaborator: Some(proto::Collaborator {
1423 peer_id: Some(session.connection_id.into()),
1424 replica_id: replica_id.0 as u32,
1425 user_id: guest_user_id.to_proto(),
1426 }),
1427 },
1428 )
1429 .trace_err();
1430 }
1431
1432 // First, we send the metadata associated with each worktree.
1433 response.send(proto::JoinProjectResponse {
1434 worktrees: worktrees.clone(),
1435 replica_id: replica_id.0 as u32,
1436 collaborators: collaborators.clone(),
1437 language_servers: project.language_servers.clone(),
1438 })?;
1439
1440 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1441 #[cfg(any(test, feature = "test-support"))]
1442 const MAX_CHUNK_SIZE: usize = 2;
1443 #[cfg(not(any(test, feature = "test-support")))]
1444 const MAX_CHUNK_SIZE: usize = 256;
1445
1446 // Stream this worktree's entries.
1447 let message = proto::UpdateWorktree {
1448 project_id: project_id.to_proto(),
1449 worktree_id,
1450 abs_path: worktree.abs_path.clone(),
1451 root_name: worktree.root_name,
1452 updated_entries: worktree.entries,
1453 removed_entries: Default::default(),
1454 scan_id: worktree.scan_id,
1455 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1456 updated_repositories: worktree.repository_entries.into_values().collect(),
1457 removed_repositories: Default::default(),
1458 };
1459 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1460 session.peer.send(session.connection_id, update.clone())?;
1461 }
1462
1463 // Stream this worktree's diagnostics.
1464 for summary in worktree.diagnostic_summaries {
1465 session.peer.send(
1466 session.connection_id,
1467 proto::UpdateDiagnosticSummary {
1468 project_id: project_id.to_proto(),
1469 worktree_id: worktree.id,
1470 summary: Some(summary),
1471 },
1472 )?;
1473 }
1474
1475 for settings_file in worktree.settings_files {
1476 session.peer.send(
1477 session.connection_id,
1478 proto::UpdateWorktreeSettings {
1479 project_id: project_id.to_proto(),
1480 worktree_id: worktree.id,
1481 path: settings_file.path,
1482 content: Some(settings_file.content),
1483 },
1484 )?;
1485 }
1486 }
1487
1488 for language_server in &project.language_servers {
1489 session.peer.send(
1490 session.connection_id,
1491 proto::UpdateLanguageServer {
1492 project_id: project_id.to_proto(),
1493 language_server_id: language_server.id,
1494 variant: Some(
1495 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1496 proto::LspDiskBasedDiagnosticsUpdated {},
1497 ),
1498 ),
1499 },
1500 )?;
1501 }
1502
1503 Ok(())
1504}
1505
1506async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1507 let sender_id = session.connection_id;
1508 let project_id = ProjectId::from_proto(request.project_id);
1509
1510 let (room, project) = &*session
1511 .db()
1512 .await
1513 .leave_project(project_id, sender_id)
1514 .await?;
1515 tracing::info!(
1516 %project_id,
1517 host_user_id = %project.host_user_id,
1518 host_connection_id = %project.host_connection_id,
1519 "leave project"
1520 );
1521
1522 project_left(&project, &session);
1523 room_updated(&room, &session.peer);
1524
1525 Ok(())
1526}
1527
1528async fn update_project(
1529 request: proto::UpdateProject,
1530 response: Response<proto::UpdateProject>,
1531 session: Session,
1532) -> Result<()> {
1533 let project_id = ProjectId::from_proto(request.project_id);
1534 let (room, guest_connection_ids) = &*session
1535 .db()
1536 .await
1537 .update_project(project_id, session.connection_id, &request.worktrees)
1538 .await?;
1539 broadcast(
1540 Some(session.connection_id),
1541 guest_connection_ids.iter().copied(),
1542 |connection_id| {
1543 session
1544 .peer
1545 .forward_send(session.connection_id, connection_id, request.clone())
1546 },
1547 );
1548 room_updated(&room, &session.peer);
1549 response.send(proto::Ack {})?;
1550
1551 Ok(())
1552}
1553
1554async fn update_worktree(
1555 request: proto::UpdateWorktree,
1556 response: Response<proto::UpdateWorktree>,
1557 session: Session,
1558) -> Result<()> {
1559 let guest_connection_ids = session
1560 .db()
1561 .await
1562 .update_worktree(&request, session.connection_id)
1563 .await?;
1564
1565 broadcast(
1566 Some(session.connection_id),
1567 guest_connection_ids.iter().copied(),
1568 |connection_id| {
1569 session
1570 .peer
1571 .forward_send(session.connection_id, connection_id, request.clone())
1572 },
1573 );
1574 response.send(proto::Ack {})?;
1575 Ok(())
1576}
1577
1578async fn update_diagnostic_summary(
1579 message: proto::UpdateDiagnosticSummary,
1580 session: Session,
1581) -> Result<()> {
1582 let guest_connection_ids = session
1583 .db()
1584 .await
1585 .update_diagnostic_summary(&message, session.connection_id)
1586 .await?;
1587
1588 broadcast(
1589 Some(session.connection_id),
1590 guest_connection_ids.iter().copied(),
1591 |connection_id| {
1592 session
1593 .peer
1594 .forward_send(session.connection_id, connection_id, message.clone())
1595 },
1596 );
1597
1598 Ok(())
1599}
1600
1601async fn update_worktree_settings(
1602 message: proto::UpdateWorktreeSettings,
1603 session: Session,
1604) -> Result<()> {
1605 let guest_connection_ids = session
1606 .db()
1607 .await
1608 .update_worktree_settings(&message, session.connection_id)
1609 .await?;
1610
1611 broadcast(
1612 Some(session.connection_id),
1613 guest_connection_ids.iter().copied(),
1614 |connection_id| {
1615 session
1616 .peer
1617 .forward_send(session.connection_id, connection_id, message.clone())
1618 },
1619 );
1620
1621 Ok(())
1622}
1623
1624async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1625 broadcast_project_message(request.project_id, request, session).await
1626}
1627
1628async fn start_language_server(
1629 request: proto::StartLanguageServer,
1630 session: Session,
1631) -> Result<()> {
1632 let guest_connection_ids = session
1633 .db()
1634 .await
1635 .start_language_server(&request, session.connection_id)
1636 .await?;
1637
1638 broadcast(
1639 Some(session.connection_id),
1640 guest_connection_ids.iter().copied(),
1641 |connection_id| {
1642 session
1643 .peer
1644 .forward_send(session.connection_id, connection_id, request.clone())
1645 },
1646 );
1647 Ok(())
1648}
1649
1650async fn update_language_server(
1651 request: proto::UpdateLanguageServer,
1652 session: Session,
1653) -> Result<()> {
1654 session.executor.record_backtrace();
1655 let project_id = ProjectId::from_proto(request.project_id);
1656 let project_connection_ids = session
1657 .db()
1658 .await
1659 .project_connection_ids(project_id, session.connection_id)
1660 .await?;
1661 broadcast(
1662 Some(session.connection_id),
1663 project_connection_ids.iter().copied(),
1664 |connection_id| {
1665 session
1666 .peer
1667 .forward_send(session.connection_id, connection_id, request.clone())
1668 },
1669 );
1670 Ok(())
1671}
1672
1673async fn forward_project_request<T>(
1674 request: T,
1675 response: Response<T>,
1676 session: Session,
1677) -> Result<()>
1678where
1679 T: EntityMessage + RequestMessage,
1680{
1681 session.executor.record_backtrace();
1682 let project_id = ProjectId::from_proto(request.remote_entity_id());
1683 let host_connection_id = {
1684 let collaborators = session
1685 .db()
1686 .await
1687 .project_collaborators(project_id, session.connection_id)
1688 .await?;
1689 collaborators
1690 .iter()
1691 .find(|collaborator| collaborator.is_host)
1692 .ok_or_else(|| anyhow!("host not found"))?
1693 .connection_id
1694 };
1695
1696 let payload = session
1697 .peer
1698 .forward_request(session.connection_id, host_connection_id, request)
1699 .await?;
1700
1701 response.send(payload)?;
1702 Ok(())
1703}
1704
1705async fn create_buffer_for_peer(
1706 request: proto::CreateBufferForPeer,
1707 session: Session,
1708) -> Result<()> {
1709 session.executor.record_backtrace();
1710 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1711 session
1712 .peer
1713 .forward_send(session.connection_id, peer_id.into(), request)?;
1714 Ok(())
1715}
1716
1717async fn update_buffer(
1718 request: proto::UpdateBuffer,
1719 response: Response<proto::UpdateBuffer>,
1720 session: Session,
1721) -> Result<()> {
1722 session.executor.record_backtrace();
1723 let project_id = ProjectId::from_proto(request.project_id);
1724 let mut guest_connection_ids;
1725 let mut host_connection_id = None;
1726 {
1727 let collaborators = session
1728 .db()
1729 .await
1730 .project_collaborators(project_id, session.connection_id)
1731 .await?;
1732 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1733 for collaborator in collaborators.iter() {
1734 if collaborator.is_host {
1735 host_connection_id = Some(collaborator.connection_id);
1736 } else {
1737 guest_connection_ids.push(collaborator.connection_id);
1738 }
1739 }
1740 }
1741 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1742
1743 session.executor.record_backtrace();
1744 broadcast(
1745 Some(session.connection_id),
1746 guest_connection_ids,
1747 |connection_id| {
1748 session
1749 .peer
1750 .forward_send(session.connection_id, connection_id, request.clone())
1751 },
1752 );
1753 if host_connection_id != session.connection_id {
1754 session
1755 .peer
1756 .forward_request(session.connection_id, host_connection_id, request.clone())
1757 .await?;
1758 }
1759
1760 response.send(proto::Ack {})?;
1761 Ok(())
1762}
1763
1764async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1765 let project_id = ProjectId::from_proto(request.project_id);
1766 let project_connection_ids = session
1767 .db()
1768 .await
1769 .project_connection_ids(project_id, session.connection_id)
1770 .await?;
1771
1772 broadcast(
1773 Some(session.connection_id),
1774 project_connection_ids.iter().copied(),
1775 |connection_id| {
1776 session
1777 .peer
1778 .forward_send(session.connection_id, connection_id, request.clone())
1779 },
1780 );
1781 Ok(())
1782}
1783
1784async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1785 let project_id = ProjectId::from_proto(request.project_id);
1786 let project_connection_ids = session
1787 .db()
1788 .await
1789 .project_connection_ids(project_id, session.connection_id)
1790 .await?;
1791 broadcast(
1792 Some(session.connection_id),
1793 project_connection_ids.iter().copied(),
1794 |connection_id| {
1795 session
1796 .peer
1797 .forward_send(session.connection_id, connection_id, request.clone())
1798 },
1799 );
1800 Ok(())
1801}
1802
1803async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1804 broadcast_project_message(request.project_id, request, session).await
1805}
1806
1807async fn broadcast_project_message<T: EnvelopedMessage>(
1808 project_id: u64,
1809 request: T,
1810 session: Session,
1811) -> Result<()> {
1812 let project_id = ProjectId::from_proto(project_id);
1813 let project_connection_ids = session
1814 .db()
1815 .await
1816 .project_connection_ids(project_id, session.connection_id)
1817 .await?;
1818 broadcast(
1819 Some(session.connection_id),
1820 project_connection_ids.iter().copied(),
1821 |connection_id| {
1822 session
1823 .peer
1824 .forward_send(session.connection_id, connection_id, request.clone())
1825 },
1826 );
1827 Ok(())
1828}
1829
1830async fn follow(
1831 request: proto::Follow,
1832 response: Response<proto::Follow>,
1833 session: Session,
1834) -> Result<()> {
1835 let project_id = ProjectId::from_proto(request.project_id);
1836 let leader_id = request
1837 .leader_id
1838 .ok_or_else(|| anyhow!("invalid leader id"))?
1839 .into();
1840 let follower_id = session.connection_id;
1841
1842 {
1843 let project_connection_ids = session
1844 .db()
1845 .await
1846 .project_connection_ids(project_id, session.connection_id)
1847 .await?;
1848
1849 if !project_connection_ids.contains(&leader_id) {
1850 Err(anyhow!("no such peer"))?;
1851 }
1852 }
1853
1854 let mut response_payload = session
1855 .peer
1856 .forward_request(session.connection_id, leader_id, request)
1857 .await?;
1858 response_payload
1859 .views
1860 .retain(|view| view.leader_id != Some(follower_id.into()));
1861 response.send(response_payload)?;
1862
1863 let room = session
1864 .db()
1865 .await
1866 .follow(project_id, leader_id, follower_id)
1867 .await?;
1868 room_updated(&room, &session.peer);
1869
1870 Ok(())
1871}
1872
1873async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1874 let project_id = ProjectId::from_proto(request.project_id);
1875 let leader_id = request
1876 .leader_id
1877 .ok_or_else(|| anyhow!("invalid leader id"))?
1878 .into();
1879 let follower_id = session.connection_id;
1880
1881 if !session
1882 .db()
1883 .await
1884 .project_connection_ids(project_id, session.connection_id)
1885 .await?
1886 .contains(&leader_id)
1887 {
1888 Err(anyhow!("no such peer"))?;
1889 }
1890
1891 session
1892 .peer
1893 .forward_send(session.connection_id, leader_id, request)?;
1894
1895 let room = session
1896 .db()
1897 .await
1898 .unfollow(project_id, leader_id, follower_id)
1899 .await?;
1900 room_updated(&room, &session.peer);
1901
1902 Ok(())
1903}
1904
1905async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1906 let project_id = ProjectId::from_proto(request.project_id);
1907 let project_connection_ids = session
1908 .db
1909 .lock()
1910 .await
1911 .project_connection_ids(project_id, session.connection_id)
1912 .await?;
1913
1914 let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1915 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1916 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1917 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1918 });
1919 for follower_peer_id in request.follower_ids.iter().copied() {
1920 let follower_connection_id = follower_peer_id.into();
1921 if project_connection_ids.contains(&follower_connection_id)
1922 && Some(follower_peer_id) != leader_id
1923 {
1924 session.peer.forward_send(
1925 session.connection_id,
1926 follower_connection_id,
1927 request.clone(),
1928 )?;
1929 }
1930 }
1931 Ok(())
1932}
1933
1934async fn get_users(
1935 request: proto::GetUsers,
1936 response: Response<proto::GetUsers>,
1937 session: Session,
1938) -> Result<()> {
1939 let user_ids = request
1940 .user_ids
1941 .into_iter()
1942 .map(UserId::from_proto)
1943 .collect();
1944 let users = session
1945 .db()
1946 .await
1947 .get_users_by_ids(user_ids)
1948 .await?
1949 .into_iter()
1950 .map(|user| proto::User {
1951 id: user.id.to_proto(),
1952 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1953 github_login: user.github_login,
1954 })
1955 .collect();
1956 response.send(proto::UsersResponse { users })?;
1957 Ok(())
1958}
1959
1960async fn fuzzy_search_users(
1961 request: proto::FuzzySearchUsers,
1962 response: Response<proto::FuzzySearchUsers>,
1963 session: Session,
1964) -> Result<()> {
1965 let query = request.query;
1966 let users = match query.len() {
1967 0 => vec![],
1968 1 | 2 => session
1969 .db()
1970 .await
1971 .get_user_by_github_login(&query)
1972 .await?
1973 .into_iter()
1974 .collect(),
1975 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1976 };
1977 let users = users
1978 .into_iter()
1979 .filter(|user| user.id != session.user_id)
1980 .map(|user| proto::User {
1981 id: user.id.to_proto(),
1982 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1983 github_login: user.github_login,
1984 })
1985 .collect();
1986 response.send(proto::UsersResponse { users })?;
1987 Ok(())
1988}
1989
1990async fn request_contact(
1991 request: proto::RequestContact,
1992 response: Response<proto::RequestContact>,
1993 session: Session,
1994) -> Result<()> {
1995 let requester_id = session.user_id;
1996 let responder_id = UserId::from_proto(request.responder_id);
1997 if requester_id == responder_id {
1998 return Err(anyhow!("cannot add yourself as a contact"))?;
1999 }
2000
2001 session
2002 .db()
2003 .await
2004 .send_contact_request(requester_id, responder_id)
2005 .await?;
2006
2007 // Update outgoing contact requests of requester
2008 let mut update = proto::UpdateContacts::default();
2009 update.outgoing_requests.push(responder_id.to_proto());
2010 for connection_id in session
2011 .connection_pool()
2012 .await
2013 .user_connection_ids(requester_id)
2014 {
2015 session.peer.send(connection_id, update.clone())?;
2016 }
2017
2018 // Update incoming contact requests of responder
2019 let mut update = proto::UpdateContacts::default();
2020 update
2021 .incoming_requests
2022 .push(proto::IncomingContactRequest {
2023 requester_id: requester_id.to_proto(),
2024 should_notify: true,
2025 });
2026 for connection_id in session
2027 .connection_pool()
2028 .await
2029 .user_connection_ids(responder_id)
2030 {
2031 session.peer.send(connection_id, update.clone())?;
2032 }
2033
2034 response.send(proto::Ack {})?;
2035 Ok(())
2036}
2037
2038async fn respond_to_contact_request(
2039 request: proto::RespondToContactRequest,
2040 response: Response<proto::RespondToContactRequest>,
2041 session: Session,
2042) -> Result<()> {
2043 let responder_id = session.user_id;
2044 let requester_id = UserId::from_proto(request.requester_id);
2045 let db = session.db().await;
2046 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2047 db.dismiss_contact_notification(responder_id, requester_id)
2048 .await?;
2049 } else {
2050 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2051
2052 db.respond_to_contact_request(responder_id, requester_id, accept)
2053 .await?;
2054 let requester_busy = db.is_user_busy(requester_id).await?;
2055 let responder_busy = db.is_user_busy(responder_id).await?;
2056
2057 let pool = session.connection_pool().await;
2058 // Update responder with new contact
2059 let mut update = proto::UpdateContacts::default();
2060 if accept {
2061 update
2062 .contacts
2063 .push(contact_for_user(requester_id, false, requester_busy, &pool));
2064 }
2065 update
2066 .remove_incoming_requests
2067 .push(requester_id.to_proto());
2068 for connection_id in pool.user_connection_ids(responder_id) {
2069 session.peer.send(connection_id, update.clone())?;
2070 }
2071
2072 // Update requester with new contact
2073 let mut update = proto::UpdateContacts::default();
2074 if accept {
2075 update
2076 .contacts
2077 .push(contact_for_user(responder_id, true, responder_busy, &pool));
2078 }
2079 update
2080 .remove_outgoing_requests
2081 .push(responder_id.to_proto());
2082 for connection_id in pool.user_connection_ids(requester_id) {
2083 session.peer.send(connection_id, update.clone())?;
2084 }
2085 }
2086
2087 response.send(proto::Ack {})?;
2088 Ok(())
2089}
2090
2091async fn remove_contact(
2092 request: proto::RemoveContact,
2093 response: Response<proto::RemoveContact>,
2094 session: Session,
2095) -> Result<()> {
2096 let requester_id = session.user_id;
2097 let responder_id = UserId::from_proto(request.user_id);
2098 let db = session.db().await;
2099 let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
2100
2101 let pool = session.connection_pool().await;
2102 // Update outgoing contact requests of requester
2103 let mut update = proto::UpdateContacts::default();
2104 if contact_accepted {
2105 update.remove_contacts.push(responder_id.to_proto());
2106 } else {
2107 update
2108 .remove_outgoing_requests
2109 .push(responder_id.to_proto());
2110 }
2111 for connection_id in pool.user_connection_ids(requester_id) {
2112 session.peer.send(connection_id, update.clone())?;
2113 }
2114
2115 // Update incoming contact requests of responder
2116 let mut update = proto::UpdateContacts::default();
2117 if contact_accepted {
2118 update.remove_contacts.push(requester_id.to_proto());
2119 } else {
2120 update
2121 .remove_incoming_requests
2122 .push(requester_id.to_proto());
2123 }
2124 for connection_id in pool.user_connection_ids(responder_id) {
2125 session.peer.send(connection_id, update.clone())?;
2126 }
2127
2128 response.send(proto::Ack {})?;
2129 Ok(())
2130}
2131
2132async fn create_channel(
2133 request: proto::CreateChannel,
2134 response: Response<proto::CreateChannel>,
2135 session: Session,
2136) -> Result<()> {
2137 let db = session.db().await;
2138 let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
2139
2140 if let Some(live_kit) = session.live_kit_client.as_ref() {
2141 live_kit.create_room(live_kit_room.clone()).await?;
2142 }
2143
2144 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2145 let id = db
2146 .create_channel(&request.name, parent_id, &live_kit_room, session.user_id)
2147 .await?;
2148
2149 let channel = proto::Channel {
2150 id: id.to_proto(),
2151 name: request.name,
2152 parent_id: request.parent_id,
2153 };
2154
2155 response.send(proto::ChannelResponse {
2156 channel: Some(channel.clone()),
2157 })?;
2158
2159 let mut update = proto::UpdateChannels::default();
2160 update.channels.push(channel);
2161
2162 let user_ids_to_notify = if let Some(parent_id) = parent_id {
2163 db.get_channel_members(parent_id).await?
2164 } else {
2165 vec![session.user_id]
2166 };
2167
2168 let connection_pool = session.connection_pool().await;
2169 for user_id in user_ids_to_notify {
2170 for connection_id in connection_pool.user_connection_ids(user_id) {
2171 let mut update = update.clone();
2172 if user_id == session.user_id {
2173 update.channel_permissions.push(proto::ChannelPermission {
2174 channel_id: id.to_proto(),
2175 is_admin: true,
2176 });
2177 }
2178 session.peer.send(connection_id, update)?;
2179 }
2180 }
2181
2182 Ok(())
2183}
2184
2185async fn remove_channel(
2186 request: proto::RemoveChannel,
2187 response: Response<proto::RemoveChannel>,
2188 session: Session,
2189) -> Result<()> {
2190 let db = session.db().await;
2191
2192 let channel_id = request.channel_id;
2193 let (removed_channels, member_ids) = db
2194 .remove_channel(ChannelId::from_proto(channel_id), session.user_id)
2195 .await?;
2196 response.send(proto::Ack {})?;
2197
2198 // Notify members of removed channels
2199 let mut update = proto::UpdateChannels::default();
2200 update
2201 .remove_channels
2202 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2203
2204 let connection_pool = session.connection_pool().await;
2205 for member_id in member_ids {
2206 for connection_id in connection_pool.user_connection_ids(member_id) {
2207 session.peer.send(connection_id, update.clone())?;
2208 }
2209 }
2210
2211 Ok(())
2212}
2213
2214async fn invite_channel_member(
2215 request: proto::InviteChannelMember,
2216 response: Response<proto::InviteChannelMember>,
2217 session: Session,
2218) -> Result<()> {
2219 let db = session.db().await;
2220 let channel_id = ChannelId::from_proto(request.channel_id);
2221 let invitee_id = UserId::from_proto(request.user_id);
2222 db.invite_channel_member(channel_id, invitee_id, session.user_id, request.admin)
2223 .await?;
2224
2225 let (channel, _) = db
2226 .get_channel(channel_id, session.user_id)
2227 .await?
2228 .ok_or_else(|| anyhow!("channel not found"))?;
2229
2230 let mut update = proto::UpdateChannels::default();
2231 update.channel_invitations.push(proto::Channel {
2232 id: channel.id.to_proto(),
2233 name: channel.name,
2234 parent_id: None,
2235 });
2236 for connection_id in session
2237 .connection_pool()
2238 .await
2239 .user_connection_ids(invitee_id)
2240 {
2241 session.peer.send(connection_id, update.clone())?;
2242 }
2243
2244 response.send(proto::Ack {})?;
2245 Ok(())
2246}
2247
2248async fn remove_channel_member(
2249 request: proto::RemoveChannelMember,
2250 response: Response<proto::RemoveChannelMember>,
2251 session: Session,
2252) -> Result<()> {
2253 let db = session.db().await;
2254 let channel_id = ChannelId::from_proto(request.channel_id);
2255 let member_id = UserId::from_proto(request.user_id);
2256
2257 db.remove_channel_member(channel_id, member_id, session.user_id)
2258 .await?;
2259
2260 let mut update = proto::UpdateChannels::default();
2261 update.remove_channels.push(channel_id.to_proto());
2262
2263 for connection_id in session
2264 .connection_pool()
2265 .await
2266 .user_connection_ids(member_id)
2267 {
2268 session.peer.send(connection_id, update.clone())?;
2269 }
2270
2271 response.send(proto::Ack {})?;
2272 Ok(())
2273}
2274
2275async fn set_channel_member_admin(
2276 request: proto::SetChannelMemberAdmin,
2277 response: Response<proto::SetChannelMemberAdmin>,
2278 session: Session,
2279) -> Result<()> {
2280 let db = session.db().await;
2281 let channel_id = ChannelId::from_proto(request.channel_id);
2282 let member_id = UserId::from_proto(request.user_id);
2283 db.set_channel_member_admin(channel_id, session.user_id, member_id, request.admin)
2284 .await?;
2285
2286 let (channel, has_accepted) = db
2287 .get_channel(channel_id, member_id)
2288 .await?
2289 .ok_or_else(|| anyhow!("channel not found"))?;
2290
2291 let mut update = proto::UpdateChannels::default();
2292 if has_accepted {
2293 update.channel_permissions.push(proto::ChannelPermission {
2294 channel_id: channel.id.to_proto(),
2295 is_admin: request.admin,
2296 });
2297 }
2298
2299 for connection_id in session
2300 .connection_pool()
2301 .await
2302 .user_connection_ids(member_id)
2303 {
2304 session.peer.send(connection_id, update.clone())?;
2305 }
2306
2307 response.send(proto::Ack {})?;
2308 Ok(())
2309}
2310
2311async fn rename_channel(
2312 request: proto::RenameChannel,
2313 response: Response<proto::RenameChannel>,
2314 session: Session,
2315) -> Result<()> {
2316 let db = session.db().await;
2317 let channel_id = ChannelId::from_proto(request.channel_id);
2318 let new_name = db
2319 .rename_channel(channel_id, session.user_id, &request.name)
2320 .await?;
2321
2322 let channel = proto::Channel {
2323 id: request.channel_id,
2324 name: new_name,
2325 parent_id: None,
2326 };
2327 response.send(proto::ChannelResponse {
2328 channel: Some(channel.clone()),
2329 })?;
2330 let mut update = proto::UpdateChannels::default();
2331 update.channels.push(channel);
2332
2333 let member_ids = db.get_channel_members(channel_id).await?;
2334
2335 let connection_pool = session.connection_pool().await;
2336 for member_id in member_ids {
2337 for connection_id in connection_pool.user_connection_ids(member_id) {
2338 session.peer.send(connection_id, update.clone())?;
2339 }
2340 }
2341
2342 Ok(())
2343}
2344
2345async fn get_channel_members(
2346 request: proto::GetChannelMembers,
2347 response: Response<proto::GetChannelMembers>,
2348 session: Session,
2349) -> Result<()> {
2350 let db = session.db().await;
2351 let channel_id = ChannelId::from_proto(request.channel_id);
2352 let members = db
2353 .get_channel_member_details(channel_id, session.user_id)
2354 .await?;
2355 response.send(proto::GetChannelMembersResponse { members })?;
2356 Ok(())
2357}
2358
2359async fn respond_to_channel_invite(
2360 request: proto::RespondToChannelInvite,
2361 response: Response<proto::RespondToChannelInvite>,
2362 session: Session,
2363) -> Result<()> {
2364 let db = session.db().await;
2365 let channel_id = ChannelId::from_proto(request.channel_id);
2366 db.respond_to_channel_invite(channel_id, session.user_id, request.accept)
2367 .await?;
2368
2369 let mut update = proto::UpdateChannels::default();
2370 update
2371 .remove_channel_invitations
2372 .push(channel_id.to_proto());
2373 if request.accept {
2374 let result = db.get_channels_for_user(session.user_id).await?;
2375 update
2376 .channels
2377 .extend(result.channels.into_iter().map(|channel| proto::Channel {
2378 id: channel.id.to_proto(),
2379 name: channel.name,
2380 parent_id: channel.parent_id.map(ChannelId::to_proto),
2381 }));
2382 update
2383 .channel_participants
2384 .extend(
2385 result
2386 .channel_participants
2387 .into_iter()
2388 .map(|(channel_id, user_ids)| proto::ChannelParticipants {
2389 channel_id: channel_id.to_proto(),
2390 participant_user_ids: user_ids.into_iter().map(UserId::to_proto).collect(),
2391 }),
2392 );
2393 update
2394 .channel_permissions
2395 .extend(
2396 result
2397 .channels_with_admin_privileges
2398 .into_iter()
2399 .map(|channel_id| proto::ChannelPermission {
2400 channel_id: channel_id.to_proto(),
2401 is_admin: true,
2402 }),
2403 );
2404 }
2405 session.peer.send(session.connection_id, update)?;
2406 response.send(proto::Ack {})?;
2407
2408 Ok(())
2409}
2410
2411async fn join_channel(
2412 request: proto::JoinChannel,
2413 response: Response<proto::JoinChannel>,
2414 session: Session,
2415) -> Result<()> {
2416 let channel_id = ChannelId::from_proto(request.channel_id);
2417
2418 let joined_room = {
2419 let db = session.db().await;
2420
2421 let room_id = db.room_id_for_channel(channel_id).await?;
2422
2423 let joined_room = db
2424 .join_room(
2425 room_id,
2426 session.user_id,
2427 Some(channel_id),
2428 session.connection_id,
2429 )
2430 .await?;
2431
2432 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2433 let token = live_kit
2434 .room_token(
2435 &joined_room.room.live_kit_room,
2436 &session.user_id.to_string(),
2437 )
2438 .trace_err()?;
2439
2440 Some(LiveKitConnectionInfo {
2441 server_url: live_kit.url().into(),
2442 token,
2443 })
2444 });
2445
2446 response.send(proto::JoinRoomResponse {
2447 room: Some(joined_room.room.clone()),
2448 live_kit_connection_info,
2449 })?;
2450
2451 room_updated(&joined_room.room, &session.peer);
2452
2453 joined_room.clone()
2454 };
2455
2456 // TODO - do this while still holding the room guard,
2457 // currently there's a possible race condition if someone joins the channel
2458 // after we've dropped the lock but before we finish sending these updates
2459 channel_updated(
2460 channel_id,
2461 &joined_room.room,
2462 &joined_room.channel_members,
2463 &session.peer,
2464 &*session.connection_pool().await,
2465 );
2466
2467 update_user_contacts(session.user_id, &session).await?;
2468
2469 Ok(())
2470}
2471
2472async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
2473 let project_id = ProjectId::from_proto(request.project_id);
2474 let project_connection_ids = session
2475 .db()
2476 .await
2477 .project_connection_ids(project_id, session.connection_id)
2478 .await?;
2479 broadcast(
2480 Some(session.connection_id),
2481 project_connection_ids.iter().copied(),
2482 |connection_id| {
2483 session
2484 .peer
2485 .forward_send(session.connection_id, connection_id, request.clone())
2486 },
2487 );
2488 Ok(())
2489}
2490
2491async fn get_private_user_info(
2492 _request: proto::GetPrivateUserInfo,
2493 response: Response<proto::GetPrivateUserInfo>,
2494 session: Session,
2495) -> Result<()> {
2496 let metrics_id = session
2497 .db()
2498 .await
2499 .get_user_metrics_id(session.user_id)
2500 .await?;
2501 let user = session
2502 .db()
2503 .await
2504 .get_user_by_id(session.user_id)
2505 .await?
2506 .ok_or_else(|| anyhow!("user not found"))?;
2507 response.send(proto::GetPrivateUserInfoResponse {
2508 metrics_id,
2509 staff: user.admin,
2510 })?;
2511 Ok(())
2512}
2513
2514fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
2515 match message {
2516 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
2517 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
2518 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
2519 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
2520 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
2521 code: frame.code.into(),
2522 reason: frame.reason,
2523 })),
2524 }
2525}
2526
2527fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
2528 match message {
2529 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
2530 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
2531 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
2532 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
2533 AxumMessage::Close(frame) => {
2534 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
2535 code: frame.code.into(),
2536 reason: frame.reason,
2537 }))
2538 }
2539 }
2540}
2541
2542fn build_initial_channels_update(
2543 channels: ChannelsForUser,
2544 channel_invites: Vec<db::Channel>,
2545) -> proto::UpdateChannels {
2546 let mut update = proto::UpdateChannels::default();
2547
2548 for channel in channels.channels {
2549 update.channels.push(proto::Channel {
2550 id: channel.id.to_proto(),
2551 name: channel.name,
2552 parent_id: channel.parent_id.map(|id| id.to_proto()),
2553 });
2554 }
2555
2556 for (channel_id, participants) in channels.channel_participants {
2557 update
2558 .channel_participants
2559 .push(proto::ChannelParticipants {
2560 channel_id: channel_id.to_proto(),
2561 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
2562 });
2563 }
2564
2565 update
2566 .channel_permissions
2567 .extend(
2568 channels
2569 .channels_with_admin_privileges
2570 .into_iter()
2571 .map(|id| proto::ChannelPermission {
2572 channel_id: id.to_proto(),
2573 is_admin: true,
2574 }),
2575 );
2576
2577 for channel in channel_invites {
2578 update.channel_invitations.push(proto::Channel {
2579 id: channel.id.to_proto(),
2580 name: channel.name,
2581 parent_id: None,
2582 });
2583 }
2584
2585 update
2586}
2587
2588fn build_initial_contacts_update(
2589 contacts: Vec<db::Contact>,
2590 pool: &ConnectionPool,
2591) -> proto::UpdateContacts {
2592 let mut update = proto::UpdateContacts::default();
2593
2594 for contact in contacts {
2595 match contact {
2596 db::Contact::Accepted {
2597 user_id,
2598 should_notify,
2599 busy,
2600 } => {
2601 update
2602 .contacts
2603 .push(contact_for_user(user_id, should_notify, busy, &pool));
2604 }
2605 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
2606 db::Contact::Incoming {
2607 user_id,
2608 should_notify,
2609 } => update
2610 .incoming_requests
2611 .push(proto::IncomingContactRequest {
2612 requester_id: user_id.to_proto(),
2613 should_notify,
2614 }),
2615 }
2616 }
2617
2618 update
2619}
2620
2621fn contact_for_user(
2622 user_id: UserId,
2623 should_notify: bool,
2624 busy: bool,
2625 pool: &ConnectionPool,
2626) -> proto::Contact {
2627 proto::Contact {
2628 user_id: user_id.to_proto(),
2629 online: pool.is_user_online(user_id),
2630 busy,
2631 should_notify,
2632 }
2633}
2634
2635fn room_updated(room: &proto::Room, peer: &Peer) {
2636 broadcast(
2637 None,
2638 room.participants
2639 .iter()
2640 .filter_map(|participant| Some(participant.peer_id?.into())),
2641 |peer_id| {
2642 peer.send(
2643 peer_id.into(),
2644 proto::RoomUpdated {
2645 room: Some(room.clone()),
2646 },
2647 )
2648 },
2649 );
2650}
2651
2652fn channel_updated(
2653 channel_id: ChannelId,
2654 room: &proto::Room,
2655 channel_members: &[UserId],
2656 peer: &Peer,
2657 pool: &ConnectionPool,
2658) {
2659 let participants = room
2660 .participants
2661 .iter()
2662 .map(|p| p.user_id)
2663 .collect::<Vec<_>>();
2664
2665 broadcast(
2666 None,
2667 channel_members
2668 .iter()
2669 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2670 |peer_id| {
2671 peer.send(
2672 peer_id.into(),
2673 proto::UpdateChannels {
2674 channel_participants: vec![proto::ChannelParticipants {
2675 channel_id: channel_id.to_proto(),
2676 participant_user_ids: participants.clone(),
2677 }],
2678 ..Default::default()
2679 },
2680 )
2681 },
2682 );
2683}
2684
2685async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
2686 let db = session.db().await;
2687
2688 let contacts = db.get_contacts(user_id).await?;
2689 let busy = db.is_user_busy(user_id).await?;
2690
2691 let pool = session.connection_pool().await;
2692 let updated_contact = contact_for_user(user_id, false, busy, &pool);
2693 for contact in contacts {
2694 if let db::Contact::Accepted {
2695 user_id: contact_user_id,
2696 ..
2697 } = contact
2698 {
2699 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
2700 session
2701 .peer
2702 .send(
2703 contact_conn_id,
2704 proto::UpdateContacts {
2705 contacts: vec![updated_contact.clone()],
2706 remove_contacts: Default::default(),
2707 incoming_requests: Default::default(),
2708 remove_incoming_requests: Default::default(),
2709 outgoing_requests: Default::default(),
2710 remove_outgoing_requests: Default::default(),
2711 },
2712 )
2713 .trace_err();
2714 }
2715 }
2716 }
2717 Ok(())
2718}
2719
2720async fn leave_room_for_session(session: &Session) -> Result<()> {
2721 let mut contacts_to_update = HashSet::default();
2722
2723 let room_id;
2724 let canceled_calls_to_user_ids;
2725 let live_kit_room;
2726 let delete_live_kit_room;
2727 let room;
2728 let channel_members;
2729 let channel_id;
2730
2731 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
2732 contacts_to_update.insert(session.user_id);
2733
2734 for project in left_room.left_projects.values() {
2735 project_left(project, session);
2736 }
2737
2738 room_id = RoomId::from_proto(left_room.room.id);
2739 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
2740 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
2741 delete_live_kit_room = left_room.deleted;
2742 room = mem::take(&mut left_room.room);
2743 channel_members = mem::take(&mut left_room.channel_members);
2744 channel_id = left_room.channel_id;
2745
2746 room_updated(&room, &session.peer);
2747 } else {
2748 return Ok(());
2749 }
2750
2751 // TODO - do this while holding the room guard.
2752 if let Some(channel_id) = channel_id {
2753 channel_updated(
2754 channel_id,
2755 &room,
2756 &channel_members,
2757 &session.peer,
2758 &*session.connection_pool().await,
2759 );
2760 }
2761
2762 {
2763 let pool = session.connection_pool().await;
2764 for canceled_user_id in canceled_calls_to_user_ids {
2765 for connection_id in pool.user_connection_ids(canceled_user_id) {
2766 session
2767 .peer
2768 .send(
2769 connection_id,
2770 proto::CallCanceled {
2771 room_id: room_id.to_proto(),
2772 },
2773 )
2774 .trace_err();
2775 }
2776 contacts_to_update.insert(canceled_user_id);
2777 }
2778 }
2779
2780 for contact_user_id in contacts_to_update {
2781 update_user_contacts(contact_user_id, &session).await?;
2782 }
2783
2784 if let Some(live_kit) = session.live_kit_client.as_ref() {
2785 live_kit
2786 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
2787 .await
2788 .trace_err();
2789
2790 if delete_live_kit_room {
2791 live_kit.delete_room(live_kit_room).await.trace_err();
2792 }
2793 }
2794
2795 Ok(())
2796}
2797
2798fn project_left(project: &db::LeftProject, session: &Session) {
2799 for connection_id in &project.connection_ids {
2800 if project.host_user_id == session.user_id {
2801 session
2802 .peer
2803 .send(
2804 *connection_id,
2805 proto::UnshareProject {
2806 project_id: project.id.to_proto(),
2807 },
2808 )
2809 .trace_err();
2810 } else {
2811 session
2812 .peer
2813 .send(
2814 *connection_id,
2815 proto::RemoveProjectCollaborator {
2816 project_id: project.id.to_proto(),
2817 peer_id: Some(session.connection_id.into()),
2818 },
2819 )
2820 .trace_err();
2821 }
2822 }
2823}
2824
2825pub trait ResultExt {
2826 type Ok;
2827
2828 fn trace_err(self) -> Option<Self::Ok>;
2829}
2830
2831impl<T, E> ResultExt for Result<T, E>
2832where
2833 E: std::fmt::Debug,
2834{
2835 type Ok = T;
2836
2837 fn trace_err(self) -> Option<T> {
2838 match self {
2839 Ok(value) => Some(value),
2840 Err(error) => {
2841 tracing::error!("{:?}", error);
2842 None
2843 }
2844 }
2845 }
2846}