1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{self, ChannelId, ChannelsForUser, Database, ProjectId, RoomId, ServerId, User, UserId},
6 executor::Executor,
7 AppState, Result,
8};
9use anyhow::anyhow;
10use async_tungstenite::tungstenite::{
11 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
12};
13use axum::{
14 body::Body,
15 extract::{
16 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
17 ConnectInfo, WebSocketUpgrade,
18 },
19 headers::{Header, HeaderName},
20 http::StatusCode,
21 middleware,
22 response::IntoResponse,
23 routing::get,
24 Extension, Router, TypedHeader,
25};
26use collections::{HashMap, HashSet};
27pub use connection_pool::ConnectionPool;
28use futures::{
29 channel::oneshot,
30 future::{self, BoxFuture},
31 stream::FuturesUnordered,
32 FutureExt, SinkExt, StreamExt, TryStreamExt,
33};
34use lazy_static::lazy_static;
35use prometheus::{register_int_gauge, IntGauge};
36use rpc::{
37 proto::{
38 self, Ack, AddChannelBufferCollaborator, AnyTypedEnvelope, EntityMessage, EnvelopedMessage,
39 LiveKitConnectionInfo, RequestMessage,
40 },
41 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
42};
43use serde::{Serialize, Serializer};
44use std::{
45 any::TypeId,
46 fmt,
47 future::Future,
48 marker::PhantomData,
49 mem,
50 net::SocketAddr,
51 ops::{Deref, DerefMut},
52 rc::Rc,
53 sync::{
54 atomic::{AtomicBool, Ordering::SeqCst},
55 Arc,
56 },
57 time::{Duration, Instant},
58};
59use tokio::sync::{watch, Semaphore};
60use tower::ServiceBuilder;
61use tracing::{info_span, instrument, Instrument};
62
63pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
64pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
65
66lazy_static! {
67 static ref METRIC_CONNECTIONS: IntGauge =
68 register_int_gauge!("connections", "number of connections").unwrap();
69 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
70 "shared_projects",
71 "number of open projects with one or more guests"
72 )
73 .unwrap();
74}
75
76type MessageHandler =
77 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
78
79struct Response<R> {
80 peer: Arc<Peer>,
81 receipt: Receipt<R>,
82 responded: Arc<AtomicBool>,
83}
84
85impl<R: RequestMessage> Response<R> {
86 fn send(self, payload: R::Response) -> Result<()> {
87 self.responded.store(true, SeqCst);
88 self.peer.respond(self.receipt, payload)?;
89 Ok(())
90 }
91}
92
93#[derive(Clone)]
94struct Session {
95 user_id: UserId,
96 connection_id: ConnectionId,
97 db: Arc<tokio::sync::Mutex<DbHandle>>,
98 peer: Arc<Peer>,
99 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
100 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
101 executor: Executor,
102}
103
104impl Session {
105 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
106 #[cfg(test)]
107 tokio::task::yield_now().await;
108 let guard = self.db.lock().await;
109 #[cfg(test)]
110 tokio::task::yield_now().await;
111 guard
112 }
113
114 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
115 #[cfg(test)]
116 tokio::task::yield_now().await;
117 let guard = self.connection_pool.lock();
118 ConnectionPoolGuard {
119 guard,
120 _not_send: PhantomData,
121 }
122 }
123}
124
125impl fmt::Debug for Session {
126 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127 f.debug_struct("Session")
128 .field("user_id", &self.user_id)
129 .field("connection_id", &self.connection_id)
130 .finish()
131 }
132}
133
134struct DbHandle(Arc<Database>);
135
136impl Deref for DbHandle {
137 type Target = Database;
138
139 fn deref(&self) -> &Self::Target {
140 self.0.as_ref()
141 }
142}
143
144pub struct Server {
145 id: parking_lot::Mutex<ServerId>,
146 peer: Arc<Peer>,
147 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
148 app_state: Arc<AppState>,
149 executor: Executor,
150 handlers: HashMap<TypeId, MessageHandler>,
151 teardown: watch::Sender<()>,
152}
153
154pub(crate) struct ConnectionPoolGuard<'a> {
155 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
156 _not_send: PhantomData<Rc<()>>,
157}
158
159#[derive(Serialize)]
160pub struct ServerSnapshot<'a> {
161 peer: &'a Peer,
162 #[serde(serialize_with = "serialize_deref")]
163 connection_pool: ConnectionPoolGuard<'a>,
164}
165
166pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
167where
168 S: Serializer,
169 T: Deref<Target = U>,
170 U: Serialize,
171{
172 Serialize::serialize(value.deref(), serializer)
173}
174
175impl Server {
176 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
177 let mut server = Self {
178 id: parking_lot::Mutex::new(id),
179 peer: Peer::new(id.0 as u32),
180 app_state,
181 executor,
182 connection_pool: Default::default(),
183 handlers: Default::default(),
184 teardown: watch::channel(()).0,
185 };
186
187 server
188 .add_request_handler(ping)
189 .add_request_handler(create_room)
190 .add_request_handler(join_room)
191 .add_request_handler(rejoin_room)
192 .add_request_handler(leave_room)
193 .add_request_handler(call)
194 .add_request_handler(cancel_call)
195 .add_message_handler(decline_call)
196 .add_request_handler(update_participant_location)
197 .add_request_handler(share_project)
198 .add_message_handler(unshare_project)
199 .add_request_handler(join_project)
200 .add_message_handler(leave_project)
201 .add_request_handler(update_project)
202 .add_request_handler(update_worktree)
203 .add_message_handler(start_language_server)
204 .add_message_handler(update_language_server)
205 .add_message_handler(update_diagnostic_summary)
206 .add_message_handler(update_worktree_settings)
207 .add_message_handler(refresh_inlay_hints)
208 .add_request_handler(forward_project_request::<proto::GetHover>)
209 .add_request_handler(forward_project_request::<proto::GetDefinition>)
210 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
211 .add_request_handler(forward_project_request::<proto::GetReferences>)
212 .add_request_handler(forward_project_request::<proto::SearchProject>)
213 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
214 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
215 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
216 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
217 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
218 .add_request_handler(forward_project_request::<proto::GetCompletions>)
219 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
220 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
221 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
222 .add_request_handler(forward_project_request::<proto::PrepareRename>)
223 .add_request_handler(forward_project_request::<proto::PerformRename>)
224 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
225 .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
226 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
227 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
228 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
229 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
230 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
231 .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
232 .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
233 .add_request_handler(forward_project_request::<proto::InlayHints>)
234 .add_message_handler(create_buffer_for_peer)
235 .add_request_handler(update_buffer)
236 .add_message_handler(update_buffer_file)
237 .add_message_handler(buffer_reloaded)
238 .add_message_handler(buffer_saved)
239 .add_request_handler(forward_project_request::<proto::SaveBuffer>)
240 .add_request_handler(get_users)
241 .add_request_handler(fuzzy_search_users)
242 .add_request_handler(request_contact)
243 .add_request_handler(remove_contact)
244 .add_request_handler(respond_to_contact_request)
245 .add_request_handler(create_channel)
246 .add_request_handler(remove_channel)
247 .add_request_handler(invite_channel_member)
248 .add_request_handler(remove_channel_member)
249 .add_request_handler(set_channel_member_admin)
250 .add_request_handler(rename_channel)
251 .add_request_handler(join_channel_buffer)
252 .add_request_handler(leave_channel_buffer)
253 .add_message_handler(update_channel_buffer)
254 .add_request_handler(rejoin_channel_buffers)
255 .add_request_handler(get_channel_members)
256 .add_request_handler(respond_to_channel_invite)
257 .add_request_handler(join_channel)
258 .add_request_handler(follow)
259 .add_message_handler(unfollow)
260 .add_message_handler(update_followers)
261 .add_message_handler(update_diff_base)
262 .add_request_handler(get_private_user_info);
263
264 Arc::new(server)
265 }
266
267 pub async fn start(&self) -> Result<()> {
268 let server_id = *self.id.lock();
269 let app_state = self.app_state.clone();
270 let peer = self.peer.clone();
271 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
272 let pool = self.connection_pool.clone();
273 let live_kit_client = self.app_state.live_kit_client.clone();
274
275 let span = info_span!("start server");
276 self.executor.spawn_detached(
277 async move {
278 tracing::info!("waiting for cleanup timeout");
279 timeout.await;
280 tracing::info!("cleanup timeout expired, retrieving stale rooms");
281 if let Some((room_ids, channel_ids)) = app_state
282 .db
283 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
284 .await
285 .trace_err()
286 {
287 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
288 tracing::info!(
289 stale_channel_buffer_count = channel_ids.len(),
290 "retrieved stale channel buffers"
291 );
292
293 for channel_id in channel_ids {
294 if let Some(refreshed_channel_buffer) = app_state
295 .db
296 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
297 .await
298 .trace_err()
299 {
300 for connection_id in refreshed_channel_buffer.connection_ids {
301 for message in &refreshed_channel_buffer.removed_collaborators {
302 peer.send(connection_id, message.clone()).trace_err();
303 }
304 }
305 }
306 }
307
308 for room_id in room_ids {
309 let mut contacts_to_update = HashSet::default();
310 let mut canceled_calls_to_user_ids = Vec::new();
311 let mut live_kit_room = String::new();
312 let mut delete_live_kit_room = false;
313
314 if let Some(mut refreshed_room) = app_state
315 .db
316 .clear_stale_room_participants(room_id, server_id)
317 .await
318 .trace_err()
319 {
320 tracing::info!(
321 room_id = room_id.0,
322 new_participant_count = refreshed_room.room.participants.len(),
323 "refreshed room"
324 );
325 room_updated(&refreshed_room.room, &peer);
326 if let Some(channel_id) = refreshed_room.channel_id {
327 channel_updated(
328 channel_id,
329 &refreshed_room.room,
330 &refreshed_room.channel_members,
331 &peer,
332 &*pool.lock(),
333 );
334 }
335 contacts_to_update
336 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
337 contacts_to_update
338 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
339 canceled_calls_to_user_ids =
340 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
341 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
342 delete_live_kit_room = refreshed_room.room.participants.is_empty();
343 }
344
345 {
346 let pool = pool.lock();
347 for canceled_user_id in canceled_calls_to_user_ids {
348 for connection_id in pool.user_connection_ids(canceled_user_id) {
349 peer.send(
350 connection_id,
351 proto::CallCanceled {
352 room_id: room_id.to_proto(),
353 },
354 )
355 .trace_err();
356 }
357 }
358 }
359
360 for user_id in contacts_to_update {
361 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
362 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
363 if let Some((busy, contacts)) = busy.zip(contacts) {
364 let pool = pool.lock();
365 let updated_contact = contact_for_user(user_id, false, busy, &pool);
366 for contact in contacts {
367 if let db::Contact::Accepted {
368 user_id: contact_user_id,
369 ..
370 } = contact
371 {
372 for contact_conn_id in
373 pool.user_connection_ids(contact_user_id)
374 {
375 peer.send(
376 contact_conn_id,
377 proto::UpdateContacts {
378 contacts: vec![updated_contact.clone()],
379 remove_contacts: Default::default(),
380 incoming_requests: Default::default(),
381 remove_incoming_requests: Default::default(),
382 outgoing_requests: Default::default(),
383 remove_outgoing_requests: Default::default(),
384 },
385 )
386 .trace_err();
387 }
388 }
389 }
390 }
391 }
392
393 if let Some(live_kit) = live_kit_client.as_ref() {
394 if delete_live_kit_room {
395 live_kit.delete_room(live_kit_room).await.trace_err();
396 }
397 }
398 }
399 }
400
401 app_state
402 .db
403 .delete_stale_servers(&app_state.config.zed_environment, server_id)
404 .await
405 .trace_err();
406 }
407 .instrument(span),
408 );
409 Ok(())
410 }
411
412 pub fn teardown(&self) {
413 self.peer.teardown();
414 self.connection_pool.lock().reset();
415 let _ = self.teardown.send(());
416 }
417
418 #[cfg(test)]
419 pub fn reset(&self, id: ServerId) {
420 self.teardown();
421 *self.id.lock() = id;
422 self.peer.reset(id.0 as u32);
423 }
424
425 #[cfg(test)]
426 pub fn id(&self) -> ServerId {
427 *self.id.lock()
428 }
429
430 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
431 where
432 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
433 Fut: 'static + Send + Future<Output = Result<()>>,
434 M: EnvelopedMessage,
435 {
436 let prev_handler = self.handlers.insert(
437 TypeId::of::<M>(),
438 Box::new(move |envelope, session| {
439 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
440 let span = info_span!(
441 "handle message",
442 payload_type = envelope.payload_type_name()
443 );
444 span.in_scope(|| {
445 tracing::info!(
446 payload_type = envelope.payload_type_name(),
447 "message received"
448 );
449 });
450 let start_time = Instant::now();
451 let future = (handler)(*envelope, session);
452 async move {
453 let result = future.await;
454 let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
455 match result {
456 Err(error) => {
457 tracing::error!(%error, ?duration_ms, "error handling message")
458 }
459 Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
460 }
461 }
462 .instrument(span)
463 .boxed()
464 }),
465 );
466 if prev_handler.is_some() {
467 panic!("registered a handler for the same message twice");
468 }
469 self
470 }
471
472 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
473 where
474 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
475 Fut: 'static + Send + Future<Output = Result<()>>,
476 M: EnvelopedMessage,
477 {
478 self.add_handler(move |envelope, session| handler(envelope.payload, session));
479 self
480 }
481
482 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
483 where
484 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
485 Fut: Send + Future<Output = Result<()>>,
486 M: RequestMessage,
487 {
488 let handler = Arc::new(handler);
489 self.add_handler(move |envelope, session| {
490 let receipt = envelope.receipt();
491 let handler = handler.clone();
492 async move {
493 let peer = session.peer.clone();
494 let responded = Arc::new(AtomicBool::default());
495 let response = Response {
496 peer: peer.clone(),
497 responded: responded.clone(),
498 receipt,
499 };
500 match (handler)(envelope.payload, response, session).await {
501 Ok(()) => {
502 if responded.load(std::sync::atomic::Ordering::SeqCst) {
503 Ok(())
504 } else {
505 Err(anyhow!("handler did not send a response"))?
506 }
507 }
508 Err(error) => {
509 peer.respond_with_error(
510 receipt,
511 proto::Error {
512 message: error.to_string(),
513 },
514 )?;
515 Err(error)
516 }
517 }
518 }
519 })
520 }
521
522 pub fn handle_connection(
523 self: &Arc<Self>,
524 connection: Connection,
525 address: String,
526 user: User,
527 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
528 executor: Executor,
529 ) -> impl Future<Output = Result<()>> {
530 let this = self.clone();
531 let user_id = user.id;
532 let login = user.github_login;
533 let span = info_span!("handle connection", %user_id, %login, %address);
534 let mut teardown = self.teardown.subscribe();
535 async move {
536 let (connection_id, handle_io, mut incoming_rx) = this
537 .peer
538 .add_connection(connection, {
539 let executor = executor.clone();
540 move |duration| executor.sleep(duration)
541 });
542
543 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
544 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
545 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
546
547 if let Some(send_connection_id) = send_connection_id.take() {
548 let _ = send_connection_id.send(connection_id);
549 }
550
551 if !user.connected_once {
552 this.peer.send(connection_id, proto::ShowContacts {})?;
553 this.app_state.db.set_user_connected_once(user_id, true).await?;
554 }
555
556 let (contacts, invite_code, channels_for_user, channel_invites) = future::try_join4(
557 this.app_state.db.get_contacts(user_id),
558 this.app_state.db.get_invite_code_for_user(user_id),
559 this.app_state.db.get_channels_for_user(user_id),
560 this.app_state.db.get_channel_invites_for_user(user_id)
561 ).await?;
562
563 {
564 let mut pool = this.connection_pool.lock();
565 pool.add_connection(connection_id, user_id, user.admin);
566 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
567 this.peer.send(connection_id, build_initial_channels_update(
568 channels_for_user,
569 channel_invites
570 ))?;
571
572 if let Some((code, count)) = invite_code {
573 this.peer.send(connection_id, proto::UpdateInviteInfo {
574 url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
575 count: count as u32,
576 })?;
577 }
578 }
579
580 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
581 this.peer.send(connection_id, incoming_call)?;
582 }
583
584 let session = Session {
585 user_id,
586 connection_id,
587 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
588 peer: this.peer.clone(),
589 connection_pool: this.connection_pool.clone(),
590 live_kit_client: this.app_state.live_kit_client.clone(),
591 executor: executor.clone(),
592 };
593 update_user_contacts(user_id, &session).await?;
594
595 let handle_io = handle_io.fuse();
596 futures::pin_mut!(handle_io);
597
598 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
599 // This prevents deadlocks when e.g., client A performs a request to client B and
600 // client B performs a request to client A. If both clients stop processing further
601 // messages until their respective request completes, they won't have a chance to
602 // respond to the other client's request and cause a deadlock.
603 //
604 // This arrangement ensures we will attempt to process earlier messages first, but fall
605 // back to processing messages arrived later in the spirit of making progress.
606 let mut foreground_message_handlers = FuturesUnordered::new();
607 let concurrent_handlers = Arc::new(Semaphore::new(256));
608 loop {
609 let next_message = async {
610 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
611 let message = incoming_rx.next().await;
612 (permit, message)
613 }.fuse();
614 futures::pin_mut!(next_message);
615 futures::select_biased! {
616 _ = teardown.changed().fuse() => return Ok(()),
617 result = handle_io => {
618 if let Err(error) = result {
619 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
620 }
621 break;
622 }
623 _ = foreground_message_handlers.next() => {}
624 next_message = next_message => {
625 let (permit, message) = next_message;
626 if let Some(message) = message {
627 let type_name = message.payload_type_name();
628 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
629 let span_enter = span.enter();
630 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
631 let is_background = message.is_background();
632 let handle_message = (handler)(message, session.clone());
633 drop(span_enter);
634
635 let handle_message = async move {
636 handle_message.await;
637 drop(permit);
638 }.instrument(span);
639 if is_background {
640 executor.spawn_detached(handle_message);
641 } else {
642 foreground_message_handlers.push(handle_message);
643 }
644 } else {
645 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
646 }
647 } else {
648 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
649 break;
650 }
651 }
652 }
653 }
654
655 drop(foreground_message_handlers);
656 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
657 if let Err(error) = connection_lost(session, teardown, executor).await {
658 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
659 }
660
661 Ok(())
662 }.instrument(span)
663 }
664
665 pub async fn invite_code_redeemed(
666 self: &Arc<Self>,
667 inviter_id: UserId,
668 invitee_id: UserId,
669 ) -> Result<()> {
670 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
671 if let Some(code) = &user.invite_code {
672 let pool = self.connection_pool.lock();
673 let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
674 for connection_id in pool.user_connection_ids(inviter_id) {
675 self.peer.send(
676 connection_id,
677 proto::UpdateContacts {
678 contacts: vec![invitee_contact.clone()],
679 ..Default::default()
680 },
681 )?;
682 self.peer.send(
683 connection_id,
684 proto::UpdateInviteInfo {
685 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
686 count: user.invite_count as u32,
687 },
688 )?;
689 }
690 }
691 }
692 Ok(())
693 }
694
695 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
696 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
697 if let Some(invite_code) = &user.invite_code {
698 let pool = self.connection_pool.lock();
699 for connection_id in pool.user_connection_ids(user_id) {
700 self.peer.send(
701 connection_id,
702 proto::UpdateInviteInfo {
703 url: format!(
704 "{}{}",
705 self.app_state.config.invite_link_prefix, invite_code
706 ),
707 count: user.invite_count as u32,
708 },
709 )?;
710 }
711 }
712 }
713 Ok(())
714 }
715
716 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
717 ServerSnapshot {
718 connection_pool: ConnectionPoolGuard {
719 guard: self.connection_pool.lock(),
720 _not_send: PhantomData,
721 },
722 peer: &self.peer,
723 }
724 }
725}
726
727impl<'a> Deref for ConnectionPoolGuard<'a> {
728 type Target = ConnectionPool;
729
730 fn deref(&self) -> &Self::Target {
731 &*self.guard
732 }
733}
734
735impl<'a> DerefMut for ConnectionPoolGuard<'a> {
736 fn deref_mut(&mut self) -> &mut Self::Target {
737 &mut *self.guard
738 }
739}
740
741impl<'a> Drop for ConnectionPoolGuard<'a> {
742 fn drop(&mut self) {
743 #[cfg(test)]
744 self.check_invariants();
745 }
746}
747
748fn broadcast<F>(
749 sender_id: Option<ConnectionId>,
750 receiver_ids: impl IntoIterator<Item = ConnectionId>,
751 mut f: F,
752) where
753 F: FnMut(ConnectionId) -> anyhow::Result<()>,
754{
755 for receiver_id in receiver_ids {
756 if Some(receiver_id) != sender_id {
757 if let Err(error) = f(receiver_id) {
758 tracing::error!("failed to send to {:?} {}", receiver_id, error);
759 }
760 }
761 }
762}
763
764lazy_static! {
765 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
766}
767
768pub struct ProtocolVersion(u32);
769
770impl Header for ProtocolVersion {
771 fn name() -> &'static HeaderName {
772 &ZED_PROTOCOL_VERSION
773 }
774
775 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
776 where
777 Self: Sized,
778 I: Iterator<Item = &'i axum::http::HeaderValue>,
779 {
780 let version = values
781 .next()
782 .ok_or_else(axum::headers::Error::invalid)?
783 .to_str()
784 .map_err(|_| axum::headers::Error::invalid())?
785 .parse()
786 .map_err(|_| axum::headers::Error::invalid())?;
787 Ok(Self(version))
788 }
789
790 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
791 values.extend([self.0.to_string().parse().unwrap()]);
792 }
793}
794
795pub fn routes(server: Arc<Server>) -> Router<Body> {
796 Router::new()
797 .route("/rpc", get(handle_websocket_request))
798 .layer(
799 ServiceBuilder::new()
800 .layer(Extension(server.app_state.clone()))
801 .layer(middleware::from_fn(auth::validate_header)),
802 )
803 .route("/metrics", get(handle_metrics))
804 .layer(Extension(server))
805}
806
807pub async fn handle_websocket_request(
808 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
809 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
810 Extension(server): Extension<Arc<Server>>,
811 Extension(user): Extension<User>,
812 ws: WebSocketUpgrade,
813) -> axum::response::Response {
814 if protocol_version != rpc::PROTOCOL_VERSION {
815 return (
816 StatusCode::UPGRADE_REQUIRED,
817 "client must be upgraded".to_string(),
818 )
819 .into_response();
820 }
821 let socket_address = socket_address.to_string();
822 ws.on_upgrade(move |socket| {
823 use util::ResultExt;
824 let socket = socket
825 .map_ok(to_tungstenite_message)
826 .err_into()
827 .with(|message| async move { Ok(to_axum_message(message)) });
828 let connection = Connection::new(Box::pin(socket));
829 async move {
830 server
831 .handle_connection(connection, socket_address, user, None, Executor::Production)
832 .await
833 .log_err();
834 }
835 })
836}
837
838pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
839 let connections = server
840 .connection_pool
841 .lock()
842 .connections()
843 .filter(|connection| !connection.admin)
844 .count();
845
846 METRIC_CONNECTIONS.set(connections as _);
847
848 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
849 METRIC_SHARED_PROJECTS.set(shared_projects as _);
850
851 let encoder = prometheus::TextEncoder::new();
852 let metric_families = prometheus::gather();
853 let encoded_metrics = encoder
854 .encode_to_string(&metric_families)
855 .map_err(|err| anyhow!("{}", err))?;
856 Ok(encoded_metrics)
857}
858
859#[instrument(err, skip(executor))]
860async fn connection_lost(
861 session: Session,
862 mut teardown: watch::Receiver<()>,
863 executor: Executor,
864) -> Result<()> {
865 session.peer.disconnect(session.connection_id);
866 session
867 .connection_pool()
868 .await
869 .remove_connection(session.connection_id)?;
870
871 session
872 .db()
873 .await
874 .connection_lost(session.connection_id)
875 .await
876 .trace_err();
877
878 futures::select_biased! {
879 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
880 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
881 leave_room_for_session(&session).await.trace_err();
882 leave_channel_buffers_for_session(&session)
883 .await
884 .trace_err();
885
886 if !session
887 .connection_pool()
888 .await
889 .is_user_online(session.user_id)
890 {
891 let db = session.db().await;
892 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
893 room_updated(&room, &session.peer);
894 }
895 }
896 update_user_contacts(session.user_id, &session).await?;
897
898
899 }
900 _ = teardown.changed().fuse() => {}
901 }
902
903 Ok(())
904}
905
906async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
907 response.send(proto::Ack {})?;
908 Ok(())
909}
910
911async fn create_room(
912 _request: proto::CreateRoom,
913 response: Response<proto::CreateRoom>,
914 session: Session,
915) -> Result<()> {
916 let live_kit_room = nanoid::nanoid!(30);
917
918 let live_kit_connection_info = {
919 let live_kit_room = live_kit_room.clone();
920 let live_kit = session.live_kit_client.as_ref();
921
922 util::async_iife!({
923 let live_kit = live_kit?;
924
925 live_kit
926 .create_room(live_kit_room.clone())
927 .await
928 .trace_err()?;
929
930 let token = live_kit
931 .room_token(&live_kit_room, &session.user_id.to_string())
932 .trace_err()?;
933
934 Some(proto::LiveKitConnectionInfo {
935 server_url: live_kit.url().into(),
936 token,
937 })
938 })
939 }
940 .await;
941
942 let room = session
943 .db()
944 .await
945 .create_room(session.user_id, session.connection_id, &live_kit_room)
946 .await?;
947
948 response.send(proto::CreateRoomResponse {
949 room: Some(room.clone()),
950 live_kit_connection_info,
951 })?;
952
953 update_user_contacts(session.user_id, &session).await?;
954 Ok(())
955}
956
957async fn join_room(
958 request: proto::JoinRoom,
959 response: Response<proto::JoinRoom>,
960 session: Session,
961) -> Result<()> {
962 let room_id = RoomId::from_proto(request.id);
963 let joined_room = {
964 let room = session
965 .db()
966 .await
967 .join_room(room_id, session.user_id, session.connection_id)
968 .await?;
969 room_updated(&room.room, &session.peer);
970 room.into_inner()
971 };
972
973 if let Some(channel_id) = joined_room.channel_id {
974 channel_updated(
975 channel_id,
976 &joined_room.room,
977 &joined_room.channel_members,
978 &session.peer,
979 &*session.connection_pool().await,
980 )
981 }
982
983 for connection_id in session
984 .connection_pool()
985 .await
986 .user_connection_ids(session.user_id)
987 {
988 session
989 .peer
990 .send(
991 connection_id,
992 proto::CallCanceled {
993 room_id: room_id.to_proto(),
994 },
995 )
996 .trace_err();
997 }
998
999 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1000 if let Some(token) = live_kit
1001 .room_token(
1002 &joined_room.room.live_kit_room,
1003 &session.user_id.to_string(),
1004 )
1005 .trace_err()
1006 {
1007 Some(proto::LiveKitConnectionInfo {
1008 server_url: live_kit.url().into(),
1009 token,
1010 })
1011 } else {
1012 None
1013 }
1014 } else {
1015 None
1016 };
1017
1018 response.send(proto::JoinRoomResponse {
1019 room: Some(joined_room.room),
1020 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
1021 live_kit_connection_info,
1022 })?;
1023
1024 update_user_contacts(session.user_id, &session).await?;
1025 Ok(())
1026}
1027
1028async fn rejoin_room(
1029 request: proto::RejoinRoom,
1030 response: Response<proto::RejoinRoom>,
1031 session: Session,
1032) -> Result<()> {
1033 let room;
1034 let channel_id;
1035 let channel_members;
1036 {
1037 let mut rejoined_room = session
1038 .db()
1039 .await
1040 .rejoin_room(request, session.user_id, session.connection_id)
1041 .await?;
1042
1043 response.send(proto::RejoinRoomResponse {
1044 room: Some(rejoined_room.room.clone()),
1045 reshared_projects: rejoined_room
1046 .reshared_projects
1047 .iter()
1048 .map(|project| proto::ResharedProject {
1049 id: project.id.to_proto(),
1050 collaborators: project
1051 .collaborators
1052 .iter()
1053 .map(|collaborator| collaborator.to_proto())
1054 .collect(),
1055 })
1056 .collect(),
1057 rejoined_projects: rejoined_room
1058 .rejoined_projects
1059 .iter()
1060 .map(|rejoined_project| proto::RejoinedProject {
1061 id: rejoined_project.id.to_proto(),
1062 worktrees: rejoined_project
1063 .worktrees
1064 .iter()
1065 .map(|worktree| proto::WorktreeMetadata {
1066 id: worktree.id,
1067 root_name: worktree.root_name.clone(),
1068 visible: worktree.visible,
1069 abs_path: worktree.abs_path.clone(),
1070 })
1071 .collect(),
1072 collaborators: rejoined_project
1073 .collaborators
1074 .iter()
1075 .map(|collaborator| collaborator.to_proto())
1076 .collect(),
1077 language_servers: rejoined_project.language_servers.clone(),
1078 })
1079 .collect(),
1080 })?;
1081 room_updated(&rejoined_room.room, &session.peer);
1082
1083 for project in &rejoined_room.reshared_projects {
1084 for collaborator in &project.collaborators {
1085 session
1086 .peer
1087 .send(
1088 collaborator.connection_id,
1089 proto::UpdateProjectCollaborator {
1090 project_id: project.id.to_proto(),
1091 old_peer_id: Some(project.old_connection_id.into()),
1092 new_peer_id: Some(session.connection_id.into()),
1093 },
1094 )
1095 .trace_err();
1096 }
1097
1098 broadcast(
1099 Some(session.connection_id),
1100 project
1101 .collaborators
1102 .iter()
1103 .map(|collaborator| collaborator.connection_id),
1104 |connection_id| {
1105 session.peer.forward_send(
1106 session.connection_id,
1107 connection_id,
1108 proto::UpdateProject {
1109 project_id: project.id.to_proto(),
1110 worktrees: project.worktrees.clone(),
1111 },
1112 )
1113 },
1114 );
1115 }
1116
1117 for project in &rejoined_room.rejoined_projects {
1118 for collaborator in &project.collaborators {
1119 session
1120 .peer
1121 .send(
1122 collaborator.connection_id,
1123 proto::UpdateProjectCollaborator {
1124 project_id: project.id.to_proto(),
1125 old_peer_id: Some(project.old_connection_id.into()),
1126 new_peer_id: Some(session.connection_id.into()),
1127 },
1128 )
1129 .trace_err();
1130 }
1131 }
1132
1133 for project in &mut rejoined_room.rejoined_projects {
1134 for worktree in mem::take(&mut project.worktrees) {
1135 #[cfg(any(test, feature = "test-support"))]
1136 const MAX_CHUNK_SIZE: usize = 2;
1137 #[cfg(not(any(test, feature = "test-support")))]
1138 const MAX_CHUNK_SIZE: usize = 256;
1139
1140 // Stream this worktree's entries.
1141 let message = proto::UpdateWorktree {
1142 project_id: project.id.to_proto(),
1143 worktree_id: worktree.id,
1144 abs_path: worktree.abs_path.clone(),
1145 root_name: worktree.root_name,
1146 updated_entries: worktree.updated_entries,
1147 removed_entries: worktree.removed_entries,
1148 scan_id: worktree.scan_id,
1149 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1150 updated_repositories: worktree.updated_repositories,
1151 removed_repositories: worktree.removed_repositories,
1152 };
1153 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1154 session.peer.send(session.connection_id, update.clone())?;
1155 }
1156
1157 // Stream this worktree's diagnostics.
1158 for summary in worktree.diagnostic_summaries {
1159 session.peer.send(
1160 session.connection_id,
1161 proto::UpdateDiagnosticSummary {
1162 project_id: project.id.to_proto(),
1163 worktree_id: worktree.id,
1164 summary: Some(summary),
1165 },
1166 )?;
1167 }
1168
1169 for settings_file in worktree.settings_files {
1170 session.peer.send(
1171 session.connection_id,
1172 proto::UpdateWorktreeSettings {
1173 project_id: project.id.to_proto(),
1174 worktree_id: worktree.id,
1175 path: settings_file.path,
1176 content: Some(settings_file.content),
1177 },
1178 )?;
1179 }
1180 }
1181
1182 for language_server in &project.language_servers {
1183 session.peer.send(
1184 session.connection_id,
1185 proto::UpdateLanguageServer {
1186 project_id: project.id.to_proto(),
1187 language_server_id: language_server.id,
1188 variant: Some(
1189 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1190 proto::LspDiskBasedDiagnosticsUpdated {},
1191 ),
1192 ),
1193 },
1194 )?;
1195 }
1196 }
1197
1198 let rejoined_room = rejoined_room.into_inner();
1199
1200 room = rejoined_room.room;
1201 channel_id = rejoined_room.channel_id;
1202 channel_members = rejoined_room.channel_members;
1203 }
1204
1205 if let Some(channel_id) = channel_id {
1206 channel_updated(
1207 channel_id,
1208 &room,
1209 &channel_members,
1210 &session.peer,
1211 &*session.connection_pool().await,
1212 );
1213 }
1214
1215 update_user_contacts(session.user_id, &session).await?;
1216 Ok(())
1217}
1218
1219async fn leave_room(
1220 _: proto::LeaveRoom,
1221 response: Response<proto::LeaveRoom>,
1222 session: Session,
1223) -> Result<()> {
1224 leave_room_for_session(&session).await?;
1225 response.send(proto::Ack {})?;
1226 Ok(())
1227}
1228
1229async fn call(
1230 request: proto::Call,
1231 response: Response<proto::Call>,
1232 session: Session,
1233) -> Result<()> {
1234 let room_id = RoomId::from_proto(request.room_id);
1235 let calling_user_id = session.user_id;
1236 let calling_connection_id = session.connection_id;
1237 let called_user_id = UserId::from_proto(request.called_user_id);
1238 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1239 if !session
1240 .db()
1241 .await
1242 .has_contact(calling_user_id, called_user_id)
1243 .await?
1244 {
1245 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1246 }
1247
1248 let incoming_call = {
1249 let (room, incoming_call) = &mut *session
1250 .db()
1251 .await
1252 .call(
1253 room_id,
1254 calling_user_id,
1255 calling_connection_id,
1256 called_user_id,
1257 initial_project_id,
1258 )
1259 .await?;
1260 room_updated(&room, &session.peer);
1261 mem::take(incoming_call)
1262 };
1263 update_user_contacts(called_user_id, &session).await?;
1264
1265 let mut calls = session
1266 .connection_pool()
1267 .await
1268 .user_connection_ids(called_user_id)
1269 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1270 .collect::<FuturesUnordered<_>>();
1271
1272 while let Some(call_response) = calls.next().await {
1273 match call_response.as_ref() {
1274 Ok(_) => {
1275 response.send(proto::Ack {})?;
1276 return Ok(());
1277 }
1278 Err(_) => {
1279 call_response.trace_err();
1280 }
1281 }
1282 }
1283
1284 {
1285 let room = session
1286 .db()
1287 .await
1288 .call_failed(room_id, called_user_id)
1289 .await?;
1290 room_updated(&room, &session.peer);
1291 }
1292 update_user_contacts(called_user_id, &session).await?;
1293
1294 Err(anyhow!("failed to ring user"))?
1295}
1296
1297async fn cancel_call(
1298 request: proto::CancelCall,
1299 response: Response<proto::CancelCall>,
1300 session: Session,
1301) -> Result<()> {
1302 let called_user_id = UserId::from_proto(request.called_user_id);
1303 let room_id = RoomId::from_proto(request.room_id);
1304 {
1305 let room = session
1306 .db()
1307 .await
1308 .cancel_call(room_id, session.connection_id, called_user_id)
1309 .await?;
1310 room_updated(&room, &session.peer);
1311 }
1312
1313 for connection_id in session
1314 .connection_pool()
1315 .await
1316 .user_connection_ids(called_user_id)
1317 {
1318 session
1319 .peer
1320 .send(
1321 connection_id,
1322 proto::CallCanceled {
1323 room_id: room_id.to_proto(),
1324 },
1325 )
1326 .trace_err();
1327 }
1328 response.send(proto::Ack {})?;
1329
1330 update_user_contacts(called_user_id, &session).await?;
1331 Ok(())
1332}
1333
1334async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1335 let room_id = RoomId::from_proto(message.room_id);
1336 {
1337 let room = session
1338 .db()
1339 .await
1340 .decline_call(Some(room_id), session.user_id)
1341 .await?
1342 .ok_or_else(|| anyhow!("failed to decline call"))?;
1343 room_updated(&room, &session.peer);
1344 }
1345
1346 for connection_id in session
1347 .connection_pool()
1348 .await
1349 .user_connection_ids(session.user_id)
1350 {
1351 session
1352 .peer
1353 .send(
1354 connection_id,
1355 proto::CallCanceled {
1356 room_id: room_id.to_proto(),
1357 },
1358 )
1359 .trace_err();
1360 }
1361 update_user_contacts(session.user_id, &session).await?;
1362 Ok(())
1363}
1364
1365async fn update_participant_location(
1366 request: proto::UpdateParticipantLocation,
1367 response: Response<proto::UpdateParticipantLocation>,
1368 session: Session,
1369) -> Result<()> {
1370 let room_id = RoomId::from_proto(request.room_id);
1371 let location = request
1372 .location
1373 .ok_or_else(|| anyhow!("invalid location"))?;
1374
1375 let db = session.db().await;
1376 let room = db
1377 .update_room_participant_location(room_id, session.connection_id, location)
1378 .await?;
1379
1380 room_updated(&room, &session.peer);
1381 response.send(proto::Ack {})?;
1382 Ok(())
1383}
1384
1385async fn share_project(
1386 request: proto::ShareProject,
1387 response: Response<proto::ShareProject>,
1388 session: Session,
1389) -> Result<()> {
1390 let (project_id, room) = &*session
1391 .db()
1392 .await
1393 .share_project(
1394 RoomId::from_proto(request.room_id),
1395 session.connection_id,
1396 &request.worktrees,
1397 )
1398 .await?;
1399 response.send(proto::ShareProjectResponse {
1400 project_id: project_id.to_proto(),
1401 })?;
1402 room_updated(&room, &session.peer);
1403
1404 Ok(())
1405}
1406
1407async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1408 let project_id = ProjectId::from_proto(message.project_id);
1409
1410 let (room, guest_connection_ids) = &*session
1411 .db()
1412 .await
1413 .unshare_project(project_id, session.connection_id)
1414 .await?;
1415
1416 broadcast(
1417 Some(session.connection_id),
1418 guest_connection_ids.iter().copied(),
1419 |conn_id| session.peer.send(conn_id, message.clone()),
1420 );
1421 room_updated(&room, &session.peer);
1422
1423 Ok(())
1424}
1425
1426async fn join_project(
1427 request: proto::JoinProject,
1428 response: Response<proto::JoinProject>,
1429 session: Session,
1430) -> Result<()> {
1431 let project_id = ProjectId::from_proto(request.project_id);
1432 let guest_user_id = session.user_id;
1433
1434 tracing::info!(%project_id, "join project");
1435
1436 let (project, replica_id) = &mut *session
1437 .db()
1438 .await
1439 .join_project(project_id, session.connection_id)
1440 .await?;
1441
1442 let collaborators = project
1443 .collaborators
1444 .iter()
1445 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1446 .map(|collaborator| collaborator.to_proto())
1447 .collect::<Vec<_>>();
1448
1449 let worktrees = project
1450 .worktrees
1451 .iter()
1452 .map(|(id, worktree)| proto::WorktreeMetadata {
1453 id: *id,
1454 root_name: worktree.root_name.clone(),
1455 visible: worktree.visible,
1456 abs_path: worktree.abs_path.clone(),
1457 })
1458 .collect::<Vec<_>>();
1459
1460 for collaborator in &collaborators {
1461 session
1462 .peer
1463 .send(
1464 collaborator.peer_id.unwrap().into(),
1465 proto::AddProjectCollaborator {
1466 project_id: project_id.to_proto(),
1467 collaborator: Some(proto::Collaborator {
1468 peer_id: Some(session.connection_id.into()),
1469 replica_id: replica_id.0 as u32,
1470 user_id: guest_user_id.to_proto(),
1471 }),
1472 },
1473 )
1474 .trace_err();
1475 }
1476
1477 // First, we send the metadata associated with each worktree.
1478 response.send(proto::JoinProjectResponse {
1479 worktrees: worktrees.clone(),
1480 replica_id: replica_id.0 as u32,
1481 collaborators: collaborators.clone(),
1482 language_servers: project.language_servers.clone(),
1483 })?;
1484
1485 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1486 #[cfg(any(test, feature = "test-support"))]
1487 const MAX_CHUNK_SIZE: usize = 2;
1488 #[cfg(not(any(test, feature = "test-support")))]
1489 const MAX_CHUNK_SIZE: usize = 256;
1490
1491 // Stream this worktree's entries.
1492 let message = proto::UpdateWorktree {
1493 project_id: project_id.to_proto(),
1494 worktree_id,
1495 abs_path: worktree.abs_path.clone(),
1496 root_name: worktree.root_name,
1497 updated_entries: worktree.entries,
1498 removed_entries: Default::default(),
1499 scan_id: worktree.scan_id,
1500 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1501 updated_repositories: worktree.repository_entries.into_values().collect(),
1502 removed_repositories: Default::default(),
1503 };
1504 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1505 session.peer.send(session.connection_id, update.clone())?;
1506 }
1507
1508 // Stream this worktree's diagnostics.
1509 for summary in worktree.diagnostic_summaries {
1510 session.peer.send(
1511 session.connection_id,
1512 proto::UpdateDiagnosticSummary {
1513 project_id: project_id.to_proto(),
1514 worktree_id: worktree.id,
1515 summary: Some(summary),
1516 },
1517 )?;
1518 }
1519
1520 for settings_file in worktree.settings_files {
1521 session.peer.send(
1522 session.connection_id,
1523 proto::UpdateWorktreeSettings {
1524 project_id: project_id.to_proto(),
1525 worktree_id: worktree.id,
1526 path: settings_file.path,
1527 content: Some(settings_file.content),
1528 },
1529 )?;
1530 }
1531 }
1532
1533 for language_server in &project.language_servers {
1534 session.peer.send(
1535 session.connection_id,
1536 proto::UpdateLanguageServer {
1537 project_id: project_id.to_proto(),
1538 language_server_id: language_server.id,
1539 variant: Some(
1540 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1541 proto::LspDiskBasedDiagnosticsUpdated {},
1542 ),
1543 ),
1544 },
1545 )?;
1546 }
1547
1548 Ok(())
1549}
1550
1551async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1552 let sender_id = session.connection_id;
1553 let project_id = ProjectId::from_proto(request.project_id);
1554
1555 let (room, project) = &*session
1556 .db()
1557 .await
1558 .leave_project(project_id, sender_id)
1559 .await?;
1560 tracing::info!(
1561 %project_id,
1562 host_user_id = %project.host_user_id,
1563 host_connection_id = %project.host_connection_id,
1564 "leave project"
1565 );
1566
1567 project_left(&project, &session);
1568 room_updated(&room, &session.peer);
1569
1570 Ok(())
1571}
1572
1573async fn update_project(
1574 request: proto::UpdateProject,
1575 response: Response<proto::UpdateProject>,
1576 session: Session,
1577) -> Result<()> {
1578 let project_id = ProjectId::from_proto(request.project_id);
1579 let (room, guest_connection_ids) = &*session
1580 .db()
1581 .await
1582 .update_project(project_id, session.connection_id, &request.worktrees)
1583 .await?;
1584 broadcast(
1585 Some(session.connection_id),
1586 guest_connection_ids.iter().copied(),
1587 |connection_id| {
1588 session
1589 .peer
1590 .forward_send(session.connection_id, connection_id, request.clone())
1591 },
1592 );
1593 room_updated(&room, &session.peer);
1594 response.send(proto::Ack {})?;
1595
1596 Ok(())
1597}
1598
1599async fn update_worktree(
1600 request: proto::UpdateWorktree,
1601 response: Response<proto::UpdateWorktree>,
1602 session: Session,
1603) -> Result<()> {
1604 let guest_connection_ids = session
1605 .db()
1606 .await
1607 .update_worktree(&request, session.connection_id)
1608 .await?;
1609
1610 broadcast(
1611 Some(session.connection_id),
1612 guest_connection_ids.iter().copied(),
1613 |connection_id| {
1614 session
1615 .peer
1616 .forward_send(session.connection_id, connection_id, request.clone())
1617 },
1618 );
1619 response.send(proto::Ack {})?;
1620 Ok(())
1621}
1622
1623async fn update_diagnostic_summary(
1624 message: proto::UpdateDiagnosticSummary,
1625 session: Session,
1626) -> Result<()> {
1627 let guest_connection_ids = session
1628 .db()
1629 .await
1630 .update_diagnostic_summary(&message, session.connection_id)
1631 .await?;
1632
1633 broadcast(
1634 Some(session.connection_id),
1635 guest_connection_ids.iter().copied(),
1636 |connection_id| {
1637 session
1638 .peer
1639 .forward_send(session.connection_id, connection_id, message.clone())
1640 },
1641 );
1642
1643 Ok(())
1644}
1645
1646async fn update_worktree_settings(
1647 message: proto::UpdateWorktreeSettings,
1648 session: Session,
1649) -> Result<()> {
1650 let guest_connection_ids = session
1651 .db()
1652 .await
1653 .update_worktree_settings(&message, session.connection_id)
1654 .await?;
1655
1656 broadcast(
1657 Some(session.connection_id),
1658 guest_connection_ids.iter().copied(),
1659 |connection_id| {
1660 session
1661 .peer
1662 .forward_send(session.connection_id, connection_id, message.clone())
1663 },
1664 );
1665
1666 Ok(())
1667}
1668
1669async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1670 broadcast_project_message(request.project_id, request, session).await
1671}
1672
1673async fn start_language_server(
1674 request: proto::StartLanguageServer,
1675 session: Session,
1676) -> Result<()> {
1677 let guest_connection_ids = session
1678 .db()
1679 .await
1680 .start_language_server(&request, session.connection_id)
1681 .await?;
1682
1683 broadcast(
1684 Some(session.connection_id),
1685 guest_connection_ids.iter().copied(),
1686 |connection_id| {
1687 session
1688 .peer
1689 .forward_send(session.connection_id, connection_id, request.clone())
1690 },
1691 );
1692 Ok(())
1693}
1694
1695async fn update_language_server(
1696 request: proto::UpdateLanguageServer,
1697 session: Session,
1698) -> Result<()> {
1699 session.executor.record_backtrace();
1700 let project_id = ProjectId::from_proto(request.project_id);
1701 let project_connection_ids = session
1702 .db()
1703 .await
1704 .project_connection_ids(project_id, session.connection_id)
1705 .await?;
1706 broadcast(
1707 Some(session.connection_id),
1708 project_connection_ids.iter().copied(),
1709 |connection_id| {
1710 session
1711 .peer
1712 .forward_send(session.connection_id, connection_id, request.clone())
1713 },
1714 );
1715 Ok(())
1716}
1717
1718async fn forward_project_request<T>(
1719 request: T,
1720 response: Response<T>,
1721 session: Session,
1722) -> Result<()>
1723where
1724 T: EntityMessage + RequestMessage,
1725{
1726 session.executor.record_backtrace();
1727 let project_id = ProjectId::from_proto(request.remote_entity_id());
1728 let host_connection_id = {
1729 let collaborators = session
1730 .db()
1731 .await
1732 .project_collaborators(project_id, session.connection_id)
1733 .await?;
1734 collaborators
1735 .iter()
1736 .find(|collaborator| collaborator.is_host)
1737 .ok_or_else(|| anyhow!("host not found"))?
1738 .connection_id
1739 };
1740
1741 let payload = session
1742 .peer
1743 .forward_request(session.connection_id, host_connection_id, request)
1744 .await?;
1745
1746 response.send(payload)?;
1747 Ok(())
1748}
1749
1750async fn create_buffer_for_peer(
1751 request: proto::CreateBufferForPeer,
1752 session: Session,
1753) -> Result<()> {
1754 session.executor.record_backtrace();
1755 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1756 session
1757 .peer
1758 .forward_send(session.connection_id, peer_id.into(), request)?;
1759 Ok(())
1760}
1761
1762async fn update_buffer(
1763 request: proto::UpdateBuffer,
1764 response: Response<proto::UpdateBuffer>,
1765 session: Session,
1766) -> Result<()> {
1767 session.executor.record_backtrace();
1768 let project_id = ProjectId::from_proto(request.project_id);
1769 let mut guest_connection_ids;
1770 let mut host_connection_id = None;
1771 {
1772 let collaborators = session
1773 .db()
1774 .await
1775 .project_collaborators(project_id, session.connection_id)
1776 .await?;
1777 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1778 for collaborator in collaborators.iter() {
1779 if collaborator.is_host {
1780 host_connection_id = Some(collaborator.connection_id);
1781 } else {
1782 guest_connection_ids.push(collaborator.connection_id);
1783 }
1784 }
1785 }
1786 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1787
1788 session.executor.record_backtrace();
1789 broadcast(
1790 Some(session.connection_id),
1791 guest_connection_ids,
1792 |connection_id| {
1793 session
1794 .peer
1795 .forward_send(session.connection_id, connection_id, request.clone())
1796 },
1797 );
1798 if host_connection_id != session.connection_id {
1799 session
1800 .peer
1801 .forward_request(session.connection_id, host_connection_id, request.clone())
1802 .await?;
1803 }
1804
1805 response.send(proto::Ack {})?;
1806 Ok(())
1807}
1808
1809async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1810 let project_id = ProjectId::from_proto(request.project_id);
1811 let project_connection_ids = session
1812 .db()
1813 .await
1814 .project_connection_ids(project_id, session.connection_id)
1815 .await?;
1816
1817 broadcast(
1818 Some(session.connection_id),
1819 project_connection_ids.iter().copied(),
1820 |connection_id| {
1821 session
1822 .peer
1823 .forward_send(session.connection_id, connection_id, request.clone())
1824 },
1825 );
1826 Ok(())
1827}
1828
1829async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1830 let project_id = ProjectId::from_proto(request.project_id);
1831 let project_connection_ids = session
1832 .db()
1833 .await
1834 .project_connection_ids(project_id, session.connection_id)
1835 .await?;
1836 broadcast(
1837 Some(session.connection_id),
1838 project_connection_ids.iter().copied(),
1839 |connection_id| {
1840 session
1841 .peer
1842 .forward_send(session.connection_id, connection_id, request.clone())
1843 },
1844 );
1845 Ok(())
1846}
1847
1848async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1849 broadcast_project_message(request.project_id, request, session).await
1850}
1851
1852async fn broadcast_project_message<T: EnvelopedMessage>(
1853 project_id: u64,
1854 request: T,
1855 session: Session,
1856) -> Result<()> {
1857 let project_id = ProjectId::from_proto(project_id);
1858 let project_connection_ids = session
1859 .db()
1860 .await
1861 .project_connection_ids(project_id, session.connection_id)
1862 .await?;
1863 broadcast(
1864 Some(session.connection_id),
1865 project_connection_ids.iter().copied(),
1866 |connection_id| {
1867 session
1868 .peer
1869 .forward_send(session.connection_id, connection_id, request.clone())
1870 },
1871 );
1872 Ok(())
1873}
1874
1875async fn follow(
1876 request: proto::Follow,
1877 response: Response<proto::Follow>,
1878 session: Session,
1879) -> Result<()> {
1880 let project_id = ProjectId::from_proto(request.project_id);
1881 let leader_id = request
1882 .leader_id
1883 .ok_or_else(|| anyhow!("invalid leader id"))?
1884 .into();
1885 let follower_id = session.connection_id;
1886
1887 {
1888 let project_connection_ids = session
1889 .db()
1890 .await
1891 .project_connection_ids(project_id, session.connection_id)
1892 .await?;
1893
1894 if !project_connection_ids.contains(&leader_id) {
1895 Err(anyhow!("no such peer"))?;
1896 }
1897 }
1898
1899 let mut response_payload = session
1900 .peer
1901 .forward_request(session.connection_id, leader_id, request)
1902 .await?;
1903 response_payload
1904 .views
1905 .retain(|view| view.leader_id != Some(follower_id.into()));
1906 response.send(response_payload)?;
1907
1908 let room = session
1909 .db()
1910 .await
1911 .follow(project_id, leader_id, follower_id)
1912 .await?;
1913 room_updated(&room, &session.peer);
1914
1915 Ok(())
1916}
1917
1918async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1919 let project_id = ProjectId::from_proto(request.project_id);
1920 let leader_id = request
1921 .leader_id
1922 .ok_or_else(|| anyhow!("invalid leader id"))?
1923 .into();
1924 let follower_id = session.connection_id;
1925
1926 if !session
1927 .db()
1928 .await
1929 .project_connection_ids(project_id, session.connection_id)
1930 .await?
1931 .contains(&leader_id)
1932 {
1933 Err(anyhow!("no such peer"))?;
1934 }
1935
1936 session
1937 .peer
1938 .forward_send(session.connection_id, leader_id, request)?;
1939
1940 let room = session
1941 .db()
1942 .await
1943 .unfollow(project_id, leader_id, follower_id)
1944 .await?;
1945 room_updated(&room, &session.peer);
1946
1947 Ok(())
1948}
1949
1950async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1951 let project_id = ProjectId::from_proto(request.project_id);
1952 let project_connection_ids = session
1953 .db
1954 .lock()
1955 .await
1956 .project_connection_ids(project_id, session.connection_id)
1957 .await?;
1958
1959 let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1960 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1961 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1962 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1963 });
1964 for follower_peer_id in request.follower_ids.iter().copied() {
1965 let follower_connection_id = follower_peer_id.into();
1966 if project_connection_ids.contains(&follower_connection_id)
1967 && Some(follower_peer_id) != leader_id
1968 {
1969 session.peer.forward_send(
1970 session.connection_id,
1971 follower_connection_id,
1972 request.clone(),
1973 )?;
1974 }
1975 }
1976 Ok(())
1977}
1978
1979async fn get_users(
1980 request: proto::GetUsers,
1981 response: Response<proto::GetUsers>,
1982 session: Session,
1983) -> Result<()> {
1984 let user_ids = request
1985 .user_ids
1986 .into_iter()
1987 .map(UserId::from_proto)
1988 .collect();
1989 let users = session
1990 .db()
1991 .await
1992 .get_users_by_ids(user_ids)
1993 .await?
1994 .into_iter()
1995 .map(|user| proto::User {
1996 id: user.id.to_proto(),
1997 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1998 github_login: user.github_login,
1999 })
2000 .collect();
2001 response.send(proto::UsersResponse { users })?;
2002 Ok(())
2003}
2004
2005async fn fuzzy_search_users(
2006 request: proto::FuzzySearchUsers,
2007 response: Response<proto::FuzzySearchUsers>,
2008 session: Session,
2009) -> Result<()> {
2010 let query = request.query;
2011 let users = match query.len() {
2012 0 => vec![],
2013 1 | 2 => session
2014 .db()
2015 .await
2016 .get_user_by_github_login(&query)
2017 .await?
2018 .into_iter()
2019 .collect(),
2020 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2021 };
2022 let users = users
2023 .into_iter()
2024 .filter(|user| user.id != session.user_id)
2025 .map(|user| proto::User {
2026 id: user.id.to_proto(),
2027 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2028 github_login: user.github_login,
2029 })
2030 .collect();
2031 response.send(proto::UsersResponse { users })?;
2032 Ok(())
2033}
2034
2035async fn request_contact(
2036 request: proto::RequestContact,
2037 response: Response<proto::RequestContact>,
2038 session: Session,
2039) -> Result<()> {
2040 let requester_id = session.user_id;
2041 let responder_id = UserId::from_proto(request.responder_id);
2042 if requester_id == responder_id {
2043 return Err(anyhow!("cannot add yourself as a contact"))?;
2044 }
2045
2046 session
2047 .db()
2048 .await
2049 .send_contact_request(requester_id, responder_id)
2050 .await?;
2051
2052 // Update outgoing contact requests of requester
2053 let mut update = proto::UpdateContacts::default();
2054 update.outgoing_requests.push(responder_id.to_proto());
2055 for connection_id in session
2056 .connection_pool()
2057 .await
2058 .user_connection_ids(requester_id)
2059 {
2060 session.peer.send(connection_id, update.clone())?;
2061 }
2062
2063 // Update incoming contact requests of responder
2064 let mut update = proto::UpdateContacts::default();
2065 update
2066 .incoming_requests
2067 .push(proto::IncomingContactRequest {
2068 requester_id: requester_id.to_proto(),
2069 should_notify: true,
2070 });
2071 for connection_id in session
2072 .connection_pool()
2073 .await
2074 .user_connection_ids(responder_id)
2075 {
2076 session.peer.send(connection_id, update.clone())?;
2077 }
2078
2079 response.send(proto::Ack {})?;
2080 Ok(())
2081}
2082
2083async fn respond_to_contact_request(
2084 request: proto::RespondToContactRequest,
2085 response: Response<proto::RespondToContactRequest>,
2086 session: Session,
2087) -> Result<()> {
2088 let responder_id = session.user_id;
2089 let requester_id = UserId::from_proto(request.requester_id);
2090 let db = session.db().await;
2091 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2092 db.dismiss_contact_notification(responder_id, requester_id)
2093 .await?;
2094 } else {
2095 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2096
2097 db.respond_to_contact_request(responder_id, requester_id, accept)
2098 .await?;
2099 let requester_busy = db.is_user_busy(requester_id).await?;
2100 let responder_busy = db.is_user_busy(responder_id).await?;
2101
2102 let pool = session.connection_pool().await;
2103 // Update responder with new contact
2104 let mut update = proto::UpdateContacts::default();
2105 if accept {
2106 update
2107 .contacts
2108 .push(contact_for_user(requester_id, false, requester_busy, &pool));
2109 }
2110 update
2111 .remove_incoming_requests
2112 .push(requester_id.to_proto());
2113 for connection_id in pool.user_connection_ids(responder_id) {
2114 session.peer.send(connection_id, update.clone())?;
2115 }
2116
2117 // Update requester with new contact
2118 let mut update = proto::UpdateContacts::default();
2119 if accept {
2120 update
2121 .contacts
2122 .push(contact_for_user(responder_id, true, responder_busy, &pool));
2123 }
2124 update
2125 .remove_outgoing_requests
2126 .push(responder_id.to_proto());
2127 for connection_id in pool.user_connection_ids(requester_id) {
2128 session.peer.send(connection_id, update.clone())?;
2129 }
2130 }
2131
2132 response.send(proto::Ack {})?;
2133 Ok(())
2134}
2135
2136async fn remove_contact(
2137 request: proto::RemoveContact,
2138 response: Response<proto::RemoveContact>,
2139 session: Session,
2140) -> Result<()> {
2141 let requester_id = session.user_id;
2142 let responder_id = UserId::from_proto(request.user_id);
2143 let db = session.db().await;
2144 let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
2145
2146 let pool = session.connection_pool().await;
2147 // Update outgoing contact requests of requester
2148 let mut update = proto::UpdateContacts::default();
2149 if contact_accepted {
2150 update.remove_contacts.push(responder_id.to_proto());
2151 } else {
2152 update
2153 .remove_outgoing_requests
2154 .push(responder_id.to_proto());
2155 }
2156 for connection_id in pool.user_connection_ids(requester_id) {
2157 session.peer.send(connection_id, update.clone())?;
2158 }
2159
2160 // Update incoming contact requests of responder
2161 let mut update = proto::UpdateContacts::default();
2162 if contact_accepted {
2163 update.remove_contacts.push(requester_id.to_proto());
2164 } else {
2165 update
2166 .remove_incoming_requests
2167 .push(requester_id.to_proto());
2168 }
2169 for connection_id in pool.user_connection_ids(responder_id) {
2170 session.peer.send(connection_id, update.clone())?;
2171 }
2172
2173 response.send(proto::Ack {})?;
2174 Ok(())
2175}
2176
2177async fn create_channel(
2178 request: proto::CreateChannel,
2179 response: Response<proto::CreateChannel>,
2180 session: Session,
2181) -> Result<()> {
2182 let db = session.db().await;
2183 let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
2184
2185 if let Some(live_kit) = session.live_kit_client.as_ref() {
2186 live_kit.create_room(live_kit_room.clone()).await?;
2187 }
2188
2189 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2190 let id = db
2191 .create_channel(&request.name, parent_id, &live_kit_room, session.user_id)
2192 .await?;
2193
2194 let channel = proto::Channel {
2195 id: id.to_proto(),
2196 name: request.name,
2197 parent_id: request.parent_id,
2198 };
2199
2200 response.send(proto::ChannelResponse {
2201 channel: Some(channel.clone()),
2202 })?;
2203
2204 let mut update = proto::UpdateChannels::default();
2205 update.channels.push(channel);
2206
2207 let user_ids_to_notify = if let Some(parent_id) = parent_id {
2208 db.get_channel_members(parent_id).await?
2209 } else {
2210 vec![session.user_id]
2211 };
2212
2213 let connection_pool = session.connection_pool().await;
2214 for user_id in user_ids_to_notify {
2215 for connection_id in connection_pool.user_connection_ids(user_id) {
2216 let mut update = update.clone();
2217 if user_id == session.user_id {
2218 update.channel_permissions.push(proto::ChannelPermission {
2219 channel_id: id.to_proto(),
2220 is_admin: true,
2221 });
2222 }
2223 session.peer.send(connection_id, update)?;
2224 }
2225 }
2226
2227 Ok(())
2228}
2229
2230async fn remove_channel(
2231 request: proto::RemoveChannel,
2232 response: Response<proto::RemoveChannel>,
2233 session: Session,
2234) -> Result<()> {
2235 let db = session.db().await;
2236
2237 let channel_id = request.channel_id;
2238 let (removed_channels, member_ids) = db
2239 .remove_channel(ChannelId::from_proto(channel_id), session.user_id)
2240 .await?;
2241 response.send(proto::Ack {})?;
2242
2243 // Notify members of removed channels
2244 let mut update = proto::UpdateChannels::default();
2245 update
2246 .remove_channels
2247 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2248
2249 let connection_pool = session.connection_pool().await;
2250 for member_id in member_ids {
2251 for connection_id in connection_pool.user_connection_ids(member_id) {
2252 session.peer.send(connection_id, update.clone())?;
2253 }
2254 }
2255
2256 Ok(())
2257}
2258
2259async fn invite_channel_member(
2260 request: proto::InviteChannelMember,
2261 response: Response<proto::InviteChannelMember>,
2262 session: Session,
2263) -> Result<()> {
2264 let db = session.db().await;
2265 let channel_id = ChannelId::from_proto(request.channel_id);
2266 let invitee_id = UserId::from_proto(request.user_id);
2267 db.invite_channel_member(channel_id, invitee_id, session.user_id, request.admin)
2268 .await?;
2269
2270 let (channel, _) = db
2271 .get_channel(channel_id, session.user_id)
2272 .await?
2273 .ok_or_else(|| anyhow!("channel not found"))?;
2274
2275 let mut update = proto::UpdateChannels::default();
2276 update.channel_invitations.push(proto::Channel {
2277 id: channel.id.to_proto(),
2278 name: channel.name,
2279 parent_id: None,
2280 });
2281 for connection_id in session
2282 .connection_pool()
2283 .await
2284 .user_connection_ids(invitee_id)
2285 {
2286 session.peer.send(connection_id, update.clone())?;
2287 }
2288
2289 response.send(proto::Ack {})?;
2290 Ok(())
2291}
2292
2293async fn remove_channel_member(
2294 request: proto::RemoveChannelMember,
2295 response: Response<proto::RemoveChannelMember>,
2296 session: Session,
2297) -> Result<()> {
2298 let db = session.db().await;
2299 let channel_id = ChannelId::from_proto(request.channel_id);
2300 let member_id = UserId::from_proto(request.user_id);
2301
2302 db.remove_channel_member(channel_id, member_id, session.user_id)
2303 .await?;
2304
2305 let mut update = proto::UpdateChannels::default();
2306 update.remove_channels.push(channel_id.to_proto());
2307
2308 for connection_id in session
2309 .connection_pool()
2310 .await
2311 .user_connection_ids(member_id)
2312 {
2313 session.peer.send(connection_id, update.clone())?;
2314 }
2315
2316 response.send(proto::Ack {})?;
2317 Ok(())
2318}
2319
2320async fn set_channel_member_admin(
2321 request: proto::SetChannelMemberAdmin,
2322 response: Response<proto::SetChannelMemberAdmin>,
2323 session: Session,
2324) -> Result<()> {
2325 let db = session.db().await;
2326 let channel_id = ChannelId::from_proto(request.channel_id);
2327 let member_id = UserId::from_proto(request.user_id);
2328 db.set_channel_member_admin(channel_id, session.user_id, member_id, request.admin)
2329 .await?;
2330
2331 let (channel, has_accepted) = db
2332 .get_channel(channel_id, member_id)
2333 .await?
2334 .ok_or_else(|| anyhow!("channel not found"))?;
2335
2336 let mut update = proto::UpdateChannels::default();
2337 if has_accepted {
2338 update.channel_permissions.push(proto::ChannelPermission {
2339 channel_id: channel.id.to_proto(),
2340 is_admin: request.admin,
2341 });
2342 }
2343
2344 for connection_id in session
2345 .connection_pool()
2346 .await
2347 .user_connection_ids(member_id)
2348 {
2349 session.peer.send(connection_id, update.clone())?;
2350 }
2351
2352 response.send(proto::Ack {})?;
2353 Ok(())
2354}
2355
2356async fn rename_channel(
2357 request: proto::RenameChannel,
2358 response: Response<proto::RenameChannel>,
2359 session: Session,
2360) -> Result<()> {
2361 let db = session.db().await;
2362 let channel_id = ChannelId::from_proto(request.channel_id);
2363 let new_name = db
2364 .rename_channel(channel_id, session.user_id, &request.name)
2365 .await?;
2366
2367 let channel = proto::Channel {
2368 id: request.channel_id,
2369 name: new_name,
2370 parent_id: None,
2371 };
2372 response.send(proto::ChannelResponse {
2373 channel: Some(channel.clone()),
2374 })?;
2375 let mut update = proto::UpdateChannels::default();
2376 update.channels.push(channel);
2377
2378 let member_ids = db.get_channel_members(channel_id).await?;
2379
2380 let connection_pool = session.connection_pool().await;
2381 for member_id in member_ids {
2382 for connection_id in connection_pool.user_connection_ids(member_id) {
2383 session.peer.send(connection_id, update.clone())?;
2384 }
2385 }
2386
2387 Ok(())
2388}
2389
2390async fn get_channel_members(
2391 request: proto::GetChannelMembers,
2392 response: Response<proto::GetChannelMembers>,
2393 session: Session,
2394) -> Result<()> {
2395 let db = session.db().await;
2396 let channel_id = ChannelId::from_proto(request.channel_id);
2397 let members = db
2398 .get_channel_member_details(channel_id, session.user_id)
2399 .await?;
2400 response.send(proto::GetChannelMembersResponse { members })?;
2401 Ok(())
2402}
2403
2404async fn respond_to_channel_invite(
2405 request: proto::RespondToChannelInvite,
2406 response: Response<proto::RespondToChannelInvite>,
2407 session: Session,
2408) -> Result<()> {
2409 let db = session.db().await;
2410 let channel_id = ChannelId::from_proto(request.channel_id);
2411 db.respond_to_channel_invite(channel_id, session.user_id, request.accept)
2412 .await?;
2413
2414 let mut update = proto::UpdateChannels::default();
2415 update
2416 .remove_channel_invitations
2417 .push(channel_id.to_proto());
2418 if request.accept {
2419 let result = db.get_channels_for_user(session.user_id).await?;
2420 update
2421 .channels
2422 .extend(result.channels.into_iter().map(|channel| proto::Channel {
2423 id: channel.id.to_proto(),
2424 name: channel.name,
2425 parent_id: channel.parent_id.map(ChannelId::to_proto),
2426 }));
2427 update
2428 .channel_participants
2429 .extend(
2430 result
2431 .channel_participants
2432 .into_iter()
2433 .map(|(channel_id, user_ids)| proto::ChannelParticipants {
2434 channel_id: channel_id.to_proto(),
2435 participant_user_ids: user_ids.into_iter().map(UserId::to_proto).collect(),
2436 }),
2437 );
2438 update
2439 .channel_permissions
2440 .extend(
2441 result
2442 .channels_with_admin_privileges
2443 .into_iter()
2444 .map(|channel_id| proto::ChannelPermission {
2445 channel_id: channel_id.to_proto(),
2446 is_admin: true,
2447 }),
2448 );
2449 }
2450 session.peer.send(session.connection_id, update)?;
2451 response.send(proto::Ack {})?;
2452
2453 Ok(())
2454}
2455
2456async fn join_channel(
2457 request: proto::JoinChannel,
2458 response: Response<proto::JoinChannel>,
2459 session: Session,
2460) -> Result<()> {
2461 let channel_id = ChannelId::from_proto(request.channel_id);
2462
2463 let joined_room = {
2464 leave_room_for_session(&session).await?;
2465 let db = session.db().await;
2466
2467 let room_id = db.room_id_for_channel(channel_id).await?;
2468
2469 let joined_room = db
2470 .join_room(room_id, session.user_id, session.connection_id)
2471 .await?;
2472
2473 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2474 let token = live_kit
2475 .room_token(
2476 &joined_room.room.live_kit_room,
2477 &session.user_id.to_string(),
2478 )
2479 .trace_err()?;
2480
2481 Some(LiveKitConnectionInfo {
2482 server_url: live_kit.url().into(),
2483 token,
2484 })
2485 });
2486
2487 response.send(proto::JoinRoomResponse {
2488 room: Some(joined_room.room.clone()),
2489 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2490 live_kit_connection_info,
2491 })?;
2492
2493 room_updated(&joined_room.room, &session.peer);
2494
2495 joined_room.into_inner()
2496 };
2497
2498 channel_updated(
2499 channel_id,
2500 &joined_room.room,
2501 &joined_room.channel_members,
2502 &session.peer,
2503 &*session.connection_pool().await,
2504 );
2505
2506 update_user_contacts(session.user_id, &session).await?;
2507
2508 Ok(())
2509}
2510
2511async fn join_channel_buffer(
2512 request: proto::JoinChannelBuffer,
2513 response: Response<proto::JoinChannelBuffer>,
2514 session: Session,
2515) -> Result<()> {
2516 let db = session.db().await;
2517 let channel_id = ChannelId::from_proto(request.channel_id);
2518
2519 let open_response = db
2520 .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2521 .await?;
2522
2523 let replica_id = open_response.replica_id;
2524 let collaborators = open_response.collaborators.clone();
2525
2526 response.send(open_response)?;
2527
2528 let update = AddChannelBufferCollaborator {
2529 channel_id: channel_id.to_proto(),
2530 collaborator: Some(proto::Collaborator {
2531 user_id: session.user_id.to_proto(),
2532 peer_id: Some(session.connection_id.into()),
2533 replica_id,
2534 }),
2535 };
2536 channel_buffer_updated(
2537 session.connection_id,
2538 collaborators
2539 .iter()
2540 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2541 &update,
2542 &session.peer,
2543 );
2544
2545 Ok(())
2546}
2547
2548async fn update_channel_buffer(
2549 request: proto::UpdateChannelBuffer,
2550 session: Session,
2551) -> Result<()> {
2552 let db = session.db().await;
2553 let channel_id = ChannelId::from_proto(request.channel_id);
2554
2555 let collaborators = db
2556 .update_channel_buffer(channel_id, session.user_id, &request.operations)
2557 .await?;
2558
2559 channel_buffer_updated(
2560 session.connection_id,
2561 collaborators,
2562 &proto::UpdateChannelBuffer {
2563 channel_id: channel_id.to_proto(),
2564 operations: request.operations,
2565 },
2566 &session.peer,
2567 );
2568 Ok(())
2569}
2570
2571async fn rejoin_channel_buffers(
2572 request: proto::RejoinChannelBuffers,
2573 response: Response<proto::RejoinChannelBuffers>,
2574 session: Session,
2575) -> Result<()> {
2576 let db = session.db().await;
2577 let buffers = db
2578 .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2579 .await?;
2580
2581 for buffer in &buffers {
2582 let collaborators_to_notify = buffer
2583 .buffer
2584 .collaborators
2585 .iter()
2586 .filter_map(|c| Some(c.peer_id?.into()));
2587 channel_buffer_updated(
2588 session.connection_id,
2589 collaborators_to_notify,
2590 &proto::UpdateChannelBufferCollaborator {
2591 channel_id: buffer.buffer.channel_id,
2592 old_peer_id: Some(buffer.old_connection_id.into()),
2593 new_peer_id: Some(session.connection_id.into()),
2594 },
2595 &session.peer,
2596 );
2597 }
2598
2599 response.send(proto::RejoinChannelBuffersResponse {
2600 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2601 })?;
2602
2603 Ok(())
2604}
2605
2606async fn leave_channel_buffer(
2607 request: proto::LeaveChannelBuffer,
2608 response: Response<proto::LeaveChannelBuffer>,
2609 session: Session,
2610) -> Result<()> {
2611 let db = session.db().await;
2612 let channel_id = ChannelId::from_proto(request.channel_id);
2613
2614 let collaborators_to_notify = db
2615 .leave_channel_buffer(channel_id, session.connection_id)
2616 .await?;
2617
2618 response.send(Ack {})?;
2619
2620 channel_buffer_updated(
2621 session.connection_id,
2622 collaborators_to_notify,
2623 &proto::RemoveChannelBufferCollaborator {
2624 channel_id: channel_id.to_proto(),
2625 peer_id: Some(session.connection_id.into()),
2626 },
2627 &session.peer,
2628 );
2629
2630 Ok(())
2631}
2632
2633fn channel_buffer_updated<T: EnvelopedMessage>(
2634 sender_id: ConnectionId,
2635 collaborators: impl IntoIterator<Item = ConnectionId>,
2636 message: &T,
2637 peer: &Peer,
2638) {
2639 broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2640 peer.send(peer_id.into(), message.clone())
2641 });
2642}
2643
2644async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
2645 let project_id = ProjectId::from_proto(request.project_id);
2646 let project_connection_ids = session
2647 .db()
2648 .await
2649 .project_connection_ids(project_id, session.connection_id)
2650 .await?;
2651 broadcast(
2652 Some(session.connection_id),
2653 project_connection_ids.iter().copied(),
2654 |connection_id| {
2655 session
2656 .peer
2657 .forward_send(session.connection_id, connection_id, request.clone())
2658 },
2659 );
2660 Ok(())
2661}
2662
2663async fn get_private_user_info(
2664 _request: proto::GetPrivateUserInfo,
2665 response: Response<proto::GetPrivateUserInfo>,
2666 session: Session,
2667) -> Result<()> {
2668 let db = session.db().await;
2669
2670 let metrics_id = db.get_user_metrics_id(session.user_id).await?;
2671 let user = db
2672 .get_user_by_id(session.user_id)
2673 .await?
2674 .ok_or_else(|| anyhow!("user not found"))?;
2675 let flags = db.get_user_flags(session.user_id).await?;
2676
2677 response.send(proto::GetPrivateUserInfoResponse {
2678 metrics_id,
2679 staff: user.admin,
2680 flags,
2681 })?;
2682 Ok(())
2683}
2684
2685fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
2686 match message {
2687 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
2688 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
2689 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
2690 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
2691 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
2692 code: frame.code.into(),
2693 reason: frame.reason,
2694 })),
2695 }
2696}
2697
2698fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
2699 match message {
2700 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
2701 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
2702 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
2703 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
2704 AxumMessage::Close(frame) => {
2705 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
2706 code: frame.code.into(),
2707 reason: frame.reason,
2708 }))
2709 }
2710 }
2711}
2712
2713fn build_initial_channels_update(
2714 channels: ChannelsForUser,
2715 channel_invites: Vec<db::Channel>,
2716) -> proto::UpdateChannels {
2717 let mut update = proto::UpdateChannels::default();
2718
2719 for channel in channels.channels {
2720 update.channels.push(proto::Channel {
2721 id: channel.id.to_proto(),
2722 name: channel.name,
2723 parent_id: channel.parent_id.map(|id| id.to_proto()),
2724 });
2725 }
2726
2727 for (channel_id, participants) in channels.channel_participants {
2728 update
2729 .channel_participants
2730 .push(proto::ChannelParticipants {
2731 channel_id: channel_id.to_proto(),
2732 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
2733 });
2734 }
2735
2736 update
2737 .channel_permissions
2738 .extend(
2739 channels
2740 .channels_with_admin_privileges
2741 .into_iter()
2742 .map(|id| proto::ChannelPermission {
2743 channel_id: id.to_proto(),
2744 is_admin: true,
2745 }),
2746 );
2747
2748 for channel in channel_invites {
2749 update.channel_invitations.push(proto::Channel {
2750 id: channel.id.to_proto(),
2751 name: channel.name,
2752 parent_id: None,
2753 });
2754 }
2755
2756 update
2757}
2758
2759fn build_initial_contacts_update(
2760 contacts: Vec<db::Contact>,
2761 pool: &ConnectionPool,
2762) -> proto::UpdateContacts {
2763 let mut update = proto::UpdateContacts::default();
2764
2765 for contact in contacts {
2766 match contact {
2767 db::Contact::Accepted {
2768 user_id,
2769 should_notify,
2770 busy,
2771 } => {
2772 update
2773 .contacts
2774 .push(contact_for_user(user_id, should_notify, busy, &pool));
2775 }
2776 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
2777 db::Contact::Incoming {
2778 user_id,
2779 should_notify,
2780 } => update
2781 .incoming_requests
2782 .push(proto::IncomingContactRequest {
2783 requester_id: user_id.to_proto(),
2784 should_notify,
2785 }),
2786 }
2787 }
2788
2789 update
2790}
2791
2792fn contact_for_user(
2793 user_id: UserId,
2794 should_notify: bool,
2795 busy: bool,
2796 pool: &ConnectionPool,
2797) -> proto::Contact {
2798 proto::Contact {
2799 user_id: user_id.to_proto(),
2800 online: pool.is_user_online(user_id),
2801 busy,
2802 should_notify,
2803 }
2804}
2805
2806fn room_updated(room: &proto::Room, peer: &Peer) {
2807 broadcast(
2808 None,
2809 room.participants
2810 .iter()
2811 .filter_map(|participant| Some(participant.peer_id?.into())),
2812 |peer_id| {
2813 peer.send(
2814 peer_id.into(),
2815 proto::RoomUpdated {
2816 room: Some(room.clone()),
2817 },
2818 )
2819 },
2820 );
2821}
2822
2823fn channel_updated(
2824 channel_id: ChannelId,
2825 room: &proto::Room,
2826 channel_members: &[UserId],
2827 peer: &Peer,
2828 pool: &ConnectionPool,
2829) {
2830 let participants = room
2831 .participants
2832 .iter()
2833 .map(|p| p.user_id)
2834 .collect::<Vec<_>>();
2835
2836 broadcast(
2837 None,
2838 channel_members
2839 .iter()
2840 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2841 |peer_id| {
2842 peer.send(
2843 peer_id.into(),
2844 proto::UpdateChannels {
2845 channel_participants: vec![proto::ChannelParticipants {
2846 channel_id: channel_id.to_proto(),
2847 participant_user_ids: participants.clone(),
2848 }],
2849 ..Default::default()
2850 },
2851 )
2852 },
2853 );
2854}
2855
2856async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
2857 let db = session.db().await;
2858
2859 let contacts = db.get_contacts(user_id).await?;
2860 let busy = db.is_user_busy(user_id).await?;
2861
2862 let pool = session.connection_pool().await;
2863 let updated_contact = contact_for_user(user_id, false, busy, &pool);
2864 for contact in contacts {
2865 if let db::Contact::Accepted {
2866 user_id: contact_user_id,
2867 ..
2868 } = contact
2869 {
2870 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
2871 session
2872 .peer
2873 .send(
2874 contact_conn_id,
2875 proto::UpdateContacts {
2876 contacts: vec![updated_contact.clone()],
2877 remove_contacts: Default::default(),
2878 incoming_requests: Default::default(),
2879 remove_incoming_requests: Default::default(),
2880 outgoing_requests: Default::default(),
2881 remove_outgoing_requests: Default::default(),
2882 },
2883 )
2884 .trace_err();
2885 }
2886 }
2887 }
2888 Ok(())
2889}
2890
2891async fn leave_room_for_session(session: &Session) -> Result<()> {
2892 let mut contacts_to_update = HashSet::default();
2893
2894 let room_id;
2895 let canceled_calls_to_user_ids;
2896 let live_kit_room;
2897 let delete_live_kit_room;
2898 let room;
2899 let channel_members;
2900 let channel_id;
2901
2902 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
2903 contacts_to_update.insert(session.user_id);
2904
2905 for project in left_room.left_projects.values() {
2906 project_left(project, session);
2907 }
2908
2909 room_id = RoomId::from_proto(left_room.room.id);
2910 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
2911 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
2912 delete_live_kit_room = left_room.deleted;
2913 room = mem::take(&mut left_room.room);
2914 channel_members = mem::take(&mut left_room.channel_members);
2915 channel_id = left_room.channel_id;
2916
2917 room_updated(&room, &session.peer);
2918 } else {
2919 return Ok(());
2920 }
2921
2922 if let Some(channel_id) = channel_id {
2923 channel_updated(
2924 channel_id,
2925 &room,
2926 &channel_members,
2927 &session.peer,
2928 &*session.connection_pool().await,
2929 );
2930 }
2931
2932 {
2933 let pool = session.connection_pool().await;
2934 for canceled_user_id in canceled_calls_to_user_ids {
2935 for connection_id in pool.user_connection_ids(canceled_user_id) {
2936 session
2937 .peer
2938 .send(
2939 connection_id,
2940 proto::CallCanceled {
2941 room_id: room_id.to_proto(),
2942 },
2943 )
2944 .trace_err();
2945 }
2946 contacts_to_update.insert(canceled_user_id);
2947 }
2948 }
2949
2950 for contact_user_id in contacts_to_update {
2951 update_user_contacts(contact_user_id, &session).await?;
2952 }
2953
2954 if let Some(live_kit) = session.live_kit_client.as_ref() {
2955 live_kit
2956 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
2957 .await
2958 .trace_err();
2959
2960 if delete_live_kit_room {
2961 live_kit.delete_room(live_kit_room).await.trace_err();
2962 }
2963 }
2964
2965 Ok(())
2966}
2967
2968async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
2969 let left_channel_buffers = session
2970 .db()
2971 .await
2972 .leave_channel_buffers(session.connection_id)
2973 .await?;
2974
2975 for (channel_id, connections) in left_channel_buffers {
2976 channel_buffer_updated(
2977 session.connection_id,
2978 connections,
2979 &proto::RemoveChannelBufferCollaborator {
2980 channel_id: channel_id.to_proto(),
2981 peer_id: Some(session.connection_id.into()),
2982 },
2983 &session.peer,
2984 );
2985 }
2986
2987 Ok(())
2988}
2989
2990fn project_left(project: &db::LeftProject, session: &Session) {
2991 for connection_id in &project.connection_ids {
2992 if project.host_user_id == session.user_id {
2993 session
2994 .peer
2995 .send(
2996 *connection_id,
2997 proto::UnshareProject {
2998 project_id: project.id.to_proto(),
2999 },
3000 )
3001 .trace_err();
3002 } else {
3003 session
3004 .peer
3005 .send(
3006 *connection_id,
3007 proto::RemoveProjectCollaborator {
3008 project_id: project.id.to_proto(),
3009 peer_id: Some(session.connection_id.into()),
3010 },
3011 )
3012 .trace_err();
3013 }
3014 }
3015}
3016
3017pub trait ResultExt {
3018 type Ok;
3019
3020 fn trace_err(self) -> Option<Self::Ok>;
3021}
3022
3023impl<T, E> ResultExt for Result<T, E>
3024where
3025 E: std::fmt::Debug,
3026{
3027 type Ok = T;
3028
3029 fn trace_err(self) -> Option<T> {
3030 match self {
3031 Ok(value) => Some(value),
3032 Err(error) => {
3033 tracing::error!("{:?}", error);
3034 None
3035 }
3036 }
3037 }
3038}