1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{self, ChannelId, ChannelsForUser, Database, ProjectId, RoomId, ServerId, User, UserId},
6 executor::Executor,
7 AppState, Result,
8};
9use anyhow::anyhow;
10use async_tungstenite::tungstenite::{
11 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
12};
13use axum::{
14 body::Body,
15 extract::{
16 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
17 ConnectInfo, WebSocketUpgrade,
18 },
19 headers::{Header, HeaderName},
20 http::StatusCode,
21 middleware,
22 response::IntoResponse,
23 routing::get,
24 Extension, Router, TypedHeader,
25};
26use collections::{HashMap, HashSet};
27pub use connection_pool::ConnectionPool;
28use futures::{
29 channel::oneshot,
30 future::{self, BoxFuture},
31 stream::FuturesUnordered,
32 FutureExt, SinkExt, StreamExt, TryStreamExt,
33};
34use lazy_static::lazy_static;
35use prometheus::{register_int_gauge, IntGauge};
36use rpc::{
37 proto::{
38 self, Ack, AddChannelBufferCollaborator, AnyTypedEnvelope, EntityMessage, EnvelopedMessage,
39 LiveKitConnectionInfo, RequestMessage,
40 },
41 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
42};
43use serde::{Serialize, Serializer};
44use std::{
45 any::TypeId,
46 fmt,
47 future::Future,
48 marker::PhantomData,
49 mem,
50 net::SocketAddr,
51 ops::{Deref, DerefMut},
52 rc::Rc,
53 sync::{
54 atomic::{AtomicBool, Ordering::SeqCst},
55 Arc,
56 },
57 time::{Duration, Instant},
58};
59use tokio::sync::{watch, Semaphore};
60use tower::ServiceBuilder;
61use tracing::{info_span, instrument, Instrument};
62
63pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
64pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
65
66lazy_static! {
67 static ref METRIC_CONNECTIONS: IntGauge =
68 register_int_gauge!("connections", "number of connections").unwrap();
69 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
70 "shared_projects",
71 "number of open projects with one or more guests"
72 )
73 .unwrap();
74}
75
76type MessageHandler =
77 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
78
79struct Response<R> {
80 peer: Arc<Peer>,
81 receipt: Receipt<R>,
82 responded: Arc<AtomicBool>,
83}
84
85impl<R: RequestMessage> Response<R> {
86 fn send(self, payload: R::Response) -> Result<()> {
87 self.responded.store(true, SeqCst);
88 self.peer.respond(self.receipt, payload)?;
89 Ok(())
90 }
91}
92
93#[derive(Clone)]
94struct Session {
95 user_id: UserId,
96 connection_id: ConnectionId,
97 db: Arc<tokio::sync::Mutex<DbHandle>>,
98 peer: Arc<Peer>,
99 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
100 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
101 executor: Executor,
102}
103
104impl Session {
105 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
106 #[cfg(test)]
107 tokio::task::yield_now().await;
108 let guard = self.db.lock().await;
109 #[cfg(test)]
110 tokio::task::yield_now().await;
111 guard
112 }
113
114 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
115 #[cfg(test)]
116 tokio::task::yield_now().await;
117 let guard = self.connection_pool.lock();
118 ConnectionPoolGuard {
119 guard,
120 _not_send: PhantomData,
121 }
122 }
123}
124
125impl fmt::Debug for Session {
126 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127 f.debug_struct("Session")
128 .field("user_id", &self.user_id)
129 .field("connection_id", &self.connection_id)
130 .finish()
131 }
132}
133
134struct DbHandle(Arc<Database>);
135
136impl Deref for DbHandle {
137 type Target = Database;
138
139 fn deref(&self) -> &Self::Target {
140 self.0.as_ref()
141 }
142}
143
144pub struct Server {
145 id: parking_lot::Mutex<ServerId>,
146 peer: Arc<Peer>,
147 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
148 app_state: Arc<AppState>,
149 executor: Executor,
150 handlers: HashMap<TypeId, MessageHandler>,
151 teardown: watch::Sender<()>,
152}
153
154pub(crate) struct ConnectionPoolGuard<'a> {
155 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
156 _not_send: PhantomData<Rc<()>>,
157}
158
159#[derive(Serialize)]
160pub struct ServerSnapshot<'a> {
161 peer: &'a Peer,
162 #[serde(serialize_with = "serialize_deref")]
163 connection_pool: ConnectionPoolGuard<'a>,
164}
165
166pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
167where
168 S: Serializer,
169 T: Deref<Target = U>,
170 U: Serialize,
171{
172 Serialize::serialize(value.deref(), serializer)
173}
174
175impl Server {
176 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
177 let mut server = Self {
178 id: parking_lot::Mutex::new(id),
179 peer: Peer::new(id.0 as u32),
180 app_state,
181 executor,
182 connection_pool: Default::default(),
183 handlers: Default::default(),
184 teardown: watch::channel(()).0,
185 };
186
187 server
188 .add_request_handler(ping)
189 .add_request_handler(create_room)
190 .add_request_handler(join_room)
191 .add_request_handler(rejoin_room)
192 .add_request_handler(leave_room)
193 .add_request_handler(call)
194 .add_request_handler(cancel_call)
195 .add_message_handler(decline_call)
196 .add_request_handler(update_participant_location)
197 .add_request_handler(share_project)
198 .add_message_handler(unshare_project)
199 .add_request_handler(join_project)
200 .add_message_handler(leave_project)
201 .add_request_handler(update_project)
202 .add_request_handler(update_worktree)
203 .add_message_handler(start_language_server)
204 .add_message_handler(update_language_server)
205 .add_message_handler(update_diagnostic_summary)
206 .add_message_handler(update_worktree_settings)
207 .add_message_handler(refresh_inlay_hints)
208 .add_request_handler(forward_project_request::<proto::GetHover>)
209 .add_request_handler(forward_project_request::<proto::GetDefinition>)
210 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
211 .add_request_handler(forward_project_request::<proto::GetReferences>)
212 .add_request_handler(forward_project_request::<proto::SearchProject>)
213 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
214 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
215 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
216 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
217 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
218 .add_request_handler(forward_project_request::<proto::GetCompletions>)
219 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
220 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
221 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
222 .add_request_handler(forward_project_request::<proto::PrepareRename>)
223 .add_request_handler(forward_project_request::<proto::PerformRename>)
224 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
225 .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
226 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
227 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
228 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
229 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
230 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
231 .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
232 .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
233 .add_request_handler(forward_project_request::<proto::InlayHints>)
234 .add_message_handler(create_buffer_for_peer)
235 .add_request_handler(update_buffer)
236 .add_message_handler(update_buffer_file)
237 .add_message_handler(buffer_reloaded)
238 .add_message_handler(buffer_saved)
239 .add_request_handler(forward_project_request::<proto::SaveBuffer>)
240 .add_request_handler(get_users)
241 .add_request_handler(fuzzy_search_users)
242 .add_request_handler(request_contact)
243 .add_request_handler(remove_contact)
244 .add_request_handler(respond_to_contact_request)
245 .add_request_handler(create_channel)
246 .add_request_handler(remove_channel)
247 .add_request_handler(invite_channel_member)
248 .add_request_handler(remove_channel_member)
249 .add_request_handler(set_channel_member_admin)
250 .add_request_handler(rename_channel)
251 .add_request_handler(join_channel_buffer)
252 .add_request_handler(leave_channel_buffer)
253 .add_message_handler(update_channel_buffer)
254 .add_request_handler(get_channel_members)
255 .add_request_handler(respond_to_channel_invite)
256 .add_request_handler(join_channel)
257 .add_request_handler(follow)
258 .add_message_handler(unfollow)
259 .add_message_handler(update_followers)
260 .add_message_handler(update_diff_base)
261 .add_request_handler(get_private_user_info);
262
263 Arc::new(server)
264 }
265
266 pub async fn start(&self) -> Result<()> {
267 let server_id = *self.id.lock();
268 let app_state = self.app_state.clone();
269 let peer = self.peer.clone();
270 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
271 let pool = self.connection_pool.clone();
272 let live_kit_client = self.app_state.live_kit_client.clone();
273
274 let span = info_span!("start server");
275 self.executor.spawn_detached(
276 async move {
277 tracing::info!("waiting for cleanup timeout");
278 timeout.await;
279 tracing::info!("cleanup timeout expired, retrieving stale rooms");
280 if let Some(room_ids) = app_state
281 .db
282 .stale_room_ids(&app_state.config.zed_environment, server_id)
283 .await
284 .trace_err()
285 {
286 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
287 for room_id in room_ids {
288 let mut contacts_to_update = HashSet::default();
289 let mut canceled_calls_to_user_ids = Vec::new();
290 let mut live_kit_room = String::new();
291 let mut delete_live_kit_room = false;
292
293 if let Some(mut refreshed_room) = app_state
294 .db
295 .refresh_room(room_id, server_id)
296 .await
297 .trace_err()
298 {
299 tracing::info!(
300 room_id = room_id.0,
301 new_participant_count = refreshed_room.room.participants.len(),
302 "refreshed room"
303 );
304 room_updated(&refreshed_room.room, &peer);
305 if let Some(channel_id) = refreshed_room.channel_id {
306 channel_updated(
307 channel_id,
308 &refreshed_room.room,
309 &refreshed_room.channel_members,
310 &peer,
311 &*pool.lock(),
312 );
313 }
314 contacts_to_update
315 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
316 contacts_to_update
317 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
318 canceled_calls_to_user_ids =
319 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
320 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
321 delete_live_kit_room = refreshed_room.room.participants.is_empty();
322 }
323
324 {
325 let pool = pool.lock();
326 for canceled_user_id in canceled_calls_to_user_ids {
327 for connection_id in pool.user_connection_ids(canceled_user_id) {
328 peer.send(
329 connection_id,
330 proto::CallCanceled {
331 room_id: room_id.to_proto(),
332 },
333 )
334 .trace_err();
335 }
336 }
337 }
338
339 for user_id in contacts_to_update {
340 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
341 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
342 if let Some((busy, contacts)) = busy.zip(contacts) {
343 let pool = pool.lock();
344 let updated_contact = contact_for_user(user_id, false, busy, &pool);
345 for contact in contacts {
346 if let db::Contact::Accepted {
347 user_id: contact_user_id,
348 ..
349 } = contact
350 {
351 for contact_conn_id in
352 pool.user_connection_ids(contact_user_id)
353 {
354 peer.send(
355 contact_conn_id,
356 proto::UpdateContacts {
357 contacts: vec![updated_contact.clone()],
358 remove_contacts: Default::default(),
359 incoming_requests: Default::default(),
360 remove_incoming_requests: Default::default(),
361 outgoing_requests: Default::default(),
362 remove_outgoing_requests: Default::default(),
363 },
364 )
365 .trace_err();
366 }
367 }
368 }
369 }
370 }
371
372 if let Some(live_kit) = live_kit_client.as_ref() {
373 if delete_live_kit_room {
374 live_kit.delete_room(live_kit_room).await.trace_err();
375 }
376 }
377 }
378 }
379
380 app_state
381 .db
382 .delete_stale_servers(&app_state.config.zed_environment, server_id)
383 .await
384 .trace_err();
385 }
386 .instrument(span),
387 );
388 Ok(())
389 }
390
391 pub fn teardown(&self) {
392 self.peer.teardown();
393 self.connection_pool.lock().reset();
394 let _ = self.teardown.send(());
395 }
396
397 #[cfg(test)]
398 pub fn reset(&self, id: ServerId) {
399 self.teardown();
400 *self.id.lock() = id;
401 self.peer.reset(id.0 as u32);
402 }
403
404 #[cfg(test)]
405 pub fn id(&self) -> ServerId {
406 *self.id.lock()
407 }
408
409 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
410 where
411 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
412 Fut: 'static + Send + Future<Output = Result<()>>,
413 M: EnvelopedMessage,
414 {
415 let prev_handler = self.handlers.insert(
416 TypeId::of::<M>(),
417 Box::new(move |envelope, session| {
418 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
419 let span = info_span!(
420 "handle message",
421 payload_type = envelope.payload_type_name()
422 );
423 span.in_scope(|| {
424 tracing::info!(
425 payload_type = envelope.payload_type_name(),
426 "message received"
427 );
428 });
429 let start_time = Instant::now();
430 let future = (handler)(*envelope, session);
431 async move {
432 let result = future.await;
433 let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
434 match result {
435 Err(error) => {
436 tracing::error!(%error, ?duration_ms, "error handling message")
437 }
438 Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
439 }
440 }
441 .instrument(span)
442 .boxed()
443 }),
444 );
445 if prev_handler.is_some() {
446 panic!("registered a handler for the same message twice");
447 }
448 self
449 }
450
451 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
452 where
453 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
454 Fut: 'static + Send + Future<Output = Result<()>>,
455 M: EnvelopedMessage,
456 {
457 self.add_handler(move |envelope, session| handler(envelope.payload, session));
458 self
459 }
460
461 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
462 where
463 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
464 Fut: Send + Future<Output = Result<()>>,
465 M: RequestMessage,
466 {
467 let handler = Arc::new(handler);
468 self.add_handler(move |envelope, session| {
469 let receipt = envelope.receipt();
470 let handler = handler.clone();
471 async move {
472 let peer = session.peer.clone();
473 let responded = Arc::new(AtomicBool::default());
474 let response = Response {
475 peer: peer.clone(),
476 responded: responded.clone(),
477 receipt,
478 };
479 match (handler)(envelope.payload, response, session).await {
480 Ok(()) => {
481 if responded.load(std::sync::atomic::Ordering::SeqCst) {
482 Ok(())
483 } else {
484 Err(anyhow!("handler did not send a response"))?
485 }
486 }
487 Err(error) => {
488 peer.respond_with_error(
489 receipt,
490 proto::Error {
491 message: error.to_string(),
492 },
493 )?;
494 Err(error)
495 }
496 }
497 }
498 })
499 }
500
501 pub fn handle_connection(
502 self: &Arc<Self>,
503 connection: Connection,
504 address: String,
505 user: User,
506 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
507 executor: Executor,
508 ) -> impl Future<Output = Result<()>> {
509 let this = self.clone();
510 let user_id = user.id;
511 let login = user.github_login;
512 let span = info_span!("handle connection", %user_id, %login, %address);
513 let mut teardown = self.teardown.subscribe();
514 async move {
515 let (connection_id, handle_io, mut incoming_rx) = this
516 .peer
517 .add_connection(connection, {
518 let executor = executor.clone();
519 move |duration| executor.sleep(duration)
520 });
521
522 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
523 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
524 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
525
526 if let Some(send_connection_id) = send_connection_id.take() {
527 let _ = send_connection_id.send(connection_id);
528 }
529
530 if !user.connected_once {
531 this.peer.send(connection_id, proto::ShowContacts {})?;
532 this.app_state.db.set_user_connected_once(user_id, true).await?;
533 }
534
535 let (contacts, invite_code, channels_for_user, channel_invites) = future::try_join4(
536 this.app_state.db.get_contacts(user_id),
537 this.app_state.db.get_invite_code_for_user(user_id),
538 this.app_state.db.get_channels_for_user(user_id),
539 this.app_state.db.get_channel_invites_for_user(user_id)
540 ).await?;
541
542 {
543 let mut pool = this.connection_pool.lock();
544 pool.add_connection(connection_id, user_id, user.admin);
545 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
546 this.peer.send(connection_id, build_initial_channels_update(
547 channels_for_user,
548 channel_invites
549 ))?;
550
551 if let Some((code, count)) = invite_code {
552 this.peer.send(connection_id, proto::UpdateInviteInfo {
553 url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
554 count: count as u32,
555 })?;
556 }
557 }
558
559 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
560 this.peer.send(connection_id, incoming_call)?;
561 }
562
563 let session = Session {
564 user_id,
565 connection_id,
566 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
567 peer: this.peer.clone(),
568 connection_pool: this.connection_pool.clone(),
569 live_kit_client: this.app_state.live_kit_client.clone(),
570 executor: executor.clone(),
571 };
572 update_user_contacts(user_id, &session).await?;
573
574 let handle_io = handle_io.fuse();
575 futures::pin_mut!(handle_io);
576
577 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
578 // This prevents deadlocks when e.g., client A performs a request to client B and
579 // client B performs a request to client A. If both clients stop processing further
580 // messages until their respective request completes, they won't have a chance to
581 // respond to the other client's request and cause a deadlock.
582 //
583 // This arrangement ensures we will attempt to process earlier messages first, but fall
584 // back to processing messages arrived later in the spirit of making progress.
585 let mut foreground_message_handlers = FuturesUnordered::new();
586 let concurrent_handlers = Arc::new(Semaphore::new(256));
587 loop {
588 let next_message = async {
589 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
590 let message = incoming_rx.next().await;
591 (permit, message)
592 }.fuse();
593 futures::pin_mut!(next_message);
594 futures::select_biased! {
595 _ = teardown.changed().fuse() => return Ok(()),
596 result = handle_io => {
597 if let Err(error) = result {
598 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
599 }
600 break;
601 }
602 _ = foreground_message_handlers.next() => {}
603 next_message = next_message => {
604 let (permit, message) = next_message;
605 if let Some(message) = message {
606 let type_name = message.payload_type_name();
607 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
608 let span_enter = span.enter();
609 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
610 let is_background = message.is_background();
611 let handle_message = (handler)(message, session.clone());
612 drop(span_enter);
613
614 let handle_message = async move {
615 handle_message.await;
616 drop(permit);
617 }.instrument(span);
618 if is_background {
619 executor.spawn_detached(handle_message);
620 } else {
621 foreground_message_handlers.push(handle_message);
622 }
623 } else {
624 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
625 }
626 } else {
627 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
628 break;
629 }
630 }
631 }
632 }
633
634 drop(foreground_message_handlers);
635 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
636 if let Err(error) = connection_lost(session, teardown, executor).await {
637 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
638 }
639
640 Ok(())
641 }.instrument(span)
642 }
643
644 pub async fn invite_code_redeemed(
645 self: &Arc<Self>,
646 inviter_id: UserId,
647 invitee_id: UserId,
648 ) -> Result<()> {
649 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
650 if let Some(code) = &user.invite_code {
651 let pool = self.connection_pool.lock();
652 let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
653 for connection_id in pool.user_connection_ids(inviter_id) {
654 self.peer.send(
655 connection_id,
656 proto::UpdateContacts {
657 contacts: vec![invitee_contact.clone()],
658 ..Default::default()
659 },
660 )?;
661 self.peer.send(
662 connection_id,
663 proto::UpdateInviteInfo {
664 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
665 count: user.invite_count as u32,
666 },
667 )?;
668 }
669 }
670 }
671 Ok(())
672 }
673
674 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
675 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
676 if let Some(invite_code) = &user.invite_code {
677 let pool = self.connection_pool.lock();
678 for connection_id in pool.user_connection_ids(user_id) {
679 self.peer.send(
680 connection_id,
681 proto::UpdateInviteInfo {
682 url: format!(
683 "{}{}",
684 self.app_state.config.invite_link_prefix, invite_code
685 ),
686 count: user.invite_count as u32,
687 },
688 )?;
689 }
690 }
691 }
692 Ok(())
693 }
694
695 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
696 ServerSnapshot {
697 connection_pool: ConnectionPoolGuard {
698 guard: self.connection_pool.lock(),
699 _not_send: PhantomData,
700 },
701 peer: &self.peer,
702 }
703 }
704}
705
706impl<'a> Deref for ConnectionPoolGuard<'a> {
707 type Target = ConnectionPool;
708
709 fn deref(&self) -> &Self::Target {
710 &*self.guard
711 }
712}
713
714impl<'a> DerefMut for ConnectionPoolGuard<'a> {
715 fn deref_mut(&mut self) -> &mut Self::Target {
716 &mut *self.guard
717 }
718}
719
720impl<'a> Drop for ConnectionPoolGuard<'a> {
721 fn drop(&mut self) {
722 #[cfg(test)]
723 self.check_invariants();
724 }
725}
726
727fn broadcast<F>(
728 sender_id: Option<ConnectionId>,
729 receiver_ids: impl IntoIterator<Item = ConnectionId>,
730 mut f: F,
731) where
732 F: FnMut(ConnectionId) -> anyhow::Result<()>,
733{
734 for receiver_id in receiver_ids {
735 if Some(receiver_id) != sender_id {
736 if let Err(error) = f(receiver_id) {
737 tracing::error!("failed to send to {:?} {}", receiver_id, error);
738 }
739 }
740 }
741}
742
743lazy_static! {
744 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
745}
746
747pub struct ProtocolVersion(u32);
748
749impl Header for ProtocolVersion {
750 fn name() -> &'static HeaderName {
751 &ZED_PROTOCOL_VERSION
752 }
753
754 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
755 where
756 Self: Sized,
757 I: Iterator<Item = &'i axum::http::HeaderValue>,
758 {
759 let version = values
760 .next()
761 .ok_or_else(axum::headers::Error::invalid)?
762 .to_str()
763 .map_err(|_| axum::headers::Error::invalid())?
764 .parse()
765 .map_err(|_| axum::headers::Error::invalid())?;
766 Ok(Self(version))
767 }
768
769 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
770 values.extend([self.0.to_string().parse().unwrap()]);
771 }
772}
773
774pub fn routes(server: Arc<Server>) -> Router<Body> {
775 Router::new()
776 .route("/rpc", get(handle_websocket_request))
777 .layer(
778 ServiceBuilder::new()
779 .layer(Extension(server.app_state.clone()))
780 .layer(middleware::from_fn(auth::validate_header)),
781 )
782 .route("/metrics", get(handle_metrics))
783 .layer(Extension(server))
784}
785
786pub async fn handle_websocket_request(
787 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
788 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
789 Extension(server): Extension<Arc<Server>>,
790 Extension(user): Extension<User>,
791 ws: WebSocketUpgrade,
792) -> axum::response::Response {
793 if protocol_version != rpc::PROTOCOL_VERSION {
794 return (
795 StatusCode::UPGRADE_REQUIRED,
796 "client must be upgraded".to_string(),
797 )
798 .into_response();
799 }
800 let socket_address = socket_address.to_string();
801 ws.on_upgrade(move |socket| {
802 use util::ResultExt;
803 let socket = socket
804 .map_ok(to_tungstenite_message)
805 .err_into()
806 .with(|message| async move { Ok(to_axum_message(message)) });
807 let connection = Connection::new(Box::pin(socket));
808 async move {
809 server
810 .handle_connection(connection, socket_address, user, None, Executor::Production)
811 .await
812 .log_err();
813 }
814 })
815}
816
817pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
818 let connections = server
819 .connection_pool
820 .lock()
821 .connections()
822 .filter(|connection| !connection.admin)
823 .count();
824
825 METRIC_CONNECTIONS.set(connections as _);
826
827 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
828 METRIC_SHARED_PROJECTS.set(shared_projects as _);
829
830 let encoder = prometheus::TextEncoder::new();
831 let metric_families = prometheus::gather();
832 let encoded_metrics = encoder
833 .encode_to_string(&metric_families)
834 .map_err(|err| anyhow!("{}", err))?;
835 Ok(encoded_metrics)
836}
837
838#[instrument(err, skip(executor))]
839async fn connection_lost(
840 session: Session,
841 mut teardown: watch::Receiver<()>,
842 executor: Executor,
843) -> Result<()> {
844 session.peer.disconnect(session.connection_id);
845 session
846 .connection_pool()
847 .await
848 .remove_connection(session.connection_id)?;
849
850 session
851 .db()
852 .await
853 .connection_lost(session.connection_id)
854 .await
855 .trace_err();
856
857 futures::select_biased! {
858 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
859 leave_room_for_session(&session).await.trace_err();
860 leave_channel_buffers_for_session(&session).await.trace_err();
861
862 if !session
863 .connection_pool()
864 .await
865 .is_user_online(session.user_id)
866 {
867 let db = session.db().await;
868 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
869 room_updated(&room, &session.peer);
870 }
871 }
872 update_user_contacts(session.user_id, &session).await?;
873
874
875 }
876 _ = teardown.changed().fuse() => {}
877 }
878
879 Ok(())
880}
881
882async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
883 response.send(proto::Ack {})?;
884 Ok(())
885}
886
887async fn create_room(
888 _request: proto::CreateRoom,
889 response: Response<proto::CreateRoom>,
890 session: Session,
891) -> Result<()> {
892 let live_kit_room = nanoid::nanoid!(30);
893
894 let live_kit_connection_info = {
895 let live_kit_room = live_kit_room.clone();
896 let live_kit = session.live_kit_client.as_ref();
897
898 util::async_iife!({
899 let live_kit = live_kit?;
900
901 live_kit
902 .create_room(live_kit_room.clone())
903 .await
904 .trace_err()?;
905
906 let token = live_kit
907 .room_token(&live_kit_room, &session.user_id.to_string())
908 .trace_err()?;
909
910 Some(proto::LiveKitConnectionInfo {
911 server_url: live_kit.url().into(),
912 token,
913 })
914 })
915 }
916 .await;
917
918 let room = session
919 .db()
920 .await
921 .create_room(session.user_id, session.connection_id, &live_kit_room)
922 .await?;
923
924 response.send(proto::CreateRoomResponse {
925 room: Some(room.clone()),
926 live_kit_connection_info,
927 })?;
928
929 update_user_contacts(session.user_id, &session).await?;
930 Ok(())
931}
932
933async fn join_room(
934 request: proto::JoinRoom,
935 response: Response<proto::JoinRoom>,
936 session: Session,
937) -> Result<()> {
938 let room_id = RoomId::from_proto(request.id);
939 let joined_room = {
940 let room = session
941 .db()
942 .await
943 .join_room(room_id, session.user_id, session.connection_id)
944 .await?;
945 room_updated(&room.room, &session.peer);
946 room.into_inner()
947 };
948
949 if let Some(channel_id) = joined_room.channel_id {
950 channel_updated(
951 channel_id,
952 &joined_room.room,
953 &joined_room.channel_members,
954 &session.peer,
955 &*session.connection_pool().await,
956 )
957 }
958
959 for connection_id in session
960 .connection_pool()
961 .await
962 .user_connection_ids(session.user_id)
963 {
964 session
965 .peer
966 .send(
967 connection_id,
968 proto::CallCanceled {
969 room_id: room_id.to_proto(),
970 },
971 )
972 .trace_err();
973 }
974
975 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
976 if let Some(token) = live_kit
977 .room_token(
978 &joined_room.room.live_kit_room,
979 &session.user_id.to_string(),
980 )
981 .trace_err()
982 {
983 Some(proto::LiveKitConnectionInfo {
984 server_url: live_kit.url().into(),
985 token,
986 })
987 } else {
988 None
989 }
990 } else {
991 None
992 };
993
994 response.send(proto::JoinRoomResponse {
995 room: Some(joined_room.room),
996 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
997 live_kit_connection_info,
998 })?;
999
1000 update_user_contacts(session.user_id, &session).await?;
1001 Ok(())
1002}
1003
1004async fn rejoin_room(
1005 request: proto::RejoinRoom,
1006 response: Response<proto::RejoinRoom>,
1007 session: Session,
1008) -> Result<()> {
1009 let room;
1010 let channel_id;
1011 let channel_members;
1012 {
1013 let mut rejoined_room = session
1014 .db()
1015 .await
1016 .rejoin_room(request, session.user_id, session.connection_id)
1017 .await?;
1018
1019 response.send(proto::RejoinRoomResponse {
1020 room: Some(rejoined_room.room.clone()),
1021 reshared_projects: rejoined_room
1022 .reshared_projects
1023 .iter()
1024 .map(|project| proto::ResharedProject {
1025 id: project.id.to_proto(),
1026 collaborators: project
1027 .collaborators
1028 .iter()
1029 .map(|collaborator| collaborator.to_proto())
1030 .collect(),
1031 })
1032 .collect(),
1033 rejoined_projects: rejoined_room
1034 .rejoined_projects
1035 .iter()
1036 .map(|rejoined_project| proto::RejoinedProject {
1037 id: rejoined_project.id.to_proto(),
1038 worktrees: rejoined_project
1039 .worktrees
1040 .iter()
1041 .map(|worktree| proto::WorktreeMetadata {
1042 id: worktree.id,
1043 root_name: worktree.root_name.clone(),
1044 visible: worktree.visible,
1045 abs_path: worktree.abs_path.clone(),
1046 })
1047 .collect(),
1048 collaborators: rejoined_project
1049 .collaborators
1050 .iter()
1051 .map(|collaborator| collaborator.to_proto())
1052 .collect(),
1053 language_servers: rejoined_project.language_servers.clone(),
1054 })
1055 .collect(),
1056 })?;
1057 room_updated(&rejoined_room.room, &session.peer);
1058
1059 for project in &rejoined_room.reshared_projects {
1060 for collaborator in &project.collaborators {
1061 session
1062 .peer
1063 .send(
1064 collaborator.connection_id,
1065 proto::UpdateProjectCollaborator {
1066 project_id: project.id.to_proto(),
1067 old_peer_id: Some(project.old_connection_id.into()),
1068 new_peer_id: Some(session.connection_id.into()),
1069 },
1070 )
1071 .trace_err();
1072 }
1073
1074 broadcast(
1075 Some(session.connection_id),
1076 project
1077 .collaborators
1078 .iter()
1079 .map(|collaborator| collaborator.connection_id),
1080 |connection_id| {
1081 session.peer.forward_send(
1082 session.connection_id,
1083 connection_id,
1084 proto::UpdateProject {
1085 project_id: project.id.to_proto(),
1086 worktrees: project.worktrees.clone(),
1087 },
1088 )
1089 },
1090 );
1091 }
1092
1093 for project in &rejoined_room.rejoined_projects {
1094 for collaborator in &project.collaborators {
1095 session
1096 .peer
1097 .send(
1098 collaborator.connection_id,
1099 proto::UpdateProjectCollaborator {
1100 project_id: project.id.to_proto(),
1101 old_peer_id: Some(project.old_connection_id.into()),
1102 new_peer_id: Some(session.connection_id.into()),
1103 },
1104 )
1105 .trace_err();
1106 }
1107 }
1108
1109 for project in &mut rejoined_room.rejoined_projects {
1110 for worktree in mem::take(&mut project.worktrees) {
1111 #[cfg(any(test, feature = "test-support"))]
1112 const MAX_CHUNK_SIZE: usize = 2;
1113 #[cfg(not(any(test, feature = "test-support")))]
1114 const MAX_CHUNK_SIZE: usize = 256;
1115
1116 // Stream this worktree's entries.
1117 let message = proto::UpdateWorktree {
1118 project_id: project.id.to_proto(),
1119 worktree_id: worktree.id,
1120 abs_path: worktree.abs_path.clone(),
1121 root_name: worktree.root_name,
1122 updated_entries: worktree.updated_entries,
1123 removed_entries: worktree.removed_entries,
1124 scan_id: worktree.scan_id,
1125 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1126 updated_repositories: worktree.updated_repositories,
1127 removed_repositories: worktree.removed_repositories,
1128 };
1129 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1130 session.peer.send(session.connection_id, update.clone())?;
1131 }
1132
1133 // Stream this worktree's diagnostics.
1134 for summary in worktree.diagnostic_summaries {
1135 session.peer.send(
1136 session.connection_id,
1137 proto::UpdateDiagnosticSummary {
1138 project_id: project.id.to_proto(),
1139 worktree_id: worktree.id,
1140 summary: Some(summary),
1141 },
1142 )?;
1143 }
1144
1145 for settings_file in worktree.settings_files {
1146 session.peer.send(
1147 session.connection_id,
1148 proto::UpdateWorktreeSettings {
1149 project_id: project.id.to_proto(),
1150 worktree_id: worktree.id,
1151 path: settings_file.path,
1152 content: Some(settings_file.content),
1153 },
1154 )?;
1155 }
1156 }
1157
1158 for language_server in &project.language_servers {
1159 session.peer.send(
1160 session.connection_id,
1161 proto::UpdateLanguageServer {
1162 project_id: project.id.to_proto(),
1163 language_server_id: language_server.id,
1164 variant: Some(
1165 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1166 proto::LspDiskBasedDiagnosticsUpdated {},
1167 ),
1168 ),
1169 },
1170 )?;
1171 }
1172 }
1173
1174 let rejoined_room = rejoined_room.into_inner();
1175
1176 room = rejoined_room.room;
1177 channel_id = rejoined_room.channel_id;
1178 channel_members = rejoined_room.channel_members;
1179 }
1180
1181 if let Some(channel_id) = channel_id {
1182 channel_updated(
1183 channel_id,
1184 &room,
1185 &channel_members,
1186 &session.peer,
1187 &*session.connection_pool().await,
1188 );
1189 }
1190
1191 update_user_contacts(session.user_id, &session).await?;
1192 Ok(())
1193}
1194
1195async fn leave_room(
1196 _: proto::LeaveRoom,
1197 response: Response<proto::LeaveRoom>,
1198 session: Session,
1199) -> Result<()> {
1200 leave_room_for_session(&session).await?;
1201 response.send(proto::Ack {})?;
1202 Ok(())
1203}
1204
1205async fn call(
1206 request: proto::Call,
1207 response: Response<proto::Call>,
1208 session: Session,
1209) -> Result<()> {
1210 let room_id = RoomId::from_proto(request.room_id);
1211 let calling_user_id = session.user_id;
1212 let calling_connection_id = session.connection_id;
1213 let called_user_id = UserId::from_proto(request.called_user_id);
1214 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1215 if !session
1216 .db()
1217 .await
1218 .has_contact(calling_user_id, called_user_id)
1219 .await?
1220 {
1221 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1222 }
1223
1224 let incoming_call = {
1225 let (room, incoming_call) = &mut *session
1226 .db()
1227 .await
1228 .call(
1229 room_id,
1230 calling_user_id,
1231 calling_connection_id,
1232 called_user_id,
1233 initial_project_id,
1234 )
1235 .await?;
1236 room_updated(&room, &session.peer);
1237 mem::take(incoming_call)
1238 };
1239 update_user_contacts(called_user_id, &session).await?;
1240
1241 let mut calls = session
1242 .connection_pool()
1243 .await
1244 .user_connection_ids(called_user_id)
1245 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1246 .collect::<FuturesUnordered<_>>();
1247
1248 while let Some(call_response) = calls.next().await {
1249 match call_response.as_ref() {
1250 Ok(_) => {
1251 response.send(proto::Ack {})?;
1252 return Ok(());
1253 }
1254 Err(_) => {
1255 call_response.trace_err();
1256 }
1257 }
1258 }
1259
1260 {
1261 let room = session
1262 .db()
1263 .await
1264 .call_failed(room_id, called_user_id)
1265 .await?;
1266 room_updated(&room, &session.peer);
1267 }
1268 update_user_contacts(called_user_id, &session).await?;
1269
1270 Err(anyhow!("failed to ring user"))?
1271}
1272
1273async fn cancel_call(
1274 request: proto::CancelCall,
1275 response: Response<proto::CancelCall>,
1276 session: Session,
1277) -> Result<()> {
1278 let called_user_id = UserId::from_proto(request.called_user_id);
1279 let room_id = RoomId::from_proto(request.room_id);
1280 {
1281 let room = session
1282 .db()
1283 .await
1284 .cancel_call(room_id, session.connection_id, called_user_id)
1285 .await?;
1286 room_updated(&room, &session.peer);
1287 }
1288
1289 for connection_id in session
1290 .connection_pool()
1291 .await
1292 .user_connection_ids(called_user_id)
1293 {
1294 session
1295 .peer
1296 .send(
1297 connection_id,
1298 proto::CallCanceled {
1299 room_id: room_id.to_proto(),
1300 },
1301 )
1302 .trace_err();
1303 }
1304 response.send(proto::Ack {})?;
1305
1306 update_user_contacts(called_user_id, &session).await?;
1307 Ok(())
1308}
1309
1310async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1311 let room_id = RoomId::from_proto(message.room_id);
1312 {
1313 let room = session
1314 .db()
1315 .await
1316 .decline_call(Some(room_id), session.user_id)
1317 .await?
1318 .ok_or_else(|| anyhow!("failed to decline call"))?;
1319 room_updated(&room, &session.peer);
1320 }
1321
1322 for connection_id in session
1323 .connection_pool()
1324 .await
1325 .user_connection_ids(session.user_id)
1326 {
1327 session
1328 .peer
1329 .send(
1330 connection_id,
1331 proto::CallCanceled {
1332 room_id: room_id.to_proto(),
1333 },
1334 )
1335 .trace_err();
1336 }
1337 update_user_contacts(session.user_id, &session).await?;
1338 Ok(())
1339}
1340
1341async fn update_participant_location(
1342 request: proto::UpdateParticipantLocation,
1343 response: Response<proto::UpdateParticipantLocation>,
1344 session: Session,
1345) -> Result<()> {
1346 let room_id = RoomId::from_proto(request.room_id);
1347 let location = request
1348 .location
1349 .ok_or_else(|| anyhow!("invalid location"))?;
1350
1351 let db = session.db().await;
1352 let room = db
1353 .update_room_participant_location(room_id, session.connection_id, location)
1354 .await?;
1355
1356 room_updated(&room, &session.peer);
1357 response.send(proto::Ack {})?;
1358 Ok(())
1359}
1360
1361async fn share_project(
1362 request: proto::ShareProject,
1363 response: Response<proto::ShareProject>,
1364 session: Session,
1365) -> Result<()> {
1366 let (project_id, room) = &*session
1367 .db()
1368 .await
1369 .share_project(
1370 RoomId::from_proto(request.room_id),
1371 session.connection_id,
1372 &request.worktrees,
1373 )
1374 .await?;
1375 response.send(proto::ShareProjectResponse {
1376 project_id: project_id.to_proto(),
1377 })?;
1378 room_updated(&room, &session.peer);
1379
1380 Ok(())
1381}
1382
1383async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1384 let project_id = ProjectId::from_proto(message.project_id);
1385
1386 let (room, guest_connection_ids) = &*session
1387 .db()
1388 .await
1389 .unshare_project(project_id, session.connection_id)
1390 .await?;
1391
1392 broadcast(
1393 Some(session.connection_id),
1394 guest_connection_ids.iter().copied(),
1395 |conn_id| session.peer.send(conn_id, message.clone()),
1396 );
1397 room_updated(&room, &session.peer);
1398
1399 Ok(())
1400}
1401
1402async fn join_project(
1403 request: proto::JoinProject,
1404 response: Response<proto::JoinProject>,
1405 session: Session,
1406) -> Result<()> {
1407 let project_id = ProjectId::from_proto(request.project_id);
1408 let guest_user_id = session.user_id;
1409
1410 tracing::info!(%project_id, "join project");
1411
1412 let (project, replica_id) = &mut *session
1413 .db()
1414 .await
1415 .join_project(project_id, session.connection_id)
1416 .await?;
1417
1418 let collaborators = project
1419 .collaborators
1420 .iter()
1421 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1422 .map(|collaborator| collaborator.to_proto())
1423 .collect::<Vec<_>>();
1424
1425 let worktrees = project
1426 .worktrees
1427 .iter()
1428 .map(|(id, worktree)| proto::WorktreeMetadata {
1429 id: *id,
1430 root_name: worktree.root_name.clone(),
1431 visible: worktree.visible,
1432 abs_path: worktree.abs_path.clone(),
1433 })
1434 .collect::<Vec<_>>();
1435
1436 for collaborator in &collaborators {
1437 session
1438 .peer
1439 .send(
1440 collaborator.peer_id.unwrap().into(),
1441 proto::AddProjectCollaborator {
1442 project_id: project_id.to_proto(),
1443 collaborator: Some(proto::Collaborator {
1444 peer_id: Some(session.connection_id.into()),
1445 replica_id: replica_id.0 as u32,
1446 user_id: guest_user_id.to_proto(),
1447 }),
1448 },
1449 )
1450 .trace_err();
1451 }
1452
1453 // First, we send the metadata associated with each worktree.
1454 response.send(proto::JoinProjectResponse {
1455 worktrees: worktrees.clone(),
1456 replica_id: replica_id.0 as u32,
1457 collaborators: collaborators.clone(),
1458 language_servers: project.language_servers.clone(),
1459 })?;
1460
1461 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1462 #[cfg(any(test, feature = "test-support"))]
1463 const MAX_CHUNK_SIZE: usize = 2;
1464 #[cfg(not(any(test, feature = "test-support")))]
1465 const MAX_CHUNK_SIZE: usize = 256;
1466
1467 // Stream this worktree's entries.
1468 let message = proto::UpdateWorktree {
1469 project_id: project_id.to_proto(),
1470 worktree_id,
1471 abs_path: worktree.abs_path.clone(),
1472 root_name: worktree.root_name,
1473 updated_entries: worktree.entries,
1474 removed_entries: Default::default(),
1475 scan_id: worktree.scan_id,
1476 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1477 updated_repositories: worktree.repository_entries.into_values().collect(),
1478 removed_repositories: Default::default(),
1479 };
1480 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1481 session.peer.send(session.connection_id, update.clone())?;
1482 }
1483
1484 // Stream this worktree's diagnostics.
1485 for summary in worktree.diagnostic_summaries {
1486 session.peer.send(
1487 session.connection_id,
1488 proto::UpdateDiagnosticSummary {
1489 project_id: project_id.to_proto(),
1490 worktree_id: worktree.id,
1491 summary: Some(summary),
1492 },
1493 )?;
1494 }
1495
1496 for settings_file in worktree.settings_files {
1497 session.peer.send(
1498 session.connection_id,
1499 proto::UpdateWorktreeSettings {
1500 project_id: project_id.to_proto(),
1501 worktree_id: worktree.id,
1502 path: settings_file.path,
1503 content: Some(settings_file.content),
1504 },
1505 )?;
1506 }
1507 }
1508
1509 for language_server in &project.language_servers {
1510 session.peer.send(
1511 session.connection_id,
1512 proto::UpdateLanguageServer {
1513 project_id: project_id.to_proto(),
1514 language_server_id: language_server.id,
1515 variant: Some(
1516 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1517 proto::LspDiskBasedDiagnosticsUpdated {},
1518 ),
1519 ),
1520 },
1521 )?;
1522 }
1523
1524 Ok(())
1525}
1526
1527async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1528 let sender_id = session.connection_id;
1529 let project_id = ProjectId::from_proto(request.project_id);
1530
1531 let (room, project) = &*session
1532 .db()
1533 .await
1534 .leave_project(project_id, sender_id)
1535 .await?;
1536 tracing::info!(
1537 %project_id,
1538 host_user_id = %project.host_user_id,
1539 host_connection_id = %project.host_connection_id,
1540 "leave project"
1541 );
1542
1543 project_left(&project, &session);
1544 room_updated(&room, &session.peer);
1545
1546 Ok(())
1547}
1548
1549async fn update_project(
1550 request: proto::UpdateProject,
1551 response: Response<proto::UpdateProject>,
1552 session: Session,
1553) -> Result<()> {
1554 let project_id = ProjectId::from_proto(request.project_id);
1555 let (room, guest_connection_ids) = &*session
1556 .db()
1557 .await
1558 .update_project(project_id, session.connection_id, &request.worktrees)
1559 .await?;
1560 broadcast(
1561 Some(session.connection_id),
1562 guest_connection_ids.iter().copied(),
1563 |connection_id| {
1564 session
1565 .peer
1566 .forward_send(session.connection_id, connection_id, request.clone())
1567 },
1568 );
1569 room_updated(&room, &session.peer);
1570 response.send(proto::Ack {})?;
1571
1572 Ok(())
1573}
1574
1575async fn update_worktree(
1576 request: proto::UpdateWorktree,
1577 response: Response<proto::UpdateWorktree>,
1578 session: Session,
1579) -> Result<()> {
1580 let guest_connection_ids = session
1581 .db()
1582 .await
1583 .update_worktree(&request, session.connection_id)
1584 .await?;
1585
1586 broadcast(
1587 Some(session.connection_id),
1588 guest_connection_ids.iter().copied(),
1589 |connection_id| {
1590 session
1591 .peer
1592 .forward_send(session.connection_id, connection_id, request.clone())
1593 },
1594 );
1595 response.send(proto::Ack {})?;
1596 Ok(())
1597}
1598
1599async fn update_diagnostic_summary(
1600 message: proto::UpdateDiagnosticSummary,
1601 session: Session,
1602) -> Result<()> {
1603 let guest_connection_ids = session
1604 .db()
1605 .await
1606 .update_diagnostic_summary(&message, session.connection_id)
1607 .await?;
1608
1609 broadcast(
1610 Some(session.connection_id),
1611 guest_connection_ids.iter().copied(),
1612 |connection_id| {
1613 session
1614 .peer
1615 .forward_send(session.connection_id, connection_id, message.clone())
1616 },
1617 );
1618
1619 Ok(())
1620}
1621
1622async fn update_worktree_settings(
1623 message: proto::UpdateWorktreeSettings,
1624 session: Session,
1625) -> Result<()> {
1626 let guest_connection_ids = session
1627 .db()
1628 .await
1629 .update_worktree_settings(&message, session.connection_id)
1630 .await?;
1631
1632 broadcast(
1633 Some(session.connection_id),
1634 guest_connection_ids.iter().copied(),
1635 |connection_id| {
1636 session
1637 .peer
1638 .forward_send(session.connection_id, connection_id, message.clone())
1639 },
1640 );
1641
1642 Ok(())
1643}
1644
1645async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1646 broadcast_project_message(request.project_id, request, session).await
1647}
1648
1649async fn start_language_server(
1650 request: proto::StartLanguageServer,
1651 session: Session,
1652) -> Result<()> {
1653 let guest_connection_ids = session
1654 .db()
1655 .await
1656 .start_language_server(&request, session.connection_id)
1657 .await?;
1658
1659 broadcast(
1660 Some(session.connection_id),
1661 guest_connection_ids.iter().copied(),
1662 |connection_id| {
1663 session
1664 .peer
1665 .forward_send(session.connection_id, connection_id, request.clone())
1666 },
1667 );
1668 Ok(())
1669}
1670
1671async fn update_language_server(
1672 request: proto::UpdateLanguageServer,
1673 session: Session,
1674) -> Result<()> {
1675 session.executor.record_backtrace();
1676 let project_id = ProjectId::from_proto(request.project_id);
1677 let project_connection_ids = session
1678 .db()
1679 .await
1680 .project_connection_ids(project_id, session.connection_id)
1681 .await?;
1682 broadcast(
1683 Some(session.connection_id),
1684 project_connection_ids.iter().copied(),
1685 |connection_id| {
1686 session
1687 .peer
1688 .forward_send(session.connection_id, connection_id, request.clone())
1689 },
1690 );
1691 Ok(())
1692}
1693
1694async fn forward_project_request<T>(
1695 request: T,
1696 response: Response<T>,
1697 session: Session,
1698) -> Result<()>
1699where
1700 T: EntityMessage + RequestMessage,
1701{
1702 session.executor.record_backtrace();
1703 let project_id = ProjectId::from_proto(request.remote_entity_id());
1704 let host_connection_id = {
1705 let collaborators = session
1706 .db()
1707 .await
1708 .project_collaborators(project_id, session.connection_id)
1709 .await?;
1710 collaborators
1711 .iter()
1712 .find(|collaborator| collaborator.is_host)
1713 .ok_or_else(|| anyhow!("host not found"))?
1714 .connection_id
1715 };
1716
1717 let payload = session
1718 .peer
1719 .forward_request(session.connection_id, host_connection_id, request)
1720 .await?;
1721
1722 response.send(payload)?;
1723 Ok(())
1724}
1725
1726async fn create_buffer_for_peer(
1727 request: proto::CreateBufferForPeer,
1728 session: Session,
1729) -> Result<()> {
1730 session.executor.record_backtrace();
1731 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1732 session
1733 .peer
1734 .forward_send(session.connection_id, peer_id.into(), request)?;
1735 Ok(())
1736}
1737
1738async fn update_buffer(
1739 request: proto::UpdateBuffer,
1740 response: Response<proto::UpdateBuffer>,
1741 session: Session,
1742) -> Result<()> {
1743 session.executor.record_backtrace();
1744 let project_id = ProjectId::from_proto(request.project_id);
1745 let mut guest_connection_ids;
1746 let mut host_connection_id = None;
1747 {
1748 let collaborators = session
1749 .db()
1750 .await
1751 .project_collaborators(project_id, session.connection_id)
1752 .await?;
1753 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1754 for collaborator in collaborators.iter() {
1755 if collaborator.is_host {
1756 host_connection_id = Some(collaborator.connection_id);
1757 } else {
1758 guest_connection_ids.push(collaborator.connection_id);
1759 }
1760 }
1761 }
1762 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1763
1764 session.executor.record_backtrace();
1765 broadcast(
1766 Some(session.connection_id),
1767 guest_connection_ids,
1768 |connection_id| {
1769 session
1770 .peer
1771 .forward_send(session.connection_id, connection_id, request.clone())
1772 },
1773 );
1774 if host_connection_id != session.connection_id {
1775 session
1776 .peer
1777 .forward_request(session.connection_id, host_connection_id, request.clone())
1778 .await?;
1779 }
1780
1781 response.send(proto::Ack {})?;
1782 Ok(())
1783}
1784
1785async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1786 let project_id = ProjectId::from_proto(request.project_id);
1787 let project_connection_ids = session
1788 .db()
1789 .await
1790 .project_connection_ids(project_id, session.connection_id)
1791 .await?;
1792
1793 broadcast(
1794 Some(session.connection_id),
1795 project_connection_ids.iter().copied(),
1796 |connection_id| {
1797 session
1798 .peer
1799 .forward_send(session.connection_id, connection_id, request.clone())
1800 },
1801 );
1802 Ok(())
1803}
1804
1805async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1806 let project_id = ProjectId::from_proto(request.project_id);
1807 let project_connection_ids = session
1808 .db()
1809 .await
1810 .project_connection_ids(project_id, session.connection_id)
1811 .await?;
1812 broadcast(
1813 Some(session.connection_id),
1814 project_connection_ids.iter().copied(),
1815 |connection_id| {
1816 session
1817 .peer
1818 .forward_send(session.connection_id, connection_id, request.clone())
1819 },
1820 );
1821 Ok(())
1822}
1823
1824async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1825 broadcast_project_message(request.project_id, request, session).await
1826}
1827
1828async fn broadcast_project_message<T: EnvelopedMessage>(
1829 project_id: u64,
1830 request: T,
1831 session: Session,
1832) -> Result<()> {
1833 let project_id = ProjectId::from_proto(project_id);
1834 let project_connection_ids = session
1835 .db()
1836 .await
1837 .project_connection_ids(project_id, session.connection_id)
1838 .await?;
1839 broadcast(
1840 Some(session.connection_id),
1841 project_connection_ids.iter().copied(),
1842 |connection_id| {
1843 session
1844 .peer
1845 .forward_send(session.connection_id, connection_id, request.clone())
1846 },
1847 );
1848 Ok(())
1849}
1850
1851async fn follow(
1852 request: proto::Follow,
1853 response: Response<proto::Follow>,
1854 session: Session,
1855) -> Result<()> {
1856 let project_id = ProjectId::from_proto(request.project_id);
1857 let leader_id = request
1858 .leader_id
1859 .ok_or_else(|| anyhow!("invalid leader id"))?
1860 .into();
1861 let follower_id = session.connection_id;
1862
1863 {
1864 let project_connection_ids = session
1865 .db()
1866 .await
1867 .project_connection_ids(project_id, session.connection_id)
1868 .await?;
1869
1870 if !project_connection_ids.contains(&leader_id) {
1871 Err(anyhow!("no such peer"))?;
1872 }
1873 }
1874
1875 let mut response_payload = session
1876 .peer
1877 .forward_request(session.connection_id, leader_id, request)
1878 .await?;
1879 response_payload
1880 .views
1881 .retain(|view| view.leader_id != Some(follower_id.into()));
1882 response.send(response_payload)?;
1883
1884 let room = session
1885 .db()
1886 .await
1887 .follow(project_id, leader_id, follower_id)
1888 .await?;
1889 room_updated(&room, &session.peer);
1890
1891 Ok(())
1892}
1893
1894async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1895 let project_id = ProjectId::from_proto(request.project_id);
1896 let leader_id = request
1897 .leader_id
1898 .ok_or_else(|| anyhow!("invalid leader id"))?
1899 .into();
1900 let follower_id = session.connection_id;
1901
1902 if !session
1903 .db()
1904 .await
1905 .project_connection_ids(project_id, session.connection_id)
1906 .await?
1907 .contains(&leader_id)
1908 {
1909 Err(anyhow!("no such peer"))?;
1910 }
1911
1912 session
1913 .peer
1914 .forward_send(session.connection_id, leader_id, request)?;
1915
1916 let room = session
1917 .db()
1918 .await
1919 .unfollow(project_id, leader_id, follower_id)
1920 .await?;
1921 room_updated(&room, &session.peer);
1922
1923 Ok(())
1924}
1925
1926async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1927 let project_id = ProjectId::from_proto(request.project_id);
1928 let project_connection_ids = session
1929 .db
1930 .lock()
1931 .await
1932 .project_connection_ids(project_id, session.connection_id)
1933 .await?;
1934
1935 let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1936 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1937 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1938 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1939 });
1940 for follower_peer_id in request.follower_ids.iter().copied() {
1941 let follower_connection_id = follower_peer_id.into();
1942 if project_connection_ids.contains(&follower_connection_id)
1943 && Some(follower_peer_id) != leader_id
1944 {
1945 session.peer.forward_send(
1946 session.connection_id,
1947 follower_connection_id,
1948 request.clone(),
1949 )?;
1950 }
1951 }
1952 Ok(())
1953}
1954
1955async fn get_users(
1956 request: proto::GetUsers,
1957 response: Response<proto::GetUsers>,
1958 session: Session,
1959) -> Result<()> {
1960 let user_ids = request
1961 .user_ids
1962 .into_iter()
1963 .map(UserId::from_proto)
1964 .collect();
1965 let users = session
1966 .db()
1967 .await
1968 .get_users_by_ids(user_ids)
1969 .await?
1970 .into_iter()
1971 .map(|user| proto::User {
1972 id: user.id.to_proto(),
1973 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1974 github_login: user.github_login,
1975 })
1976 .collect();
1977 response.send(proto::UsersResponse { users })?;
1978 Ok(())
1979}
1980
1981async fn fuzzy_search_users(
1982 request: proto::FuzzySearchUsers,
1983 response: Response<proto::FuzzySearchUsers>,
1984 session: Session,
1985) -> Result<()> {
1986 let query = request.query;
1987 let users = match query.len() {
1988 0 => vec![],
1989 1 | 2 => session
1990 .db()
1991 .await
1992 .get_user_by_github_login(&query)
1993 .await?
1994 .into_iter()
1995 .collect(),
1996 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1997 };
1998 let users = users
1999 .into_iter()
2000 .filter(|user| user.id != session.user_id)
2001 .map(|user| proto::User {
2002 id: user.id.to_proto(),
2003 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2004 github_login: user.github_login,
2005 })
2006 .collect();
2007 response.send(proto::UsersResponse { users })?;
2008 Ok(())
2009}
2010
2011async fn request_contact(
2012 request: proto::RequestContact,
2013 response: Response<proto::RequestContact>,
2014 session: Session,
2015) -> Result<()> {
2016 let requester_id = session.user_id;
2017 let responder_id = UserId::from_proto(request.responder_id);
2018 if requester_id == responder_id {
2019 return Err(anyhow!("cannot add yourself as a contact"))?;
2020 }
2021
2022 session
2023 .db()
2024 .await
2025 .send_contact_request(requester_id, responder_id)
2026 .await?;
2027
2028 // Update outgoing contact requests of requester
2029 let mut update = proto::UpdateContacts::default();
2030 update.outgoing_requests.push(responder_id.to_proto());
2031 for connection_id in session
2032 .connection_pool()
2033 .await
2034 .user_connection_ids(requester_id)
2035 {
2036 session.peer.send(connection_id, update.clone())?;
2037 }
2038
2039 // Update incoming contact requests of responder
2040 let mut update = proto::UpdateContacts::default();
2041 update
2042 .incoming_requests
2043 .push(proto::IncomingContactRequest {
2044 requester_id: requester_id.to_proto(),
2045 should_notify: true,
2046 });
2047 for connection_id in session
2048 .connection_pool()
2049 .await
2050 .user_connection_ids(responder_id)
2051 {
2052 session.peer.send(connection_id, update.clone())?;
2053 }
2054
2055 response.send(proto::Ack {})?;
2056 Ok(())
2057}
2058
2059async fn respond_to_contact_request(
2060 request: proto::RespondToContactRequest,
2061 response: Response<proto::RespondToContactRequest>,
2062 session: Session,
2063) -> Result<()> {
2064 let responder_id = session.user_id;
2065 let requester_id = UserId::from_proto(request.requester_id);
2066 let db = session.db().await;
2067 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2068 db.dismiss_contact_notification(responder_id, requester_id)
2069 .await?;
2070 } else {
2071 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2072
2073 db.respond_to_contact_request(responder_id, requester_id, accept)
2074 .await?;
2075 let requester_busy = db.is_user_busy(requester_id).await?;
2076 let responder_busy = db.is_user_busy(responder_id).await?;
2077
2078 let pool = session.connection_pool().await;
2079 // Update responder with new contact
2080 let mut update = proto::UpdateContacts::default();
2081 if accept {
2082 update
2083 .contacts
2084 .push(contact_for_user(requester_id, false, requester_busy, &pool));
2085 }
2086 update
2087 .remove_incoming_requests
2088 .push(requester_id.to_proto());
2089 for connection_id in pool.user_connection_ids(responder_id) {
2090 session.peer.send(connection_id, update.clone())?;
2091 }
2092
2093 // Update requester with new contact
2094 let mut update = proto::UpdateContacts::default();
2095 if accept {
2096 update
2097 .contacts
2098 .push(contact_for_user(responder_id, true, responder_busy, &pool));
2099 }
2100 update
2101 .remove_outgoing_requests
2102 .push(responder_id.to_proto());
2103 for connection_id in pool.user_connection_ids(requester_id) {
2104 session.peer.send(connection_id, update.clone())?;
2105 }
2106 }
2107
2108 response.send(proto::Ack {})?;
2109 Ok(())
2110}
2111
2112async fn remove_contact(
2113 request: proto::RemoveContact,
2114 response: Response<proto::RemoveContact>,
2115 session: Session,
2116) -> Result<()> {
2117 let requester_id = session.user_id;
2118 let responder_id = UserId::from_proto(request.user_id);
2119 let db = session.db().await;
2120 let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
2121
2122 let pool = session.connection_pool().await;
2123 // Update outgoing contact requests of requester
2124 let mut update = proto::UpdateContacts::default();
2125 if contact_accepted {
2126 update.remove_contacts.push(responder_id.to_proto());
2127 } else {
2128 update
2129 .remove_outgoing_requests
2130 .push(responder_id.to_proto());
2131 }
2132 for connection_id in pool.user_connection_ids(requester_id) {
2133 session.peer.send(connection_id, update.clone())?;
2134 }
2135
2136 // Update incoming contact requests of responder
2137 let mut update = proto::UpdateContacts::default();
2138 if contact_accepted {
2139 update.remove_contacts.push(requester_id.to_proto());
2140 } else {
2141 update
2142 .remove_incoming_requests
2143 .push(requester_id.to_proto());
2144 }
2145 for connection_id in pool.user_connection_ids(responder_id) {
2146 session.peer.send(connection_id, update.clone())?;
2147 }
2148
2149 response.send(proto::Ack {})?;
2150 Ok(())
2151}
2152
2153async fn create_channel(
2154 request: proto::CreateChannel,
2155 response: Response<proto::CreateChannel>,
2156 session: Session,
2157) -> Result<()> {
2158 let db = session.db().await;
2159 let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
2160
2161 if let Some(live_kit) = session.live_kit_client.as_ref() {
2162 live_kit.create_room(live_kit_room.clone()).await?;
2163 }
2164
2165 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2166 let id = db
2167 .create_channel(&request.name, parent_id, &live_kit_room, session.user_id)
2168 .await?;
2169
2170 let channel = proto::Channel {
2171 id: id.to_proto(),
2172 name: request.name,
2173 parent_id: request.parent_id,
2174 };
2175
2176 response.send(proto::ChannelResponse {
2177 channel: Some(channel.clone()),
2178 })?;
2179
2180 let mut update = proto::UpdateChannels::default();
2181 update.channels.push(channel);
2182
2183 let user_ids_to_notify = if let Some(parent_id) = parent_id {
2184 db.get_channel_members(parent_id).await?
2185 } else {
2186 vec![session.user_id]
2187 };
2188
2189 let connection_pool = session.connection_pool().await;
2190 for user_id in user_ids_to_notify {
2191 for connection_id in connection_pool.user_connection_ids(user_id) {
2192 let mut update = update.clone();
2193 if user_id == session.user_id {
2194 update.channel_permissions.push(proto::ChannelPermission {
2195 channel_id: id.to_proto(),
2196 is_admin: true,
2197 });
2198 }
2199 session.peer.send(connection_id, update)?;
2200 }
2201 }
2202
2203 Ok(())
2204}
2205
2206async fn remove_channel(
2207 request: proto::RemoveChannel,
2208 response: Response<proto::RemoveChannel>,
2209 session: Session,
2210) -> Result<()> {
2211 let db = session.db().await;
2212
2213 let channel_id = request.channel_id;
2214 let (removed_channels, member_ids) = db
2215 .remove_channel(ChannelId::from_proto(channel_id), session.user_id)
2216 .await?;
2217 response.send(proto::Ack {})?;
2218
2219 // Notify members of removed channels
2220 let mut update = proto::UpdateChannels::default();
2221 update
2222 .remove_channels
2223 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2224
2225 let connection_pool = session.connection_pool().await;
2226 for member_id in member_ids {
2227 for connection_id in connection_pool.user_connection_ids(member_id) {
2228 session.peer.send(connection_id, update.clone())?;
2229 }
2230 }
2231
2232 Ok(())
2233}
2234
2235async fn invite_channel_member(
2236 request: proto::InviteChannelMember,
2237 response: Response<proto::InviteChannelMember>,
2238 session: Session,
2239) -> Result<()> {
2240 let db = session.db().await;
2241 let channel_id = ChannelId::from_proto(request.channel_id);
2242 let invitee_id = UserId::from_proto(request.user_id);
2243 db.invite_channel_member(channel_id, invitee_id, session.user_id, request.admin)
2244 .await?;
2245
2246 let (channel, _) = db
2247 .get_channel(channel_id, session.user_id)
2248 .await?
2249 .ok_or_else(|| anyhow!("channel not found"))?;
2250
2251 let mut update = proto::UpdateChannels::default();
2252 update.channel_invitations.push(proto::Channel {
2253 id: channel.id.to_proto(),
2254 name: channel.name,
2255 parent_id: None,
2256 });
2257 for connection_id in session
2258 .connection_pool()
2259 .await
2260 .user_connection_ids(invitee_id)
2261 {
2262 session.peer.send(connection_id, update.clone())?;
2263 }
2264
2265 response.send(proto::Ack {})?;
2266 Ok(())
2267}
2268
2269async fn remove_channel_member(
2270 request: proto::RemoveChannelMember,
2271 response: Response<proto::RemoveChannelMember>,
2272 session: Session,
2273) -> Result<()> {
2274 let db = session.db().await;
2275 let channel_id = ChannelId::from_proto(request.channel_id);
2276 let member_id = UserId::from_proto(request.user_id);
2277
2278 db.remove_channel_member(channel_id, member_id, session.user_id)
2279 .await?;
2280
2281 let mut update = proto::UpdateChannels::default();
2282 update.remove_channels.push(channel_id.to_proto());
2283
2284 for connection_id in session
2285 .connection_pool()
2286 .await
2287 .user_connection_ids(member_id)
2288 {
2289 session.peer.send(connection_id, update.clone())?;
2290 }
2291
2292 response.send(proto::Ack {})?;
2293 Ok(())
2294}
2295
2296async fn set_channel_member_admin(
2297 request: proto::SetChannelMemberAdmin,
2298 response: Response<proto::SetChannelMemberAdmin>,
2299 session: Session,
2300) -> Result<()> {
2301 let db = session.db().await;
2302 let channel_id = ChannelId::from_proto(request.channel_id);
2303 let member_id = UserId::from_proto(request.user_id);
2304 db.set_channel_member_admin(channel_id, session.user_id, member_id, request.admin)
2305 .await?;
2306
2307 let (channel, has_accepted) = db
2308 .get_channel(channel_id, member_id)
2309 .await?
2310 .ok_or_else(|| anyhow!("channel not found"))?;
2311
2312 let mut update = proto::UpdateChannels::default();
2313 if has_accepted {
2314 update.channel_permissions.push(proto::ChannelPermission {
2315 channel_id: channel.id.to_proto(),
2316 is_admin: request.admin,
2317 });
2318 }
2319
2320 for connection_id in session
2321 .connection_pool()
2322 .await
2323 .user_connection_ids(member_id)
2324 {
2325 session.peer.send(connection_id, update.clone())?;
2326 }
2327
2328 response.send(proto::Ack {})?;
2329 Ok(())
2330}
2331
2332async fn rename_channel(
2333 request: proto::RenameChannel,
2334 response: Response<proto::RenameChannel>,
2335 session: Session,
2336) -> Result<()> {
2337 let db = session.db().await;
2338 let channel_id = ChannelId::from_proto(request.channel_id);
2339 let new_name = db
2340 .rename_channel(channel_id, session.user_id, &request.name)
2341 .await?;
2342
2343 let channel = proto::Channel {
2344 id: request.channel_id,
2345 name: new_name,
2346 parent_id: None,
2347 };
2348 response.send(proto::ChannelResponse {
2349 channel: Some(channel.clone()),
2350 })?;
2351 let mut update = proto::UpdateChannels::default();
2352 update.channels.push(channel);
2353
2354 let member_ids = db.get_channel_members(channel_id).await?;
2355
2356 let connection_pool = session.connection_pool().await;
2357 for member_id in member_ids {
2358 for connection_id in connection_pool.user_connection_ids(member_id) {
2359 session.peer.send(connection_id, update.clone())?;
2360 }
2361 }
2362
2363 Ok(())
2364}
2365
2366async fn get_channel_members(
2367 request: proto::GetChannelMembers,
2368 response: Response<proto::GetChannelMembers>,
2369 session: Session,
2370) -> Result<()> {
2371 let db = session.db().await;
2372 let channel_id = ChannelId::from_proto(request.channel_id);
2373 let members = db
2374 .get_channel_member_details(channel_id, session.user_id)
2375 .await?;
2376 response.send(proto::GetChannelMembersResponse { members })?;
2377 Ok(())
2378}
2379
2380async fn respond_to_channel_invite(
2381 request: proto::RespondToChannelInvite,
2382 response: Response<proto::RespondToChannelInvite>,
2383 session: Session,
2384) -> Result<()> {
2385 let db = session.db().await;
2386 let channel_id = ChannelId::from_proto(request.channel_id);
2387 db.respond_to_channel_invite(channel_id, session.user_id, request.accept)
2388 .await?;
2389
2390 let mut update = proto::UpdateChannels::default();
2391 update
2392 .remove_channel_invitations
2393 .push(channel_id.to_proto());
2394 if request.accept {
2395 let result = db.get_channels_for_user(session.user_id).await?;
2396 update
2397 .channels
2398 .extend(result.channels.into_iter().map(|channel| proto::Channel {
2399 id: channel.id.to_proto(),
2400 name: channel.name,
2401 parent_id: channel.parent_id.map(ChannelId::to_proto),
2402 }));
2403 update
2404 .channel_participants
2405 .extend(
2406 result
2407 .channel_participants
2408 .into_iter()
2409 .map(|(channel_id, user_ids)| proto::ChannelParticipants {
2410 channel_id: channel_id.to_proto(),
2411 participant_user_ids: user_ids.into_iter().map(UserId::to_proto).collect(),
2412 }),
2413 );
2414 update
2415 .channel_permissions
2416 .extend(
2417 result
2418 .channels_with_admin_privileges
2419 .into_iter()
2420 .map(|channel_id| proto::ChannelPermission {
2421 channel_id: channel_id.to_proto(),
2422 is_admin: true,
2423 }),
2424 );
2425 }
2426 session.peer.send(session.connection_id, update)?;
2427 response.send(proto::Ack {})?;
2428
2429 Ok(())
2430}
2431
2432async fn join_channel(
2433 request: proto::JoinChannel,
2434 response: Response<proto::JoinChannel>,
2435 session: Session,
2436) -> Result<()> {
2437 let channel_id = ChannelId::from_proto(request.channel_id);
2438
2439 let joined_room = {
2440 leave_room_for_session(&session).await?;
2441 let db = session.db().await;
2442
2443 let room_id = db.room_id_for_channel(channel_id).await?;
2444
2445 let joined_room = db
2446 .join_room(room_id, session.user_id, session.connection_id)
2447 .await?;
2448
2449 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2450 let token = live_kit
2451 .room_token(
2452 &joined_room.room.live_kit_room,
2453 &session.user_id.to_string(),
2454 )
2455 .trace_err()?;
2456
2457 Some(LiveKitConnectionInfo {
2458 server_url: live_kit.url().into(),
2459 token,
2460 })
2461 });
2462
2463 response.send(proto::JoinRoomResponse {
2464 room: Some(joined_room.room.clone()),
2465 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2466 live_kit_connection_info,
2467 })?;
2468
2469 room_updated(&joined_room.room, &session.peer);
2470
2471 joined_room.into_inner()
2472 };
2473
2474 channel_updated(
2475 channel_id,
2476 &joined_room.room,
2477 &joined_room.channel_members,
2478 &session.peer,
2479 &*session.connection_pool().await,
2480 );
2481
2482 update_user_contacts(session.user_id, &session).await?;
2483
2484 Ok(())
2485}
2486
2487async fn join_channel_buffer(
2488 request: proto::JoinChannelBuffer,
2489 response: Response<proto::JoinChannelBuffer>,
2490 session: Session,
2491) -> Result<()> {
2492 let db = session.db().await;
2493 let channel_id = ChannelId::from_proto(request.channel_id);
2494
2495 let open_response = db
2496 .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2497 .await?;
2498
2499 let replica_id = open_response.replica_id;
2500 let collaborators = open_response.collaborators.clone();
2501
2502 response.send(open_response)?;
2503
2504 let update = AddChannelBufferCollaborator {
2505 channel_id: channel_id.to_proto(),
2506 collaborator: Some(proto::Collaborator {
2507 user_id: session.user_id.to_proto(),
2508 peer_id: Some(session.connection_id.into()),
2509 replica_id,
2510 }),
2511 };
2512 channel_buffer_updated(
2513 session.connection_id,
2514 collaborators
2515 .iter()
2516 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2517 &update,
2518 &session.peer,
2519 );
2520
2521 Ok(())
2522}
2523
2524async fn update_channel_buffer(
2525 request: proto::UpdateChannelBuffer,
2526 session: Session,
2527) -> Result<()> {
2528 let db = session.db().await;
2529 let channel_id = ChannelId::from_proto(request.channel_id);
2530
2531 let collaborators = db
2532 .update_channel_buffer(channel_id, session.user_id, &request.operations)
2533 .await?;
2534
2535 channel_buffer_updated(
2536 session.connection_id,
2537 collaborators,
2538 &proto::UpdateChannelBuffer {
2539 channel_id: channel_id.to_proto(),
2540 operations: request.operations,
2541 },
2542 &session.peer,
2543 );
2544 Ok(())
2545}
2546
2547async fn leave_channel_buffer(
2548 request: proto::LeaveChannelBuffer,
2549 response: Response<proto::LeaveChannelBuffer>,
2550 session: Session,
2551) -> Result<()> {
2552 let db = session.db().await;
2553 let channel_id = ChannelId::from_proto(request.channel_id);
2554
2555 let collaborators_to_notify = db
2556 .leave_channel_buffer(channel_id, session.connection_id)
2557 .await?;
2558
2559 response.send(Ack {})?;
2560
2561 channel_buffer_updated(
2562 session.connection_id,
2563 collaborators_to_notify,
2564 &proto::RemoveChannelBufferCollaborator {
2565 channel_id: channel_id.to_proto(),
2566 peer_id: Some(session.connection_id.into()),
2567 },
2568 &session.peer,
2569 );
2570
2571 Ok(())
2572}
2573
2574fn channel_buffer_updated<T: EnvelopedMessage>(
2575 sender_id: ConnectionId,
2576 collaborators: impl IntoIterator<Item = ConnectionId>,
2577 message: &T,
2578 peer: &Peer,
2579) {
2580 broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2581 peer.send(peer_id.into(), message.clone())
2582 });
2583}
2584
2585async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
2586 let project_id = ProjectId::from_proto(request.project_id);
2587 let project_connection_ids = session
2588 .db()
2589 .await
2590 .project_connection_ids(project_id, session.connection_id)
2591 .await?;
2592 broadcast(
2593 Some(session.connection_id),
2594 project_connection_ids.iter().copied(),
2595 |connection_id| {
2596 session
2597 .peer
2598 .forward_send(session.connection_id, connection_id, request.clone())
2599 },
2600 );
2601 Ok(())
2602}
2603
2604async fn get_private_user_info(
2605 _request: proto::GetPrivateUserInfo,
2606 response: Response<proto::GetPrivateUserInfo>,
2607 session: Session,
2608) -> Result<()> {
2609 let metrics_id = session
2610 .db()
2611 .await
2612 .get_user_metrics_id(session.user_id)
2613 .await?;
2614 let user = session
2615 .db()
2616 .await
2617 .get_user_by_id(session.user_id)
2618 .await?
2619 .ok_or_else(|| anyhow!("user not found"))?;
2620 response.send(proto::GetPrivateUserInfoResponse {
2621 metrics_id,
2622 staff: user.admin,
2623 })?;
2624 Ok(())
2625}
2626
2627fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
2628 match message {
2629 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
2630 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
2631 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
2632 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
2633 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
2634 code: frame.code.into(),
2635 reason: frame.reason,
2636 })),
2637 }
2638}
2639
2640fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
2641 match message {
2642 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
2643 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
2644 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
2645 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
2646 AxumMessage::Close(frame) => {
2647 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
2648 code: frame.code.into(),
2649 reason: frame.reason,
2650 }))
2651 }
2652 }
2653}
2654
2655fn build_initial_channels_update(
2656 channels: ChannelsForUser,
2657 channel_invites: Vec<db::Channel>,
2658) -> proto::UpdateChannels {
2659 let mut update = proto::UpdateChannels::default();
2660
2661 for channel in channels.channels {
2662 update.channels.push(proto::Channel {
2663 id: channel.id.to_proto(),
2664 name: channel.name,
2665 parent_id: channel.parent_id.map(|id| id.to_proto()),
2666 });
2667 }
2668
2669 for (channel_id, participants) in channels.channel_participants {
2670 update
2671 .channel_participants
2672 .push(proto::ChannelParticipants {
2673 channel_id: channel_id.to_proto(),
2674 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
2675 });
2676 }
2677
2678 update
2679 .channel_permissions
2680 .extend(
2681 channels
2682 .channels_with_admin_privileges
2683 .into_iter()
2684 .map(|id| proto::ChannelPermission {
2685 channel_id: id.to_proto(),
2686 is_admin: true,
2687 }),
2688 );
2689
2690 for channel in channel_invites {
2691 update.channel_invitations.push(proto::Channel {
2692 id: channel.id.to_proto(),
2693 name: channel.name,
2694 parent_id: None,
2695 });
2696 }
2697
2698 update
2699}
2700
2701fn build_initial_contacts_update(
2702 contacts: Vec<db::Contact>,
2703 pool: &ConnectionPool,
2704) -> proto::UpdateContacts {
2705 let mut update = proto::UpdateContacts::default();
2706
2707 for contact in contacts {
2708 match contact {
2709 db::Contact::Accepted {
2710 user_id,
2711 should_notify,
2712 busy,
2713 } => {
2714 update
2715 .contacts
2716 .push(contact_for_user(user_id, should_notify, busy, &pool));
2717 }
2718 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
2719 db::Contact::Incoming {
2720 user_id,
2721 should_notify,
2722 } => update
2723 .incoming_requests
2724 .push(proto::IncomingContactRequest {
2725 requester_id: user_id.to_proto(),
2726 should_notify,
2727 }),
2728 }
2729 }
2730
2731 update
2732}
2733
2734fn contact_for_user(
2735 user_id: UserId,
2736 should_notify: bool,
2737 busy: bool,
2738 pool: &ConnectionPool,
2739) -> proto::Contact {
2740 proto::Contact {
2741 user_id: user_id.to_proto(),
2742 online: pool.is_user_online(user_id),
2743 busy,
2744 should_notify,
2745 }
2746}
2747
2748fn room_updated(room: &proto::Room, peer: &Peer) {
2749 broadcast(
2750 None,
2751 room.participants
2752 .iter()
2753 .filter_map(|participant| Some(participant.peer_id?.into())),
2754 |peer_id| {
2755 peer.send(
2756 peer_id.into(),
2757 proto::RoomUpdated {
2758 room: Some(room.clone()),
2759 },
2760 )
2761 },
2762 );
2763}
2764
2765fn channel_updated(
2766 channel_id: ChannelId,
2767 room: &proto::Room,
2768 channel_members: &[UserId],
2769 peer: &Peer,
2770 pool: &ConnectionPool,
2771) {
2772 let participants = room
2773 .participants
2774 .iter()
2775 .map(|p| p.user_id)
2776 .collect::<Vec<_>>();
2777
2778 broadcast(
2779 None,
2780 channel_members
2781 .iter()
2782 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2783 |peer_id| {
2784 peer.send(
2785 peer_id.into(),
2786 proto::UpdateChannels {
2787 channel_participants: vec![proto::ChannelParticipants {
2788 channel_id: channel_id.to_proto(),
2789 participant_user_ids: participants.clone(),
2790 }],
2791 ..Default::default()
2792 },
2793 )
2794 },
2795 );
2796}
2797
2798async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
2799 let db = session.db().await;
2800
2801 let contacts = db.get_contacts(user_id).await?;
2802 let busy = db.is_user_busy(user_id).await?;
2803
2804 let pool = session.connection_pool().await;
2805 let updated_contact = contact_for_user(user_id, false, busy, &pool);
2806 for contact in contacts {
2807 if let db::Contact::Accepted {
2808 user_id: contact_user_id,
2809 ..
2810 } = contact
2811 {
2812 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
2813 session
2814 .peer
2815 .send(
2816 contact_conn_id,
2817 proto::UpdateContacts {
2818 contacts: vec![updated_contact.clone()],
2819 remove_contacts: Default::default(),
2820 incoming_requests: Default::default(),
2821 remove_incoming_requests: Default::default(),
2822 outgoing_requests: Default::default(),
2823 remove_outgoing_requests: Default::default(),
2824 },
2825 )
2826 .trace_err();
2827 }
2828 }
2829 }
2830 Ok(())
2831}
2832
2833async fn leave_room_for_session(session: &Session) -> Result<()> {
2834 let mut contacts_to_update = HashSet::default();
2835
2836 let room_id;
2837 let canceled_calls_to_user_ids;
2838 let live_kit_room;
2839 let delete_live_kit_room;
2840 let room;
2841 let channel_members;
2842 let channel_id;
2843
2844 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
2845 contacts_to_update.insert(session.user_id);
2846
2847 for project in left_room.left_projects.values() {
2848 project_left(project, session);
2849 }
2850
2851 room_id = RoomId::from_proto(left_room.room.id);
2852 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
2853 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
2854 delete_live_kit_room = left_room.deleted;
2855 room = mem::take(&mut left_room.room);
2856 channel_members = mem::take(&mut left_room.channel_members);
2857 channel_id = left_room.channel_id;
2858
2859 room_updated(&room, &session.peer);
2860 } else {
2861 return Ok(());
2862 }
2863
2864 if let Some(channel_id) = channel_id {
2865 channel_updated(
2866 channel_id,
2867 &room,
2868 &channel_members,
2869 &session.peer,
2870 &*session.connection_pool().await,
2871 );
2872 }
2873
2874 {
2875 let pool = session.connection_pool().await;
2876 for canceled_user_id in canceled_calls_to_user_ids {
2877 for connection_id in pool.user_connection_ids(canceled_user_id) {
2878 session
2879 .peer
2880 .send(
2881 connection_id,
2882 proto::CallCanceled {
2883 room_id: room_id.to_proto(),
2884 },
2885 )
2886 .trace_err();
2887 }
2888 contacts_to_update.insert(canceled_user_id);
2889 }
2890 }
2891
2892 for contact_user_id in contacts_to_update {
2893 update_user_contacts(contact_user_id, &session).await?;
2894 }
2895
2896 if let Some(live_kit) = session.live_kit_client.as_ref() {
2897 live_kit
2898 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
2899 .await
2900 .trace_err();
2901
2902 if delete_live_kit_room {
2903 live_kit.delete_room(live_kit_room).await.trace_err();
2904 }
2905 }
2906
2907 Ok(())
2908}
2909
2910async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
2911 let left_channel_buffers = session
2912 .db()
2913 .await
2914 .leave_channel_buffers(session.connection_id)
2915 .await?;
2916
2917 for (channel_id, connections) in left_channel_buffers {
2918 channel_buffer_updated(
2919 session.connection_id,
2920 connections,
2921 &proto::RemoveChannelBufferCollaborator {
2922 channel_id: channel_id.to_proto(),
2923 peer_id: Some(session.connection_id.into()),
2924 },
2925 &session.peer,
2926 );
2927 }
2928
2929 Ok(())
2930}
2931
2932fn project_left(project: &db::LeftProject, session: &Session) {
2933 for connection_id in &project.connection_ids {
2934 if project.host_user_id == session.user_id {
2935 session
2936 .peer
2937 .send(
2938 *connection_id,
2939 proto::UnshareProject {
2940 project_id: project.id.to_proto(),
2941 },
2942 )
2943 .trace_err();
2944 } else {
2945 session
2946 .peer
2947 .send(
2948 *connection_id,
2949 proto::RemoveProjectCollaborator {
2950 project_id: project.id.to_proto(),
2951 peer_id: Some(session.connection_id.into()),
2952 },
2953 )
2954 .trace_err();
2955 }
2956 }
2957}
2958
2959pub trait ResultExt {
2960 type Ok;
2961
2962 fn trace_err(self) -> Option<Self::Ok>;
2963}
2964
2965impl<T, E> ResultExt for Result<T, E>
2966where
2967 E: std::fmt::Debug,
2968{
2969 type Ok = T;
2970
2971 fn trace_err(self) -> Option<T> {
2972 match self {
2973 Ok(value) => Some(value),
2974 Err(error) => {
2975 tracing::error!("{:?}", error);
2976 None
2977 }
2978 }
2979 }
2980}