1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{
6 self, BufferId, ChannelId, ChannelsForUser, Database, ProjectId, RoomId, ServerId, User,
7 UserId,
8 },
9 executor::Executor,
10 AppState, Result,
11};
12use anyhow::anyhow;
13use async_tungstenite::tungstenite::{
14 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
15};
16use axum::{
17 body::Body,
18 extract::{
19 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
20 ConnectInfo, WebSocketUpgrade,
21 },
22 headers::{Header, HeaderName},
23 http::StatusCode,
24 middleware,
25 response::IntoResponse,
26 routing::get,
27 Extension, Router, TypedHeader,
28};
29use collections::{HashMap, HashSet};
30pub use connection_pool::ConnectionPool;
31use futures::{
32 channel::oneshot,
33 future::{self, BoxFuture},
34 stream::FuturesUnordered,
35 FutureExt, SinkExt, StreamExt, TryStreamExt,
36};
37use lazy_static::lazy_static;
38use prometheus::{register_int_gauge, IntGauge};
39use rpc::{
40 proto::{
41 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
42 OpenChannelBufferResponse, RequestMessage,
43 },
44 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
45};
46use serde::{Serialize, Serializer};
47use std::{
48 any::TypeId,
49 fmt,
50 future::Future,
51 marker::PhantomData,
52 mem,
53 net::SocketAddr,
54 ops::{Deref, DerefMut},
55 rc::Rc,
56 sync::{
57 atomic::{AtomicBool, Ordering::SeqCst},
58 Arc,
59 },
60 time::{Duration, Instant},
61};
62use tokio::sync::{watch, Semaphore};
63use tower::ServiceBuilder;
64use tracing::{info_span, instrument, Instrument};
65
66pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
67pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
68
69lazy_static! {
70 static ref METRIC_CONNECTIONS: IntGauge =
71 register_int_gauge!("connections", "number of connections").unwrap();
72 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
73 "shared_projects",
74 "number of open projects with one or more guests"
75 )
76 .unwrap();
77}
78
79type MessageHandler =
80 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
81
82struct Response<R> {
83 peer: Arc<Peer>,
84 receipt: Receipt<R>,
85 responded: Arc<AtomicBool>,
86}
87
88impl<R: RequestMessage> Response<R> {
89 fn send(self, payload: R::Response) -> Result<()> {
90 self.responded.store(true, SeqCst);
91 self.peer.respond(self.receipt, payload)?;
92 Ok(())
93 }
94}
95
96#[derive(Clone)]
97struct Session {
98 user_id: UserId,
99 connection_id: ConnectionId,
100 db: Arc<tokio::sync::Mutex<DbHandle>>,
101 peer: Arc<Peer>,
102 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
103 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
104 executor: Executor,
105}
106
107impl Session {
108 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
109 #[cfg(test)]
110 tokio::task::yield_now().await;
111 let guard = self.db.lock().await;
112 #[cfg(test)]
113 tokio::task::yield_now().await;
114 guard
115 }
116
117 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
118 #[cfg(test)]
119 tokio::task::yield_now().await;
120 let guard = self.connection_pool.lock();
121 ConnectionPoolGuard {
122 guard,
123 _not_send: PhantomData,
124 }
125 }
126}
127
128impl fmt::Debug for Session {
129 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
130 f.debug_struct("Session")
131 .field("user_id", &self.user_id)
132 .field("connection_id", &self.connection_id)
133 .finish()
134 }
135}
136
137struct DbHandle(Arc<Database>);
138
139impl Deref for DbHandle {
140 type Target = Database;
141
142 fn deref(&self) -> &Self::Target {
143 self.0.as_ref()
144 }
145}
146
147pub struct Server {
148 id: parking_lot::Mutex<ServerId>,
149 peer: Arc<Peer>,
150 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
151 app_state: Arc<AppState>,
152 executor: Executor,
153 handlers: HashMap<TypeId, MessageHandler>,
154 teardown: watch::Sender<()>,
155}
156
157pub(crate) struct ConnectionPoolGuard<'a> {
158 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
159 _not_send: PhantomData<Rc<()>>,
160}
161
162#[derive(Serialize)]
163pub struct ServerSnapshot<'a> {
164 peer: &'a Peer,
165 #[serde(serialize_with = "serialize_deref")]
166 connection_pool: ConnectionPoolGuard<'a>,
167}
168
169pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
170where
171 S: Serializer,
172 T: Deref<Target = U>,
173 U: Serialize,
174{
175 Serialize::serialize(value.deref(), serializer)
176}
177
178impl Server {
179 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
180 let mut server = Self {
181 id: parking_lot::Mutex::new(id),
182 peer: Peer::new(id.0 as u32),
183 app_state,
184 executor,
185 connection_pool: Default::default(),
186 handlers: Default::default(),
187 teardown: watch::channel(()).0,
188 };
189
190 server
191 .add_request_handler(ping)
192 .add_request_handler(create_room)
193 .add_request_handler(join_room)
194 .add_request_handler(rejoin_room)
195 .add_request_handler(leave_room)
196 .add_request_handler(call)
197 .add_request_handler(cancel_call)
198 .add_message_handler(decline_call)
199 .add_request_handler(update_participant_location)
200 .add_request_handler(share_project)
201 .add_message_handler(unshare_project)
202 .add_request_handler(join_project)
203 .add_message_handler(leave_project)
204 .add_request_handler(update_project)
205 .add_request_handler(update_worktree)
206 .add_message_handler(start_language_server)
207 .add_message_handler(update_language_server)
208 .add_message_handler(update_diagnostic_summary)
209 .add_message_handler(update_worktree_settings)
210 .add_message_handler(refresh_inlay_hints)
211 .add_request_handler(forward_project_request::<proto::GetHover>)
212 .add_request_handler(forward_project_request::<proto::GetDefinition>)
213 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
214 .add_request_handler(forward_project_request::<proto::GetReferences>)
215 .add_request_handler(forward_project_request::<proto::SearchProject>)
216 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
217 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
218 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
219 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
220 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
221 .add_request_handler(forward_project_request::<proto::GetCompletions>)
222 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
223 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
224 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
225 .add_request_handler(forward_project_request::<proto::PrepareRename>)
226 .add_request_handler(forward_project_request::<proto::PerformRename>)
227 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
228 .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
229 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
230 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
231 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
232 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
233 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
234 .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
235 .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
236 .add_request_handler(forward_project_request::<proto::InlayHints>)
237 .add_message_handler(create_buffer_for_peer)
238 .add_request_handler(update_buffer)
239 .add_message_handler(update_buffer_file)
240 .add_message_handler(buffer_reloaded)
241 .add_message_handler(buffer_saved)
242 .add_request_handler(forward_project_request::<proto::SaveBuffer>)
243 .add_request_handler(get_users)
244 .add_request_handler(fuzzy_search_users)
245 .add_request_handler(request_contact)
246 .add_request_handler(remove_contact)
247 .add_request_handler(respond_to_contact_request)
248 .add_request_handler(create_channel)
249 .add_request_handler(remove_channel)
250 .add_request_handler(invite_channel_member)
251 .add_request_handler(remove_channel_member)
252 .add_request_handler(set_channel_member_admin)
253 .add_request_handler(rename_channel)
254 .add_request_handler(open_channel_buffer)
255 .add_request_handler(close_channel_buffer)
256 .add_message_handler(update_channel_buffer)
257 .add_request_handler(get_channel_members)
258 .add_request_handler(respond_to_channel_invite)
259 .add_request_handler(join_channel)
260 .add_request_handler(follow)
261 .add_message_handler(unfollow)
262 .add_message_handler(update_followers)
263 .add_message_handler(update_diff_base)
264 .add_request_handler(get_private_user_info);
265
266 Arc::new(server)
267 }
268
269 pub async fn start(&self) -> Result<()> {
270 let server_id = *self.id.lock();
271 let app_state = self.app_state.clone();
272 let peer = self.peer.clone();
273 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
274 let pool = self.connection_pool.clone();
275 let live_kit_client = self.app_state.live_kit_client.clone();
276
277 let span = info_span!("start server");
278 self.executor.spawn_detached(
279 async move {
280 tracing::info!("waiting for cleanup timeout");
281 timeout.await;
282 tracing::info!("cleanup timeout expired, retrieving stale rooms");
283 if let Some(room_ids) = app_state
284 .db
285 .stale_room_ids(&app_state.config.zed_environment, server_id)
286 .await
287 .trace_err()
288 {
289 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
290 for room_id in room_ids {
291 let mut contacts_to_update = HashSet::default();
292 let mut canceled_calls_to_user_ids = Vec::new();
293 let mut live_kit_room = String::new();
294 let mut delete_live_kit_room = false;
295
296 if let Some(mut refreshed_room) = app_state
297 .db
298 .refresh_room(room_id, server_id)
299 .await
300 .trace_err()
301 {
302 tracing::info!(
303 room_id = room_id.0,
304 new_participant_count = refreshed_room.room.participants.len(),
305 "refreshed room"
306 );
307 room_updated(&refreshed_room.room, &peer);
308 if let Some(channel_id) = refreshed_room.channel_id {
309 channel_updated(
310 channel_id,
311 &refreshed_room.room,
312 &refreshed_room.channel_members,
313 &peer,
314 &*pool.lock(),
315 );
316 }
317 contacts_to_update
318 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
319 contacts_to_update
320 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
321 canceled_calls_to_user_ids =
322 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
323 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
324 delete_live_kit_room = refreshed_room.room.participants.is_empty();
325 }
326
327 {
328 let pool = pool.lock();
329 for canceled_user_id in canceled_calls_to_user_ids {
330 for connection_id in pool.user_connection_ids(canceled_user_id) {
331 peer.send(
332 connection_id,
333 proto::CallCanceled {
334 room_id: room_id.to_proto(),
335 },
336 )
337 .trace_err();
338 }
339 }
340 }
341
342 for user_id in contacts_to_update {
343 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
344 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
345 if let Some((busy, contacts)) = busy.zip(contacts) {
346 let pool = pool.lock();
347 let updated_contact = contact_for_user(user_id, false, busy, &pool);
348 for contact in contacts {
349 if let db::Contact::Accepted {
350 user_id: contact_user_id,
351 ..
352 } = contact
353 {
354 for contact_conn_id in
355 pool.user_connection_ids(contact_user_id)
356 {
357 peer.send(
358 contact_conn_id,
359 proto::UpdateContacts {
360 contacts: vec![updated_contact.clone()],
361 remove_contacts: Default::default(),
362 incoming_requests: Default::default(),
363 remove_incoming_requests: Default::default(),
364 outgoing_requests: Default::default(),
365 remove_outgoing_requests: Default::default(),
366 },
367 )
368 .trace_err();
369 }
370 }
371 }
372 }
373 }
374
375 if let Some(live_kit) = live_kit_client.as_ref() {
376 if delete_live_kit_room {
377 live_kit.delete_room(live_kit_room).await.trace_err();
378 }
379 }
380 }
381 }
382
383 app_state
384 .db
385 .delete_stale_servers(&app_state.config.zed_environment, server_id)
386 .await
387 .trace_err();
388 }
389 .instrument(span),
390 );
391 Ok(())
392 }
393
394 pub fn teardown(&self) {
395 self.peer.teardown();
396 self.connection_pool.lock().reset();
397 let _ = self.teardown.send(());
398 }
399
400 #[cfg(test)]
401 pub fn reset(&self, id: ServerId) {
402 self.teardown();
403 *self.id.lock() = id;
404 self.peer.reset(id.0 as u32);
405 }
406
407 #[cfg(test)]
408 pub fn id(&self) -> ServerId {
409 *self.id.lock()
410 }
411
412 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
413 where
414 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
415 Fut: 'static + Send + Future<Output = Result<()>>,
416 M: EnvelopedMessage,
417 {
418 let prev_handler = self.handlers.insert(
419 TypeId::of::<M>(),
420 Box::new(move |envelope, session| {
421 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
422 let span = info_span!(
423 "handle message",
424 payload_type = envelope.payload_type_name()
425 );
426 span.in_scope(|| {
427 tracing::info!(
428 payload_type = envelope.payload_type_name(),
429 "message received"
430 );
431 });
432 let start_time = Instant::now();
433 let future = (handler)(*envelope, session);
434 async move {
435 let result = future.await;
436 let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
437 match result {
438 Err(error) => {
439 tracing::error!(%error, ?duration_ms, "error handling message")
440 }
441 Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
442 }
443 }
444 .instrument(span)
445 .boxed()
446 }),
447 );
448 if prev_handler.is_some() {
449 panic!("registered a handler for the same message twice");
450 }
451 self
452 }
453
454 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
455 where
456 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
457 Fut: 'static + Send + Future<Output = Result<()>>,
458 M: EnvelopedMessage,
459 {
460 self.add_handler(move |envelope, session| handler(envelope.payload, session));
461 self
462 }
463
464 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
465 where
466 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
467 Fut: Send + Future<Output = Result<()>>,
468 M: RequestMessage,
469 {
470 let handler = Arc::new(handler);
471 self.add_handler(move |envelope, session| {
472 let receipt = envelope.receipt();
473 let handler = handler.clone();
474 async move {
475 let peer = session.peer.clone();
476 let responded = Arc::new(AtomicBool::default());
477 let response = Response {
478 peer: peer.clone(),
479 responded: responded.clone(),
480 receipt,
481 };
482 match (handler)(envelope.payload, response, session).await {
483 Ok(()) => {
484 if responded.load(std::sync::atomic::Ordering::SeqCst) {
485 Ok(())
486 } else {
487 Err(anyhow!("handler did not send a response"))?
488 }
489 }
490 Err(error) => {
491 peer.respond_with_error(
492 receipt,
493 proto::Error {
494 message: error.to_string(),
495 },
496 )?;
497 Err(error)
498 }
499 }
500 }
501 })
502 }
503
504 pub fn handle_connection(
505 self: &Arc<Self>,
506 connection: Connection,
507 address: String,
508 user: User,
509 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
510 executor: Executor,
511 ) -> impl Future<Output = Result<()>> {
512 let this = self.clone();
513 let user_id = user.id;
514 let login = user.github_login;
515 let span = info_span!("handle connection", %user_id, %login, %address);
516 let mut teardown = self.teardown.subscribe();
517 async move {
518 let (connection_id, handle_io, mut incoming_rx) = this
519 .peer
520 .add_connection(connection, {
521 let executor = executor.clone();
522 move |duration| executor.sleep(duration)
523 });
524
525 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
526 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
527 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
528
529 if let Some(send_connection_id) = send_connection_id.take() {
530 let _ = send_connection_id.send(connection_id);
531 }
532
533 if !user.connected_once {
534 this.peer.send(connection_id, proto::ShowContacts {})?;
535 this.app_state.db.set_user_connected_once(user_id, true).await?;
536 }
537
538 let (contacts, invite_code, channels_for_user, channel_invites) = future::try_join4(
539 this.app_state.db.get_contacts(user_id),
540 this.app_state.db.get_invite_code_for_user(user_id),
541 this.app_state.db.get_channels_for_user(user_id),
542 this.app_state.db.get_channel_invites_for_user(user_id)
543 ).await?;
544
545 {
546 let mut pool = this.connection_pool.lock();
547 pool.add_connection(connection_id, user_id, user.admin);
548 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
549 this.peer.send(connection_id, build_initial_channels_update(
550 channels_for_user,
551 channel_invites
552 ))?;
553
554 if let Some((code, count)) = invite_code {
555 this.peer.send(connection_id, proto::UpdateInviteInfo {
556 url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
557 count: count as u32,
558 })?;
559 }
560 }
561
562 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
563 this.peer.send(connection_id, incoming_call)?;
564 }
565
566 let session = Session {
567 user_id,
568 connection_id,
569 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
570 peer: this.peer.clone(),
571 connection_pool: this.connection_pool.clone(),
572 live_kit_client: this.app_state.live_kit_client.clone(),
573 executor: executor.clone(),
574 };
575 update_user_contacts(user_id, &session).await?;
576
577 let handle_io = handle_io.fuse();
578 futures::pin_mut!(handle_io);
579
580 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
581 // This prevents deadlocks when e.g., client A performs a request to client B and
582 // client B performs a request to client A. If both clients stop processing further
583 // messages until their respective request completes, they won't have a chance to
584 // respond to the other client's request and cause a deadlock.
585 //
586 // This arrangement ensures we will attempt to process earlier messages first, but fall
587 // back to processing messages arrived later in the spirit of making progress.
588 let mut foreground_message_handlers = FuturesUnordered::new();
589 let concurrent_handlers = Arc::new(Semaphore::new(256));
590 loop {
591 let next_message = async {
592 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
593 let message = incoming_rx.next().await;
594 (permit, message)
595 }.fuse();
596 futures::pin_mut!(next_message);
597 futures::select_biased! {
598 _ = teardown.changed().fuse() => return Ok(()),
599 result = handle_io => {
600 if let Err(error) = result {
601 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
602 }
603 break;
604 }
605 _ = foreground_message_handlers.next() => {}
606 next_message = next_message => {
607 let (permit, message) = next_message;
608 if let Some(message) = message {
609 let type_name = message.payload_type_name();
610 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
611 let span_enter = span.enter();
612 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
613 let is_background = message.is_background();
614 let handle_message = (handler)(message, session.clone());
615 drop(span_enter);
616
617 let handle_message = async move {
618 handle_message.await;
619 drop(permit);
620 }.instrument(span);
621 if is_background {
622 executor.spawn_detached(handle_message);
623 } else {
624 foreground_message_handlers.push(handle_message);
625 }
626 } else {
627 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
628 }
629 } else {
630 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
631 break;
632 }
633 }
634 }
635 }
636
637 drop(foreground_message_handlers);
638 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
639 if let Err(error) = connection_lost(session, teardown, executor).await {
640 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
641 }
642
643 Ok(())
644 }.instrument(span)
645 }
646
647 pub async fn invite_code_redeemed(
648 self: &Arc<Self>,
649 inviter_id: UserId,
650 invitee_id: UserId,
651 ) -> Result<()> {
652 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
653 if let Some(code) = &user.invite_code {
654 let pool = self.connection_pool.lock();
655 let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
656 for connection_id in pool.user_connection_ids(inviter_id) {
657 self.peer.send(
658 connection_id,
659 proto::UpdateContacts {
660 contacts: vec![invitee_contact.clone()],
661 ..Default::default()
662 },
663 )?;
664 self.peer.send(
665 connection_id,
666 proto::UpdateInviteInfo {
667 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
668 count: user.invite_count as u32,
669 },
670 )?;
671 }
672 }
673 }
674 Ok(())
675 }
676
677 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
678 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
679 if let Some(invite_code) = &user.invite_code {
680 let pool = self.connection_pool.lock();
681 for connection_id in pool.user_connection_ids(user_id) {
682 self.peer.send(
683 connection_id,
684 proto::UpdateInviteInfo {
685 url: format!(
686 "{}{}",
687 self.app_state.config.invite_link_prefix, invite_code
688 ),
689 count: user.invite_count as u32,
690 },
691 )?;
692 }
693 }
694 }
695 Ok(())
696 }
697
698 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
699 ServerSnapshot {
700 connection_pool: ConnectionPoolGuard {
701 guard: self.connection_pool.lock(),
702 _not_send: PhantomData,
703 },
704 peer: &self.peer,
705 }
706 }
707}
708
709impl<'a> Deref for ConnectionPoolGuard<'a> {
710 type Target = ConnectionPool;
711
712 fn deref(&self) -> &Self::Target {
713 &*self.guard
714 }
715}
716
717impl<'a> DerefMut for ConnectionPoolGuard<'a> {
718 fn deref_mut(&mut self) -> &mut Self::Target {
719 &mut *self.guard
720 }
721}
722
723impl<'a> Drop for ConnectionPoolGuard<'a> {
724 fn drop(&mut self) {
725 #[cfg(test)]
726 self.check_invariants();
727 }
728}
729
730fn broadcast<F>(
731 sender_id: Option<ConnectionId>,
732 receiver_ids: impl IntoIterator<Item = ConnectionId>,
733 mut f: F,
734) where
735 F: FnMut(ConnectionId) -> anyhow::Result<()>,
736{
737 for receiver_id in receiver_ids {
738 if Some(receiver_id) != sender_id {
739 if let Err(error) = f(receiver_id) {
740 tracing::error!("failed to send to {:?} {}", receiver_id, error);
741 }
742 }
743 }
744}
745
746lazy_static! {
747 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
748}
749
750pub struct ProtocolVersion(u32);
751
752impl Header for ProtocolVersion {
753 fn name() -> &'static HeaderName {
754 &ZED_PROTOCOL_VERSION
755 }
756
757 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
758 where
759 Self: Sized,
760 I: Iterator<Item = &'i axum::http::HeaderValue>,
761 {
762 let version = values
763 .next()
764 .ok_or_else(axum::headers::Error::invalid)?
765 .to_str()
766 .map_err(|_| axum::headers::Error::invalid())?
767 .parse()
768 .map_err(|_| axum::headers::Error::invalid())?;
769 Ok(Self(version))
770 }
771
772 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
773 values.extend([self.0.to_string().parse().unwrap()]);
774 }
775}
776
777pub fn routes(server: Arc<Server>) -> Router<Body> {
778 Router::new()
779 .route("/rpc", get(handle_websocket_request))
780 .layer(
781 ServiceBuilder::new()
782 .layer(Extension(server.app_state.clone()))
783 .layer(middleware::from_fn(auth::validate_header)),
784 )
785 .route("/metrics", get(handle_metrics))
786 .layer(Extension(server))
787}
788
789pub async fn handle_websocket_request(
790 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
791 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
792 Extension(server): Extension<Arc<Server>>,
793 Extension(user): Extension<User>,
794 ws: WebSocketUpgrade,
795) -> axum::response::Response {
796 if protocol_version != rpc::PROTOCOL_VERSION {
797 return (
798 StatusCode::UPGRADE_REQUIRED,
799 "client must be upgraded".to_string(),
800 )
801 .into_response();
802 }
803 let socket_address = socket_address.to_string();
804 ws.on_upgrade(move |socket| {
805 use util::ResultExt;
806 let socket = socket
807 .map_ok(to_tungstenite_message)
808 .err_into()
809 .with(|message| async move { Ok(to_axum_message(message)) });
810 let connection = Connection::new(Box::pin(socket));
811 async move {
812 server
813 .handle_connection(connection, socket_address, user, None, Executor::Production)
814 .await
815 .log_err();
816 }
817 })
818}
819
820pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
821 let connections = server
822 .connection_pool
823 .lock()
824 .connections()
825 .filter(|connection| !connection.admin)
826 .count();
827
828 METRIC_CONNECTIONS.set(connections as _);
829
830 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
831 METRIC_SHARED_PROJECTS.set(shared_projects as _);
832
833 let encoder = prometheus::TextEncoder::new();
834 let metric_families = prometheus::gather();
835 let encoded_metrics = encoder
836 .encode_to_string(&metric_families)
837 .map_err(|err| anyhow!("{}", err))?;
838 Ok(encoded_metrics)
839}
840
841#[instrument(err, skip(executor))]
842async fn connection_lost(
843 session: Session,
844 mut teardown: watch::Receiver<()>,
845 executor: Executor,
846) -> Result<()> {
847 session.peer.disconnect(session.connection_id);
848 session
849 .connection_pool()
850 .await
851 .remove_connection(session.connection_id)?;
852
853 session
854 .db()
855 .await
856 .connection_lost(session.connection_id)
857 .await
858 .trace_err();
859
860 futures::select_biased! {
861 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
862 leave_room_for_session(&session).await.trace_err();
863
864 if !session
865 .connection_pool()
866 .await
867 .is_user_online(session.user_id)
868 {
869 let db = session.db().await;
870 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
871 room_updated(&room, &session.peer);
872 }
873 }
874 update_user_contacts(session.user_id, &session).await?;
875 }
876 _ = teardown.changed().fuse() => {}
877 }
878
879 Ok(())
880}
881
882async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
883 response.send(proto::Ack {})?;
884 Ok(())
885}
886
887async fn create_room(
888 _request: proto::CreateRoom,
889 response: Response<proto::CreateRoom>,
890 session: Session,
891) -> Result<()> {
892 let live_kit_room = nanoid::nanoid!(30);
893
894 let live_kit_connection_info = {
895 let live_kit_room = live_kit_room.clone();
896 let live_kit = session.live_kit_client.as_ref();
897
898 util::async_iife!({
899 let live_kit = live_kit?;
900
901 live_kit
902 .create_room(live_kit_room.clone())
903 .await
904 .trace_err()?;
905
906 let token = live_kit
907 .room_token(&live_kit_room, &session.user_id.to_string())
908 .trace_err()?;
909
910 Some(proto::LiveKitConnectionInfo {
911 server_url: live_kit.url().into(),
912 token,
913 })
914 })
915 }
916 .await;
917
918 let room = session
919 .db()
920 .await
921 .create_room(session.user_id, session.connection_id, &live_kit_room)
922 .await?;
923
924 response.send(proto::CreateRoomResponse {
925 room: Some(room.clone()),
926 live_kit_connection_info,
927 })?;
928
929 update_user_contacts(session.user_id, &session).await?;
930 Ok(())
931}
932
933async fn join_room(
934 request: proto::JoinRoom,
935 response: Response<proto::JoinRoom>,
936 session: Session,
937) -> Result<()> {
938 let room_id = RoomId::from_proto(request.id);
939 let joined_room = {
940 let room = session
941 .db()
942 .await
943 .join_room(room_id, session.user_id, session.connection_id)
944 .await?;
945 room_updated(&room.room, &session.peer);
946 room.into_inner()
947 };
948
949 if let Some(channel_id) = joined_room.channel_id {
950 channel_updated(
951 channel_id,
952 &joined_room.room,
953 &joined_room.channel_members,
954 &session.peer,
955 &*session.connection_pool().await,
956 )
957 }
958
959 for connection_id in session
960 .connection_pool()
961 .await
962 .user_connection_ids(session.user_id)
963 {
964 session
965 .peer
966 .send(
967 connection_id,
968 proto::CallCanceled {
969 room_id: room_id.to_proto(),
970 },
971 )
972 .trace_err();
973 }
974
975 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
976 if let Some(token) = live_kit
977 .room_token(
978 &joined_room.room.live_kit_room,
979 &session.user_id.to_string(),
980 )
981 .trace_err()
982 {
983 Some(proto::LiveKitConnectionInfo {
984 server_url: live_kit.url().into(),
985 token,
986 })
987 } else {
988 None
989 }
990 } else {
991 None
992 };
993
994 response.send(proto::JoinRoomResponse {
995 room: Some(joined_room.room),
996 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
997 live_kit_connection_info,
998 })?;
999
1000 update_user_contacts(session.user_id, &session).await?;
1001 Ok(())
1002}
1003
1004async fn rejoin_room(
1005 request: proto::RejoinRoom,
1006 response: Response<proto::RejoinRoom>,
1007 session: Session,
1008) -> Result<()> {
1009 let room;
1010 let channel_id;
1011 let channel_members;
1012 {
1013 let mut rejoined_room = session
1014 .db()
1015 .await
1016 .rejoin_room(request, session.user_id, session.connection_id)
1017 .await?;
1018
1019 response.send(proto::RejoinRoomResponse {
1020 room: Some(rejoined_room.room.clone()),
1021 reshared_projects: rejoined_room
1022 .reshared_projects
1023 .iter()
1024 .map(|project| proto::ResharedProject {
1025 id: project.id.to_proto(),
1026 collaborators: project
1027 .collaborators
1028 .iter()
1029 .map(|collaborator| collaborator.to_proto())
1030 .collect(),
1031 })
1032 .collect(),
1033 rejoined_projects: rejoined_room
1034 .rejoined_projects
1035 .iter()
1036 .map(|rejoined_project| proto::RejoinedProject {
1037 id: rejoined_project.id.to_proto(),
1038 worktrees: rejoined_project
1039 .worktrees
1040 .iter()
1041 .map(|worktree| proto::WorktreeMetadata {
1042 id: worktree.id,
1043 root_name: worktree.root_name.clone(),
1044 visible: worktree.visible,
1045 abs_path: worktree.abs_path.clone(),
1046 })
1047 .collect(),
1048 collaborators: rejoined_project
1049 .collaborators
1050 .iter()
1051 .map(|collaborator| collaborator.to_proto())
1052 .collect(),
1053 language_servers: rejoined_project.language_servers.clone(),
1054 })
1055 .collect(),
1056 })?;
1057 room_updated(&rejoined_room.room, &session.peer);
1058
1059 for project in &rejoined_room.reshared_projects {
1060 for collaborator in &project.collaborators {
1061 session
1062 .peer
1063 .send(
1064 collaborator.connection_id,
1065 proto::UpdateProjectCollaborator {
1066 project_id: project.id.to_proto(),
1067 old_peer_id: Some(project.old_connection_id.into()),
1068 new_peer_id: Some(session.connection_id.into()),
1069 },
1070 )
1071 .trace_err();
1072 }
1073
1074 broadcast(
1075 Some(session.connection_id),
1076 project
1077 .collaborators
1078 .iter()
1079 .map(|collaborator| collaborator.connection_id),
1080 |connection_id| {
1081 session.peer.forward_send(
1082 session.connection_id,
1083 connection_id,
1084 proto::UpdateProject {
1085 project_id: project.id.to_proto(),
1086 worktrees: project.worktrees.clone(),
1087 },
1088 )
1089 },
1090 );
1091 }
1092
1093 for project in &rejoined_room.rejoined_projects {
1094 for collaborator in &project.collaborators {
1095 session
1096 .peer
1097 .send(
1098 collaborator.connection_id,
1099 proto::UpdateProjectCollaborator {
1100 project_id: project.id.to_proto(),
1101 old_peer_id: Some(project.old_connection_id.into()),
1102 new_peer_id: Some(session.connection_id.into()),
1103 },
1104 )
1105 .trace_err();
1106 }
1107 }
1108
1109 for project in &mut rejoined_room.rejoined_projects {
1110 for worktree in mem::take(&mut project.worktrees) {
1111 #[cfg(any(test, feature = "test-support"))]
1112 const MAX_CHUNK_SIZE: usize = 2;
1113 #[cfg(not(any(test, feature = "test-support")))]
1114 const MAX_CHUNK_SIZE: usize = 256;
1115
1116 // Stream this worktree's entries.
1117 let message = proto::UpdateWorktree {
1118 project_id: project.id.to_proto(),
1119 worktree_id: worktree.id,
1120 abs_path: worktree.abs_path.clone(),
1121 root_name: worktree.root_name,
1122 updated_entries: worktree.updated_entries,
1123 removed_entries: worktree.removed_entries,
1124 scan_id: worktree.scan_id,
1125 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1126 updated_repositories: worktree.updated_repositories,
1127 removed_repositories: worktree.removed_repositories,
1128 };
1129 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1130 session.peer.send(session.connection_id, update.clone())?;
1131 }
1132
1133 // Stream this worktree's diagnostics.
1134 for summary in worktree.diagnostic_summaries {
1135 session.peer.send(
1136 session.connection_id,
1137 proto::UpdateDiagnosticSummary {
1138 project_id: project.id.to_proto(),
1139 worktree_id: worktree.id,
1140 summary: Some(summary),
1141 },
1142 )?;
1143 }
1144
1145 for settings_file in worktree.settings_files {
1146 session.peer.send(
1147 session.connection_id,
1148 proto::UpdateWorktreeSettings {
1149 project_id: project.id.to_proto(),
1150 worktree_id: worktree.id,
1151 path: settings_file.path,
1152 content: Some(settings_file.content),
1153 },
1154 )?;
1155 }
1156 }
1157
1158 for language_server in &project.language_servers {
1159 session.peer.send(
1160 session.connection_id,
1161 proto::UpdateLanguageServer {
1162 project_id: project.id.to_proto(),
1163 language_server_id: language_server.id,
1164 variant: Some(
1165 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1166 proto::LspDiskBasedDiagnosticsUpdated {},
1167 ),
1168 ),
1169 },
1170 )?;
1171 }
1172 }
1173
1174 let rejoined_room = rejoined_room.into_inner();
1175
1176 room = rejoined_room.room;
1177 channel_id = rejoined_room.channel_id;
1178 channel_members = rejoined_room.channel_members;
1179 }
1180
1181 if let Some(channel_id) = channel_id {
1182 channel_updated(
1183 channel_id,
1184 &room,
1185 &channel_members,
1186 &session.peer,
1187 &*session.connection_pool().await,
1188 );
1189 }
1190
1191 update_user_contacts(session.user_id, &session).await?;
1192 Ok(())
1193}
1194
1195async fn leave_room(
1196 _: proto::LeaveRoom,
1197 response: Response<proto::LeaveRoom>,
1198 session: Session,
1199) -> Result<()> {
1200 leave_room_for_session(&session).await?;
1201 response.send(proto::Ack {})?;
1202 Ok(())
1203}
1204
1205async fn call(
1206 request: proto::Call,
1207 response: Response<proto::Call>,
1208 session: Session,
1209) -> Result<()> {
1210 let room_id = RoomId::from_proto(request.room_id);
1211 let calling_user_id = session.user_id;
1212 let calling_connection_id = session.connection_id;
1213 let called_user_id = UserId::from_proto(request.called_user_id);
1214 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1215 if !session
1216 .db()
1217 .await
1218 .has_contact(calling_user_id, called_user_id)
1219 .await?
1220 {
1221 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1222 }
1223
1224 let incoming_call = {
1225 let (room, incoming_call) = &mut *session
1226 .db()
1227 .await
1228 .call(
1229 room_id,
1230 calling_user_id,
1231 calling_connection_id,
1232 called_user_id,
1233 initial_project_id,
1234 )
1235 .await?;
1236 room_updated(&room, &session.peer);
1237 mem::take(incoming_call)
1238 };
1239 update_user_contacts(called_user_id, &session).await?;
1240
1241 let mut calls = session
1242 .connection_pool()
1243 .await
1244 .user_connection_ids(called_user_id)
1245 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1246 .collect::<FuturesUnordered<_>>();
1247
1248 while let Some(call_response) = calls.next().await {
1249 match call_response.as_ref() {
1250 Ok(_) => {
1251 response.send(proto::Ack {})?;
1252 return Ok(());
1253 }
1254 Err(_) => {
1255 call_response.trace_err();
1256 }
1257 }
1258 }
1259
1260 {
1261 let room = session
1262 .db()
1263 .await
1264 .call_failed(room_id, called_user_id)
1265 .await?;
1266 room_updated(&room, &session.peer);
1267 }
1268 update_user_contacts(called_user_id, &session).await?;
1269
1270 Err(anyhow!("failed to ring user"))?
1271}
1272
1273async fn cancel_call(
1274 request: proto::CancelCall,
1275 response: Response<proto::CancelCall>,
1276 session: Session,
1277) -> Result<()> {
1278 let called_user_id = UserId::from_proto(request.called_user_id);
1279 let room_id = RoomId::from_proto(request.room_id);
1280 {
1281 let room = session
1282 .db()
1283 .await
1284 .cancel_call(room_id, session.connection_id, called_user_id)
1285 .await?;
1286 room_updated(&room, &session.peer);
1287 }
1288
1289 for connection_id in session
1290 .connection_pool()
1291 .await
1292 .user_connection_ids(called_user_id)
1293 {
1294 session
1295 .peer
1296 .send(
1297 connection_id,
1298 proto::CallCanceled {
1299 room_id: room_id.to_proto(),
1300 },
1301 )
1302 .trace_err();
1303 }
1304 response.send(proto::Ack {})?;
1305
1306 update_user_contacts(called_user_id, &session).await?;
1307 Ok(())
1308}
1309
1310async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1311 let room_id = RoomId::from_proto(message.room_id);
1312 {
1313 let room = session
1314 .db()
1315 .await
1316 .decline_call(Some(room_id), session.user_id)
1317 .await?
1318 .ok_or_else(|| anyhow!("failed to decline call"))?;
1319 room_updated(&room, &session.peer);
1320 }
1321
1322 for connection_id in session
1323 .connection_pool()
1324 .await
1325 .user_connection_ids(session.user_id)
1326 {
1327 session
1328 .peer
1329 .send(
1330 connection_id,
1331 proto::CallCanceled {
1332 room_id: room_id.to_proto(),
1333 },
1334 )
1335 .trace_err();
1336 }
1337 update_user_contacts(session.user_id, &session).await?;
1338 Ok(())
1339}
1340
1341async fn update_participant_location(
1342 request: proto::UpdateParticipantLocation,
1343 response: Response<proto::UpdateParticipantLocation>,
1344 session: Session,
1345) -> Result<()> {
1346 let room_id = RoomId::from_proto(request.room_id);
1347 let location = request
1348 .location
1349 .ok_or_else(|| anyhow!("invalid location"))?;
1350
1351 let db = session.db().await;
1352 let room = db
1353 .update_room_participant_location(room_id, session.connection_id, location)
1354 .await?;
1355
1356 room_updated(&room, &session.peer);
1357 response.send(proto::Ack {})?;
1358 Ok(())
1359}
1360
1361async fn share_project(
1362 request: proto::ShareProject,
1363 response: Response<proto::ShareProject>,
1364 session: Session,
1365) -> Result<()> {
1366 let (project_id, room) = &*session
1367 .db()
1368 .await
1369 .share_project(
1370 RoomId::from_proto(request.room_id),
1371 session.connection_id,
1372 &request.worktrees,
1373 )
1374 .await?;
1375 response.send(proto::ShareProjectResponse {
1376 project_id: project_id.to_proto(),
1377 })?;
1378 room_updated(&room, &session.peer);
1379
1380 Ok(())
1381}
1382
1383async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1384 let project_id = ProjectId::from_proto(message.project_id);
1385
1386 let (room, guest_connection_ids) = &*session
1387 .db()
1388 .await
1389 .unshare_project(project_id, session.connection_id)
1390 .await?;
1391
1392 broadcast(
1393 Some(session.connection_id),
1394 guest_connection_ids.iter().copied(),
1395 |conn_id| session.peer.send(conn_id, message.clone()),
1396 );
1397 room_updated(&room, &session.peer);
1398
1399 Ok(())
1400}
1401
1402async fn join_project(
1403 request: proto::JoinProject,
1404 response: Response<proto::JoinProject>,
1405 session: Session,
1406) -> Result<()> {
1407 let project_id = ProjectId::from_proto(request.project_id);
1408 let guest_user_id = session.user_id;
1409
1410 tracing::info!(%project_id, "join project");
1411
1412 let (project, replica_id) = &mut *session
1413 .db()
1414 .await
1415 .join_project(project_id, session.connection_id)
1416 .await?;
1417
1418 let collaborators = project
1419 .collaborators
1420 .iter()
1421 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1422 .map(|collaborator| collaborator.to_proto())
1423 .collect::<Vec<_>>();
1424
1425 let worktrees = project
1426 .worktrees
1427 .iter()
1428 .map(|(id, worktree)| proto::WorktreeMetadata {
1429 id: *id,
1430 root_name: worktree.root_name.clone(),
1431 visible: worktree.visible,
1432 abs_path: worktree.abs_path.clone(),
1433 })
1434 .collect::<Vec<_>>();
1435
1436 for collaborator in &collaborators {
1437 session
1438 .peer
1439 .send(
1440 collaborator.peer_id.unwrap().into(),
1441 proto::AddProjectCollaborator {
1442 project_id: project_id.to_proto(),
1443 collaborator: Some(proto::Collaborator {
1444 peer_id: Some(session.connection_id.into()),
1445 replica_id: replica_id.0 as u32,
1446 user_id: guest_user_id.to_proto(),
1447 }),
1448 },
1449 )
1450 .trace_err();
1451 }
1452
1453 // First, we send the metadata associated with each worktree.
1454 response.send(proto::JoinProjectResponse {
1455 worktrees: worktrees.clone(),
1456 replica_id: replica_id.0 as u32,
1457 collaborators: collaborators.clone(),
1458 language_servers: project.language_servers.clone(),
1459 })?;
1460
1461 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1462 #[cfg(any(test, feature = "test-support"))]
1463 const MAX_CHUNK_SIZE: usize = 2;
1464 #[cfg(not(any(test, feature = "test-support")))]
1465 const MAX_CHUNK_SIZE: usize = 256;
1466
1467 // Stream this worktree's entries.
1468 let message = proto::UpdateWorktree {
1469 project_id: project_id.to_proto(),
1470 worktree_id,
1471 abs_path: worktree.abs_path.clone(),
1472 root_name: worktree.root_name,
1473 updated_entries: worktree.entries,
1474 removed_entries: Default::default(),
1475 scan_id: worktree.scan_id,
1476 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1477 updated_repositories: worktree.repository_entries.into_values().collect(),
1478 removed_repositories: Default::default(),
1479 };
1480 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1481 session.peer.send(session.connection_id, update.clone())?;
1482 }
1483
1484 // Stream this worktree's diagnostics.
1485 for summary in worktree.diagnostic_summaries {
1486 session.peer.send(
1487 session.connection_id,
1488 proto::UpdateDiagnosticSummary {
1489 project_id: project_id.to_proto(),
1490 worktree_id: worktree.id,
1491 summary: Some(summary),
1492 },
1493 )?;
1494 }
1495
1496 for settings_file in worktree.settings_files {
1497 session.peer.send(
1498 session.connection_id,
1499 proto::UpdateWorktreeSettings {
1500 project_id: project_id.to_proto(),
1501 worktree_id: worktree.id,
1502 path: settings_file.path,
1503 content: Some(settings_file.content),
1504 },
1505 )?;
1506 }
1507 }
1508
1509 for language_server in &project.language_servers {
1510 session.peer.send(
1511 session.connection_id,
1512 proto::UpdateLanguageServer {
1513 project_id: project_id.to_proto(),
1514 language_server_id: language_server.id,
1515 variant: Some(
1516 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1517 proto::LspDiskBasedDiagnosticsUpdated {},
1518 ),
1519 ),
1520 },
1521 )?;
1522 }
1523
1524 Ok(())
1525}
1526
1527async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1528 let sender_id = session.connection_id;
1529 let project_id = ProjectId::from_proto(request.project_id);
1530
1531 let (room, project) = &*session
1532 .db()
1533 .await
1534 .leave_project(project_id, sender_id)
1535 .await?;
1536 tracing::info!(
1537 %project_id,
1538 host_user_id = %project.host_user_id,
1539 host_connection_id = %project.host_connection_id,
1540 "leave project"
1541 );
1542
1543 project_left(&project, &session);
1544 room_updated(&room, &session.peer);
1545
1546 Ok(())
1547}
1548
1549async fn update_project(
1550 request: proto::UpdateProject,
1551 response: Response<proto::UpdateProject>,
1552 session: Session,
1553) -> Result<()> {
1554 let project_id = ProjectId::from_proto(request.project_id);
1555 let (room, guest_connection_ids) = &*session
1556 .db()
1557 .await
1558 .update_project(project_id, session.connection_id, &request.worktrees)
1559 .await?;
1560 broadcast(
1561 Some(session.connection_id),
1562 guest_connection_ids.iter().copied(),
1563 |connection_id| {
1564 session
1565 .peer
1566 .forward_send(session.connection_id, connection_id, request.clone())
1567 },
1568 );
1569 room_updated(&room, &session.peer);
1570 response.send(proto::Ack {})?;
1571
1572 Ok(())
1573}
1574
1575async fn update_worktree(
1576 request: proto::UpdateWorktree,
1577 response: Response<proto::UpdateWorktree>,
1578 session: Session,
1579) -> Result<()> {
1580 let guest_connection_ids = session
1581 .db()
1582 .await
1583 .update_worktree(&request, session.connection_id)
1584 .await?;
1585
1586 broadcast(
1587 Some(session.connection_id),
1588 guest_connection_ids.iter().copied(),
1589 |connection_id| {
1590 session
1591 .peer
1592 .forward_send(session.connection_id, connection_id, request.clone())
1593 },
1594 );
1595 response.send(proto::Ack {})?;
1596 Ok(())
1597}
1598
1599async fn update_diagnostic_summary(
1600 message: proto::UpdateDiagnosticSummary,
1601 session: Session,
1602) -> Result<()> {
1603 let guest_connection_ids = session
1604 .db()
1605 .await
1606 .update_diagnostic_summary(&message, session.connection_id)
1607 .await?;
1608
1609 broadcast(
1610 Some(session.connection_id),
1611 guest_connection_ids.iter().copied(),
1612 |connection_id| {
1613 session
1614 .peer
1615 .forward_send(session.connection_id, connection_id, message.clone())
1616 },
1617 );
1618
1619 Ok(())
1620}
1621
1622async fn update_worktree_settings(
1623 message: proto::UpdateWorktreeSettings,
1624 session: Session,
1625) -> Result<()> {
1626 let guest_connection_ids = session
1627 .db()
1628 .await
1629 .update_worktree_settings(&message, session.connection_id)
1630 .await?;
1631
1632 broadcast(
1633 Some(session.connection_id),
1634 guest_connection_ids.iter().copied(),
1635 |connection_id| {
1636 session
1637 .peer
1638 .forward_send(session.connection_id, connection_id, message.clone())
1639 },
1640 );
1641
1642 Ok(())
1643}
1644
1645async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1646 broadcast_project_message(request.project_id, request, session).await
1647}
1648
1649async fn start_language_server(
1650 request: proto::StartLanguageServer,
1651 session: Session,
1652) -> Result<()> {
1653 let guest_connection_ids = session
1654 .db()
1655 .await
1656 .start_language_server(&request, session.connection_id)
1657 .await?;
1658
1659 broadcast(
1660 Some(session.connection_id),
1661 guest_connection_ids.iter().copied(),
1662 |connection_id| {
1663 session
1664 .peer
1665 .forward_send(session.connection_id, connection_id, request.clone())
1666 },
1667 );
1668 Ok(())
1669}
1670
1671async fn update_language_server(
1672 request: proto::UpdateLanguageServer,
1673 session: Session,
1674) -> Result<()> {
1675 session.executor.record_backtrace();
1676 let project_id = ProjectId::from_proto(request.project_id);
1677 let project_connection_ids = session
1678 .db()
1679 .await
1680 .project_connection_ids(project_id, session.connection_id)
1681 .await?;
1682 broadcast(
1683 Some(session.connection_id),
1684 project_connection_ids.iter().copied(),
1685 |connection_id| {
1686 session
1687 .peer
1688 .forward_send(session.connection_id, connection_id, request.clone())
1689 },
1690 );
1691 Ok(())
1692}
1693
1694async fn forward_project_request<T>(
1695 request: T,
1696 response: Response<T>,
1697 session: Session,
1698) -> Result<()>
1699where
1700 T: EntityMessage + RequestMessage,
1701{
1702 session.executor.record_backtrace();
1703 let project_id = ProjectId::from_proto(request.remote_entity_id());
1704 let host_connection_id = {
1705 let collaborators = session
1706 .db()
1707 .await
1708 .project_collaborators(project_id, session.connection_id)
1709 .await?;
1710 collaborators
1711 .iter()
1712 .find(|collaborator| collaborator.is_host)
1713 .ok_or_else(|| anyhow!("host not found"))?
1714 .connection_id
1715 };
1716
1717 let payload = session
1718 .peer
1719 .forward_request(session.connection_id, host_connection_id, request)
1720 .await?;
1721
1722 response.send(payload)?;
1723 Ok(())
1724}
1725
1726async fn create_buffer_for_peer(
1727 request: proto::CreateBufferForPeer,
1728 session: Session,
1729) -> Result<()> {
1730 session.executor.record_backtrace();
1731 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1732 session
1733 .peer
1734 .forward_send(session.connection_id, peer_id.into(), request)?;
1735 Ok(())
1736}
1737
1738async fn update_buffer(
1739 request: proto::UpdateBuffer,
1740 response: Response<proto::UpdateBuffer>,
1741 session: Session,
1742) -> Result<()> {
1743 session.executor.record_backtrace();
1744 let project_id = ProjectId::from_proto(request.project_id);
1745 let mut guest_connection_ids;
1746 let mut host_connection_id = None;
1747 {
1748 let collaborators = session
1749 .db()
1750 .await
1751 .project_collaborators(project_id, session.connection_id)
1752 .await?;
1753 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1754 for collaborator in collaborators.iter() {
1755 if collaborator.is_host {
1756 host_connection_id = Some(collaborator.connection_id);
1757 } else {
1758 guest_connection_ids.push(collaborator.connection_id);
1759 }
1760 }
1761 }
1762 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1763
1764 session.executor.record_backtrace();
1765 broadcast(
1766 Some(session.connection_id),
1767 guest_connection_ids,
1768 |connection_id| {
1769 session
1770 .peer
1771 .forward_send(session.connection_id, connection_id, request.clone())
1772 },
1773 );
1774 if host_connection_id != session.connection_id {
1775 session
1776 .peer
1777 .forward_request(session.connection_id, host_connection_id, request.clone())
1778 .await?;
1779 }
1780
1781 response.send(proto::Ack {})?;
1782 Ok(())
1783}
1784
1785async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1786 let project_id = ProjectId::from_proto(request.project_id);
1787 let project_connection_ids = session
1788 .db()
1789 .await
1790 .project_connection_ids(project_id, session.connection_id)
1791 .await?;
1792
1793 broadcast(
1794 Some(session.connection_id),
1795 project_connection_ids.iter().copied(),
1796 |connection_id| {
1797 session
1798 .peer
1799 .forward_send(session.connection_id, connection_id, request.clone())
1800 },
1801 );
1802 Ok(())
1803}
1804
1805async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1806 let project_id = ProjectId::from_proto(request.project_id);
1807 let project_connection_ids = session
1808 .db()
1809 .await
1810 .project_connection_ids(project_id, session.connection_id)
1811 .await?;
1812 broadcast(
1813 Some(session.connection_id),
1814 project_connection_ids.iter().copied(),
1815 |connection_id| {
1816 session
1817 .peer
1818 .forward_send(session.connection_id, connection_id, request.clone())
1819 },
1820 );
1821 Ok(())
1822}
1823
1824async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1825 broadcast_project_message(request.project_id, request, session).await
1826}
1827
1828async fn broadcast_project_message<T: EnvelopedMessage>(
1829 project_id: u64,
1830 request: T,
1831 session: Session,
1832) -> Result<()> {
1833 let project_id = ProjectId::from_proto(project_id);
1834 let project_connection_ids = session
1835 .db()
1836 .await
1837 .project_connection_ids(project_id, session.connection_id)
1838 .await?;
1839 broadcast(
1840 Some(session.connection_id),
1841 project_connection_ids.iter().copied(),
1842 |connection_id| {
1843 session
1844 .peer
1845 .forward_send(session.connection_id, connection_id, request.clone())
1846 },
1847 );
1848 Ok(())
1849}
1850
1851async fn follow(
1852 request: proto::Follow,
1853 response: Response<proto::Follow>,
1854 session: Session,
1855) -> Result<()> {
1856 let project_id = ProjectId::from_proto(request.project_id);
1857 let leader_id = request
1858 .leader_id
1859 .ok_or_else(|| anyhow!("invalid leader id"))?
1860 .into();
1861 let follower_id = session.connection_id;
1862
1863 {
1864 let project_connection_ids = session
1865 .db()
1866 .await
1867 .project_connection_ids(project_id, session.connection_id)
1868 .await?;
1869
1870 if !project_connection_ids.contains(&leader_id) {
1871 Err(anyhow!("no such peer"))?;
1872 }
1873 }
1874
1875 let mut response_payload = session
1876 .peer
1877 .forward_request(session.connection_id, leader_id, request)
1878 .await?;
1879 response_payload
1880 .views
1881 .retain(|view| view.leader_id != Some(follower_id.into()));
1882 response.send(response_payload)?;
1883
1884 let room = session
1885 .db()
1886 .await
1887 .follow(project_id, leader_id, follower_id)
1888 .await?;
1889 room_updated(&room, &session.peer);
1890
1891 Ok(())
1892}
1893
1894async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1895 let project_id = ProjectId::from_proto(request.project_id);
1896 let leader_id = request
1897 .leader_id
1898 .ok_or_else(|| anyhow!("invalid leader id"))?
1899 .into();
1900 let follower_id = session.connection_id;
1901
1902 if !session
1903 .db()
1904 .await
1905 .project_connection_ids(project_id, session.connection_id)
1906 .await?
1907 .contains(&leader_id)
1908 {
1909 Err(anyhow!("no such peer"))?;
1910 }
1911
1912 session
1913 .peer
1914 .forward_send(session.connection_id, leader_id, request)?;
1915
1916 let room = session
1917 .db()
1918 .await
1919 .unfollow(project_id, leader_id, follower_id)
1920 .await?;
1921 room_updated(&room, &session.peer);
1922
1923 Ok(())
1924}
1925
1926async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1927 let project_id = ProjectId::from_proto(request.project_id);
1928 let project_connection_ids = session
1929 .db
1930 .lock()
1931 .await
1932 .project_connection_ids(project_id, session.connection_id)
1933 .await?;
1934
1935 let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1936 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1937 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1938 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1939 });
1940 for follower_peer_id in request.follower_ids.iter().copied() {
1941 let follower_connection_id = follower_peer_id.into();
1942 if project_connection_ids.contains(&follower_connection_id)
1943 && Some(follower_peer_id) != leader_id
1944 {
1945 session.peer.forward_send(
1946 session.connection_id,
1947 follower_connection_id,
1948 request.clone(),
1949 )?;
1950 }
1951 }
1952 Ok(())
1953}
1954
1955async fn get_users(
1956 request: proto::GetUsers,
1957 response: Response<proto::GetUsers>,
1958 session: Session,
1959) -> Result<()> {
1960 let user_ids = request
1961 .user_ids
1962 .into_iter()
1963 .map(UserId::from_proto)
1964 .collect();
1965 let users = session
1966 .db()
1967 .await
1968 .get_users_by_ids(user_ids)
1969 .await?
1970 .into_iter()
1971 .map(|user| proto::User {
1972 id: user.id.to_proto(),
1973 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1974 github_login: user.github_login,
1975 })
1976 .collect();
1977 response.send(proto::UsersResponse { users })?;
1978 Ok(())
1979}
1980
1981async fn fuzzy_search_users(
1982 request: proto::FuzzySearchUsers,
1983 response: Response<proto::FuzzySearchUsers>,
1984 session: Session,
1985) -> Result<()> {
1986 let query = request.query;
1987 let users = match query.len() {
1988 0 => vec![],
1989 1 | 2 => session
1990 .db()
1991 .await
1992 .get_user_by_github_login(&query)
1993 .await?
1994 .into_iter()
1995 .collect(),
1996 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1997 };
1998 let users = users
1999 .into_iter()
2000 .filter(|user| user.id != session.user_id)
2001 .map(|user| proto::User {
2002 id: user.id.to_proto(),
2003 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2004 github_login: user.github_login,
2005 })
2006 .collect();
2007 response.send(proto::UsersResponse { users })?;
2008 Ok(())
2009}
2010
2011async fn request_contact(
2012 request: proto::RequestContact,
2013 response: Response<proto::RequestContact>,
2014 session: Session,
2015) -> Result<()> {
2016 let requester_id = session.user_id;
2017 let responder_id = UserId::from_proto(request.responder_id);
2018 if requester_id == responder_id {
2019 return Err(anyhow!("cannot add yourself as a contact"))?;
2020 }
2021
2022 session
2023 .db()
2024 .await
2025 .send_contact_request(requester_id, responder_id)
2026 .await?;
2027
2028 // Update outgoing contact requests of requester
2029 let mut update = proto::UpdateContacts::default();
2030 update.outgoing_requests.push(responder_id.to_proto());
2031 for connection_id in session
2032 .connection_pool()
2033 .await
2034 .user_connection_ids(requester_id)
2035 {
2036 session.peer.send(connection_id, update.clone())?;
2037 }
2038
2039 // Update incoming contact requests of responder
2040 let mut update = proto::UpdateContacts::default();
2041 update
2042 .incoming_requests
2043 .push(proto::IncomingContactRequest {
2044 requester_id: requester_id.to_proto(),
2045 should_notify: true,
2046 });
2047 for connection_id in session
2048 .connection_pool()
2049 .await
2050 .user_connection_ids(responder_id)
2051 {
2052 session.peer.send(connection_id, update.clone())?;
2053 }
2054
2055 response.send(proto::Ack {})?;
2056 Ok(())
2057}
2058
2059async fn respond_to_contact_request(
2060 request: proto::RespondToContactRequest,
2061 response: Response<proto::RespondToContactRequest>,
2062 session: Session,
2063) -> Result<()> {
2064 let responder_id = session.user_id;
2065 let requester_id = UserId::from_proto(request.requester_id);
2066 let db = session.db().await;
2067 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2068 db.dismiss_contact_notification(responder_id, requester_id)
2069 .await?;
2070 } else {
2071 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2072
2073 db.respond_to_contact_request(responder_id, requester_id, accept)
2074 .await?;
2075 let requester_busy = db.is_user_busy(requester_id).await?;
2076 let responder_busy = db.is_user_busy(responder_id).await?;
2077
2078 let pool = session.connection_pool().await;
2079 // Update responder with new contact
2080 let mut update = proto::UpdateContacts::default();
2081 if accept {
2082 update
2083 .contacts
2084 .push(contact_for_user(requester_id, false, requester_busy, &pool));
2085 }
2086 update
2087 .remove_incoming_requests
2088 .push(requester_id.to_proto());
2089 for connection_id in pool.user_connection_ids(responder_id) {
2090 session.peer.send(connection_id, update.clone())?;
2091 }
2092
2093 // Update requester with new contact
2094 let mut update = proto::UpdateContacts::default();
2095 if accept {
2096 update
2097 .contacts
2098 .push(contact_for_user(responder_id, true, responder_busy, &pool));
2099 }
2100 update
2101 .remove_outgoing_requests
2102 .push(responder_id.to_proto());
2103 for connection_id in pool.user_connection_ids(requester_id) {
2104 session.peer.send(connection_id, update.clone())?;
2105 }
2106 }
2107
2108 response.send(proto::Ack {})?;
2109 Ok(())
2110}
2111
2112async fn remove_contact(
2113 request: proto::RemoveContact,
2114 response: Response<proto::RemoveContact>,
2115 session: Session,
2116) -> Result<()> {
2117 let requester_id = session.user_id;
2118 let responder_id = UserId::from_proto(request.user_id);
2119 let db = session.db().await;
2120 let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
2121
2122 let pool = session.connection_pool().await;
2123 // Update outgoing contact requests of requester
2124 let mut update = proto::UpdateContacts::default();
2125 if contact_accepted {
2126 update.remove_contacts.push(responder_id.to_proto());
2127 } else {
2128 update
2129 .remove_outgoing_requests
2130 .push(responder_id.to_proto());
2131 }
2132 for connection_id in pool.user_connection_ids(requester_id) {
2133 session.peer.send(connection_id, update.clone())?;
2134 }
2135
2136 // Update incoming contact requests of responder
2137 let mut update = proto::UpdateContacts::default();
2138 if contact_accepted {
2139 update.remove_contacts.push(requester_id.to_proto());
2140 } else {
2141 update
2142 .remove_incoming_requests
2143 .push(requester_id.to_proto());
2144 }
2145 for connection_id in pool.user_connection_ids(responder_id) {
2146 session.peer.send(connection_id, update.clone())?;
2147 }
2148
2149 response.send(proto::Ack {})?;
2150 Ok(())
2151}
2152
2153async fn create_channel(
2154 request: proto::CreateChannel,
2155 response: Response<proto::CreateChannel>,
2156 session: Session,
2157) -> Result<()> {
2158 let db = session.db().await;
2159 let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
2160
2161 if let Some(live_kit) = session.live_kit_client.as_ref() {
2162 live_kit.create_room(live_kit_room.clone()).await?;
2163 }
2164
2165 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2166 let id = db
2167 .create_channel(&request.name, parent_id, &live_kit_room, session.user_id)
2168 .await?;
2169
2170 let channel = proto::Channel {
2171 id: id.to_proto(),
2172 name: request.name,
2173 parent_id: request.parent_id,
2174 };
2175
2176 response.send(proto::ChannelResponse {
2177 channel: Some(channel.clone()),
2178 })?;
2179
2180 let mut update = proto::UpdateChannels::default();
2181 update.channels.push(channel);
2182
2183 let user_ids_to_notify = if let Some(parent_id) = parent_id {
2184 db.get_channel_members(parent_id).await?
2185 } else {
2186 vec![session.user_id]
2187 };
2188
2189 let connection_pool = session.connection_pool().await;
2190 for user_id in user_ids_to_notify {
2191 for connection_id in connection_pool.user_connection_ids(user_id) {
2192 let mut update = update.clone();
2193 if user_id == session.user_id {
2194 update.channel_permissions.push(proto::ChannelPermission {
2195 channel_id: id.to_proto(),
2196 is_admin: true,
2197 });
2198 }
2199 session.peer.send(connection_id, update)?;
2200 }
2201 }
2202
2203 Ok(())
2204}
2205
2206async fn remove_channel(
2207 request: proto::RemoveChannel,
2208 response: Response<proto::RemoveChannel>,
2209 session: Session,
2210) -> Result<()> {
2211 let db = session.db().await;
2212
2213 let channel_id = request.channel_id;
2214 let (removed_channels, member_ids) = db
2215 .remove_channel(ChannelId::from_proto(channel_id), session.user_id)
2216 .await?;
2217 response.send(proto::Ack {})?;
2218
2219 // Notify members of removed channels
2220 let mut update = proto::UpdateChannels::default();
2221 update
2222 .remove_channels
2223 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2224
2225 let connection_pool = session.connection_pool().await;
2226 for member_id in member_ids {
2227 for connection_id in connection_pool.user_connection_ids(member_id) {
2228 session.peer.send(connection_id, update.clone())?;
2229 }
2230 }
2231
2232 Ok(())
2233}
2234
2235async fn invite_channel_member(
2236 request: proto::InviteChannelMember,
2237 response: Response<proto::InviteChannelMember>,
2238 session: Session,
2239) -> Result<()> {
2240 let db = session.db().await;
2241 let channel_id = ChannelId::from_proto(request.channel_id);
2242 let invitee_id = UserId::from_proto(request.user_id);
2243 db.invite_channel_member(channel_id, invitee_id, session.user_id, request.admin)
2244 .await?;
2245
2246 let (channel, _) = db
2247 .get_channel(channel_id, session.user_id)
2248 .await?
2249 .ok_or_else(|| anyhow!("channel not found"))?;
2250
2251 let mut update = proto::UpdateChannels::default();
2252 update.channel_invitations.push(proto::Channel {
2253 id: channel.id.to_proto(),
2254 name: channel.name,
2255 parent_id: None,
2256 });
2257 for connection_id in session
2258 .connection_pool()
2259 .await
2260 .user_connection_ids(invitee_id)
2261 {
2262 session.peer.send(connection_id, update.clone())?;
2263 }
2264
2265 response.send(proto::Ack {})?;
2266 Ok(())
2267}
2268
2269async fn remove_channel_member(
2270 request: proto::RemoveChannelMember,
2271 response: Response<proto::RemoveChannelMember>,
2272 session: Session,
2273) -> Result<()> {
2274 let db = session.db().await;
2275 let channel_id = ChannelId::from_proto(request.channel_id);
2276 let member_id = UserId::from_proto(request.user_id);
2277
2278 db.remove_channel_member(channel_id, member_id, session.user_id)
2279 .await?;
2280
2281 let mut update = proto::UpdateChannels::default();
2282 update.remove_channels.push(channel_id.to_proto());
2283
2284 for connection_id in session
2285 .connection_pool()
2286 .await
2287 .user_connection_ids(member_id)
2288 {
2289 session.peer.send(connection_id, update.clone())?;
2290 }
2291
2292 response.send(proto::Ack {})?;
2293 Ok(())
2294}
2295
2296async fn set_channel_member_admin(
2297 request: proto::SetChannelMemberAdmin,
2298 response: Response<proto::SetChannelMemberAdmin>,
2299 session: Session,
2300) -> Result<()> {
2301 let db = session.db().await;
2302 let channel_id = ChannelId::from_proto(request.channel_id);
2303 let member_id = UserId::from_proto(request.user_id);
2304 db.set_channel_member_admin(channel_id, session.user_id, member_id, request.admin)
2305 .await?;
2306
2307 let (channel, has_accepted) = db
2308 .get_channel(channel_id, member_id)
2309 .await?
2310 .ok_or_else(|| anyhow!("channel not found"))?;
2311
2312 let mut update = proto::UpdateChannels::default();
2313 if has_accepted {
2314 update.channel_permissions.push(proto::ChannelPermission {
2315 channel_id: channel.id.to_proto(),
2316 is_admin: request.admin,
2317 });
2318 }
2319
2320 for connection_id in session
2321 .connection_pool()
2322 .await
2323 .user_connection_ids(member_id)
2324 {
2325 session.peer.send(connection_id, update.clone())?;
2326 }
2327
2328 response.send(proto::Ack {})?;
2329 Ok(())
2330}
2331
2332async fn rename_channel(
2333 request: proto::RenameChannel,
2334 response: Response<proto::RenameChannel>,
2335 session: Session,
2336) -> Result<()> {
2337 let db = session.db().await;
2338 let channel_id = ChannelId::from_proto(request.channel_id);
2339 let new_name = db
2340 .rename_channel(channel_id, session.user_id, &request.name)
2341 .await?;
2342
2343 let channel = proto::Channel {
2344 id: request.channel_id,
2345 name: new_name,
2346 parent_id: None,
2347 };
2348 response.send(proto::ChannelResponse {
2349 channel: Some(channel.clone()),
2350 })?;
2351 let mut update = proto::UpdateChannels::default();
2352 update.channels.push(channel);
2353
2354 let member_ids = db.get_channel_members(channel_id).await?;
2355
2356 let connection_pool = session.connection_pool().await;
2357 for member_id in member_ids {
2358 for connection_id in connection_pool.user_connection_ids(member_id) {
2359 session.peer.send(connection_id, update.clone())?;
2360 }
2361 }
2362
2363 Ok(())
2364}
2365
2366async fn get_channel_members(
2367 request: proto::GetChannelMembers,
2368 response: Response<proto::GetChannelMembers>,
2369 session: Session,
2370) -> Result<()> {
2371 let db = session.db().await;
2372 let channel_id = ChannelId::from_proto(request.channel_id);
2373 let members = db
2374 .get_channel_member_details(channel_id, session.user_id)
2375 .await?;
2376 response.send(proto::GetChannelMembersResponse { members })?;
2377 Ok(())
2378}
2379
2380async fn respond_to_channel_invite(
2381 request: proto::RespondToChannelInvite,
2382 response: Response<proto::RespondToChannelInvite>,
2383 session: Session,
2384) -> Result<()> {
2385 let db = session.db().await;
2386 let channel_id = ChannelId::from_proto(request.channel_id);
2387 db.respond_to_channel_invite(channel_id, session.user_id, request.accept)
2388 .await?;
2389
2390 let mut update = proto::UpdateChannels::default();
2391 update
2392 .remove_channel_invitations
2393 .push(channel_id.to_proto());
2394 if request.accept {
2395 let result = db.get_channels_for_user(session.user_id).await?;
2396 update
2397 .channels
2398 .extend(result.channels.into_iter().map(|channel| proto::Channel {
2399 id: channel.id.to_proto(),
2400 name: channel.name,
2401 parent_id: channel.parent_id.map(ChannelId::to_proto),
2402 }));
2403 update
2404 .channel_participants
2405 .extend(
2406 result
2407 .channel_participants
2408 .into_iter()
2409 .map(|(channel_id, user_ids)| proto::ChannelParticipants {
2410 channel_id: channel_id.to_proto(),
2411 participant_user_ids: user_ids.into_iter().map(UserId::to_proto).collect(),
2412 }),
2413 );
2414 update
2415 .channel_permissions
2416 .extend(
2417 result
2418 .channels_with_admin_privileges
2419 .into_iter()
2420 .map(|channel_id| proto::ChannelPermission {
2421 channel_id: channel_id.to_proto(),
2422 is_admin: true,
2423 }),
2424 );
2425 }
2426 session.peer.send(session.connection_id, update)?;
2427 response.send(proto::Ack {})?;
2428
2429 Ok(())
2430}
2431
2432async fn join_channel(
2433 request: proto::JoinChannel,
2434 response: Response<proto::JoinChannel>,
2435 session: Session,
2436) -> Result<()> {
2437 let channel_id = ChannelId::from_proto(request.channel_id);
2438
2439 let joined_room = {
2440 leave_room_for_session(&session).await?;
2441 let db = session.db().await;
2442
2443 let room_id = db.room_id_for_channel(channel_id).await?;
2444
2445 let joined_room = db
2446 .join_room(room_id, session.user_id, session.connection_id)
2447 .await?;
2448
2449 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2450 let token = live_kit
2451 .room_token(
2452 &joined_room.room.live_kit_room,
2453 &session.user_id.to_string(),
2454 )
2455 .trace_err()?;
2456
2457 Some(LiveKitConnectionInfo {
2458 server_url: live_kit.url().into(),
2459 token,
2460 })
2461 });
2462
2463 response.send(proto::JoinRoomResponse {
2464 room: Some(joined_room.room.clone()),
2465 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2466 live_kit_connection_info,
2467 })?;
2468
2469 room_updated(&joined_room.room, &session.peer);
2470
2471 joined_room.into_inner()
2472 };
2473
2474 channel_updated(
2475 channel_id,
2476 &joined_room.room,
2477 &joined_room.channel_members,
2478 &session.peer,
2479 &*session.connection_pool().await,
2480 );
2481
2482 update_user_contacts(session.user_id, &session).await?;
2483
2484 Ok(())
2485}
2486
2487async fn open_channel_buffer(
2488 request: proto::OpenChannelBuffer,
2489 response: Response<proto::OpenChannelBuffer>,
2490 session: Session,
2491) -> Result<()> {
2492 let db = session.db().await;
2493 let channel_id = ChannelId::from_proto(request.channel_id);
2494
2495 let buffer_id = db.get_or_create_buffer_for_channel(channel_id).await?;
2496
2497 // TODO: join channel_buffer
2498
2499 let buffer = db.open_buffer(buffer_id).await?;
2500
2501 response.send(OpenChannelBufferResponse {
2502 buffer_id: buffer_id.to_proto(),
2503 base_text: buffer.base_text,
2504 operations: buffer.operations,
2505 })?;
2506
2507 Ok(())
2508}
2509
2510async fn close_channel_buffer(
2511 request: proto::CloseChannelBuffer,
2512 response: Response<proto::CloseChannelBuffer>,
2513 session: Session,
2514) -> Result<()> {
2515 let db = session.db().await;
2516 let buffer_id = BufferId::from_proto(request.buffer_id);
2517
2518 // TODO: close channel buffer here
2519 //
2520 response.send(Ack {})?;
2521
2522 Ok(())
2523}
2524
2525async fn update_channel_buffer(
2526 request: proto::UpdateChannelBuffer,
2527 session: Session,
2528) -> Result<()> {
2529 let db = session.db().await;
2530
2531 // TODO: Broadcast to buffer members
2532
2533 Ok(())
2534}
2535
2536async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
2537 let project_id = ProjectId::from_proto(request.project_id);
2538 let project_connection_ids = session
2539 .db()
2540 .await
2541 .project_connection_ids(project_id, session.connection_id)
2542 .await?;
2543 broadcast(
2544 Some(session.connection_id),
2545 project_connection_ids.iter().copied(),
2546 |connection_id| {
2547 session
2548 .peer
2549 .forward_send(session.connection_id, connection_id, request.clone())
2550 },
2551 );
2552 Ok(())
2553}
2554
2555async fn get_private_user_info(
2556 _request: proto::GetPrivateUserInfo,
2557 response: Response<proto::GetPrivateUserInfo>,
2558 session: Session,
2559) -> Result<()> {
2560 let metrics_id = session
2561 .db()
2562 .await
2563 .get_user_metrics_id(session.user_id)
2564 .await?;
2565 let user = session
2566 .db()
2567 .await
2568 .get_user_by_id(session.user_id)
2569 .await?
2570 .ok_or_else(|| anyhow!("user not found"))?;
2571 response.send(proto::GetPrivateUserInfoResponse {
2572 metrics_id,
2573 staff: user.admin,
2574 })?;
2575 Ok(())
2576}
2577
2578fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
2579 match message {
2580 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
2581 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
2582 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
2583 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
2584 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
2585 code: frame.code.into(),
2586 reason: frame.reason,
2587 })),
2588 }
2589}
2590
2591fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
2592 match message {
2593 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
2594 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
2595 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
2596 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
2597 AxumMessage::Close(frame) => {
2598 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
2599 code: frame.code.into(),
2600 reason: frame.reason,
2601 }))
2602 }
2603 }
2604}
2605
2606fn build_initial_channels_update(
2607 channels: ChannelsForUser,
2608 channel_invites: Vec<db::Channel>,
2609) -> proto::UpdateChannels {
2610 let mut update = proto::UpdateChannels::default();
2611
2612 for channel in channels.channels {
2613 update.channels.push(proto::Channel {
2614 id: channel.id.to_proto(),
2615 name: channel.name,
2616 parent_id: channel.parent_id.map(|id| id.to_proto()),
2617 });
2618 }
2619
2620 for (channel_id, participants) in channels.channel_participants {
2621 update
2622 .channel_participants
2623 .push(proto::ChannelParticipants {
2624 channel_id: channel_id.to_proto(),
2625 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
2626 });
2627 }
2628
2629 update
2630 .channel_permissions
2631 .extend(
2632 channels
2633 .channels_with_admin_privileges
2634 .into_iter()
2635 .map(|id| proto::ChannelPermission {
2636 channel_id: id.to_proto(),
2637 is_admin: true,
2638 }),
2639 );
2640
2641 for channel in channel_invites {
2642 update.channel_invitations.push(proto::Channel {
2643 id: channel.id.to_proto(),
2644 name: channel.name,
2645 parent_id: None,
2646 });
2647 }
2648
2649 update
2650}
2651
2652fn build_initial_contacts_update(
2653 contacts: Vec<db::Contact>,
2654 pool: &ConnectionPool,
2655) -> proto::UpdateContacts {
2656 let mut update = proto::UpdateContacts::default();
2657
2658 for contact in contacts {
2659 match contact {
2660 db::Contact::Accepted {
2661 user_id,
2662 should_notify,
2663 busy,
2664 } => {
2665 update
2666 .contacts
2667 .push(contact_for_user(user_id, should_notify, busy, &pool));
2668 }
2669 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
2670 db::Contact::Incoming {
2671 user_id,
2672 should_notify,
2673 } => update
2674 .incoming_requests
2675 .push(proto::IncomingContactRequest {
2676 requester_id: user_id.to_proto(),
2677 should_notify,
2678 }),
2679 }
2680 }
2681
2682 update
2683}
2684
2685fn contact_for_user(
2686 user_id: UserId,
2687 should_notify: bool,
2688 busy: bool,
2689 pool: &ConnectionPool,
2690) -> proto::Contact {
2691 proto::Contact {
2692 user_id: user_id.to_proto(),
2693 online: pool.is_user_online(user_id),
2694 busy,
2695 should_notify,
2696 }
2697}
2698
2699fn room_updated(room: &proto::Room, peer: &Peer) {
2700 broadcast(
2701 None,
2702 room.participants
2703 .iter()
2704 .filter_map(|participant| Some(participant.peer_id?.into())),
2705 |peer_id| {
2706 peer.send(
2707 peer_id.into(),
2708 proto::RoomUpdated {
2709 room: Some(room.clone()),
2710 },
2711 )
2712 },
2713 );
2714}
2715
2716fn channel_updated(
2717 channel_id: ChannelId,
2718 room: &proto::Room,
2719 channel_members: &[UserId],
2720 peer: &Peer,
2721 pool: &ConnectionPool,
2722) {
2723 let participants = room
2724 .participants
2725 .iter()
2726 .map(|p| p.user_id)
2727 .collect::<Vec<_>>();
2728
2729 broadcast(
2730 None,
2731 channel_members
2732 .iter()
2733 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2734 |peer_id| {
2735 peer.send(
2736 peer_id.into(),
2737 proto::UpdateChannels {
2738 channel_participants: vec![proto::ChannelParticipants {
2739 channel_id: channel_id.to_proto(),
2740 participant_user_ids: participants.clone(),
2741 }],
2742 ..Default::default()
2743 },
2744 )
2745 },
2746 );
2747}
2748
2749async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
2750 let db = session.db().await;
2751
2752 let contacts = db.get_contacts(user_id).await?;
2753 let busy = db.is_user_busy(user_id).await?;
2754
2755 let pool = session.connection_pool().await;
2756 let updated_contact = contact_for_user(user_id, false, busy, &pool);
2757 for contact in contacts {
2758 if let db::Contact::Accepted {
2759 user_id: contact_user_id,
2760 ..
2761 } = contact
2762 {
2763 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
2764 session
2765 .peer
2766 .send(
2767 contact_conn_id,
2768 proto::UpdateContacts {
2769 contacts: vec![updated_contact.clone()],
2770 remove_contacts: Default::default(),
2771 incoming_requests: Default::default(),
2772 remove_incoming_requests: Default::default(),
2773 outgoing_requests: Default::default(),
2774 remove_outgoing_requests: Default::default(),
2775 },
2776 )
2777 .trace_err();
2778 }
2779 }
2780 }
2781 Ok(())
2782}
2783
2784async fn leave_room_for_session(session: &Session) -> Result<()> {
2785 let mut contacts_to_update = HashSet::default();
2786
2787 let room_id;
2788 let canceled_calls_to_user_ids;
2789 let live_kit_room;
2790 let delete_live_kit_room;
2791 let room;
2792 let channel_members;
2793 let channel_id;
2794
2795 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
2796 contacts_to_update.insert(session.user_id);
2797
2798 for project in left_room.left_projects.values() {
2799 project_left(project, session);
2800 }
2801
2802 room_id = RoomId::from_proto(left_room.room.id);
2803 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
2804 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
2805 delete_live_kit_room = left_room.deleted;
2806 room = mem::take(&mut left_room.room);
2807 channel_members = mem::take(&mut left_room.channel_members);
2808 channel_id = left_room.channel_id;
2809
2810 room_updated(&room, &session.peer);
2811 } else {
2812 return Ok(());
2813 }
2814
2815 if let Some(channel_id) = channel_id {
2816 channel_updated(
2817 channel_id,
2818 &room,
2819 &channel_members,
2820 &session.peer,
2821 &*session.connection_pool().await,
2822 );
2823 }
2824
2825 {
2826 let pool = session.connection_pool().await;
2827 for canceled_user_id in canceled_calls_to_user_ids {
2828 for connection_id in pool.user_connection_ids(canceled_user_id) {
2829 session
2830 .peer
2831 .send(
2832 connection_id,
2833 proto::CallCanceled {
2834 room_id: room_id.to_proto(),
2835 },
2836 )
2837 .trace_err();
2838 }
2839 contacts_to_update.insert(canceled_user_id);
2840 }
2841 }
2842
2843 for contact_user_id in contacts_to_update {
2844 update_user_contacts(contact_user_id, &session).await?;
2845 }
2846
2847 if let Some(live_kit) = session.live_kit_client.as_ref() {
2848 live_kit
2849 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
2850 .await
2851 .trace_err();
2852
2853 if delete_live_kit_room {
2854 live_kit.delete_room(live_kit_room).await.trace_err();
2855 }
2856 }
2857
2858 Ok(())
2859}
2860
2861fn project_left(project: &db::LeftProject, session: &Session) {
2862 for connection_id in &project.connection_ids {
2863 if project.host_user_id == session.user_id {
2864 session
2865 .peer
2866 .send(
2867 *connection_id,
2868 proto::UnshareProject {
2869 project_id: project.id.to_proto(),
2870 },
2871 )
2872 .trace_err();
2873 } else {
2874 session
2875 .peer
2876 .send(
2877 *connection_id,
2878 proto::RemoveProjectCollaborator {
2879 project_id: project.id.to_proto(),
2880 peer_id: Some(session.connection_id.into()),
2881 },
2882 )
2883 .trace_err();
2884 }
2885 }
2886}
2887
2888pub trait ResultExt {
2889 type Ok;
2890
2891 fn trace_err(self) -> Option<Self::Ok>;
2892}
2893
2894impl<T, E> ResultExt for Result<T, E>
2895where
2896 E: std::fmt::Debug,
2897{
2898 type Ok = T;
2899
2900 fn trace_err(self) -> Option<T> {
2901 match self {
2902 Ok(value) => Some(value),
2903 Err(error) => {
2904 tracing::error!("{:?}", error);
2905 None
2906 }
2907 }
2908 }
2909}