1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{self, ChannelId, ChannelsForUser, Database, ProjectId, RoomId, ServerId, User, UserId},
6 executor::Executor,
7 AppState, Result,
8};
9use anyhow::anyhow;
10use async_tungstenite::tungstenite::{
11 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
12};
13use axum::{
14 body::Body,
15 extract::{
16 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
17 ConnectInfo, WebSocketUpgrade,
18 },
19 headers::{Header, HeaderName},
20 http::StatusCode,
21 middleware,
22 response::IntoResponse,
23 routing::get,
24 Extension, Router, TypedHeader,
25};
26use collections::{HashMap, HashSet};
27pub use connection_pool::ConnectionPool;
28use futures::{
29 channel::oneshot,
30 future::{self, BoxFuture},
31 stream::FuturesUnordered,
32 FutureExt, SinkExt, StreamExt, TryStreamExt,
33};
34use lazy_static::lazy_static;
35use prometheus::{register_int_gauge, IntGauge};
36use rpc::{
37 proto::{
38 self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
39 RequestMessage,
40 },
41 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
42};
43use serde::{Serialize, Serializer};
44use std::{
45 any::TypeId,
46 fmt,
47 future::Future,
48 marker::PhantomData,
49 mem,
50 net::SocketAddr,
51 ops::{Deref, DerefMut},
52 rc::Rc,
53 sync::{
54 atomic::{AtomicBool, Ordering::SeqCst},
55 Arc,
56 },
57 time::{Duration, Instant},
58};
59use tokio::sync::{watch, Semaphore};
60use tower::ServiceBuilder;
61use tracing::{info_span, instrument, Instrument};
62
63pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
64pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
65
66lazy_static! {
67 static ref METRIC_CONNECTIONS: IntGauge =
68 register_int_gauge!("connections", "number of connections").unwrap();
69 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
70 "shared_projects",
71 "number of open projects with one or more guests"
72 )
73 .unwrap();
74}
75
76type MessageHandler =
77 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
78
79struct Response<R> {
80 peer: Arc<Peer>,
81 receipt: Receipt<R>,
82 responded: Arc<AtomicBool>,
83}
84
85impl<R: RequestMessage> Response<R> {
86 fn send(self, payload: R::Response) -> Result<()> {
87 self.responded.store(true, SeqCst);
88 self.peer.respond(self.receipt, payload)?;
89 Ok(())
90 }
91}
92
93#[derive(Clone)]
94struct Session {
95 user_id: UserId,
96 connection_id: ConnectionId,
97 db: Arc<tokio::sync::Mutex<DbHandle>>,
98 peer: Arc<Peer>,
99 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
100 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
101 executor: Executor,
102}
103
104impl Session {
105 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
106 #[cfg(test)]
107 tokio::task::yield_now().await;
108 let guard = self.db.lock().await;
109 #[cfg(test)]
110 tokio::task::yield_now().await;
111 guard
112 }
113
114 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
115 #[cfg(test)]
116 tokio::task::yield_now().await;
117 let guard = self.connection_pool.lock();
118 ConnectionPoolGuard {
119 guard,
120 _not_send: PhantomData,
121 }
122 }
123}
124
125impl fmt::Debug for Session {
126 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127 f.debug_struct("Session")
128 .field("user_id", &self.user_id)
129 .field("connection_id", &self.connection_id)
130 .finish()
131 }
132}
133
134struct DbHandle(Arc<Database>);
135
136impl Deref for DbHandle {
137 type Target = Database;
138
139 fn deref(&self) -> &Self::Target {
140 self.0.as_ref()
141 }
142}
143
144pub struct Server {
145 id: parking_lot::Mutex<ServerId>,
146 peer: Arc<Peer>,
147 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
148 app_state: Arc<AppState>,
149 executor: Executor,
150 handlers: HashMap<TypeId, MessageHandler>,
151 teardown: watch::Sender<()>,
152}
153
154pub(crate) struct ConnectionPoolGuard<'a> {
155 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
156 _not_send: PhantomData<Rc<()>>,
157}
158
159#[derive(Serialize)]
160pub struct ServerSnapshot<'a> {
161 peer: &'a Peer,
162 #[serde(serialize_with = "serialize_deref")]
163 connection_pool: ConnectionPoolGuard<'a>,
164}
165
166pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
167where
168 S: Serializer,
169 T: Deref<Target = U>,
170 U: Serialize,
171{
172 Serialize::serialize(value.deref(), serializer)
173}
174
175impl Server {
176 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
177 let mut server = Self {
178 id: parking_lot::Mutex::new(id),
179 peer: Peer::new(id.0 as u32),
180 app_state,
181 executor,
182 connection_pool: Default::default(),
183 handlers: Default::default(),
184 teardown: watch::channel(()).0,
185 };
186
187 server
188 .add_request_handler(ping)
189 .add_request_handler(create_room)
190 .add_request_handler(join_room)
191 .add_request_handler(rejoin_room)
192 .add_request_handler(leave_room)
193 .add_request_handler(call)
194 .add_request_handler(cancel_call)
195 .add_message_handler(decline_call)
196 .add_request_handler(update_participant_location)
197 .add_request_handler(share_project)
198 .add_message_handler(unshare_project)
199 .add_request_handler(join_project)
200 .add_message_handler(leave_project)
201 .add_request_handler(update_project)
202 .add_request_handler(update_worktree)
203 .add_message_handler(start_language_server)
204 .add_message_handler(update_language_server)
205 .add_message_handler(update_diagnostic_summary)
206 .add_message_handler(update_worktree_settings)
207 .add_message_handler(refresh_inlay_hints)
208 .add_request_handler(forward_project_request::<proto::GetHover>)
209 .add_request_handler(forward_project_request::<proto::GetDefinition>)
210 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
211 .add_request_handler(forward_project_request::<proto::GetReferences>)
212 .add_request_handler(forward_project_request::<proto::SearchProject>)
213 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
214 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
215 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
216 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
217 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
218 .add_request_handler(forward_project_request::<proto::GetCompletions>)
219 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
220 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
221 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
222 .add_request_handler(forward_project_request::<proto::PrepareRename>)
223 .add_request_handler(forward_project_request::<proto::PerformRename>)
224 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
225 .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
226 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
227 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
228 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
229 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
230 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
231 .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
232 .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
233 .add_request_handler(forward_project_request::<proto::InlayHints>)
234 .add_message_handler(create_buffer_for_peer)
235 .add_request_handler(update_buffer)
236 .add_message_handler(update_buffer_file)
237 .add_message_handler(buffer_reloaded)
238 .add_message_handler(buffer_saved)
239 .add_request_handler(forward_project_request::<proto::SaveBuffer>)
240 .add_request_handler(get_users)
241 .add_request_handler(fuzzy_search_users)
242 .add_request_handler(request_contact)
243 .add_request_handler(remove_contact)
244 .add_request_handler(respond_to_contact_request)
245 .add_request_handler(create_channel)
246 .add_request_handler(remove_channel)
247 .add_request_handler(invite_channel_member)
248 .add_request_handler(remove_channel_member)
249 .add_request_handler(set_channel_member_admin)
250 .add_request_handler(rename_channel)
251 .add_request_handler(get_channel_members)
252 .add_request_handler(respond_to_channel_invite)
253 .add_request_handler(join_channel)
254 .add_request_handler(follow)
255 .add_message_handler(unfollow)
256 .add_message_handler(update_followers)
257 .add_message_handler(update_diff_base)
258 .add_request_handler(get_private_user_info);
259
260 Arc::new(server)
261 }
262
263 pub async fn start(&self) -> Result<()> {
264 let server_id = *self.id.lock();
265 let app_state = self.app_state.clone();
266 let peer = self.peer.clone();
267 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
268 let pool = self.connection_pool.clone();
269 let live_kit_client = self.app_state.live_kit_client.clone();
270
271 let span = info_span!("start server");
272 self.executor.spawn_detached(
273 async move {
274 tracing::info!("waiting for cleanup timeout");
275 timeout.await;
276 tracing::info!("cleanup timeout expired, retrieving stale rooms");
277 if let Some(room_ids) = app_state
278 .db
279 .stale_room_ids(&app_state.config.zed_environment, server_id)
280 .await
281 .trace_err()
282 {
283 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
284 for room_id in room_ids {
285 let mut contacts_to_update = HashSet::default();
286 let mut canceled_calls_to_user_ids = Vec::new();
287 let mut live_kit_room = String::new();
288 let mut delete_live_kit_room = false;
289
290 if let Some(mut refreshed_room) = app_state
291 .db
292 .refresh_room(room_id, server_id)
293 .await
294 .trace_err()
295 {
296 tracing::info!(
297 room_id = room_id.0,
298 new_participant_count = refreshed_room.room.participants.len(),
299 "refreshed room"
300 );
301 room_updated(&refreshed_room.room, &peer);
302 if let Some(channel_id) = refreshed_room.channel_id {
303 channel_updated(
304 channel_id,
305 &refreshed_room.room,
306 &refreshed_room.channel_members,
307 &peer,
308 &*pool.lock(),
309 );
310 }
311 contacts_to_update
312 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
313 contacts_to_update
314 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
315 canceled_calls_to_user_ids =
316 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
317 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
318 delete_live_kit_room = refreshed_room.room.participants.is_empty();
319 }
320
321 {
322 let pool = pool.lock();
323 for canceled_user_id in canceled_calls_to_user_ids {
324 for connection_id in pool.user_connection_ids(canceled_user_id) {
325 peer.send(
326 connection_id,
327 proto::CallCanceled {
328 room_id: room_id.to_proto(),
329 },
330 )
331 .trace_err();
332 }
333 }
334 }
335
336 for user_id in contacts_to_update {
337 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
338 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
339 if let Some((busy, contacts)) = busy.zip(contacts) {
340 let pool = pool.lock();
341 let updated_contact = contact_for_user(user_id, false, busy, &pool);
342 for contact in contacts {
343 if let db::Contact::Accepted {
344 user_id: contact_user_id,
345 ..
346 } = contact
347 {
348 for contact_conn_id in
349 pool.user_connection_ids(contact_user_id)
350 {
351 peer.send(
352 contact_conn_id,
353 proto::UpdateContacts {
354 contacts: vec![updated_contact.clone()],
355 remove_contacts: Default::default(),
356 incoming_requests: Default::default(),
357 remove_incoming_requests: Default::default(),
358 outgoing_requests: Default::default(),
359 remove_outgoing_requests: Default::default(),
360 },
361 )
362 .trace_err();
363 }
364 }
365 }
366 }
367 }
368
369 if let Some(live_kit) = live_kit_client.as_ref() {
370 if delete_live_kit_room {
371 live_kit.delete_room(live_kit_room).await.trace_err();
372 }
373 }
374 }
375 }
376
377 app_state
378 .db
379 .delete_stale_servers(&app_state.config.zed_environment, server_id)
380 .await
381 .trace_err();
382 }
383 .instrument(span),
384 );
385 Ok(())
386 }
387
388 pub fn teardown(&self) {
389 self.peer.teardown();
390 self.connection_pool.lock().reset();
391 let _ = self.teardown.send(());
392 }
393
394 #[cfg(test)]
395 pub fn reset(&self, id: ServerId) {
396 self.teardown();
397 *self.id.lock() = id;
398 self.peer.reset(id.0 as u32);
399 }
400
401 #[cfg(test)]
402 pub fn id(&self) -> ServerId {
403 *self.id.lock()
404 }
405
406 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
407 where
408 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
409 Fut: 'static + Send + Future<Output = Result<()>>,
410 M: EnvelopedMessage,
411 {
412 let prev_handler = self.handlers.insert(
413 TypeId::of::<M>(),
414 Box::new(move |envelope, session| {
415 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
416 let span = info_span!(
417 "handle message",
418 payload_type = envelope.payload_type_name()
419 );
420 span.in_scope(|| {
421 tracing::info!(
422 payload_type = envelope.payload_type_name(),
423 "message received"
424 );
425 });
426 let start_time = Instant::now();
427 let future = (handler)(*envelope, session);
428 async move {
429 let result = future.await;
430 let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
431 match result {
432 Err(error) => {
433 tracing::error!(%error, ?duration_ms, "error handling message")
434 }
435 Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
436 }
437 }
438 .instrument(span)
439 .boxed()
440 }),
441 );
442 if prev_handler.is_some() {
443 panic!("registered a handler for the same message twice");
444 }
445 self
446 }
447
448 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
449 where
450 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
451 Fut: 'static + Send + Future<Output = Result<()>>,
452 M: EnvelopedMessage,
453 {
454 self.add_handler(move |envelope, session| handler(envelope.payload, session));
455 self
456 }
457
458 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
459 where
460 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
461 Fut: Send + Future<Output = Result<()>>,
462 M: RequestMessage,
463 {
464 let handler = Arc::new(handler);
465 self.add_handler(move |envelope, session| {
466 let receipt = envelope.receipt();
467 let handler = handler.clone();
468 async move {
469 let peer = session.peer.clone();
470 let responded = Arc::new(AtomicBool::default());
471 let response = Response {
472 peer: peer.clone(),
473 responded: responded.clone(),
474 receipt,
475 };
476 match (handler)(envelope.payload, response, session).await {
477 Ok(()) => {
478 if responded.load(std::sync::atomic::Ordering::SeqCst) {
479 Ok(())
480 } else {
481 Err(anyhow!("handler did not send a response"))?
482 }
483 }
484 Err(error) => {
485 peer.respond_with_error(
486 receipt,
487 proto::Error {
488 message: error.to_string(),
489 },
490 )?;
491 Err(error)
492 }
493 }
494 }
495 })
496 }
497
498 pub fn handle_connection(
499 self: &Arc<Self>,
500 connection: Connection,
501 address: String,
502 user: User,
503 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
504 executor: Executor,
505 ) -> impl Future<Output = Result<()>> {
506 let this = self.clone();
507 let user_id = user.id;
508 let login = user.github_login;
509 let span = info_span!("handle connection", %user_id, %login, %address);
510 let mut teardown = self.teardown.subscribe();
511 async move {
512 let (connection_id, handle_io, mut incoming_rx) = this
513 .peer
514 .add_connection(connection, {
515 let executor = executor.clone();
516 move |duration| executor.sleep(duration)
517 });
518
519 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
520 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
521 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
522
523 if let Some(send_connection_id) = send_connection_id.take() {
524 let _ = send_connection_id.send(connection_id);
525 }
526
527 if !user.connected_once {
528 this.peer.send(connection_id, proto::ShowContacts {})?;
529 this.app_state.db.set_user_connected_once(user_id, true).await?;
530 }
531
532 let (contacts, invite_code, channels_for_user, channel_invites) = future::try_join4(
533 this.app_state.db.get_contacts(user_id),
534 this.app_state.db.get_invite_code_for_user(user_id),
535 this.app_state.db.get_channels_for_user(user_id),
536 this.app_state.db.get_channel_invites_for_user(user_id)
537 ).await?;
538
539 {
540 let mut pool = this.connection_pool.lock();
541 pool.add_connection(connection_id, user_id, user.admin);
542 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
543 this.peer.send(connection_id, build_initial_channels_update(
544 channels_for_user,
545 channel_invites
546 ))?;
547
548 if let Some((code, count)) = invite_code {
549 this.peer.send(connection_id, proto::UpdateInviteInfo {
550 url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
551 count: count as u32,
552 })?;
553 }
554 }
555
556 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
557 this.peer.send(connection_id, incoming_call)?;
558 }
559
560 let session = Session {
561 user_id,
562 connection_id,
563 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
564 peer: this.peer.clone(),
565 connection_pool: this.connection_pool.clone(),
566 live_kit_client: this.app_state.live_kit_client.clone(),
567 executor: executor.clone(),
568 };
569 update_user_contacts(user_id, &session).await?;
570
571 let handle_io = handle_io.fuse();
572 futures::pin_mut!(handle_io);
573
574 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
575 // This prevents deadlocks when e.g., client A performs a request to client B and
576 // client B performs a request to client A. If both clients stop processing further
577 // messages until their respective request completes, they won't have a chance to
578 // respond to the other client's request and cause a deadlock.
579 //
580 // This arrangement ensures we will attempt to process earlier messages first, but fall
581 // back to processing messages arrived later in the spirit of making progress.
582 let mut foreground_message_handlers = FuturesUnordered::new();
583 let concurrent_handlers = Arc::new(Semaphore::new(256));
584 loop {
585 let next_message = async {
586 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
587 let message = incoming_rx.next().await;
588 (permit, message)
589 }.fuse();
590 futures::pin_mut!(next_message);
591 futures::select_biased! {
592 _ = teardown.changed().fuse() => return Ok(()),
593 result = handle_io => {
594 if let Err(error) = result {
595 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
596 }
597 break;
598 }
599 _ = foreground_message_handlers.next() => {}
600 next_message = next_message => {
601 let (permit, message) = next_message;
602 if let Some(message) = message {
603 let type_name = message.payload_type_name();
604 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
605 let span_enter = span.enter();
606 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
607 let is_background = message.is_background();
608 let handle_message = (handler)(message, session.clone());
609 drop(span_enter);
610
611 let handle_message = async move {
612 handle_message.await;
613 drop(permit);
614 }.instrument(span);
615 if is_background {
616 executor.spawn_detached(handle_message);
617 } else {
618 foreground_message_handlers.push(handle_message);
619 }
620 } else {
621 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
622 }
623 } else {
624 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
625 break;
626 }
627 }
628 }
629 }
630
631 drop(foreground_message_handlers);
632 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
633 if let Err(error) = connection_lost(session, teardown, executor).await {
634 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
635 }
636
637 Ok(())
638 }.instrument(span)
639 }
640
641 pub async fn invite_code_redeemed(
642 self: &Arc<Self>,
643 inviter_id: UserId,
644 invitee_id: UserId,
645 ) -> Result<()> {
646 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
647 if let Some(code) = &user.invite_code {
648 let pool = self.connection_pool.lock();
649 let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
650 for connection_id in pool.user_connection_ids(inviter_id) {
651 self.peer.send(
652 connection_id,
653 proto::UpdateContacts {
654 contacts: vec![invitee_contact.clone()],
655 ..Default::default()
656 },
657 )?;
658 self.peer.send(
659 connection_id,
660 proto::UpdateInviteInfo {
661 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
662 count: user.invite_count as u32,
663 },
664 )?;
665 }
666 }
667 }
668 Ok(())
669 }
670
671 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
672 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
673 if let Some(invite_code) = &user.invite_code {
674 let pool = self.connection_pool.lock();
675 for connection_id in pool.user_connection_ids(user_id) {
676 self.peer.send(
677 connection_id,
678 proto::UpdateInviteInfo {
679 url: format!(
680 "{}{}",
681 self.app_state.config.invite_link_prefix, invite_code
682 ),
683 count: user.invite_count as u32,
684 },
685 )?;
686 }
687 }
688 }
689 Ok(())
690 }
691
692 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
693 ServerSnapshot {
694 connection_pool: ConnectionPoolGuard {
695 guard: self.connection_pool.lock(),
696 _not_send: PhantomData,
697 },
698 peer: &self.peer,
699 }
700 }
701}
702
703impl<'a> Deref for ConnectionPoolGuard<'a> {
704 type Target = ConnectionPool;
705
706 fn deref(&self) -> &Self::Target {
707 &*self.guard
708 }
709}
710
711impl<'a> DerefMut for ConnectionPoolGuard<'a> {
712 fn deref_mut(&mut self) -> &mut Self::Target {
713 &mut *self.guard
714 }
715}
716
717impl<'a> Drop for ConnectionPoolGuard<'a> {
718 fn drop(&mut self) {
719 #[cfg(test)]
720 self.check_invariants();
721 }
722}
723
724fn broadcast<F>(
725 sender_id: Option<ConnectionId>,
726 receiver_ids: impl IntoIterator<Item = ConnectionId>,
727 mut f: F,
728) where
729 F: FnMut(ConnectionId) -> anyhow::Result<()>,
730{
731 for receiver_id in receiver_ids {
732 if Some(receiver_id) != sender_id {
733 if let Err(error) = f(receiver_id) {
734 tracing::error!("failed to send to {:?} {}", receiver_id, error);
735 }
736 }
737 }
738}
739
740lazy_static! {
741 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
742}
743
744pub struct ProtocolVersion(u32);
745
746impl Header for ProtocolVersion {
747 fn name() -> &'static HeaderName {
748 &ZED_PROTOCOL_VERSION
749 }
750
751 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
752 where
753 Self: Sized,
754 I: Iterator<Item = &'i axum::http::HeaderValue>,
755 {
756 let version = values
757 .next()
758 .ok_or_else(axum::headers::Error::invalid)?
759 .to_str()
760 .map_err(|_| axum::headers::Error::invalid())?
761 .parse()
762 .map_err(|_| axum::headers::Error::invalid())?;
763 Ok(Self(version))
764 }
765
766 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
767 values.extend([self.0.to_string().parse().unwrap()]);
768 }
769}
770
771pub fn routes(server: Arc<Server>) -> Router<Body> {
772 Router::new()
773 .route("/rpc", get(handle_websocket_request))
774 .layer(
775 ServiceBuilder::new()
776 .layer(Extension(server.app_state.clone()))
777 .layer(middleware::from_fn(auth::validate_header)),
778 )
779 .route("/metrics", get(handle_metrics))
780 .layer(Extension(server))
781}
782
783pub async fn handle_websocket_request(
784 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
785 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
786 Extension(server): Extension<Arc<Server>>,
787 Extension(user): Extension<User>,
788 ws: WebSocketUpgrade,
789) -> axum::response::Response {
790 if protocol_version != rpc::PROTOCOL_VERSION {
791 return (
792 StatusCode::UPGRADE_REQUIRED,
793 "client must be upgraded".to_string(),
794 )
795 .into_response();
796 }
797 let socket_address = socket_address.to_string();
798 ws.on_upgrade(move |socket| {
799 use util::ResultExt;
800 let socket = socket
801 .map_ok(to_tungstenite_message)
802 .err_into()
803 .with(|message| async move { Ok(to_axum_message(message)) });
804 let connection = Connection::new(Box::pin(socket));
805 async move {
806 server
807 .handle_connection(connection, socket_address, user, None, Executor::Production)
808 .await
809 .log_err();
810 }
811 })
812}
813
814pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
815 let connections = server
816 .connection_pool
817 .lock()
818 .connections()
819 .filter(|connection| !connection.admin)
820 .count();
821
822 METRIC_CONNECTIONS.set(connections as _);
823
824 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
825 METRIC_SHARED_PROJECTS.set(shared_projects as _);
826
827 let encoder = prometheus::TextEncoder::new();
828 let metric_families = prometheus::gather();
829 let encoded_metrics = encoder
830 .encode_to_string(&metric_families)
831 .map_err(|err| anyhow!("{}", err))?;
832 Ok(encoded_metrics)
833}
834
835#[instrument(err, skip(executor))]
836async fn connection_lost(
837 session: Session,
838 mut teardown: watch::Receiver<()>,
839 executor: Executor,
840) -> Result<()> {
841 session.peer.disconnect(session.connection_id);
842 session
843 .connection_pool()
844 .await
845 .remove_connection(session.connection_id)?;
846
847 session
848 .db()
849 .await
850 .connection_lost(session.connection_id)
851 .await
852 .trace_err();
853
854 futures::select_biased! {
855 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
856 leave_room_for_session(&session).await.trace_err();
857
858 if !session
859 .connection_pool()
860 .await
861 .is_user_online(session.user_id)
862 {
863 let db = session.db().await;
864 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
865 room_updated(&room, &session.peer);
866 }
867 }
868 update_user_contacts(session.user_id, &session).await?;
869 }
870 _ = teardown.changed().fuse() => {}
871 }
872
873 Ok(())
874}
875
876async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
877 response.send(proto::Ack {})?;
878 Ok(())
879}
880
881async fn create_room(
882 _request: proto::CreateRoom,
883 response: Response<proto::CreateRoom>,
884 session: Session,
885) -> Result<()> {
886 let live_kit_room = nanoid::nanoid!(30);
887
888 let live_kit_connection_info = {
889 let live_kit_room = live_kit_room.clone();
890 let live_kit = session.live_kit_client.as_ref();
891
892 util::async_iife!({
893 let live_kit = live_kit?;
894
895 live_kit
896 .create_room(live_kit_room.clone())
897 .await
898 .trace_err()?;
899
900 let token = live_kit
901 .room_token(&live_kit_room, &session.user_id.to_string())
902 .trace_err()?;
903
904 Some(proto::LiveKitConnectionInfo {
905 server_url: live_kit.url().into(),
906 token,
907 })
908 })
909 }
910 .await;
911
912 let room = session
913 .db()
914 .await
915 .create_room(session.user_id, session.connection_id, &live_kit_room)
916 .await?;
917
918 response.send(proto::CreateRoomResponse {
919 room: Some(room.clone()),
920 live_kit_connection_info,
921 })?;
922
923 update_user_contacts(session.user_id, &session).await?;
924 Ok(())
925}
926
927async fn join_room(
928 request: proto::JoinRoom,
929 response: Response<proto::JoinRoom>,
930 session: Session,
931) -> Result<()> {
932 let room_id = RoomId::from_proto(request.id);
933 let joined_room = {
934 let room = session
935 .db()
936 .await
937 .join_room(room_id, session.user_id, session.connection_id)
938 .await?;
939 room_updated(&room.room, &session.peer);
940 room.into_inner()
941 };
942
943 if let Some(channel_id) = joined_room.channel_id {
944 channel_updated(
945 channel_id,
946 &joined_room.room,
947 &joined_room.channel_members,
948 &session.peer,
949 &*session.connection_pool().await,
950 )
951 }
952
953 for connection_id in session
954 .connection_pool()
955 .await
956 .user_connection_ids(session.user_id)
957 {
958 session
959 .peer
960 .send(
961 connection_id,
962 proto::CallCanceled {
963 room_id: room_id.to_proto(),
964 },
965 )
966 .trace_err();
967 }
968
969 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
970 if let Some(token) = live_kit
971 .room_token(
972 &joined_room.room.live_kit_room,
973 &session.user_id.to_string(),
974 )
975 .trace_err()
976 {
977 Some(proto::LiveKitConnectionInfo {
978 server_url: live_kit.url().into(),
979 token,
980 })
981 } else {
982 None
983 }
984 } else {
985 None
986 };
987
988 response.send(proto::JoinRoomResponse {
989 room: Some(joined_room.room),
990 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
991 live_kit_connection_info,
992 })?;
993
994 update_user_contacts(session.user_id, &session).await?;
995 Ok(())
996}
997
998async fn rejoin_room(
999 request: proto::RejoinRoom,
1000 response: Response<proto::RejoinRoom>,
1001 session: Session,
1002) -> Result<()> {
1003 let room;
1004 let channel_id;
1005 let channel_members;
1006 {
1007 let mut rejoined_room = session
1008 .db()
1009 .await
1010 .rejoin_room(request, session.user_id, session.connection_id)
1011 .await?;
1012
1013 response.send(proto::RejoinRoomResponse {
1014 room: Some(rejoined_room.room.clone()),
1015 reshared_projects: rejoined_room
1016 .reshared_projects
1017 .iter()
1018 .map(|project| proto::ResharedProject {
1019 id: project.id.to_proto(),
1020 collaborators: project
1021 .collaborators
1022 .iter()
1023 .map(|collaborator| collaborator.to_proto())
1024 .collect(),
1025 })
1026 .collect(),
1027 rejoined_projects: rejoined_room
1028 .rejoined_projects
1029 .iter()
1030 .map(|rejoined_project| proto::RejoinedProject {
1031 id: rejoined_project.id.to_proto(),
1032 worktrees: rejoined_project
1033 .worktrees
1034 .iter()
1035 .map(|worktree| proto::WorktreeMetadata {
1036 id: worktree.id,
1037 root_name: worktree.root_name.clone(),
1038 visible: worktree.visible,
1039 abs_path: worktree.abs_path.clone(),
1040 })
1041 .collect(),
1042 collaborators: rejoined_project
1043 .collaborators
1044 .iter()
1045 .map(|collaborator| collaborator.to_proto())
1046 .collect(),
1047 language_servers: rejoined_project.language_servers.clone(),
1048 })
1049 .collect(),
1050 })?;
1051 room_updated(&rejoined_room.room, &session.peer);
1052
1053 for project in &rejoined_room.reshared_projects {
1054 for collaborator in &project.collaborators {
1055 session
1056 .peer
1057 .send(
1058 collaborator.connection_id,
1059 proto::UpdateProjectCollaborator {
1060 project_id: project.id.to_proto(),
1061 old_peer_id: Some(project.old_connection_id.into()),
1062 new_peer_id: Some(session.connection_id.into()),
1063 },
1064 )
1065 .trace_err();
1066 }
1067
1068 broadcast(
1069 Some(session.connection_id),
1070 project
1071 .collaborators
1072 .iter()
1073 .map(|collaborator| collaborator.connection_id),
1074 |connection_id| {
1075 session.peer.forward_send(
1076 session.connection_id,
1077 connection_id,
1078 proto::UpdateProject {
1079 project_id: project.id.to_proto(),
1080 worktrees: project.worktrees.clone(),
1081 },
1082 )
1083 },
1084 );
1085 }
1086
1087 for project in &rejoined_room.rejoined_projects {
1088 for collaborator in &project.collaborators {
1089 session
1090 .peer
1091 .send(
1092 collaborator.connection_id,
1093 proto::UpdateProjectCollaborator {
1094 project_id: project.id.to_proto(),
1095 old_peer_id: Some(project.old_connection_id.into()),
1096 new_peer_id: Some(session.connection_id.into()),
1097 },
1098 )
1099 .trace_err();
1100 }
1101 }
1102
1103 for project in &mut rejoined_room.rejoined_projects {
1104 for worktree in mem::take(&mut project.worktrees) {
1105 #[cfg(any(test, feature = "test-support"))]
1106 const MAX_CHUNK_SIZE: usize = 2;
1107 #[cfg(not(any(test, feature = "test-support")))]
1108 const MAX_CHUNK_SIZE: usize = 256;
1109
1110 // Stream this worktree's entries.
1111 let message = proto::UpdateWorktree {
1112 project_id: project.id.to_proto(),
1113 worktree_id: worktree.id,
1114 abs_path: worktree.abs_path.clone(),
1115 root_name: worktree.root_name,
1116 updated_entries: worktree.updated_entries,
1117 removed_entries: worktree.removed_entries,
1118 scan_id: worktree.scan_id,
1119 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1120 updated_repositories: worktree.updated_repositories,
1121 removed_repositories: worktree.removed_repositories,
1122 };
1123 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1124 session.peer.send(session.connection_id, update.clone())?;
1125 }
1126
1127 // Stream this worktree's diagnostics.
1128 for summary in worktree.diagnostic_summaries {
1129 session.peer.send(
1130 session.connection_id,
1131 proto::UpdateDiagnosticSummary {
1132 project_id: project.id.to_proto(),
1133 worktree_id: worktree.id,
1134 summary: Some(summary),
1135 },
1136 )?;
1137 }
1138
1139 for settings_file in worktree.settings_files {
1140 session.peer.send(
1141 session.connection_id,
1142 proto::UpdateWorktreeSettings {
1143 project_id: project.id.to_proto(),
1144 worktree_id: worktree.id,
1145 path: settings_file.path,
1146 content: Some(settings_file.content),
1147 },
1148 )?;
1149 }
1150 }
1151
1152 for language_server in &project.language_servers {
1153 session.peer.send(
1154 session.connection_id,
1155 proto::UpdateLanguageServer {
1156 project_id: project.id.to_proto(),
1157 language_server_id: language_server.id,
1158 variant: Some(
1159 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1160 proto::LspDiskBasedDiagnosticsUpdated {},
1161 ),
1162 ),
1163 },
1164 )?;
1165 }
1166 }
1167
1168 let rejoined_room = rejoined_room.into_inner();
1169
1170 room = rejoined_room.room;
1171 channel_id = rejoined_room.channel_id;
1172 channel_members = rejoined_room.channel_members;
1173 }
1174
1175 if let Some(channel_id) = channel_id {
1176 channel_updated(
1177 channel_id,
1178 &room,
1179 &channel_members,
1180 &session.peer,
1181 &*session.connection_pool().await,
1182 );
1183 }
1184
1185 update_user_contacts(session.user_id, &session).await?;
1186 Ok(())
1187}
1188
1189async fn leave_room(
1190 _: proto::LeaveRoom,
1191 response: Response<proto::LeaveRoom>,
1192 session: Session,
1193) -> Result<()> {
1194 leave_room_for_session(&session).await?;
1195 response.send(proto::Ack {})?;
1196 Ok(())
1197}
1198
1199async fn call(
1200 request: proto::Call,
1201 response: Response<proto::Call>,
1202 session: Session,
1203) -> Result<()> {
1204 let room_id = RoomId::from_proto(request.room_id);
1205 let calling_user_id = session.user_id;
1206 let calling_connection_id = session.connection_id;
1207 let called_user_id = UserId::from_proto(request.called_user_id);
1208 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1209 if !session
1210 .db()
1211 .await
1212 .has_contact(calling_user_id, called_user_id)
1213 .await?
1214 {
1215 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1216 }
1217
1218 let incoming_call = {
1219 let (room, incoming_call) = &mut *session
1220 .db()
1221 .await
1222 .call(
1223 room_id,
1224 calling_user_id,
1225 calling_connection_id,
1226 called_user_id,
1227 initial_project_id,
1228 )
1229 .await?;
1230 room_updated(&room, &session.peer);
1231 mem::take(incoming_call)
1232 };
1233 update_user_contacts(called_user_id, &session).await?;
1234
1235 let mut calls = session
1236 .connection_pool()
1237 .await
1238 .user_connection_ids(called_user_id)
1239 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1240 .collect::<FuturesUnordered<_>>();
1241
1242 while let Some(call_response) = calls.next().await {
1243 match call_response.as_ref() {
1244 Ok(_) => {
1245 response.send(proto::Ack {})?;
1246 return Ok(());
1247 }
1248 Err(_) => {
1249 call_response.trace_err();
1250 }
1251 }
1252 }
1253
1254 {
1255 let room = session
1256 .db()
1257 .await
1258 .call_failed(room_id, called_user_id)
1259 .await?;
1260 room_updated(&room, &session.peer);
1261 }
1262 update_user_contacts(called_user_id, &session).await?;
1263
1264 Err(anyhow!("failed to ring user"))?
1265}
1266
1267async fn cancel_call(
1268 request: proto::CancelCall,
1269 response: Response<proto::CancelCall>,
1270 session: Session,
1271) -> Result<()> {
1272 let called_user_id = UserId::from_proto(request.called_user_id);
1273 let room_id = RoomId::from_proto(request.room_id);
1274 {
1275 let room = session
1276 .db()
1277 .await
1278 .cancel_call(room_id, session.connection_id, called_user_id)
1279 .await?;
1280 room_updated(&room, &session.peer);
1281 }
1282
1283 for connection_id in session
1284 .connection_pool()
1285 .await
1286 .user_connection_ids(called_user_id)
1287 {
1288 session
1289 .peer
1290 .send(
1291 connection_id,
1292 proto::CallCanceled {
1293 room_id: room_id.to_proto(),
1294 },
1295 )
1296 .trace_err();
1297 }
1298 response.send(proto::Ack {})?;
1299
1300 update_user_contacts(called_user_id, &session).await?;
1301 Ok(())
1302}
1303
1304async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1305 let room_id = RoomId::from_proto(message.room_id);
1306 {
1307 let room = session
1308 .db()
1309 .await
1310 .decline_call(Some(room_id), session.user_id)
1311 .await?
1312 .ok_or_else(|| anyhow!("failed to decline call"))?;
1313 room_updated(&room, &session.peer);
1314 }
1315
1316 for connection_id in session
1317 .connection_pool()
1318 .await
1319 .user_connection_ids(session.user_id)
1320 {
1321 session
1322 .peer
1323 .send(
1324 connection_id,
1325 proto::CallCanceled {
1326 room_id: room_id.to_proto(),
1327 },
1328 )
1329 .trace_err();
1330 }
1331 update_user_contacts(session.user_id, &session).await?;
1332 Ok(())
1333}
1334
1335async fn update_participant_location(
1336 request: proto::UpdateParticipantLocation,
1337 response: Response<proto::UpdateParticipantLocation>,
1338 session: Session,
1339) -> Result<()> {
1340 let room_id = RoomId::from_proto(request.room_id);
1341 let location = request
1342 .location
1343 .ok_or_else(|| anyhow!("invalid location"))?;
1344
1345 let db = session.db().await;
1346 let room = db
1347 .update_room_participant_location(room_id, session.connection_id, location)
1348 .await?;
1349
1350 room_updated(&room, &session.peer);
1351 response.send(proto::Ack {})?;
1352 Ok(())
1353}
1354
1355async fn share_project(
1356 request: proto::ShareProject,
1357 response: Response<proto::ShareProject>,
1358 session: Session,
1359) -> Result<()> {
1360 let (project_id, room) = &*session
1361 .db()
1362 .await
1363 .share_project(
1364 RoomId::from_proto(request.room_id),
1365 session.connection_id,
1366 &request.worktrees,
1367 )
1368 .await?;
1369 response.send(proto::ShareProjectResponse {
1370 project_id: project_id.to_proto(),
1371 })?;
1372 room_updated(&room, &session.peer);
1373
1374 Ok(())
1375}
1376
1377async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1378 let project_id = ProjectId::from_proto(message.project_id);
1379
1380 let (room, guest_connection_ids) = &*session
1381 .db()
1382 .await
1383 .unshare_project(project_id, session.connection_id)
1384 .await?;
1385
1386 broadcast(
1387 Some(session.connection_id),
1388 guest_connection_ids.iter().copied(),
1389 |conn_id| session.peer.send(conn_id, message.clone()),
1390 );
1391 room_updated(&room, &session.peer);
1392
1393 Ok(())
1394}
1395
1396async fn join_project(
1397 request: proto::JoinProject,
1398 response: Response<proto::JoinProject>,
1399 session: Session,
1400) -> Result<()> {
1401 let project_id = ProjectId::from_proto(request.project_id);
1402 let guest_user_id = session.user_id;
1403
1404 tracing::info!(%project_id, "join project");
1405
1406 let (project, replica_id) = &mut *session
1407 .db()
1408 .await
1409 .join_project(project_id, session.connection_id)
1410 .await?;
1411
1412 let collaborators = project
1413 .collaborators
1414 .iter()
1415 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1416 .map(|collaborator| collaborator.to_proto())
1417 .collect::<Vec<_>>();
1418
1419 let worktrees = project
1420 .worktrees
1421 .iter()
1422 .map(|(id, worktree)| proto::WorktreeMetadata {
1423 id: *id,
1424 root_name: worktree.root_name.clone(),
1425 visible: worktree.visible,
1426 abs_path: worktree.abs_path.clone(),
1427 })
1428 .collect::<Vec<_>>();
1429
1430 for collaborator in &collaborators {
1431 session
1432 .peer
1433 .send(
1434 collaborator.peer_id.unwrap().into(),
1435 proto::AddProjectCollaborator {
1436 project_id: project_id.to_proto(),
1437 collaborator: Some(proto::Collaborator {
1438 peer_id: Some(session.connection_id.into()),
1439 replica_id: replica_id.0 as u32,
1440 user_id: guest_user_id.to_proto(),
1441 }),
1442 },
1443 )
1444 .trace_err();
1445 }
1446
1447 // First, we send the metadata associated with each worktree.
1448 response.send(proto::JoinProjectResponse {
1449 worktrees: worktrees.clone(),
1450 replica_id: replica_id.0 as u32,
1451 collaborators: collaborators.clone(),
1452 language_servers: project.language_servers.clone(),
1453 })?;
1454
1455 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1456 #[cfg(any(test, feature = "test-support"))]
1457 const MAX_CHUNK_SIZE: usize = 2;
1458 #[cfg(not(any(test, feature = "test-support")))]
1459 const MAX_CHUNK_SIZE: usize = 256;
1460
1461 // Stream this worktree's entries.
1462 let message = proto::UpdateWorktree {
1463 project_id: project_id.to_proto(),
1464 worktree_id,
1465 abs_path: worktree.abs_path.clone(),
1466 root_name: worktree.root_name,
1467 updated_entries: worktree.entries,
1468 removed_entries: Default::default(),
1469 scan_id: worktree.scan_id,
1470 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1471 updated_repositories: worktree.repository_entries.into_values().collect(),
1472 removed_repositories: Default::default(),
1473 };
1474 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1475 session.peer.send(session.connection_id, update.clone())?;
1476 }
1477
1478 // Stream this worktree's diagnostics.
1479 for summary in worktree.diagnostic_summaries {
1480 session.peer.send(
1481 session.connection_id,
1482 proto::UpdateDiagnosticSummary {
1483 project_id: project_id.to_proto(),
1484 worktree_id: worktree.id,
1485 summary: Some(summary),
1486 },
1487 )?;
1488 }
1489
1490 for settings_file in worktree.settings_files {
1491 session.peer.send(
1492 session.connection_id,
1493 proto::UpdateWorktreeSettings {
1494 project_id: project_id.to_proto(),
1495 worktree_id: worktree.id,
1496 path: settings_file.path,
1497 content: Some(settings_file.content),
1498 },
1499 )?;
1500 }
1501 }
1502
1503 for language_server in &project.language_servers {
1504 session.peer.send(
1505 session.connection_id,
1506 proto::UpdateLanguageServer {
1507 project_id: project_id.to_proto(),
1508 language_server_id: language_server.id,
1509 variant: Some(
1510 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1511 proto::LspDiskBasedDiagnosticsUpdated {},
1512 ),
1513 ),
1514 },
1515 )?;
1516 }
1517
1518 Ok(())
1519}
1520
1521async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1522 let sender_id = session.connection_id;
1523 let project_id = ProjectId::from_proto(request.project_id);
1524
1525 let (room, project) = &*session
1526 .db()
1527 .await
1528 .leave_project(project_id, sender_id)
1529 .await?;
1530 tracing::info!(
1531 %project_id,
1532 host_user_id = %project.host_user_id,
1533 host_connection_id = %project.host_connection_id,
1534 "leave project"
1535 );
1536
1537 project_left(&project, &session);
1538 room_updated(&room, &session.peer);
1539
1540 Ok(())
1541}
1542
1543async fn update_project(
1544 request: proto::UpdateProject,
1545 response: Response<proto::UpdateProject>,
1546 session: Session,
1547) -> Result<()> {
1548 let project_id = ProjectId::from_proto(request.project_id);
1549 let (room, guest_connection_ids) = &*session
1550 .db()
1551 .await
1552 .update_project(project_id, session.connection_id, &request.worktrees)
1553 .await?;
1554 broadcast(
1555 Some(session.connection_id),
1556 guest_connection_ids.iter().copied(),
1557 |connection_id| {
1558 session
1559 .peer
1560 .forward_send(session.connection_id, connection_id, request.clone())
1561 },
1562 );
1563 room_updated(&room, &session.peer);
1564 response.send(proto::Ack {})?;
1565
1566 Ok(())
1567}
1568
1569async fn update_worktree(
1570 request: proto::UpdateWorktree,
1571 response: Response<proto::UpdateWorktree>,
1572 session: Session,
1573) -> Result<()> {
1574 let guest_connection_ids = session
1575 .db()
1576 .await
1577 .update_worktree(&request, session.connection_id)
1578 .await?;
1579
1580 broadcast(
1581 Some(session.connection_id),
1582 guest_connection_ids.iter().copied(),
1583 |connection_id| {
1584 session
1585 .peer
1586 .forward_send(session.connection_id, connection_id, request.clone())
1587 },
1588 );
1589 response.send(proto::Ack {})?;
1590 Ok(())
1591}
1592
1593async fn update_diagnostic_summary(
1594 message: proto::UpdateDiagnosticSummary,
1595 session: Session,
1596) -> Result<()> {
1597 let guest_connection_ids = session
1598 .db()
1599 .await
1600 .update_diagnostic_summary(&message, session.connection_id)
1601 .await?;
1602
1603 broadcast(
1604 Some(session.connection_id),
1605 guest_connection_ids.iter().copied(),
1606 |connection_id| {
1607 session
1608 .peer
1609 .forward_send(session.connection_id, connection_id, message.clone())
1610 },
1611 );
1612
1613 Ok(())
1614}
1615
1616async fn update_worktree_settings(
1617 message: proto::UpdateWorktreeSettings,
1618 session: Session,
1619) -> Result<()> {
1620 let guest_connection_ids = session
1621 .db()
1622 .await
1623 .update_worktree_settings(&message, session.connection_id)
1624 .await?;
1625
1626 broadcast(
1627 Some(session.connection_id),
1628 guest_connection_ids.iter().copied(),
1629 |connection_id| {
1630 session
1631 .peer
1632 .forward_send(session.connection_id, connection_id, message.clone())
1633 },
1634 );
1635
1636 Ok(())
1637}
1638
1639async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1640 broadcast_project_message(request.project_id, request, session).await
1641}
1642
1643async fn start_language_server(
1644 request: proto::StartLanguageServer,
1645 session: Session,
1646) -> Result<()> {
1647 let guest_connection_ids = session
1648 .db()
1649 .await
1650 .start_language_server(&request, session.connection_id)
1651 .await?;
1652
1653 broadcast(
1654 Some(session.connection_id),
1655 guest_connection_ids.iter().copied(),
1656 |connection_id| {
1657 session
1658 .peer
1659 .forward_send(session.connection_id, connection_id, request.clone())
1660 },
1661 );
1662 Ok(())
1663}
1664
1665async fn update_language_server(
1666 request: proto::UpdateLanguageServer,
1667 session: Session,
1668) -> Result<()> {
1669 session.executor.record_backtrace();
1670 let project_id = ProjectId::from_proto(request.project_id);
1671 let project_connection_ids = session
1672 .db()
1673 .await
1674 .project_connection_ids(project_id, session.connection_id)
1675 .await?;
1676 broadcast(
1677 Some(session.connection_id),
1678 project_connection_ids.iter().copied(),
1679 |connection_id| {
1680 session
1681 .peer
1682 .forward_send(session.connection_id, connection_id, request.clone())
1683 },
1684 );
1685 Ok(())
1686}
1687
1688async fn forward_project_request<T>(
1689 request: T,
1690 response: Response<T>,
1691 session: Session,
1692) -> Result<()>
1693where
1694 T: EntityMessage + RequestMessage,
1695{
1696 session.executor.record_backtrace();
1697 let project_id = ProjectId::from_proto(request.remote_entity_id());
1698 let host_connection_id = {
1699 let collaborators = session
1700 .db()
1701 .await
1702 .project_collaborators(project_id, session.connection_id)
1703 .await?;
1704 collaborators
1705 .iter()
1706 .find(|collaborator| collaborator.is_host)
1707 .ok_or_else(|| anyhow!("host not found"))?
1708 .connection_id
1709 };
1710
1711 let payload = session
1712 .peer
1713 .forward_request(session.connection_id, host_connection_id, request)
1714 .await?;
1715
1716 response.send(payload)?;
1717 Ok(())
1718}
1719
1720async fn create_buffer_for_peer(
1721 request: proto::CreateBufferForPeer,
1722 session: Session,
1723) -> Result<()> {
1724 session.executor.record_backtrace();
1725 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1726 session
1727 .peer
1728 .forward_send(session.connection_id, peer_id.into(), request)?;
1729 Ok(())
1730}
1731
1732async fn update_buffer(
1733 request: proto::UpdateBuffer,
1734 response: Response<proto::UpdateBuffer>,
1735 session: Session,
1736) -> Result<()> {
1737 session.executor.record_backtrace();
1738 let project_id = ProjectId::from_proto(request.project_id);
1739 let mut guest_connection_ids;
1740 let mut host_connection_id = None;
1741 {
1742 let collaborators = session
1743 .db()
1744 .await
1745 .project_collaborators(project_id, session.connection_id)
1746 .await?;
1747 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1748 for collaborator in collaborators.iter() {
1749 if collaborator.is_host {
1750 host_connection_id = Some(collaborator.connection_id);
1751 } else {
1752 guest_connection_ids.push(collaborator.connection_id);
1753 }
1754 }
1755 }
1756 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1757
1758 session.executor.record_backtrace();
1759 broadcast(
1760 Some(session.connection_id),
1761 guest_connection_ids,
1762 |connection_id| {
1763 session
1764 .peer
1765 .forward_send(session.connection_id, connection_id, request.clone())
1766 },
1767 );
1768 if host_connection_id != session.connection_id {
1769 session
1770 .peer
1771 .forward_request(session.connection_id, host_connection_id, request.clone())
1772 .await?;
1773 }
1774
1775 response.send(proto::Ack {})?;
1776 Ok(())
1777}
1778
1779async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1780 let project_id = ProjectId::from_proto(request.project_id);
1781 let project_connection_ids = session
1782 .db()
1783 .await
1784 .project_connection_ids(project_id, session.connection_id)
1785 .await?;
1786
1787 broadcast(
1788 Some(session.connection_id),
1789 project_connection_ids.iter().copied(),
1790 |connection_id| {
1791 session
1792 .peer
1793 .forward_send(session.connection_id, connection_id, request.clone())
1794 },
1795 );
1796 Ok(())
1797}
1798
1799async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1800 let project_id = ProjectId::from_proto(request.project_id);
1801 let project_connection_ids = session
1802 .db()
1803 .await
1804 .project_connection_ids(project_id, session.connection_id)
1805 .await?;
1806 broadcast(
1807 Some(session.connection_id),
1808 project_connection_ids.iter().copied(),
1809 |connection_id| {
1810 session
1811 .peer
1812 .forward_send(session.connection_id, connection_id, request.clone())
1813 },
1814 );
1815 Ok(())
1816}
1817
1818async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1819 broadcast_project_message(request.project_id, request, session).await
1820}
1821
1822async fn broadcast_project_message<T: EnvelopedMessage>(
1823 project_id: u64,
1824 request: T,
1825 session: Session,
1826) -> Result<()> {
1827 let project_id = ProjectId::from_proto(project_id);
1828 let project_connection_ids = session
1829 .db()
1830 .await
1831 .project_connection_ids(project_id, session.connection_id)
1832 .await?;
1833 broadcast(
1834 Some(session.connection_id),
1835 project_connection_ids.iter().copied(),
1836 |connection_id| {
1837 session
1838 .peer
1839 .forward_send(session.connection_id, connection_id, request.clone())
1840 },
1841 );
1842 Ok(())
1843}
1844
1845async fn follow(
1846 request: proto::Follow,
1847 response: Response<proto::Follow>,
1848 session: Session,
1849) -> Result<()> {
1850 let project_id = ProjectId::from_proto(request.project_id);
1851 let leader_id = request
1852 .leader_id
1853 .ok_or_else(|| anyhow!("invalid leader id"))?
1854 .into();
1855 let follower_id = session.connection_id;
1856
1857 {
1858 let project_connection_ids = session
1859 .db()
1860 .await
1861 .project_connection_ids(project_id, session.connection_id)
1862 .await?;
1863
1864 if !project_connection_ids.contains(&leader_id) {
1865 Err(anyhow!("no such peer"))?;
1866 }
1867 }
1868
1869 let mut response_payload = session
1870 .peer
1871 .forward_request(session.connection_id, leader_id, request)
1872 .await?;
1873 response_payload
1874 .views
1875 .retain(|view| view.leader_id != Some(follower_id.into()));
1876 response.send(response_payload)?;
1877
1878 let room = session
1879 .db()
1880 .await
1881 .follow(project_id, leader_id, follower_id)
1882 .await?;
1883 room_updated(&room, &session.peer);
1884
1885 Ok(())
1886}
1887
1888async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1889 let project_id = ProjectId::from_proto(request.project_id);
1890 let leader_id = request
1891 .leader_id
1892 .ok_or_else(|| anyhow!("invalid leader id"))?
1893 .into();
1894 let follower_id = session.connection_id;
1895
1896 if !session
1897 .db()
1898 .await
1899 .project_connection_ids(project_id, session.connection_id)
1900 .await?
1901 .contains(&leader_id)
1902 {
1903 Err(anyhow!("no such peer"))?;
1904 }
1905
1906 session
1907 .peer
1908 .forward_send(session.connection_id, leader_id, request)?;
1909
1910 let room = session
1911 .db()
1912 .await
1913 .unfollow(project_id, leader_id, follower_id)
1914 .await?;
1915 room_updated(&room, &session.peer);
1916
1917 Ok(())
1918}
1919
1920async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1921 let project_id = ProjectId::from_proto(request.project_id);
1922 let project_connection_ids = session
1923 .db
1924 .lock()
1925 .await
1926 .project_connection_ids(project_id, session.connection_id)
1927 .await?;
1928
1929 let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1930 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1931 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1932 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1933 });
1934 for follower_peer_id in request.follower_ids.iter().copied() {
1935 let follower_connection_id = follower_peer_id.into();
1936 if project_connection_ids.contains(&follower_connection_id)
1937 && Some(follower_peer_id) != leader_id
1938 {
1939 session.peer.forward_send(
1940 session.connection_id,
1941 follower_connection_id,
1942 request.clone(),
1943 )?;
1944 }
1945 }
1946 Ok(())
1947}
1948
1949async fn get_users(
1950 request: proto::GetUsers,
1951 response: Response<proto::GetUsers>,
1952 session: Session,
1953) -> Result<()> {
1954 let user_ids = request
1955 .user_ids
1956 .into_iter()
1957 .map(UserId::from_proto)
1958 .collect();
1959 let users = session
1960 .db()
1961 .await
1962 .get_users_by_ids(user_ids)
1963 .await?
1964 .into_iter()
1965 .map(|user| proto::User {
1966 id: user.id.to_proto(),
1967 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1968 github_login: user.github_login,
1969 })
1970 .collect();
1971 response.send(proto::UsersResponse { users })?;
1972 Ok(())
1973}
1974
1975async fn fuzzy_search_users(
1976 request: proto::FuzzySearchUsers,
1977 response: Response<proto::FuzzySearchUsers>,
1978 session: Session,
1979) -> Result<()> {
1980 let query = request.query;
1981 let users = match query.len() {
1982 0 => vec![],
1983 1 | 2 => session
1984 .db()
1985 .await
1986 .get_user_by_github_login(&query)
1987 .await?
1988 .into_iter()
1989 .collect(),
1990 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1991 };
1992 let users = users
1993 .into_iter()
1994 .filter(|user| user.id != session.user_id)
1995 .map(|user| proto::User {
1996 id: user.id.to_proto(),
1997 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1998 github_login: user.github_login,
1999 })
2000 .collect();
2001 response.send(proto::UsersResponse { users })?;
2002 Ok(())
2003}
2004
2005async fn request_contact(
2006 request: proto::RequestContact,
2007 response: Response<proto::RequestContact>,
2008 session: Session,
2009) -> Result<()> {
2010 let requester_id = session.user_id;
2011 let responder_id = UserId::from_proto(request.responder_id);
2012 if requester_id == responder_id {
2013 return Err(anyhow!("cannot add yourself as a contact"))?;
2014 }
2015
2016 session
2017 .db()
2018 .await
2019 .send_contact_request(requester_id, responder_id)
2020 .await?;
2021
2022 // Update outgoing contact requests of requester
2023 let mut update = proto::UpdateContacts::default();
2024 update.outgoing_requests.push(responder_id.to_proto());
2025 for connection_id in session
2026 .connection_pool()
2027 .await
2028 .user_connection_ids(requester_id)
2029 {
2030 session.peer.send(connection_id, update.clone())?;
2031 }
2032
2033 // Update incoming contact requests of responder
2034 let mut update = proto::UpdateContacts::default();
2035 update
2036 .incoming_requests
2037 .push(proto::IncomingContactRequest {
2038 requester_id: requester_id.to_proto(),
2039 should_notify: true,
2040 });
2041 for connection_id in session
2042 .connection_pool()
2043 .await
2044 .user_connection_ids(responder_id)
2045 {
2046 session.peer.send(connection_id, update.clone())?;
2047 }
2048
2049 response.send(proto::Ack {})?;
2050 Ok(())
2051}
2052
2053async fn respond_to_contact_request(
2054 request: proto::RespondToContactRequest,
2055 response: Response<proto::RespondToContactRequest>,
2056 session: Session,
2057) -> Result<()> {
2058 let responder_id = session.user_id;
2059 let requester_id = UserId::from_proto(request.requester_id);
2060 let db = session.db().await;
2061 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2062 db.dismiss_contact_notification(responder_id, requester_id)
2063 .await?;
2064 } else {
2065 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2066
2067 db.respond_to_contact_request(responder_id, requester_id, accept)
2068 .await?;
2069 let requester_busy = db.is_user_busy(requester_id).await?;
2070 let responder_busy = db.is_user_busy(responder_id).await?;
2071
2072 let pool = session.connection_pool().await;
2073 // Update responder with new contact
2074 let mut update = proto::UpdateContacts::default();
2075 if accept {
2076 update
2077 .contacts
2078 .push(contact_for_user(requester_id, false, requester_busy, &pool));
2079 }
2080 update
2081 .remove_incoming_requests
2082 .push(requester_id.to_proto());
2083 for connection_id in pool.user_connection_ids(responder_id) {
2084 session.peer.send(connection_id, update.clone())?;
2085 }
2086
2087 // Update requester with new contact
2088 let mut update = proto::UpdateContacts::default();
2089 if accept {
2090 update
2091 .contacts
2092 .push(contact_for_user(responder_id, true, responder_busy, &pool));
2093 }
2094 update
2095 .remove_outgoing_requests
2096 .push(responder_id.to_proto());
2097 for connection_id in pool.user_connection_ids(requester_id) {
2098 session.peer.send(connection_id, update.clone())?;
2099 }
2100 }
2101
2102 response.send(proto::Ack {})?;
2103 Ok(())
2104}
2105
2106async fn remove_contact(
2107 request: proto::RemoveContact,
2108 response: Response<proto::RemoveContact>,
2109 session: Session,
2110) -> Result<()> {
2111 let requester_id = session.user_id;
2112 let responder_id = UserId::from_proto(request.user_id);
2113 let db = session.db().await;
2114 let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
2115
2116 let pool = session.connection_pool().await;
2117 // Update outgoing contact requests of requester
2118 let mut update = proto::UpdateContacts::default();
2119 if contact_accepted {
2120 update.remove_contacts.push(responder_id.to_proto());
2121 } else {
2122 update
2123 .remove_outgoing_requests
2124 .push(responder_id.to_proto());
2125 }
2126 for connection_id in pool.user_connection_ids(requester_id) {
2127 session.peer.send(connection_id, update.clone())?;
2128 }
2129
2130 // Update incoming contact requests of responder
2131 let mut update = proto::UpdateContacts::default();
2132 if contact_accepted {
2133 update.remove_contacts.push(requester_id.to_proto());
2134 } else {
2135 update
2136 .remove_incoming_requests
2137 .push(requester_id.to_proto());
2138 }
2139 for connection_id in pool.user_connection_ids(responder_id) {
2140 session.peer.send(connection_id, update.clone())?;
2141 }
2142
2143 response.send(proto::Ack {})?;
2144 Ok(())
2145}
2146
2147async fn create_channel(
2148 request: proto::CreateChannel,
2149 response: Response<proto::CreateChannel>,
2150 session: Session,
2151) -> Result<()> {
2152 let db = session.db().await;
2153 let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
2154
2155 if let Some(live_kit) = session.live_kit_client.as_ref() {
2156 live_kit.create_room(live_kit_room.clone()).await?;
2157 }
2158
2159 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2160 let id = db
2161 .create_channel(&request.name, parent_id, &live_kit_room, session.user_id)
2162 .await?;
2163
2164 let channel = proto::Channel {
2165 id: id.to_proto(),
2166 name: request.name,
2167 parent_id: request.parent_id,
2168 };
2169
2170 response.send(proto::ChannelResponse {
2171 channel: Some(channel.clone()),
2172 })?;
2173
2174 let mut update = proto::UpdateChannels::default();
2175 update.channels.push(channel);
2176
2177 let user_ids_to_notify = if let Some(parent_id) = parent_id {
2178 db.get_channel_members(parent_id).await?
2179 } else {
2180 vec![session.user_id]
2181 };
2182
2183 let connection_pool = session.connection_pool().await;
2184 for user_id in user_ids_to_notify {
2185 for connection_id in connection_pool.user_connection_ids(user_id) {
2186 let mut update = update.clone();
2187 if user_id == session.user_id {
2188 update.channel_permissions.push(proto::ChannelPermission {
2189 channel_id: id.to_proto(),
2190 is_admin: true,
2191 });
2192 }
2193 session.peer.send(connection_id, update)?;
2194 }
2195 }
2196
2197 Ok(())
2198}
2199
2200async fn remove_channel(
2201 request: proto::RemoveChannel,
2202 response: Response<proto::RemoveChannel>,
2203 session: Session,
2204) -> Result<()> {
2205 let db = session.db().await;
2206
2207 let channel_id = request.channel_id;
2208 let (removed_channels, member_ids) = db
2209 .remove_channel(ChannelId::from_proto(channel_id), session.user_id)
2210 .await?;
2211 response.send(proto::Ack {})?;
2212
2213 // Notify members of removed channels
2214 let mut update = proto::UpdateChannels::default();
2215 update
2216 .remove_channels
2217 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2218
2219 let connection_pool = session.connection_pool().await;
2220 for member_id in member_ids {
2221 for connection_id in connection_pool.user_connection_ids(member_id) {
2222 session.peer.send(connection_id, update.clone())?;
2223 }
2224 }
2225
2226 Ok(())
2227}
2228
2229async fn invite_channel_member(
2230 request: proto::InviteChannelMember,
2231 response: Response<proto::InviteChannelMember>,
2232 session: Session,
2233) -> Result<()> {
2234 let db = session.db().await;
2235 let channel_id = ChannelId::from_proto(request.channel_id);
2236 let invitee_id = UserId::from_proto(request.user_id);
2237 db.invite_channel_member(channel_id, invitee_id, session.user_id, request.admin)
2238 .await?;
2239
2240 let (channel, _) = db
2241 .get_channel(channel_id, session.user_id)
2242 .await?
2243 .ok_or_else(|| anyhow!("channel not found"))?;
2244
2245 let mut update = proto::UpdateChannels::default();
2246 update.channel_invitations.push(proto::Channel {
2247 id: channel.id.to_proto(),
2248 name: channel.name,
2249 parent_id: None,
2250 });
2251 for connection_id in session
2252 .connection_pool()
2253 .await
2254 .user_connection_ids(invitee_id)
2255 {
2256 session.peer.send(connection_id, update.clone())?;
2257 }
2258
2259 response.send(proto::Ack {})?;
2260 Ok(())
2261}
2262
2263async fn remove_channel_member(
2264 request: proto::RemoveChannelMember,
2265 response: Response<proto::RemoveChannelMember>,
2266 session: Session,
2267) -> Result<()> {
2268 let db = session.db().await;
2269 let channel_id = ChannelId::from_proto(request.channel_id);
2270 let member_id = UserId::from_proto(request.user_id);
2271
2272 db.remove_channel_member(channel_id, member_id, session.user_id)
2273 .await?;
2274
2275 let mut update = proto::UpdateChannels::default();
2276 update.remove_channels.push(channel_id.to_proto());
2277
2278 for connection_id in session
2279 .connection_pool()
2280 .await
2281 .user_connection_ids(member_id)
2282 {
2283 session.peer.send(connection_id, update.clone())?;
2284 }
2285
2286 response.send(proto::Ack {})?;
2287 Ok(())
2288}
2289
2290async fn set_channel_member_admin(
2291 request: proto::SetChannelMemberAdmin,
2292 response: Response<proto::SetChannelMemberAdmin>,
2293 session: Session,
2294) -> Result<()> {
2295 let db = session.db().await;
2296 let channel_id = ChannelId::from_proto(request.channel_id);
2297 let member_id = UserId::from_proto(request.user_id);
2298 db.set_channel_member_admin(channel_id, session.user_id, member_id, request.admin)
2299 .await?;
2300
2301 let (channel, has_accepted) = db
2302 .get_channel(channel_id, member_id)
2303 .await?
2304 .ok_or_else(|| anyhow!("channel not found"))?;
2305
2306 let mut update = proto::UpdateChannels::default();
2307 if has_accepted {
2308 update.channel_permissions.push(proto::ChannelPermission {
2309 channel_id: channel.id.to_proto(),
2310 is_admin: request.admin,
2311 });
2312 }
2313
2314 for connection_id in session
2315 .connection_pool()
2316 .await
2317 .user_connection_ids(member_id)
2318 {
2319 session.peer.send(connection_id, update.clone())?;
2320 }
2321
2322 response.send(proto::Ack {})?;
2323 Ok(())
2324}
2325
2326async fn rename_channel(
2327 request: proto::RenameChannel,
2328 response: Response<proto::RenameChannel>,
2329 session: Session,
2330) -> Result<()> {
2331 let db = session.db().await;
2332 let channel_id = ChannelId::from_proto(request.channel_id);
2333 let new_name = db
2334 .rename_channel(channel_id, session.user_id, &request.name)
2335 .await?;
2336
2337 let channel = proto::Channel {
2338 id: request.channel_id,
2339 name: new_name,
2340 parent_id: None,
2341 };
2342 response.send(proto::ChannelResponse {
2343 channel: Some(channel.clone()),
2344 })?;
2345 let mut update = proto::UpdateChannels::default();
2346 update.channels.push(channel);
2347
2348 let member_ids = db.get_channel_members(channel_id).await?;
2349
2350 let connection_pool = session.connection_pool().await;
2351 for member_id in member_ids {
2352 for connection_id in connection_pool.user_connection_ids(member_id) {
2353 session.peer.send(connection_id, update.clone())?;
2354 }
2355 }
2356
2357 Ok(())
2358}
2359
2360async fn get_channel_members(
2361 request: proto::GetChannelMembers,
2362 response: Response<proto::GetChannelMembers>,
2363 session: Session,
2364) -> Result<()> {
2365 let db = session.db().await;
2366 let channel_id = ChannelId::from_proto(request.channel_id);
2367 let members = db
2368 .get_channel_member_details(channel_id, session.user_id)
2369 .await?;
2370 response.send(proto::GetChannelMembersResponse { members })?;
2371 Ok(())
2372}
2373
2374async fn respond_to_channel_invite(
2375 request: proto::RespondToChannelInvite,
2376 response: Response<proto::RespondToChannelInvite>,
2377 session: Session,
2378) -> Result<()> {
2379 let db = session.db().await;
2380 let channel_id = ChannelId::from_proto(request.channel_id);
2381 db.respond_to_channel_invite(channel_id, session.user_id, request.accept)
2382 .await?;
2383
2384 let mut update = proto::UpdateChannels::default();
2385 update
2386 .remove_channel_invitations
2387 .push(channel_id.to_proto());
2388 if request.accept {
2389 let result = db.get_channels_for_user(session.user_id).await?;
2390 update
2391 .channels
2392 .extend(result.channels.into_iter().map(|channel| proto::Channel {
2393 id: channel.id.to_proto(),
2394 name: channel.name,
2395 parent_id: channel.parent_id.map(ChannelId::to_proto),
2396 }));
2397 update
2398 .channel_participants
2399 .extend(
2400 result
2401 .channel_participants
2402 .into_iter()
2403 .map(|(channel_id, user_ids)| proto::ChannelParticipants {
2404 channel_id: channel_id.to_proto(),
2405 participant_user_ids: user_ids.into_iter().map(UserId::to_proto).collect(),
2406 }),
2407 );
2408 update
2409 .channel_permissions
2410 .extend(
2411 result
2412 .channels_with_admin_privileges
2413 .into_iter()
2414 .map(|channel_id| proto::ChannelPermission {
2415 channel_id: channel_id.to_proto(),
2416 is_admin: true,
2417 }),
2418 );
2419 }
2420 session.peer.send(session.connection_id, update)?;
2421 response.send(proto::Ack {})?;
2422
2423 Ok(())
2424}
2425
2426async fn join_channel(
2427 request: proto::JoinChannel,
2428 response: Response<proto::JoinChannel>,
2429 session: Session,
2430) -> Result<()> {
2431 let channel_id = ChannelId::from_proto(request.channel_id);
2432
2433 let joined_room = {
2434 leave_room_for_session(&session).await?;
2435 let db = session.db().await;
2436
2437 let room_id = db.room_id_for_channel(channel_id).await?;
2438
2439 let joined_room = db
2440 .join_room(room_id, session.user_id, session.connection_id)
2441 .await?;
2442
2443 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2444 let token = live_kit
2445 .room_token(
2446 &joined_room.room.live_kit_room,
2447 &session.user_id.to_string(),
2448 )
2449 .trace_err()?;
2450
2451 Some(LiveKitConnectionInfo {
2452 server_url: live_kit.url().into(),
2453 token,
2454 })
2455 });
2456
2457 response.send(proto::JoinRoomResponse {
2458 room: Some(joined_room.room.clone()),
2459 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2460 live_kit_connection_info,
2461 })?;
2462
2463 room_updated(&joined_room.room, &session.peer);
2464
2465 joined_room.into_inner()
2466 };
2467
2468 channel_updated(
2469 channel_id,
2470 &joined_room.room,
2471 &joined_room.channel_members,
2472 &session.peer,
2473 &*session.connection_pool().await,
2474 );
2475
2476 update_user_contacts(session.user_id, &session).await?;
2477
2478 Ok(())
2479}
2480
2481async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
2482 let project_id = ProjectId::from_proto(request.project_id);
2483 let project_connection_ids = session
2484 .db()
2485 .await
2486 .project_connection_ids(project_id, session.connection_id)
2487 .await?;
2488 broadcast(
2489 Some(session.connection_id),
2490 project_connection_ids.iter().copied(),
2491 |connection_id| {
2492 session
2493 .peer
2494 .forward_send(session.connection_id, connection_id, request.clone())
2495 },
2496 );
2497 Ok(())
2498}
2499
2500async fn get_private_user_info(
2501 _request: proto::GetPrivateUserInfo,
2502 response: Response<proto::GetPrivateUserInfo>,
2503 session: Session,
2504) -> Result<()> {
2505 let metrics_id = session
2506 .db()
2507 .await
2508 .get_user_metrics_id(session.user_id)
2509 .await?;
2510 let user = session
2511 .db()
2512 .await
2513 .get_user_by_id(session.user_id)
2514 .await?
2515 .ok_or_else(|| anyhow!("user not found"))?;
2516 response.send(proto::GetPrivateUserInfoResponse {
2517 metrics_id,
2518 staff: user.admin,
2519 })?;
2520 Ok(())
2521}
2522
2523fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
2524 match message {
2525 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
2526 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
2527 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
2528 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
2529 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
2530 code: frame.code.into(),
2531 reason: frame.reason,
2532 })),
2533 }
2534}
2535
2536fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
2537 match message {
2538 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
2539 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
2540 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
2541 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
2542 AxumMessage::Close(frame) => {
2543 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
2544 code: frame.code.into(),
2545 reason: frame.reason,
2546 }))
2547 }
2548 }
2549}
2550
2551fn build_initial_channels_update(
2552 channels: ChannelsForUser,
2553 channel_invites: Vec<db::Channel>,
2554) -> proto::UpdateChannels {
2555 let mut update = proto::UpdateChannels::default();
2556
2557 for channel in channels.channels {
2558 update.channels.push(proto::Channel {
2559 id: channel.id.to_proto(),
2560 name: channel.name,
2561 parent_id: channel.parent_id.map(|id| id.to_proto()),
2562 });
2563 }
2564
2565 for (channel_id, participants) in channels.channel_participants {
2566 update
2567 .channel_participants
2568 .push(proto::ChannelParticipants {
2569 channel_id: channel_id.to_proto(),
2570 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
2571 });
2572 }
2573
2574 update
2575 .channel_permissions
2576 .extend(
2577 channels
2578 .channels_with_admin_privileges
2579 .into_iter()
2580 .map(|id| proto::ChannelPermission {
2581 channel_id: id.to_proto(),
2582 is_admin: true,
2583 }),
2584 );
2585
2586 for channel in channel_invites {
2587 update.channel_invitations.push(proto::Channel {
2588 id: channel.id.to_proto(),
2589 name: channel.name,
2590 parent_id: None,
2591 });
2592 }
2593
2594 update
2595}
2596
2597fn build_initial_contacts_update(
2598 contacts: Vec<db::Contact>,
2599 pool: &ConnectionPool,
2600) -> proto::UpdateContacts {
2601 let mut update = proto::UpdateContacts::default();
2602
2603 for contact in contacts {
2604 match contact {
2605 db::Contact::Accepted {
2606 user_id,
2607 should_notify,
2608 busy,
2609 } => {
2610 update
2611 .contacts
2612 .push(contact_for_user(user_id, should_notify, busy, &pool));
2613 }
2614 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
2615 db::Contact::Incoming {
2616 user_id,
2617 should_notify,
2618 } => update
2619 .incoming_requests
2620 .push(proto::IncomingContactRequest {
2621 requester_id: user_id.to_proto(),
2622 should_notify,
2623 }),
2624 }
2625 }
2626
2627 update
2628}
2629
2630fn contact_for_user(
2631 user_id: UserId,
2632 should_notify: bool,
2633 busy: bool,
2634 pool: &ConnectionPool,
2635) -> proto::Contact {
2636 proto::Contact {
2637 user_id: user_id.to_proto(),
2638 online: pool.is_user_online(user_id),
2639 busy,
2640 should_notify,
2641 }
2642}
2643
2644fn room_updated(room: &proto::Room, peer: &Peer) {
2645 broadcast(
2646 None,
2647 room.participants
2648 .iter()
2649 .filter_map(|participant| Some(participant.peer_id?.into())),
2650 |peer_id| {
2651 peer.send(
2652 peer_id.into(),
2653 proto::RoomUpdated {
2654 room: Some(room.clone()),
2655 },
2656 )
2657 },
2658 );
2659}
2660
2661fn channel_updated(
2662 channel_id: ChannelId,
2663 room: &proto::Room,
2664 channel_members: &[UserId],
2665 peer: &Peer,
2666 pool: &ConnectionPool,
2667) {
2668 let participants = room
2669 .participants
2670 .iter()
2671 .map(|p| p.user_id)
2672 .collect::<Vec<_>>();
2673
2674 broadcast(
2675 None,
2676 channel_members
2677 .iter()
2678 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2679 |peer_id| {
2680 peer.send(
2681 peer_id.into(),
2682 proto::UpdateChannels {
2683 channel_participants: vec![proto::ChannelParticipants {
2684 channel_id: channel_id.to_proto(),
2685 participant_user_ids: participants.clone(),
2686 }],
2687 ..Default::default()
2688 },
2689 )
2690 },
2691 );
2692}
2693
2694async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
2695 let db = session.db().await;
2696
2697 let contacts = db.get_contacts(user_id).await?;
2698 let busy = db.is_user_busy(user_id).await?;
2699
2700 let pool = session.connection_pool().await;
2701 let updated_contact = contact_for_user(user_id, false, busy, &pool);
2702 for contact in contacts {
2703 if let db::Contact::Accepted {
2704 user_id: contact_user_id,
2705 ..
2706 } = contact
2707 {
2708 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
2709 session
2710 .peer
2711 .send(
2712 contact_conn_id,
2713 proto::UpdateContacts {
2714 contacts: vec![updated_contact.clone()],
2715 remove_contacts: Default::default(),
2716 incoming_requests: Default::default(),
2717 remove_incoming_requests: Default::default(),
2718 outgoing_requests: Default::default(),
2719 remove_outgoing_requests: Default::default(),
2720 },
2721 )
2722 .trace_err();
2723 }
2724 }
2725 }
2726 Ok(())
2727}
2728
2729async fn leave_room_for_session(session: &Session) -> Result<()> {
2730 let mut contacts_to_update = HashSet::default();
2731
2732 let room_id;
2733 let canceled_calls_to_user_ids;
2734 let live_kit_room;
2735 let delete_live_kit_room;
2736 let room;
2737 let channel_members;
2738 let channel_id;
2739
2740 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
2741 contacts_to_update.insert(session.user_id);
2742
2743 for project in left_room.left_projects.values() {
2744 project_left(project, session);
2745 }
2746
2747 room_id = RoomId::from_proto(left_room.room.id);
2748 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
2749 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
2750 delete_live_kit_room = left_room.deleted;
2751 room = mem::take(&mut left_room.room);
2752 channel_members = mem::take(&mut left_room.channel_members);
2753 channel_id = left_room.channel_id;
2754
2755 room_updated(&room, &session.peer);
2756 } else {
2757 return Ok(());
2758 }
2759
2760 if let Some(channel_id) = channel_id {
2761 channel_updated(
2762 channel_id,
2763 &room,
2764 &channel_members,
2765 &session.peer,
2766 &*session.connection_pool().await,
2767 );
2768 }
2769
2770 {
2771 let pool = session.connection_pool().await;
2772 for canceled_user_id in canceled_calls_to_user_ids {
2773 for connection_id in pool.user_connection_ids(canceled_user_id) {
2774 session
2775 .peer
2776 .send(
2777 connection_id,
2778 proto::CallCanceled {
2779 room_id: room_id.to_proto(),
2780 },
2781 )
2782 .trace_err();
2783 }
2784 contacts_to_update.insert(canceled_user_id);
2785 }
2786 }
2787
2788 for contact_user_id in contacts_to_update {
2789 update_user_contacts(contact_user_id, &session).await?;
2790 }
2791
2792 if let Some(live_kit) = session.live_kit_client.as_ref() {
2793 live_kit
2794 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
2795 .await
2796 .trace_err();
2797
2798 if delete_live_kit_room {
2799 live_kit.delete_room(live_kit_room).await.trace_err();
2800 }
2801 }
2802
2803 Ok(())
2804}
2805
2806fn project_left(project: &db::LeftProject, session: &Session) {
2807 for connection_id in &project.connection_ids {
2808 if project.host_user_id == session.user_id {
2809 session
2810 .peer
2811 .send(
2812 *connection_id,
2813 proto::UnshareProject {
2814 project_id: project.id.to_proto(),
2815 },
2816 )
2817 .trace_err();
2818 } else {
2819 session
2820 .peer
2821 .send(
2822 *connection_id,
2823 proto::RemoveProjectCollaborator {
2824 project_id: project.id.to_proto(),
2825 peer_id: Some(session.connection_id.into()),
2826 },
2827 )
2828 .trace_err();
2829 }
2830 }
2831}
2832
2833pub trait ResultExt {
2834 type Ok;
2835
2836 fn trace_err(self) -> Option<Self::Ok>;
2837}
2838
2839impl<T, E> ResultExt for Result<T, E>
2840where
2841 E: std::fmt::Debug,
2842{
2843 type Ok = T;
2844
2845 fn trace_err(self) -> Option<T> {
2846 match self {
2847 Ok(value) => Some(value),
2848 Err(error) => {
2849 tracing::error!("{:?}", error);
2850 None
2851 }
2852 }
2853 }
2854}