1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{self, ChannelId, ChannelsForUser, Database, ProjectId, RoomId, ServerId, User, UserId},
6 executor::Executor,
7 AppState, Result,
8};
9use anyhow::anyhow;
10use async_tungstenite::tungstenite::{
11 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
12};
13use axum::{
14 body::Body,
15 extract::{
16 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
17 ConnectInfo, WebSocketUpgrade,
18 },
19 headers::{Header, HeaderName},
20 http::StatusCode,
21 middleware,
22 response::IntoResponse,
23 routing::get,
24 Extension, Router, TypedHeader,
25};
26use collections::{HashMap, HashSet};
27pub use connection_pool::ConnectionPool;
28use futures::{
29 channel::oneshot,
30 future::{self, BoxFuture},
31 stream::FuturesUnordered,
32 FutureExt, SinkExt, StreamExt, TryStreamExt,
33};
34use lazy_static::lazy_static;
35use prometheus::{register_int_gauge, IntGauge};
36use rpc::{
37 proto::{
38 self, Ack, AddChannelBufferCollaborator, AnyTypedEnvelope, EntityMessage, EnvelopedMessage,
39 LiveKitConnectionInfo, RequestMessage,
40 },
41 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
42};
43use serde::{Serialize, Serializer};
44use std::{
45 any::TypeId,
46 fmt,
47 future::Future,
48 marker::PhantomData,
49 mem,
50 net::SocketAddr,
51 ops::{Deref, DerefMut},
52 rc::Rc,
53 sync::{
54 atomic::{AtomicBool, Ordering::SeqCst},
55 Arc,
56 },
57 time::{Duration, Instant},
58};
59use tokio::sync::{watch, Semaphore};
60use tower::ServiceBuilder;
61use tracing::{info_span, instrument, Instrument};
62
63pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
64pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
65
66lazy_static! {
67 static ref METRIC_CONNECTIONS: IntGauge =
68 register_int_gauge!("connections", "number of connections").unwrap();
69 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
70 "shared_projects",
71 "number of open projects with one or more guests"
72 )
73 .unwrap();
74}
75
76type MessageHandler =
77 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
78
79struct Response<R> {
80 peer: Arc<Peer>,
81 receipt: Receipt<R>,
82 responded: Arc<AtomicBool>,
83}
84
85impl<R: RequestMessage> Response<R> {
86 fn send(self, payload: R::Response) -> Result<()> {
87 self.responded.store(true, SeqCst);
88 self.peer.respond(self.receipt, payload)?;
89 Ok(())
90 }
91}
92
93#[derive(Clone)]
94struct Session {
95 user_id: UserId,
96 connection_id: ConnectionId,
97 db: Arc<tokio::sync::Mutex<DbHandle>>,
98 peer: Arc<Peer>,
99 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
100 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
101 executor: Executor,
102}
103
104impl Session {
105 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
106 #[cfg(test)]
107 tokio::task::yield_now().await;
108 let guard = self.db.lock().await;
109 #[cfg(test)]
110 tokio::task::yield_now().await;
111 guard
112 }
113
114 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
115 #[cfg(test)]
116 tokio::task::yield_now().await;
117 let guard = self.connection_pool.lock();
118 ConnectionPoolGuard {
119 guard,
120 _not_send: PhantomData,
121 }
122 }
123}
124
125impl fmt::Debug for Session {
126 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127 f.debug_struct("Session")
128 .field("user_id", &self.user_id)
129 .field("connection_id", &self.connection_id)
130 .finish()
131 }
132}
133
134struct DbHandle(Arc<Database>);
135
136impl Deref for DbHandle {
137 type Target = Database;
138
139 fn deref(&self) -> &Self::Target {
140 self.0.as_ref()
141 }
142}
143
144pub struct Server {
145 id: parking_lot::Mutex<ServerId>,
146 peer: Arc<Peer>,
147 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
148 app_state: Arc<AppState>,
149 executor: Executor,
150 handlers: HashMap<TypeId, MessageHandler>,
151 teardown: watch::Sender<()>,
152}
153
154pub(crate) struct ConnectionPoolGuard<'a> {
155 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
156 _not_send: PhantomData<Rc<()>>,
157}
158
159#[derive(Serialize)]
160pub struct ServerSnapshot<'a> {
161 peer: &'a Peer,
162 #[serde(serialize_with = "serialize_deref")]
163 connection_pool: ConnectionPoolGuard<'a>,
164}
165
166pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
167where
168 S: Serializer,
169 T: Deref<Target = U>,
170 U: Serialize,
171{
172 Serialize::serialize(value.deref(), serializer)
173}
174
175impl Server {
176 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
177 let mut server = Self {
178 id: parking_lot::Mutex::new(id),
179 peer: Peer::new(id.0 as u32),
180 app_state,
181 executor,
182 connection_pool: Default::default(),
183 handlers: Default::default(),
184 teardown: watch::channel(()).0,
185 };
186
187 server
188 .add_request_handler(ping)
189 .add_request_handler(create_room)
190 .add_request_handler(join_room)
191 .add_request_handler(rejoin_room)
192 .add_request_handler(leave_room)
193 .add_request_handler(call)
194 .add_request_handler(cancel_call)
195 .add_message_handler(decline_call)
196 .add_request_handler(update_participant_location)
197 .add_request_handler(share_project)
198 .add_message_handler(unshare_project)
199 .add_request_handler(join_project)
200 .add_message_handler(leave_project)
201 .add_request_handler(update_project)
202 .add_request_handler(update_worktree)
203 .add_message_handler(start_language_server)
204 .add_message_handler(update_language_server)
205 .add_message_handler(update_diagnostic_summary)
206 .add_message_handler(update_worktree_settings)
207 .add_message_handler(refresh_inlay_hints)
208 .add_request_handler(forward_project_request::<proto::GetHover>)
209 .add_request_handler(forward_project_request::<proto::GetDefinition>)
210 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
211 .add_request_handler(forward_project_request::<proto::GetReferences>)
212 .add_request_handler(forward_project_request::<proto::SearchProject>)
213 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
214 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
215 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
216 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
217 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
218 .add_request_handler(forward_project_request::<proto::GetCompletions>)
219 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
220 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
221 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
222 .add_request_handler(forward_project_request::<proto::PrepareRename>)
223 .add_request_handler(forward_project_request::<proto::PerformRename>)
224 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
225 .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
226 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
227 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
228 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
229 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
230 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
231 .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
232 .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
233 .add_request_handler(forward_project_request::<proto::InlayHints>)
234 .add_message_handler(create_buffer_for_peer)
235 .add_request_handler(update_buffer)
236 .add_message_handler(update_buffer_file)
237 .add_message_handler(buffer_reloaded)
238 .add_message_handler(buffer_saved)
239 .add_request_handler(forward_project_request::<proto::SaveBuffer>)
240 .add_request_handler(get_users)
241 .add_request_handler(fuzzy_search_users)
242 .add_request_handler(request_contact)
243 .add_request_handler(remove_contact)
244 .add_request_handler(respond_to_contact_request)
245 .add_request_handler(create_channel)
246 .add_request_handler(remove_channel)
247 .add_request_handler(invite_channel_member)
248 .add_request_handler(remove_channel_member)
249 .add_request_handler(set_channel_member_admin)
250 .add_request_handler(rename_channel)
251 .add_request_handler(join_channel_buffer)
252 .add_request_handler(leave_channel_buffer)
253 .add_message_handler(update_channel_buffer)
254 .add_request_handler(get_channel_members)
255 .add_request_handler(respond_to_channel_invite)
256 .add_request_handler(join_channel)
257 .add_request_handler(follow)
258 .add_message_handler(unfollow)
259 .add_message_handler(update_followers)
260 .add_message_handler(update_diff_base)
261 .add_request_handler(get_private_user_info);
262
263 Arc::new(server)
264 }
265
266 pub async fn start(&self) -> Result<()> {
267 let server_id = *self.id.lock();
268 let app_state = self.app_state.clone();
269 let peer = self.peer.clone();
270 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
271 let pool = self.connection_pool.clone();
272 let live_kit_client = self.app_state.live_kit_client.clone();
273
274 let span = info_span!("start server");
275 self.executor.spawn_detached(
276 async move {
277 tracing::info!("waiting for cleanup timeout");
278 timeout.await;
279 tracing::info!("cleanup timeout expired, retrieving stale rooms");
280 if let Some(room_ids) = app_state
281 .db
282 .stale_room_ids(&app_state.config.zed_environment, server_id)
283 .await
284 .trace_err()
285 {
286 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
287 for room_id in room_ids {
288 let mut contacts_to_update = HashSet::default();
289 let mut canceled_calls_to_user_ids = Vec::new();
290 let mut live_kit_room = String::new();
291 let mut delete_live_kit_room = false;
292
293 if let Some(mut refreshed_room) = app_state
294 .db
295 .refresh_room(room_id, server_id)
296 .await
297 .trace_err()
298 {
299 tracing::info!(
300 room_id = room_id.0,
301 new_participant_count = refreshed_room.room.participants.len(),
302 "refreshed room"
303 );
304 room_updated(&refreshed_room.room, &peer);
305 if let Some(channel_id) = refreshed_room.channel_id {
306 channel_updated(
307 channel_id,
308 &refreshed_room.room,
309 &refreshed_room.channel_members,
310 &peer,
311 &*pool.lock(),
312 );
313 }
314 contacts_to_update
315 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
316 contacts_to_update
317 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
318 canceled_calls_to_user_ids =
319 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
320 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
321 delete_live_kit_room = refreshed_room.room.participants.is_empty();
322 }
323
324 {
325 let pool = pool.lock();
326 for canceled_user_id in canceled_calls_to_user_ids {
327 for connection_id in pool.user_connection_ids(canceled_user_id) {
328 peer.send(
329 connection_id,
330 proto::CallCanceled {
331 room_id: room_id.to_proto(),
332 },
333 )
334 .trace_err();
335 }
336 }
337 }
338
339 for user_id in contacts_to_update {
340 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
341 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
342 if let Some((busy, contacts)) = busy.zip(contacts) {
343 let pool = pool.lock();
344 let updated_contact = contact_for_user(user_id, false, busy, &pool);
345 for contact in contacts {
346 if let db::Contact::Accepted {
347 user_id: contact_user_id,
348 ..
349 } = contact
350 {
351 for contact_conn_id in
352 pool.user_connection_ids(contact_user_id)
353 {
354 peer.send(
355 contact_conn_id,
356 proto::UpdateContacts {
357 contacts: vec![updated_contact.clone()],
358 remove_contacts: Default::default(),
359 incoming_requests: Default::default(),
360 remove_incoming_requests: Default::default(),
361 outgoing_requests: Default::default(),
362 remove_outgoing_requests: Default::default(),
363 },
364 )
365 .trace_err();
366 }
367 }
368 }
369 }
370 }
371
372 if let Some(live_kit) = live_kit_client.as_ref() {
373 if delete_live_kit_room {
374 live_kit.delete_room(live_kit_room).await.trace_err();
375 }
376 }
377 }
378 }
379
380 app_state
381 .db
382 .delete_stale_servers(&app_state.config.zed_environment, server_id)
383 .await
384 .trace_err();
385 }
386 .instrument(span),
387 );
388 Ok(())
389 }
390
391 pub fn teardown(&self) {
392 self.peer.teardown();
393 self.connection_pool.lock().reset();
394 let _ = self.teardown.send(());
395 }
396
397 #[cfg(test)]
398 pub fn reset(&self, id: ServerId) {
399 self.teardown();
400 *self.id.lock() = id;
401 self.peer.reset(id.0 as u32);
402 }
403
404 #[cfg(test)]
405 pub fn id(&self) -> ServerId {
406 *self.id.lock()
407 }
408
409 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
410 where
411 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
412 Fut: 'static + Send + Future<Output = Result<()>>,
413 M: EnvelopedMessage,
414 {
415 let prev_handler = self.handlers.insert(
416 TypeId::of::<M>(),
417 Box::new(move |envelope, session| {
418 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
419 let span = info_span!(
420 "handle message",
421 payload_type = envelope.payload_type_name()
422 );
423 span.in_scope(|| {
424 tracing::info!(
425 payload_type = envelope.payload_type_name(),
426 "message received"
427 );
428 });
429 let start_time = Instant::now();
430 let future = (handler)(*envelope, session);
431 async move {
432 let result = future.await;
433 let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
434 match result {
435 Err(error) => {
436 tracing::error!(%error, ?duration_ms, "error handling message")
437 }
438 Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
439 }
440 }
441 .instrument(span)
442 .boxed()
443 }),
444 );
445 if prev_handler.is_some() {
446 panic!("registered a handler for the same message twice");
447 }
448 self
449 }
450
451 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
452 where
453 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
454 Fut: 'static + Send + Future<Output = Result<()>>,
455 M: EnvelopedMessage,
456 {
457 self.add_handler(move |envelope, session| handler(envelope.payload, session));
458 self
459 }
460
461 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
462 where
463 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
464 Fut: Send + Future<Output = Result<()>>,
465 M: RequestMessage,
466 {
467 let handler = Arc::new(handler);
468 self.add_handler(move |envelope, session| {
469 let receipt = envelope.receipt();
470 let handler = handler.clone();
471 async move {
472 let peer = session.peer.clone();
473 let responded = Arc::new(AtomicBool::default());
474 let response = Response {
475 peer: peer.clone(),
476 responded: responded.clone(),
477 receipt,
478 };
479 match (handler)(envelope.payload, response, session).await {
480 Ok(()) => {
481 if responded.load(std::sync::atomic::Ordering::SeqCst) {
482 Ok(())
483 } else {
484 Err(anyhow!("handler did not send a response"))?
485 }
486 }
487 Err(error) => {
488 peer.respond_with_error(
489 receipt,
490 proto::Error {
491 message: error.to_string(),
492 },
493 )?;
494 Err(error)
495 }
496 }
497 }
498 })
499 }
500
501 pub fn handle_connection(
502 self: &Arc<Self>,
503 connection: Connection,
504 address: String,
505 user: User,
506 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
507 executor: Executor,
508 ) -> impl Future<Output = Result<()>> {
509 let this = self.clone();
510 let user_id = user.id;
511 let login = user.github_login;
512 let span = info_span!("handle connection", %user_id, %login, %address);
513 let mut teardown = self.teardown.subscribe();
514 async move {
515 let (connection_id, handle_io, mut incoming_rx) = this
516 .peer
517 .add_connection(connection, {
518 let executor = executor.clone();
519 move |duration| executor.sleep(duration)
520 });
521
522 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
523 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
524 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
525
526 if let Some(send_connection_id) = send_connection_id.take() {
527 let _ = send_connection_id.send(connection_id);
528 }
529
530 if !user.connected_once {
531 this.peer.send(connection_id, proto::ShowContacts {})?;
532 this.app_state.db.set_user_connected_once(user_id, true).await?;
533 }
534
535 let (contacts, invite_code, channels_for_user, channel_invites) = future::try_join4(
536 this.app_state.db.get_contacts(user_id),
537 this.app_state.db.get_invite_code_for_user(user_id),
538 this.app_state.db.get_channels_for_user(user_id),
539 this.app_state.db.get_channel_invites_for_user(user_id)
540 ).await?;
541
542 {
543 let mut pool = this.connection_pool.lock();
544 pool.add_connection(connection_id, user_id, user.admin);
545 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
546 this.peer.send(connection_id, build_initial_channels_update(
547 channels_for_user,
548 channel_invites
549 ))?;
550
551 if let Some((code, count)) = invite_code {
552 this.peer.send(connection_id, proto::UpdateInviteInfo {
553 url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
554 count: count as u32,
555 })?;
556 }
557 }
558
559 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
560 this.peer.send(connection_id, incoming_call)?;
561 }
562
563 let session = Session {
564 user_id,
565 connection_id,
566 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
567 peer: this.peer.clone(),
568 connection_pool: this.connection_pool.clone(),
569 live_kit_client: this.app_state.live_kit_client.clone(),
570 executor: executor.clone(),
571 };
572 update_user_contacts(user_id, &session).await?;
573
574 let handle_io = handle_io.fuse();
575 futures::pin_mut!(handle_io);
576
577 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
578 // This prevents deadlocks when e.g., client A performs a request to client B and
579 // client B performs a request to client A. If both clients stop processing further
580 // messages until their respective request completes, they won't have a chance to
581 // respond to the other client's request and cause a deadlock.
582 //
583 // This arrangement ensures we will attempt to process earlier messages first, but fall
584 // back to processing messages arrived later in the spirit of making progress.
585 let mut foreground_message_handlers = FuturesUnordered::new();
586 let concurrent_handlers = Arc::new(Semaphore::new(256));
587 loop {
588 let next_message = async {
589 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
590 let message = incoming_rx.next().await;
591 (permit, message)
592 }.fuse();
593 futures::pin_mut!(next_message);
594 futures::select_biased! {
595 _ = teardown.changed().fuse() => return Ok(()),
596 result = handle_io => {
597 if let Err(error) = result {
598 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
599 }
600 break;
601 }
602 _ = foreground_message_handlers.next() => {}
603 next_message = next_message => {
604 let (permit, message) = next_message;
605 if let Some(message) = message {
606 let type_name = message.payload_type_name();
607 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
608 let span_enter = span.enter();
609 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
610 let is_background = message.is_background();
611 let handle_message = (handler)(message, session.clone());
612 drop(span_enter);
613
614 let handle_message = async move {
615 handle_message.await;
616 drop(permit);
617 }.instrument(span);
618 if is_background {
619 executor.spawn_detached(handle_message);
620 } else {
621 foreground_message_handlers.push(handle_message);
622 }
623 } else {
624 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
625 }
626 } else {
627 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
628 break;
629 }
630 }
631 }
632 }
633
634 drop(foreground_message_handlers);
635 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
636 if let Err(error) = connection_lost(session, teardown, executor).await {
637 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
638 }
639
640 Ok(())
641 }.instrument(span)
642 }
643
644 pub async fn invite_code_redeemed(
645 self: &Arc<Self>,
646 inviter_id: UserId,
647 invitee_id: UserId,
648 ) -> Result<()> {
649 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
650 if let Some(code) = &user.invite_code {
651 let pool = self.connection_pool.lock();
652 let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
653 for connection_id in pool.user_connection_ids(inviter_id) {
654 self.peer.send(
655 connection_id,
656 proto::UpdateContacts {
657 contacts: vec![invitee_contact.clone()],
658 ..Default::default()
659 },
660 )?;
661 self.peer.send(
662 connection_id,
663 proto::UpdateInviteInfo {
664 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
665 count: user.invite_count as u32,
666 },
667 )?;
668 }
669 }
670 }
671 Ok(())
672 }
673
674 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
675 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
676 if let Some(invite_code) = &user.invite_code {
677 let pool = self.connection_pool.lock();
678 for connection_id in pool.user_connection_ids(user_id) {
679 self.peer.send(
680 connection_id,
681 proto::UpdateInviteInfo {
682 url: format!(
683 "{}{}",
684 self.app_state.config.invite_link_prefix, invite_code
685 ),
686 count: user.invite_count as u32,
687 },
688 )?;
689 }
690 }
691 }
692 Ok(())
693 }
694
695 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
696 ServerSnapshot {
697 connection_pool: ConnectionPoolGuard {
698 guard: self.connection_pool.lock(),
699 _not_send: PhantomData,
700 },
701 peer: &self.peer,
702 }
703 }
704}
705
706impl<'a> Deref for ConnectionPoolGuard<'a> {
707 type Target = ConnectionPool;
708
709 fn deref(&self) -> &Self::Target {
710 &*self.guard
711 }
712}
713
714impl<'a> DerefMut for ConnectionPoolGuard<'a> {
715 fn deref_mut(&mut self) -> &mut Self::Target {
716 &mut *self.guard
717 }
718}
719
720impl<'a> Drop for ConnectionPoolGuard<'a> {
721 fn drop(&mut self) {
722 #[cfg(test)]
723 self.check_invariants();
724 }
725}
726
727fn broadcast<F>(
728 sender_id: Option<ConnectionId>,
729 receiver_ids: impl IntoIterator<Item = ConnectionId>,
730 mut f: F,
731) where
732 F: FnMut(ConnectionId) -> anyhow::Result<()>,
733{
734 for receiver_id in receiver_ids {
735 if Some(receiver_id) != sender_id {
736 if let Err(error) = f(receiver_id) {
737 tracing::error!("failed to send to {:?} {}", receiver_id, error);
738 }
739 }
740 }
741}
742
743lazy_static! {
744 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
745}
746
747pub struct ProtocolVersion(u32);
748
749impl Header for ProtocolVersion {
750 fn name() -> &'static HeaderName {
751 &ZED_PROTOCOL_VERSION
752 }
753
754 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
755 where
756 Self: Sized,
757 I: Iterator<Item = &'i axum::http::HeaderValue>,
758 {
759 let version = values
760 .next()
761 .ok_or_else(axum::headers::Error::invalid)?
762 .to_str()
763 .map_err(|_| axum::headers::Error::invalid())?
764 .parse()
765 .map_err(|_| axum::headers::Error::invalid())?;
766 Ok(Self(version))
767 }
768
769 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
770 values.extend([self.0.to_string().parse().unwrap()]);
771 }
772}
773
774pub fn routes(server: Arc<Server>) -> Router<Body> {
775 Router::new()
776 .route("/rpc", get(handle_websocket_request))
777 .layer(
778 ServiceBuilder::new()
779 .layer(Extension(server.app_state.clone()))
780 .layer(middleware::from_fn(auth::validate_header)),
781 )
782 .route("/metrics", get(handle_metrics))
783 .layer(Extension(server))
784}
785
786pub async fn handle_websocket_request(
787 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
788 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
789 Extension(server): Extension<Arc<Server>>,
790 Extension(user): Extension<User>,
791 ws: WebSocketUpgrade,
792) -> axum::response::Response {
793 if protocol_version != rpc::PROTOCOL_VERSION {
794 return (
795 StatusCode::UPGRADE_REQUIRED,
796 "client must be upgraded".to_string(),
797 )
798 .into_response();
799 }
800 let socket_address = socket_address.to_string();
801 ws.on_upgrade(move |socket| {
802 use util::ResultExt;
803 let socket = socket
804 .map_ok(to_tungstenite_message)
805 .err_into()
806 .with(|message| async move { Ok(to_axum_message(message)) });
807 let connection = Connection::new(Box::pin(socket));
808 async move {
809 server
810 .handle_connection(connection, socket_address, user, None, Executor::Production)
811 .await
812 .log_err();
813 }
814 })
815}
816
817pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
818 let connections = server
819 .connection_pool
820 .lock()
821 .connections()
822 .filter(|connection| !connection.admin)
823 .count();
824
825 METRIC_CONNECTIONS.set(connections as _);
826
827 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
828 METRIC_SHARED_PROJECTS.set(shared_projects as _);
829
830 let encoder = prometheus::TextEncoder::new();
831 let metric_families = prometheus::gather();
832 let encoded_metrics = encoder
833 .encode_to_string(&metric_families)
834 .map_err(|err| anyhow!("{}", err))?;
835 Ok(encoded_metrics)
836}
837
838#[instrument(err, skip(executor))]
839async fn connection_lost(
840 session: Session,
841 mut teardown: watch::Receiver<()>,
842 executor: Executor,
843) -> Result<()> {
844 session.peer.disconnect(session.connection_id);
845 session
846 .connection_pool()
847 .await
848 .remove_connection(session.connection_id)?;
849
850 session
851 .db()
852 .await
853 .connection_lost(session.connection_id)
854 .await
855 .trace_err();
856
857 leave_channel_buffers_for_session(&session)
858 .await
859 .trace_err();
860
861 futures::select_biased! {
862 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
863 leave_room_for_session(&session).await.trace_err();
864
865 if !session
866 .connection_pool()
867 .await
868 .is_user_online(session.user_id)
869 {
870 let db = session.db().await;
871 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
872 room_updated(&room, &session.peer);
873 }
874 }
875 update_user_contacts(session.user_id, &session).await?;
876
877
878 }
879 _ = teardown.changed().fuse() => {}
880 }
881
882 Ok(())
883}
884
885async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
886 response.send(proto::Ack {})?;
887 Ok(())
888}
889
890async fn create_room(
891 _request: proto::CreateRoom,
892 response: Response<proto::CreateRoom>,
893 session: Session,
894) -> Result<()> {
895 let live_kit_room = nanoid::nanoid!(30);
896
897 let live_kit_connection_info = {
898 let live_kit_room = live_kit_room.clone();
899 let live_kit = session.live_kit_client.as_ref();
900
901 util::async_iife!({
902 let live_kit = live_kit?;
903
904 live_kit
905 .create_room(live_kit_room.clone())
906 .await
907 .trace_err()?;
908
909 let token = live_kit
910 .room_token(&live_kit_room, &session.user_id.to_string())
911 .trace_err()?;
912
913 Some(proto::LiveKitConnectionInfo {
914 server_url: live_kit.url().into(),
915 token,
916 })
917 })
918 }
919 .await;
920
921 let room = session
922 .db()
923 .await
924 .create_room(session.user_id, session.connection_id, &live_kit_room)
925 .await?;
926
927 response.send(proto::CreateRoomResponse {
928 room: Some(room.clone()),
929 live_kit_connection_info,
930 })?;
931
932 update_user_contacts(session.user_id, &session).await?;
933 Ok(())
934}
935
936async fn join_room(
937 request: proto::JoinRoom,
938 response: Response<proto::JoinRoom>,
939 session: Session,
940) -> Result<()> {
941 let room_id = RoomId::from_proto(request.id);
942 let joined_room = {
943 let room = session
944 .db()
945 .await
946 .join_room(room_id, session.user_id, session.connection_id)
947 .await?;
948 room_updated(&room.room, &session.peer);
949 room.into_inner()
950 };
951
952 if let Some(channel_id) = joined_room.channel_id {
953 channel_updated(
954 channel_id,
955 &joined_room.room,
956 &joined_room.channel_members,
957 &session.peer,
958 &*session.connection_pool().await,
959 )
960 }
961
962 for connection_id in session
963 .connection_pool()
964 .await
965 .user_connection_ids(session.user_id)
966 {
967 session
968 .peer
969 .send(
970 connection_id,
971 proto::CallCanceled {
972 room_id: room_id.to_proto(),
973 },
974 )
975 .trace_err();
976 }
977
978 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
979 if let Some(token) = live_kit
980 .room_token(
981 &joined_room.room.live_kit_room,
982 &session.user_id.to_string(),
983 )
984 .trace_err()
985 {
986 Some(proto::LiveKitConnectionInfo {
987 server_url: live_kit.url().into(),
988 token,
989 })
990 } else {
991 None
992 }
993 } else {
994 None
995 };
996
997 response.send(proto::JoinRoomResponse {
998 room: Some(joined_room.room),
999 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
1000 live_kit_connection_info,
1001 })?;
1002
1003 update_user_contacts(session.user_id, &session).await?;
1004 Ok(())
1005}
1006
1007async fn rejoin_room(
1008 request: proto::RejoinRoom,
1009 response: Response<proto::RejoinRoom>,
1010 session: Session,
1011) -> Result<()> {
1012 let room;
1013 let channel_id;
1014 let channel_members;
1015 {
1016 let mut rejoined_room = session
1017 .db()
1018 .await
1019 .rejoin_room(request, session.user_id, session.connection_id)
1020 .await?;
1021
1022 response.send(proto::RejoinRoomResponse {
1023 room: Some(rejoined_room.room.clone()),
1024 reshared_projects: rejoined_room
1025 .reshared_projects
1026 .iter()
1027 .map(|project| proto::ResharedProject {
1028 id: project.id.to_proto(),
1029 collaborators: project
1030 .collaborators
1031 .iter()
1032 .map(|collaborator| collaborator.to_proto())
1033 .collect(),
1034 })
1035 .collect(),
1036 rejoined_projects: rejoined_room
1037 .rejoined_projects
1038 .iter()
1039 .map(|rejoined_project| proto::RejoinedProject {
1040 id: rejoined_project.id.to_proto(),
1041 worktrees: rejoined_project
1042 .worktrees
1043 .iter()
1044 .map(|worktree| proto::WorktreeMetadata {
1045 id: worktree.id,
1046 root_name: worktree.root_name.clone(),
1047 visible: worktree.visible,
1048 abs_path: worktree.abs_path.clone(),
1049 })
1050 .collect(),
1051 collaborators: rejoined_project
1052 .collaborators
1053 .iter()
1054 .map(|collaborator| collaborator.to_proto())
1055 .collect(),
1056 language_servers: rejoined_project.language_servers.clone(),
1057 })
1058 .collect(),
1059 })?;
1060 room_updated(&rejoined_room.room, &session.peer);
1061
1062 for project in &rejoined_room.reshared_projects {
1063 for collaborator in &project.collaborators {
1064 session
1065 .peer
1066 .send(
1067 collaborator.connection_id,
1068 proto::UpdateProjectCollaborator {
1069 project_id: project.id.to_proto(),
1070 old_peer_id: Some(project.old_connection_id.into()),
1071 new_peer_id: Some(session.connection_id.into()),
1072 },
1073 )
1074 .trace_err();
1075 }
1076
1077 broadcast(
1078 Some(session.connection_id),
1079 project
1080 .collaborators
1081 .iter()
1082 .map(|collaborator| collaborator.connection_id),
1083 |connection_id| {
1084 session.peer.forward_send(
1085 session.connection_id,
1086 connection_id,
1087 proto::UpdateProject {
1088 project_id: project.id.to_proto(),
1089 worktrees: project.worktrees.clone(),
1090 },
1091 )
1092 },
1093 );
1094 }
1095
1096 for project in &rejoined_room.rejoined_projects {
1097 for collaborator in &project.collaborators {
1098 session
1099 .peer
1100 .send(
1101 collaborator.connection_id,
1102 proto::UpdateProjectCollaborator {
1103 project_id: project.id.to_proto(),
1104 old_peer_id: Some(project.old_connection_id.into()),
1105 new_peer_id: Some(session.connection_id.into()),
1106 },
1107 )
1108 .trace_err();
1109 }
1110 }
1111
1112 for project in &mut rejoined_room.rejoined_projects {
1113 for worktree in mem::take(&mut project.worktrees) {
1114 #[cfg(any(test, feature = "test-support"))]
1115 const MAX_CHUNK_SIZE: usize = 2;
1116 #[cfg(not(any(test, feature = "test-support")))]
1117 const MAX_CHUNK_SIZE: usize = 256;
1118
1119 // Stream this worktree's entries.
1120 let message = proto::UpdateWorktree {
1121 project_id: project.id.to_proto(),
1122 worktree_id: worktree.id,
1123 abs_path: worktree.abs_path.clone(),
1124 root_name: worktree.root_name,
1125 updated_entries: worktree.updated_entries,
1126 removed_entries: worktree.removed_entries,
1127 scan_id: worktree.scan_id,
1128 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1129 updated_repositories: worktree.updated_repositories,
1130 removed_repositories: worktree.removed_repositories,
1131 };
1132 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1133 session.peer.send(session.connection_id, update.clone())?;
1134 }
1135
1136 // Stream this worktree's diagnostics.
1137 for summary in worktree.diagnostic_summaries {
1138 session.peer.send(
1139 session.connection_id,
1140 proto::UpdateDiagnosticSummary {
1141 project_id: project.id.to_proto(),
1142 worktree_id: worktree.id,
1143 summary: Some(summary),
1144 },
1145 )?;
1146 }
1147
1148 for settings_file in worktree.settings_files {
1149 session.peer.send(
1150 session.connection_id,
1151 proto::UpdateWorktreeSettings {
1152 project_id: project.id.to_proto(),
1153 worktree_id: worktree.id,
1154 path: settings_file.path,
1155 content: Some(settings_file.content),
1156 },
1157 )?;
1158 }
1159 }
1160
1161 for language_server in &project.language_servers {
1162 session.peer.send(
1163 session.connection_id,
1164 proto::UpdateLanguageServer {
1165 project_id: project.id.to_proto(),
1166 language_server_id: language_server.id,
1167 variant: Some(
1168 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1169 proto::LspDiskBasedDiagnosticsUpdated {},
1170 ),
1171 ),
1172 },
1173 )?;
1174 }
1175 }
1176
1177 let rejoined_room = rejoined_room.into_inner();
1178
1179 room = rejoined_room.room;
1180 channel_id = rejoined_room.channel_id;
1181 channel_members = rejoined_room.channel_members;
1182 }
1183
1184 if let Some(channel_id) = channel_id {
1185 channel_updated(
1186 channel_id,
1187 &room,
1188 &channel_members,
1189 &session.peer,
1190 &*session.connection_pool().await,
1191 );
1192 }
1193
1194 update_user_contacts(session.user_id, &session).await?;
1195 Ok(())
1196}
1197
1198async fn leave_room(
1199 _: proto::LeaveRoom,
1200 response: Response<proto::LeaveRoom>,
1201 session: Session,
1202) -> Result<()> {
1203 leave_room_for_session(&session).await?;
1204 response.send(proto::Ack {})?;
1205 Ok(())
1206}
1207
1208async fn call(
1209 request: proto::Call,
1210 response: Response<proto::Call>,
1211 session: Session,
1212) -> Result<()> {
1213 let room_id = RoomId::from_proto(request.room_id);
1214 let calling_user_id = session.user_id;
1215 let calling_connection_id = session.connection_id;
1216 let called_user_id = UserId::from_proto(request.called_user_id);
1217 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1218 if !session
1219 .db()
1220 .await
1221 .has_contact(calling_user_id, called_user_id)
1222 .await?
1223 {
1224 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1225 }
1226
1227 let incoming_call = {
1228 let (room, incoming_call) = &mut *session
1229 .db()
1230 .await
1231 .call(
1232 room_id,
1233 calling_user_id,
1234 calling_connection_id,
1235 called_user_id,
1236 initial_project_id,
1237 )
1238 .await?;
1239 room_updated(&room, &session.peer);
1240 mem::take(incoming_call)
1241 };
1242 update_user_contacts(called_user_id, &session).await?;
1243
1244 let mut calls = session
1245 .connection_pool()
1246 .await
1247 .user_connection_ids(called_user_id)
1248 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1249 .collect::<FuturesUnordered<_>>();
1250
1251 while let Some(call_response) = calls.next().await {
1252 match call_response.as_ref() {
1253 Ok(_) => {
1254 response.send(proto::Ack {})?;
1255 return Ok(());
1256 }
1257 Err(_) => {
1258 call_response.trace_err();
1259 }
1260 }
1261 }
1262
1263 {
1264 let room = session
1265 .db()
1266 .await
1267 .call_failed(room_id, called_user_id)
1268 .await?;
1269 room_updated(&room, &session.peer);
1270 }
1271 update_user_contacts(called_user_id, &session).await?;
1272
1273 Err(anyhow!("failed to ring user"))?
1274}
1275
1276async fn cancel_call(
1277 request: proto::CancelCall,
1278 response: Response<proto::CancelCall>,
1279 session: Session,
1280) -> Result<()> {
1281 let called_user_id = UserId::from_proto(request.called_user_id);
1282 let room_id = RoomId::from_proto(request.room_id);
1283 {
1284 let room = session
1285 .db()
1286 .await
1287 .cancel_call(room_id, session.connection_id, called_user_id)
1288 .await?;
1289 room_updated(&room, &session.peer);
1290 }
1291
1292 for connection_id in session
1293 .connection_pool()
1294 .await
1295 .user_connection_ids(called_user_id)
1296 {
1297 session
1298 .peer
1299 .send(
1300 connection_id,
1301 proto::CallCanceled {
1302 room_id: room_id.to_proto(),
1303 },
1304 )
1305 .trace_err();
1306 }
1307 response.send(proto::Ack {})?;
1308
1309 update_user_contacts(called_user_id, &session).await?;
1310 Ok(())
1311}
1312
1313async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1314 let room_id = RoomId::from_proto(message.room_id);
1315 {
1316 let room = session
1317 .db()
1318 .await
1319 .decline_call(Some(room_id), session.user_id)
1320 .await?
1321 .ok_or_else(|| anyhow!("failed to decline call"))?;
1322 room_updated(&room, &session.peer);
1323 }
1324
1325 for connection_id in session
1326 .connection_pool()
1327 .await
1328 .user_connection_ids(session.user_id)
1329 {
1330 session
1331 .peer
1332 .send(
1333 connection_id,
1334 proto::CallCanceled {
1335 room_id: room_id.to_proto(),
1336 },
1337 )
1338 .trace_err();
1339 }
1340 update_user_contacts(session.user_id, &session).await?;
1341 Ok(())
1342}
1343
1344async fn update_participant_location(
1345 request: proto::UpdateParticipantLocation,
1346 response: Response<proto::UpdateParticipantLocation>,
1347 session: Session,
1348) -> Result<()> {
1349 let room_id = RoomId::from_proto(request.room_id);
1350 let location = request
1351 .location
1352 .ok_or_else(|| anyhow!("invalid location"))?;
1353
1354 let db = session.db().await;
1355 let room = db
1356 .update_room_participant_location(room_id, session.connection_id, location)
1357 .await?;
1358
1359 room_updated(&room, &session.peer);
1360 response.send(proto::Ack {})?;
1361 Ok(())
1362}
1363
1364async fn share_project(
1365 request: proto::ShareProject,
1366 response: Response<proto::ShareProject>,
1367 session: Session,
1368) -> Result<()> {
1369 let (project_id, room) = &*session
1370 .db()
1371 .await
1372 .share_project(
1373 RoomId::from_proto(request.room_id),
1374 session.connection_id,
1375 &request.worktrees,
1376 )
1377 .await?;
1378 response.send(proto::ShareProjectResponse {
1379 project_id: project_id.to_proto(),
1380 })?;
1381 room_updated(&room, &session.peer);
1382
1383 Ok(())
1384}
1385
1386async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1387 let project_id = ProjectId::from_proto(message.project_id);
1388
1389 let (room, guest_connection_ids) = &*session
1390 .db()
1391 .await
1392 .unshare_project(project_id, session.connection_id)
1393 .await?;
1394
1395 broadcast(
1396 Some(session.connection_id),
1397 guest_connection_ids.iter().copied(),
1398 |conn_id| session.peer.send(conn_id, message.clone()),
1399 );
1400 room_updated(&room, &session.peer);
1401
1402 Ok(())
1403}
1404
1405async fn join_project(
1406 request: proto::JoinProject,
1407 response: Response<proto::JoinProject>,
1408 session: Session,
1409) -> Result<()> {
1410 let project_id = ProjectId::from_proto(request.project_id);
1411 let guest_user_id = session.user_id;
1412
1413 tracing::info!(%project_id, "join project");
1414
1415 let (project, replica_id) = &mut *session
1416 .db()
1417 .await
1418 .join_project(project_id, session.connection_id)
1419 .await?;
1420
1421 let collaborators = project
1422 .collaborators
1423 .iter()
1424 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1425 .map(|collaborator| collaborator.to_proto())
1426 .collect::<Vec<_>>();
1427
1428 let worktrees = project
1429 .worktrees
1430 .iter()
1431 .map(|(id, worktree)| proto::WorktreeMetadata {
1432 id: *id,
1433 root_name: worktree.root_name.clone(),
1434 visible: worktree.visible,
1435 abs_path: worktree.abs_path.clone(),
1436 })
1437 .collect::<Vec<_>>();
1438
1439 for collaborator in &collaborators {
1440 session
1441 .peer
1442 .send(
1443 collaborator.peer_id.unwrap().into(),
1444 proto::AddProjectCollaborator {
1445 project_id: project_id.to_proto(),
1446 collaborator: Some(proto::Collaborator {
1447 peer_id: Some(session.connection_id.into()),
1448 replica_id: replica_id.0 as u32,
1449 user_id: guest_user_id.to_proto(),
1450 }),
1451 },
1452 )
1453 .trace_err();
1454 }
1455
1456 // First, we send the metadata associated with each worktree.
1457 response.send(proto::JoinProjectResponse {
1458 worktrees: worktrees.clone(),
1459 replica_id: replica_id.0 as u32,
1460 collaborators: collaborators.clone(),
1461 language_servers: project.language_servers.clone(),
1462 })?;
1463
1464 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1465 #[cfg(any(test, feature = "test-support"))]
1466 const MAX_CHUNK_SIZE: usize = 2;
1467 #[cfg(not(any(test, feature = "test-support")))]
1468 const MAX_CHUNK_SIZE: usize = 256;
1469
1470 // Stream this worktree's entries.
1471 let message = proto::UpdateWorktree {
1472 project_id: project_id.to_proto(),
1473 worktree_id,
1474 abs_path: worktree.abs_path.clone(),
1475 root_name: worktree.root_name,
1476 updated_entries: worktree.entries,
1477 removed_entries: Default::default(),
1478 scan_id: worktree.scan_id,
1479 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1480 updated_repositories: worktree.repository_entries.into_values().collect(),
1481 removed_repositories: Default::default(),
1482 };
1483 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1484 session.peer.send(session.connection_id, update.clone())?;
1485 }
1486
1487 // Stream this worktree's diagnostics.
1488 for summary in worktree.diagnostic_summaries {
1489 session.peer.send(
1490 session.connection_id,
1491 proto::UpdateDiagnosticSummary {
1492 project_id: project_id.to_proto(),
1493 worktree_id: worktree.id,
1494 summary: Some(summary),
1495 },
1496 )?;
1497 }
1498
1499 for settings_file in worktree.settings_files {
1500 session.peer.send(
1501 session.connection_id,
1502 proto::UpdateWorktreeSettings {
1503 project_id: project_id.to_proto(),
1504 worktree_id: worktree.id,
1505 path: settings_file.path,
1506 content: Some(settings_file.content),
1507 },
1508 )?;
1509 }
1510 }
1511
1512 for language_server in &project.language_servers {
1513 session.peer.send(
1514 session.connection_id,
1515 proto::UpdateLanguageServer {
1516 project_id: project_id.to_proto(),
1517 language_server_id: language_server.id,
1518 variant: Some(
1519 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1520 proto::LspDiskBasedDiagnosticsUpdated {},
1521 ),
1522 ),
1523 },
1524 )?;
1525 }
1526
1527 Ok(())
1528}
1529
1530async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1531 let sender_id = session.connection_id;
1532 let project_id = ProjectId::from_proto(request.project_id);
1533
1534 let (room, project) = &*session
1535 .db()
1536 .await
1537 .leave_project(project_id, sender_id)
1538 .await?;
1539 tracing::info!(
1540 %project_id,
1541 host_user_id = %project.host_user_id,
1542 host_connection_id = %project.host_connection_id,
1543 "leave project"
1544 );
1545
1546 project_left(&project, &session);
1547 room_updated(&room, &session.peer);
1548
1549 Ok(())
1550}
1551
1552async fn update_project(
1553 request: proto::UpdateProject,
1554 response: Response<proto::UpdateProject>,
1555 session: Session,
1556) -> Result<()> {
1557 let project_id = ProjectId::from_proto(request.project_id);
1558 let (room, guest_connection_ids) = &*session
1559 .db()
1560 .await
1561 .update_project(project_id, session.connection_id, &request.worktrees)
1562 .await?;
1563 broadcast(
1564 Some(session.connection_id),
1565 guest_connection_ids.iter().copied(),
1566 |connection_id| {
1567 session
1568 .peer
1569 .forward_send(session.connection_id, connection_id, request.clone())
1570 },
1571 );
1572 room_updated(&room, &session.peer);
1573 response.send(proto::Ack {})?;
1574
1575 Ok(())
1576}
1577
1578async fn update_worktree(
1579 request: proto::UpdateWorktree,
1580 response: Response<proto::UpdateWorktree>,
1581 session: Session,
1582) -> Result<()> {
1583 let guest_connection_ids = session
1584 .db()
1585 .await
1586 .update_worktree(&request, session.connection_id)
1587 .await?;
1588
1589 broadcast(
1590 Some(session.connection_id),
1591 guest_connection_ids.iter().copied(),
1592 |connection_id| {
1593 session
1594 .peer
1595 .forward_send(session.connection_id, connection_id, request.clone())
1596 },
1597 );
1598 response.send(proto::Ack {})?;
1599 Ok(())
1600}
1601
1602async fn update_diagnostic_summary(
1603 message: proto::UpdateDiagnosticSummary,
1604 session: Session,
1605) -> Result<()> {
1606 let guest_connection_ids = session
1607 .db()
1608 .await
1609 .update_diagnostic_summary(&message, session.connection_id)
1610 .await?;
1611
1612 broadcast(
1613 Some(session.connection_id),
1614 guest_connection_ids.iter().copied(),
1615 |connection_id| {
1616 session
1617 .peer
1618 .forward_send(session.connection_id, connection_id, message.clone())
1619 },
1620 );
1621
1622 Ok(())
1623}
1624
1625async fn update_worktree_settings(
1626 message: proto::UpdateWorktreeSettings,
1627 session: Session,
1628) -> Result<()> {
1629 let guest_connection_ids = session
1630 .db()
1631 .await
1632 .update_worktree_settings(&message, session.connection_id)
1633 .await?;
1634
1635 broadcast(
1636 Some(session.connection_id),
1637 guest_connection_ids.iter().copied(),
1638 |connection_id| {
1639 session
1640 .peer
1641 .forward_send(session.connection_id, connection_id, message.clone())
1642 },
1643 );
1644
1645 Ok(())
1646}
1647
1648async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1649 broadcast_project_message(request.project_id, request, session).await
1650}
1651
1652async fn start_language_server(
1653 request: proto::StartLanguageServer,
1654 session: Session,
1655) -> Result<()> {
1656 let guest_connection_ids = session
1657 .db()
1658 .await
1659 .start_language_server(&request, session.connection_id)
1660 .await?;
1661
1662 broadcast(
1663 Some(session.connection_id),
1664 guest_connection_ids.iter().copied(),
1665 |connection_id| {
1666 session
1667 .peer
1668 .forward_send(session.connection_id, connection_id, request.clone())
1669 },
1670 );
1671 Ok(())
1672}
1673
1674async fn update_language_server(
1675 request: proto::UpdateLanguageServer,
1676 session: Session,
1677) -> Result<()> {
1678 session.executor.record_backtrace();
1679 let project_id = ProjectId::from_proto(request.project_id);
1680 let project_connection_ids = session
1681 .db()
1682 .await
1683 .project_connection_ids(project_id, session.connection_id)
1684 .await?;
1685 broadcast(
1686 Some(session.connection_id),
1687 project_connection_ids.iter().copied(),
1688 |connection_id| {
1689 session
1690 .peer
1691 .forward_send(session.connection_id, connection_id, request.clone())
1692 },
1693 );
1694 Ok(())
1695}
1696
1697async fn forward_project_request<T>(
1698 request: T,
1699 response: Response<T>,
1700 session: Session,
1701) -> Result<()>
1702where
1703 T: EntityMessage + RequestMessage,
1704{
1705 session.executor.record_backtrace();
1706 let project_id = ProjectId::from_proto(request.remote_entity_id());
1707 let host_connection_id = {
1708 let collaborators = session
1709 .db()
1710 .await
1711 .project_collaborators(project_id, session.connection_id)
1712 .await?;
1713 collaborators
1714 .iter()
1715 .find(|collaborator| collaborator.is_host)
1716 .ok_or_else(|| anyhow!("host not found"))?
1717 .connection_id
1718 };
1719
1720 let payload = session
1721 .peer
1722 .forward_request(session.connection_id, host_connection_id, request)
1723 .await?;
1724
1725 response.send(payload)?;
1726 Ok(())
1727}
1728
1729async fn create_buffer_for_peer(
1730 request: proto::CreateBufferForPeer,
1731 session: Session,
1732) -> Result<()> {
1733 session.executor.record_backtrace();
1734 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1735 session
1736 .peer
1737 .forward_send(session.connection_id, peer_id.into(), request)?;
1738 Ok(())
1739}
1740
1741async fn update_buffer(
1742 request: proto::UpdateBuffer,
1743 response: Response<proto::UpdateBuffer>,
1744 session: Session,
1745) -> Result<()> {
1746 session.executor.record_backtrace();
1747 let project_id = ProjectId::from_proto(request.project_id);
1748 let mut guest_connection_ids;
1749 let mut host_connection_id = None;
1750 {
1751 let collaborators = session
1752 .db()
1753 .await
1754 .project_collaborators(project_id, session.connection_id)
1755 .await?;
1756 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1757 for collaborator in collaborators.iter() {
1758 if collaborator.is_host {
1759 host_connection_id = Some(collaborator.connection_id);
1760 } else {
1761 guest_connection_ids.push(collaborator.connection_id);
1762 }
1763 }
1764 }
1765 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1766
1767 session.executor.record_backtrace();
1768 broadcast(
1769 Some(session.connection_id),
1770 guest_connection_ids,
1771 |connection_id| {
1772 session
1773 .peer
1774 .forward_send(session.connection_id, connection_id, request.clone())
1775 },
1776 );
1777 if host_connection_id != session.connection_id {
1778 session
1779 .peer
1780 .forward_request(session.connection_id, host_connection_id, request.clone())
1781 .await?;
1782 }
1783
1784 response.send(proto::Ack {})?;
1785 Ok(())
1786}
1787
1788async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1789 let project_id = ProjectId::from_proto(request.project_id);
1790 let project_connection_ids = session
1791 .db()
1792 .await
1793 .project_connection_ids(project_id, session.connection_id)
1794 .await?;
1795
1796 broadcast(
1797 Some(session.connection_id),
1798 project_connection_ids.iter().copied(),
1799 |connection_id| {
1800 session
1801 .peer
1802 .forward_send(session.connection_id, connection_id, request.clone())
1803 },
1804 );
1805 Ok(())
1806}
1807
1808async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1809 let project_id = ProjectId::from_proto(request.project_id);
1810 let project_connection_ids = session
1811 .db()
1812 .await
1813 .project_connection_ids(project_id, session.connection_id)
1814 .await?;
1815 broadcast(
1816 Some(session.connection_id),
1817 project_connection_ids.iter().copied(),
1818 |connection_id| {
1819 session
1820 .peer
1821 .forward_send(session.connection_id, connection_id, request.clone())
1822 },
1823 );
1824 Ok(())
1825}
1826
1827async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1828 broadcast_project_message(request.project_id, request, session).await
1829}
1830
1831async fn broadcast_project_message<T: EnvelopedMessage>(
1832 project_id: u64,
1833 request: T,
1834 session: Session,
1835) -> Result<()> {
1836 let project_id = ProjectId::from_proto(project_id);
1837 let project_connection_ids = session
1838 .db()
1839 .await
1840 .project_connection_ids(project_id, session.connection_id)
1841 .await?;
1842 broadcast(
1843 Some(session.connection_id),
1844 project_connection_ids.iter().copied(),
1845 |connection_id| {
1846 session
1847 .peer
1848 .forward_send(session.connection_id, connection_id, request.clone())
1849 },
1850 );
1851 Ok(())
1852}
1853
1854async fn follow(
1855 request: proto::Follow,
1856 response: Response<proto::Follow>,
1857 session: Session,
1858) -> Result<()> {
1859 let project_id = ProjectId::from_proto(request.project_id);
1860 let leader_id = request
1861 .leader_id
1862 .ok_or_else(|| anyhow!("invalid leader id"))?
1863 .into();
1864 let follower_id = session.connection_id;
1865
1866 {
1867 let project_connection_ids = session
1868 .db()
1869 .await
1870 .project_connection_ids(project_id, session.connection_id)
1871 .await?;
1872
1873 if !project_connection_ids.contains(&leader_id) {
1874 Err(anyhow!("no such peer"))?;
1875 }
1876 }
1877
1878 let mut response_payload = session
1879 .peer
1880 .forward_request(session.connection_id, leader_id, request)
1881 .await?;
1882 response_payload
1883 .views
1884 .retain(|view| view.leader_id != Some(follower_id.into()));
1885 response.send(response_payload)?;
1886
1887 let room = session
1888 .db()
1889 .await
1890 .follow(project_id, leader_id, follower_id)
1891 .await?;
1892 room_updated(&room, &session.peer);
1893
1894 Ok(())
1895}
1896
1897async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1898 let project_id = ProjectId::from_proto(request.project_id);
1899 let leader_id = request
1900 .leader_id
1901 .ok_or_else(|| anyhow!("invalid leader id"))?
1902 .into();
1903 let follower_id = session.connection_id;
1904
1905 if !session
1906 .db()
1907 .await
1908 .project_connection_ids(project_id, session.connection_id)
1909 .await?
1910 .contains(&leader_id)
1911 {
1912 Err(anyhow!("no such peer"))?;
1913 }
1914
1915 session
1916 .peer
1917 .forward_send(session.connection_id, leader_id, request)?;
1918
1919 let room = session
1920 .db()
1921 .await
1922 .unfollow(project_id, leader_id, follower_id)
1923 .await?;
1924 room_updated(&room, &session.peer);
1925
1926 Ok(())
1927}
1928
1929async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1930 let project_id = ProjectId::from_proto(request.project_id);
1931 let project_connection_ids = session
1932 .db
1933 .lock()
1934 .await
1935 .project_connection_ids(project_id, session.connection_id)
1936 .await?;
1937
1938 let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1939 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1940 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1941 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1942 });
1943 for follower_peer_id in request.follower_ids.iter().copied() {
1944 let follower_connection_id = follower_peer_id.into();
1945 if project_connection_ids.contains(&follower_connection_id)
1946 && Some(follower_peer_id) != leader_id
1947 {
1948 session.peer.forward_send(
1949 session.connection_id,
1950 follower_connection_id,
1951 request.clone(),
1952 )?;
1953 }
1954 }
1955 Ok(())
1956}
1957
1958async fn get_users(
1959 request: proto::GetUsers,
1960 response: Response<proto::GetUsers>,
1961 session: Session,
1962) -> Result<()> {
1963 let user_ids = request
1964 .user_ids
1965 .into_iter()
1966 .map(UserId::from_proto)
1967 .collect();
1968 let users = session
1969 .db()
1970 .await
1971 .get_users_by_ids(user_ids)
1972 .await?
1973 .into_iter()
1974 .map(|user| proto::User {
1975 id: user.id.to_proto(),
1976 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1977 github_login: user.github_login,
1978 })
1979 .collect();
1980 response.send(proto::UsersResponse { users })?;
1981 Ok(())
1982}
1983
1984async fn fuzzy_search_users(
1985 request: proto::FuzzySearchUsers,
1986 response: Response<proto::FuzzySearchUsers>,
1987 session: Session,
1988) -> Result<()> {
1989 let query = request.query;
1990 let users = match query.len() {
1991 0 => vec![],
1992 1 | 2 => session
1993 .db()
1994 .await
1995 .get_user_by_github_login(&query)
1996 .await?
1997 .into_iter()
1998 .collect(),
1999 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2000 };
2001 let users = users
2002 .into_iter()
2003 .filter(|user| user.id != session.user_id)
2004 .map(|user| proto::User {
2005 id: user.id.to_proto(),
2006 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2007 github_login: user.github_login,
2008 })
2009 .collect();
2010 response.send(proto::UsersResponse { users })?;
2011 Ok(())
2012}
2013
2014async fn request_contact(
2015 request: proto::RequestContact,
2016 response: Response<proto::RequestContact>,
2017 session: Session,
2018) -> Result<()> {
2019 let requester_id = session.user_id;
2020 let responder_id = UserId::from_proto(request.responder_id);
2021 if requester_id == responder_id {
2022 return Err(anyhow!("cannot add yourself as a contact"))?;
2023 }
2024
2025 session
2026 .db()
2027 .await
2028 .send_contact_request(requester_id, responder_id)
2029 .await?;
2030
2031 // Update outgoing contact requests of requester
2032 let mut update = proto::UpdateContacts::default();
2033 update.outgoing_requests.push(responder_id.to_proto());
2034 for connection_id in session
2035 .connection_pool()
2036 .await
2037 .user_connection_ids(requester_id)
2038 {
2039 session.peer.send(connection_id, update.clone())?;
2040 }
2041
2042 // Update incoming contact requests of responder
2043 let mut update = proto::UpdateContacts::default();
2044 update
2045 .incoming_requests
2046 .push(proto::IncomingContactRequest {
2047 requester_id: requester_id.to_proto(),
2048 should_notify: true,
2049 });
2050 for connection_id in session
2051 .connection_pool()
2052 .await
2053 .user_connection_ids(responder_id)
2054 {
2055 session.peer.send(connection_id, update.clone())?;
2056 }
2057
2058 response.send(proto::Ack {})?;
2059 Ok(())
2060}
2061
2062async fn respond_to_contact_request(
2063 request: proto::RespondToContactRequest,
2064 response: Response<proto::RespondToContactRequest>,
2065 session: Session,
2066) -> Result<()> {
2067 let responder_id = session.user_id;
2068 let requester_id = UserId::from_proto(request.requester_id);
2069 let db = session.db().await;
2070 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2071 db.dismiss_contact_notification(responder_id, requester_id)
2072 .await?;
2073 } else {
2074 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2075
2076 db.respond_to_contact_request(responder_id, requester_id, accept)
2077 .await?;
2078 let requester_busy = db.is_user_busy(requester_id).await?;
2079 let responder_busy = db.is_user_busy(responder_id).await?;
2080
2081 let pool = session.connection_pool().await;
2082 // Update responder with new contact
2083 let mut update = proto::UpdateContacts::default();
2084 if accept {
2085 update
2086 .contacts
2087 .push(contact_for_user(requester_id, false, requester_busy, &pool));
2088 }
2089 update
2090 .remove_incoming_requests
2091 .push(requester_id.to_proto());
2092 for connection_id in pool.user_connection_ids(responder_id) {
2093 session.peer.send(connection_id, update.clone())?;
2094 }
2095
2096 // Update requester with new contact
2097 let mut update = proto::UpdateContacts::default();
2098 if accept {
2099 update
2100 .contacts
2101 .push(contact_for_user(responder_id, true, responder_busy, &pool));
2102 }
2103 update
2104 .remove_outgoing_requests
2105 .push(responder_id.to_proto());
2106 for connection_id in pool.user_connection_ids(requester_id) {
2107 session.peer.send(connection_id, update.clone())?;
2108 }
2109 }
2110
2111 response.send(proto::Ack {})?;
2112 Ok(())
2113}
2114
2115async fn remove_contact(
2116 request: proto::RemoveContact,
2117 response: Response<proto::RemoveContact>,
2118 session: Session,
2119) -> Result<()> {
2120 let requester_id = session.user_id;
2121 let responder_id = UserId::from_proto(request.user_id);
2122 let db = session.db().await;
2123 let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
2124
2125 let pool = session.connection_pool().await;
2126 // Update outgoing contact requests of requester
2127 let mut update = proto::UpdateContacts::default();
2128 if contact_accepted {
2129 update.remove_contacts.push(responder_id.to_proto());
2130 } else {
2131 update
2132 .remove_outgoing_requests
2133 .push(responder_id.to_proto());
2134 }
2135 for connection_id in pool.user_connection_ids(requester_id) {
2136 session.peer.send(connection_id, update.clone())?;
2137 }
2138
2139 // Update incoming contact requests of responder
2140 let mut update = proto::UpdateContacts::default();
2141 if contact_accepted {
2142 update.remove_contacts.push(requester_id.to_proto());
2143 } else {
2144 update
2145 .remove_incoming_requests
2146 .push(requester_id.to_proto());
2147 }
2148 for connection_id in pool.user_connection_ids(responder_id) {
2149 session.peer.send(connection_id, update.clone())?;
2150 }
2151
2152 response.send(proto::Ack {})?;
2153 Ok(())
2154}
2155
2156async fn create_channel(
2157 request: proto::CreateChannel,
2158 response: Response<proto::CreateChannel>,
2159 session: Session,
2160) -> Result<()> {
2161 let db = session.db().await;
2162 let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
2163
2164 if let Some(live_kit) = session.live_kit_client.as_ref() {
2165 live_kit.create_room(live_kit_room.clone()).await?;
2166 }
2167
2168 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2169 let id = db
2170 .create_channel(&request.name, parent_id, &live_kit_room, session.user_id)
2171 .await?;
2172
2173 let channel = proto::Channel {
2174 id: id.to_proto(),
2175 name: request.name,
2176 parent_id: request.parent_id,
2177 };
2178
2179 response.send(proto::ChannelResponse {
2180 channel: Some(channel.clone()),
2181 })?;
2182
2183 let mut update = proto::UpdateChannels::default();
2184 update.channels.push(channel);
2185
2186 let user_ids_to_notify = if let Some(parent_id) = parent_id {
2187 db.get_channel_members(parent_id).await?
2188 } else {
2189 vec![session.user_id]
2190 };
2191
2192 let connection_pool = session.connection_pool().await;
2193 for user_id in user_ids_to_notify {
2194 for connection_id in connection_pool.user_connection_ids(user_id) {
2195 let mut update = update.clone();
2196 if user_id == session.user_id {
2197 update.channel_permissions.push(proto::ChannelPermission {
2198 channel_id: id.to_proto(),
2199 is_admin: true,
2200 });
2201 }
2202 session.peer.send(connection_id, update)?;
2203 }
2204 }
2205
2206 Ok(())
2207}
2208
2209async fn remove_channel(
2210 request: proto::RemoveChannel,
2211 response: Response<proto::RemoveChannel>,
2212 session: Session,
2213) -> Result<()> {
2214 let db = session.db().await;
2215
2216 let channel_id = request.channel_id;
2217 let (removed_channels, member_ids) = db
2218 .remove_channel(ChannelId::from_proto(channel_id), session.user_id)
2219 .await?;
2220 response.send(proto::Ack {})?;
2221
2222 // Notify members of removed channels
2223 let mut update = proto::UpdateChannels::default();
2224 update
2225 .remove_channels
2226 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2227
2228 let connection_pool = session.connection_pool().await;
2229 for member_id in member_ids {
2230 for connection_id in connection_pool.user_connection_ids(member_id) {
2231 session.peer.send(connection_id, update.clone())?;
2232 }
2233 }
2234
2235 Ok(())
2236}
2237
2238async fn invite_channel_member(
2239 request: proto::InviteChannelMember,
2240 response: Response<proto::InviteChannelMember>,
2241 session: Session,
2242) -> Result<()> {
2243 let db = session.db().await;
2244 let channel_id = ChannelId::from_proto(request.channel_id);
2245 let invitee_id = UserId::from_proto(request.user_id);
2246 db.invite_channel_member(channel_id, invitee_id, session.user_id, request.admin)
2247 .await?;
2248
2249 let (channel, _) = db
2250 .get_channel(channel_id, session.user_id)
2251 .await?
2252 .ok_or_else(|| anyhow!("channel not found"))?;
2253
2254 let mut update = proto::UpdateChannels::default();
2255 update.channel_invitations.push(proto::Channel {
2256 id: channel.id.to_proto(),
2257 name: channel.name,
2258 parent_id: None,
2259 });
2260 for connection_id in session
2261 .connection_pool()
2262 .await
2263 .user_connection_ids(invitee_id)
2264 {
2265 session.peer.send(connection_id, update.clone())?;
2266 }
2267
2268 response.send(proto::Ack {})?;
2269 Ok(())
2270}
2271
2272async fn remove_channel_member(
2273 request: proto::RemoveChannelMember,
2274 response: Response<proto::RemoveChannelMember>,
2275 session: Session,
2276) -> Result<()> {
2277 let db = session.db().await;
2278 let channel_id = ChannelId::from_proto(request.channel_id);
2279 let member_id = UserId::from_proto(request.user_id);
2280
2281 db.remove_channel_member(channel_id, member_id, session.user_id)
2282 .await?;
2283
2284 let mut update = proto::UpdateChannels::default();
2285 update.remove_channels.push(channel_id.to_proto());
2286
2287 for connection_id in session
2288 .connection_pool()
2289 .await
2290 .user_connection_ids(member_id)
2291 {
2292 session.peer.send(connection_id, update.clone())?;
2293 }
2294
2295 response.send(proto::Ack {})?;
2296 Ok(())
2297}
2298
2299async fn set_channel_member_admin(
2300 request: proto::SetChannelMemberAdmin,
2301 response: Response<proto::SetChannelMemberAdmin>,
2302 session: Session,
2303) -> Result<()> {
2304 let db = session.db().await;
2305 let channel_id = ChannelId::from_proto(request.channel_id);
2306 let member_id = UserId::from_proto(request.user_id);
2307 db.set_channel_member_admin(channel_id, session.user_id, member_id, request.admin)
2308 .await?;
2309
2310 let (channel, has_accepted) = db
2311 .get_channel(channel_id, member_id)
2312 .await?
2313 .ok_or_else(|| anyhow!("channel not found"))?;
2314
2315 let mut update = proto::UpdateChannels::default();
2316 if has_accepted {
2317 update.channel_permissions.push(proto::ChannelPermission {
2318 channel_id: channel.id.to_proto(),
2319 is_admin: request.admin,
2320 });
2321 }
2322
2323 for connection_id in session
2324 .connection_pool()
2325 .await
2326 .user_connection_ids(member_id)
2327 {
2328 session.peer.send(connection_id, update.clone())?;
2329 }
2330
2331 response.send(proto::Ack {})?;
2332 Ok(())
2333}
2334
2335async fn rename_channel(
2336 request: proto::RenameChannel,
2337 response: Response<proto::RenameChannel>,
2338 session: Session,
2339) -> Result<()> {
2340 let db = session.db().await;
2341 let channel_id = ChannelId::from_proto(request.channel_id);
2342 let new_name = db
2343 .rename_channel(channel_id, session.user_id, &request.name)
2344 .await?;
2345
2346 let channel = proto::Channel {
2347 id: request.channel_id,
2348 name: new_name,
2349 parent_id: None,
2350 };
2351 response.send(proto::ChannelResponse {
2352 channel: Some(channel.clone()),
2353 })?;
2354 let mut update = proto::UpdateChannels::default();
2355 update.channels.push(channel);
2356
2357 let member_ids = db.get_channel_members(channel_id).await?;
2358
2359 let connection_pool = session.connection_pool().await;
2360 for member_id in member_ids {
2361 for connection_id in connection_pool.user_connection_ids(member_id) {
2362 session.peer.send(connection_id, update.clone())?;
2363 }
2364 }
2365
2366 Ok(())
2367}
2368
2369async fn get_channel_members(
2370 request: proto::GetChannelMembers,
2371 response: Response<proto::GetChannelMembers>,
2372 session: Session,
2373) -> Result<()> {
2374 let db = session.db().await;
2375 let channel_id = ChannelId::from_proto(request.channel_id);
2376 let members = db
2377 .get_channel_member_details(channel_id, session.user_id)
2378 .await?;
2379 response.send(proto::GetChannelMembersResponse { members })?;
2380 Ok(())
2381}
2382
2383async fn respond_to_channel_invite(
2384 request: proto::RespondToChannelInvite,
2385 response: Response<proto::RespondToChannelInvite>,
2386 session: Session,
2387) -> Result<()> {
2388 let db = session.db().await;
2389 let channel_id = ChannelId::from_proto(request.channel_id);
2390 db.respond_to_channel_invite(channel_id, session.user_id, request.accept)
2391 .await?;
2392
2393 let mut update = proto::UpdateChannels::default();
2394 update
2395 .remove_channel_invitations
2396 .push(channel_id.to_proto());
2397 if request.accept {
2398 let result = db.get_channels_for_user(session.user_id).await?;
2399 update
2400 .channels
2401 .extend(result.channels.into_iter().map(|channel| proto::Channel {
2402 id: channel.id.to_proto(),
2403 name: channel.name,
2404 parent_id: channel.parent_id.map(ChannelId::to_proto),
2405 }));
2406 update
2407 .channel_participants
2408 .extend(
2409 result
2410 .channel_participants
2411 .into_iter()
2412 .map(|(channel_id, user_ids)| proto::ChannelParticipants {
2413 channel_id: channel_id.to_proto(),
2414 participant_user_ids: user_ids.into_iter().map(UserId::to_proto).collect(),
2415 }),
2416 );
2417 update
2418 .channel_permissions
2419 .extend(
2420 result
2421 .channels_with_admin_privileges
2422 .into_iter()
2423 .map(|channel_id| proto::ChannelPermission {
2424 channel_id: channel_id.to_proto(),
2425 is_admin: true,
2426 }),
2427 );
2428 }
2429 session.peer.send(session.connection_id, update)?;
2430 response.send(proto::Ack {})?;
2431
2432 Ok(())
2433}
2434
2435async fn join_channel(
2436 request: proto::JoinChannel,
2437 response: Response<proto::JoinChannel>,
2438 session: Session,
2439) -> Result<()> {
2440 let channel_id = ChannelId::from_proto(request.channel_id);
2441
2442 let joined_room = {
2443 leave_room_for_session(&session).await?;
2444 let db = session.db().await;
2445
2446 let room_id = db.room_id_for_channel(channel_id).await?;
2447
2448 let joined_room = db
2449 .join_room(room_id, session.user_id, session.connection_id)
2450 .await?;
2451
2452 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2453 let token = live_kit
2454 .room_token(
2455 &joined_room.room.live_kit_room,
2456 &session.user_id.to_string(),
2457 )
2458 .trace_err()?;
2459
2460 Some(LiveKitConnectionInfo {
2461 server_url: live_kit.url().into(),
2462 token,
2463 })
2464 });
2465
2466 response.send(proto::JoinRoomResponse {
2467 room: Some(joined_room.room.clone()),
2468 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2469 live_kit_connection_info,
2470 })?;
2471
2472 room_updated(&joined_room.room, &session.peer);
2473
2474 joined_room.into_inner()
2475 };
2476
2477 channel_updated(
2478 channel_id,
2479 &joined_room.room,
2480 &joined_room.channel_members,
2481 &session.peer,
2482 &*session.connection_pool().await,
2483 );
2484
2485 update_user_contacts(session.user_id, &session).await?;
2486
2487 Ok(())
2488}
2489
2490async fn join_channel_buffer(
2491 request: proto::JoinChannelBuffer,
2492 response: Response<proto::JoinChannelBuffer>,
2493 session: Session,
2494) -> Result<()> {
2495 let db = session.db().await;
2496 let channel_id = ChannelId::from_proto(request.channel_id);
2497
2498 let open_response = db
2499 .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2500 .await?;
2501
2502 let replica_id = open_response.replica_id;
2503 let collaborators = open_response.collaborators.clone();
2504
2505 response.send(open_response)?;
2506
2507 let update = AddChannelBufferCollaborator {
2508 channel_id: channel_id.to_proto(),
2509 collaborator: Some(proto::Collaborator {
2510 user_id: session.user_id.to_proto(),
2511 peer_id: Some(session.connection_id.into()),
2512 replica_id,
2513 }),
2514 };
2515 channel_buffer_updated(
2516 session.connection_id,
2517 collaborators
2518 .iter()
2519 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2520 &update,
2521 &session.peer,
2522 );
2523
2524 Ok(())
2525}
2526
2527async fn update_channel_buffer(
2528 request: proto::UpdateChannelBuffer,
2529 session: Session,
2530) -> Result<()> {
2531 let db = session.db().await;
2532 let channel_id = ChannelId::from_proto(request.channel_id);
2533
2534 let collaborators = db
2535 .update_channel_buffer(channel_id, session.user_id, &request.operations)
2536 .await?;
2537
2538 channel_buffer_updated(
2539 session.connection_id,
2540 collaborators,
2541 &proto::UpdateChannelBuffer {
2542 channel_id: channel_id.to_proto(),
2543 operations: request.operations,
2544 },
2545 &session.peer,
2546 );
2547 Ok(())
2548}
2549
2550async fn leave_channel_buffer(
2551 request: proto::LeaveChannelBuffer,
2552 response: Response<proto::LeaveChannelBuffer>,
2553 session: Session,
2554) -> Result<()> {
2555 let db = session.db().await;
2556 let channel_id = ChannelId::from_proto(request.channel_id);
2557
2558 let collaborators_to_notify = db
2559 .leave_channel_buffer(channel_id, session.connection_id)
2560 .await?;
2561
2562 response.send(Ack {})?;
2563
2564 channel_buffer_updated(
2565 session.connection_id,
2566 collaborators_to_notify,
2567 &proto::RemoveChannelBufferCollaborator {
2568 channel_id: channel_id.to_proto(),
2569 peer_id: Some(session.connection_id.into()),
2570 },
2571 &session.peer,
2572 );
2573
2574 Ok(())
2575}
2576
2577fn channel_buffer_updated<T: EnvelopedMessage>(
2578 sender_id: ConnectionId,
2579 collaborators: impl IntoIterator<Item = ConnectionId>,
2580 message: &T,
2581 peer: &Peer,
2582) {
2583 broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2584 peer.send(peer_id.into(), message.clone())
2585 });
2586}
2587
2588async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
2589 let project_id = ProjectId::from_proto(request.project_id);
2590 let project_connection_ids = session
2591 .db()
2592 .await
2593 .project_connection_ids(project_id, session.connection_id)
2594 .await?;
2595 broadcast(
2596 Some(session.connection_id),
2597 project_connection_ids.iter().copied(),
2598 |connection_id| {
2599 session
2600 .peer
2601 .forward_send(session.connection_id, connection_id, request.clone())
2602 },
2603 );
2604 Ok(())
2605}
2606
2607async fn get_private_user_info(
2608 _request: proto::GetPrivateUserInfo,
2609 response: Response<proto::GetPrivateUserInfo>,
2610 session: Session,
2611) -> Result<()> {
2612 let db = session.db().await;
2613
2614 let metrics_id = db.get_user_metrics_id(session.user_id).await?;
2615 let user = db
2616 .get_user_by_id(session.user_id)
2617 .await?
2618 .ok_or_else(|| anyhow!("user not found"))?;
2619 let flags = db.get_user_flags(session.user_id).await?;
2620
2621 response.send(proto::GetPrivateUserInfoResponse {
2622 metrics_id,
2623 staff: user.admin,
2624 flags,
2625 })?;
2626 Ok(())
2627}
2628
2629fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
2630 match message {
2631 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
2632 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
2633 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
2634 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
2635 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
2636 code: frame.code.into(),
2637 reason: frame.reason,
2638 })),
2639 }
2640}
2641
2642fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
2643 match message {
2644 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
2645 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
2646 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
2647 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
2648 AxumMessage::Close(frame) => {
2649 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
2650 code: frame.code.into(),
2651 reason: frame.reason,
2652 }))
2653 }
2654 }
2655}
2656
2657fn build_initial_channels_update(
2658 channels: ChannelsForUser,
2659 channel_invites: Vec<db::Channel>,
2660) -> proto::UpdateChannels {
2661 let mut update = proto::UpdateChannels::default();
2662
2663 for channel in channels.channels {
2664 update.channels.push(proto::Channel {
2665 id: channel.id.to_proto(),
2666 name: channel.name,
2667 parent_id: channel.parent_id.map(|id| id.to_proto()),
2668 });
2669 }
2670
2671 for (channel_id, participants) in channels.channel_participants {
2672 update
2673 .channel_participants
2674 .push(proto::ChannelParticipants {
2675 channel_id: channel_id.to_proto(),
2676 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
2677 });
2678 }
2679
2680 update
2681 .channel_permissions
2682 .extend(
2683 channels
2684 .channels_with_admin_privileges
2685 .into_iter()
2686 .map(|id| proto::ChannelPermission {
2687 channel_id: id.to_proto(),
2688 is_admin: true,
2689 }),
2690 );
2691
2692 for channel in channel_invites {
2693 update.channel_invitations.push(proto::Channel {
2694 id: channel.id.to_proto(),
2695 name: channel.name,
2696 parent_id: None,
2697 });
2698 }
2699
2700 update
2701}
2702
2703fn build_initial_contacts_update(
2704 contacts: Vec<db::Contact>,
2705 pool: &ConnectionPool,
2706) -> proto::UpdateContacts {
2707 let mut update = proto::UpdateContacts::default();
2708
2709 for contact in contacts {
2710 match contact {
2711 db::Contact::Accepted {
2712 user_id,
2713 should_notify,
2714 busy,
2715 } => {
2716 update
2717 .contacts
2718 .push(contact_for_user(user_id, should_notify, busy, &pool));
2719 }
2720 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
2721 db::Contact::Incoming {
2722 user_id,
2723 should_notify,
2724 } => update
2725 .incoming_requests
2726 .push(proto::IncomingContactRequest {
2727 requester_id: user_id.to_proto(),
2728 should_notify,
2729 }),
2730 }
2731 }
2732
2733 update
2734}
2735
2736fn contact_for_user(
2737 user_id: UserId,
2738 should_notify: bool,
2739 busy: bool,
2740 pool: &ConnectionPool,
2741) -> proto::Contact {
2742 proto::Contact {
2743 user_id: user_id.to_proto(),
2744 online: pool.is_user_online(user_id),
2745 busy,
2746 should_notify,
2747 }
2748}
2749
2750fn room_updated(room: &proto::Room, peer: &Peer) {
2751 broadcast(
2752 None,
2753 room.participants
2754 .iter()
2755 .filter_map(|participant| Some(participant.peer_id?.into())),
2756 |peer_id| {
2757 peer.send(
2758 peer_id.into(),
2759 proto::RoomUpdated {
2760 room: Some(room.clone()),
2761 },
2762 )
2763 },
2764 );
2765}
2766
2767fn channel_updated(
2768 channel_id: ChannelId,
2769 room: &proto::Room,
2770 channel_members: &[UserId],
2771 peer: &Peer,
2772 pool: &ConnectionPool,
2773) {
2774 let participants = room
2775 .participants
2776 .iter()
2777 .map(|p| p.user_id)
2778 .collect::<Vec<_>>();
2779
2780 broadcast(
2781 None,
2782 channel_members
2783 .iter()
2784 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2785 |peer_id| {
2786 peer.send(
2787 peer_id.into(),
2788 proto::UpdateChannels {
2789 channel_participants: vec![proto::ChannelParticipants {
2790 channel_id: channel_id.to_proto(),
2791 participant_user_ids: participants.clone(),
2792 }],
2793 ..Default::default()
2794 },
2795 )
2796 },
2797 );
2798}
2799
2800async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
2801 let db = session.db().await;
2802
2803 let contacts = db.get_contacts(user_id).await?;
2804 let busy = db.is_user_busy(user_id).await?;
2805
2806 let pool = session.connection_pool().await;
2807 let updated_contact = contact_for_user(user_id, false, busy, &pool);
2808 for contact in contacts {
2809 if let db::Contact::Accepted {
2810 user_id: contact_user_id,
2811 ..
2812 } = contact
2813 {
2814 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
2815 session
2816 .peer
2817 .send(
2818 contact_conn_id,
2819 proto::UpdateContacts {
2820 contacts: vec![updated_contact.clone()],
2821 remove_contacts: Default::default(),
2822 incoming_requests: Default::default(),
2823 remove_incoming_requests: Default::default(),
2824 outgoing_requests: Default::default(),
2825 remove_outgoing_requests: Default::default(),
2826 },
2827 )
2828 .trace_err();
2829 }
2830 }
2831 }
2832 Ok(())
2833}
2834
2835async fn leave_room_for_session(session: &Session) -> Result<()> {
2836 let mut contacts_to_update = HashSet::default();
2837
2838 let room_id;
2839 let canceled_calls_to_user_ids;
2840 let live_kit_room;
2841 let delete_live_kit_room;
2842 let room;
2843 let channel_members;
2844 let channel_id;
2845
2846 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
2847 contacts_to_update.insert(session.user_id);
2848
2849 for project in left_room.left_projects.values() {
2850 project_left(project, session);
2851 }
2852
2853 room_id = RoomId::from_proto(left_room.room.id);
2854 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
2855 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
2856 delete_live_kit_room = left_room.deleted;
2857 room = mem::take(&mut left_room.room);
2858 channel_members = mem::take(&mut left_room.channel_members);
2859 channel_id = left_room.channel_id;
2860
2861 room_updated(&room, &session.peer);
2862 } else {
2863 return Ok(());
2864 }
2865
2866 if let Some(channel_id) = channel_id {
2867 channel_updated(
2868 channel_id,
2869 &room,
2870 &channel_members,
2871 &session.peer,
2872 &*session.connection_pool().await,
2873 );
2874 }
2875
2876 {
2877 let pool = session.connection_pool().await;
2878 for canceled_user_id in canceled_calls_to_user_ids {
2879 for connection_id in pool.user_connection_ids(canceled_user_id) {
2880 session
2881 .peer
2882 .send(
2883 connection_id,
2884 proto::CallCanceled {
2885 room_id: room_id.to_proto(),
2886 },
2887 )
2888 .trace_err();
2889 }
2890 contacts_to_update.insert(canceled_user_id);
2891 }
2892 }
2893
2894 for contact_user_id in contacts_to_update {
2895 update_user_contacts(contact_user_id, &session).await?;
2896 }
2897
2898 if let Some(live_kit) = session.live_kit_client.as_ref() {
2899 live_kit
2900 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
2901 .await
2902 .trace_err();
2903
2904 if delete_live_kit_room {
2905 live_kit.delete_room(live_kit_room).await.trace_err();
2906 }
2907 }
2908
2909 Ok(())
2910}
2911
2912async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
2913 let left_channel_buffers = session
2914 .db()
2915 .await
2916 .leave_channel_buffers(session.connection_id)
2917 .await?;
2918
2919 for (channel_id, connections) in left_channel_buffers {
2920 channel_buffer_updated(
2921 session.connection_id,
2922 connections,
2923 &proto::RemoveChannelBufferCollaborator {
2924 channel_id: channel_id.to_proto(),
2925 peer_id: Some(session.connection_id.into()),
2926 },
2927 &session.peer,
2928 );
2929 }
2930
2931 Ok(())
2932}
2933
2934fn project_left(project: &db::LeftProject, session: &Session) {
2935 for connection_id in &project.connection_ids {
2936 if project.host_user_id == session.user_id {
2937 session
2938 .peer
2939 .send(
2940 *connection_id,
2941 proto::UnshareProject {
2942 project_id: project.id.to_proto(),
2943 },
2944 )
2945 .trace_err();
2946 } else {
2947 session
2948 .peer
2949 .send(
2950 *connection_id,
2951 proto::RemoveProjectCollaborator {
2952 project_id: project.id.to_proto(),
2953 peer_id: Some(session.connection_id.into()),
2954 },
2955 )
2956 .trace_err();
2957 }
2958 }
2959}
2960
2961pub trait ResultExt {
2962 type Ok;
2963
2964 fn trace_err(self) -> Option<Self::Ok>;
2965}
2966
2967impl<T, E> ResultExt for Result<T, E>
2968where
2969 E: std::fmt::Debug,
2970{
2971 type Ok = T;
2972
2973 fn trace_err(self) -> Option<T> {
2974 match self {
2975 Ok(value) => Some(value),
2976 Err(error) => {
2977 tracing::error!("{:?}", error);
2978 None
2979 }
2980 }
2981 }
2982}