1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{self, ChannelId, ChannelsForUser, Database, ProjectId, RoomId, ServerId, User, UserId},
6 executor::Executor,
7 AppState, Result,
8};
9use anyhow::anyhow;
10use async_tungstenite::tungstenite::{
11 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
12};
13use axum::{
14 body::Body,
15 extract::{
16 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
17 ConnectInfo, WebSocketUpgrade,
18 },
19 headers::{Header, HeaderName},
20 http::StatusCode,
21 middleware,
22 response::IntoResponse,
23 routing::get,
24 Extension, Router, TypedHeader,
25};
26use collections::{HashMap, HashSet};
27pub use connection_pool::ConnectionPool;
28use futures::{
29 channel::oneshot,
30 future::{self, BoxFuture},
31 stream::FuturesUnordered,
32 FutureExt, SinkExt, StreamExt, TryStreamExt,
33};
34use lazy_static::lazy_static;
35use prometheus::{register_int_gauge, IntGauge};
36use rpc::{
37 proto::{
38 self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, GetChannelBufferResponse,
39 LiveKitConnectionInfo, RequestMessage,
40 },
41 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
42};
43use serde::{Serialize, Serializer};
44use std::{
45 any::TypeId,
46 fmt,
47 future::Future,
48 marker::PhantomData,
49 mem,
50 net::SocketAddr,
51 ops::{Deref, DerefMut},
52 rc::Rc,
53 sync::{
54 atomic::{AtomicBool, Ordering::SeqCst},
55 Arc,
56 },
57 time::{Duration, Instant},
58};
59use tokio::sync::{watch, Semaphore};
60use tower::ServiceBuilder;
61use tracing::{info_span, instrument, Instrument};
62
63pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
64pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
65
66lazy_static! {
67 static ref METRIC_CONNECTIONS: IntGauge =
68 register_int_gauge!("connections", "number of connections").unwrap();
69 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
70 "shared_projects",
71 "number of open projects with one or more guests"
72 )
73 .unwrap();
74}
75
76type MessageHandler =
77 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
78
79struct Response<R> {
80 peer: Arc<Peer>,
81 receipt: Receipt<R>,
82 responded: Arc<AtomicBool>,
83}
84
85impl<R: RequestMessage> Response<R> {
86 fn send(self, payload: R::Response) -> Result<()> {
87 self.responded.store(true, SeqCst);
88 self.peer.respond(self.receipt, payload)?;
89 Ok(())
90 }
91}
92
93#[derive(Clone)]
94struct Session {
95 user_id: UserId,
96 connection_id: ConnectionId,
97 db: Arc<tokio::sync::Mutex<DbHandle>>,
98 peer: Arc<Peer>,
99 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
100 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
101 executor: Executor,
102}
103
104impl Session {
105 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
106 #[cfg(test)]
107 tokio::task::yield_now().await;
108 let guard = self.db.lock().await;
109 #[cfg(test)]
110 tokio::task::yield_now().await;
111 guard
112 }
113
114 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
115 #[cfg(test)]
116 tokio::task::yield_now().await;
117 let guard = self.connection_pool.lock();
118 ConnectionPoolGuard {
119 guard,
120 _not_send: PhantomData,
121 }
122 }
123}
124
125impl fmt::Debug for Session {
126 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127 f.debug_struct("Session")
128 .field("user_id", &self.user_id)
129 .field("connection_id", &self.connection_id)
130 .finish()
131 }
132}
133
134struct DbHandle(Arc<Database>);
135
136impl Deref for DbHandle {
137 type Target = Database;
138
139 fn deref(&self) -> &Self::Target {
140 self.0.as_ref()
141 }
142}
143
144pub struct Server {
145 id: parking_lot::Mutex<ServerId>,
146 peer: Arc<Peer>,
147 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
148 app_state: Arc<AppState>,
149 executor: Executor,
150 handlers: HashMap<TypeId, MessageHandler>,
151 teardown: watch::Sender<()>,
152}
153
154pub(crate) struct ConnectionPoolGuard<'a> {
155 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
156 _not_send: PhantomData<Rc<()>>,
157}
158
159#[derive(Serialize)]
160pub struct ServerSnapshot<'a> {
161 peer: &'a Peer,
162 #[serde(serialize_with = "serialize_deref")]
163 connection_pool: ConnectionPoolGuard<'a>,
164}
165
166pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
167where
168 S: Serializer,
169 T: Deref<Target = U>,
170 U: Serialize,
171{
172 Serialize::serialize(value.deref(), serializer)
173}
174
175impl Server {
176 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
177 let mut server = Self {
178 id: parking_lot::Mutex::new(id),
179 peer: Peer::new(id.0 as u32),
180 app_state,
181 executor,
182 connection_pool: Default::default(),
183 handlers: Default::default(),
184 teardown: watch::channel(()).0,
185 };
186
187 server
188 .add_request_handler(ping)
189 .add_request_handler(create_room)
190 .add_request_handler(join_room)
191 .add_request_handler(rejoin_room)
192 .add_request_handler(leave_room)
193 .add_request_handler(call)
194 .add_request_handler(cancel_call)
195 .add_message_handler(decline_call)
196 .add_request_handler(update_participant_location)
197 .add_request_handler(share_project)
198 .add_message_handler(unshare_project)
199 .add_request_handler(join_project)
200 .add_message_handler(leave_project)
201 .add_request_handler(update_project)
202 .add_request_handler(update_worktree)
203 .add_message_handler(start_language_server)
204 .add_message_handler(update_language_server)
205 .add_message_handler(update_diagnostic_summary)
206 .add_message_handler(update_worktree_settings)
207 .add_message_handler(refresh_inlay_hints)
208 .add_request_handler(forward_project_request::<proto::GetHover>)
209 .add_request_handler(forward_project_request::<proto::GetDefinition>)
210 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
211 .add_request_handler(forward_project_request::<proto::GetReferences>)
212 .add_request_handler(forward_project_request::<proto::SearchProject>)
213 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
214 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
215 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
216 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
217 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
218 .add_request_handler(forward_project_request::<proto::GetCompletions>)
219 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
220 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
221 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
222 .add_request_handler(forward_project_request::<proto::PrepareRename>)
223 .add_request_handler(forward_project_request::<proto::PerformRename>)
224 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
225 .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
226 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
227 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
228 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
229 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
230 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
231 .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
232 .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
233 .add_request_handler(forward_project_request::<proto::InlayHints>)
234 .add_message_handler(create_buffer_for_peer)
235 .add_request_handler(update_buffer)
236 .add_message_handler(update_buffer_file)
237 .add_message_handler(buffer_reloaded)
238 .add_message_handler(buffer_saved)
239 .add_request_handler(forward_project_request::<proto::SaveBuffer>)
240 .add_request_handler(get_users)
241 .add_request_handler(fuzzy_search_users)
242 .add_request_handler(request_contact)
243 .add_request_handler(remove_contact)
244 .add_request_handler(respond_to_contact_request)
245 .add_request_handler(create_channel)
246 .add_request_handler(remove_channel)
247 .add_request_handler(invite_channel_member)
248 .add_request_handler(remove_channel_member)
249 .add_request_handler(set_channel_member_admin)
250 .add_request_handler(rename_channel)
251 .add_request_handler(get_channel_buffer)
252 .add_request_handler(get_channel_members)
253 .add_request_handler(respond_to_channel_invite)
254 .add_request_handler(join_channel)
255 .add_request_handler(follow)
256 .add_message_handler(unfollow)
257 .add_message_handler(update_followers)
258 .add_message_handler(update_diff_base)
259 .add_request_handler(get_private_user_info);
260
261 Arc::new(server)
262 }
263
264 pub async fn start(&self) -> Result<()> {
265 let server_id = *self.id.lock();
266 let app_state = self.app_state.clone();
267 let peer = self.peer.clone();
268 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
269 let pool = self.connection_pool.clone();
270 let live_kit_client = self.app_state.live_kit_client.clone();
271
272 let span = info_span!("start server");
273 self.executor.spawn_detached(
274 async move {
275 tracing::info!("waiting for cleanup timeout");
276 timeout.await;
277 tracing::info!("cleanup timeout expired, retrieving stale rooms");
278 if let Some(room_ids) = app_state
279 .db
280 .stale_room_ids(&app_state.config.zed_environment, server_id)
281 .await
282 .trace_err()
283 {
284 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
285 for room_id in room_ids {
286 let mut contacts_to_update = HashSet::default();
287 let mut canceled_calls_to_user_ids = Vec::new();
288 let mut live_kit_room = String::new();
289 let mut delete_live_kit_room = false;
290
291 if let Some(mut refreshed_room) = app_state
292 .db
293 .refresh_room(room_id, server_id)
294 .await
295 .trace_err()
296 {
297 tracing::info!(
298 room_id = room_id.0,
299 new_participant_count = refreshed_room.room.participants.len(),
300 "refreshed room"
301 );
302 room_updated(&refreshed_room.room, &peer);
303 if let Some(channel_id) = refreshed_room.channel_id {
304 channel_updated(
305 channel_id,
306 &refreshed_room.room,
307 &refreshed_room.channel_members,
308 &peer,
309 &*pool.lock(),
310 );
311 }
312 contacts_to_update
313 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
314 contacts_to_update
315 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
316 canceled_calls_to_user_ids =
317 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
318 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
319 delete_live_kit_room = refreshed_room.room.participants.is_empty();
320 }
321
322 {
323 let pool = pool.lock();
324 for canceled_user_id in canceled_calls_to_user_ids {
325 for connection_id in pool.user_connection_ids(canceled_user_id) {
326 peer.send(
327 connection_id,
328 proto::CallCanceled {
329 room_id: room_id.to_proto(),
330 },
331 )
332 .trace_err();
333 }
334 }
335 }
336
337 for user_id in contacts_to_update {
338 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
339 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
340 if let Some((busy, contacts)) = busy.zip(contacts) {
341 let pool = pool.lock();
342 let updated_contact = contact_for_user(user_id, false, busy, &pool);
343 for contact in contacts {
344 if let db::Contact::Accepted {
345 user_id: contact_user_id,
346 ..
347 } = contact
348 {
349 for contact_conn_id in
350 pool.user_connection_ids(contact_user_id)
351 {
352 peer.send(
353 contact_conn_id,
354 proto::UpdateContacts {
355 contacts: vec![updated_contact.clone()],
356 remove_contacts: Default::default(),
357 incoming_requests: Default::default(),
358 remove_incoming_requests: Default::default(),
359 outgoing_requests: Default::default(),
360 remove_outgoing_requests: Default::default(),
361 },
362 )
363 .trace_err();
364 }
365 }
366 }
367 }
368 }
369
370 if let Some(live_kit) = live_kit_client.as_ref() {
371 if delete_live_kit_room {
372 live_kit.delete_room(live_kit_room).await.trace_err();
373 }
374 }
375 }
376 }
377
378 app_state
379 .db
380 .delete_stale_servers(&app_state.config.zed_environment, server_id)
381 .await
382 .trace_err();
383 }
384 .instrument(span),
385 );
386 Ok(())
387 }
388
389 pub fn teardown(&self) {
390 self.peer.teardown();
391 self.connection_pool.lock().reset();
392 let _ = self.teardown.send(());
393 }
394
395 #[cfg(test)]
396 pub fn reset(&self, id: ServerId) {
397 self.teardown();
398 *self.id.lock() = id;
399 self.peer.reset(id.0 as u32);
400 }
401
402 #[cfg(test)]
403 pub fn id(&self) -> ServerId {
404 *self.id.lock()
405 }
406
407 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
408 where
409 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
410 Fut: 'static + Send + Future<Output = Result<()>>,
411 M: EnvelopedMessage,
412 {
413 let prev_handler = self.handlers.insert(
414 TypeId::of::<M>(),
415 Box::new(move |envelope, session| {
416 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
417 let span = info_span!(
418 "handle message",
419 payload_type = envelope.payload_type_name()
420 );
421 span.in_scope(|| {
422 tracing::info!(
423 payload_type = envelope.payload_type_name(),
424 "message received"
425 );
426 });
427 let start_time = Instant::now();
428 let future = (handler)(*envelope, session);
429 async move {
430 let result = future.await;
431 let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
432 match result {
433 Err(error) => {
434 tracing::error!(%error, ?duration_ms, "error handling message")
435 }
436 Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
437 }
438 }
439 .instrument(span)
440 .boxed()
441 }),
442 );
443 if prev_handler.is_some() {
444 panic!("registered a handler for the same message twice");
445 }
446 self
447 }
448
449 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
450 where
451 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
452 Fut: 'static + Send + Future<Output = Result<()>>,
453 M: EnvelopedMessage,
454 {
455 self.add_handler(move |envelope, session| handler(envelope.payload, session));
456 self
457 }
458
459 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
460 where
461 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
462 Fut: Send + Future<Output = Result<()>>,
463 M: RequestMessage,
464 {
465 let handler = Arc::new(handler);
466 self.add_handler(move |envelope, session| {
467 let receipt = envelope.receipt();
468 let handler = handler.clone();
469 async move {
470 let peer = session.peer.clone();
471 let responded = Arc::new(AtomicBool::default());
472 let response = Response {
473 peer: peer.clone(),
474 responded: responded.clone(),
475 receipt,
476 };
477 match (handler)(envelope.payload, response, session).await {
478 Ok(()) => {
479 if responded.load(std::sync::atomic::Ordering::SeqCst) {
480 Ok(())
481 } else {
482 Err(anyhow!("handler did not send a response"))?
483 }
484 }
485 Err(error) => {
486 peer.respond_with_error(
487 receipt,
488 proto::Error {
489 message: error.to_string(),
490 },
491 )?;
492 Err(error)
493 }
494 }
495 }
496 })
497 }
498
499 pub fn handle_connection(
500 self: &Arc<Self>,
501 connection: Connection,
502 address: String,
503 user: User,
504 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
505 executor: Executor,
506 ) -> impl Future<Output = Result<()>> {
507 let this = self.clone();
508 let user_id = user.id;
509 let login = user.github_login;
510 let span = info_span!("handle connection", %user_id, %login, %address);
511 let mut teardown = self.teardown.subscribe();
512 async move {
513 let (connection_id, handle_io, mut incoming_rx) = this
514 .peer
515 .add_connection(connection, {
516 let executor = executor.clone();
517 move |duration| executor.sleep(duration)
518 });
519
520 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
521 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
522 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
523
524 if let Some(send_connection_id) = send_connection_id.take() {
525 let _ = send_connection_id.send(connection_id);
526 }
527
528 if !user.connected_once {
529 this.peer.send(connection_id, proto::ShowContacts {})?;
530 this.app_state.db.set_user_connected_once(user_id, true).await?;
531 }
532
533 let (contacts, invite_code, channels_for_user, channel_invites) = future::try_join4(
534 this.app_state.db.get_contacts(user_id),
535 this.app_state.db.get_invite_code_for_user(user_id),
536 this.app_state.db.get_channels_for_user(user_id),
537 this.app_state.db.get_channel_invites_for_user(user_id)
538 ).await?;
539
540 {
541 let mut pool = this.connection_pool.lock();
542 pool.add_connection(connection_id, user_id, user.admin);
543 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
544 this.peer.send(connection_id, build_initial_channels_update(
545 channels_for_user,
546 channel_invites
547 ))?;
548
549 if let Some((code, count)) = invite_code {
550 this.peer.send(connection_id, proto::UpdateInviteInfo {
551 url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
552 count: count as u32,
553 })?;
554 }
555 }
556
557 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
558 this.peer.send(connection_id, incoming_call)?;
559 }
560
561 let session = Session {
562 user_id,
563 connection_id,
564 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
565 peer: this.peer.clone(),
566 connection_pool: this.connection_pool.clone(),
567 live_kit_client: this.app_state.live_kit_client.clone(),
568 executor: executor.clone(),
569 };
570 update_user_contacts(user_id, &session).await?;
571
572 let handle_io = handle_io.fuse();
573 futures::pin_mut!(handle_io);
574
575 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
576 // This prevents deadlocks when e.g., client A performs a request to client B and
577 // client B performs a request to client A. If both clients stop processing further
578 // messages until their respective request completes, they won't have a chance to
579 // respond to the other client's request and cause a deadlock.
580 //
581 // This arrangement ensures we will attempt to process earlier messages first, but fall
582 // back to processing messages arrived later in the spirit of making progress.
583 let mut foreground_message_handlers = FuturesUnordered::new();
584 let concurrent_handlers = Arc::new(Semaphore::new(256));
585 loop {
586 let next_message = async {
587 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
588 let message = incoming_rx.next().await;
589 (permit, message)
590 }.fuse();
591 futures::pin_mut!(next_message);
592 futures::select_biased! {
593 _ = teardown.changed().fuse() => return Ok(()),
594 result = handle_io => {
595 if let Err(error) = result {
596 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
597 }
598 break;
599 }
600 _ = foreground_message_handlers.next() => {}
601 next_message = next_message => {
602 let (permit, message) = next_message;
603 if let Some(message) = message {
604 let type_name = message.payload_type_name();
605 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
606 let span_enter = span.enter();
607 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
608 let is_background = message.is_background();
609 let handle_message = (handler)(message, session.clone());
610 drop(span_enter);
611
612 let handle_message = async move {
613 handle_message.await;
614 drop(permit);
615 }.instrument(span);
616 if is_background {
617 executor.spawn_detached(handle_message);
618 } else {
619 foreground_message_handlers.push(handle_message);
620 }
621 } else {
622 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
623 }
624 } else {
625 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
626 break;
627 }
628 }
629 }
630 }
631
632 drop(foreground_message_handlers);
633 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
634 if let Err(error) = connection_lost(session, teardown, executor).await {
635 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
636 }
637
638 Ok(())
639 }.instrument(span)
640 }
641
642 pub async fn invite_code_redeemed(
643 self: &Arc<Self>,
644 inviter_id: UserId,
645 invitee_id: UserId,
646 ) -> Result<()> {
647 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
648 if let Some(code) = &user.invite_code {
649 let pool = self.connection_pool.lock();
650 let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
651 for connection_id in pool.user_connection_ids(inviter_id) {
652 self.peer.send(
653 connection_id,
654 proto::UpdateContacts {
655 contacts: vec![invitee_contact.clone()],
656 ..Default::default()
657 },
658 )?;
659 self.peer.send(
660 connection_id,
661 proto::UpdateInviteInfo {
662 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
663 count: user.invite_count as u32,
664 },
665 )?;
666 }
667 }
668 }
669 Ok(())
670 }
671
672 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
673 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
674 if let Some(invite_code) = &user.invite_code {
675 let pool = self.connection_pool.lock();
676 for connection_id in pool.user_connection_ids(user_id) {
677 self.peer.send(
678 connection_id,
679 proto::UpdateInviteInfo {
680 url: format!(
681 "{}{}",
682 self.app_state.config.invite_link_prefix, invite_code
683 ),
684 count: user.invite_count as u32,
685 },
686 )?;
687 }
688 }
689 }
690 Ok(())
691 }
692
693 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
694 ServerSnapshot {
695 connection_pool: ConnectionPoolGuard {
696 guard: self.connection_pool.lock(),
697 _not_send: PhantomData,
698 },
699 peer: &self.peer,
700 }
701 }
702}
703
704impl<'a> Deref for ConnectionPoolGuard<'a> {
705 type Target = ConnectionPool;
706
707 fn deref(&self) -> &Self::Target {
708 &*self.guard
709 }
710}
711
712impl<'a> DerefMut for ConnectionPoolGuard<'a> {
713 fn deref_mut(&mut self) -> &mut Self::Target {
714 &mut *self.guard
715 }
716}
717
718impl<'a> Drop for ConnectionPoolGuard<'a> {
719 fn drop(&mut self) {
720 #[cfg(test)]
721 self.check_invariants();
722 }
723}
724
725fn broadcast<F>(
726 sender_id: Option<ConnectionId>,
727 receiver_ids: impl IntoIterator<Item = ConnectionId>,
728 mut f: F,
729) where
730 F: FnMut(ConnectionId) -> anyhow::Result<()>,
731{
732 for receiver_id in receiver_ids {
733 if Some(receiver_id) != sender_id {
734 if let Err(error) = f(receiver_id) {
735 tracing::error!("failed to send to {:?} {}", receiver_id, error);
736 }
737 }
738 }
739}
740
741lazy_static! {
742 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
743}
744
745pub struct ProtocolVersion(u32);
746
747impl Header for ProtocolVersion {
748 fn name() -> &'static HeaderName {
749 &ZED_PROTOCOL_VERSION
750 }
751
752 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
753 where
754 Self: Sized,
755 I: Iterator<Item = &'i axum::http::HeaderValue>,
756 {
757 let version = values
758 .next()
759 .ok_or_else(axum::headers::Error::invalid)?
760 .to_str()
761 .map_err(|_| axum::headers::Error::invalid())?
762 .parse()
763 .map_err(|_| axum::headers::Error::invalid())?;
764 Ok(Self(version))
765 }
766
767 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
768 values.extend([self.0.to_string().parse().unwrap()]);
769 }
770}
771
772pub fn routes(server: Arc<Server>) -> Router<Body> {
773 Router::new()
774 .route("/rpc", get(handle_websocket_request))
775 .layer(
776 ServiceBuilder::new()
777 .layer(Extension(server.app_state.clone()))
778 .layer(middleware::from_fn(auth::validate_header)),
779 )
780 .route("/metrics", get(handle_metrics))
781 .layer(Extension(server))
782}
783
784pub async fn handle_websocket_request(
785 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
786 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
787 Extension(server): Extension<Arc<Server>>,
788 Extension(user): Extension<User>,
789 ws: WebSocketUpgrade,
790) -> axum::response::Response {
791 if protocol_version != rpc::PROTOCOL_VERSION {
792 return (
793 StatusCode::UPGRADE_REQUIRED,
794 "client must be upgraded".to_string(),
795 )
796 .into_response();
797 }
798 let socket_address = socket_address.to_string();
799 ws.on_upgrade(move |socket| {
800 use util::ResultExt;
801 let socket = socket
802 .map_ok(to_tungstenite_message)
803 .err_into()
804 .with(|message| async move { Ok(to_axum_message(message)) });
805 let connection = Connection::new(Box::pin(socket));
806 async move {
807 server
808 .handle_connection(connection, socket_address, user, None, Executor::Production)
809 .await
810 .log_err();
811 }
812 })
813}
814
815pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
816 let connections = server
817 .connection_pool
818 .lock()
819 .connections()
820 .filter(|connection| !connection.admin)
821 .count();
822
823 METRIC_CONNECTIONS.set(connections as _);
824
825 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
826 METRIC_SHARED_PROJECTS.set(shared_projects as _);
827
828 let encoder = prometheus::TextEncoder::new();
829 let metric_families = prometheus::gather();
830 let encoded_metrics = encoder
831 .encode_to_string(&metric_families)
832 .map_err(|err| anyhow!("{}", err))?;
833 Ok(encoded_metrics)
834}
835
836#[instrument(err, skip(executor))]
837async fn connection_lost(
838 session: Session,
839 mut teardown: watch::Receiver<()>,
840 executor: Executor,
841) -> Result<()> {
842 session.peer.disconnect(session.connection_id);
843 session
844 .connection_pool()
845 .await
846 .remove_connection(session.connection_id)?;
847
848 session
849 .db()
850 .await
851 .connection_lost(session.connection_id)
852 .await
853 .trace_err();
854
855 futures::select_biased! {
856 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
857 leave_room_for_session(&session).await.trace_err();
858
859 if !session
860 .connection_pool()
861 .await
862 .is_user_online(session.user_id)
863 {
864 let db = session.db().await;
865 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
866 room_updated(&room, &session.peer);
867 }
868 }
869 update_user_contacts(session.user_id, &session).await?;
870 }
871 _ = teardown.changed().fuse() => {}
872 }
873
874 Ok(())
875}
876
877async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
878 response.send(proto::Ack {})?;
879 Ok(())
880}
881
882async fn create_room(
883 _request: proto::CreateRoom,
884 response: Response<proto::CreateRoom>,
885 session: Session,
886) -> Result<()> {
887 let live_kit_room = nanoid::nanoid!(30);
888
889 let live_kit_connection_info = {
890 let live_kit_room = live_kit_room.clone();
891 let live_kit = session.live_kit_client.as_ref();
892
893 util::async_iife!({
894 let live_kit = live_kit?;
895
896 live_kit
897 .create_room(live_kit_room.clone())
898 .await
899 .trace_err()?;
900
901 let token = live_kit
902 .room_token(&live_kit_room, &session.user_id.to_string())
903 .trace_err()?;
904
905 Some(proto::LiveKitConnectionInfo {
906 server_url: live_kit.url().into(),
907 token,
908 })
909 })
910 }
911 .await;
912
913 let room = session
914 .db()
915 .await
916 .create_room(session.user_id, session.connection_id, &live_kit_room)
917 .await?;
918
919 response.send(proto::CreateRoomResponse {
920 room: Some(room.clone()),
921 live_kit_connection_info,
922 })?;
923
924 update_user_contacts(session.user_id, &session).await?;
925 Ok(())
926}
927
928async fn join_room(
929 request: proto::JoinRoom,
930 response: Response<proto::JoinRoom>,
931 session: Session,
932) -> Result<()> {
933 let room_id = RoomId::from_proto(request.id);
934 let joined_room = {
935 let room = session
936 .db()
937 .await
938 .join_room(room_id, session.user_id, session.connection_id)
939 .await?;
940 room_updated(&room.room, &session.peer);
941 room.into_inner()
942 };
943
944 if let Some(channel_id) = joined_room.channel_id {
945 channel_updated(
946 channel_id,
947 &joined_room.room,
948 &joined_room.channel_members,
949 &session.peer,
950 &*session.connection_pool().await,
951 )
952 }
953
954 for connection_id in session
955 .connection_pool()
956 .await
957 .user_connection_ids(session.user_id)
958 {
959 session
960 .peer
961 .send(
962 connection_id,
963 proto::CallCanceled {
964 room_id: room_id.to_proto(),
965 },
966 )
967 .trace_err();
968 }
969
970 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
971 if let Some(token) = live_kit
972 .room_token(
973 &joined_room.room.live_kit_room,
974 &session.user_id.to_string(),
975 )
976 .trace_err()
977 {
978 Some(proto::LiveKitConnectionInfo {
979 server_url: live_kit.url().into(),
980 token,
981 })
982 } else {
983 None
984 }
985 } else {
986 None
987 };
988
989 response.send(proto::JoinRoomResponse {
990 room: Some(joined_room.room),
991 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
992 live_kit_connection_info,
993 })?;
994
995 update_user_contacts(session.user_id, &session).await?;
996 Ok(())
997}
998
999async fn rejoin_room(
1000 request: proto::RejoinRoom,
1001 response: Response<proto::RejoinRoom>,
1002 session: Session,
1003) -> Result<()> {
1004 let room;
1005 let channel_id;
1006 let channel_members;
1007 {
1008 let mut rejoined_room = session
1009 .db()
1010 .await
1011 .rejoin_room(request, session.user_id, session.connection_id)
1012 .await?;
1013
1014 response.send(proto::RejoinRoomResponse {
1015 room: Some(rejoined_room.room.clone()),
1016 reshared_projects: rejoined_room
1017 .reshared_projects
1018 .iter()
1019 .map(|project| proto::ResharedProject {
1020 id: project.id.to_proto(),
1021 collaborators: project
1022 .collaborators
1023 .iter()
1024 .map(|collaborator| collaborator.to_proto())
1025 .collect(),
1026 })
1027 .collect(),
1028 rejoined_projects: rejoined_room
1029 .rejoined_projects
1030 .iter()
1031 .map(|rejoined_project| proto::RejoinedProject {
1032 id: rejoined_project.id.to_proto(),
1033 worktrees: rejoined_project
1034 .worktrees
1035 .iter()
1036 .map(|worktree| proto::WorktreeMetadata {
1037 id: worktree.id,
1038 root_name: worktree.root_name.clone(),
1039 visible: worktree.visible,
1040 abs_path: worktree.abs_path.clone(),
1041 })
1042 .collect(),
1043 collaborators: rejoined_project
1044 .collaborators
1045 .iter()
1046 .map(|collaborator| collaborator.to_proto())
1047 .collect(),
1048 language_servers: rejoined_project.language_servers.clone(),
1049 })
1050 .collect(),
1051 })?;
1052 room_updated(&rejoined_room.room, &session.peer);
1053
1054 for project in &rejoined_room.reshared_projects {
1055 for collaborator in &project.collaborators {
1056 session
1057 .peer
1058 .send(
1059 collaborator.connection_id,
1060 proto::UpdateProjectCollaborator {
1061 project_id: project.id.to_proto(),
1062 old_peer_id: Some(project.old_connection_id.into()),
1063 new_peer_id: Some(session.connection_id.into()),
1064 },
1065 )
1066 .trace_err();
1067 }
1068
1069 broadcast(
1070 Some(session.connection_id),
1071 project
1072 .collaborators
1073 .iter()
1074 .map(|collaborator| collaborator.connection_id),
1075 |connection_id| {
1076 session.peer.forward_send(
1077 session.connection_id,
1078 connection_id,
1079 proto::UpdateProject {
1080 project_id: project.id.to_proto(),
1081 worktrees: project.worktrees.clone(),
1082 },
1083 )
1084 },
1085 );
1086 }
1087
1088 for project in &rejoined_room.rejoined_projects {
1089 for collaborator in &project.collaborators {
1090 session
1091 .peer
1092 .send(
1093 collaborator.connection_id,
1094 proto::UpdateProjectCollaborator {
1095 project_id: project.id.to_proto(),
1096 old_peer_id: Some(project.old_connection_id.into()),
1097 new_peer_id: Some(session.connection_id.into()),
1098 },
1099 )
1100 .trace_err();
1101 }
1102 }
1103
1104 for project in &mut rejoined_room.rejoined_projects {
1105 for worktree in mem::take(&mut project.worktrees) {
1106 #[cfg(any(test, feature = "test-support"))]
1107 const MAX_CHUNK_SIZE: usize = 2;
1108 #[cfg(not(any(test, feature = "test-support")))]
1109 const MAX_CHUNK_SIZE: usize = 256;
1110
1111 // Stream this worktree's entries.
1112 let message = proto::UpdateWorktree {
1113 project_id: project.id.to_proto(),
1114 worktree_id: worktree.id,
1115 abs_path: worktree.abs_path.clone(),
1116 root_name: worktree.root_name,
1117 updated_entries: worktree.updated_entries,
1118 removed_entries: worktree.removed_entries,
1119 scan_id: worktree.scan_id,
1120 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1121 updated_repositories: worktree.updated_repositories,
1122 removed_repositories: worktree.removed_repositories,
1123 };
1124 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1125 session.peer.send(session.connection_id, update.clone())?;
1126 }
1127
1128 // Stream this worktree's diagnostics.
1129 for summary in worktree.diagnostic_summaries {
1130 session.peer.send(
1131 session.connection_id,
1132 proto::UpdateDiagnosticSummary {
1133 project_id: project.id.to_proto(),
1134 worktree_id: worktree.id,
1135 summary: Some(summary),
1136 },
1137 )?;
1138 }
1139
1140 for settings_file in worktree.settings_files {
1141 session.peer.send(
1142 session.connection_id,
1143 proto::UpdateWorktreeSettings {
1144 project_id: project.id.to_proto(),
1145 worktree_id: worktree.id,
1146 path: settings_file.path,
1147 content: Some(settings_file.content),
1148 },
1149 )?;
1150 }
1151 }
1152
1153 for language_server in &project.language_servers {
1154 session.peer.send(
1155 session.connection_id,
1156 proto::UpdateLanguageServer {
1157 project_id: project.id.to_proto(),
1158 language_server_id: language_server.id,
1159 variant: Some(
1160 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1161 proto::LspDiskBasedDiagnosticsUpdated {},
1162 ),
1163 ),
1164 },
1165 )?;
1166 }
1167 }
1168
1169 let rejoined_room = rejoined_room.into_inner();
1170
1171 room = rejoined_room.room;
1172 channel_id = rejoined_room.channel_id;
1173 channel_members = rejoined_room.channel_members;
1174 }
1175
1176 if let Some(channel_id) = channel_id {
1177 channel_updated(
1178 channel_id,
1179 &room,
1180 &channel_members,
1181 &session.peer,
1182 &*session.connection_pool().await,
1183 );
1184 }
1185
1186 update_user_contacts(session.user_id, &session).await?;
1187 Ok(())
1188}
1189
1190async fn leave_room(
1191 _: proto::LeaveRoom,
1192 response: Response<proto::LeaveRoom>,
1193 session: Session,
1194) -> Result<()> {
1195 leave_room_for_session(&session).await?;
1196 response.send(proto::Ack {})?;
1197 Ok(())
1198}
1199
1200async fn call(
1201 request: proto::Call,
1202 response: Response<proto::Call>,
1203 session: Session,
1204) -> Result<()> {
1205 let room_id = RoomId::from_proto(request.room_id);
1206 let calling_user_id = session.user_id;
1207 let calling_connection_id = session.connection_id;
1208 let called_user_id = UserId::from_proto(request.called_user_id);
1209 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1210 if !session
1211 .db()
1212 .await
1213 .has_contact(calling_user_id, called_user_id)
1214 .await?
1215 {
1216 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1217 }
1218
1219 let incoming_call = {
1220 let (room, incoming_call) = &mut *session
1221 .db()
1222 .await
1223 .call(
1224 room_id,
1225 calling_user_id,
1226 calling_connection_id,
1227 called_user_id,
1228 initial_project_id,
1229 )
1230 .await?;
1231 room_updated(&room, &session.peer);
1232 mem::take(incoming_call)
1233 };
1234 update_user_contacts(called_user_id, &session).await?;
1235
1236 let mut calls = session
1237 .connection_pool()
1238 .await
1239 .user_connection_ids(called_user_id)
1240 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1241 .collect::<FuturesUnordered<_>>();
1242
1243 while let Some(call_response) = calls.next().await {
1244 match call_response.as_ref() {
1245 Ok(_) => {
1246 response.send(proto::Ack {})?;
1247 return Ok(());
1248 }
1249 Err(_) => {
1250 call_response.trace_err();
1251 }
1252 }
1253 }
1254
1255 {
1256 let room = session
1257 .db()
1258 .await
1259 .call_failed(room_id, called_user_id)
1260 .await?;
1261 room_updated(&room, &session.peer);
1262 }
1263 update_user_contacts(called_user_id, &session).await?;
1264
1265 Err(anyhow!("failed to ring user"))?
1266}
1267
1268async fn cancel_call(
1269 request: proto::CancelCall,
1270 response: Response<proto::CancelCall>,
1271 session: Session,
1272) -> Result<()> {
1273 let called_user_id = UserId::from_proto(request.called_user_id);
1274 let room_id = RoomId::from_proto(request.room_id);
1275 {
1276 let room = session
1277 .db()
1278 .await
1279 .cancel_call(room_id, session.connection_id, called_user_id)
1280 .await?;
1281 room_updated(&room, &session.peer);
1282 }
1283
1284 for connection_id in session
1285 .connection_pool()
1286 .await
1287 .user_connection_ids(called_user_id)
1288 {
1289 session
1290 .peer
1291 .send(
1292 connection_id,
1293 proto::CallCanceled {
1294 room_id: room_id.to_proto(),
1295 },
1296 )
1297 .trace_err();
1298 }
1299 response.send(proto::Ack {})?;
1300
1301 update_user_contacts(called_user_id, &session).await?;
1302 Ok(())
1303}
1304
1305async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1306 let room_id = RoomId::from_proto(message.room_id);
1307 {
1308 let room = session
1309 .db()
1310 .await
1311 .decline_call(Some(room_id), session.user_id)
1312 .await?
1313 .ok_or_else(|| anyhow!("failed to decline call"))?;
1314 room_updated(&room, &session.peer);
1315 }
1316
1317 for connection_id in session
1318 .connection_pool()
1319 .await
1320 .user_connection_ids(session.user_id)
1321 {
1322 session
1323 .peer
1324 .send(
1325 connection_id,
1326 proto::CallCanceled {
1327 room_id: room_id.to_proto(),
1328 },
1329 )
1330 .trace_err();
1331 }
1332 update_user_contacts(session.user_id, &session).await?;
1333 Ok(())
1334}
1335
1336async fn update_participant_location(
1337 request: proto::UpdateParticipantLocation,
1338 response: Response<proto::UpdateParticipantLocation>,
1339 session: Session,
1340) -> Result<()> {
1341 let room_id = RoomId::from_proto(request.room_id);
1342 let location = request
1343 .location
1344 .ok_or_else(|| anyhow!("invalid location"))?;
1345
1346 let db = session.db().await;
1347 let room = db
1348 .update_room_participant_location(room_id, session.connection_id, location)
1349 .await?;
1350
1351 room_updated(&room, &session.peer);
1352 response.send(proto::Ack {})?;
1353 Ok(())
1354}
1355
1356async fn share_project(
1357 request: proto::ShareProject,
1358 response: Response<proto::ShareProject>,
1359 session: Session,
1360) -> Result<()> {
1361 let (project_id, room) = &*session
1362 .db()
1363 .await
1364 .share_project(
1365 RoomId::from_proto(request.room_id),
1366 session.connection_id,
1367 &request.worktrees,
1368 )
1369 .await?;
1370 response.send(proto::ShareProjectResponse {
1371 project_id: project_id.to_proto(),
1372 })?;
1373 room_updated(&room, &session.peer);
1374
1375 Ok(())
1376}
1377
1378async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1379 let project_id = ProjectId::from_proto(message.project_id);
1380
1381 let (room, guest_connection_ids) = &*session
1382 .db()
1383 .await
1384 .unshare_project(project_id, session.connection_id)
1385 .await?;
1386
1387 broadcast(
1388 Some(session.connection_id),
1389 guest_connection_ids.iter().copied(),
1390 |conn_id| session.peer.send(conn_id, message.clone()),
1391 );
1392 room_updated(&room, &session.peer);
1393
1394 Ok(())
1395}
1396
1397async fn join_project(
1398 request: proto::JoinProject,
1399 response: Response<proto::JoinProject>,
1400 session: Session,
1401) -> Result<()> {
1402 let project_id = ProjectId::from_proto(request.project_id);
1403 let guest_user_id = session.user_id;
1404
1405 tracing::info!(%project_id, "join project");
1406
1407 let (project, replica_id) = &mut *session
1408 .db()
1409 .await
1410 .join_project(project_id, session.connection_id)
1411 .await?;
1412
1413 let collaborators = project
1414 .collaborators
1415 .iter()
1416 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1417 .map(|collaborator| collaborator.to_proto())
1418 .collect::<Vec<_>>();
1419
1420 let worktrees = project
1421 .worktrees
1422 .iter()
1423 .map(|(id, worktree)| proto::WorktreeMetadata {
1424 id: *id,
1425 root_name: worktree.root_name.clone(),
1426 visible: worktree.visible,
1427 abs_path: worktree.abs_path.clone(),
1428 })
1429 .collect::<Vec<_>>();
1430
1431 for collaborator in &collaborators {
1432 session
1433 .peer
1434 .send(
1435 collaborator.peer_id.unwrap().into(),
1436 proto::AddProjectCollaborator {
1437 project_id: project_id.to_proto(),
1438 collaborator: Some(proto::Collaborator {
1439 peer_id: Some(session.connection_id.into()),
1440 replica_id: replica_id.0 as u32,
1441 user_id: guest_user_id.to_proto(),
1442 }),
1443 },
1444 )
1445 .trace_err();
1446 }
1447
1448 // First, we send the metadata associated with each worktree.
1449 response.send(proto::JoinProjectResponse {
1450 worktrees: worktrees.clone(),
1451 replica_id: replica_id.0 as u32,
1452 collaborators: collaborators.clone(),
1453 language_servers: project.language_servers.clone(),
1454 })?;
1455
1456 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1457 #[cfg(any(test, feature = "test-support"))]
1458 const MAX_CHUNK_SIZE: usize = 2;
1459 #[cfg(not(any(test, feature = "test-support")))]
1460 const MAX_CHUNK_SIZE: usize = 256;
1461
1462 // Stream this worktree's entries.
1463 let message = proto::UpdateWorktree {
1464 project_id: project_id.to_proto(),
1465 worktree_id,
1466 abs_path: worktree.abs_path.clone(),
1467 root_name: worktree.root_name,
1468 updated_entries: worktree.entries,
1469 removed_entries: Default::default(),
1470 scan_id: worktree.scan_id,
1471 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1472 updated_repositories: worktree.repository_entries.into_values().collect(),
1473 removed_repositories: Default::default(),
1474 };
1475 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1476 session.peer.send(session.connection_id, update.clone())?;
1477 }
1478
1479 // Stream this worktree's diagnostics.
1480 for summary in worktree.diagnostic_summaries {
1481 session.peer.send(
1482 session.connection_id,
1483 proto::UpdateDiagnosticSummary {
1484 project_id: project_id.to_proto(),
1485 worktree_id: worktree.id,
1486 summary: Some(summary),
1487 },
1488 )?;
1489 }
1490
1491 for settings_file in worktree.settings_files {
1492 session.peer.send(
1493 session.connection_id,
1494 proto::UpdateWorktreeSettings {
1495 project_id: project_id.to_proto(),
1496 worktree_id: worktree.id,
1497 path: settings_file.path,
1498 content: Some(settings_file.content),
1499 },
1500 )?;
1501 }
1502 }
1503
1504 for language_server in &project.language_servers {
1505 session.peer.send(
1506 session.connection_id,
1507 proto::UpdateLanguageServer {
1508 project_id: project_id.to_proto(),
1509 language_server_id: language_server.id,
1510 variant: Some(
1511 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1512 proto::LspDiskBasedDiagnosticsUpdated {},
1513 ),
1514 ),
1515 },
1516 )?;
1517 }
1518
1519 Ok(())
1520}
1521
1522async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1523 let sender_id = session.connection_id;
1524 let project_id = ProjectId::from_proto(request.project_id);
1525
1526 let (room, project) = &*session
1527 .db()
1528 .await
1529 .leave_project(project_id, sender_id)
1530 .await?;
1531 tracing::info!(
1532 %project_id,
1533 host_user_id = %project.host_user_id,
1534 host_connection_id = %project.host_connection_id,
1535 "leave project"
1536 );
1537
1538 project_left(&project, &session);
1539 room_updated(&room, &session.peer);
1540
1541 Ok(())
1542}
1543
1544async fn update_project(
1545 request: proto::UpdateProject,
1546 response: Response<proto::UpdateProject>,
1547 session: Session,
1548) -> Result<()> {
1549 let project_id = ProjectId::from_proto(request.project_id);
1550 let (room, guest_connection_ids) = &*session
1551 .db()
1552 .await
1553 .update_project(project_id, session.connection_id, &request.worktrees)
1554 .await?;
1555 broadcast(
1556 Some(session.connection_id),
1557 guest_connection_ids.iter().copied(),
1558 |connection_id| {
1559 session
1560 .peer
1561 .forward_send(session.connection_id, connection_id, request.clone())
1562 },
1563 );
1564 room_updated(&room, &session.peer);
1565 response.send(proto::Ack {})?;
1566
1567 Ok(())
1568}
1569
1570async fn update_worktree(
1571 request: proto::UpdateWorktree,
1572 response: Response<proto::UpdateWorktree>,
1573 session: Session,
1574) -> Result<()> {
1575 let guest_connection_ids = session
1576 .db()
1577 .await
1578 .update_worktree(&request, session.connection_id)
1579 .await?;
1580
1581 broadcast(
1582 Some(session.connection_id),
1583 guest_connection_ids.iter().copied(),
1584 |connection_id| {
1585 session
1586 .peer
1587 .forward_send(session.connection_id, connection_id, request.clone())
1588 },
1589 );
1590 response.send(proto::Ack {})?;
1591 Ok(())
1592}
1593
1594async fn update_diagnostic_summary(
1595 message: proto::UpdateDiagnosticSummary,
1596 session: Session,
1597) -> Result<()> {
1598 let guest_connection_ids = session
1599 .db()
1600 .await
1601 .update_diagnostic_summary(&message, session.connection_id)
1602 .await?;
1603
1604 broadcast(
1605 Some(session.connection_id),
1606 guest_connection_ids.iter().copied(),
1607 |connection_id| {
1608 session
1609 .peer
1610 .forward_send(session.connection_id, connection_id, message.clone())
1611 },
1612 );
1613
1614 Ok(())
1615}
1616
1617async fn update_worktree_settings(
1618 message: proto::UpdateWorktreeSettings,
1619 session: Session,
1620) -> Result<()> {
1621 let guest_connection_ids = session
1622 .db()
1623 .await
1624 .update_worktree_settings(&message, session.connection_id)
1625 .await?;
1626
1627 broadcast(
1628 Some(session.connection_id),
1629 guest_connection_ids.iter().copied(),
1630 |connection_id| {
1631 session
1632 .peer
1633 .forward_send(session.connection_id, connection_id, message.clone())
1634 },
1635 );
1636
1637 Ok(())
1638}
1639
1640async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1641 broadcast_project_message(request.project_id, request, session).await
1642}
1643
1644async fn start_language_server(
1645 request: proto::StartLanguageServer,
1646 session: Session,
1647) -> Result<()> {
1648 let guest_connection_ids = session
1649 .db()
1650 .await
1651 .start_language_server(&request, session.connection_id)
1652 .await?;
1653
1654 broadcast(
1655 Some(session.connection_id),
1656 guest_connection_ids.iter().copied(),
1657 |connection_id| {
1658 session
1659 .peer
1660 .forward_send(session.connection_id, connection_id, request.clone())
1661 },
1662 );
1663 Ok(())
1664}
1665
1666async fn update_language_server(
1667 request: proto::UpdateLanguageServer,
1668 session: Session,
1669) -> Result<()> {
1670 session.executor.record_backtrace();
1671 let project_id = ProjectId::from_proto(request.project_id);
1672 let project_connection_ids = session
1673 .db()
1674 .await
1675 .project_connection_ids(project_id, session.connection_id)
1676 .await?;
1677 broadcast(
1678 Some(session.connection_id),
1679 project_connection_ids.iter().copied(),
1680 |connection_id| {
1681 session
1682 .peer
1683 .forward_send(session.connection_id, connection_id, request.clone())
1684 },
1685 );
1686 Ok(())
1687}
1688
1689async fn forward_project_request<T>(
1690 request: T,
1691 response: Response<T>,
1692 session: Session,
1693) -> Result<()>
1694where
1695 T: EntityMessage + RequestMessage,
1696{
1697 session.executor.record_backtrace();
1698 let project_id = ProjectId::from_proto(request.remote_entity_id());
1699 let host_connection_id = {
1700 let collaborators = session
1701 .db()
1702 .await
1703 .project_collaborators(project_id, session.connection_id)
1704 .await?;
1705 collaborators
1706 .iter()
1707 .find(|collaborator| collaborator.is_host)
1708 .ok_or_else(|| anyhow!("host not found"))?
1709 .connection_id
1710 };
1711
1712 let payload = session
1713 .peer
1714 .forward_request(session.connection_id, host_connection_id, request)
1715 .await?;
1716
1717 response.send(payload)?;
1718 Ok(())
1719}
1720
1721async fn create_buffer_for_peer(
1722 request: proto::CreateBufferForPeer,
1723 session: Session,
1724) -> Result<()> {
1725 session.executor.record_backtrace();
1726 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1727 session
1728 .peer
1729 .forward_send(session.connection_id, peer_id.into(), request)?;
1730 Ok(())
1731}
1732
1733async fn update_buffer(
1734 request: proto::UpdateBuffer,
1735 response: Response<proto::UpdateBuffer>,
1736 session: Session,
1737) -> Result<()> {
1738 session.executor.record_backtrace();
1739 let project_id = ProjectId::from_proto(request.project_id);
1740 let mut guest_connection_ids;
1741 let mut host_connection_id = None;
1742 {
1743 let collaborators = session
1744 .db()
1745 .await
1746 .project_collaborators(project_id, session.connection_id)
1747 .await?;
1748 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1749 for collaborator in collaborators.iter() {
1750 if collaborator.is_host {
1751 host_connection_id = Some(collaborator.connection_id);
1752 } else {
1753 guest_connection_ids.push(collaborator.connection_id);
1754 }
1755 }
1756 }
1757 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1758
1759 session.executor.record_backtrace();
1760 broadcast(
1761 Some(session.connection_id),
1762 guest_connection_ids,
1763 |connection_id| {
1764 session
1765 .peer
1766 .forward_send(session.connection_id, connection_id, request.clone())
1767 },
1768 );
1769 if host_connection_id != session.connection_id {
1770 session
1771 .peer
1772 .forward_request(session.connection_id, host_connection_id, request.clone())
1773 .await?;
1774 }
1775
1776 response.send(proto::Ack {})?;
1777 Ok(())
1778}
1779
1780async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1781 let project_id = ProjectId::from_proto(request.project_id);
1782 let project_connection_ids = session
1783 .db()
1784 .await
1785 .project_connection_ids(project_id, session.connection_id)
1786 .await?;
1787
1788 broadcast(
1789 Some(session.connection_id),
1790 project_connection_ids.iter().copied(),
1791 |connection_id| {
1792 session
1793 .peer
1794 .forward_send(session.connection_id, connection_id, request.clone())
1795 },
1796 );
1797 Ok(())
1798}
1799
1800async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1801 let project_id = ProjectId::from_proto(request.project_id);
1802 let project_connection_ids = session
1803 .db()
1804 .await
1805 .project_connection_ids(project_id, session.connection_id)
1806 .await?;
1807 broadcast(
1808 Some(session.connection_id),
1809 project_connection_ids.iter().copied(),
1810 |connection_id| {
1811 session
1812 .peer
1813 .forward_send(session.connection_id, connection_id, request.clone())
1814 },
1815 );
1816 Ok(())
1817}
1818
1819async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1820 broadcast_project_message(request.project_id, request, session).await
1821}
1822
1823async fn broadcast_project_message<T: EnvelopedMessage>(
1824 project_id: u64,
1825 request: T,
1826 session: Session,
1827) -> Result<()> {
1828 let project_id = ProjectId::from_proto(project_id);
1829 let project_connection_ids = session
1830 .db()
1831 .await
1832 .project_connection_ids(project_id, session.connection_id)
1833 .await?;
1834 broadcast(
1835 Some(session.connection_id),
1836 project_connection_ids.iter().copied(),
1837 |connection_id| {
1838 session
1839 .peer
1840 .forward_send(session.connection_id, connection_id, request.clone())
1841 },
1842 );
1843 Ok(())
1844}
1845
1846async fn follow(
1847 request: proto::Follow,
1848 response: Response<proto::Follow>,
1849 session: Session,
1850) -> Result<()> {
1851 let project_id = ProjectId::from_proto(request.project_id);
1852 let leader_id = request
1853 .leader_id
1854 .ok_or_else(|| anyhow!("invalid leader id"))?
1855 .into();
1856 let follower_id = session.connection_id;
1857
1858 {
1859 let project_connection_ids = session
1860 .db()
1861 .await
1862 .project_connection_ids(project_id, session.connection_id)
1863 .await?;
1864
1865 if !project_connection_ids.contains(&leader_id) {
1866 Err(anyhow!("no such peer"))?;
1867 }
1868 }
1869
1870 let mut response_payload = session
1871 .peer
1872 .forward_request(session.connection_id, leader_id, request)
1873 .await?;
1874 response_payload
1875 .views
1876 .retain(|view| view.leader_id != Some(follower_id.into()));
1877 response.send(response_payload)?;
1878
1879 let room = session
1880 .db()
1881 .await
1882 .follow(project_id, leader_id, follower_id)
1883 .await?;
1884 room_updated(&room, &session.peer);
1885
1886 Ok(())
1887}
1888
1889async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1890 let project_id = ProjectId::from_proto(request.project_id);
1891 let leader_id = request
1892 .leader_id
1893 .ok_or_else(|| anyhow!("invalid leader id"))?
1894 .into();
1895 let follower_id = session.connection_id;
1896
1897 if !session
1898 .db()
1899 .await
1900 .project_connection_ids(project_id, session.connection_id)
1901 .await?
1902 .contains(&leader_id)
1903 {
1904 Err(anyhow!("no such peer"))?;
1905 }
1906
1907 session
1908 .peer
1909 .forward_send(session.connection_id, leader_id, request)?;
1910
1911 let room = session
1912 .db()
1913 .await
1914 .unfollow(project_id, leader_id, follower_id)
1915 .await?;
1916 room_updated(&room, &session.peer);
1917
1918 Ok(())
1919}
1920
1921async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1922 let project_id = ProjectId::from_proto(request.project_id);
1923 let project_connection_ids = session
1924 .db
1925 .lock()
1926 .await
1927 .project_connection_ids(project_id, session.connection_id)
1928 .await?;
1929
1930 let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1931 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1932 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1933 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1934 });
1935 for follower_peer_id in request.follower_ids.iter().copied() {
1936 let follower_connection_id = follower_peer_id.into();
1937 if project_connection_ids.contains(&follower_connection_id)
1938 && Some(follower_peer_id) != leader_id
1939 {
1940 session.peer.forward_send(
1941 session.connection_id,
1942 follower_connection_id,
1943 request.clone(),
1944 )?;
1945 }
1946 }
1947 Ok(())
1948}
1949
1950async fn get_users(
1951 request: proto::GetUsers,
1952 response: Response<proto::GetUsers>,
1953 session: Session,
1954) -> Result<()> {
1955 let user_ids = request
1956 .user_ids
1957 .into_iter()
1958 .map(UserId::from_proto)
1959 .collect();
1960 let users = session
1961 .db()
1962 .await
1963 .get_users_by_ids(user_ids)
1964 .await?
1965 .into_iter()
1966 .map(|user| proto::User {
1967 id: user.id.to_proto(),
1968 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1969 github_login: user.github_login,
1970 })
1971 .collect();
1972 response.send(proto::UsersResponse { users })?;
1973 Ok(())
1974}
1975
1976async fn fuzzy_search_users(
1977 request: proto::FuzzySearchUsers,
1978 response: Response<proto::FuzzySearchUsers>,
1979 session: Session,
1980) -> Result<()> {
1981 let query = request.query;
1982 let users = match query.len() {
1983 0 => vec![],
1984 1 | 2 => session
1985 .db()
1986 .await
1987 .get_user_by_github_login(&query)
1988 .await?
1989 .into_iter()
1990 .collect(),
1991 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1992 };
1993 let users = users
1994 .into_iter()
1995 .filter(|user| user.id != session.user_id)
1996 .map(|user| proto::User {
1997 id: user.id.to_proto(),
1998 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1999 github_login: user.github_login,
2000 })
2001 .collect();
2002 response.send(proto::UsersResponse { users })?;
2003 Ok(())
2004}
2005
2006async fn request_contact(
2007 request: proto::RequestContact,
2008 response: Response<proto::RequestContact>,
2009 session: Session,
2010) -> Result<()> {
2011 let requester_id = session.user_id;
2012 let responder_id = UserId::from_proto(request.responder_id);
2013 if requester_id == responder_id {
2014 return Err(anyhow!("cannot add yourself as a contact"))?;
2015 }
2016
2017 session
2018 .db()
2019 .await
2020 .send_contact_request(requester_id, responder_id)
2021 .await?;
2022
2023 // Update outgoing contact requests of requester
2024 let mut update = proto::UpdateContacts::default();
2025 update.outgoing_requests.push(responder_id.to_proto());
2026 for connection_id in session
2027 .connection_pool()
2028 .await
2029 .user_connection_ids(requester_id)
2030 {
2031 session.peer.send(connection_id, update.clone())?;
2032 }
2033
2034 // Update incoming contact requests of responder
2035 let mut update = proto::UpdateContacts::default();
2036 update
2037 .incoming_requests
2038 .push(proto::IncomingContactRequest {
2039 requester_id: requester_id.to_proto(),
2040 should_notify: true,
2041 });
2042 for connection_id in session
2043 .connection_pool()
2044 .await
2045 .user_connection_ids(responder_id)
2046 {
2047 session.peer.send(connection_id, update.clone())?;
2048 }
2049
2050 response.send(proto::Ack {})?;
2051 Ok(())
2052}
2053
2054async fn respond_to_contact_request(
2055 request: proto::RespondToContactRequest,
2056 response: Response<proto::RespondToContactRequest>,
2057 session: Session,
2058) -> Result<()> {
2059 let responder_id = session.user_id;
2060 let requester_id = UserId::from_proto(request.requester_id);
2061 let db = session.db().await;
2062 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2063 db.dismiss_contact_notification(responder_id, requester_id)
2064 .await?;
2065 } else {
2066 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2067
2068 db.respond_to_contact_request(responder_id, requester_id, accept)
2069 .await?;
2070 let requester_busy = db.is_user_busy(requester_id).await?;
2071 let responder_busy = db.is_user_busy(responder_id).await?;
2072
2073 let pool = session.connection_pool().await;
2074 // Update responder with new contact
2075 let mut update = proto::UpdateContacts::default();
2076 if accept {
2077 update
2078 .contacts
2079 .push(contact_for_user(requester_id, false, requester_busy, &pool));
2080 }
2081 update
2082 .remove_incoming_requests
2083 .push(requester_id.to_proto());
2084 for connection_id in pool.user_connection_ids(responder_id) {
2085 session.peer.send(connection_id, update.clone())?;
2086 }
2087
2088 // Update requester with new contact
2089 let mut update = proto::UpdateContacts::default();
2090 if accept {
2091 update
2092 .contacts
2093 .push(contact_for_user(responder_id, true, responder_busy, &pool));
2094 }
2095 update
2096 .remove_outgoing_requests
2097 .push(responder_id.to_proto());
2098 for connection_id in pool.user_connection_ids(requester_id) {
2099 session.peer.send(connection_id, update.clone())?;
2100 }
2101 }
2102
2103 response.send(proto::Ack {})?;
2104 Ok(())
2105}
2106
2107async fn remove_contact(
2108 request: proto::RemoveContact,
2109 response: Response<proto::RemoveContact>,
2110 session: Session,
2111) -> Result<()> {
2112 let requester_id = session.user_id;
2113 let responder_id = UserId::from_proto(request.user_id);
2114 let db = session.db().await;
2115 let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
2116
2117 let pool = session.connection_pool().await;
2118 // Update outgoing contact requests of requester
2119 let mut update = proto::UpdateContacts::default();
2120 if contact_accepted {
2121 update.remove_contacts.push(responder_id.to_proto());
2122 } else {
2123 update
2124 .remove_outgoing_requests
2125 .push(responder_id.to_proto());
2126 }
2127 for connection_id in pool.user_connection_ids(requester_id) {
2128 session.peer.send(connection_id, update.clone())?;
2129 }
2130
2131 // Update incoming contact requests of responder
2132 let mut update = proto::UpdateContacts::default();
2133 if contact_accepted {
2134 update.remove_contacts.push(requester_id.to_proto());
2135 } else {
2136 update
2137 .remove_incoming_requests
2138 .push(requester_id.to_proto());
2139 }
2140 for connection_id in pool.user_connection_ids(responder_id) {
2141 session.peer.send(connection_id, update.clone())?;
2142 }
2143
2144 response.send(proto::Ack {})?;
2145 Ok(())
2146}
2147
2148async fn create_channel(
2149 request: proto::CreateChannel,
2150 response: Response<proto::CreateChannel>,
2151 session: Session,
2152) -> Result<()> {
2153 let db = session.db().await;
2154 let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
2155
2156 if let Some(live_kit) = session.live_kit_client.as_ref() {
2157 live_kit.create_room(live_kit_room.clone()).await?;
2158 }
2159
2160 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2161 let id = db
2162 .create_channel(&request.name, parent_id, &live_kit_room, session.user_id)
2163 .await?;
2164
2165 let channel = proto::Channel {
2166 id: id.to_proto(),
2167 name: request.name,
2168 parent_id: request.parent_id,
2169 };
2170
2171 response.send(proto::ChannelResponse {
2172 channel: Some(channel.clone()),
2173 })?;
2174
2175 let mut update = proto::UpdateChannels::default();
2176 update.channels.push(channel);
2177
2178 let user_ids_to_notify = if let Some(parent_id) = parent_id {
2179 db.get_channel_members(parent_id).await?
2180 } else {
2181 vec![session.user_id]
2182 };
2183
2184 let connection_pool = session.connection_pool().await;
2185 for user_id in user_ids_to_notify {
2186 for connection_id in connection_pool.user_connection_ids(user_id) {
2187 let mut update = update.clone();
2188 if user_id == session.user_id {
2189 update.channel_permissions.push(proto::ChannelPermission {
2190 channel_id: id.to_proto(),
2191 is_admin: true,
2192 });
2193 }
2194 session.peer.send(connection_id, update)?;
2195 }
2196 }
2197
2198 Ok(())
2199}
2200
2201async fn remove_channel(
2202 request: proto::RemoveChannel,
2203 response: Response<proto::RemoveChannel>,
2204 session: Session,
2205) -> Result<()> {
2206 let db = session.db().await;
2207
2208 let channel_id = request.channel_id;
2209 let (removed_channels, member_ids) = db
2210 .remove_channel(ChannelId::from_proto(channel_id), session.user_id)
2211 .await?;
2212 response.send(proto::Ack {})?;
2213
2214 // Notify members of removed channels
2215 let mut update = proto::UpdateChannels::default();
2216 update
2217 .remove_channels
2218 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2219
2220 let connection_pool = session.connection_pool().await;
2221 for member_id in member_ids {
2222 for connection_id in connection_pool.user_connection_ids(member_id) {
2223 session.peer.send(connection_id, update.clone())?;
2224 }
2225 }
2226
2227 Ok(())
2228}
2229
2230async fn invite_channel_member(
2231 request: proto::InviteChannelMember,
2232 response: Response<proto::InviteChannelMember>,
2233 session: Session,
2234) -> Result<()> {
2235 let db = session.db().await;
2236 let channel_id = ChannelId::from_proto(request.channel_id);
2237 let invitee_id = UserId::from_proto(request.user_id);
2238 db.invite_channel_member(channel_id, invitee_id, session.user_id, request.admin)
2239 .await?;
2240
2241 let (channel, _) = db
2242 .get_channel(channel_id, session.user_id)
2243 .await?
2244 .ok_or_else(|| anyhow!("channel not found"))?;
2245
2246 let mut update = proto::UpdateChannels::default();
2247 update.channel_invitations.push(proto::Channel {
2248 id: channel.id.to_proto(),
2249 name: channel.name,
2250 parent_id: None,
2251 });
2252 for connection_id in session
2253 .connection_pool()
2254 .await
2255 .user_connection_ids(invitee_id)
2256 {
2257 session.peer.send(connection_id, update.clone())?;
2258 }
2259
2260 response.send(proto::Ack {})?;
2261 Ok(())
2262}
2263
2264async fn remove_channel_member(
2265 request: proto::RemoveChannelMember,
2266 response: Response<proto::RemoveChannelMember>,
2267 session: Session,
2268) -> Result<()> {
2269 let db = session.db().await;
2270 let channel_id = ChannelId::from_proto(request.channel_id);
2271 let member_id = UserId::from_proto(request.user_id);
2272
2273 db.remove_channel_member(channel_id, member_id, session.user_id)
2274 .await?;
2275
2276 let mut update = proto::UpdateChannels::default();
2277 update.remove_channels.push(channel_id.to_proto());
2278
2279 for connection_id in session
2280 .connection_pool()
2281 .await
2282 .user_connection_ids(member_id)
2283 {
2284 session.peer.send(connection_id, update.clone())?;
2285 }
2286
2287 response.send(proto::Ack {})?;
2288 Ok(())
2289}
2290
2291async fn set_channel_member_admin(
2292 request: proto::SetChannelMemberAdmin,
2293 response: Response<proto::SetChannelMemberAdmin>,
2294 session: Session,
2295) -> Result<()> {
2296 let db = session.db().await;
2297 let channel_id = ChannelId::from_proto(request.channel_id);
2298 let member_id = UserId::from_proto(request.user_id);
2299 db.set_channel_member_admin(channel_id, session.user_id, member_id, request.admin)
2300 .await?;
2301
2302 let (channel, has_accepted) = db
2303 .get_channel(channel_id, member_id)
2304 .await?
2305 .ok_or_else(|| anyhow!("channel not found"))?;
2306
2307 let mut update = proto::UpdateChannels::default();
2308 if has_accepted {
2309 update.channel_permissions.push(proto::ChannelPermission {
2310 channel_id: channel.id.to_proto(),
2311 is_admin: request.admin,
2312 });
2313 }
2314
2315 for connection_id in session
2316 .connection_pool()
2317 .await
2318 .user_connection_ids(member_id)
2319 {
2320 session.peer.send(connection_id, update.clone())?;
2321 }
2322
2323 response.send(proto::Ack {})?;
2324 Ok(())
2325}
2326
2327async fn rename_channel(
2328 request: proto::RenameChannel,
2329 response: Response<proto::RenameChannel>,
2330 session: Session,
2331) -> Result<()> {
2332 let db = session.db().await;
2333 let channel_id = ChannelId::from_proto(request.channel_id);
2334 let new_name = db
2335 .rename_channel(channel_id, session.user_id, &request.name)
2336 .await?;
2337
2338 let channel = proto::Channel {
2339 id: request.channel_id,
2340 name: new_name,
2341 parent_id: None,
2342 };
2343 response.send(proto::ChannelResponse {
2344 channel: Some(channel.clone()),
2345 })?;
2346 let mut update = proto::UpdateChannels::default();
2347 update.channels.push(channel);
2348
2349 let member_ids = db.get_channel_members(channel_id).await?;
2350
2351 let connection_pool = session.connection_pool().await;
2352 for member_id in member_ids {
2353 for connection_id in connection_pool.user_connection_ids(member_id) {
2354 session.peer.send(connection_id, update.clone())?;
2355 }
2356 }
2357
2358 Ok(())
2359}
2360
2361async fn get_channel_members(
2362 request: proto::GetChannelMembers,
2363 response: Response<proto::GetChannelMembers>,
2364 session: Session,
2365) -> Result<()> {
2366 let db = session.db().await;
2367 let channel_id = ChannelId::from_proto(request.channel_id);
2368 let members = db
2369 .get_channel_member_details(channel_id, session.user_id)
2370 .await?;
2371 response.send(proto::GetChannelMembersResponse { members })?;
2372 Ok(())
2373}
2374
2375async fn respond_to_channel_invite(
2376 request: proto::RespondToChannelInvite,
2377 response: Response<proto::RespondToChannelInvite>,
2378 session: Session,
2379) -> Result<()> {
2380 let db = session.db().await;
2381 let channel_id = ChannelId::from_proto(request.channel_id);
2382 db.respond_to_channel_invite(channel_id, session.user_id, request.accept)
2383 .await?;
2384
2385 let mut update = proto::UpdateChannels::default();
2386 update
2387 .remove_channel_invitations
2388 .push(channel_id.to_proto());
2389 if request.accept {
2390 let result = db.get_channels_for_user(session.user_id).await?;
2391 update
2392 .channels
2393 .extend(result.channels.into_iter().map(|channel| proto::Channel {
2394 id: channel.id.to_proto(),
2395 name: channel.name,
2396 parent_id: channel.parent_id.map(ChannelId::to_proto),
2397 }));
2398 update
2399 .channel_participants
2400 .extend(
2401 result
2402 .channel_participants
2403 .into_iter()
2404 .map(|(channel_id, user_ids)| proto::ChannelParticipants {
2405 channel_id: channel_id.to_proto(),
2406 participant_user_ids: user_ids.into_iter().map(UserId::to_proto).collect(),
2407 }),
2408 );
2409 update
2410 .channel_permissions
2411 .extend(
2412 result
2413 .channels_with_admin_privileges
2414 .into_iter()
2415 .map(|channel_id| proto::ChannelPermission {
2416 channel_id: channel_id.to_proto(),
2417 is_admin: true,
2418 }),
2419 );
2420 }
2421 session.peer.send(session.connection_id, update)?;
2422 response.send(proto::Ack {})?;
2423
2424 Ok(())
2425}
2426
2427async fn join_channel(
2428 request: proto::JoinChannel,
2429 response: Response<proto::JoinChannel>,
2430 session: Session,
2431) -> Result<()> {
2432 let channel_id = ChannelId::from_proto(request.channel_id);
2433
2434 let joined_room = {
2435 leave_room_for_session(&session).await?;
2436 let db = session.db().await;
2437
2438 let room_id = db.room_id_for_channel(channel_id).await?;
2439
2440 let joined_room = db
2441 .join_room(room_id, session.user_id, session.connection_id)
2442 .await?;
2443
2444 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2445 let token = live_kit
2446 .room_token(
2447 &joined_room.room.live_kit_room,
2448 &session.user_id.to_string(),
2449 )
2450 .trace_err()?;
2451
2452 Some(LiveKitConnectionInfo {
2453 server_url: live_kit.url().into(),
2454 token,
2455 })
2456 });
2457
2458 response.send(proto::JoinRoomResponse {
2459 room: Some(joined_room.room.clone()),
2460 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2461 live_kit_connection_info,
2462 })?;
2463
2464 room_updated(&joined_room.room, &session.peer);
2465
2466 joined_room.into_inner()
2467 };
2468
2469 channel_updated(
2470 channel_id,
2471 &joined_room.room,
2472 &joined_room.channel_members,
2473 &session.peer,
2474 &*session.connection_pool().await,
2475 );
2476
2477 update_user_contacts(session.user_id, &session).await?;
2478
2479 Ok(())
2480}
2481
2482async fn get_channel_buffer(
2483 request: proto::GetChannelBuffer,
2484 response: Response<proto::GetChannelBuffer>,
2485 session: Session,
2486) -> Result<()> {
2487 let db = session.db().await;
2488 let channel_id = ChannelId::from_proto(request.channel_id);
2489
2490 let buffer_id = db.get_or_create_buffer_for_channel(channel_id).await?;
2491
2492 let buffer = db.get_buffer(buffer_id).await?;
2493
2494 response.send(GetChannelBufferResponse {
2495 base_text: buffer.base_text,
2496 operations: buffer.operations,
2497 })?;
2498
2499 Ok(())
2500}
2501
2502async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
2503 let project_id = ProjectId::from_proto(request.project_id);
2504 let project_connection_ids = session
2505 .db()
2506 .await
2507 .project_connection_ids(project_id, session.connection_id)
2508 .await?;
2509 broadcast(
2510 Some(session.connection_id),
2511 project_connection_ids.iter().copied(),
2512 |connection_id| {
2513 session
2514 .peer
2515 .forward_send(session.connection_id, connection_id, request.clone())
2516 },
2517 );
2518 Ok(())
2519}
2520
2521async fn get_private_user_info(
2522 _request: proto::GetPrivateUserInfo,
2523 response: Response<proto::GetPrivateUserInfo>,
2524 session: Session,
2525) -> Result<()> {
2526 let metrics_id = session
2527 .db()
2528 .await
2529 .get_user_metrics_id(session.user_id)
2530 .await?;
2531 let user = session
2532 .db()
2533 .await
2534 .get_user_by_id(session.user_id)
2535 .await?
2536 .ok_or_else(|| anyhow!("user not found"))?;
2537 response.send(proto::GetPrivateUserInfoResponse {
2538 metrics_id,
2539 staff: user.admin,
2540 })?;
2541 Ok(())
2542}
2543
2544fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
2545 match message {
2546 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
2547 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
2548 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
2549 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
2550 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
2551 code: frame.code.into(),
2552 reason: frame.reason,
2553 })),
2554 }
2555}
2556
2557fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
2558 match message {
2559 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
2560 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
2561 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
2562 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
2563 AxumMessage::Close(frame) => {
2564 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
2565 code: frame.code.into(),
2566 reason: frame.reason,
2567 }))
2568 }
2569 }
2570}
2571
2572fn build_initial_channels_update(
2573 channels: ChannelsForUser,
2574 channel_invites: Vec<db::Channel>,
2575) -> proto::UpdateChannels {
2576 let mut update = proto::UpdateChannels::default();
2577
2578 for channel in channels.channels {
2579 update.channels.push(proto::Channel {
2580 id: channel.id.to_proto(),
2581 name: channel.name,
2582 parent_id: channel.parent_id.map(|id| id.to_proto()),
2583 });
2584 }
2585
2586 for (channel_id, participants) in channels.channel_participants {
2587 update
2588 .channel_participants
2589 .push(proto::ChannelParticipants {
2590 channel_id: channel_id.to_proto(),
2591 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
2592 });
2593 }
2594
2595 update
2596 .channel_permissions
2597 .extend(
2598 channels
2599 .channels_with_admin_privileges
2600 .into_iter()
2601 .map(|id| proto::ChannelPermission {
2602 channel_id: id.to_proto(),
2603 is_admin: true,
2604 }),
2605 );
2606
2607 for channel in channel_invites {
2608 update.channel_invitations.push(proto::Channel {
2609 id: channel.id.to_proto(),
2610 name: channel.name,
2611 parent_id: None,
2612 });
2613 }
2614
2615 update
2616}
2617
2618fn build_initial_contacts_update(
2619 contacts: Vec<db::Contact>,
2620 pool: &ConnectionPool,
2621) -> proto::UpdateContacts {
2622 let mut update = proto::UpdateContacts::default();
2623
2624 for contact in contacts {
2625 match contact {
2626 db::Contact::Accepted {
2627 user_id,
2628 should_notify,
2629 busy,
2630 } => {
2631 update
2632 .contacts
2633 .push(contact_for_user(user_id, should_notify, busy, &pool));
2634 }
2635 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
2636 db::Contact::Incoming {
2637 user_id,
2638 should_notify,
2639 } => update
2640 .incoming_requests
2641 .push(proto::IncomingContactRequest {
2642 requester_id: user_id.to_proto(),
2643 should_notify,
2644 }),
2645 }
2646 }
2647
2648 update
2649}
2650
2651fn contact_for_user(
2652 user_id: UserId,
2653 should_notify: bool,
2654 busy: bool,
2655 pool: &ConnectionPool,
2656) -> proto::Contact {
2657 proto::Contact {
2658 user_id: user_id.to_proto(),
2659 online: pool.is_user_online(user_id),
2660 busy,
2661 should_notify,
2662 }
2663}
2664
2665fn room_updated(room: &proto::Room, peer: &Peer) {
2666 broadcast(
2667 None,
2668 room.participants
2669 .iter()
2670 .filter_map(|participant| Some(participant.peer_id?.into())),
2671 |peer_id| {
2672 peer.send(
2673 peer_id.into(),
2674 proto::RoomUpdated {
2675 room: Some(room.clone()),
2676 },
2677 )
2678 },
2679 );
2680}
2681
2682fn channel_updated(
2683 channel_id: ChannelId,
2684 room: &proto::Room,
2685 channel_members: &[UserId],
2686 peer: &Peer,
2687 pool: &ConnectionPool,
2688) {
2689 let participants = room
2690 .participants
2691 .iter()
2692 .map(|p| p.user_id)
2693 .collect::<Vec<_>>();
2694
2695 broadcast(
2696 None,
2697 channel_members
2698 .iter()
2699 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2700 |peer_id| {
2701 peer.send(
2702 peer_id.into(),
2703 proto::UpdateChannels {
2704 channel_participants: vec![proto::ChannelParticipants {
2705 channel_id: channel_id.to_proto(),
2706 participant_user_ids: participants.clone(),
2707 }],
2708 ..Default::default()
2709 },
2710 )
2711 },
2712 );
2713}
2714
2715async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
2716 let db = session.db().await;
2717
2718 let contacts = db.get_contacts(user_id).await?;
2719 let busy = db.is_user_busy(user_id).await?;
2720
2721 let pool = session.connection_pool().await;
2722 let updated_contact = contact_for_user(user_id, false, busy, &pool);
2723 for contact in contacts {
2724 if let db::Contact::Accepted {
2725 user_id: contact_user_id,
2726 ..
2727 } = contact
2728 {
2729 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
2730 session
2731 .peer
2732 .send(
2733 contact_conn_id,
2734 proto::UpdateContacts {
2735 contacts: vec![updated_contact.clone()],
2736 remove_contacts: Default::default(),
2737 incoming_requests: Default::default(),
2738 remove_incoming_requests: Default::default(),
2739 outgoing_requests: Default::default(),
2740 remove_outgoing_requests: Default::default(),
2741 },
2742 )
2743 .trace_err();
2744 }
2745 }
2746 }
2747 Ok(())
2748}
2749
2750async fn leave_room_for_session(session: &Session) -> Result<()> {
2751 let mut contacts_to_update = HashSet::default();
2752
2753 let room_id;
2754 let canceled_calls_to_user_ids;
2755 let live_kit_room;
2756 let delete_live_kit_room;
2757 let room;
2758 let channel_members;
2759 let channel_id;
2760
2761 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
2762 contacts_to_update.insert(session.user_id);
2763
2764 for project in left_room.left_projects.values() {
2765 project_left(project, session);
2766 }
2767
2768 room_id = RoomId::from_proto(left_room.room.id);
2769 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
2770 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
2771 delete_live_kit_room = left_room.deleted;
2772 room = mem::take(&mut left_room.room);
2773 channel_members = mem::take(&mut left_room.channel_members);
2774 channel_id = left_room.channel_id;
2775
2776 room_updated(&room, &session.peer);
2777 } else {
2778 return Ok(());
2779 }
2780
2781 if let Some(channel_id) = channel_id {
2782 channel_updated(
2783 channel_id,
2784 &room,
2785 &channel_members,
2786 &session.peer,
2787 &*session.connection_pool().await,
2788 );
2789 }
2790
2791 {
2792 let pool = session.connection_pool().await;
2793 for canceled_user_id in canceled_calls_to_user_ids {
2794 for connection_id in pool.user_connection_ids(canceled_user_id) {
2795 session
2796 .peer
2797 .send(
2798 connection_id,
2799 proto::CallCanceled {
2800 room_id: room_id.to_proto(),
2801 },
2802 )
2803 .trace_err();
2804 }
2805 contacts_to_update.insert(canceled_user_id);
2806 }
2807 }
2808
2809 for contact_user_id in contacts_to_update {
2810 update_user_contacts(contact_user_id, &session).await?;
2811 }
2812
2813 if let Some(live_kit) = session.live_kit_client.as_ref() {
2814 live_kit
2815 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
2816 .await
2817 .trace_err();
2818
2819 if delete_live_kit_room {
2820 live_kit.delete_room(live_kit_room).await.trace_err();
2821 }
2822 }
2823
2824 Ok(())
2825}
2826
2827fn project_left(project: &db::LeftProject, session: &Session) {
2828 for connection_id in &project.connection_ids {
2829 if project.host_user_id == session.user_id {
2830 session
2831 .peer
2832 .send(
2833 *connection_id,
2834 proto::UnshareProject {
2835 project_id: project.id.to_proto(),
2836 },
2837 )
2838 .trace_err();
2839 } else {
2840 session
2841 .peer
2842 .send(
2843 *connection_id,
2844 proto::RemoveProjectCollaborator {
2845 project_id: project.id.to_proto(),
2846 peer_id: Some(session.connection_id.into()),
2847 },
2848 )
2849 .trace_err();
2850 }
2851 }
2852}
2853
2854pub trait ResultExt {
2855 type Ok;
2856
2857 fn trace_err(self) -> Option<Self::Ok>;
2858}
2859
2860impl<T, E> ResultExt for Result<T, E>
2861where
2862 E: std::fmt::Debug,
2863{
2864 type Ok = T;
2865
2866 fn trace_err(self) -> Option<T> {
2867 match self {
2868 Ok(value) => Some(value),
2869 Err(error) => {
2870 tracing::error!("{:?}", error);
2871 None
2872 }
2873 }
2874 }
2875}