1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{self, ChannelId, ChannelsForUser, Database, ProjectId, RoomId, ServerId, User, UserId},
6 executor::Executor,
7 AppState, Result,
8};
9use anyhow::anyhow;
10use async_tungstenite::tungstenite::{
11 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
12};
13use axum::{
14 body::Body,
15 extract::{
16 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
17 ConnectInfo, WebSocketUpgrade,
18 },
19 headers::{Header, HeaderName},
20 http::StatusCode,
21 middleware,
22 response::IntoResponse,
23 routing::get,
24 Extension, Router, TypedHeader,
25};
26use collections::{HashMap, HashSet};
27pub use connection_pool::ConnectionPool;
28use futures::{
29 channel::oneshot,
30 future::{self, BoxFuture},
31 stream::FuturesUnordered,
32 FutureExt, SinkExt, StreamExt, TryStreamExt,
33};
34use lazy_static::lazy_static;
35use prometheus::{register_int_gauge, IntGauge};
36use rpc::{
37 proto::{
38 self, Ack, AddChannelBufferCollaborator, AnyTypedEnvelope, EntityMessage, EnvelopedMessage,
39 LiveKitConnectionInfo, RequestMessage,
40 },
41 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
42};
43use serde::{Serialize, Serializer};
44use std::{
45 any::TypeId,
46 fmt,
47 future::Future,
48 marker::PhantomData,
49 mem,
50 net::SocketAddr,
51 ops::{Deref, DerefMut},
52 rc::Rc,
53 sync::{
54 atomic::{AtomicBool, Ordering::SeqCst},
55 Arc,
56 },
57 time::{Duration, Instant},
58};
59use tokio::sync::{watch, Semaphore};
60use tower::ServiceBuilder;
61use tracing::{info_span, instrument, Instrument};
62
63pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
64pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
65
66lazy_static! {
67 static ref METRIC_CONNECTIONS: IntGauge =
68 register_int_gauge!("connections", "number of connections").unwrap();
69 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
70 "shared_projects",
71 "number of open projects with one or more guests"
72 )
73 .unwrap();
74}
75
76type MessageHandler =
77 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
78
79struct Response<R> {
80 peer: Arc<Peer>,
81 receipt: Receipt<R>,
82 responded: Arc<AtomicBool>,
83}
84
85impl<R: RequestMessage> Response<R> {
86 fn send(self, payload: R::Response) -> Result<()> {
87 self.responded.store(true, SeqCst);
88 self.peer.respond(self.receipt, payload)?;
89 Ok(())
90 }
91}
92
93#[derive(Clone)]
94struct Session {
95 user_id: UserId,
96 connection_id: ConnectionId,
97 db: Arc<tokio::sync::Mutex<DbHandle>>,
98 peer: Arc<Peer>,
99 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
100 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
101 executor: Executor,
102}
103
104impl Session {
105 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
106 #[cfg(test)]
107 tokio::task::yield_now().await;
108 let guard = self.db.lock().await;
109 #[cfg(test)]
110 tokio::task::yield_now().await;
111 guard
112 }
113
114 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
115 #[cfg(test)]
116 tokio::task::yield_now().await;
117 let guard = self.connection_pool.lock();
118 ConnectionPoolGuard {
119 guard,
120 _not_send: PhantomData,
121 }
122 }
123}
124
125impl fmt::Debug for Session {
126 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127 f.debug_struct("Session")
128 .field("user_id", &self.user_id)
129 .field("connection_id", &self.connection_id)
130 .finish()
131 }
132}
133
134struct DbHandle(Arc<Database>);
135
136impl Deref for DbHandle {
137 type Target = Database;
138
139 fn deref(&self) -> &Self::Target {
140 self.0.as_ref()
141 }
142}
143
144pub struct Server {
145 id: parking_lot::Mutex<ServerId>,
146 peer: Arc<Peer>,
147 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
148 app_state: Arc<AppState>,
149 executor: Executor,
150 handlers: HashMap<TypeId, MessageHandler>,
151 teardown: watch::Sender<()>,
152}
153
154pub(crate) struct ConnectionPoolGuard<'a> {
155 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
156 _not_send: PhantomData<Rc<()>>,
157}
158
159#[derive(Serialize)]
160pub struct ServerSnapshot<'a> {
161 peer: &'a Peer,
162 #[serde(serialize_with = "serialize_deref")]
163 connection_pool: ConnectionPoolGuard<'a>,
164}
165
166pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
167where
168 S: Serializer,
169 T: Deref<Target = U>,
170 U: Serialize,
171{
172 Serialize::serialize(value.deref(), serializer)
173}
174
175impl Server {
176 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
177 let mut server = Self {
178 id: parking_lot::Mutex::new(id),
179 peer: Peer::new(id.0 as u32),
180 app_state,
181 executor,
182 connection_pool: Default::default(),
183 handlers: Default::default(),
184 teardown: watch::channel(()).0,
185 };
186
187 server
188 .add_request_handler(ping)
189 .add_request_handler(create_room)
190 .add_request_handler(join_room)
191 .add_request_handler(rejoin_room)
192 .add_request_handler(leave_room)
193 .add_request_handler(call)
194 .add_request_handler(cancel_call)
195 .add_message_handler(decline_call)
196 .add_request_handler(update_participant_location)
197 .add_request_handler(share_project)
198 .add_message_handler(unshare_project)
199 .add_request_handler(join_project)
200 .add_message_handler(leave_project)
201 .add_request_handler(update_project)
202 .add_request_handler(update_worktree)
203 .add_message_handler(start_language_server)
204 .add_message_handler(update_language_server)
205 .add_message_handler(update_diagnostic_summary)
206 .add_message_handler(update_worktree_settings)
207 .add_message_handler(refresh_inlay_hints)
208 .add_request_handler(forward_project_request::<proto::GetHover>)
209 .add_request_handler(forward_project_request::<proto::GetDefinition>)
210 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
211 .add_request_handler(forward_project_request::<proto::GetReferences>)
212 .add_request_handler(forward_project_request::<proto::SearchProject>)
213 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
214 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
215 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
216 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
217 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
218 .add_request_handler(forward_project_request::<proto::GetCompletions>)
219 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
220 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
221 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
222 .add_request_handler(forward_project_request::<proto::PrepareRename>)
223 .add_request_handler(forward_project_request::<proto::PerformRename>)
224 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
225 .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
226 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
227 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
228 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
229 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
230 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
231 .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
232 .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
233 .add_request_handler(forward_project_request::<proto::InlayHints>)
234 .add_message_handler(create_buffer_for_peer)
235 .add_request_handler(update_buffer)
236 .add_message_handler(update_buffer_file)
237 .add_message_handler(buffer_reloaded)
238 .add_message_handler(buffer_saved)
239 .add_request_handler(forward_project_request::<proto::SaveBuffer>)
240 .add_request_handler(get_users)
241 .add_request_handler(fuzzy_search_users)
242 .add_request_handler(request_contact)
243 .add_request_handler(remove_contact)
244 .add_request_handler(respond_to_contact_request)
245 .add_request_handler(create_channel)
246 .add_request_handler(remove_channel)
247 .add_request_handler(invite_channel_member)
248 .add_request_handler(remove_channel_member)
249 .add_request_handler(set_channel_member_admin)
250 .add_request_handler(rename_channel)
251 .add_request_handler(join_channel_buffer)
252 .add_request_handler(leave_channel_buffer)
253 .add_message_handler(update_channel_buffer)
254 .add_request_handler(rejoin_channel_buffers)
255 .add_request_handler(get_channel_members)
256 .add_request_handler(respond_to_channel_invite)
257 .add_request_handler(join_channel)
258 .add_request_handler(follow)
259 .add_message_handler(unfollow)
260 .add_message_handler(update_followers)
261 .add_message_handler(update_diff_base)
262 .add_request_handler(get_private_user_info);
263
264 Arc::new(server)
265 }
266
267 pub async fn start(&self) -> Result<()> {
268 let server_id = *self.id.lock();
269 let app_state = self.app_state.clone();
270 let peer = self.peer.clone();
271 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
272 let pool = self.connection_pool.clone();
273 let live_kit_client = self.app_state.live_kit_client.clone();
274
275 let span = info_span!("start server");
276 self.executor.spawn_detached(
277 async move {
278 tracing::info!("waiting for cleanup timeout");
279 timeout.await;
280 tracing::info!("cleanup timeout expired, retrieving stale rooms");
281 if let Some((room_ids, channel_ids)) = app_state
282 .db
283 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
284 .await
285 .trace_err()
286 {
287 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
288 tracing::info!(
289 stale_channel_buffer_count = channel_ids.len(),
290 "retrieved stale channel buffers"
291 );
292
293 for channel_id in channel_ids {
294 if let Some(refreshed_channel_buffer) = app_state
295 .db
296 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
297 .await
298 .trace_err()
299 {
300 for connection_id in refreshed_channel_buffer.connection_ids {
301 for message in &refreshed_channel_buffer.removed_collaborators {
302 peer.send(connection_id, message.clone()).trace_err();
303 }
304 }
305 }
306 }
307
308 for room_id in room_ids {
309 let mut contacts_to_update = HashSet::default();
310 let mut canceled_calls_to_user_ids = Vec::new();
311 let mut live_kit_room = String::new();
312 let mut delete_live_kit_room = false;
313
314 if let Some(mut refreshed_room) = app_state
315 .db
316 .clear_stale_room_participants(room_id, server_id)
317 .await
318 .trace_err()
319 {
320 tracing::info!(
321 room_id = room_id.0,
322 new_participant_count = refreshed_room.room.participants.len(),
323 "refreshed room"
324 );
325 room_updated(&refreshed_room.room, &peer);
326 if let Some(channel_id) = refreshed_room.channel_id {
327 channel_updated(
328 channel_id,
329 &refreshed_room.room,
330 &refreshed_room.channel_members,
331 &peer,
332 &*pool.lock(),
333 );
334 }
335 contacts_to_update
336 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
337 contacts_to_update
338 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
339 canceled_calls_to_user_ids =
340 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
341 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
342 delete_live_kit_room = refreshed_room.room.participants.is_empty();
343 }
344
345 {
346 let pool = pool.lock();
347 for canceled_user_id in canceled_calls_to_user_ids {
348 for connection_id in pool.user_connection_ids(canceled_user_id) {
349 peer.send(
350 connection_id,
351 proto::CallCanceled {
352 room_id: room_id.to_proto(),
353 },
354 )
355 .trace_err();
356 }
357 }
358 }
359
360 for user_id in contacts_to_update {
361 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
362 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
363 if let Some((busy, contacts)) = busy.zip(contacts) {
364 let pool = pool.lock();
365 let updated_contact = contact_for_user(user_id, false, busy, &pool);
366 for contact in contacts {
367 if let db::Contact::Accepted {
368 user_id: contact_user_id,
369 ..
370 } = contact
371 {
372 for contact_conn_id in
373 pool.user_connection_ids(contact_user_id)
374 {
375 peer.send(
376 contact_conn_id,
377 proto::UpdateContacts {
378 contacts: vec![updated_contact.clone()],
379 remove_contacts: Default::default(),
380 incoming_requests: Default::default(),
381 remove_incoming_requests: Default::default(),
382 outgoing_requests: Default::default(),
383 remove_outgoing_requests: Default::default(),
384 },
385 )
386 .trace_err();
387 }
388 }
389 }
390 }
391 }
392
393 if let Some(live_kit) = live_kit_client.as_ref() {
394 if delete_live_kit_room {
395 live_kit.delete_room(live_kit_room).await.trace_err();
396 }
397 }
398 }
399 }
400
401 app_state
402 .db
403 .delete_stale_servers(&app_state.config.zed_environment, server_id)
404 .await
405 .trace_err();
406 }
407 .instrument(span),
408 );
409 Ok(())
410 }
411
412 pub fn teardown(&self) {
413 self.peer.teardown();
414 self.connection_pool.lock().reset();
415 let _ = self.teardown.send(());
416 }
417
418 #[cfg(test)]
419 pub fn reset(&self, id: ServerId) {
420 self.teardown();
421 *self.id.lock() = id;
422 self.peer.reset(id.0 as u32);
423 }
424
425 #[cfg(test)]
426 pub fn id(&self) -> ServerId {
427 *self.id.lock()
428 }
429
430 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
431 where
432 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
433 Fut: 'static + Send + Future<Output = Result<()>>,
434 M: EnvelopedMessage,
435 {
436 let prev_handler = self.handlers.insert(
437 TypeId::of::<M>(),
438 Box::new(move |envelope, session| {
439 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
440 let span = info_span!(
441 "handle message",
442 payload_type = envelope.payload_type_name()
443 );
444 span.in_scope(|| {
445 tracing::info!(
446 payload_type = envelope.payload_type_name(),
447 "message received"
448 );
449 });
450 let start_time = Instant::now();
451 let future = (handler)(*envelope, session);
452 async move {
453 let result = future.await;
454 let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
455 match result {
456 Err(error) => {
457 tracing::error!(%error, ?duration_ms, "error handling message")
458 }
459 Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
460 }
461 }
462 .instrument(span)
463 .boxed()
464 }),
465 );
466 if prev_handler.is_some() {
467 panic!("registered a handler for the same message twice");
468 }
469 self
470 }
471
472 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
473 where
474 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
475 Fut: 'static + Send + Future<Output = Result<()>>,
476 M: EnvelopedMessage,
477 {
478 self.add_handler(move |envelope, session| handler(envelope.payload, session));
479 self
480 }
481
482 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
483 where
484 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
485 Fut: Send + Future<Output = Result<()>>,
486 M: RequestMessage,
487 {
488 let handler = Arc::new(handler);
489 self.add_handler(move |envelope, session| {
490 let receipt = envelope.receipt();
491 let handler = handler.clone();
492 async move {
493 let peer = session.peer.clone();
494 let responded = Arc::new(AtomicBool::default());
495 let response = Response {
496 peer: peer.clone(),
497 responded: responded.clone(),
498 receipt,
499 };
500 match (handler)(envelope.payload, response, session).await {
501 Ok(()) => {
502 if responded.load(std::sync::atomic::Ordering::SeqCst) {
503 Ok(())
504 } else {
505 Err(anyhow!("handler did not send a response"))?
506 }
507 }
508 Err(error) => {
509 peer.respond_with_error(
510 receipt,
511 proto::Error {
512 message: error.to_string(),
513 },
514 )?;
515 Err(error)
516 }
517 }
518 }
519 })
520 }
521
522 pub fn handle_connection(
523 self: &Arc<Self>,
524 connection: Connection,
525 address: String,
526 user: User,
527 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
528 executor: Executor,
529 ) -> impl Future<Output = Result<()>> {
530 let this = self.clone();
531 let user_id = user.id;
532 let login = user.github_login;
533 let span = info_span!("handle connection", %user_id, %login, %address);
534 let mut teardown = self.teardown.subscribe();
535 async move {
536 let (connection_id, handle_io, mut incoming_rx) = this
537 .peer
538 .add_connection(connection, {
539 let executor = executor.clone();
540 move |duration| executor.sleep(duration)
541 });
542
543 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
544 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
545 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
546
547 if let Some(send_connection_id) = send_connection_id.take() {
548 let _ = send_connection_id.send(connection_id);
549 }
550
551 if !user.connected_once {
552 this.peer.send(connection_id, proto::ShowContacts {})?;
553 this.app_state.db.set_user_connected_once(user_id, true).await?;
554 }
555
556 let (contacts, channels_for_user, channel_invites) = future::try_join3(
557 this.app_state.db.get_contacts(user_id),
558 this.app_state.db.get_channels_for_user(user_id),
559 this.app_state.db.get_channel_invites_for_user(user_id)
560 ).await?;
561
562 {
563 let mut pool = this.connection_pool.lock();
564 pool.add_connection(connection_id, user_id, user.admin);
565 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
566 this.peer.send(connection_id, build_initial_channels_update(
567 channels_for_user,
568 channel_invites
569 ))?;
570 }
571
572 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
573 this.peer.send(connection_id, incoming_call)?;
574 }
575
576 let session = Session {
577 user_id,
578 connection_id,
579 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
580 peer: this.peer.clone(),
581 connection_pool: this.connection_pool.clone(),
582 live_kit_client: this.app_state.live_kit_client.clone(),
583 executor: executor.clone(),
584 };
585 update_user_contacts(user_id, &session).await?;
586
587 let handle_io = handle_io.fuse();
588 futures::pin_mut!(handle_io);
589
590 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
591 // This prevents deadlocks when e.g., client A performs a request to client B and
592 // client B performs a request to client A. If both clients stop processing further
593 // messages until their respective request completes, they won't have a chance to
594 // respond to the other client's request and cause a deadlock.
595 //
596 // This arrangement ensures we will attempt to process earlier messages first, but fall
597 // back to processing messages arrived later in the spirit of making progress.
598 let mut foreground_message_handlers = FuturesUnordered::new();
599 let concurrent_handlers = Arc::new(Semaphore::new(256));
600 loop {
601 let next_message = async {
602 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
603 let message = incoming_rx.next().await;
604 (permit, message)
605 }.fuse();
606 futures::pin_mut!(next_message);
607 futures::select_biased! {
608 _ = teardown.changed().fuse() => return Ok(()),
609 result = handle_io => {
610 if let Err(error) = result {
611 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
612 }
613 break;
614 }
615 _ = foreground_message_handlers.next() => {}
616 next_message = next_message => {
617 let (permit, message) = next_message;
618 if let Some(message) = message {
619 let type_name = message.payload_type_name();
620 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
621 let span_enter = span.enter();
622 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
623 let is_background = message.is_background();
624 let handle_message = (handler)(message, session.clone());
625 drop(span_enter);
626
627 let handle_message = async move {
628 handle_message.await;
629 drop(permit);
630 }.instrument(span);
631 if is_background {
632 executor.spawn_detached(handle_message);
633 } else {
634 foreground_message_handlers.push(handle_message);
635 }
636 } else {
637 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
638 }
639 } else {
640 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
641 break;
642 }
643 }
644 }
645 }
646
647 drop(foreground_message_handlers);
648 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
649 if let Err(error) = connection_lost(session, teardown, executor).await {
650 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
651 }
652
653 Ok(())
654 }.instrument(span)
655 }
656
657 pub async fn invite_code_redeemed(
658 self: &Arc<Self>,
659 inviter_id: UserId,
660 invitee_id: UserId,
661 ) -> Result<()> {
662 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
663 if let Some(code) = &user.invite_code {
664 let pool = self.connection_pool.lock();
665 let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
666 for connection_id in pool.user_connection_ids(inviter_id) {
667 self.peer.send(
668 connection_id,
669 proto::UpdateContacts {
670 contacts: vec![invitee_contact.clone()],
671 ..Default::default()
672 },
673 )?;
674 self.peer.send(
675 connection_id,
676 proto::UpdateInviteInfo {
677 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
678 count: user.invite_count as u32,
679 },
680 )?;
681 }
682 }
683 }
684 Ok(())
685 }
686
687 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
688 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
689 if let Some(invite_code) = &user.invite_code {
690 let pool = self.connection_pool.lock();
691 for connection_id in pool.user_connection_ids(user_id) {
692 self.peer.send(
693 connection_id,
694 proto::UpdateInviteInfo {
695 url: format!(
696 "{}{}",
697 self.app_state.config.invite_link_prefix, invite_code
698 ),
699 count: user.invite_count as u32,
700 },
701 )?;
702 }
703 }
704 }
705 Ok(())
706 }
707
708 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
709 ServerSnapshot {
710 connection_pool: ConnectionPoolGuard {
711 guard: self.connection_pool.lock(),
712 _not_send: PhantomData,
713 },
714 peer: &self.peer,
715 }
716 }
717}
718
719impl<'a> Deref for ConnectionPoolGuard<'a> {
720 type Target = ConnectionPool;
721
722 fn deref(&self) -> &Self::Target {
723 &*self.guard
724 }
725}
726
727impl<'a> DerefMut for ConnectionPoolGuard<'a> {
728 fn deref_mut(&mut self) -> &mut Self::Target {
729 &mut *self.guard
730 }
731}
732
733impl<'a> Drop for ConnectionPoolGuard<'a> {
734 fn drop(&mut self) {
735 #[cfg(test)]
736 self.check_invariants();
737 }
738}
739
740fn broadcast<F>(
741 sender_id: Option<ConnectionId>,
742 receiver_ids: impl IntoIterator<Item = ConnectionId>,
743 mut f: F,
744) where
745 F: FnMut(ConnectionId) -> anyhow::Result<()>,
746{
747 for receiver_id in receiver_ids {
748 if Some(receiver_id) != sender_id {
749 if let Err(error) = f(receiver_id) {
750 tracing::error!("failed to send to {:?} {}", receiver_id, error);
751 }
752 }
753 }
754}
755
756lazy_static! {
757 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
758}
759
760pub struct ProtocolVersion(u32);
761
762impl Header for ProtocolVersion {
763 fn name() -> &'static HeaderName {
764 &ZED_PROTOCOL_VERSION
765 }
766
767 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
768 where
769 Self: Sized,
770 I: Iterator<Item = &'i axum::http::HeaderValue>,
771 {
772 let version = values
773 .next()
774 .ok_or_else(axum::headers::Error::invalid)?
775 .to_str()
776 .map_err(|_| axum::headers::Error::invalid())?
777 .parse()
778 .map_err(|_| axum::headers::Error::invalid())?;
779 Ok(Self(version))
780 }
781
782 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
783 values.extend([self.0.to_string().parse().unwrap()]);
784 }
785}
786
787pub fn routes(server: Arc<Server>) -> Router<Body> {
788 Router::new()
789 .route("/rpc", get(handle_websocket_request))
790 .layer(
791 ServiceBuilder::new()
792 .layer(Extension(server.app_state.clone()))
793 .layer(middleware::from_fn(auth::validate_header)),
794 )
795 .route("/metrics", get(handle_metrics))
796 .layer(Extension(server))
797}
798
799pub async fn handle_websocket_request(
800 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
801 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
802 Extension(server): Extension<Arc<Server>>,
803 Extension(user): Extension<User>,
804 ws: WebSocketUpgrade,
805) -> axum::response::Response {
806 if protocol_version != rpc::PROTOCOL_VERSION {
807 return (
808 StatusCode::UPGRADE_REQUIRED,
809 "client must be upgraded".to_string(),
810 )
811 .into_response();
812 }
813 let socket_address = socket_address.to_string();
814 ws.on_upgrade(move |socket| {
815 use util::ResultExt;
816 let socket = socket
817 .map_ok(to_tungstenite_message)
818 .err_into()
819 .with(|message| async move { Ok(to_axum_message(message)) });
820 let connection = Connection::new(Box::pin(socket));
821 async move {
822 server
823 .handle_connection(connection, socket_address, user, None, Executor::Production)
824 .await
825 .log_err();
826 }
827 })
828}
829
830pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
831 let connections = server
832 .connection_pool
833 .lock()
834 .connections()
835 .filter(|connection| !connection.admin)
836 .count();
837
838 METRIC_CONNECTIONS.set(connections as _);
839
840 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
841 METRIC_SHARED_PROJECTS.set(shared_projects as _);
842
843 let encoder = prometheus::TextEncoder::new();
844 let metric_families = prometheus::gather();
845 let encoded_metrics = encoder
846 .encode_to_string(&metric_families)
847 .map_err(|err| anyhow!("{}", err))?;
848 Ok(encoded_metrics)
849}
850
851#[instrument(err, skip(executor))]
852async fn connection_lost(
853 session: Session,
854 mut teardown: watch::Receiver<()>,
855 executor: Executor,
856) -> Result<()> {
857 session.peer.disconnect(session.connection_id);
858 session
859 .connection_pool()
860 .await
861 .remove_connection(session.connection_id)?;
862
863 session
864 .db()
865 .await
866 .connection_lost(session.connection_id)
867 .await
868 .trace_err();
869
870 futures::select_biased! {
871 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
872 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
873 leave_room_for_session(&session).await.trace_err();
874 leave_channel_buffers_for_session(&session)
875 .await
876 .trace_err();
877
878 if !session
879 .connection_pool()
880 .await
881 .is_user_online(session.user_id)
882 {
883 let db = session.db().await;
884 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
885 room_updated(&room, &session.peer);
886 }
887 }
888 update_user_contacts(session.user_id, &session).await?;
889
890
891 }
892 _ = teardown.changed().fuse() => {}
893 }
894
895 Ok(())
896}
897
898async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
899 response.send(proto::Ack {})?;
900 Ok(())
901}
902
903async fn create_room(
904 _request: proto::CreateRoom,
905 response: Response<proto::CreateRoom>,
906 session: Session,
907) -> Result<()> {
908 let live_kit_room = nanoid::nanoid!(30);
909
910 let live_kit_connection_info = {
911 let live_kit_room = live_kit_room.clone();
912 let live_kit = session.live_kit_client.as_ref();
913
914 util::async_iife!({
915 let live_kit = live_kit?;
916
917 live_kit
918 .create_room(live_kit_room.clone())
919 .await
920 .trace_err()?;
921
922 let token = live_kit
923 .room_token(&live_kit_room, &session.user_id.to_string())
924 .trace_err()?;
925
926 Some(proto::LiveKitConnectionInfo {
927 server_url: live_kit.url().into(),
928 token,
929 })
930 })
931 }
932 .await;
933
934 let room = session
935 .db()
936 .await
937 .create_room(session.user_id, session.connection_id, &live_kit_room)
938 .await?;
939
940 response.send(proto::CreateRoomResponse {
941 room: Some(room.clone()),
942 live_kit_connection_info,
943 })?;
944
945 update_user_contacts(session.user_id, &session).await?;
946 Ok(())
947}
948
949async fn join_room(
950 request: proto::JoinRoom,
951 response: Response<proto::JoinRoom>,
952 session: Session,
953) -> Result<()> {
954 let room_id = RoomId::from_proto(request.id);
955 let joined_room = {
956 let room = session
957 .db()
958 .await
959 .join_room(room_id, session.user_id, session.connection_id)
960 .await?;
961 room_updated(&room.room, &session.peer);
962 room.into_inner()
963 };
964
965 if let Some(channel_id) = joined_room.channel_id {
966 channel_updated(
967 channel_id,
968 &joined_room.room,
969 &joined_room.channel_members,
970 &session.peer,
971 &*session.connection_pool().await,
972 )
973 }
974
975 for connection_id in session
976 .connection_pool()
977 .await
978 .user_connection_ids(session.user_id)
979 {
980 session
981 .peer
982 .send(
983 connection_id,
984 proto::CallCanceled {
985 room_id: room_id.to_proto(),
986 },
987 )
988 .trace_err();
989 }
990
991 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
992 if let Some(token) = live_kit
993 .room_token(
994 &joined_room.room.live_kit_room,
995 &session.user_id.to_string(),
996 )
997 .trace_err()
998 {
999 Some(proto::LiveKitConnectionInfo {
1000 server_url: live_kit.url().into(),
1001 token,
1002 })
1003 } else {
1004 None
1005 }
1006 } else {
1007 None
1008 };
1009
1010 response.send(proto::JoinRoomResponse {
1011 room: Some(joined_room.room),
1012 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
1013 live_kit_connection_info,
1014 })?;
1015
1016 update_user_contacts(session.user_id, &session).await?;
1017 Ok(())
1018}
1019
1020async fn rejoin_room(
1021 request: proto::RejoinRoom,
1022 response: Response<proto::RejoinRoom>,
1023 session: Session,
1024) -> Result<()> {
1025 let room;
1026 let channel_id;
1027 let channel_members;
1028 {
1029 let mut rejoined_room = session
1030 .db()
1031 .await
1032 .rejoin_room(request, session.user_id, session.connection_id)
1033 .await?;
1034
1035 response.send(proto::RejoinRoomResponse {
1036 room: Some(rejoined_room.room.clone()),
1037 reshared_projects: rejoined_room
1038 .reshared_projects
1039 .iter()
1040 .map(|project| proto::ResharedProject {
1041 id: project.id.to_proto(),
1042 collaborators: project
1043 .collaborators
1044 .iter()
1045 .map(|collaborator| collaborator.to_proto())
1046 .collect(),
1047 })
1048 .collect(),
1049 rejoined_projects: rejoined_room
1050 .rejoined_projects
1051 .iter()
1052 .map(|rejoined_project| proto::RejoinedProject {
1053 id: rejoined_project.id.to_proto(),
1054 worktrees: rejoined_project
1055 .worktrees
1056 .iter()
1057 .map(|worktree| proto::WorktreeMetadata {
1058 id: worktree.id,
1059 root_name: worktree.root_name.clone(),
1060 visible: worktree.visible,
1061 abs_path: worktree.abs_path.clone(),
1062 })
1063 .collect(),
1064 collaborators: rejoined_project
1065 .collaborators
1066 .iter()
1067 .map(|collaborator| collaborator.to_proto())
1068 .collect(),
1069 language_servers: rejoined_project.language_servers.clone(),
1070 })
1071 .collect(),
1072 })?;
1073 room_updated(&rejoined_room.room, &session.peer);
1074
1075 for project in &rejoined_room.reshared_projects {
1076 for collaborator in &project.collaborators {
1077 session
1078 .peer
1079 .send(
1080 collaborator.connection_id,
1081 proto::UpdateProjectCollaborator {
1082 project_id: project.id.to_proto(),
1083 old_peer_id: Some(project.old_connection_id.into()),
1084 new_peer_id: Some(session.connection_id.into()),
1085 },
1086 )
1087 .trace_err();
1088 }
1089
1090 broadcast(
1091 Some(session.connection_id),
1092 project
1093 .collaborators
1094 .iter()
1095 .map(|collaborator| collaborator.connection_id),
1096 |connection_id| {
1097 session.peer.forward_send(
1098 session.connection_id,
1099 connection_id,
1100 proto::UpdateProject {
1101 project_id: project.id.to_proto(),
1102 worktrees: project.worktrees.clone(),
1103 },
1104 )
1105 },
1106 );
1107 }
1108
1109 for project in &rejoined_room.rejoined_projects {
1110 for collaborator in &project.collaborators {
1111 session
1112 .peer
1113 .send(
1114 collaborator.connection_id,
1115 proto::UpdateProjectCollaborator {
1116 project_id: project.id.to_proto(),
1117 old_peer_id: Some(project.old_connection_id.into()),
1118 new_peer_id: Some(session.connection_id.into()),
1119 },
1120 )
1121 .trace_err();
1122 }
1123 }
1124
1125 for project in &mut rejoined_room.rejoined_projects {
1126 for worktree in mem::take(&mut project.worktrees) {
1127 #[cfg(any(test, feature = "test-support"))]
1128 const MAX_CHUNK_SIZE: usize = 2;
1129 #[cfg(not(any(test, feature = "test-support")))]
1130 const MAX_CHUNK_SIZE: usize = 256;
1131
1132 // Stream this worktree's entries.
1133 let message = proto::UpdateWorktree {
1134 project_id: project.id.to_proto(),
1135 worktree_id: worktree.id,
1136 abs_path: worktree.abs_path.clone(),
1137 root_name: worktree.root_name,
1138 updated_entries: worktree.updated_entries,
1139 removed_entries: worktree.removed_entries,
1140 scan_id: worktree.scan_id,
1141 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1142 updated_repositories: worktree.updated_repositories,
1143 removed_repositories: worktree.removed_repositories,
1144 };
1145 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1146 session.peer.send(session.connection_id, update.clone())?;
1147 }
1148
1149 // Stream this worktree's diagnostics.
1150 for summary in worktree.diagnostic_summaries {
1151 session.peer.send(
1152 session.connection_id,
1153 proto::UpdateDiagnosticSummary {
1154 project_id: project.id.to_proto(),
1155 worktree_id: worktree.id,
1156 summary: Some(summary),
1157 },
1158 )?;
1159 }
1160
1161 for settings_file in worktree.settings_files {
1162 session.peer.send(
1163 session.connection_id,
1164 proto::UpdateWorktreeSettings {
1165 project_id: project.id.to_proto(),
1166 worktree_id: worktree.id,
1167 path: settings_file.path,
1168 content: Some(settings_file.content),
1169 },
1170 )?;
1171 }
1172 }
1173
1174 for language_server in &project.language_servers {
1175 session.peer.send(
1176 session.connection_id,
1177 proto::UpdateLanguageServer {
1178 project_id: project.id.to_proto(),
1179 language_server_id: language_server.id,
1180 variant: Some(
1181 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1182 proto::LspDiskBasedDiagnosticsUpdated {},
1183 ),
1184 ),
1185 },
1186 )?;
1187 }
1188 }
1189
1190 let rejoined_room = rejoined_room.into_inner();
1191
1192 room = rejoined_room.room;
1193 channel_id = rejoined_room.channel_id;
1194 channel_members = rejoined_room.channel_members;
1195 }
1196
1197 if let Some(channel_id) = channel_id {
1198 channel_updated(
1199 channel_id,
1200 &room,
1201 &channel_members,
1202 &session.peer,
1203 &*session.connection_pool().await,
1204 );
1205 }
1206
1207 update_user_contacts(session.user_id, &session).await?;
1208 Ok(())
1209}
1210
1211async fn leave_room(
1212 _: proto::LeaveRoom,
1213 response: Response<proto::LeaveRoom>,
1214 session: Session,
1215) -> Result<()> {
1216 leave_room_for_session(&session).await?;
1217 response.send(proto::Ack {})?;
1218 Ok(())
1219}
1220
1221async fn call(
1222 request: proto::Call,
1223 response: Response<proto::Call>,
1224 session: Session,
1225) -> Result<()> {
1226 let room_id = RoomId::from_proto(request.room_id);
1227 let calling_user_id = session.user_id;
1228 let calling_connection_id = session.connection_id;
1229 let called_user_id = UserId::from_proto(request.called_user_id);
1230 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1231 if !session
1232 .db()
1233 .await
1234 .has_contact(calling_user_id, called_user_id)
1235 .await?
1236 {
1237 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1238 }
1239
1240 let incoming_call = {
1241 let (room, incoming_call) = &mut *session
1242 .db()
1243 .await
1244 .call(
1245 room_id,
1246 calling_user_id,
1247 calling_connection_id,
1248 called_user_id,
1249 initial_project_id,
1250 )
1251 .await?;
1252 room_updated(&room, &session.peer);
1253 mem::take(incoming_call)
1254 };
1255 update_user_contacts(called_user_id, &session).await?;
1256
1257 let mut calls = session
1258 .connection_pool()
1259 .await
1260 .user_connection_ids(called_user_id)
1261 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1262 .collect::<FuturesUnordered<_>>();
1263
1264 while let Some(call_response) = calls.next().await {
1265 match call_response.as_ref() {
1266 Ok(_) => {
1267 response.send(proto::Ack {})?;
1268 return Ok(());
1269 }
1270 Err(_) => {
1271 call_response.trace_err();
1272 }
1273 }
1274 }
1275
1276 {
1277 let room = session
1278 .db()
1279 .await
1280 .call_failed(room_id, called_user_id)
1281 .await?;
1282 room_updated(&room, &session.peer);
1283 }
1284 update_user_contacts(called_user_id, &session).await?;
1285
1286 Err(anyhow!("failed to ring user"))?
1287}
1288
1289async fn cancel_call(
1290 request: proto::CancelCall,
1291 response: Response<proto::CancelCall>,
1292 session: Session,
1293) -> Result<()> {
1294 let called_user_id = UserId::from_proto(request.called_user_id);
1295 let room_id = RoomId::from_proto(request.room_id);
1296 {
1297 let room = session
1298 .db()
1299 .await
1300 .cancel_call(room_id, session.connection_id, called_user_id)
1301 .await?;
1302 room_updated(&room, &session.peer);
1303 }
1304
1305 for connection_id in session
1306 .connection_pool()
1307 .await
1308 .user_connection_ids(called_user_id)
1309 {
1310 session
1311 .peer
1312 .send(
1313 connection_id,
1314 proto::CallCanceled {
1315 room_id: room_id.to_proto(),
1316 },
1317 )
1318 .trace_err();
1319 }
1320 response.send(proto::Ack {})?;
1321
1322 update_user_contacts(called_user_id, &session).await?;
1323 Ok(())
1324}
1325
1326async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1327 let room_id = RoomId::from_proto(message.room_id);
1328 {
1329 let room = session
1330 .db()
1331 .await
1332 .decline_call(Some(room_id), session.user_id)
1333 .await?
1334 .ok_or_else(|| anyhow!("failed to decline call"))?;
1335 room_updated(&room, &session.peer);
1336 }
1337
1338 for connection_id in session
1339 .connection_pool()
1340 .await
1341 .user_connection_ids(session.user_id)
1342 {
1343 session
1344 .peer
1345 .send(
1346 connection_id,
1347 proto::CallCanceled {
1348 room_id: room_id.to_proto(),
1349 },
1350 )
1351 .trace_err();
1352 }
1353 update_user_contacts(session.user_id, &session).await?;
1354 Ok(())
1355}
1356
1357async fn update_participant_location(
1358 request: proto::UpdateParticipantLocation,
1359 response: Response<proto::UpdateParticipantLocation>,
1360 session: Session,
1361) -> Result<()> {
1362 let room_id = RoomId::from_proto(request.room_id);
1363 let location = request
1364 .location
1365 .ok_or_else(|| anyhow!("invalid location"))?;
1366
1367 let db = session.db().await;
1368 let room = db
1369 .update_room_participant_location(room_id, session.connection_id, location)
1370 .await?;
1371
1372 room_updated(&room, &session.peer);
1373 response.send(proto::Ack {})?;
1374 Ok(())
1375}
1376
1377async fn share_project(
1378 request: proto::ShareProject,
1379 response: Response<proto::ShareProject>,
1380 session: Session,
1381) -> Result<()> {
1382 let (project_id, room) = &*session
1383 .db()
1384 .await
1385 .share_project(
1386 RoomId::from_proto(request.room_id),
1387 session.connection_id,
1388 &request.worktrees,
1389 )
1390 .await?;
1391 response.send(proto::ShareProjectResponse {
1392 project_id: project_id.to_proto(),
1393 })?;
1394 room_updated(&room, &session.peer);
1395
1396 Ok(())
1397}
1398
1399async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1400 let project_id = ProjectId::from_proto(message.project_id);
1401
1402 let (room, guest_connection_ids) = &*session
1403 .db()
1404 .await
1405 .unshare_project(project_id, session.connection_id)
1406 .await?;
1407
1408 broadcast(
1409 Some(session.connection_id),
1410 guest_connection_ids.iter().copied(),
1411 |conn_id| session.peer.send(conn_id, message.clone()),
1412 );
1413 room_updated(&room, &session.peer);
1414
1415 Ok(())
1416}
1417
1418async fn join_project(
1419 request: proto::JoinProject,
1420 response: Response<proto::JoinProject>,
1421 session: Session,
1422) -> Result<()> {
1423 let project_id = ProjectId::from_proto(request.project_id);
1424 let guest_user_id = session.user_id;
1425
1426 tracing::info!(%project_id, "join project");
1427
1428 let (project, replica_id) = &mut *session
1429 .db()
1430 .await
1431 .join_project(project_id, session.connection_id)
1432 .await?;
1433
1434 let collaborators = project
1435 .collaborators
1436 .iter()
1437 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1438 .map(|collaborator| collaborator.to_proto())
1439 .collect::<Vec<_>>();
1440
1441 let worktrees = project
1442 .worktrees
1443 .iter()
1444 .map(|(id, worktree)| proto::WorktreeMetadata {
1445 id: *id,
1446 root_name: worktree.root_name.clone(),
1447 visible: worktree.visible,
1448 abs_path: worktree.abs_path.clone(),
1449 })
1450 .collect::<Vec<_>>();
1451
1452 for collaborator in &collaborators {
1453 session
1454 .peer
1455 .send(
1456 collaborator.peer_id.unwrap().into(),
1457 proto::AddProjectCollaborator {
1458 project_id: project_id.to_proto(),
1459 collaborator: Some(proto::Collaborator {
1460 peer_id: Some(session.connection_id.into()),
1461 replica_id: replica_id.0 as u32,
1462 user_id: guest_user_id.to_proto(),
1463 }),
1464 },
1465 )
1466 .trace_err();
1467 }
1468
1469 // First, we send the metadata associated with each worktree.
1470 response.send(proto::JoinProjectResponse {
1471 worktrees: worktrees.clone(),
1472 replica_id: replica_id.0 as u32,
1473 collaborators: collaborators.clone(),
1474 language_servers: project.language_servers.clone(),
1475 })?;
1476
1477 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1478 #[cfg(any(test, feature = "test-support"))]
1479 const MAX_CHUNK_SIZE: usize = 2;
1480 #[cfg(not(any(test, feature = "test-support")))]
1481 const MAX_CHUNK_SIZE: usize = 256;
1482
1483 // Stream this worktree's entries.
1484 let message = proto::UpdateWorktree {
1485 project_id: project_id.to_proto(),
1486 worktree_id,
1487 abs_path: worktree.abs_path.clone(),
1488 root_name: worktree.root_name,
1489 updated_entries: worktree.entries,
1490 removed_entries: Default::default(),
1491 scan_id: worktree.scan_id,
1492 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1493 updated_repositories: worktree.repository_entries.into_values().collect(),
1494 removed_repositories: Default::default(),
1495 };
1496 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1497 session.peer.send(session.connection_id, update.clone())?;
1498 }
1499
1500 // Stream this worktree's diagnostics.
1501 for summary in worktree.diagnostic_summaries {
1502 session.peer.send(
1503 session.connection_id,
1504 proto::UpdateDiagnosticSummary {
1505 project_id: project_id.to_proto(),
1506 worktree_id: worktree.id,
1507 summary: Some(summary),
1508 },
1509 )?;
1510 }
1511
1512 for settings_file in worktree.settings_files {
1513 session.peer.send(
1514 session.connection_id,
1515 proto::UpdateWorktreeSettings {
1516 project_id: project_id.to_proto(),
1517 worktree_id: worktree.id,
1518 path: settings_file.path,
1519 content: Some(settings_file.content),
1520 },
1521 )?;
1522 }
1523 }
1524
1525 for language_server in &project.language_servers {
1526 session.peer.send(
1527 session.connection_id,
1528 proto::UpdateLanguageServer {
1529 project_id: project_id.to_proto(),
1530 language_server_id: language_server.id,
1531 variant: Some(
1532 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1533 proto::LspDiskBasedDiagnosticsUpdated {},
1534 ),
1535 ),
1536 },
1537 )?;
1538 }
1539
1540 Ok(())
1541}
1542
1543async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1544 let sender_id = session.connection_id;
1545 let project_id = ProjectId::from_proto(request.project_id);
1546
1547 let (room, project) = &*session
1548 .db()
1549 .await
1550 .leave_project(project_id, sender_id)
1551 .await?;
1552 tracing::info!(
1553 %project_id,
1554 host_user_id = %project.host_user_id,
1555 host_connection_id = %project.host_connection_id,
1556 "leave project"
1557 );
1558
1559 project_left(&project, &session);
1560 room_updated(&room, &session.peer);
1561
1562 Ok(())
1563}
1564
1565async fn update_project(
1566 request: proto::UpdateProject,
1567 response: Response<proto::UpdateProject>,
1568 session: Session,
1569) -> Result<()> {
1570 let project_id = ProjectId::from_proto(request.project_id);
1571 let (room, guest_connection_ids) = &*session
1572 .db()
1573 .await
1574 .update_project(project_id, session.connection_id, &request.worktrees)
1575 .await?;
1576 broadcast(
1577 Some(session.connection_id),
1578 guest_connection_ids.iter().copied(),
1579 |connection_id| {
1580 session
1581 .peer
1582 .forward_send(session.connection_id, connection_id, request.clone())
1583 },
1584 );
1585 room_updated(&room, &session.peer);
1586 response.send(proto::Ack {})?;
1587
1588 Ok(())
1589}
1590
1591async fn update_worktree(
1592 request: proto::UpdateWorktree,
1593 response: Response<proto::UpdateWorktree>,
1594 session: Session,
1595) -> Result<()> {
1596 let guest_connection_ids = session
1597 .db()
1598 .await
1599 .update_worktree(&request, session.connection_id)
1600 .await?;
1601
1602 broadcast(
1603 Some(session.connection_id),
1604 guest_connection_ids.iter().copied(),
1605 |connection_id| {
1606 session
1607 .peer
1608 .forward_send(session.connection_id, connection_id, request.clone())
1609 },
1610 );
1611 response.send(proto::Ack {})?;
1612 Ok(())
1613}
1614
1615async fn update_diagnostic_summary(
1616 message: proto::UpdateDiagnosticSummary,
1617 session: Session,
1618) -> Result<()> {
1619 let guest_connection_ids = session
1620 .db()
1621 .await
1622 .update_diagnostic_summary(&message, session.connection_id)
1623 .await?;
1624
1625 broadcast(
1626 Some(session.connection_id),
1627 guest_connection_ids.iter().copied(),
1628 |connection_id| {
1629 session
1630 .peer
1631 .forward_send(session.connection_id, connection_id, message.clone())
1632 },
1633 );
1634
1635 Ok(())
1636}
1637
1638async fn update_worktree_settings(
1639 message: proto::UpdateWorktreeSettings,
1640 session: Session,
1641) -> Result<()> {
1642 let guest_connection_ids = session
1643 .db()
1644 .await
1645 .update_worktree_settings(&message, session.connection_id)
1646 .await?;
1647
1648 broadcast(
1649 Some(session.connection_id),
1650 guest_connection_ids.iter().copied(),
1651 |connection_id| {
1652 session
1653 .peer
1654 .forward_send(session.connection_id, connection_id, message.clone())
1655 },
1656 );
1657
1658 Ok(())
1659}
1660
1661async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1662 broadcast_project_message(request.project_id, request, session).await
1663}
1664
1665async fn start_language_server(
1666 request: proto::StartLanguageServer,
1667 session: Session,
1668) -> Result<()> {
1669 let guest_connection_ids = session
1670 .db()
1671 .await
1672 .start_language_server(&request, session.connection_id)
1673 .await?;
1674
1675 broadcast(
1676 Some(session.connection_id),
1677 guest_connection_ids.iter().copied(),
1678 |connection_id| {
1679 session
1680 .peer
1681 .forward_send(session.connection_id, connection_id, request.clone())
1682 },
1683 );
1684 Ok(())
1685}
1686
1687async fn update_language_server(
1688 request: proto::UpdateLanguageServer,
1689 session: Session,
1690) -> Result<()> {
1691 session.executor.record_backtrace();
1692 let project_id = ProjectId::from_proto(request.project_id);
1693 let project_connection_ids = session
1694 .db()
1695 .await
1696 .project_connection_ids(project_id, session.connection_id)
1697 .await?;
1698 broadcast(
1699 Some(session.connection_id),
1700 project_connection_ids.iter().copied(),
1701 |connection_id| {
1702 session
1703 .peer
1704 .forward_send(session.connection_id, connection_id, request.clone())
1705 },
1706 );
1707 Ok(())
1708}
1709
1710async fn forward_project_request<T>(
1711 request: T,
1712 response: Response<T>,
1713 session: Session,
1714) -> Result<()>
1715where
1716 T: EntityMessage + RequestMessage,
1717{
1718 session.executor.record_backtrace();
1719 let project_id = ProjectId::from_proto(request.remote_entity_id());
1720 let host_connection_id = {
1721 let collaborators = session
1722 .db()
1723 .await
1724 .project_collaborators(project_id, session.connection_id)
1725 .await?;
1726 collaborators
1727 .iter()
1728 .find(|collaborator| collaborator.is_host)
1729 .ok_or_else(|| anyhow!("host not found"))?
1730 .connection_id
1731 };
1732
1733 let payload = session
1734 .peer
1735 .forward_request(session.connection_id, host_connection_id, request)
1736 .await?;
1737
1738 response.send(payload)?;
1739 Ok(())
1740}
1741
1742async fn create_buffer_for_peer(
1743 request: proto::CreateBufferForPeer,
1744 session: Session,
1745) -> Result<()> {
1746 session.executor.record_backtrace();
1747 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1748 session
1749 .peer
1750 .forward_send(session.connection_id, peer_id.into(), request)?;
1751 Ok(())
1752}
1753
1754async fn update_buffer(
1755 request: proto::UpdateBuffer,
1756 response: Response<proto::UpdateBuffer>,
1757 session: Session,
1758) -> Result<()> {
1759 session.executor.record_backtrace();
1760 let project_id = ProjectId::from_proto(request.project_id);
1761 let mut guest_connection_ids;
1762 let mut host_connection_id = None;
1763 {
1764 let collaborators = session
1765 .db()
1766 .await
1767 .project_collaborators(project_id, session.connection_id)
1768 .await?;
1769 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1770 for collaborator in collaborators.iter() {
1771 if collaborator.is_host {
1772 host_connection_id = Some(collaborator.connection_id);
1773 } else {
1774 guest_connection_ids.push(collaborator.connection_id);
1775 }
1776 }
1777 }
1778 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1779
1780 session.executor.record_backtrace();
1781 broadcast(
1782 Some(session.connection_id),
1783 guest_connection_ids,
1784 |connection_id| {
1785 session
1786 .peer
1787 .forward_send(session.connection_id, connection_id, request.clone())
1788 },
1789 );
1790 if host_connection_id != session.connection_id {
1791 session
1792 .peer
1793 .forward_request(session.connection_id, host_connection_id, request.clone())
1794 .await?;
1795 }
1796
1797 response.send(proto::Ack {})?;
1798 Ok(())
1799}
1800
1801async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1802 let project_id = ProjectId::from_proto(request.project_id);
1803 let project_connection_ids = session
1804 .db()
1805 .await
1806 .project_connection_ids(project_id, session.connection_id)
1807 .await?;
1808
1809 broadcast(
1810 Some(session.connection_id),
1811 project_connection_ids.iter().copied(),
1812 |connection_id| {
1813 session
1814 .peer
1815 .forward_send(session.connection_id, connection_id, request.clone())
1816 },
1817 );
1818 Ok(())
1819}
1820
1821async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1822 let project_id = ProjectId::from_proto(request.project_id);
1823 let project_connection_ids = session
1824 .db()
1825 .await
1826 .project_connection_ids(project_id, session.connection_id)
1827 .await?;
1828 broadcast(
1829 Some(session.connection_id),
1830 project_connection_ids.iter().copied(),
1831 |connection_id| {
1832 session
1833 .peer
1834 .forward_send(session.connection_id, connection_id, request.clone())
1835 },
1836 );
1837 Ok(())
1838}
1839
1840async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1841 broadcast_project_message(request.project_id, request, session).await
1842}
1843
1844async fn broadcast_project_message<T: EnvelopedMessage>(
1845 project_id: u64,
1846 request: T,
1847 session: Session,
1848) -> Result<()> {
1849 let project_id = ProjectId::from_proto(project_id);
1850 let project_connection_ids = session
1851 .db()
1852 .await
1853 .project_connection_ids(project_id, session.connection_id)
1854 .await?;
1855 broadcast(
1856 Some(session.connection_id),
1857 project_connection_ids.iter().copied(),
1858 |connection_id| {
1859 session
1860 .peer
1861 .forward_send(session.connection_id, connection_id, request.clone())
1862 },
1863 );
1864 Ok(())
1865}
1866
1867async fn follow(
1868 request: proto::Follow,
1869 response: Response<proto::Follow>,
1870 session: Session,
1871) -> Result<()> {
1872 let project_id = ProjectId::from_proto(request.project_id);
1873 let leader_id = request
1874 .leader_id
1875 .ok_or_else(|| anyhow!("invalid leader id"))?
1876 .into();
1877 let follower_id = session.connection_id;
1878
1879 {
1880 let project_connection_ids = session
1881 .db()
1882 .await
1883 .project_connection_ids(project_id, session.connection_id)
1884 .await?;
1885
1886 if !project_connection_ids.contains(&leader_id) {
1887 Err(anyhow!("no such peer"))?;
1888 }
1889 }
1890
1891 let mut response_payload = session
1892 .peer
1893 .forward_request(session.connection_id, leader_id, request)
1894 .await?;
1895 response_payload
1896 .views
1897 .retain(|view| view.leader_id != Some(follower_id.into()));
1898 response.send(response_payload)?;
1899
1900 let room = session
1901 .db()
1902 .await
1903 .follow(project_id, leader_id, follower_id)
1904 .await?;
1905 room_updated(&room, &session.peer);
1906
1907 Ok(())
1908}
1909
1910async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1911 let project_id = ProjectId::from_proto(request.project_id);
1912 let leader_id = request
1913 .leader_id
1914 .ok_or_else(|| anyhow!("invalid leader id"))?
1915 .into();
1916 let follower_id = session.connection_id;
1917
1918 if !session
1919 .db()
1920 .await
1921 .project_connection_ids(project_id, session.connection_id)
1922 .await?
1923 .contains(&leader_id)
1924 {
1925 Err(anyhow!("no such peer"))?;
1926 }
1927
1928 session
1929 .peer
1930 .forward_send(session.connection_id, leader_id, request)?;
1931
1932 let room = session
1933 .db()
1934 .await
1935 .unfollow(project_id, leader_id, follower_id)
1936 .await?;
1937 room_updated(&room, &session.peer);
1938
1939 Ok(())
1940}
1941
1942async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1943 let project_id = ProjectId::from_proto(request.project_id);
1944 let project_connection_ids = session
1945 .db
1946 .lock()
1947 .await
1948 .project_connection_ids(project_id, session.connection_id)
1949 .await?;
1950
1951 let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1952 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1953 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1954 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1955 });
1956 for follower_peer_id in request.follower_ids.iter().copied() {
1957 let follower_connection_id = follower_peer_id.into();
1958 if project_connection_ids.contains(&follower_connection_id)
1959 && Some(follower_peer_id) != leader_id
1960 {
1961 session.peer.forward_send(
1962 session.connection_id,
1963 follower_connection_id,
1964 request.clone(),
1965 )?;
1966 }
1967 }
1968 Ok(())
1969}
1970
1971async fn get_users(
1972 request: proto::GetUsers,
1973 response: Response<proto::GetUsers>,
1974 session: Session,
1975) -> Result<()> {
1976 let user_ids = request
1977 .user_ids
1978 .into_iter()
1979 .map(UserId::from_proto)
1980 .collect();
1981 let users = session
1982 .db()
1983 .await
1984 .get_users_by_ids(user_ids)
1985 .await?
1986 .into_iter()
1987 .map(|user| proto::User {
1988 id: user.id.to_proto(),
1989 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1990 github_login: user.github_login,
1991 })
1992 .collect();
1993 response.send(proto::UsersResponse { users })?;
1994 Ok(())
1995}
1996
1997async fn fuzzy_search_users(
1998 request: proto::FuzzySearchUsers,
1999 response: Response<proto::FuzzySearchUsers>,
2000 session: Session,
2001) -> Result<()> {
2002 let query = request.query;
2003 let users = match query.len() {
2004 0 => vec![],
2005 1 | 2 => session
2006 .db()
2007 .await
2008 .get_user_by_github_login(&query)
2009 .await?
2010 .into_iter()
2011 .collect(),
2012 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2013 };
2014 let users = users
2015 .into_iter()
2016 .filter(|user| user.id != session.user_id)
2017 .map(|user| proto::User {
2018 id: user.id.to_proto(),
2019 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2020 github_login: user.github_login,
2021 })
2022 .collect();
2023 response.send(proto::UsersResponse { users })?;
2024 Ok(())
2025}
2026
2027async fn request_contact(
2028 request: proto::RequestContact,
2029 response: Response<proto::RequestContact>,
2030 session: Session,
2031) -> Result<()> {
2032 let requester_id = session.user_id;
2033 let responder_id = UserId::from_proto(request.responder_id);
2034 if requester_id == responder_id {
2035 return Err(anyhow!("cannot add yourself as a contact"))?;
2036 }
2037
2038 session
2039 .db()
2040 .await
2041 .send_contact_request(requester_id, responder_id)
2042 .await?;
2043
2044 // Update outgoing contact requests of requester
2045 let mut update = proto::UpdateContacts::default();
2046 update.outgoing_requests.push(responder_id.to_proto());
2047 for connection_id in session
2048 .connection_pool()
2049 .await
2050 .user_connection_ids(requester_id)
2051 {
2052 session.peer.send(connection_id, update.clone())?;
2053 }
2054
2055 // Update incoming contact requests of responder
2056 let mut update = proto::UpdateContacts::default();
2057 update
2058 .incoming_requests
2059 .push(proto::IncomingContactRequest {
2060 requester_id: requester_id.to_proto(),
2061 should_notify: true,
2062 });
2063 for connection_id in session
2064 .connection_pool()
2065 .await
2066 .user_connection_ids(responder_id)
2067 {
2068 session.peer.send(connection_id, update.clone())?;
2069 }
2070
2071 response.send(proto::Ack {})?;
2072 Ok(())
2073}
2074
2075async fn respond_to_contact_request(
2076 request: proto::RespondToContactRequest,
2077 response: Response<proto::RespondToContactRequest>,
2078 session: Session,
2079) -> Result<()> {
2080 let responder_id = session.user_id;
2081 let requester_id = UserId::from_proto(request.requester_id);
2082 let db = session.db().await;
2083 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2084 db.dismiss_contact_notification(responder_id, requester_id)
2085 .await?;
2086 } else {
2087 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2088
2089 db.respond_to_contact_request(responder_id, requester_id, accept)
2090 .await?;
2091 let requester_busy = db.is_user_busy(requester_id).await?;
2092 let responder_busy = db.is_user_busy(responder_id).await?;
2093
2094 let pool = session.connection_pool().await;
2095 // Update responder with new contact
2096 let mut update = proto::UpdateContacts::default();
2097 if accept {
2098 update
2099 .contacts
2100 .push(contact_for_user(requester_id, false, requester_busy, &pool));
2101 }
2102 update
2103 .remove_incoming_requests
2104 .push(requester_id.to_proto());
2105 for connection_id in pool.user_connection_ids(responder_id) {
2106 session.peer.send(connection_id, update.clone())?;
2107 }
2108
2109 // Update requester with new contact
2110 let mut update = proto::UpdateContacts::default();
2111 if accept {
2112 update
2113 .contacts
2114 .push(contact_for_user(responder_id, true, responder_busy, &pool));
2115 }
2116 update
2117 .remove_outgoing_requests
2118 .push(responder_id.to_proto());
2119 for connection_id in pool.user_connection_ids(requester_id) {
2120 session.peer.send(connection_id, update.clone())?;
2121 }
2122 }
2123
2124 response.send(proto::Ack {})?;
2125 Ok(())
2126}
2127
2128async fn remove_contact(
2129 request: proto::RemoveContact,
2130 response: Response<proto::RemoveContact>,
2131 session: Session,
2132) -> Result<()> {
2133 let requester_id = session.user_id;
2134 let responder_id = UserId::from_proto(request.user_id);
2135 let db = session.db().await;
2136 let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
2137
2138 let pool = session.connection_pool().await;
2139 // Update outgoing contact requests of requester
2140 let mut update = proto::UpdateContacts::default();
2141 if contact_accepted {
2142 update.remove_contacts.push(responder_id.to_proto());
2143 } else {
2144 update
2145 .remove_outgoing_requests
2146 .push(responder_id.to_proto());
2147 }
2148 for connection_id in pool.user_connection_ids(requester_id) {
2149 session.peer.send(connection_id, update.clone())?;
2150 }
2151
2152 // Update incoming contact requests of responder
2153 let mut update = proto::UpdateContacts::default();
2154 if contact_accepted {
2155 update.remove_contacts.push(requester_id.to_proto());
2156 } else {
2157 update
2158 .remove_incoming_requests
2159 .push(requester_id.to_proto());
2160 }
2161 for connection_id in pool.user_connection_ids(responder_id) {
2162 session.peer.send(connection_id, update.clone())?;
2163 }
2164
2165 response.send(proto::Ack {})?;
2166 Ok(())
2167}
2168
2169async fn create_channel(
2170 request: proto::CreateChannel,
2171 response: Response<proto::CreateChannel>,
2172 session: Session,
2173) -> Result<()> {
2174 let db = session.db().await;
2175 let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
2176
2177 if let Some(live_kit) = session.live_kit_client.as_ref() {
2178 live_kit.create_room(live_kit_room.clone()).await?;
2179 }
2180
2181 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2182 let id = db
2183 .create_channel(&request.name, parent_id, &live_kit_room, session.user_id)
2184 .await?;
2185
2186 let channel = proto::Channel {
2187 id: id.to_proto(),
2188 name: request.name,
2189 parent_id: request.parent_id,
2190 };
2191
2192 response.send(proto::ChannelResponse {
2193 channel: Some(channel.clone()),
2194 })?;
2195
2196 let mut update = proto::UpdateChannels::default();
2197 update.channels.push(channel);
2198
2199 let user_ids_to_notify = if let Some(parent_id) = parent_id {
2200 db.get_channel_members(parent_id).await?
2201 } else {
2202 vec![session.user_id]
2203 };
2204
2205 let connection_pool = session.connection_pool().await;
2206 for user_id in user_ids_to_notify {
2207 for connection_id in connection_pool.user_connection_ids(user_id) {
2208 let mut update = update.clone();
2209 if user_id == session.user_id {
2210 update.channel_permissions.push(proto::ChannelPermission {
2211 channel_id: id.to_proto(),
2212 is_admin: true,
2213 });
2214 }
2215 session.peer.send(connection_id, update)?;
2216 }
2217 }
2218
2219 Ok(())
2220}
2221
2222async fn remove_channel(
2223 request: proto::RemoveChannel,
2224 response: Response<proto::RemoveChannel>,
2225 session: Session,
2226) -> Result<()> {
2227 let db = session.db().await;
2228
2229 let channel_id = request.channel_id;
2230 let (removed_channels, member_ids) = db
2231 .remove_channel(ChannelId::from_proto(channel_id), session.user_id)
2232 .await?;
2233 response.send(proto::Ack {})?;
2234
2235 // Notify members of removed channels
2236 let mut update = proto::UpdateChannels::default();
2237 update
2238 .remove_channels
2239 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2240
2241 let connection_pool = session.connection_pool().await;
2242 for member_id in member_ids {
2243 for connection_id in connection_pool.user_connection_ids(member_id) {
2244 session.peer.send(connection_id, update.clone())?;
2245 }
2246 }
2247
2248 Ok(())
2249}
2250
2251async fn invite_channel_member(
2252 request: proto::InviteChannelMember,
2253 response: Response<proto::InviteChannelMember>,
2254 session: Session,
2255) -> Result<()> {
2256 let db = session.db().await;
2257 let channel_id = ChannelId::from_proto(request.channel_id);
2258 let invitee_id = UserId::from_proto(request.user_id);
2259 db.invite_channel_member(channel_id, invitee_id, session.user_id, request.admin)
2260 .await?;
2261
2262 let (channel, _) = db
2263 .get_channel(channel_id, session.user_id)
2264 .await?
2265 .ok_or_else(|| anyhow!("channel not found"))?;
2266
2267 let mut update = proto::UpdateChannels::default();
2268 update.channel_invitations.push(proto::Channel {
2269 id: channel.id.to_proto(),
2270 name: channel.name,
2271 parent_id: None,
2272 });
2273 for connection_id in session
2274 .connection_pool()
2275 .await
2276 .user_connection_ids(invitee_id)
2277 {
2278 session.peer.send(connection_id, update.clone())?;
2279 }
2280
2281 response.send(proto::Ack {})?;
2282 Ok(())
2283}
2284
2285async fn remove_channel_member(
2286 request: proto::RemoveChannelMember,
2287 response: Response<proto::RemoveChannelMember>,
2288 session: Session,
2289) -> Result<()> {
2290 let db = session.db().await;
2291 let channel_id = ChannelId::from_proto(request.channel_id);
2292 let member_id = UserId::from_proto(request.user_id);
2293
2294 db.remove_channel_member(channel_id, member_id, session.user_id)
2295 .await?;
2296
2297 let mut update = proto::UpdateChannels::default();
2298 update.remove_channels.push(channel_id.to_proto());
2299
2300 for connection_id in session
2301 .connection_pool()
2302 .await
2303 .user_connection_ids(member_id)
2304 {
2305 session.peer.send(connection_id, update.clone())?;
2306 }
2307
2308 response.send(proto::Ack {})?;
2309 Ok(())
2310}
2311
2312async fn set_channel_member_admin(
2313 request: proto::SetChannelMemberAdmin,
2314 response: Response<proto::SetChannelMemberAdmin>,
2315 session: Session,
2316) -> Result<()> {
2317 let db = session.db().await;
2318 let channel_id = ChannelId::from_proto(request.channel_id);
2319 let member_id = UserId::from_proto(request.user_id);
2320 db.set_channel_member_admin(channel_id, session.user_id, member_id, request.admin)
2321 .await?;
2322
2323 let (channel, has_accepted) = db
2324 .get_channel(channel_id, member_id)
2325 .await?
2326 .ok_or_else(|| anyhow!("channel not found"))?;
2327
2328 let mut update = proto::UpdateChannels::default();
2329 if has_accepted {
2330 update.channel_permissions.push(proto::ChannelPermission {
2331 channel_id: channel.id.to_proto(),
2332 is_admin: request.admin,
2333 });
2334 }
2335
2336 for connection_id in session
2337 .connection_pool()
2338 .await
2339 .user_connection_ids(member_id)
2340 {
2341 session.peer.send(connection_id, update.clone())?;
2342 }
2343
2344 response.send(proto::Ack {})?;
2345 Ok(())
2346}
2347
2348async fn rename_channel(
2349 request: proto::RenameChannel,
2350 response: Response<proto::RenameChannel>,
2351 session: Session,
2352) -> Result<()> {
2353 let db = session.db().await;
2354 let channel_id = ChannelId::from_proto(request.channel_id);
2355 let new_name = db
2356 .rename_channel(channel_id, session.user_id, &request.name)
2357 .await?;
2358
2359 let channel = proto::Channel {
2360 id: request.channel_id,
2361 name: new_name,
2362 parent_id: None,
2363 };
2364 response.send(proto::ChannelResponse {
2365 channel: Some(channel.clone()),
2366 })?;
2367 let mut update = proto::UpdateChannels::default();
2368 update.channels.push(channel);
2369
2370 let member_ids = db.get_channel_members(channel_id).await?;
2371
2372 let connection_pool = session.connection_pool().await;
2373 for member_id in member_ids {
2374 for connection_id in connection_pool.user_connection_ids(member_id) {
2375 session.peer.send(connection_id, update.clone())?;
2376 }
2377 }
2378
2379 Ok(())
2380}
2381
2382async fn get_channel_members(
2383 request: proto::GetChannelMembers,
2384 response: Response<proto::GetChannelMembers>,
2385 session: Session,
2386) -> Result<()> {
2387 let db = session.db().await;
2388 let channel_id = ChannelId::from_proto(request.channel_id);
2389 let members = db
2390 .get_channel_member_details(channel_id, session.user_id)
2391 .await?;
2392 response.send(proto::GetChannelMembersResponse { members })?;
2393 Ok(())
2394}
2395
2396async fn respond_to_channel_invite(
2397 request: proto::RespondToChannelInvite,
2398 response: Response<proto::RespondToChannelInvite>,
2399 session: Session,
2400) -> Result<()> {
2401 let db = session.db().await;
2402 let channel_id = ChannelId::from_proto(request.channel_id);
2403 db.respond_to_channel_invite(channel_id, session.user_id, request.accept)
2404 .await?;
2405
2406 let mut update = proto::UpdateChannels::default();
2407 update
2408 .remove_channel_invitations
2409 .push(channel_id.to_proto());
2410 if request.accept {
2411 let result = db.get_channels_for_user(session.user_id).await?;
2412 update
2413 .channels
2414 .extend(result.channels.into_iter().map(|channel| proto::Channel {
2415 id: channel.id.to_proto(),
2416 name: channel.name,
2417 parent_id: channel.parent_id.map(ChannelId::to_proto),
2418 }));
2419 update
2420 .channel_participants
2421 .extend(
2422 result
2423 .channel_participants
2424 .into_iter()
2425 .map(|(channel_id, user_ids)| proto::ChannelParticipants {
2426 channel_id: channel_id.to_proto(),
2427 participant_user_ids: user_ids.into_iter().map(UserId::to_proto).collect(),
2428 }),
2429 );
2430 update
2431 .channel_permissions
2432 .extend(
2433 result
2434 .channels_with_admin_privileges
2435 .into_iter()
2436 .map(|channel_id| proto::ChannelPermission {
2437 channel_id: channel_id.to_proto(),
2438 is_admin: true,
2439 }),
2440 );
2441 }
2442 session.peer.send(session.connection_id, update)?;
2443 response.send(proto::Ack {})?;
2444
2445 Ok(())
2446}
2447
2448async fn join_channel(
2449 request: proto::JoinChannel,
2450 response: Response<proto::JoinChannel>,
2451 session: Session,
2452) -> Result<()> {
2453 let channel_id = ChannelId::from_proto(request.channel_id);
2454
2455 let joined_room = {
2456 leave_room_for_session(&session).await?;
2457 let db = session.db().await;
2458
2459 let room_id = db.room_id_for_channel(channel_id).await?;
2460
2461 let joined_room = db
2462 .join_room(room_id, session.user_id, session.connection_id)
2463 .await?;
2464
2465 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2466 let token = live_kit
2467 .room_token(
2468 &joined_room.room.live_kit_room,
2469 &session.user_id.to_string(),
2470 )
2471 .trace_err()?;
2472
2473 Some(LiveKitConnectionInfo {
2474 server_url: live_kit.url().into(),
2475 token,
2476 })
2477 });
2478
2479 response.send(proto::JoinRoomResponse {
2480 room: Some(joined_room.room.clone()),
2481 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2482 live_kit_connection_info,
2483 })?;
2484
2485 room_updated(&joined_room.room, &session.peer);
2486
2487 joined_room.into_inner()
2488 };
2489
2490 channel_updated(
2491 channel_id,
2492 &joined_room.room,
2493 &joined_room.channel_members,
2494 &session.peer,
2495 &*session.connection_pool().await,
2496 );
2497
2498 update_user_contacts(session.user_id, &session).await?;
2499
2500 Ok(())
2501}
2502
2503async fn join_channel_buffer(
2504 request: proto::JoinChannelBuffer,
2505 response: Response<proto::JoinChannelBuffer>,
2506 session: Session,
2507) -> Result<()> {
2508 let db = session.db().await;
2509 let channel_id = ChannelId::from_proto(request.channel_id);
2510
2511 let open_response = db
2512 .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2513 .await?;
2514
2515 let replica_id = open_response.replica_id;
2516 let collaborators = open_response.collaborators.clone();
2517
2518 response.send(open_response)?;
2519
2520 let update = AddChannelBufferCollaborator {
2521 channel_id: channel_id.to_proto(),
2522 collaborator: Some(proto::Collaborator {
2523 user_id: session.user_id.to_proto(),
2524 peer_id: Some(session.connection_id.into()),
2525 replica_id,
2526 }),
2527 };
2528 channel_buffer_updated(
2529 session.connection_id,
2530 collaborators
2531 .iter()
2532 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2533 &update,
2534 &session.peer,
2535 );
2536
2537 Ok(())
2538}
2539
2540async fn update_channel_buffer(
2541 request: proto::UpdateChannelBuffer,
2542 session: Session,
2543) -> Result<()> {
2544 let db = session.db().await;
2545 let channel_id = ChannelId::from_proto(request.channel_id);
2546
2547 let collaborators = db
2548 .update_channel_buffer(channel_id, session.user_id, &request.operations)
2549 .await?;
2550
2551 channel_buffer_updated(
2552 session.connection_id,
2553 collaborators,
2554 &proto::UpdateChannelBuffer {
2555 channel_id: channel_id.to_proto(),
2556 operations: request.operations,
2557 },
2558 &session.peer,
2559 );
2560 Ok(())
2561}
2562
2563async fn rejoin_channel_buffers(
2564 request: proto::RejoinChannelBuffers,
2565 response: Response<proto::RejoinChannelBuffers>,
2566 session: Session,
2567) -> Result<()> {
2568 let db = session.db().await;
2569 let buffers = db
2570 .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2571 .await?;
2572
2573 for buffer in &buffers {
2574 let collaborators_to_notify = buffer
2575 .buffer
2576 .collaborators
2577 .iter()
2578 .filter_map(|c| Some(c.peer_id?.into()));
2579 channel_buffer_updated(
2580 session.connection_id,
2581 collaborators_to_notify,
2582 &proto::UpdateChannelBufferCollaborator {
2583 channel_id: buffer.buffer.channel_id,
2584 old_peer_id: Some(buffer.old_connection_id.into()),
2585 new_peer_id: Some(session.connection_id.into()),
2586 },
2587 &session.peer,
2588 );
2589 }
2590
2591 response.send(proto::RejoinChannelBuffersResponse {
2592 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2593 })?;
2594
2595 Ok(())
2596}
2597
2598async fn leave_channel_buffer(
2599 request: proto::LeaveChannelBuffer,
2600 response: Response<proto::LeaveChannelBuffer>,
2601 session: Session,
2602) -> Result<()> {
2603 let db = session.db().await;
2604 let channel_id = ChannelId::from_proto(request.channel_id);
2605
2606 let collaborators_to_notify = db
2607 .leave_channel_buffer(channel_id, session.connection_id)
2608 .await?;
2609
2610 response.send(Ack {})?;
2611
2612 channel_buffer_updated(
2613 session.connection_id,
2614 collaborators_to_notify,
2615 &proto::RemoveChannelBufferCollaborator {
2616 channel_id: channel_id.to_proto(),
2617 peer_id: Some(session.connection_id.into()),
2618 },
2619 &session.peer,
2620 );
2621
2622 Ok(())
2623}
2624
2625fn channel_buffer_updated<T: EnvelopedMessage>(
2626 sender_id: ConnectionId,
2627 collaborators: impl IntoIterator<Item = ConnectionId>,
2628 message: &T,
2629 peer: &Peer,
2630) {
2631 broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2632 peer.send(peer_id.into(), message.clone())
2633 });
2634}
2635
2636async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
2637 let project_id = ProjectId::from_proto(request.project_id);
2638 let project_connection_ids = session
2639 .db()
2640 .await
2641 .project_connection_ids(project_id, session.connection_id)
2642 .await?;
2643 broadcast(
2644 Some(session.connection_id),
2645 project_connection_ids.iter().copied(),
2646 |connection_id| {
2647 session
2648 .peer
2649 .forward_send(session.connection_id, connection_id, request.clone())
2650 },
2651 );
2652 Ok(())
2653}
2654
2655async fn get_private_user_info(
2656 _request: proto::GetPrivateUserInfo,
2657 response: Response<proto::GetPrivateUserInfo>,
2658 session: Session,
2659) -> Result<()> {
2660 let db = session.db().await;
2661
2662 let metrics_id = db.get_user_metrics_id(session.user_id).await?;
2663 let user = db
2664 .get_user_by_id(session.user_id)
2665 .await?
2666 .ok_or_else(|| anyhow!("user not found"))?;
2667 let flags = db.get_user_flags(session.user_id).await?;
2668
2669 response.send(proto::GetPrivateUserInfoResponse {
2670 metrics_id,
2671 staff: user.admin,
2672 flags,
2673 })?;
2674 Ok(())
2675}
2676
2677fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
2678 match message {
2679 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
2680 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
2681 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
2682 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
2683 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
2684 code: frame.code.into(),
2685 reason: frame.reason,
2686 })),
2687 }
2688}
2689
2690fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
2691 match message {
2692 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
2693 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
2694 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
2695 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
2696 AxumMessage::Close(frame) => {
2697 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
2698 code: frame.code.into(),
2699 reason: frame.reason,
2700 }))
2701 }
2702 }
2703}
2704
2705fn build_initial_channels_update(
2706 channels: ChannelsForUser,
2707 channel_invites: Vec<db::Channel>,
2708) -> proto::UpdateChannels {
2709 let mut update = proto::UpdateChannels::default();
2710
2711 for channel in channels.channels {
2712 update.channels.push(proto::Channel {
2713 id: channel.id.to_proto(),
2714 name: channel.name,
2715 parent_id: channel.parent_id.map(|id| id.to_proto()),
2716 });
2717 }
2718
2719 for (channel_id, participants) in channels.channel_participants {
2720 update
2721 .channel_participants
2722 .push(proto::ChannelParticipants {
2723 channel_id: channel_id.to_proto(),
2724 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
2725 });
2726 }
2727
2728 update
2729 .channel_permissions
2730 .extend(
2731 channels
2732 .channels_with_admin_privileges
2733 .into_iter()
2734 .map(|id| proto::ChannelPermission {
2735 channel_id: id.to_proto(),
2736 is_admin: true,
2737 }),
2738 );
2739
2740 for channel in channel_invites {
2741 update.channel_invitations.push(proto::Channel {
2742 id: channel.id.to_proto(),
2743 name: channel.name,
2744 parent_id: None,
2745 });
2746 }
2747
2748 update
2749}
2750
2751fn build_initial_contacts_update(
2752 contacts: Vec<db::Contact>,
2753 pool: &ConnectionPool,
2754) -> proto::UpdateContacts {
2755 let mut update = proto::UpdateContacts::default();
2756
2757 for contact in contacts {
2758 match contact {
2759 db::Contact::Accepted {
2760 user_id,
2761 should_notify,
2762 busy,
2763 } => {
2764 update
2765 .contacts
2766 .push(contact_for_user(user_id, should_notify, busy, &pool));
2767 }
2768 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
2769 db::Contact::Incoming {
2770 user_id,
2771 should_notify,
2772 } => update
2773 .incoming_requests
2774 .push(proto::IncomingContactRequest {
2775 requester_id: user_id.to_proto(),
2776 should_notify,
2777 }),
2778 }
2779 }
2780
2781 update
2782}
2783
2784fn contact_for_user(
2785 user_id: UserId,
2786 should_notify: bool,
2787 busy: bool,
2788 pool: &ConnectionPool,
2789) -> proto::Contact {
2790 proto::Contact {
2791 user_id: user_id.to_proto(),
2792 online: pool.is_user_online(user_id),
2793 busy,
2794 should_notify,
2795 }
2796}
2797
2798fn room_updated(room: &proto::Room, peer: &Peer) {
2799 broadcast(
2800 None,
2801 room.participants
2802 .iter()
2803 .filter_map(|participant| Some(participant.peer_id?.into())),
2804 |peer_id| {
2805 peer.send(
2806 peer_id.into(),
2807 proto::RoomUpdated {
2808 room: Some(room.clone()),
2809 },
2810 )
2811 },
2812 );
2813}
2814
2815fn channel_updated(
2816 channel_id: ChannelId,
2817 room: &proto::Room,
2818 channel_members: &[UserId],
2819 peer: &Peer,
2820 pool: &ConnectionPool,
2821) {
2822 let participants = room
2823 .participants
2824 .iter()
2825 .map(|p| p.user_id)
2826 .collect::<Vec<_>>();
2827
2828 broadcast(
2829 None,
2830 channel_members
2831 .iter()
2832 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2833 |peer_id| {
2834 peer.send(
2835 peer_id.into(),
2836 proto::UpdateChannels {
2837 channel_participants: vec![proto::ChannelParticipants {
2838 channel_id: channel_id.to_proto(),
2839 participant_user_ids: participants.clone(),
2840 }],
2841 ..Default::default()
2842 },
2843 )
2844 },
2845 );
2846}
2847
2848async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
2849 let db = session.db().await;
2850
2851 let contacts = db.get_contacts(user_id).await?;
2852 let busy = db.is_user_busy(user_id).await?;
2853
2854 let pool = session.connection_pool().await;
2855 let updated_contact = contact_for_user(user_id, false, busy, &pool);
2856 for contact in contacts {
2857 if let db::Contact::Accepted {
2858 user_id: contact_user_id,
2859 ..
2860 } = contact
2861 {
2862 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
2863 session
2864 .peer
2865 .send(
2866 contact_conn_id,
2867 proto::UpdateContacts {
2868 contacts: vec![updated_contact.clone()],
2869 remove_contacts: Default::default(),
2870 incoming_requests: Default::default(),
2871 remove_incoming_requests: Default::default(),
2872 outgoing_requests: Default::default(),
2873 remove_outgoing_requests: Default::default(),
2874 },
2875 )
2876 .trace_err();
2877 }
2878 }
2879 }
2880 Ok(())
2881}
2882
2883async fn leave_room_for_session(session: &Session) -> Result<()> {
2884 let mut contacts_to_update = HashSet::default();
2885
2886 let room_id;
2887 let canceled_calls_to_user_ids;
2888 let live_kit_room;
2889 let delete_live_kit_room;
2890 let room;
2891 let channel_members;
2892 let channel_id;
2893
2894 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
2895 contacts_to_update.insert(session.user_id);
2896
2897 for project in left_room.left_projects.values() {
2898 project_left(project, session);
2899 }
2900
2901 room_id = RoomId::from_proto(left_room.room.id);
2902 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
2903 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
2904 delete_live_kit_room = left_room.deleted;
2905 room = mem::take(&mut left_room.room);
2906 channel_members = mem::take(&mut left_room.channel_members);
2907 channel_id = left_room.channel_id;
2908
2909 room_updated(&room, &session.peer);
2910 } else {
2911 return Ok(());
2912 }
2913
2914 if let Some(channel_id) = channel_id {
2915 channel_updated(
2916 channel_id,
2917 &room,
2918 &channel_members,
2919 &session.peer,
2920 &*session.connection_pool().await,
2921 );
2922 }
2923
2924 {
2925 let pool = session.connection_pool().await;
2926 for canceled_user_id in canceled_calls_to_user_ids {
2927 for connection_id in pool.user_connection_ids(canceled_user_id) {
2928 session
2929 .peer
2930 .send(
2931 connection_id,
2932 proto::CallCanceled {
2933 room_id: room_id.to_proto(),
2934 },
2935 )
2936 .trace_err();
2937 }
2938 contacts_to_update.insert(canceled_user_id);
2939 }
2940 }
2941
2942 for contact_user_id in contacts_to_update {
2943 update_user_contacts(contact_user_id, &session).await?;
2944 }
2945
2946 if let Some(live_kit) = session.live_kit_client.as_ref() {
2947 live_kit
2948 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
2949 .await
2950 .trace_err();
2951
2952 if delete_live_kit_room {
2953 live_kit.delete_room(live_kit_room).await.trace_err();
2954 }
2955 }
2956
2957 Ok(())
2958}
2959
2960async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
2961 let left_channel_buffers = session
2962 .db()
2963 .await
2964 .leave_channel_buffers(session.connection_id)
2965 .await?;
2966
2967 for (channel_id, connections) in left_channel_buffers {
2968 channel_buffer_updated(
2969 session.connection_id,
2970 connections,
2971 &proto::RemoveChannelBufferCollaborator {
2972 channel_id: channel_id.to_proto(),
2973 peer_id: Some(session.connection_id.into()),
2974 },
2975 &session.peer,
2976 );
2977 }
2978
2979 Ok(())
2980}
2981
2982fn project_left(project: &db::LeftProject, session: &Session) {
2983 for connection_id in &project.connection_ids {
2984 if project.host_user_id == session.user_id {
2985 session
2986 .peer
2987 .send(
2988 *connection_id,
2989 proto::UnshareProject {
2990 project_id: project.id.to_proto(),
2991 },
2992 )
2993 .trace_err();
2994 } else {
2995 session
2996 .peer
2997 .send(
2998 *connection_id,
2999 proto::RemoveProjectCollaborator {
3000 project_id: project.id.to_proto(),
3001 peer_id: Some(session.connection_id.into()),
3002 },
3003 )
3004 .trace_err();
3005 }
3006 }
3007}
3008
3009pub trait ResultExt {
3010 type Ok;
3011
3012 fn trace_err(self) -> Option<Self::Ok>;
3013}
3014
3015impl<T, E> ResultExt for Result<T, E>
3016where
3017 E: std::fmt::Debug,
3018{
3019 type Ok = T;
3020
3021 fn trace_err(self) -> Option<T> {
3022 match self {
3023 Ok(value) => Some(value),
3024 Err(error) => {
3025 tracing::error!("{:?}", error);
3026 None
3027 }
3028 }
3029 }
3030}