1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{self, ChannelId, Database, ProjectId, RoomId, ServerId, User, UserId},
6 executor::Executor,
7 AppState, Result,
8};
9use anyhow::anyhow;
10use async_tungstenite::tungstenite::{
11 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
12};
13use axum::{
14 body::Body,
15 extract::{
16 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
17 ConnectInfo, WebSocketUpgrade,
18 },
19 headers::{Header, HeaderName},
20 http::StatusCode,
21 middleware,
22 response::IntoResponse,
23 routing::get,
24 Extension, Router, TypedHeader,
25};
26use collections::{HashMap, HashSet};
27pub use connection_pool::ConnectionPool;
28use futures::{
29 channel::oneshot,
30 future::{self, BoxFuture},
31 stream::FuturesUnordered,
32 FutureExt, SinkExt, StreamExt, TryStreamExt,
33};
34use lazy_static::lazy_static;
35use prometheus::{register_int_gauge, IntGauge};
36use rpc::{
37 proto::{
38 self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
39 RequestMessage,
40 },
41 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
42};
43use serde::{Serialize, Serializer};
44use std::{
45 any::TypeId,
46 fmt,
47 future::Future,
48 marker::PhantomData,
49 mem,
50 net::SocketAddr,
51 ops::{Deref, DerefMut},
52 rc::Rc,
53 sync::{
54 atomic::{AtomicBool, Ordering::SeqCst},
55 Arc,
56 },
57 time::{Duration, Instant},
58};
59use tokio::sync::{watch, Semaphore};
60use tower::ServiceBuilder;
61use tracing::{info_span, instrument, Instrument};
62
63pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
64pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
65
66lazy_static! {
67 static ref METRIC_CONNECTIONS: IntGauge =
68 register_int_gauge!("connections", "number of connections").unwrap();
69 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
70 "shared_projects",
71 "number of open projects with one or more guests"
72 )
73 .unwrap();
74}
75
76type MessageHandler =
77 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
78
79struct Response<R> {
80 peer: Arc<Peer>,
81 receipt: Receipt<R>,
82 responded: Arc<AtomicBool>,
83}
84
85impl<R: RequestMessage> Response<R> {
86 fn send(self, payload: R::Response) -> Result<()> {
87 self.responded.store(true, SeqCst);
88 self.peer.respond(self.receipt, payload)?;
89 Ok(())
90 }
91}
92
93#[derive(Clone)]
94struct Session {
95 user_id: UserId,
96 connection_id: ConnectionId,
97 db: Arc<tokio::sync::Mutex<DbHandle>>,
98 peer: Arc<Peer>,
99 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
100 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
101 executor: Executor,
102}
103
104impl Session {
105 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
106 #[cfg(test)]
107 tokio::task::yield_now().await;
108 let guard = self.db.lock().await;
109 #[cfg(test)]
110 tokio::task::yield_now().await;
111 guard
112 }
113
114 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
115 #[cfg(test)]
116 tokio::task::yield_now().await;
117 let guard = self.connection_pool.lock();
118 ConnectionPoolGuard {
119 guard,
120 _not_send: PhantomData,
121 }
122 }
123}
124
125impl fmt::Debug for Session {
126 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127 f.debug_struct("Session")
128 .field("user_id", &self.user_id)
129 .field("connection_id", &self.connection_id)
130 .finish()
131 }
132}
133
134struct DbHandle(Arc<Database>);
135
136impl Deref for DbHandle {
137 type Target = Database;
138
139 fn deref(&self) -> &Self::Target {
140 self.0.as_ref()
141 }
142}
143
144pub struct Server {
145 id: parking_lot::Mutex<ServerId>,
146 peer: Arc<Peer>,
147 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
148 app_state: Arc<AppState>,
149 executor: Executor,
150 handlers: HashMap<TypeId, MessageHandler>,
151 teardown: watch::Sender<()>,
152}
153
154pub(crate) struct ConnectionPoolGuard<'a> {
155 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
156 _not_send: PhantomData<Rc<()>>,
157}
158
159#[derive(Serialize)]
160pub struct ServerSnapshot<'a> {
161 peer: &'a Peer,
162 #[serde(serialize_with = "serialize_deref")]
163 connection_pool: ConnectionPoolGuard<'a>,
164}
165
166pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
167where
168 S: Serializer,
169 T: Deref<Target = U>,
170 U: Serialize,
171{
172 Serialize::serialize(value.deref(), serializer)
173}
174
175impl Server {
176 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
177 let mut server = Self {
178 id: parking_lot::Mutex::new(id),
179 peer: Peer::new(id.0 as u32),
180 app_state,
181 executor,
182 connection_pool: Default::default(),
183 handlers: Default::default(),
184 teardown: watch::channel(()).0,
185 };
186
187 server
188 .add_request_handler(ping)
189 .add_request_handler(create_room)
190 .add_request_handler(join_room)
191 .add_request_handler(rejoin_room)
192 .add_request_handler(leave_room)
193 .add_request_handler(call)
194 .add_request_handler(cancel_call)
195 .add_message_handler(decline_call)
196 .add_request_handler(update_participant_location)
197 .add_request_handler(share_project)
198 .add_message_handler(unshare_project)
199 .add_request_handler(join_project)
200 .add_message_handler(leave_project)
201 .add_request_handler(update_project)
202 .add_request_handler(update_worktree)
203 .add_message_handler(start_language_server)
204 .add_message_handler(update_language_server)
205 .add_message_handler(update_diagnostic_summary)
206 .add_message_handler(update_worktree_settings)
207 .add_message_handler(refresh_inlay_hints)
208 .add_request_handler(forward_project_request::<proto::GetHover>)
209 .add_request_handler(forward_project_request::<proto::GetDefinition>)
210 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
211 .add_request_handler(forward_project_request::<proto::GetReferences>)
212 .add_request_handler(forward_project_request::<proto::SearchProject>)
213 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
214 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
215 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
216 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
217 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
218 .add_request_handler(forward_project_request::<proto::GetCompletions>)
219 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
220 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
221 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
222 .add_request_handler(forward_project_request::<proto::PrepareRename>)
223 .add_request_handler(forward_project_request::<proto::PerformRename>)
224 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
225 .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
226 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
227 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
228 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
229 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
230 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
231 .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
232 .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
233 .add_request_handler(forward_project_request::<proto::InlayHints>)
234 .add_message_handler(create_buffer_for_peer)
235 .add_request_handler(update_buffer)
236 .add_message_handler(update_buffer_file)
237 .add_message_handler(buffer_reloaded)
238 .add_message_handler(buffer_saved)
239 .add_request_handler(forward_project_request::<proto::SaveBuffer>)
240 .add_request_handler(get_users)
241 .add_request_handler(fuzzy_search_users)
242 .add_request_handler(request_contact)
243 .add_request_handler(remove_contact)
244 .add_request_handler(respond_to_contact_request)
245 .add_request_handler(create_channel)
246 .add_request_handler(remove_channel)
247 .add_request_handler(invite_channel_member)
248 .add_request_handler(remove_channel_member)
249 .add_request_handler(set_channel_member_admin)
250 .add_request_handler(get_channel_members)
251 .add_request_handler(respond_to_channel_invite)
252 .add_request_handler(join_channel)
253 .add_request_handler(follow)
254 .add_message_handler(unfollow)
255 .add_message_handler(update_followers)
256 .add_message_handler(update_diff_base)
257 .add_request_handler(get_private_user_info);
258
259 Arc::new(server)
260 }
261
262 pub async fn start(&self) -> Result<()> {
263 let server_id = *self.id.lock();
264 let app_state = self.app_state.clone();
265 let peer = self.peer.clone();
266 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
267 let pool = self.connection_pool.clone();
268 let live_kit_client = self.app_state.live_kit_client.clone();
269
270 let span = info_span!("start server");
271 self.executor.spawn_detached(
272 async move {
273 tracing::info!("waiting for cleanup timeout");
274 timeout.await;
275 tracing::info!("cleanup timeout expired, retrieving stale rooms");
276 if let Some(room_ids) = app_state
277 .db
278 .stale_room_ids(&app_state.config.zed_environment, server_id)
279 .await
280 .trace_err()
281 {
282 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
283 for room_id in room_ids {
284 let mut contacts_to_update = HashSet::default();
285 let mut canceled_calls_to_user_ids = Vec::new();
286 let mut live_kit_room = String::new();
287 let mut delete_live_kit_room = false;
288
289 if let Some(mut refreshed_room) = app_state
290 .db
291 .refresh_room(room_id, server_id)
292 .await
293 .trace_err()
294 {
295 tracing::info!(
296 room_id = room_id.0,
297 new_participant_count = refreshed_room.room.participants.len(),
298 "refreshed room"
299 );
300 room_updated(&refreshed_room.room, &peer);
301 if let Some(channel_id) = refreshed_room.channel_id {
302 channel_updated(
303 channel_id,
304 &refreshed_room.room,
305 &refreshed_room.channel_members,
306 &peer,
307 &*pool.lock(),
308 );
309 }
310 contacts_to_update
311 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
312 contacts_to_update
313 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
314 canceled_calls_to_user_ids =
315 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
316 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
317 delete_live_kit_room = refreshed_room.room.participants.is_empty();
318 }
319
320 {
321 let pool = pool.lock();
322 for canceled_user_id in canceled_calls_to_user_ids {
323 for connection_id in pool.user_connection_ids(canceled_user_id) {
324 peer.send(
325 connection_id,
326 proto::CallCanceled {
327 room_id: room_id.to_proto(),
328 },
329 )
330 .trace_err();
331 }
332 }
333 }
334
335 for user_id in contacts_to_update {
336 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
337 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
338 if let Some((busy, contacts)) = busy.zip(contacts) {
339 let pool = pool.lock();
340 let updated_contact = contact_for_user(user_id, false, busy, &pool);
341 for contact in contacts {
342 if let db::Contact::Accepted {
343 user_id: contact_user_id,
344 ..
345 } = contact
346 {
347 for contact_conn_id in
348 pool.user_connection_ids(contact_user_id)
349 {
350 peer.send(
351 contact_conn_id,
352 proto::UpdateContacts {
353 contacts: vec![updated_contact.clone()],
354 remove_contacts: Default::default(),
355 incoming_requests: Default::default(),
356 remove_incoming_requests: Default::default(),
357 outgoing_requests: Default::default(),
358 remove_outgoing_requests: Default::default(),
359 },
360 )
361 .trace_err();
362 }
363 }
364 }
365 }
366 }
367
368 if let Some(live_kit) = live_kit_client.as_ref() {
369 if delete_live_kit_room {
370 live_kit.delete_room(live_kit_room).await.trace_err();
371 }
372 }
373 }
374 }
375
376 app_state
377 .db
378 .delete_stale_servers(&app_state.config.zed_environment, server_id)
379 .await
380 .trace_err();
381 }
382 .instrument(span),
383 );
384 Ok(())
385 }
386
387 pub fn teardown(&self) {
388 self.peer.teardown();
389 self.connection_pool.lock().reset();
390 let _ = self.teardown.send(());
391 }
392
393 #[cfg(test)]
394 pub fn reset(&self, id: ServerId) {
395 self.teardown();
396 *self.id.lock() = id;
397 self.peer.reset(id.0 as u32);
398 }
399
400 #[cfg(test)]
401 pub fn id(&self) -> ServerId {
402 *self.id.lock()
403 }
404
405 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
406 where
407 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
408 Fut: 'static + Send + Future<Output = Result<()>>,
409 M: EnvelopedMessage,
410 {
411 let prev_handler = self.handlers.insert(
412 TypeId::of::<M>(),
413 Box::new(move |envelope, session| {
414 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
415 let span = info_span!(
416 "handle message",
417 payload_type = envelope.payload_type_name()
418 );
419 span.in_scope(|| {
420 tracing::info!(
421 payload_type = envelope.payload_type_name(),
422 "message received"
423 );
424 });
425 let start_time = Instant::now();
426 let future = (handler)(*envelope, session);
427 async move {
428 let result = future.await;
429 let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
430 match result {
431 Err(error) => {
432 tracing::error!(%error, ?duration_ms, "error handling message")
433 }
434 Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
435 }
436 }
437 .instrument(span)
438 .boxed()
439 }),
440 );
441 if prev_handler.is_some() {
442 panic!("registered a handler for the same message twice");
443 }
444 self
445 }
446
447 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
448 where
449 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
450 Fut: 'static + Send + Future<Output = Result<()>>,
451 M: EnvelopedMessage,
452 {
453 self.add_handler(move |envelope, session| handler(envelope.payload, session));
454 self
455 }
456
457 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
458 where
459 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
460 Fut: Send + Future<Output = Result<()>>,
461 M: RequestMessage,
462 {
463 let handler = Arc::new(handler);
464 self.add_handler(move |envelope, session| {
465 let receipt = envelope.receipt();
466 let handler = handler.clone();
467 async move {
468 let peer = session.peer.clone();
469 let responded = Arc::new(AtomicBool::default());
470 let response = Response {
471 peer: peer.clone(),
472 responded: responded.clone(),
473 receipt,
474 };
475 match (handler)(envelope.payload, response, session).await {
476 Ok(()) => {
477 if responded.load(std::sync::atomic::Ordering::SeqCst) {
478 Ok(())
479 } else {
480 Err(anyhow!("handler did not send a response"))?
481 }
482 }
483 Err(error) => {
484 peer.respond_with_error(
485 receipt,
486 proto::Error {
487 message: error.to_string(),
488 },
489 )?;
490 Err(error)
491 }
492 }
493 }
494 })
495 }
496
497 pub fn handle_connection(
498 self: &Arc<Self>,
499 connection: Connection,
500 address: String,
501 user: User,
502 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
503 executor: Executor,
504 ) -> impl Future<Output = Result<()>> {
505 let this = self.clone();
506 let user_id = user.id;
507 let login = user.github_login;
508 let span = info_span!("handle connection", %user_id, %login, %address);
509 let mut teardown = self.teardown.subscribe();
510 async move {
511 let (connection_id, handle_io, mut incoming_rx) = this
512 .peer
513 .add_connection(connection, {
514 let executor = executor.clone();
515 move |duration| executor.sleep(duration)
516 });
517
518 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
519 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
520 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
521
522 if let Some(send_connection_id) = send_connection_id.take() {
523 let _ = send_connection_id.send(connection_id);
524 }
525
526 if !user.connected_once {
527 this.peer.send(connection_id, proto::ShowContacts {})?;
528 this.app_state.db.set_user_connected_once(user_id, true).await?;
529 }
530
531 let (contacts, invite_code, (channels, channel_participants), channel_invites) = future::try_join4(
532 this.app_state.db.get_contacts(user_id),
533 this.app_state.db.get_invite_code_for_user(user_id),
534 this.app_state.db.get_channels_for_user(user_id),
535 this.app_state.db.get_channel_invites_for_user(user_id)
536 ).await?;
537
538 {
539 let mut pool = this.connection_pool.lock();
540 pool.add_connection(connection_id, user_id, user.admin);
541 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
542 this.peer.send(connection_id, build_initial_channels_update(channels, channel_participants, channel_invites))?;
543
544 if let Some((code, count)) = invite_code {
545 this.peer.send(connection_id, proto::UpdateInviteInfo {
546 url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
547 count: count as u32,
548 })?;
549 }
550 }
551
552 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
553 this.peer.send(connection_id, incoming_call)?;
554 }
555
556 let session = Session {
557 user_id,
558 connection_id,
559 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
560 peer: this.peer.clone(),
561 connection_pool: this.connection_pool.clone(),
562 live_kit_client: this.app_state.live_kit_client.clone(),
563 executor: executor.clone(),
564 };
565 update_user_contacts(user_id, &session).await?;
566
567 let handle_io = handle_io.fuse();
568 futures::pin_mut!(handle_io);
569
570 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
571 // This prevents deadlocks when e.g., client A performs a request to client B and
572 // client B performs a request to client A. If both clients stop processing further
573 // messages until their respective request completes, they won't have a chance to
574 // respond to the other client's request and cause a deadlock.
575 //
576 // This arrangement ensures we will attempt to process earlier messages first, but fall
577 // back to processing messages arrived later in the spirit of making progress.
578 let mut foreground_message_handlers = FuturesUnordered::new();
579 let concurrent_handlers = Arc::new(Semaphore::new(256));
580 loop {
581 let next_message = async {
582 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
583 let message = incoming_rx.next().await;
584 (permit, message)
585 }.fuse();
586 futures::pin_mut!(next_message);
587 futures::select_biased! {
588 _ = teardown.changed().fuse() => return Ok(()),
589 result = handle_io => {
590 if let Err(error) = result {
591 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
592 }
593 break;
594 }
595 _ = foreground_message_handlers.next() => {}
596 next_message = next_message => {
597 let (permit, message) = next_message;
598 if let Some(message) = message {
599 let type_name = message.payload_type_name();
600 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
601 let span_enter = span.enter();
602 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
603 let is_background = message.is_background();
604 let handle_message = (handler)(message, session.clone());
605 drop(span_enter);
606
607 let handle_message = async move {
608 handle_message.await;
609 drop(permit);
610 }.instrument(span);
611 if is_background {
612 executor.spawn_detached(handle_message);
613 } else {
614 foreground_message_handlers.push(handle_message);
615 }
616 } else {
617 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
618 }
619 } else {
620 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
621 break;
622 }
623 }
624 }
625 }
626
627 drop(foreground_message_handlers);
628 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
629 if let Err(error) = connection_lost(session, teardown, executor).await {
630 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
631 }
632
633 Ok(())
634 }.instrument(span)
635 }
636
637 pub async fn invite_code_redeemed(
638 self: &Arc<Self>,
639 inviter_id: UserId,
640 invitee_id: UserId,
641 ) -> Result<()> {
642 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
643 if let Some(code) = &user.invite_code {
644 let pool = self.connection_pool.lock();
645 let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
646 for connection_id in pool.user_connection_ids(inviter_id) {
647 self.peer.send(
648 connection_id,
649 proto::UpdateContacts {
650 contacts: vec![invitee_contact.clone()],
651 ..Default::default()
652 },
653 )?;
654 self.peer.send(
655 connection_id,
656 proto::UpdateInviteInfo {
657 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
658 count: user.invite_count as u32,
659 },
660 )?;
661 }
662 }
663 }
664 Ok(())
665 }
666
667 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
668 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
669 if let Some(invite_code) = &user.invite_code {
670 let pool = self.connection_pool.lock();
671 for connection_id in pool.user_connection_ids(user_id) {
672 self.peer.send(
673 connection_id,
674 proto::UpdateInviteInfo {
675 url: format!(
676 "{}{}",
677 self.app_state.config.invite_link_prefix, invite_code
678 ),
679 count: user.invite_count as u32,
680 },
681 )?;
682 }
683 }
684 }
685 Ok(())
686 }
687
688 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
689 ServerSnapshot {
690 connection_pool: ConnectionPoolGuard {
691 guard: self.connection_pool.lock(),
692 _not_send: PhantomData,
693 },
694 peer: &self.peer,
695 }
696 }
697}
698
699impl<'a> Deref for ConnectionPoolGuard<'a> {
700 type Target = ConnectionPool;
701
702 fn deref(&self) -> &Self::Target {
703 &*self.guard
704 }
705}
706
707impl<'a> DerefMut for ConnectionPoolGuard<'a> {
708 fn deref_mut(&mut self) -> &mut Self::Target {
709 &mut *self.guard
710 }
711}
712
713impl<'a> Drop for ConnectionPoolGuard<'a> {
714 fn drop(&mut self) {
715 #[cfg(test)]
716 self.check_invariants();
717 }
718}
719
720fn broadcast<F>(
721 sender_id: Option<ConnectionId>,
722 receiver_ids: impl IntoIterator<Item = ConnectionId>,
723 mut f: F,
724) where
725 F: FnMut(ConnectionId) -> anyhow::Result<()>,
726{
727 for receiver_id in receiver_ids {
728 if Some(receiver_id) != sender_id {
729 if let Err(error) = f(receiver_id) {
730 tracing::error!("failed to send to {:?} {}", receiver_id, error);
731 }
732 }
733 }
734}
735
736lazy_static! {
737 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
738}
739
740pub struct ProtocolVersion(u32);
741
742impl Header for ProtocolVersion {
743 fn name() -> &'static HeaderName {
744 &ZED_PROTOCOL_VERSION
745 }
746
747 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
748 where
749 Self: Sized,
750 I: Iterator<Item = &'i axum::http::HeaderValue>,
751 {
752 let version = values
753 .next()
754 .ok_or_else(axum::headers::Error::invalid)?
755 .to_str()
756 .map_err(|_| axum::headers::Error::invalid())?
757 .parse()
758 .map_err(|_| axum::headers::Error::invalid())?;
759 Ok(Self(version))
760 }
761
762 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
763 values.extend([self.0.to_string().parse().unwrap()]);
764 }
765}
766
767pub fn routes(server: Arc<Server>) -> Router<Body> {
768 Router::new()
769 .route("/rpc", get(handle_websocket_request))
770 .layer(
771 ServiceBuilder::new()
772 .layer(Extension(server.app_state.clone()))
773 .layer(middleware::from_fn(auth::validate_header)),
774 )
775 .route("/metrics", get(handle_metrics))
776 .layer(Extension(server))
777}
778
779pub async fn handle_websocket_request(
780 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
781 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
782 Extension(server): Extension<Arc<Server>>,
783 Extension(user): Extension<User>,
784 ws: WebSocketUpgrade,
785) -> axum::response::Response {
786 if protocol_version != rpc::PROTOCOL_VERSION {
787 return (
788 StatusCode::UPGRADE_REQUIRED,
789 "client must be upgraded".to_string(),
790 )
791 .into_response();
792 }
793 let socket_address = socket_address.to_string();
794 ws.on_upgrade(move |socket| {
795 use util::ResultExt;
796 let socket = socket
797 .map_ok(to_tungstenite_message)
798 .err_into()
799 .with(|message| async move { Ok(to_axum_message(message)) });
800 let connection = Connection::new(Box::pin(socket));
801 async move {
802 server
803 .handle_connection(connection, socket_address, user, None, Executor::Production)
804 .await
805 .log_err();
806 }
807 })
808}
809
810pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
811 let connections = server
812 .connection_pool
813 .lock()
814 .connections()
815 .filter(|connection| !connection.admin)
816 .count();
817
818 METRIC_CONNECTIONS.set(connections as _);
819
820 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
821 METRIC_SHARED_PROJECTS.set(shared_projects as _);
822
823 let encoder = prometheus::TextEncoder::new();
824 let metric_families = prometheus::gather();
825 let encoded_metrics = encoder
826 .encode_to_string(&metric_families)
827 .map_err(|err| anyhow!("{}", err))?;
828 Ok(encoded_metrics)
829}
830
831#[instrument(err, skip(executor))]
832async fn connection_lost(
833 session: Session,
834 mut teardown: watch::Receiver<()>,
835 executor: Executor,
836) -> Result<()> {
837 session.peer.disconnect(session.connection_id);
838 session
839 .connection_pool()
840 .await
841 .remove_connection(session.connection_id)?;
842
843 session
844 .db()
845 .await
846 .connection_lost(session.connection_id)
847 .await
848 .trace_err();
849
850 futures::select_biased! {
851 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
852 leave_room_for_session(&session).await.trace_err();
853
854 if !session
855 .connection_pool()
856 .await
857 .is_user_online(session.user_id)
858 {
859 let db = session.db().await;
860 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
861 room_updated(&room, &session.peer);
862 }
863 }
864 update_user_contacts(session.user_id, &session).await?;
865 }
866 _ = teardown.changed().fuse() => {}
867 }
868
869 Ok(())
870}
871
872async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
873 response.send(proto::Ack {})?;
874 Ok(())
875}
876
877async fn create_room(
878 _request: proto::CreateRoom,
879 response: Response<proto::CreateRoom>,
880 session: Session,
881) -> Result<()> {
882 let live_kit_room = nanoid::nanoid!(30);
883
884 let live_kit_connection_info = {
885 let live_kit_room = live_kit_room.clone();
886 let live_kit = session.live_kit_client.as_ref();
887
888 util::async_iife!({
889 let live_kit = live_kit?;
890
891 live_kit
892 .create_room(live_kit_room.clone())
893 .await
894 .trace_err()?;
895
896 let token = live_kit
897 .room_token(&live_kit_room, &session.user_id.to_string())
898 .trace_err()?;
899
900 Some(proto::LiveKitConnectionInfo {
901 server_url: live_kit.url().into(),
902 token,
903 })
904 })
905 }
906 .await;
907
908 let room = session
909 .db()
910 .await
911 .create_room(session.user_id, session.connection_id, &live_kit_room)
912 .await?;
913
914 response.send(proto::CreateRoomResponse {
915 room: Some(room.clone()),
916 live_kit_connection_info,
917 })?;
918
919 update_user_contacts(session.user_id, &session).await?;
920 Ok(())
921}
922
923async fn join_room(
924 request: proto::JoinRoom,
925 response: Response<proto::JoinRoom>,
926 session: Session,
927) -> Result<()> {
928 let room_id = RoomId::from_proto(request.id);
929 let room = {
930 let room = session
931 .db()
932 .await
933 .join_room(room_id, session.user_id, None, session.connection_id)
934 .await?;
935 room_updated(&room.room, &session.peer);
936 room.room.clone()
937 };
938
939 for connection_id in session
940 .connection_pool()
941 .await
942 .user_connection_ids(session.user_id)
943 {
944 session
945 .peer
946 .send(
947 connection_id,
948 proto::CallCanceled {
949 room_id: room_id.to_proto(),
950 },
951 )
952 .trace_err();
953 }
954
955 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
956 if let Some(token) = live_kit
957 .room_token(&room.live_kit_room, &session.user_id.to_string())
958 .trace_err()
959 {
960 Some(proto::LiveKitConnectionInfo {
961 server_url: live_kit.url().into(),
962 token,
963 })
964 } else {
965 None
966 }
967 } else {
968 None
969 };
970
971 response.send(proto::JoinRoomResponse {
972 room: Some(room),
973 live_kit_connection_info,
974 })?;
975
976 update_user_contacts(session.user_id, &session).await?;
977 Ok(())
978}
979
980async fn rejoin_room(
981 request: proto::RejoinRoom,
982 response: Response<proto::RejoinRoom>,
983 session: Session,
984) -> Result<()> {
985 let room;
986 let channel_id;
987 let channel_members;
988 {
989 let mut rejoined_room = session
990 .db()
991 .await
992 .rejoin_room(request, session.user_id, session.connection_id)
993 .await?;
994
995 response.send(proto::RejoinRoomResponse {
996 room: Some(rejoined_room.room.clone()),
997 reshared_projects: rejoined_room
998 .reshared_projects
999 .iter()
1000 .map(|project| proto::ResharedProject {
1001 id: project.id.to_proto(),
1002 collaborators: project
1003 .collaborators
1004 .iter()
1005 .map(|collaborator| collaborator.to_proto())
1006 .collect(),
1007 })
1008 .collect(),
1009 rejoined_projects: rejoined_room
1010 .rejoined_projects
1011 .iter()
1012 .map(|rejoined_project| proto::RejoinedProject {
1013 id: rejoined_project.id.to_proto(),
1014 worktrees: rejoined_project
1015 .worktrees
1016 .iter()
1017 .map(|worktree| proto::WorktreeMetadata {
1018 id: worktree.id,
1019 root_name: worktree.root_name.clone(),
1020 visible: worktree.visible,
1021 abs_path: worktree.abs_path.clone(),
1022 })
1023 .collect(),
1024 collaborators: rejoined_project
1025 .collaborators
1026 .iter()
1027 .map(|collaborator| collaborator.to_proto())
1028 .collect(),
1029 language_servers: rejoined_project.language_servers.clone(),
1030 })
1031 .collect(),
1032 })?;
1033 room_updated(&rejoined_room.room, &session.peer);
1034
1035 for project in &rejoined_room.reshared_projects {
1036 for collaborator in &project.collaborators {
1037 session
1038 .peer
1039 .send(
1040 collaborator.connection_id,
1041 proto::UpdateProjectCollaborator {
1042 project_id: project.id.to_proto(),
1043 old_peer_id: Some(project.old_connection_id.into()),
1044 new_peer_id: Some(session.connection_id.into()),
1045 },
1046 )
1047 .trace_err();
1048 }
1049
1050 broadcast(
1051 Some(session.connection_id),
1052 project
1053 .collaborators
1054 .iter()
1055 .map(|collaborator| collaborator.connection_id),
1056 |connection_id| {
1057 session.peer.forward_send(
1058 session.connection_id,
1059 connection_id,
1060 proto::UpdateProject {
1061 project_id: project.id.to_proto(),
1062 worktrees: project.worktrees.clone(),
1063 },
1064 )
1065 },
1066 );
1067 }
1068
1069 for project in &rejoined_room.rejoined_projects {
1070 for collaborator in &project.collaborators {
1071 session
1072 .peer
1073 .send(
1074 collaborator.connection_id,
1075 proto::UpdateProjectCollaborator {
1076 project_id: project.id.to_proto(),
1077 old_peer_id: Some(project.old_connection_id.into()),
1078 new_peer_id: Some(session.connection_id.into()),
1079 },
1080 )
1081 .trace_err();
1082 }
1083 }
1084
1085 for project in &mut rejoined_room.rejoined_projects {
1086 for worktree in mem::take(&mut project.worktrees) {
1087 #[cfg(any(test, feature = "test-support"))]
1088 const MAX_CHUNK_SIZE: usize = 2;
1089 #[cfg(not(any(test, feature = "test-support")))]
1090 const MAX_CHUNK_SIZE: usize = 256;
1091
1092 // Stream this worktree's entries.
1093 let message = proto::UpdateWorktree {
1094 project_id: project.id.to_proto(),
1095 worktree_id: worktree.id,
1096 abs_path: worktree.abs_path.clone(),
1097 root_name: worktree.root_name,
1098 updated_entries: worktree.updated_entries,
1099 removed_entries: worktree.removed_entries,
1100 scan_id: worktree.scan_id,
1101 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1102 updated_repositories: worktree.updated_repositories,
1103 removed_repositories: worktree.removed_repositories,
1104 };
1105 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1106 session.peer.send(session.connection_id, update.clone())?;
1107 }
1108
1109 // Stream this worktree's diagnostics.
1110 for summary in worktree.diagnostic_summaries {
1111 session.peer.send(
1112 session.connection_id,
1113 proto::UpdateDiagnosticSummary {
1114 project_id: project.id.to_proto(),
1115 worktree_id: worktree.id,
1116 summary: Some(summary),
1117 },
1118 )?;
1119 }
1120
1121 for settings_file in worktree.settings_files {
1122 session.peer.send(
1123 session.connection_id,
1124 proto::UpdateWorktreeSettings {
1125 project_id: project.id.to_proto(),
1126 worktree_id: worktree.id,
1127 path: settings_file.path,
1128 content: Some(settings_file.content),
1129 },
1130 )?;
1131 }
1132 }
1133
1134 for language_server in &project.language_servers {
1135 session.peer.send(
1136 session.connection_id,
1137 proto::UpdateLanguageServer {
1138 project_id: project.id.to_proto(),
1139 language_server_id: language_server.id,
1140 variant: Some(
1141 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1142 proto::LspDiskBasedDiagnosticsUpdated {},
1143 ),
1144 ),
1145 },
1146 )?;
1147 }
1148 }
1149
1150 room = mem::take(&mut rejoined_room.room);
1151 channel_id = rejoined_room.channel_id;
1152 channel_members = mem::take(&mut rejoined_room.channel_members);
1153 }
1154
1155 //TODO: move this into the room guard
1156 if let Some(channel_id) = channel_id {
1157 channel_updated(
1158 channel_id,
1159 &room,
1160 &channel_members,
1161 &session.peer,
1162 &*session.connection_pool().await,
1163 );
1164 }
1165
1166 update_user_contacts(session.user_id, &session).await?;
1167 Ok(())
1168}
1169
1170async fn leave_room(
1171 _: proto::LeaveRoom,
1172 response: Response<proto::LeaveRoom>,
1173 session: Session,
1174) -> Result<()> {
1175 leave_room_for_session(&session).await?;
1176 response.send(proto::Ack {})?;
1177 Ok(())
1178}
1179
1180async fn call(
1181 request: proto::Call,
1182 response: Response<proto::Call>,
1183 session: Session,
1184) -> Result<()> {
1185 let room_id = RoomId::from_proto(request.room_id);
1186 let calling_user_id = session.user_id;
1187 let calling_connection_id = session.connection_id;
1188 let called_user_id = UserId::from_proto(request.called_user_id);
1189 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1190 if !session
1191 .db()
1192 .await
1193 .has_contact(calling_user_id, called_user_id)
1194 .await?
1195 {
1196 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1197 }
1198
1199 let incoming_call = {
1200 let (room, incoming_call) = &mut *session
1201 .db()
1202 .await
1203 .call(
1204 room_id,
1205 calling_user_id,
1206 calling_connection_id,
1207 called_user_id,
1208 initial_project_id,
1209 )
1210 .await?;
1211 room_updated(&room, &session.peer);
1212 mem::take(incoming_call)
1213 };
1214 update_user_contacts(called_user_id, &session).await?;
1215
1216 let mut calls = session
1217 .connection_pool()
1218 .await
1219 .user_connection_ids(called_user_id)
1220 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1221 .collect::<FuturesUnordered<_>>();
1222
1223 while let Some(call_response) = calls.next().await {
1224 match call_response.as_ref() {
1225 Ok(_) => {
1226 response.send(proto::Ack {})?;
1227 return Ok(());
1228 }
1229 Err(_) => {
1230 call_response.trace_err();
1231 }
1232 }
1233 }
1234
1235 {
1236 let room = session
1237 .db()
1238 .await
1239 .call_failed(room_id, called_user_id)
1240 .await?;
1241 room_updated(&room, &session.peer);
1242 }
1243 update_user_contacts(called_user_id, &session).await?;
1244
1245 Err(anyhow!("failed to ring user"))?
1246}
1247
1248async fn cancel_call(
1249 request: proto::CancelCall,
1250 response: Response<proto::CancelCall>,
1251 session: Session,
1252) -> Result<()> {
1253 let called_user_id = UserId::from_proto(request.called_user_id);
1254 let room_id = RoomId::from_proto(request.room_id);
1255 {
1256 let room = session
1257 .db()
1258 .await
1259 .cancel_call(room_id, session.connection_id, called_user_id)
1260 .await?;
1261 room_updated(&room, &session.peer);
1262 }
1263
1264 for connection_id in session
1265 .connection_pool()
1266 .await
1267 .user_connection_ids(called_user_id)
1268 {
1269 session
1270 .peer
1271 .send(
1272 connection_id,
1273 proto::CallCanceled {
1274 room_id: room_id.to_proto(),
1275 },
1276 )
1277 .trace_err();
1278 }
1279 response.send(proto::Ack {})?;
1280
1281 update_user_contacts(called_user_id, &session).await?;
1282 Ok(())
1283}
1284
1285async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1286 let room_id = RoomId::from_proto(message.room_id);
1287 {
1288 let room = session
1289 .db()
1290 .await
1291 .decline_call(Some(room_id), session.user_id)
1292 .await?
1293 .ok_or_else(|| anyhow!("failed to decline call"))?;
1294 room_updated(&room, &session.peer);
1295 }
1296
1297 for connection_id in session
1298 .connection_pool()
1299 .await
1300 .user_connection_ids(session.user_id)
1301 {
1302 session
1303 .peer
1304 .send(
1305 connection_id,
1306 proto::CallCanceled {
1307 room_id: room_id.to_proto(),
1308 },
1309 )
1310 .trace_err();
1311 }
1312 update_user_contacts(session.user_id, &session).await?;
1313 Ok(())
1314}
1315
1316async fn update_participant_location(
1317 request: proto::UpdateParticipantLocation,
1318 response: Response<proto::UpdateParticipantLocation>,
1319 session: Session,
1320) -> Result<()> {
1321 let room_id = RoomId::from_proto(request.room_id);
1322 let location = request
1323 .location
1324 .ok_or_else(|| anyhow!("invalid location"))?;
1325
1326 let db = session.db().await;
1327 let room = db
1328 .update_room_participant_location(room_id, session.connection_id, location)
1329 .await?;
1330
1331 room_updated(&room, &session.peer);
1332 response.send(proto::Ack {})?;
1333 Ok(())
1334}
1335
1336async fn share_project(
1337 request: proto::ShareProject,
1338 response: Response<proto::ShareProject>,
1339 session: Session,
1340) -> Result<()> {
1341 let (project_id, room) = &*session
1342 .db()
1343 .await
1344 .share_project(
1345 RoomId::from_proto(request.room_id),
1346 session.connection_id,
1347 &request.worktrees,
1348 )
1349 .await?;
1350 response.send(proto::ShareProjectResponse {
1351 project_id: project_id.to_proto(),
1352 })?;
1353 room_updated(&room, &session.peer);
1354
1355 Ok(())
1356}
1357
1358async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1359 let project_id = ProjectId::from_proto(message.project_id);
1360
1361 let (room, guest_connection_ids) = &*session
1362 .db()
1363 .await
1364 .unshare_project(project_id, session.connection_id)
1365 .await?;
1366
1367 broadcast(
1368 Some(session.connection_id),
1369 guest_connection_ids.iter().copied(),
1370 |conn_id| session.peer.send(conn_id, message.clone()),
1371 );
1372 room_updated(&room, &session.peer);
1373
1374 Ok(())
1375}
1376
1377async fn join_project(
1378 request: proto::JoinProject,
1379 response: Response<proto::JoinProject>,
1380 session: Session,
1381) -> Result<()> {
1382 let project_id = ProjectId::from_proto(request.project_id);
1383 let guest_user_id = session.user_id;
1384
1385 tracing::info!(%project_id, "join project");
1386
1387 let (project, replica_id) = &mut *session
1388 .db()
1389 .await
1390 .join_project(project_id, session.connection_id)
1391 .await?;
1392
1393 let collaborators = project
1394 .collaborators
1395 .iter()
1396 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1397 .map(|collaborator| collaborator.to_proto())
1398 .collect::<Vec<_>>();
1399
1400 let worktrees = project
1401 .worktrees
1402 .iter()
1403 .map(|(id, worktree)| proto::WorktreeMetadata {
1404 id: *id,
1405 root_name: worktree.root_name.clone(),
1406 visible: worktree.visible,
1407 abs_path: worktree.abs_path.clone(),
1408 })
1409 .collect::<Vec<_>>();
1410
1411 for collaborator in &collaborators {
1412 session
1413 .peer
1414 .send(
1415 collaborator.peer_id.unwrap().into(),
1416 proto::AddProjectCollaborator {
1417 project_id: project_id.to_proto(),
1418 collaborator: Some(proto::Collaborator {
1419 peer_id: Some(session.connection_id.into()),
1420 replica_id: replica_id.0 as u32,
1421 user_id: guest_user_id.to_proto(),
1422 }),
1423 },
1424 )
1425 .trace_err();
1426 }
1427
1428 // First, we send the metadata associated with each worktree.
1429 response.send(proto::JoinProjectResponse {
1430 worktrees: worktrees.clone(),
1431 replica_id: replica_id.0 as u32,
1432 collaborators: collaborators.clone(),
1433 language_servers: project.language_servers.clone(),
1434 })?;
1435
1436 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1437 #[cfg(any(test, feature = "test-support"))]
1438 const MAX_CHUNK_SIZE: usize = 2;
1439 #[cfg(not(any(test, feature = "test-support")))]
1440 const MAX_CHUNK_SIZE: usize = 256;
1441
1442 // Stream this worktree's entries.
1443 let message = proto::UpdateWorktree {
1444 project_id: project_id.to_proto(),
1445 worktree_id,
1446 abs_path: worktree.abs_path.clone(),
1447 root_name: worktree.root_name,
1448 updated_entries: worktree.entries,
1449 removed_entries: Default::default(),
1450 scan_id: worktree.scan_id,
1451 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1452 updated_repositories: worktree.repository_entries.into_values().collect(),
1453 removed_repositories: Default::default(),
1454 };
1455 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1456 session.peer.send(session.connection_id, update.clone())?;
1457 }
1458
1459 // Stream this worktree's diagnostics.
1460 for summary in worktree.diagnostic_summaries {
1461 session.peer.send(
1462 session.connection_id,
1463 proto::UpdateDiagnosticSummary {
1464 project_id: project_id.to_proto(),
1465 worktree_id: worktree.id,
1466 summary: Some(summary),
1467 },
1468 )?;
1469 }
1470
1471 for settings_file in worktree.settings_files {
1472 session.peer.send(
1473 session.connection_id,
1474 proto::UpdateWorktreeSettings {
1475 project_id: project_id.to_proto(),
1476 worktree_id: worktree.id,
1477 path: settings_file.path,
1478 content: Some(settings_file.content),
1479 },
1480 )?;
1481 }
1482 }
1483
1484 for language_server in &project.language_servers {
1485 session.peer.send(
1486 session.connection_id,
1487 proto::UpdateLanguageServer {
1488 project_id: project_id.to_proto(),
1489 language_server_id: language_server.id,
1490 variant: Some(
1491 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1492 proto::LspDiskBasedDiagnosticsUpdated {},
1493 ),
1494 ),
1495 },
1496 )?;
1497 }
1498
1499 Ok(())
1500}
1501
1502async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1503 let sender_id = session.connection_id;
1504 let project_id = ProjectId::from_proto(request.project_id);
1505
1506 let (room, project) = &*session
1507 .db()
1508 .await
1509 .leave_project(project_id, sender_id)
1510 .await?;
1511 tracing::info!(
1512 %project_id,
1513 host_user_id = %project.host_user_id,
1514 host_connection_id = %project.host_connection_id,
1515 "leave project"
1516 );
1517
1518 project_left(&project, &session);
1519 room_updated(&room, &session.peer);
1520
1521 Ok(())
1522}
1523
1524async fn update_project(
1525 request: proto::UpdateProject,
1526 response: Response<proto::UpdateProject>,
1527 session: Session,
1528) -> Result<()> {
1529 let project_id = ProjectId::from_proto(request.project_id);
1530 let (room, guest_connection_ids) = &*session
1531 .db()
1532 .await
1533 .update_project(project_id, session.connection_id, &request.worktrees)
1534 .await?;
1535 broadcast(
1536 Some(session.connection_id),
1537 guest_connection_ids.iter().copied(),
1538 |connection_id| {
1539 session
1540 .peer
1541 .forward_send(session.connection_id, connection_id, request.clone())
1542 },
1543 );
1544 room_updated(&room, &session.peer);
1545 response.send(proto::Ack {})?;
1546
1547 Ok(())
1548}
1549
1550async fn update_worktree(
1551 request: proto::UpdateWorktree,
1552 response: Response<proto::UpdateWorktree>,
1553 session: Session,
1554) -> Result<()> {
1555 let guest_connection_ids = session
1556 .db()
1557 .await
1558 .update_worktree(&request, session.connection_id)
1559 .await?;
1560
1561 broadcast(
1562 Some(session.connection_id),
1563 guest_connection_ids.iter().copied(),
1564 |connection_id| {
1565 session
1566 .peer
1567 .forward_send(session.connection_id, connection_id, request.clone())
1568 },
1569 );
1570 response.send(proto::Ack {})?;
1571 Ok(())
1572}
1573
1574async fn update_diagnostic_summary(
1575 message: proto::UpdateDiagnosticSummary,
1576 session: Session,
1577) -> Result<()> {
1578 let guest_connection_ids = session
1579 .db()
1580 .await
1581 .update_diagnostic_summary(&message, session.connection_id)
1582 .await?;
1583
1584 broadcast(
1585 Some(session.connection_id),
1586 guest_connection_ids.iter().copied(),
1587 |connection_id| {
1588 session
1589 .peer
1590 .forward_send(session.connection_id, connection_id, message.clone())
1591 },
1592 );
1593
1594 Ok(())
1595}
1596
1597async fn update_worktree_settings(
1598 message: proto::UpdateWorktreeSettings,
1599 session: Session,
1600) -> Result<()> {
1601 let guest_connection_ids = session
1602 .db()
1603 .await
1604 .update_worktree_settings(&message, session.connection_id)
1605 .await?;
1606
1607 broadcast(
1608 Some(session.connection_id),
1609 guest_connection_ids.iter().copied(),
1610 |connection_id| {
1611 session
1612 .peer
1613 .forward_send(session.connection_id, connection_id, message.clone())
1614 },
1615 );
1616
1617 Ok(())
1618}
1619
1620async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1621 broadcast_project_message(request.project_id, request, session).await
1622}
1623
1624async fn start_language_server(
1625 request: proto::StartLanguageServer,
1626 session: Session,
1627) -> Result<()> {
1628 let guest_connection_ids = session
1629 .db()
1630 .await
1631 .start_language_server(&request, session.connection_id)
1632 .await?;
1633
1634 broadcast(
1635 Some(session.connection_id),
1636 guest_connection_ids.iter().copied(),
1637 |connection_id| {
1638 session
1639 .peer
1640 .forward_send(session.connection_id, connection_id, request.clone())
1641 },
1642 );
1643 Ok(())
1644}
1645
1646async fn update_language_server(
1647 request: proto::UpdateLanguageServer,
1648 session: Session,
1649) -> Result<()> {
1650 session.executor.record_backtrace();
1651 let project_id = ProjectId::from_proto(request.project_id);
1652 let project_connection_ids = session
1653 .db()
1654 .await
1655 .project_connection_ids(project_id, session.connection_id)
1656 .await?;
1657 broadcast(
1658 Some(session.connection_id),
1659 project_connection_ids.iter().copied(),
1660 |connection_id| {
1661 session
1662 .peer
1663 .forward_send(session.connection_id, connection_id, request.clone())
1664 },
1665 );
1666 Ok(())
1667}
1668
1669async fn forward_project_request<T>(
1670 request: T,
1671 response: Response<T>,
1672 session: Session,
1673) -> Result<()>
1674where
1675 T: EntityMessage + RequestMessage,
1676{
1677 session.executor.record_backtrace();
1678 let project_id = ProjectId::from_proto(request.remote_entity_id());
1679 let host_connection_id = {
1680 let collaborators = session
1681 .db()
1682 .await
1683 .project_collaborators(project_id, session.connection_id)
1684 .await?;
1685 collaborators
1686 .iter()
1687 .find(|collaborator| collaborator.is_host)
1688 .ok_or_else(|| anyhow!("host not found"))?
1689 .connection_id
1690 };
1691
1692 let payload = session
1693 .peer
1694 .forward_request(session.connection_id, host_connection_id, request)
1695 .await?;
1696
1697 response.send(payload)?;
1698 Ok(())
1699}
1700
1701async fn create_buffer_for_peer(
1702 request: proto::CreateBufferForPeer,
1703 session: Session,
1704) -> Result<()> {
1705 session.executor.record_backtrace();
1706 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1707 session
1708 .peer
1709 .forward_send(session.connection_id, peer_id.into(), request)?;
1710 Ok(())
1711}
1712
1713async fn update_buffer(
1714 request: proto::UpdateBuffer,
1715 response: Response<proto::UpdateBuffer>,
1716 session: Session,
1717) -> Result<()> {
1718 session.executor.record_backtrace();
1719 let project_id = ProjectId::from_proto(request.project_id);
1720 let mut guest_connection_ids;
1721 let mut host_connection_id = None;
1722 {
1723 let collaborators = session
1724 .db()
1725 .await
1726 .project_collaborators(project_id, session.connection_id)
1727 .await?;
1728 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1729 for collaborator in collaborators.iter() {
1730 if collaborator.is_host {
1731 host_connection_id = Some(collaborator.connection_id);
1732 } else {
1733 guest_connection_ids.push(collaborator.connection_id);
1734 }
1735 }
1736 }
1737 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1738
1739 session.executor.record_backtrace();
1740 broadcast(
1741 Some(session.connection_id),
1742 guest_connection_ids,
1743 |connection_id| {
1744 session
1745 .peer
1746 .forward_send(session.connection_id, connection_id, request.clone())
1747 },
1748 );
1749 if host_connection_id != session.connection_id {
1750 session
1751 .peer
1752 .forward_request(session.connection_id, host_connection_id, request.clone())
1753 .await?;
1754 }
1755
1756 response.send(proto::Ack {})?;
1757 Ok(())
1758}
1759
1760async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1761 let project_id = ProjectId::from_proto(request.project_id);
1762 let project_connection_ids = session
1763 .db()
1764 .await
1765 .project_connection_ids(project_id, session.connection_id)
1766 .await?;
1767
1768 broadcast(
1769 Some(session.connection_id),
1770 project_connection_ids.iter().copied(),
1771 |connection_id| {
1772 session
1773 .peer
1774 .forward_send(session.connection_id, connection_id, request.clone())
1775 },
1776 );
1777 Ok(())
1778}
1779
1780async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1781 let project_id = ProjectId::from_proto(request.project_id);
1782 let project_connection_ids = session
1783 .db()
1784 .await
1785 .project_connection_ids(project_id, session.connection_id)
1786 .await?;
1787 broadcast(
1788 Some(session.connection_id),
1789 project_connection_ids.iter().copied(),
1790 |connection_id| {
1791 session
1792 .peer
1793 .forward_send(session.connection_id, connection_id, request.clone())
1794 },
1795 );
1796 Ok(())
1797}
1798
1799async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1800 broadcast_project_message(request.project_id, request, session).await
1801}
1802
1803async fn broadcast_project_message<T: EnvelopedMessage>(
1804 project_id: u64,
1805 request: T,
1806 session: Session,
1807) -> Result<()> {
1808 let project_id = ProjectId::from_proto(project_id);
1809 let project_connection_ids = session
1810 .db()
1811 .await
1812 .project_connection_ids(project_id, session.connection_id)
1813 .await?;
1814 broadcast(
1815 Some(session.connection_id),
1816 project_connection_ids.iter().copied(),
1817 |connection_id| {
1818 session
1819 .peer
1820 .forward_send(session.connection_id, connection_id, request.clone())
1821 },
1822 );
1823 Ok(())
1824}
1825
1826async fn follow(
1827 request: proto::Follow,
1828 response: Response<proto::Follow>,
1829 session: Session,
1830) -> Result<()> {
1831 let project_id = ProjectId::from_proto(request.project_id);
1832 let leader_id = request
1833 .leader_id
1834 .ok_or_else(|| anyhow!("invalid leader id"))?
1835 .into();
1836 let follower_id = session.connection_id;
1837
1838 {
1839 let project_connection_ids = session
1840 .db()
1841 .await
1842 .project_connection_ids(project_id, session.connection_id)
1843 .await?;
1844
1845 if !project_connection_ids.contains(&leader_id) {
1846 Err(anyhow!("no such peer"))?;
1847 }
1848 }
1849
1850 let mut response_payload = session
1851 .peer
1852 .forward_request(session.connection_id, leader_id, request)
1853 .await?;
1854 response_payload
1855 .views
1856 .retain(|view| view.leader_id != Some(follower_id.into()));
1857 response.send(response_payload)?;
1858
1859 let room = session
1860 .db()
1861 .await
1862 .follow(project_id, leader_id, follower_id)
1863 .await?;
1864 room_updated(&room, &session.peer);
1865
1866 Ok(())
1867}
1868
1869async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1870 let project_id = ProjectId::from_proto(request.project_id);
1871 let leader_id = request
1872 .leader_id
1873 .ok_or_else(|| anyhow!("invalid leader id"))?
1874 .into();
1875 let follower_id = session.connection_id;
1876
1877 if !session
1878 .db()
1879 .await
1880 .project_connection_ids(project_id, session.connection_id)
1881 .await?
1882 .contains(&leader_id)
1883 {
1884 Err(anyhow!("no such peer"))?;
1885 }
1886
1887 session
1888 .peer
1889 .forward_send(session.connection_id, leader_id, request)?;
1890
1891 let room = session
1892 .db()
1893 .await
1894 .unfollow(project_id, leader_id, follower_id)
1895 .await?;
1896 room_updated(&room, &session.peer);
1897
1898 Ok(())
1899}
1900
1901async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1902 let project_id = ProjectId::from_proto(request.project_id);
1903 let project_connection_ids = session
1904 .db
1905 .lock()
1906 .await
1907 .project_connection_ids(project_id, session.connection_id)
1908 .await?;
1909
1910 let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1911 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1912 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1913 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1914 });
1915 for follower_peer_id in request.follower_ids.iter().copied() {
1916 let follower_connection_id = follower_peer_id.into();
1917 if project_connection_ids.contains(&follower_connection_id)
1918 && Some(follower_peer_id) != leader_id
1919 {
1920 session.peer.forward_send(
1921 session.connection_id,
1922 follower_connection_id,
1923 request.clone(),
1924 )?;
1925 }
1926 }
1927 Ok(())
1928}
1929
1930async fn get_users(
1931 request: proto::GetUsers,
1932 response: Response<proto::GetUsers>,
1933 session: Session,
1934) -> Result<()> {
1935 let user_ids = request
1936 .user_ids
1937 .into_iter()
1938 .map(UserId::from_proto)
1939 .collect();
1940 let users = session
1941 .db()
1942 .await
1943 .get_users_by_ids(user_ids)
1944 .await?
1945 .into_iter()
1946 .map(|user| proto::User {
1947 id: user.id.to_proto(),
1948 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1949 github_login: user.github_login,
1950 })
1951 .collect();
1952 response.send(proto::UsersResponse { users })?;
1953 Ok(())
1954}
1955
1956async fn fuzzy_search_users(
1957 request: proto::FuzzySearchUsers,
1958 response: Response<proto::FuzzySearchUsers>,
1959 session: Session,
1960) -> Result<()> {
1961 let query = request.query;
1962 let users = match query.len() {
1963 0 => vec![],
1964 1 | 2 => session
1965 .db()
1966 .await
1967 .get_user_by_github_login(&query)
1968 .await?
1969 .into_iter()
1970 .collect(),
1971 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1972 };
1973 let users = users
1974 .into_iter()
1975 .filter(|user| user.id != session.user_id)
1976 .map(|user| proto::User {
1977 id: user.id.to_proto(),
1978 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1979 github_login: user.github_login,
1980 })
1981 .collect();
1982 response.send(proto::UsersResponse { users })?;
1983 Ok(())
1984}
1985
1986async fn request_contact(
1987 request: proto::RequestContact,
1988 response: Response<proto::RequestContact>,
1989 session: Session,
1990) -> Result<()> {
1991 let requester_id = session.user_id;
1992 let responder_id = UserId::from_proto(request.responder_id);
1993 if requester_id == responder_id {
1994 return Err(anyhow!("cannot add yourself as a contact"))?;
1995 }
1996
1997 session
1998 .db()
1999 .await
2000 .send_contact_request(requester_id, responder_id)
2001 .await?;
2002
2003 // Update outgoing contact requests of requester
2004 let mut update = proto::UpdateContacts::default();
2005 update.outgoing_requests.push(responder_id.to_proto());
2006 for connection_id in session
2007 .connection_pool()
2008 .await
2009 .user_connection_ids(requester_id)
2010 {
2011 session.peer.send(connection_id, update.clone())?;
2012 }
2013
2014 // Update incoming contact requests of responder
2015 let mut update = proto::UpdateContacts::default();
2016 update
2017 .incoming_requests
2018 .push(proto::IncomingContactRequest {
2019 requester_id: requester_id.to_proto(),
2020 should_notify: true,
2021 });
2022 for connection_id in session
2023 .connection_pool()
2024 .await
2025 .user_connection_ids(responder_id)
2026 {
2027 session.peer.send(connection_id, update.clone())?;
2028 }
2029
2030 response.send(proto::Ack {})?;
2031 Ok(())
2032}
2033
2034async fn respond_to_contact_request(
2035 request: proto::RespondToContactRequest,
2036 response: Response<proto::RespondToContactRequest>,
2037 session: Session,
2038) -> Result<()> {
2039 let responder_id = session.user_id;
2040 let requester_id = UserId::from_proto(request.requester_id);
2041 let db = session.db().await;
2042 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2043 db.dismiss_contact_notification(responder_id, requester_id)
2044 .await?;
2045 } else {
2046 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2047
2048 db.respond_to_contact_request(responder_id, requester_id, accept)
2049 .await?;
2050 let requester_busy = db.is_user_busy(requester_id).await?;
2051 let responder_busy = db.is_user_busy(responder_id).await?;
2052
2053 let pool = session.connection_pool().await;
2054 // Update responder with new contact
2055 let mut update = proto::UpdateContacts::default();
2056 if accept {
2057 update
2058 .contacts
2059 .push(contact_for_user(requester_id, false, requester_busy, &pool));
2060 }
2061 update
2062 .remove_incoming_requests
2063 .push(requester_id.to_proto());
2064 for connection_id in pool.user_connection_ids(responder_id) {
2065 session.peer.send(connection_id, update.clone())?;
2066 }
2067
2068 // Update requester with new contact
2069 let mut update = proto::UpdateContacts::default();
2070 if accept {
2071 update
2072 .contacts
2073 .push(contact_for_user(responder_id, true, responder_busy, &pool));
2074 }
2075 update
2076 .remove_outgoing_requests
2077 .push(responder_id.to_proto());
2078 for connection_id in pool.user_connection_ids(requester_id) {
2079 session.peer.send(connection_id, update.clone())?;
2080 }
2081 }
2082
2083 response.send(proto::Ack {})?;
2084 Ok(())
2085}
2086
2087async fn remove_contact(
2088 request: proto::RemoveContact,
2089 response: Response<proto::RemoveContact>,
2090 session: Session,
2091) -> Result<()> {
2092 let requester_id = session.user_id;
2093 let responder_id = UserId::from_proto(request.user_id);
2094 let db = session.db().await;
2095 let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
2096
2097 let pool = session.connection_pool().await;
2098 // Update outgoing contact requests of requester
2099 let mut update = proto::UpdateContacts::default();
2100 if contact_accepted {
2101 update.remove_contacts.push(responder_id.to_proto());
2102 } else {
2103 update
2104 .remove_outgoing_requests
2105 .push(responder_id.to_proto());
2106 }
2107 for connection_id in pool.user_connection_ids(requester_id) {
2108 session.peer.send(connection_id, update.clone())?;
2109 }
2110
2111 // Update incoming contact requests of responder
2112 let mut update = proto::UpdateContacts::default();
2113 if contact_accepted {
2114 update.remove_contacts.push(requester_id.to_proto());
2115 } else {
2116 update
2117 .remove_incoming_requests
2118 .push(requester_id.to_proto());
2119 }
2120 for connection_id in pool.user_connection_ids(responder_id) {
2121 session.peer.send(connection_id, update.clone())?;
2122 }
2123
2124 response.send(proto::Ack {})?;
2125 Ok(())
2126}
2127
2128async fn create_channel(
2129 request: proto::CreateChannel,
2130 response: Response<proto::CreateChannel>,
2131 session: Session,
2132) -> Result<()> {
2133 let db = session.db().await;
2134 let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
2135
2136 if let Some(live_kit) = session.live_kit_client.as_ref() {
2137 live_kit.create_room(live_kit_room.clone()).await?;
2138 }
2139
2140 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2141 let id = db
2142 .create_channel(&request.name, parent_id, &live_kit_room, session.user_id)
2143 .await?;
2144
2145 response.send(proto::CreateChannelResponse {
2146 channel_id: id.to_proto(),
2147 })?;
2148
2149 let mut update = proto::UpdateChannels::default();
2150 update.channels.push(proto::Channel {
2151 id: id.to_proto(),
2152 name: request.name,
2153 parent_id: request.parent_id,
2154 user_is_admin: false,
2155 });
2156
2157 let user_ids_to_notify = if let Some(parent_id) = parent_id {
2158 db.get_channel_members(parent_id).await?
2159 } else {
2160 vec![session.user_id]
2161 };
2162
2163 let connection_pool = session.connection_pool().await;
2164 for user_id in user_ids_to_notify {
2165 for connection_id in connection_pool.user_connection_ids(user_id) {
2166 let mut update = update.clone();
2167 if user_id == session.user_id {
2168 update.channels[0].user_is_admin = true;
2169 }
2170 session.peer.send(connection_id, update)?;
2171 }
2172 }
2173
2174 Ok(())
2175}
2176
2177async fn remove_channel(
2178 request: proto::RemoveChannel,
2179 response: Response<proto::RemoveChannel>,
2180 session: Session,
2181) -> Result<()> {
2182 let db = session.db().await;
2183
2184 let channel_id = request.channel_id;
2185 let (removed_channels, member_ids) = db
2186 .remove_channel(ChannelId::from_proto(channel_id), session.user_id)
2187 .await?;
2188 response.send(proto::Ack {})?;
2189
2190 // Notify members of removed channels
2191 let mut update = proto::UpdateChannels::default();
2192 update
2193 .remove_channels
2194 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2195
2196 let connection_pool = session.connection_pool().await;
2197 for member_id in member_ids {
2198 for connection_id in connection_pool.user_connection_ids(member_id) {
2199 session.peer.send(connection_id, update.clone())?;
2200 }
2201 }
2202
2203 Ok(())
2204}
2205
2206async fn invite_channel_member(
2207 request: proto::InviteChannelMember,
2208 response: Response<proto::InviteChannelMember>,
2209 session: Session,
2210) -> Result<()> {
2211 let db = session.db().await;
2212 let channel_id = ChannelId::from_proto(request.channel_id);
2213 let channel = db
2214 .get_channel(channel_id, session.user_id)
2215 .await?
2216 .ok_or_else(|| anyhow!("channel not found"))?;
2217 let invitee_id = UserId::from_proto(request.user_id);
2218 db.invite_channel_member(channel_id, invitee_id, session.user_id, request.admin)
2219 .await?;
2220
2221 let mut update = proto::UpdateChannels::default();
2222 update.channel_invitations.push(proto::Channel {
2223 id: channel.id.to_proto(),
2224 name: channel.name,
2225 parent_id: None,
2226 user_is_admin: false,
2227 });
2228 for connection_id in session
2229 .connection_pool()
2230 .await
2231 .user_connection_ids(invitee_id)
2232 {
2233 session.peer.send(connection_id, update.clone())?;
2234 }
2235
2236 response.send(proto::Ack {})?;
2237 Ok(())
2238}
2239
2240async fn remove_channel_member(
2241 request: proto::RemoveChannelMember,
2242 response: Response<proto::RemoveChannelMember>,
2243 session: Session,
2244) -> Result<()> {
2245 let db = session.db().await;
2246 let channel_id = ChannelId::from_proto(request.channel_id);
2247 let member_id = UserId::from_proto(request.user_id);
2248
2249 db.remove_channel_member(channel_id, member_id, session.user_id)
2250 .await?;
2251
2252 let mut update = proto::UpdateChannels::default();
2253 update.remove_channels.push(channel_id.to_proto());
2254
2255 for connection_id in session
2256 .connection_pool()
2257 .await
2258 .user_connection_ids(member_id)
2259 {
2260 session.peer.send(connection_id, update.clone())?;
2261 }
2262
2263 response.send(proto::Ack {})?;
2264 Ok(())
2265}
2266
2267async fn set_channel_member_admin(
2268 request: proto::SetChannelMemberAdmin,
2269 response: Response<proto::SetChannelMemberAdmin>,
2270 session: Session,
2271) -> Result<()> {
2272 let db = session.db().await;
2273 let channel_id = ChannelId::from_proto(request.channel_id);
2274 let member_id = UserId::from_proto(request.user_id);
2275 db.set_channel_member_admin(channel_id, session.user_id, member_id, request.admin)
2276 .await?;
2277
2278 let channel = db
2279 .get_channel(channel_id, member_id)
2280 .await?
2281 .ok_or_else(|| anyhow!("channel not found"))?;
2282
2283 let mut update = proto::UpdateChannels::default();
2284 update.channels.push(proto::Channel {
2285 id: channel.id.to_proto(),
2286 name: channel.name,
2287 parent_id: None,
2288 user_is_admin: request.admin,
2289 });
2290
2291 for connection_id in session
2292 .connection_pool()
2293 .await
2294 .user_connection_ids(member_id)
2295 {
2296 session.peer.send(connection_id, update.clone())?;
2297 }
2298
2299 response.send(proto::Ack {})?;
2300 Ok(())
2301}
2302
2303async fn get_channel_members(
2304 request: proto::GetChannelMembers,
2305 response: Response<proto::GetChannelMembers>,
2306 session: Session,
2307) -> Result<()> {
2308 let db = session.db().await;
2309 let channel_id = ChannelId::from_proto(request.channel_id);
2310 let members = db
2311 .get_channel_member_details(channel_id, session.user_id)
2312 .await?;
2313 response.send(proto::GetChannelMembersResponse { members })?;
2314 Ok(())
2315}
2316
2317async fn respond_to_channel_invite(
2318 request: proto::RespondToChannelInvite,
2319 response: Response<proto::RespondToChannelInvite>,
2320 session: Session,
2321) -> Result<()> {
2322 let db = session.db().await;
2323 let channel_id = ChannelId::from_proto(request.channel_id);
2324 db.respond_to_channel_invite(channel_id, session.user_id, request.accept)
2325 .await?;
2326
2327 let mut update = proto::UpdateChannels::default();
2328 update
2329 .remove_channel_invitations
2330 .push(channel_id.to_proto());
2331 if request.accept {
2332 let (channels, participants) = db.get_channels_for_user(session.user_id).await?;
2333 update
2334 .channels
2335 .extend(channels.into_iter().map(|channel| proto::Channel {
2336 id: channel.id.to_proto(),
2337 name: channel.name,
2338 user_is_admin: channel.user_is_admin,
2339 parent_id: channel.parent_id.map(ChannelId::to_proto),
2340 }));
2341 update
2342 .channel_participants
2343 .extend(participants.into_iter().map(|(channel_id, user_ids)| {
2344 proto::ChannelParticipants {
2345 channel_id: channel_id.to_proto(),
2346 participant_user_ids: user_ids.into_iter().map(UserId::to_proto).collect(),
2347 }
2348 }));
2349 }
2350 session.peer.send(session.connection_id, update)?;
2351 response.send(proto::Ack {})?;
2352
2353 Ok(())
2354}
2355
2356async fn join_channel(
2357 request: proto::JoinChannel,
2358 response: Response<proto::JoinChannel>,
2359 session: Session,
2360) -> Result<()> {
2361 let channel_id = ChannelId::from_proto(request.channel_id);
2362
2363 let joined_room = {
2364 let db = session.db().await;
2365
2366 let room_id = db.room_id_for_channel(channel_id).await?;
2367
2368 let joined_room = db
2369 .join_room(
2370 room_id,
2371 session.user_id,
2372 Some(channel_id),
2373 session.connection_id,
2374 )
2375 .await?;
2376
2377 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2378 let token = live_kit
2379 .room_token(
2380 &joined_room.room.live_kit_room,
2381 &session.user_id.to_string(),
2382 )
2383 .trace_err()?;
2384
2385 Some(LiveKitConnectionInfo {
2386 server_url: live_kit.url().into(),
2387 token,
2388 })
2389 });
2390
2391 response.send(proto::JoinRoomResponse {
2392 room: Some(joined_room.room.clone()),
2393 live_kit_connection_info,
2394 })?;
2395
2396 room_updated(&joined_room.room, &session.peer);
2397
2398 joined_room.clone()
2399 };
2400
2401 // TODO - do this while still holding the room guard,
2402 // currently there's a possible race condition if someone joins the channel
2403 // after we've dropped the lock but before we finish sending these updates
2404 channel_updated(
2405 channel_id,
2406 &joined_room.room,
2407 &joined_room.channel_members,
2408 &session.peer,
2409 &*session.connection_pool().await,
2410 );
2411
2412 update_user_contacts(session.user_id, &session).await?;
2413
2414 Ok(())
2415}
2416
2417async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
2418 let project_id = ProjectId::from_proto(request.project_id);
2419 let project_connection_ids = session
2420 .db()
2421 .await
2422 .project_connection_ids(project_id, session.connection_id)
2423 .await?;
2424 broadcast(
2425 Some(session.connection_id),
2426 project_connection_ids.iter().copied(),
2427 |connection_id| {
2428 session
2429 .peer
2430 .forward_send(session.connection_id, connection_id, request.clone())
2431 },
2432 );
2433 Ok(())
2434}
2435
2436async fn get_private_user_info(
2437 _request: proto::GetPrivateUserInfo,
2438 response: Response<proto::GetPrivateUserInfo>,
2439 session: Session,
2440) -> Result<()> {
2441 let metrics_id = session
2442 .db()
2443 .await
2444 .get_user_metrics_id(session.user_id)
2445 .await?;
2446 let user = session
2447 .db()
2448 .await
2449 .get_user_by_id(session.user_id)
2450 .await?
2451 .ok_or_else(|| anyhow!("user not found"))?;
2452 response.send(proto::GetPrivateUserInfoResponse {
2453 metrics_id,
2454 staff: user.admin,
2455 })?;
2456 Ok(())
2457}
2458
2459fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
2460 match message {
2461 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
2462 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
2463 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
2464 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
2465 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
2466 code: frame.code.into(),
2467 reason: frame.reason,
2468 })),
2469 }
2470}
2471
2472fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
2473 match message {
2474 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
2475 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
2476 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
2477 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
2478 AxumMessage::Close(frame) => {
2479 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
2480 code: frame.code.into(),
2481 reason: frame.reason,
2482 }))
2483 }
2484 }
2485}
2486
2487fn build_initial_channels_update(
2488 channels: Vec<db::Channel>,
2489 channel_participants: HashMap<db::ChannelId, Vec<UserId>>,
2490 channel_invites: Vec<db::Channel>,
2491) -> proto::UpdateChannels {
2492 let mut update = proto::UpdateChannels::default();
2493
2494 for channel in channels {
2495 update.channels.push(proto::Channel {
2496 id: channel.id.to_proto(),
2497 name: channel.name,
2498 user_is_admin: channel.user_is_admin,
2499 parent_id: channel.parent_id.map(|id| id.to_proto()),
2500 });
2501 }
2502
2503 for (channel_id, participants) in channel_participants {
2504 update
2505 .channel_participants
2506 .push(proto::ChannelParticipants {
2507 channel_id: channel_id.to_proto(),
2508 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
2509 });
2510 }
2511
2512 for channel in channel_invites {
2513 update.channel_invitations.push(proto::Channel {
2514 id: channel.id.to_proto(),
2515 name: channel.name,
2516 user_is_admin: false,
2517 parent_id: None,
2518 });
2519 }
2520
2521 update
2522}
2523
2524fn build_initial_contacts_update(
2525 contacts: Vec<db::Contact>,
2526 pool: &ConnectionPool,
2527) -> proto::UpdateContacts {
2528 let mut update = proto::UpdateContacts::default();
2529
2530 for contact in contacts {
2531 match contact {
2532 db::Contact::Accepted {
2533 user_id,
2534 should_notify,
2535 busy,
2536 } => {
2537 update
2538 .contacts
2539 .push(contact_for_user(user_id, should_notify, busy, &pool));
2540 }
2541 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
2542 db::Contact::Incoming {
2543 user_id,
2544 should_notify,
2545 } => update
2546 .incoming_requests
2547 .push(proto::IncomingContactRequest {
2548 requester_id: user_id.to_proto(),
2549 should_notify,
2550 }),
2551 }
2552 }
2553
2554 update
2555}
2556
2557fn contact_for_user(
2558 user_id: UserId,
2559 should_notify: bool,
2560 busy: bool,
2561 pool: &ConnectionPool,
2562) -> proto::Contact {
2563 proto::Contact {
2564 user_id: user_id.to_proto(),
2565 online: pool.is_user_online(user_id),
2566 busy,
2567 should_notify,
2568 }
2569}
2570
2571fn room_updated(room: &proto::Room, peer: &Peer) {
2572 broadcast(
2573 None,
2574 room.participants
2575 .iter()
2576 .filter_map(|participant| Some(participant.peer_id?.into())),
2577 |peer_id| {
2578 peer.send(
2579 peer_id.into(),
2580 proto::RoomUpdated {
2581 room: Some(room.clone()),
2582 },
2583 )
2584 },
2585 );
2586}
2587
2588fn channel_updated(
2589 channel_id: ChannelId,
2590 room: &proto::Room,
2591 channel_members: &[UserId],
2592 peer: &Peer,
2593 pool: &ConnectionPool,
2594) {
2595 let participants = room
2596 .participants
2597 .iter()
2598 .map(|p| p.user_id)
2599 .collect::<Vec<_>>();
2600
2601 broadcast(
2602 None,
2603 channel_members
2604 .iter()
2605 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2606 |peer_id| {
2607 peer.send(
2608 peer_id.into(),
2609 proto::UpdateChannels {
2610 channel_participants: vec![proto::ChannelParticipants {
2611 channel_id: channel_id.to_proto(),
2612 participant_user_ids: participants.clone(),
2613 }],
2614 ..Default::default()
2615 },
2616 )
2617 },
2618 );
2619}
2620
2621async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
2622 let db = session.db().await;
2623
2624 let contacts = db.get_contacts(user_id).await?;
2625 let busy = db.is_user_busy(user_id).await?;
2626
2627 let pool = session.connection_pool().await;
2628 let updated_contact = contact_for_user(user_id, false, busy, &pool);
2629 for contact in contacts {
2630 if let db::Contact::Accepted {
2631 user_id: contact_user_id,
2632 ..
2633 } = contact
2634 {
2635 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
2636 session
2637 .peer
2638 .send(
2639 contact_conn_id,
2640 proto::UpdateContacts {
2641 contacts: vec![updated_contact.clone()],
2642 remove_contacts: Default::default(),
2643 incoming_requests: Default::default(),
2644 remove_incoming_requests: Default::default(),
2645 outgoing_requests: Default::default(),
2646 remove_outgoing_requests: Default::default(),
2647 },
2648 )
2649 .trace_err();
2650 }
2651 }
2652 }
2653 Ok(())
2654}
2655
2656async fn leave_room_for_session(session: &Session) -> Result<()> {
2657 let mut contacts_to_update = HashSet::default();
2658
2659 let room_id;
2660 let canceled_calls_to_user_ids;
2661 let live_kit_room;
2662 let delete_live_kit_room;
2663 let room;
2664 let channel_members;
2665 let channel_id;
2666
2667 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
2668 contacts_to_update.insert(session.user_id);
2669
2670 for project in left_room.left_projects.values() {
2671 project_left(project, session);
2672 }
2673
2674 room_id = RoomId::from_proto(left_room.room.id);
2675 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
2676 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
2677 delete_live_kit_room = left_room.deleted;
2678 room = mem::take(&mut left_room.room);
2679 channel_members = mem::take(&mut left_room.channel_members);
2680 channel_id = left_room.channel_id;
2681
2682 room_updated(&room, &session.peer);
2683 } else {
2684 return Ok(());
2685 }
2686
2687 // TODO - do this while holding the room guard.
2688 if let Some(channel_id) = channel_id {
2689 channel_updated(
2690 channel_id,
2691 &room,
2692 &channel_members,
2693 &session.peer,
2694 &*session.connection_pool().await,
2695 );
2696 }
2697
2698 {
2699 let pool = session.connection_pool().await;
2700 for canceled_user_id in canceled_calls_to_user_ids {
2701 for connection_id in pool.user_connection_ids(canceled_user_id) {
2702 session
2703 .peer
2704 .send(
2705 connection_id,
2706 proto::CallCanceled {
2707 room_id: room_id.to_proto(),
2708 },
2709 )
2710 .trace_err();
2711 }
2712 contacts_to_update.insert(canceled_user_id);
2713 }
2714 }
2715
2716 for contact_user_id in contacts_to_update {
2717 update_user_contacts(contact_user_id, &session).await?;
2718 }
2719
2720 if let Some(live_kit) = session.live_kit_client.as_ref() {
2721 live_kit
2722 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
2723 .await
2724 .trace_err();
2725
2726 if delete_live_kit_room {
2727 live_kit.delete_room(live_kit_room).await.trace_err();
2728 }
2729 }
2730
2731 Ok(())
2732}
2733
2734fn project_left(project: &db::LeftProject, session: &Session) {
2735 for connection_id in &project.connection_ids {
2736 if project.host_user_id == session.user_id {
2737 session
2738 .peer
2739 .send(
2740 *connection_id,
2741 proto::UnshareProject {
2742 project_id: project.id.to_proto(),
2743 },
2744 )
2745 .trace_err();
2746 } else {
2747 session
2748 .peer
2749 .send(
2750 *connection_id,
2751 proto::RemoveProjectCollaborator {
2752 project_id: project.id.to_proto(),
2753 peer_id: Some(session.connection_id.into()),
2754 },
2755 )
2756 .trace_err();
2757 }
2758 }
2759}
2760
2761pub trait ResultExt {
2762 type Ok;
2763
2764 fn trace_err(self) -> Option<Self::Ok>;
2765}
2766
2767impl<T, E> ResultExt for Result<T, E>
2768where
2769 E: std::fmt::Debug,
2770{
2771 type Ok = T;
2772
2773 fn trace_err(self) -> Option<T> {
2774 match self {
2775 Ok(value) => Some(value),
2776 Err(error) => {
2777 tracing::error!("{:?}", error);
2778 None
2779 }
2780 }
2781 }
2782}