1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{self, ChannelId, ChannelsForUser, Database, ProjectId, RoomId, ServerId, User, UserId},
6 executor::Executor,
7 AppState, Result,
8};
9use anyhow::anyhow;
10use async_tungstenite::tungstenite::{
11 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
12};
13use axum::{
14 body::Body,
15 extract::{
16 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
17 ConnectInfo, WebSocketUpgrade,
18 },
19 headers::{Header, HeaderName},
20 http::StatusCode,
21 middleware,
22 response::IntoResponse,
23 routing::get,
24 Extension, Router, TypedHeader,
25};
26use collections::{HashMap, HashSet};
27pub use connection_pool::ConnectionPool;
28use futures::{
29 channel::oneshot,
30 future::{self, BoxFuture},
31 stream::FuturesUnordered,
32 FutureExt, SinkExt, StreamExt, TryStreamExt,
33};
34use lazy_static::lazy_static;
35use prometheus::{register_int_gauge, IntGauge};
36use rpc::{
37 proto::{
38 self, Ack, AddChannelBufferCollaborator, AnyTypedEnvelope, EntityMessage, EnvelopedMessage,
39 LiveKitConnectionInfo, RequestMessage,
40 },
41 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
42};
43use serde::{Serialize, Serializer};
44use std::{
45 any::TypeId,
46 fmt,
47 future::Future,
48 marker::PhantomData,
49 mem,
50 net::SocketAddr,
51 ops::{Deref, DerefMut},
52 rc::Rc,
53 sync::{
54 atomic::{AtomicBool, Ordering::SeqCst},
55 Arc,
56 },
57 time::{Duration, Instant},
58};
59use tokio::sync::{watch, Semaphore};
60use tower::ServiceBuilder;
61use tracing::{info_span, instrument, Instrument};
62
63pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
64pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
65
66lazy_static! {
67 static ref METRIC_CONNECTIONS: IntGauge =
68 register_int_gauge!("connections", "number of connections").unwrap();
69 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
70 "shared_projects",
71 "number of open projects with one or more guests"
72 )
73 .unwrap();
74}
75
76type MessageHandler =
77 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
78
79struct Response<R> {
80 peer: Arc<Peer>,
81 receipt: Receipt<R>,
82 responded: Arc<AtomicBool>,
83}
84
85impl<R: RequestMessage> Response<R> {
86 fn send(self, payload: R::Response) -> Result<()> {
87 self.responded.store(true, SeqCst);
88 self.peer.respond(self.receipt, payload)?;
89 Ok(())
90 }
91}
92
93#[derive(Clone)]
94struct Session {
95 user_id: UserId,
96 connection_id: ConnectionId,
97 db: Arc<tokio::sync::Mutex<DbHandle>>,
98 peer: Arc<Peer>,
99 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
100 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
101 executor: Executor,
102}
103
104impl Session {
105 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
106 #[cfg(test)]
107 tokio::task::yield_now().await;
108 let guard = self.db.lock().await;
109 #[cfg(test)]
110 tokio::task::yield_now().await;
111 guard
112 }
113
114 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
115 #[cfg(test)]
116 tokio::task::yield_now().await;
117 let guard = self.connection_pool.lock();
118 ConnectionPoolGuard {
119 guard,
120 _not_send: PhantomData,
121 }
122 }
123}
124
125impl fmt::Debug for Session {
126 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127 f.debug_struct("Session")
128 .field("user_id", &self.user_id)
129 .field("connection_id", &self.connection_id)
130 .finish()
131 }
132}
133
134struct DbHandle(Arc<Database>);
135
136impl Deref for DbHandle {
137 type Target = Database;
138
139 fn deref(&self) -> &Self::Target {
140 self.0.as_ref()
141 }
142}
143
144pub struct Server {
145 id: parking_lot::Mutex<ServerId>,
146 peer: Arc<Peer>,
147 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
148 app_state: Arc<AppState>,
149 executor: Executor,
150 handlers: HashMap<TypeId, MessageHandler>,
151 teardown: watch::Sender<()>,
152}
153
154pub(crate) struct ConnectionPoolGuard<'a> {
155 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
156 _not_send: PhantomData<Rc<()>>,
157}
158
159#[derive(Serialize)]
160pub struct ServerSnapshot<'a> {
161 peer: &'a Peer,
162 #[serde(serialize_with = "serialize_deref")]
163 connection_pool: ConnectionPoolGuard<'a>,
164}
165
166pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
167where
168 S: Serializer,
169 T: Deref<Target = U>,
170 U: Serialize,
171{
172 Serialize::serialize(value.deref(), serializer)
173}
174
175impl Server {
176 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
177 let mut server = Self {
178 id: parking_lot::Mutex::new(id),
179 peer: Peer::new(id.0 as u32),
180 app_state,
181 executor,
182 connection_pool: Default::default(),
183 handlers: Default::default(),
184 teardown: watch::channel(()).0,
185 };
186
187 server
188 .add_request_handler(ping)
189 .add_request_handler(create_room)
190 .add_request_handler(join_room)
191 .add_request_handler(rejoin_room)
192 .add_request_handler(leave_room)
193 .add_request_handler(call)
194 .add_request_handler(cancel_call)
195 .add_message_handler(decline_call)
196 .add_request_handler(update_participant_location)
197 .add_request_handler(share_project)
198 .add_message_handler(unshare_project)
199 .add_request_handler(join_project)
200 .add_message_handler(leave_project)
201 .add_request_handler(update_project)
202 .add_request_handler(update_worktree)
203 .add_message_handler(start_language_server)
204 .add_message_handler(update_language_server)
205 .add_message_handler(update_diagnostic_summary)
206 .add_message_handler(update_worktree_settings)
207 .add_message_handler(refresh_inlay_hints)
208 .add_request_handler(forward_project_request::<proto::GetHover>)
209 .add_request_handler(forward_project_request::<proto::GetDefinition>)
210 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
211 .add_request_handler(forward_project_request::<proto::GetReferences>)
212 .add_request_handler(forward_project_request::<proto::SearchProject>)
213 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
214 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
215 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
216 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
217 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
218 .add_request_handler(forward_project_request::<proto::GetCompletions>)
219 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
220 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
221 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
222 .add_request_handler(forward_project_request::<proto::PrepareRename>)
223 .add_request_handler(forward_project_request::<proto::PerformRename>)
224 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
225 .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
226 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
227 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
228 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
229 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
230 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
231 .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
232 .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
233 .add_request_handler(forward_project_request::<proto::InlayHints>)
234 .add_message_handler(create_buffer_for_peer)
235 .add_request_handler(update_buffer)
236 .add_message_handler(update_buffer_file)
237 .add_message_handler(buffer_reloaded)
238 .add_message_handler(buffer_saved)
239 .add_request_handler(forward_project_request::<proto::SaveBuffer>)
240 .add_request_handler(get_users)
241 .add_request_handler(fuzzy_search_users)
242 .add_request_handler(request_contact)
243 .add_request_handler(remove_contact)
244 .add_request_handler(respond_to_contact_request)
245 .add_request_handler(create_channel)
246 .add_request_handler(remove_channel)
247 .add_request_handler(invite_channel_member)
248 .add_request_handler(remove_channel_member)
249 .add_request_handler(set_channel_member_admin)
250 .add_request_handler(rename_channel)
251 .add_request_handler(join_channel_buffer)
252 .add_request_handler(leave_channel_buffer)
253 .add_message_handler(update_channel_buffer)
254 .add_request_handler(rejoin_channel_buffers)
255 .add_request_handler(get_channel_members)
256 .add_request_handler(respond_to_channel_invite)
257 .add_request_handler(join_channel)
258 .add_request_handler(follow)
259 .add_message_handler(unfollow)
260 .add_message_handler(update_followers)
261 .add_message_handler(update_diff_base)
262 .add_request_handler(get_private_user_info);
263
264 Arc::new(server)
265 }
266
267 pub async fn start(&self) -> Result<()> {
268 let server_id = *self.id.lock();
269 let app_state = self.app_state.clone();
270 let peer = self.peer.clone();
271 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
272 let pool = self.connection_pool.clone();
273 let live_kit_client = self.app_state.live_kit_client.clone();
274
275 let span = info_span!("start server");
276 self.executor.spawn_detached(
277 async move {
278 tracing::info!("waiting for cleanup timeout");
279 timeout.await;
280 tracing::info!("cleanup timeout expired, retrieving stale rooms");
281 if let Some((room_ids, channel_ids)) = app_state
282 .db
283 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
284 .await
285 .trace_err()
286 {
287 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
288
289 for channel_id in channel_ids {
290 if let Some(refreshed_channel_buffer) = app_state
291 .db
292 .refresh_channel_buffer(channel_id, server_id)
293 .await
294 .trace_err()
295 {
296 for connection_id in refreshed_channel_buffer.connection_ids {
297 for message in &refreshed_channel_buffer.removed_collaborators {
298 peer.send(connection_id, message.clone()).trace_err();
299 }
300 }
301 }
302 }
303
304 for room_id in room_ids {
305 let mut contacts_to_update = HashSet::default();
306 let mut canceled_calls_to_user_ids = Vec::new();
307 let mut live_kit_room = String::new();
308 let mut delete_live_kit_room = false;
309
310 if let Some(mut refreshed_room) = app_state
311 .db
312 .refresh_room(room_id, server_id)
313 .await
314 .trace_err()
315 {
316 tracing::info!(
317 room_id = room_id.0,
318 new_participant_count = refreshed_room.room.participants.len(),
319 "refreshed room"
320 );
321 room_updated(&refreshed_room.room, &peer);
322 if let Some(channel_id) = refreshed_room.channel_id {
323 channel_updated(
324 channel_id,
325 &refreshed_room.room,
326 &refreshed_room.channel_members,
327 &peer,
328 &*pool.lock(),
329 );
330 }
331 contacts_to_update
332 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
333 contacts_to_update
334 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
335 canceled_calls_to_user_ids =
336 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
337 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
338 delete_live_kit_room = refreshed_room.room.participants.is_empty();
339 }
340
341 {
342 let pool = pool.lock();
343 for canceled_user_id in canceled_calls_to_user_ids {
344 for connection_id in pool.user_connection_ids(canceled_user_id) {
345 peer.send(
346 connection_id,
347 proto::CallCanceled {
348 room_id: room_id.to_proto(),
349 },
350 )
351 .trace_err();
352 }
353 }
354 }
355
356 for user_id in contacts_to_update {
357 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
358 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
359 if let Some((busy, contacts)) = busy.zip(contacts) {
360 let pool = pool.lock();
361 let updated_contact = contact_for_user(user_id, false, busy, &pool);
362 for contact in contacts {
363 if let db::Contact::Accepted {
364 user_id: contact_user_id,
365 ..
366 } = contact
367 {
368 for contact_conn_id in
369 pool.user_connection_ids(contact_user_id)
370 {
371 peer.send(
372 contact_conn_id,
373 proto::UpdateContacts {
374 contacts: vec![updated_contact.clone()],
375 remove_contacts: Default::default(),
376 incoming_requests: Default::default(),
377 remove_incoming_requests: Default::default(),
378 outgoing_requests: Default::default(),
379 remove_outgoing_requests: Default::default(),
380 },
381 )
382 .trace_err();
383 }
384 }
385 }
386 }
387 }
388
389 if let Some(live_kit) = live_kit_client.as_ref() {
390 if delete_live_kit_room {
391 live_kit.delete_room(live_kit_room).await.trace_err();
392 }
393 }
394 }
395 }
396
397 app_state
398 .db
399 .delete_stale_servers(&app_state.config.zed_environment, server_id)
400 .await
401 .trace_err();
402 }
403 .instrument(span),
404 );
405 Ok(())
406 }
407
408 pub fn teardown(&self) {
409 self.peer.teardown();
410 self.connection_pool.lock().reset();
411 let _ = self.teardown.send(());
412 }
413
414 #[cfg(test)]
415 pub fn reset(&self, id: ServerId) {
416 self.teardown();
417 *self.id.lock() = id;
418 self.peer.reset(id.0 as u32);
419 }
420
421 #[cfg(test)]
422 pub fn id(&self) -> ServerId {
423 *self.id.lock()
424 }
425
426 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
427 where
428 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
429 Fut: 'static + Send + Future<Output = Result<()>>,
430 M: EnvelopedMessage,
431 {
432 let prev_handler = self.handlers.insert(
433 TypeId::of::<M>(),
434 Box::new(move |envelope, session| {
435 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
436 let span = info_span!(
437 "handle message",
438 payload_type = envelope.payload_type_name()
439 );
440 span.in_scope(|| {
441 tracing::info!(
442 payload_type = envelope.payload_type_name(),
443 "message received"
444 );
445 });
446 let start_time = Instant::now();
447 let future = (handler)(*envelope, session);
448 async move {
449 let result = future.await;
450 let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
451 match result {
452 Err(error) => {
453 tracing::error!(%error, ?duration_ms, "error handling message")
454 }
455 Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
456 }
457 }
458 .instrument(span)
459 .boxed()
460 }),
461 );
462 if prev_handler.is_some() {
463 panic!("registered a handler for the same message twice");
464 }
465 self
466 }
467
468 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
469 where
470 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
471 Fut: 'static + Send + Future<Output = Result<()>>,
472 M: EnvelopedMessage,
473 {
474 self.add_handler(move |envelope, session| handler(envelope.payload, session));
475 self
476 }
477
478 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
479 where
480 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
481 Fut: Send + Future<Output = Result<()>>,
482 M: RequestMessage,
483 {
484 let handler = Arc::new(handler);
485 self.add_handler(move |envelope, session| {
486 let receipt = envelope.receipt();
487 let handler = handler.clone();
488 async move {
489 let peer = session.peer.clone();
490 let responded = Arc::new(AtomicBool::default());
491 let response = Response {
492 peer: peer.clone(),
493 responded: responded.clone(),
494 receipt,
495 };
496 match (handler)(envelope.payload, response, session).await {
497 Ok(()) => {
498 if responded.load(std::sync::atomic::Ordering::SeqCst) {
499 Ok(())
500 } else {
501 Err(anyhow!("handler did not send a response"))?
502 }
503 }
504 Err(error) => {
505 peer.respond_with_error(
506 receipt,
507 proto::Error {
508 message: error.to_string(),
509 },
510 )?;
511 Err(error)
512 }
513 }
514 }
515 })
516 }
517
518 pub fn handle_connection(
519 self: &Arc<Self>,
520 connection: Connection,
521 address: String,
522 user: User,
523 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
524 executor: Executor,
525 ) -> impl Future<Output = Result<()>> {
526 let this = self.clone();
527 let user_id = user.id;
528 let login = user.github_login;
529 let span = info_span!("handle connection", %user_id, %login, %address);
530 let mut teardown = self.teardown.subscribe();
531 async move {
532 let (connection_id, handle_io, mut incoming_rx) = this
533 .peer
534 .add_connection(connection, {
535 let executor = executor.clone();
536 move |duration| executor.sleep(duration)
537 });
538
539 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
540 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
541 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
542
543 if let Some(send_connection_id) = send_connection_id.take() {
544 let _ = send_connection_id.send(connection_id);
545 }
546
547 if !user.connected_once {
548 this.peer.send(connection_id, proto::ShowContacts {})?;
549 this.app_state.db.set_user_connected_once(user_id, true).await?;
550 }
551
552 let (contacts, invite_code, channels_for_user, channel_invites) = future::try_join4(
553 this.app_state.db.get_contacts(user_id),
554 this.app_state.db.get_invite_code_for_user(user_id),
555 this.app_state.db.get_channels_for_user(user_id),
556 this.app_state.db.get_channel_invites_for_user(user_id)
557 ).await?;
558
559 {
560 let mut pool = this.connection_pool.lock();
561 pool.add_connection(connection_id, user_id, user.admin);
562 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
563 this.peer.send(connection_id, build_initial_channels_update(
564 channels_for_user,
565 channel_invites
566 ))?;
567
568 if let Some((code, count)) = invite_code {
569 this.peer.send(connection_id, proto::UpdateInviteInfo {
570 url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
571 count: count as u32,
572 })?;
573 }
574 }
575
576 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
577 this.peer.send(connection_id, incoming_call)?;
578 }
579
580 let session = Session {
581 user_id,
582 connection_id,
583 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
584 peer: this.peer.clone(),
585 connection_pool: this.connection_pool.clone(),
586 live_kit_client: this.app_state.live_kit_client.clone(),
587 executor: executor.clone(),
588 };
589 update_user_contacts(user_id, &session).await?;
590
591 let handle_io = handle_io.fuse();
592 futures::pin_mut!(handle_io);
593
594 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
595 // This prevents deadlocks when e.g., client A performs a request to client B and
596 // client B performs a request to client A. If both clients stop processing further
597 // messages until their respective request completes, they won't have a chance to
598 // respond to the other client's request and cause a deadlock.
599 //
600 // This arrangement ensures we will attempt to process earlier messages first, but fall
601 // back to processing messages arrived later in the spirit of making progress.
602 let mut foreground_message_handlers = FuturesUnordered::new();
603 let concurrent_handlers = Arc::new(Semaphore::new(256));
604 loop {
605 let next_message = async {
606 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
607 let message = incoming_rx.next().await;
608 (permit, message)
609 }.fuse();
610 futures::pin_mut!(next_message);
611 futures::select_biased! {
612 _ = teardown.changed().fuse() => return Ok(()),
613 result = handle_io => {
614 if let Err(error) = result {
615 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
616 }
617 break;
618 }
619 _ = foreground_message_handlers.next() => {}
620 next_message = next_message => {
621 let (permit, message) = next_message;
622 if let Some(message) = message {
623 let type_name = message.payload_type_name();
624 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
625 let span_enter = span.enter();
626 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
627 let is_background = message.is_background();
628 let handle_message = (handler)(message, session.clone());
629 drop(span_enter);
630
631 let handle_message = async move {
632 handle_message.await;
633 drop(permit);
634 }.instrument(span);
635 if is_background {
636 executor.spawn_detached(handle_message);
637 } else {
638 foreground_message_handlers.push(handle_message);
639 }
640 } else {
641 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
642 }
643 } else {
644 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
645 break;
646 }
647 }
648 }
649 }
650
651 drop(foreground_message_handlers);
652 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
653 if let Err(error) = connection_lost(session, teardown, executor).await {
654 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
655 }
656
657 Ok(())
658 }.instrument(span)
659 }
660
661 pub async fn invite_code_redeemed(
662 self: &Arc<Self>,
663 inviter_id: UserId,
664 invitee_id: UserId,
665 ) -> Result<()> {
666 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
667 if let Some(code) = &user.invite_code {
668 let pool = self.connection_pool.lock();
669 let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
670 for connection_id in pool.user_connection_ids(inviter_id) {
671 self.peer.send(
672 connection_id,
673 proto::UpdateContacts {
674 contacts: vec![invitee_contact.clone()],
675 ..Default::default()
676 },
677 )?;
678 self.peer.send(
679 connection_id,
680 proto::UpdateInviteInfo {
681 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
682 count: user.invite_count as u32,
683 },
684 )?;
685 }
686 }
687 }
688 Ok(())
689 }
690
691 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
692 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
693 if let Some(invite_code) = &user.invite_code {
694 let pool = self.connection_pool.lock();
695 for connection_id in pool.user_connection_ids(user_id) {
696 self.peer.send(
697 connection_id,
698 proto::UpdateInviteInfo {
699 url: format!(
700 "{}{}",
701 self.app_state.config.invite_link_prefix, invite_code
702 ),
703 count: user.invite_count as u32,
704 },
705 )?;
706 }
707 }
708 }
709 Ok(())
710 }
711
712 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
713 ServerSnapshot {
714 connection_pool: ConnectionPoolGuard {
715 guard: self.connection_pool.lock(),
716 _not_send: PhantomData,
717 },
718 peer: &self.peer,
719 }
720 }
721}
722
723impl<'a> Deref for ConnectionPoolGuard<'a> {
724 type Target = ConnectionPool;
725
726 fn deref(&self) -> &Self::Target {
727 &*self.guard
728 }
729}
730
731impl<'a> DerefMut for ConnectionPoolGuard<'a> {
732 fn deref_mut(&mut self) -> &mut Self::Target {
733 &mut *self.guard
734 }
735}
736
737impl<'a> Drop for ConnectionPoolGuard<'a> {
738 fn drop(&mut self) {
739 #[cfg(test)]
740 self.check_invariants();
741 }
742}
743
744fn broadcast<F>(
745 sender_id: Option<ConnectionId>,
746 receiver_ids: impl IntoIterator<Item = ConnectionId>,
747 mut f: F,
748) where
749 F: FnMut(ConnectionId) -> anyhow::Result<()>,
750{
751 for receiver_id in receiver_ids {
752 if Some(receiver_id) != sender_id {
753 if let Err(error) = f(receiver_id) {
754 tracing::error!("failed to send to {:?} {}", receiver_id, error);
755 }
756 }
757 }
758}
759
760lazy_static! {
761 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
762}
763
764pub struct ProtocolVersion(u32);
765
766impl Header for ProtocolVersion {
767 fn name() -> &'static HeaderName {
768 &ZED_PROTOCOL_VERSION
769 }
770
771 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
772 where
773 Self: Sized,
774 I: Iterator<Item = &'i axum::http::HeaderValue>,
775 {
776 let version = values
777 .next()
778 .ok_or_else(axum::headers::Error::invalid)?
779 .to_str()
780 .map_err(|_| axum::headers::Error::invalid())?
781 .parse()
782 .map_err(|_| axum::headers::Error::invalid())?;
783 Ok(Self(version))
784 }
785
786 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
787 values.extend([self.0.to_string().parse().unwrap()]);
788 }
789}
790
791pub fn routes(server: Arc<Server>) -> Router<Body> {
792 Router::new()
793 .route("/rpc", get(handle_websocket_request))
794 .layer(
795 ServiceBuilder::new()
796 .layer(Extension(server.app_state.clone()))
797 .layer(middleware::from_fn(auth::validate_header)),
798 )
799 .route("/metrics", get(handle_metrics))
800 .layer(Extension(server))
801}
802
803pub async fn handle_websocket_request(
804 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
805 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
806 Extension(server): Extension<Arc<Server>>,
807 Extension(user): Extension<User>,
808 ws: WebSocketUpgrade,
809) -> axum::response::Response {
810 if protocol_version != rpc::PROTOCOL_VERSION {
811 return (
812 StatusCode::UPGRADE_REQUIRED,
813 "client must be upgraded".to_string(),
814 )
815 .into_response();
816 }
817 let socket_address = socket_address.to_string();
818 ws.on_upgrade(move |socket| {
819 use util::ResultExt;
820 let socket = socket
821 .map_ok(to_tungstenite_message)
822 .err_into()
823 .with(|message| async move { Ok(to_axum_message(message)) });
824 let connection = Connection::new(Box::pin(socket));
825 async move {
826 server
827 .handle_connection(connection, socket_address, user, None, Executor::Production)
828 .await
829 .log_err();
830 }
831 })
832}
833
834pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
835 let connections = server
836 .connection_pool
837 .lock()
838 .connections()
839 .filter(|connection| !connection.admin)
840 .count();
841
842 METRIC_CONNECTIONS.set(connections as _);
843
844 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
845 METRIC_SHARED_PROJECTS.set(shared_projects as _);
846
847 let encoder = prometheus::TextEncoder::new();
848 let metric_families = prometheus::gather();
849 let encoded_metrics = encoder
850 .encode_to_string(&metric_families)
851 .map_err(|err| anyhow!("{}", err))?;
852 Ok(encoded_metrics)
853}
854
855#[instrument(err, skip(executor))]
856async fn connection_lost(
857 session: Session,
858 mut teardown: watch::Receiver<()>,
859 executor: Executor,
860) -> Result<()> {
861 session.peer.disconnect(session.connection_id);
862 session
863 .connection_pool()
864 .await
865 .remove_connection(session.connection_id)?;
866
867 session
868 .db()
869 .await
870 .connection_lost(session.connection_id)
871 .await
872 .trace_err();
873
874 futures::select_biased! {
875 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
876 leave_room_for_session(&session).await.trace_err();
877 leave_channel_buffers_for_session(&session)
878 .await
879 .trace_err();
880
881 if !session
882 .connection_pool()
883 .await
884 .is_user_online(session.user_id)
885 {
886 let db = session.db().await;
887 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
888 room_updated(&room, &session.peer);
889 }
890 }
891 update_user_contacts(session.user_id, &session).await?;
892
893
894 }
895 _ = teardown.changed().fuse() => {}
896 }
897
898 Ok(())
899}
900
901async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
902 response.send(proto::Ack {})?;
903 Ok(())
904}
905
906async fn create_room(
907 _request: proto::CreateRoom,
908 response: Response<proto::CreateRoom>,
909 session: Session,
910) -> Result<()> {
911 let live_kit_room = nanoid::nanoid!(30);
912
913 let live_kit_connection_info = {
914 let live_kit_room = live_kit_room.clone();
915 let live_kit = session.live_kit_client.as_ref();
916
917 util::async_iife!({
918 let live_kit = live_kit?;
919
920 live_kit
921 .create_room(live_kit_room.clone())
922 .await
923 .trace_err()?;
924
925 let token = live_kit
926 .room_token(&live_kit_room, &session.user_id.to_string())
927 .trace_err()?;
928
929 Some(proto::LiveKitConnectionInfo {
930 server_url: live_kit.url().into(),
931 token,
932 })
933 })
934 }
935 .await;
936
937 let room = session
938 .db()
939 .await
940 .create_room(session.user_id, session.connection_id, &live_kit_room)
941 .await?;
942
943 response.send(proto::CreateRoomResponse {
944 room: Some(room.clone()),
945 live_kit_connection_info,
946 })?;
947
948 update_user_contacts(session.user_id, &session).await?;
949 Ok(())
950}
951
952async fn join_room(
953 request: proto::JoinRoom,
954 response: Response<proto::JoinRoom>,
955 session: Session,
956) -> Result<()> {
957 let room_id = RoomId::from_proto(request.id);
958 let joined_room = {
959 let room = session
960 .db()
961 .await
962 .join_room(room_id, session.user_id, session.connection_id)
963 .await?;
964 room_updated(&room.room, &session.peer);
965 room.into_inner()
966 };
967
968 if let Some(channel_id) = joined_room.channel_id {
969 channel_updated(
970 channel_id,
971 &joined_room.room,
972 &joined_room.channel_members,
973 &session.peer,
974 &*session.connection_pool().await,
975 )
976 }
977
978 for connection_id in session
979 .connection_pool()
980 .await
981 .user_connection_ids(session.user_id)
982 {
983 session
984 .peer
985 .send(
986 connection_id,
987 proto::CallCanceled {
988 room_id: room_id.to_proto(),
989 },
990 )
991 .trace_err();
992 }
993
994 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
995 if let Some(token) = live_kit
996 .room_token(
997 &joined_room.room.live_kit_room,
998 &session.user_id.to_string(),
999 )
1000 .trace_err()
1001 {
1002 Some(proto::LiveKitConnectionInfo {
1003 server_url: live_kit.url().into(),
1004 token,
1005 })
1006 } else {
1007 None
1008 }
1009 } else {
1010 None
1011 };
1012
1013 response.send(proto::JoinRoomResponse {
1014 room: Some(joined_room.room),
1015 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
1016 live_kit_connection_info,
1017 })?;
1018
1019 update_user_contacts(session.user_id, &session).await?;
1020 Ok(())
1021}
1022
1023async fn rejoin_room(
1024 request: proto::RejoinRoom,
1025 response: Response<proto::RejoinRoom>,
1026 session: Session,
1027) -> Result<()> {
1028 let room;
1029 let channel_id;
1030 let channel_members;
1031 {
1032 let mut rejoined_room = session
1033 .db()
1034 .await
1035 .rejoin_room(request, session.user_id, session.connection_id)
1036 .await?;
1037
1038 response.send(proto::RejoinRoomResponse {
1039 room: Some(rejoined_room.room.clone()),
1040 reshared_projects: rejoined_room
1041 .reshared_projects
1042 .iter()
1043 .map(|project| proto::ResharedProject {
1044 id: project.id.to_proto(),
1045 collaborators: project
1046 .collaborators
1047 .iter()
1048 .map(|collaborator| collaborator.to_proto())
1049 .collect(),
1050 })
1051 .collect(),
1052 rejoined_projects: rejoined_room
1053 .rejoined_projects
1054 .iter()
1055 .map(|rejoined_project| proto::RejoinedProject {
1056 id: rejoined_project.id.to_proto(),
1057 worktrees: rejoined_project
1058 .worktrees
1059 .iter()
1060 .map(|worktree| proto::WorktreeMetadata {
1061 id: worktree.id,
1062 root_name: worktree.root_name.clone(),
1063 visible: worktree.visible,
1064 abs_path: worktree.abs_path.clone(),
1065 })
1066 .collect(),
1067 collaborators: rejoined_project
1068 .collaborators
1069 .iter()
1070 .map(|collaborator| collaborator.to_proto())
1071 .collect(),
1072 language_servers: rejoined_project.language_servers.clone(),
1073 })
1074 .collect(),
1075 })?;
1076 room_updated(&rejoined_room.room, &session.peer);
1077
1078 for project in &rejoined_room.reshared_projects {
1079 for collaborator in &project.collaborators {
1080 session
1081 .peer
1082 .send(
1083 collaborator.connection_id,
1084 proto::UpdateProjectCollaborator {
1085 project_id: project.id.to_proto(),
1086 old_peer_id: Some(project.old_connection_id.into()),
1087 new_peer_id: Some(session.connection_id.into()),
1088 },
1089 )
1090 .trace_err();
1091 }
1092
1093 broadcast(
1094 Some(session.connection_id),
1095 project
1096 .collaborators
1097 .iter()
1098 .map(|collaborator| collaborator.connection_id),
1099 |connection_id| {
1100 session.peer.forward_send(
1101 session.connection_id,
1102 connection_id,
1103 proto::UpdateProject {
1104 project_id: project.id.to_proto(),
1105 worktrees: project.worktrees.clone(),
1106 },
1107 )
1108 },
1109 );
1110 }
1111
1112 for project in &rejoined_room.rejoined_projects {
1113 for collaborator in &project.collaborators {
1114 session
1115 .peer
1116 .send(
1117 collaborator.connection_id,
1118 proto::UpdateProjectCollaborator {
1119 project_id: project.id.to_proto(),
1120 old_peer_id: Some(project.old_connection_id.into()),
1121 new_peer_id: Some(session.connection_id.into()),
1122 },
1123 )
1124 .trace_err();
1125 }
1126 }
1127
1128 for project in &mut rejoined_room.rejoined_projects {
1129 for worktree in mem::take(&mut project.worktrees) {
1130 #[cfg(any(test, feature = "test-support"))]
1131 const MAX_CHUNK_SIZE: usize = 2;
1132 #[cfg(not(any(test, feature = "test-support")))]
1133 const MAX_CHUNK_SIZE: usize = 256;
1134
1135 // Stream this worktree's entries.
1136 let message = proto::UpdateWorktree {
1137 project_id: project.id.to_proto(),
1138 worktree_id: worktree.id,
1139 abs_path: worktree.abs_path.clone(),
1140 root_name: worktree.root_name,
1141 updated_entries: worktree.updated_entries,
1142 removed_entries: worktree.removed_entries,
1143 scan_id: worktree.scan_id,
1144 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1145 updated_repositories: worktree.updated_repositories,
1146 removed_repositories: worktree.removed_repositories,
1147 };
1148 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1149 session.peer.send(session.connection_id, update.clone())?;
1150 }
1151
1152 // Stream this worktree's diagnostics.
1153 for summary in worktree.diagnostic_summaries {
1154 session.peer.send(
1155 session.connection_id,
1156 proto::UpdateDiagnosticSummary {
1157 project_id: project.id.to_proto(),
1158 worktree_id: worktree.id,
1159 summary: Some(summary),
1160 },
1161 )?;
1162 }
1163
1164 for settings_file in worktree.settings_files {
1165 session.peer.send(
1166 session.connection_id,
1167 proto::UpdateWorktreeSettings {
1168 project_id: project.id.to_proto(),
1169 worktree_id: worktree.id,
1170 path: settings_file.path,
1171 content: Some(settings_file.content),
1172 },
1173 )?;
1174 }
1175 }
1176
1177 for language_server in &project.language_servers {
1178 session.peer.send(
1179 session.connection_id,
1180 proto::UpdateLanguageServer {
1181 project_id: project.id.to_proto(),
1182 language_server_id: language_server.id,
1183 variant: Some(
1184 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1185 proto::LspDiskBasedDiagnosticsUpdated {},
1186 ),
1187 ),
1188 },
1189 )?;
1190 }
1191 }
1192
1193 let rejoined_room = rejoined_room.into_inner();
1194
1195 room = rejoined_room.room;
1196 channel_id = rejoined_room.channel_id;
1197 channel_members = rejoined_room.channel_members;
1198 }
1199
1200 if let Some(channel_id) = channel_id {
1201 channel_updated(
1202 channel_id,
1203 &room,
1204 &channel_members,
1205 &session.peer,
1206 &*session.connection_pool().await,
1207 );
1208 }
1209
1210 update_user_contacts(session.user_id, &session).await?;
1211 Ok(())
1212}
1213
1214async fn leave_room(
1215 _: proto::LeaveRoom,
1216 response: Response<proto::LeaveRoom>,
1217 session: Session,
1218) -> Result<()> {
1219 leave_room_for_session(&session).await?;
1220 response.send(proto::Ack {})?;
1221 Ok(())
1222}
1223
1224async fn call(
1225 request: proto::Call,
1226 response: Response<proto::Call>,
1227 session: Session,
1228) -> Result<()> {
1229 let room_id = RoomId::from_proto(request.room_id);
1230 let calling_user_id = session.user_id;
1231 let calling_connection_id = session.connection_id;
1232 let called_user_id = UserId::from_proto(request.called_user_id);
1233 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1234 if !session
1235 .db()
1236 .await
1237 .has_contact(calling_user_id, called_user_id)
1238 .await?
1239 {
1240 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1241 }
1242
1243 let incoming_call = {
1244 let (room, incoming_call) = &mut *session
1245 .db()
1246 .await
1247 .call(
1248 room_id,
1249 calling_user_id,
1250 calling_connection_id,
1251 called_user_id,
1252 initial_project_id,
1253 )
1254 .await?;
1255 room_updated(&room, &session.peer);
1256 mem::take(incoming_call)
1257 };
1258 update_user_contacts(called_user_id, &session).await?;
1259
1260 let mut calls = session
1261 .connection_pool()
1262 .await
1263 .user_connection_ids(called_user_id)
1264 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1265 .collect::<FuturesUnordered<_>>();
1266
1267 while let Some(call_response) = calls.next().await {
1268 match call_response.as_ref() {
1269 Ok(_) => {
1270 response.send(proto::Ack {})?;
1271 return Ok(());
1272 }
1273 Err(_) => {
1274 call_response.trace_err();
1275 }
1276 }
1277 }
1278
1279 {
1280 let room = session
1281 .db()
1282 .await
1283 .call_failed(room_id, called_user_id)
1284 .await?;
1285 room_updated(&room, &session.peer);
1286 }
1287 update_user_contacts(called_user_id, &session).await?;
1288
1289 Err(anyhow!("failed to ring user"))?
1290}
1291
1292async fn cancel_call(
1293 request: proto::CancelCall,
1294 response: Response<proto::CancelCall>,
1295 session: Session,
1296) -> Result<()> {
1297 let called_user_id = UserId::from_proto(request.called_user_id);
1298 let room_id = RoomId::from_proto(request.room_id);
1299 {
1300 let room = session
1301 .db()
1302 .await
1303 .cancel_call(room_id, session.connection_id, called_user_id)
1304 .await?;
1305 room_updated(&room, &session.peer);
1306 }
1307
1308 for connection_id in session
1309 .connection_pool()
1310 .await
1311 .user_connection_ids(called_user_id)
1312 {
1313 session
1314 .peer
1315 .send(
1316 connection_id,
1317 proto::CallCanceled {
1318 room_id: room_id.to_proto(),
1319 },
1320 )
1321 .trace_err();
1322 }
1323 response.send(proto::Ack {})?;
1324
1325 update_user_contacts(called_user_id, &session).await?;
1326 Ok(())
1327}
1328
1329async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1330 let room_id = RoomId::from_proto(message.room_id);
1331 {
1332 let room = session
1333 .db()
1334 .await
1335 .decline_call(Some(room_id), session.user_id)
1336 .await?
1337 .ok_or_else(|| anyhow!("failed to decline call"))?;
1338 room_updated(&room, &session.peer);
1339 }
1340
1341 for connection_id in session
1342 .connection_pool()
1343 .await
1344 .user_connection_ids(session.user_id)
1345 {
1346 session
1347 .peer
1348 .send(
1349 connection_id,
1350 proto::CallCanceled {
1351 room_id: room_id.to_proto(),
1352 },
1353 )
1354 .trace_err();
1355 }
1356 update_user_contacts(session.user_id, &session).await?;
1357 Ok(())
1358}
1359
1360async fn update_participant_location(
1361 request: proto::UpdateParticipantLocation,
1362 response: Response<proto::UpdateParticipantLocation>,
1363 session: Session,
1364) -> Result<()> {
1365 let room_id = RoomId::from_proto(request.room_id);
1366 let location = request
1367 .location
1368 .ok_or_else(|| anyhow!("invalid location"))?;
1369
1370 let db = session.db().await;
1371 let room = db
1372 .update_room_participant_location(room_id, session.connection_id, location)
1373 .await?;
1374
1375 room_updated(&room, &session.peer);
1376 response.send(proto::Ack {})?;
1377 Ok(())
1378}
1379
1380async fn share_project(
1381 request: proto::ShareProject,
1382 response: Response<proto::ShareProject>,
1383 session: Session,
1384) -> Result<()> {
1385 let (project_id, room) = &*session
1386 .db()
1387 .await
1388 .share_project(
1389 RoomId::from_proto(request.room_id),
1390 session.connection_id,
1391 &request.worktrees,
1392 )
1393 .await?;
1394 response.send(proto::ShareProjectResponse {
1395 project_id: project_id.to_proto(),
1396 })?;
1397 room_updated(&room, &session.peer);
1398
1399 Ok(())
1400}
1401
1402async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1403 let project_id = ProjectId::from_proto(message.project_id);
1404
1405 let (room, guest_connection_ids) = &*session
1406 .db()
1407 .await
1408 .unshare_project(project_id, session.connection_id)
1409 .await?;
1410
1411 broadcast(
1412 Some(session.connection_id),
1413 guest_connection_ids.iter().copied(),
1414 |conn_id| session.peer.send(conn_id, message.clone()),
1415 );
1416 room_updated(&room, &session.peer);
1417
1418 Ok(())
1419}
1420
1421async fn join_project(
1422 request: proto::JoinProject,
1423 response: Response<proto::JoinProject>,
1424 session: Session,
1425) -> Result<()> {
1426 let project_id = ProjectId::from_proto(request.project_id);
1427 let guest_user_id = session.user_id;
1428
1429 tracing::info!(%project_id, "join project");
1430
1431 let (project, replica_id) = &mut *session
1432 .db()
1433 .await
1434 .join_project(project_id, session.connection_id)
1435 .await?;
1436
1437 let collaborators = project
1438 .collaborators
1439 .iter()
1440 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1441 .map(|collaborator| collaborator.to_proto())
1442 .collect::<Vec<_>>();
1443
1444 let worktrees = project
1445 .worktrees
1446 .iter()
1447 .map(|(id, worktree)| proto::WorktreeMetadata {
1448 id: *id,
1449 root_name: worktree.root_name.clone(),
1450 visible: worktree.visible,
1451 abs_path: worktree.abs_path.clone(),
1452 })
1453 .collect::<Vec<_>>();
1454
1455 for collaborator in &collaborators {
1456 session
1457 .peer
1458 .send(
1459 collaborator.peer_id.unwrap().into(),
1460 proto::AddProjectCollaborator {
1461 project_id: project_id.to_proto(),
1462 collaborator: Some(proto::Collaborator {
1463 peer_id: Some(session.connection_id.into()),
1464 replica_id: replica_id.0 as u32,
1465 user_id: guest_user_id.to_proto(),
1466 }),
1467 },
1468 )
1469 .trace_err();
1470 }
1471
1472 // First, we send the metadata associated with each worktree.
1473 response.send(proto::JoinProjectResponse {
1474 worktrees: worktrees.clone(),
1475 replica_id: replica_id.0 as u32,
1476 collaborators: collaborators.clone(),
1477 language_servers: project.language_servers.clone(),
1478 })?;
1479
1480 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1481 #[cfg(any(test, feature = "test-support"))]
1482 const MAX_CHUNK_SIZE: usize = 2;
1483 #[cfg(not(any(test, feature = "test-support")))]
1484 const MAX_CHUNK_SIZE: usize = 256;
1485
1486 // Stream this worktree's entries.
1487 let message = proto::UpdateWorktree {
1488 project_id: project_id.to_proto(),
1489 worktree_id,
1490 abs_path: worktree.abs_path.clone(),
1491 root_name: worktree.root_name,
1492 updated_entries: worktree.entries,
1493 removed_entries: Default::default(),
1494 scan_id: worktree.scan_id,
1495 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1496 updated_repositories: worktree.repository_entries.into_values().collect(),
1497 removed_repositories: Default::default(),
1498 };
1499 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1500 session.peer.send(session.connection_id, update.clone())?;
1501 }
1502
1503 // Stream this worktree's diagnostics.
1504 for summary in worktree.diagnostic_summaries {
1505 session.peer.send(
1506 session.connection_id,
1507 proto::UpdateDiagnosticSummary {
1508 project_id: project_id.to_proto(),
1509 worktree_id: worktree.id,
1510 summary: Some(summary),
1511 },
1512 )?;
1513 }
1514
1515 for settings_file in worktree.settings_files {
1516 session.peer.send(
1517 session.connection_id,
1518 proto::UpdateWorktreeSettings {
1519 project_id: project_id.to_proto(),
1520 worktree_id: worktree.id,
1521 path: settings_file.path,
1522 content: Some(settings_file.content),
1523 },
1524 )?;
1525 }
1526 }
1527
1528 for language_server in &project.language_servers {
1529 session.peer.send(
1530 session.connection_id,
1531 proto::UpdateLanguageServer {
1532 project_id: project_id.to_proto(),
1533 language_server_id: language_server.id,
1534 variant: Some(
1535 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1536 proto::LspDiskBasedDiagnosticsUpdated {},
1537 ),
1538 ),
1539 },
1540 )?;
1541 }
1542
1543 Ok(())
1544}
1545
1546async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1547 let sender_id = session.connection_id;
1548 let project_id = ProjectId::from_proto(request.project_id);
1549
1550 let (room, project) = &*session
1551 .db()
1552 .await
1553 .leave_project(project_id, sender_id)
1554 .await?;
1555 tracing::info!(
1556 %project_id,
1557 host_user_id = %project.host_user_id,
1558 host_connection_id = %project.host_connection_id,
1559 "leave project"
1560 );
1561
1562 project_left(&project, &session);
1563 room_updated(&room, &session.peer);
1564
1565 Ok(())
1566}
1567
1568async fn update_project(
1569 request: proto::UpdateProject,
1570 response: Response<proto::UpdateProject>,
1571 session: Session,
1572) -> Result<()> {
1573 let project_id = ProjectId::from_proto(request.project_id);
1574 let (room, guest_connection_ids) = &*session
1575 .db()
1576 .await
1577 .update_project(project_id, session.connection_id, &request.worktrees)
1578 .await?;
1579 broadcast(
1580 Some(session.connection_id),
1581 guest_connection_ids.iter().copied(),
1582 |connection_id| {
1583 session
1584 .peer
1585 .forward_send(session.connection_id, connection_id, request.clone())
1586 },
1587 );
1588 room_updated(&room, &session.peer);
1589 response.send(proto::Ack {})?;
1590
1591 Ok(())
1592}
1593
1594async fn update_worktree(
1595 request: proto::UpdateWorktree,
1596 response: Response<proto::UpdateWorktree>,
1597 session: Session,
1598) -> Result<()> {
1599 let guest_connection_ids = session
1600 .db()
1601 .await
1602 .update_worktree(&request, session.connection_id)
1603 .await?;
1604
1605 broadcast(
1606 Some(session.connection_id),
1607 guest_connection_ids.iter().copied(),
1608 |connection_id| {
1609 session
1610 .peer
1611 .forward_send(session.connection_id, connection_id, request.clone())
1612 },
1613 );
1614 response.send(proto::Ack {})?;
1615 Ok(())
1616}
1617
1618async fn update_diagnostic_summary(
1619 message: proto::UpdateDiagnosticSummary,
1620 session: Session,
1621) -> Result<()> {
1622 let guest_connection_ids = session
1623 .db()
1624 .await
1625 .update_diagnostic_summary(&message, session.connection_id)
1626 .await?;
1627
1628 broadcast(
1629 Some(session.connection_id),
1630 guest_connection_ids.iter().copied(),
1631 |connection_id| {
1632 session
1633 .peer
1634 .forward_send(session.connection_id, connection_id, message.clone())
1635 },
1636 );
1637
1638 Ok(())
1639}
1640
1641async fn update_worktree_settings(
1642 message: proto::UpdateWorktreeSettings,
1643 session: Session,
1644) -> Result<()> {
1645 let guest_connection_ids = session
1646 .db()
1647 .await
1648 .update_worktree_settings(&message, session.connection_id)
1649 .await?;
1650
1651 broadcast(
1652 Some(session.connection_id),
1653 guest_connection_ids.iter().copied(),
1654 |connection_id| {
1655 session
1656 .peer
1657 .forward_send(session.connection_id, connection_id, message.clone())
1658 },
1659 );
1660
1661 Ok(())
1662}
1663
1664async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
1665 broadcast_project_message(request.project_id, request, session).await
1666}
1667
1668async fn start_language_server(
1669 request: proto::StartLanguageServer,
1670 session: Session,
1671) -> Result<()> {
1672 let guest_connection_ids = session
1673 .db()
1674 .await
1675 .start_language_server(&request, session.connection_id)
1676 .await?;
1677
1678 broadcast(
1679 Some(session.connection_id),
1680 guest_connection_ids.iter().copied(),
1681 |connection_id| {
1682 session
1683 .peer
1684 .forward_send(session.connection_id, connection_id, request.clone())
1685 },
1686 );
1687 Ok(())
1688}
1689
1690async fn update_language_server(
1691 request: proto::UpdateLanguageServer,
1692 session: Session,
1693) -> Result<()> {
1694 session.executor.record_backtrace();
1695 let project_id = ProjectId::from_proto(request.project_id);
1696 let project_connection_ids = session
1697 .db()
1698 .await
1699 .project_connection_ids(project_id, session.connection_id)
1700 .await?;
1701 broadcast(
1702 Some(session.connection_id),
1703 project_connection_ids.iter().copied(),
1704 |connection_id| {
1705 session
1706 .peer
1707 .forward_send(session.connection_id, connection_id, request.clone())
1708 },
1709 );
1710 Ok(())
1711}
1712
1713async fn forward_project_request<T>(
1714 request: T,
1715 response: Response<T>,
1716 session: Session,
1717) -> Result<()>
1718where
1719 T: EntityMessage + RequestMessage,
1720{
1721 session.executor.record_backtrace();
1722 let project_id = ProjectId::from_proto(request.remote_entity_id());
1723 let host_connection_id = {
1724 let collaborators = session
1725 .db()
1726 .await
1727 .project_collaborators(project_id, session.connection_id)
1728 .await?;
1729 collaborators
1730 .iter()
1731 .find(|collaborator| collaborator.is_host)
1732 .ok_or_else(|| anyhow!("host not found"))?
1733 .connection_id
1734 };
1735
1736 let payload = session
1737 .peer
1738 .forward_request(session.connection_id, host_connection_id, request)
1739 .await?;
1740
1741 response.send(payload)?;
1742 Ok(())
1743}
1744
1745async fn create_buffer_for_peer(
1746 request: proto::CreateBufferForPeer,
1747 session: Session,
1748) -> Result<()> {
1749 session.executor.record_backtrace();
1750 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1751 session
1752 .peer
1753 .forward_send(session.connection_id, peer_id.into(), request)?;
1754 Ok(())
1755}
1756
1757async fn update_buffer(
1758 request: proto::UpdateBuffer,
1759 response: Response<proto::UpdateBuffer>,
1760 session: Session,
1761) -> Result<()> {
1762 session.executor.record_backtrace();
1763 let project_id = ProjectId::from_proto(request.project_id);
1764 let mut guest_connection_ids;
1765 let mut host_connection_id = None;
1766 {
1767 let collaborators = session
1768 .db()
1769 .await
1770 .project_collaborators(project_id, session.connection_id)
1771 .await?;
1772 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1773 for collaborator in collaborators.iter() {
1774 if collaborator.is_host {
1775 host_connection_id = Some(collaborator.connection_id);
1776 } else {
1777 guest_connection_ids.push(collaborator.connection_id);
1778 }
1779 }
1780 }
1781 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1782
1783 session.executor.record_backtrace();
1784 broadcast(
1785 Some(session.connection_id),
1786 guest_connection_ids,
1787 |connection_id| {
1788 session
1789 .peer
1790 .forward_send(session.connection_id, connection_id, request.clone())
1791 },
1792 );
1793 if host_connection_id != session.connection_id {
1794 session
1795 .peer
1796 .forward_request(session.connection_id, host_connection_id, request.clone())
1797 .await?;
1798 }
1799
1800 response.send(proto::Ack {})?;
1801 Ok(())
1802}
1803
1804async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1805 let project_id = ProjectId::from_proto(request.project_id);
1806 let project_connection_ids = session
1807 .db()
1808 .await
1809 .project_connection_ids(project_id, session.connection_id)
1810 .await?;
1811
1812 broadcast(
1813 Some(session.connection_id),
1814 project_connection_ids.iter().copied(),
1815 |connection_id| {
1816 session
1817 .peer
1818 .forward_send(session.connection_id, connection_id, request.clone())
1819 },
1820 );
1821 Ok(())
1822}
1823
1824async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1825 let project_id = ProjectId::from_proto(request.project_id);
1826 let project_connection_ids = session
1827 .db()
1828 .await
1829 .project_connection_ids(project_id, session.connection_id)
1830 .await?;
1831 broadcast(
1832 Some(session.connection_id),
1833 project_connection_ids.iter().copied(),
1834 |connection_id| {
1835 session
1836 .peer
1837 .forward_send(session.connection_id, connection_id, request.clone())
1838 },
1839 );
1840 Ok(())
1841}
1842
1843async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1844 broadcast_project_message(request.project_id, request, session).await
1845}
1846
1847async fn broadcast_project_message<T: EnvelopedMessage>(
1848 project_id: u64,
1849 request: T,
1850 session: Session,
1851) -> Result<()> {
1852 let project_id = ProjectId::from_proto(project_id);
1853 let project_connection_ids = session
1854 .db()
1855 .await
1856 .project_connection_ids(project_id, session.connection_id)
1857 .await?;
1858 broadcast(
1859 Some(session.connection_id),
1860 project_connection_ids.iter().copied(),
1861 |connection_id| {
1862 session
1863 .peer
1864 .forward_send(session.connection_id, connection_id, request.clone())
1865 },
1866 );
1867 Ok(())
1868}
1869
1870async fn follow(
1871 request: proto::Follow,
1872 response: Response<proto::Follow>,
1873 session: Session,
1874) -> Result<()> {
1875 let project_id = ProjectId::from_proto(request.project_id);
1876 let leader_id = request
1877 .leader_id
1878 .ok_or_else(|| anyhow!("invalid leader id"))?
1879 .into();
1880 let follower_id = session.connection_id;
1881
1882 {
1883 let project_connection_ids = session
1884 .db()
1885 .await
1886 .project_connection_ids(project_id, session.connection_id)
1887 .await?;
1888
1889 if !project_connection_ids.contains(&leader_id) {
1890 Err(anyhow!("no such peer"))?;
1891 }
1892 }
1893
1894 let mut response_payload = session
1895 .peer
1896 .forward_request(session.connection_id, leader_id, request)
1897 .await?;
1898 response_payload
1899 .views
1900 .retain(|view| view.leader_id != Some(follower_id.into()));
1901 response.send(response_payload)?;
1902
1903 let room = session
1904 .db()
1905 .await
1906 .follow(project_id, leader_id, follower_id)
1907 .await?;
1908 room_updated(&room, &session.peer);
1909
1910 Ok(())
1911}
1912
1913async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1914 let project_id = ProjectId::from_proto(request.project_id);
1915 let leader_id = request
1916 .leader_id
1917 .ok_or_else(|| anyhow!("invalid leader id"))?
1918 .into();
1919 let follower_id = session.connection_id;
1920
1921 if !session
1922 .db()
1923 .await
1924 .project_connection_ids(project_id, session.connection_id)
1925 .await?
1926 .contains(&leader_id)
1927 {
1928 Err(anyhow!("no such peer"))?;
1929 }
1930
1931 session
1932 .peer
1933 .forward_send(session.connection_id, leader_id, request)?;
1934
1935 let room = session
1936 .db()
1937 .await
1938 .unfollow(project_id, leader_id, follower_id)
1939 .await?;
1940 room_updated(&room, &session.peer);
1941
1942 Ok(())
1943}
1944
1945async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1946 let project_id = ProjectId::from_proto(request.project_id);
1947 let project_connection_ids = session
1948 .db
1949 .lock()
1950 .await
1951 .project_connection_ids(project_id, session.connection_id)
1952 .await?;
1953
1954 let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1955 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1956 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1957 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1958 });
1959 for follower_peer_id in request.follower_ids.iter().copied() {
1960 let follower_connection_id = follower_peer_id.into();
1961 if project_connection_ids.contains(&follower_connection_id)
1962 && Some(follower_peer_id) != leader_id
1963 {
1964 session.peer.forward_send(
1965 session.connection_id,
1966 follower_connection_id,
1967 request.clone(),
1968 )?;
1969 }
1970 }
1971 Ok(())
1972}
1973
1974async fn get_users(
1975 request: proto::GetUsers,
1976 response: Response<proto::GetUsers>,
1977 session: Session,
1978) -> Result<()> {
1979 let user_ids = request
1980 .user_ids
1981 .into_iter()
1982 .map(UserId::from_proto)
1983 .collect();
1984 let users = session
1985 .db()
1986 .await
1987 .get_users_by_ids(user_ids)
1988 .await?
1989 .into_iter()
1990 .map(|user| proto::User {
1991 id: user.id.to_proto(),
1992 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1993 github_login: user.github_login,
1994 })
1995 .collect();
1996 response.send(proto::UsersResponse { users })?;
1997 Ok(())
1998}
1999
2000async fn fuzzy_search_users(
2001 request: proto::FuzzySearchUsers,
2002 response: Response<proto::FuzzySearchUsers>,
2003 session: Session,
2004) -> Result<()> {
2005 let query = request.query;
2006 let users = match query.len() {
2007 0 => vec![],
2008 1 | 2 => session
2009 .db()
2010 .await
2011 .get_user_by_github_login(&query)
2012 .await?
2013 .into_iter()
2014 .collect(),
2015 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2016 };
2017 let users = users
2018 .into_iter()
2019 .filter(|user| user.id != session.user_id)
2020 .map(|user| proto::User {
2021 id: user.id.to_proto(),
2022 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2023 github_login: user.github_login,
2024 })
2025 .collect();
2026 response.send(proto::UsersResponse { users })?;
2027 Ok(())
2028}
2029
2030async fn request_contact(
2031 request: proto::RequestContact,
2032 response: Response<proto::RequestContact>,
2033 session: Session,
2034) -> Result<()> {
2035 let requester_id = session.user_id;
2036 let responder_id = UserId::from_proto(request.responder_id);
2037 if requester_id == responder_id {
2038 return Err(anyhow!("cannot add yourself as a contact"))?;
2039 }
2040
2041 session
2042 .db()
2043 .await
2044 .send_contact_request(requester_id, responder_id)
2045 .await?;
2046
2047 // Update outgoing contact requests of requester
2048 let mut update = proto::UpdateContacts::default();
2049 update.outgoing_requests.push(responder_id.to_proto());
2050 for connection_id in session
2051 .connection_pool()
2052 .await
2053 .user_connection_ids(requester_id)
2054 {
2055 session.peer.send(connection_id, update.clone())?;
2056 }
2057
2058 // Update incoming contact requests of responder
2059 let mut update = proto::UpdateContacts::default();
2060 update
2061 .incoming_requests
2062 .push(proto::IncomingContactRequest {
2063 requester_id: requester_id.to_proto(),
2064 should_notify: true,
2065 });
2066 for connection_id in session
2067 .connection_pool()
2068 .await
2069 .user_connection_ids(responder_id)
2070 {
2071 session.peer.send(connection_id, update.clone())?;
2072 }
2073
2074 response.send(proto::Ack {})?;
2075 Ok(())
2076}
2077
2078async fn respond_to_contact_request(
2079 request: proto::RespondToContactRequest,
2080 response: Response<proto::RespondToContactRequest>,
2081 session: Session,
2082) -> Result<()> {
2083 let responder_id = session.user_id;
2084 let requester_id = UserId::from_proto(request.requester_id);
2085 let db = session.db().await;
2086 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2087 db.dismiss_contact_notification(responder_id, requester_id)
2088 .await?;
2089 } else {
2090 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2091
2092 db.respond_to_contact_request(responder_id, requester_id, accept)
2093 .await?;
2094 let requester_busy = db.is_user_busy(requester_id).await?;
2095 let responder_busy = db.is_user_busy(responder_id).await?;
2096
2097 let pool = session.connection_pool().await;
2098 // Update responder with new contact
2099 let mut update = proto::UpdateContacts::default();
2100 if accept {
2101 update
2102 .contacts
2103 .push(contact_for_user(requester_id, false, requester_busy, &pool));
2104 }
2105 update
2106 .remove_incoming_requests
2107 .push(requester_id.to_proto());
2108 for connection_id in pool.user_connection_ids(responder_id) {
2109 session.peer.send(connection_id, update.clone())?;
2110 }
2111
2112 // Update requester with new contact
2113 let mut update = proto::UpdateContacts::default();
2114 if accept {
2115 update
2116 .contacts
2117 .push(contact_for_user(responder_id, true, responder_busy, &pool));
2118 }
2119 update
2120 .remove_outgoing_requests
2121 .push(responder_id.to_proto());
2122 for connection_id in pool.user_connection_ids(requester_id) {
2123 session.peer.send(connection_id, update.clone())?;
2124 }
2125 }
2126
2127 response.send(proto::Ack {})?;
2128 Ok(())
2129}
2130
2131async fn remove_contact(
2132 request: proto::RemoveContact,
2133 response: Response<proto::RemoveContact>,
2134 session: Session,
2135) -> Result<()> {
2136 let requester_id = session.user_id;
2137 let responder_id = UserId::from_proto(request.user_id);
2138 let db = session.db().await;
2139 let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
2140
2141 let pool = session.connection_pool().await;
2142 // Update outgoing contact requests of requester
2143 let mut update = proto::UpdateContacts::default();
2144 if contact_accepted {
2145 update.remove_contacts.push(responder_id.to_proto());
2146 } else {
2147 update
2148 .remove_outgoing_requests
2149 .push(responder_id.to_proto());
2150 }
2151 for connection_id in pool.user_connection_ids(requester_id) {
2152 session.peer.send(connection_id, update.clone())?;
2153 }
2154
2155 // Update incoming contact requests of responder
2156 let mut update = proto::UpdateContacts::default();
2157 if contact_accepted {
2158 update.remove_contacts.push(requester_id.to_proto());
2159 } else {
2160 update
2161 .remove_incoming_requests
2162 .push(requester_id.to_proto());
2163 }
2164 for connection_id in pool.user_connection_ids(responder_id) {
2165 session.peer.send(connection_id, update.clone())?;
2166 }
2167
2168 response.send(proto::Ack {})?;
2169 Ok(())
2170}
2171
2172async fn create_channel(
2173 request: proto::CreateChannel,
2174 response: Response<proto::CreateChannel>,
2175 session: Session,
2176) -> Result<()> {
2177 let db = session.db().await;
2178 let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
2179
2180 if let Some(live_kit) = session.live_kit_client.as_ref() {
2181 live_kit.create_room(live_kit_room.clone()).await?;
2182 }
2183
2184 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2185 let id = db
2186 .create_channel(&request.name, parent_id, &live_kit_room, session.user_id)
2187 .await?;
2188
2189 let channel = proto::Channel {
2190 id: id.to_proto(),
2191 name: request.name,
2192 parent_id: request.parent_id,
2193 };
2194
2195 response.send(proto::ChannelResponse {
2196 channel: Some(channel.clone()),
2197 })?;
2198
2199 let mut update = proto::UpdateChannels::default();
2200 update.channels.push(channel);
2201
2202 let user_ids_to_notify = if let Some(parent_id) = parent_id {
2203 db.get_channel_members(parent_id).await?
2204 } else {
2205 vec![session.user_id]
2206 };
2207
2208 let connection_pool = session.connection_pool().await;
2209 for user_id in user_ids_to_notify {
2210 for connection_id in connection_pool.user_connection_ids(user_id) {
2211 let mut update = update.clone();
2212 if user_id == session.user_id {
2213 update.channel_permissions.push(proto::ChannelPermission {
2214 channel_id: id.to_proto(),
2215 is_admin: true,
2216 });
2217 }
2218 session.peer.send(connection_id, update)?;
2219 }
2220 }
2221
2222 Ok(())
2223}
2224
2225async fn remove_channel(
2226 request: proto::RemoveChannel,
2227 response: Response<proto::RemoveChannel>,
2228 session: Session,
2229) -> Result<()> {
2230 let db = session.db().await;
2231
2232 let channel_id = request.channel_id;
2233 let (removed_channels, member_ids) = db
2234 .remove_channel(ChannelId::from_proto(channel_id), session.user_id)
2235 .await?;
2236 response.send(proto::Ack {})?;
2237
2238 // Notify members of removed channels
2239 let mut update = proto::UpdateChannels::default();
2240 update
2241 .remove_channels
2242 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2243
2244 let connection_pool = session.connection_pool().await;
2245 for member_id in member_ids {
2246 for connection_id in connection_pool.user_connection_ids(member_id) {
2247 session.peer.send(connection_id, update.clone())?;
2248 }
2249 }
2250
2251 Ok(())
2252}
2253
2254async fn invite_channel_member(
2255 request: proto::InviteChannelMember,
2256 response: Response<proto::InviteChannelMember>,
2257 session: Session,
2258) -> Result<()> {
2259 let db = session.db().await;
2260 let channel_id = ChannelId::from_proto(request.channel_id);
2261 let invitee_id = UserId::from_proto(request.user_id);
2262 db.invite_channel_member(channel_id, invitee_id, session.user_id, request.admin)
2263 .await?;
2264
2265 let (channel, _) = db
2266 .get_channel(channel_id, session.user_id)
2267 .await?
2268 .ok_or_else(|| anyhow!("channel not found"))?;
2269
2270 let mut update = proto::UpdateChannels::default();
2271 update.channel_invitations.push(proto::Channel {
2272 id: channel.id.to_proto(),
2273 name: channel.name,
2274 parent_id: None,
2275 });
2276 for connection_id in session
2277 .connection_pool()
2278 .await
2279 .user_connection_ids(invitee_id)
2280 {
2281 session.peer.send(connection_id, update.clone())?;
2282 }
2283
2284 response.send(proto::Ack {})?;
2285 Ok(())
2286}
2287
2288async fn remove_channel_member(
2289 request: proto::RemoveChannelMember,
2290 response: Response<proto::RemoveChannelMember>,
2291 session: Session,
2292) -> Result<()> {
2293 let db = session.db().await;
2294 let channel_id = ChannelId::from_proto(request.channel_id);
2295 let member_id = UserId::from_proto(request.user_id);
2296
2297 db.remove_channel_member(channel_id, member_id, session.user_id)
2298 .await?;
2299
2300 let mut update = proto::UpdateChannels::default();
2301 update.remove_channels.push(channel_id.to_proto());
2302
2303 for connection_id in session
2304 .connection_pool()
2305 .await
2306 .user_connection_ids(member_id)
2307 {
2308 session.peer.send(connection_id, update.clone())?;
2309 }
2310
2311 response.send(proto::Ack {})?;
2312 Ok(())
2313}
2314
2315async fn set_channel_member_admin(
2316 request: proto::SetChannelMemberAdmin,
2317 response: Response<proto::SetChannelMemberAdmin>,
2318 session: Session,
2319) -> Result<()> {
2320 let db = session.db().await;
2321 let channel_id = ChannelId::from_proto(request.channel_id);
2322 let member_id = UserId::from_proto(request.user_id);
2323 db.set_channel_member_admin(channel_id, session.user_id, member_id, request.admin)
2324 .await?;
2325
2326 let (channel, has_accepted) = db
2327 .get_channel(channel_id, member_id)
2328 .await?
2329 .ok_or_else(|| anyhow!("channel not found"))?;
2330
2331 let mut update = proto::UpdateChannels::default();
2332 if has_accepted {
2333 update.channel_permissions.push(proto::ChannelPermission {
2334 channel_id: channel.id.to_proto(),
2335 is_admin: request.admin,
2336 });
2337 }
2338
2339 for connection_id in session
2340 .connection_pool()
2341 .await
2342 .user_connection_ids(member_id)
2343 {
2344 session.peer.send(connection_id, update.clone())?;
2345 }
2346
2347 response.send(proto::Ack {})?;
2348 Ok(())
2349}
2350
2351async fn rename_channel(
2352 request: proto::RenameChannel,
2353 response: Response<proto::RenameChannel>,
2354 session: Session,
2355) -> Result<()> {
2356 let db = session.db().await;
2357 let channel_id = ChannelId::from_proto(request.channel_id);
2358 let new_name = db
2359 .rename_channel(channel_id, session.user_id, &request.name)
2360 .await?;
2361
2362 let channel = proto::Channel {
2363 id: request.channel_id,
2364 name: new_name,
2365 parent_id: None,
2366 };
2367 response.send(proto::ChannelResponse {
2368 channel: Some(channel.clone()),
2369 })?;
2370 let mut update = proto::UpdateChannels::default();
2371 update.channels.push(channel);
2372
2373 let member_ids = db.get_channel_members(channel_id).await?;
2374
2375 let connection_pool = session.connection_pool().await;
2376 for member_id in member_ids {
2377 for connection_id in connection_pool.user_connection_ids(member_id) {
2378 session.peer.send(connection_id, update.clone())?;
2379 }
2380 }
2381
2382 Ok(())
2383}
2384
2385async fn get_channel_members(
2386 request: proto::GetChannelMembers,
2387 response: Response<proto::GetChannelMembers>,
2388 session: Session,
2389) -> Result<()> {
2390 let db = session.db().await;
2391 let channel_id = ChannelId::from_proto(request.channel_id);
2392 let members = db
2393 .get_channel_member_details(channel_id, session.user_id)
2394 .await?;
2395 response.send(proto::GetChannelMembersResponse { members })?;
2396 Ok(())
2397}
2398
2399async fn respond_to_channel_invite(
2400 request: proto::RespondToChannelInvite,
2401 response: Response<proto::RespondToChannelInvite>,
2402 session: Session,
2403) -> Result<()> {
2404 let db = session.db().await;
2405 let channel_id = ChannelId::from_proto(request.channel_id);
2406 db.respond_to_channel_invite(channel_id, session.user_id, request.accept)
2407 .await?;
2408
2409 let mut update = proto::UpdateChannels::default();
2410 update
2411 .remove_channel_invitations
2412 .push(channel_id.to_proto());
2413 if request.accept {
2414 let result = db.get_channels_for_user(session.user_id).await?;
2415 update
2416 .channels
2417 .extend(result.channels.into_iter().map(|channel| proto::Channel {
2418 id: channel.id.to_proto(),
2419 name: channel.name,
2420 parent_id: channel.parent_id.map(ChannelId::to_proto),
2421 }));
2422 update
2423 .channel_participants
2424 .extend(
2425 result
2426 .channel_participants
2427 .into_iter()
2428 .map(|(channel_id, user_ids)| proto::ChannelParticipants {
2429 channel_id: channel_id.to_proto(),
2430 participant_user_ids: user_ids.into_iter().map(UserId::to_proto).collect(),
2431 }),
2432 );
2433 update
2434 .channel_permissions
2435 .extend(
2436 result
2437 .channels_with_admin_privileges
2438 .into_iter()
2439 .map(|channel_id| proto::ChannelPermission {
2440 channel_id: channel_id.to_proto(),
2441 is_admin: true,
2442 }),
2443 );
2444 }
2445 session.peer.send(session.connection_id, update)?;
2446 response.send(proto::Ack {})?;
2447
2448 Ok(())
2449}
2450
2451async fn join_channel(
2452 request: proto::JoinChannel,
2453 response: Response<proto::JoinChannel>,
2454 session: Session,
2455) -> Result<()> {
2456 let channel_id = ChannelId::from_proto(request.channel_id);
2457
2458 let joined_room = {
2459 leave_room_for_session(&session).await?;
2460 let db = session.db().await;
2461
2462 let room_id = db.room_id_for_channel(channel_id).await?;
2463
2464 let joined_room = db
2465 .join_room(room_id, session.user_id, session.connection_id)
2466 .await?;
2467
2468 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2469 let token = live_kit
2470 .room_token(
2471 &joined_room.room.live_kit_room,
2472 &session.user_id.to_string(),
2473 )
2474 .trace_err()?;
2475
2476 Some(LiveKitConnectionInfo {
2477 server_url: live_kit.url().into(),
2478 token,
2479 })
2480 });
2481
2482 response.send(proto::JoinRoomResponse {
2483 room: Some(joined_room.room.clone()),
2484 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2485 live_kit_connection_info,
2486 })?;
2487
2488 room_updated(&joined_room.room, &session.peer);
2489
2490 joined_room.into_inner()
2491 };
2492
2493 channel_updated(
2494 channel_id,
2495 &joined_room.room,
2496 &joined_room.channel_members,
2497 &session.peer,
2498 &*session.connection_pool().await,
2499 );
2500
2501 update_user_contacts(session.user_id, &session).await?;
2502
2503 Ok(())
2504}
2505
2506async fn join_channel_buffer(
2507 request: proto::JoinChannelBuffer,
2508 response: Response<proto::JoinChannelBuffer>,
2509 session: Session,
2510) -> Result<()> {
2511 let db = session.db().await;
2512 let channel_id = ChannelId::from_proto(request.channel_id);
2513
2514 let open_response = db
2515 .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2516 .await?;
2517
2518 let replica_id = open_response.replica_id;
2519 let collaborators = open_response.collaborators.clone();
2520
2521 response.send(open_response)?;
2522
2523 let update = AddChannelBufferCollaborator {
2524 channel_id: channel_id.to_proto(),
2525 collaborator: Some(proto::Collaborator {
2526 user_id: session.user_id.to_proto(),
2527 peer_id: Some(session.connection_id.into()),
2528 replica_id,
2529 }),
2530 };
2531 channel_buffer_updated(
2532 session.connection_id,
2533 collaborators
2534 .iter()
2535 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2536 &update,
2537 &session.peer,
2538 );
2539
2540 Ok(())
2541}
2542
2543async fn update_channel_buffer(
2544 request: proto::UpdateChannelBuffer,
2545 session: Session,
2546) -> Result<()> {
2547 let db = session.db().await;
2548 let channel_id = ChannelId::from_proto(request.channel_id);
2549
2550 let collaborators = db
2551 .update_channel_buffer(channel_id, session.user_id, &request.operations)
2552 .await?;
2553
2554 channel_buffer_updated(
2555 session.connection_id,
2556 collaborators,
2557 &proto::UpdateChannelBuffer {
2558 channel_id: channel_id.to_proto(),
2559 operations: request.operations,
2560 },
2561 &session.peer,
2562 );
2563 Ok(())
2564}
2565
2566async fn rejoin_channel_buffers(
2567 request: proto::RejoinChannelBuffers,
2568 response: Response<proto::RejoinChannelBuffers>,
2569 session: Session,
2570) -> Result<()> {
2571 let db = session.db().await;
2572 let buffers = db
2573 .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2574 .await?;
2575
2576 for buffer in &buffers {
2577 let collaborators_to_notify = buffer
2578 .buffer
2579 .collaborators
2580 .iter()
2581 .filter_map(|c| Some(c.peer_id?.into()));
2582 channel_buffer_updated(
2583 session.connection_id,
2584 collaborators_to_notify,
2585 &proto::UpdateChannelBufferCollaborator {
2586 channel_id: buffer.buffer.channel_id,
2587 old_peer_id: Some(buffer.old_connection_id.into()),
2588 new_peer_id: Some(session.connection_id.into()),
2589 },
2590 &session.peer,
2591 );
2592 }
2593
2594 response.send(proto::RejoinChannelBuffersResponse {
2595 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2596 })?;
2597
2598 Ok(())
2599}
2600
2601async fn leave_channel_buffer(
2602 request: proto::LeaveChannelBuffer,
2603 response: Response<proto::LeaveChannelBuffer>,
2604 session: Session,
2605) -> Result<()> {
2606 let db = session.db().await;
2607 let channel_id = ChannelId::from_proto(request.channel_id);
2608
2609 let collaborators_to_notify = db
2610 .leave_channel_buffer(channel_id, session.connection_id)
2611 .await?;
2612
2613 response.send(Ack {})?;
2614
2615 channel_buffer_updated(
2616 session.connection_id,
2617 collaborators_to_notify,
2618 &proto::RemoveChannelBufferCollaborator {
2619 channel_id: channel_id.to_proto(),
2620 peer_id: Some(session.connection_id.into()),
2621 },
2622 &session.peer,
2623 );
2624
2625 Ok(())
2626}
2627
2628fn channel_buffer_updated<T: EnvelopedMessage>(
2629 sender_id: ConnectionId,
2630 collaborators: impl IntoIterator<Item = ConnectionId>,
2631 message: &T,
2632 peer: &Peer,
2633) {
2634 broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2635 peer.send(peer_id.into(), message.clone())
2636 });
2637}
2638
2639async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
2640 let project_id = ProjectId::from_proto(request.project_id);
2641 let project_connection_ids = session
2642 .db()
2643 .await
2644 .project_connection_ids(project_id, session.connection_id)
2645 .await?;
2646 broadcast(
2647 Some(session.connection_id),
2648 project_connection_ids.iter().copied(),
2649 |connection_id| {
2650 session
2651 .peer
2652 .forward_send(session.connection_id, connection_id, request.clone())
2653 },
2654 );
2655 Ok(())
2656}
2657
2658async fn get_private_user_info(
2659 _request: proto::GetPrivateUserInfo,
2660 response: Response<proto::GetPrivateUserInfo>,
2661 session: Session,
2662) -> Result<()> {
2663 let db = session.db().await;
2664
2665 let metrics_id = db.get_user_metrics_id(session.user_id).await?;
2666 let user = db
2667 .get_user_by_id(session.user_id)
2668 .await?
2669 .ok_or_else(|| anyhow!("user not found"))?;
2670 let flags = db.get_user_flags(session.user_id).await?;
2671
2672 response.send(proto::GetPrivateUserInfoResponse {
2673 metrics_id,
2674 staff: user.admin,
2675 flags,
2676 })?;
2677 Ok(())
2678}
2679
2680fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
2681 match message {
2682 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
2683 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
2684 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
2685 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
2686 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
2687 code: frame.code.into(),
2688 reason: frame.reason,
2689 })),
2690 }
2691}
2692
2693fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
2694 match message {
2695 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
2696 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
2697 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
2698 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
2699 AxumMessage::Close(frame) => {
2700 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
2701 code: frame.code.into(),
2702 reason: frame.reason,
2703 }))
2704 }
2705 }
2706}
2707
2708fn build_initial_channels_update(
2709 channels: ChannelsForUser,
2710 channel_invites: Vec<db::Channel>,
2711) -> proto::UpdateChannels {
2712 let mut update = proto::UpdateChannels::default();
2713
2714 for channel in channels.channels {
2715 update.channels.push(proto::Channel {
2716 id: channel.id.to_proto(),
2717 name: channel.name,
2718 parent_id: channel.parent_id.map(|id| id.to_proto()),
2719 });
2720 }
2721
2722 for (channel_id, participants) in channels.channel_participants {
2723 update
2724 .channel_participants
2725 .push(proto::ChannelParticipants {
2726 channel_id: channel_id.to_proto(),
2727 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
2728 });
2729 }
2730
2731 update
2732 .channel_permissions
2733 .extend(
2734 channels
2735 .channels_with_admin_privileges
2736 .into_iter()
2737 .map(|id| proto::ChannelPermission {
2738 channel_id: id.to_proto(),
2739 is_admin: true,
2740 }),
2741 );
2742
2743 for channel in channel_invites {
2744 update.channel_invitations.push(proto::Channel {
2745 id: channel.id.to_proto(),
2746 name: channel.name,
2747 parent_id: None,
2748 });
2749 }
2750
2751 update
2752}
2753
2754fn build_initial_contacts_update(
2755 contacts: Vec<db::Contact>,
2756 pool: &ConnectionPool,
2757) -> proto::UpdateContacts {
2758 let mut update = proto::UpdateContacts::default();
2759
2760 for contact in contacts {
2761 match contact {
2762 db::Contact::Accepted {
2763 user_id,
2764 should_notify,
2765 busy,
2766 } => {
2767 update
2768 .contacts
2769 .push(contact_for_user(user_id, should_notify, busy, &pool));
2770 }
2771 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
2772 db::Contact::Incoming {
2773 user_id,
2774 should_notify,
2775 } => update
2776 .incoming_requests
2777 .push(proto::IncomingContactRequest {
2778 requester_id: user_id.to_proto(),
2779 should_notify,
2780 }),
2781 }
2782 }
2783
2784 update
2785}
2786
2787fn contact_for_user(
2788 user_id: UserId,
2789 should_notify: bool,
2790 busy: bool,
2791 pool: &ConnectionPool,
2792) -> proto::Contact {
2793 proto::Contact {
2794 user_id: user_id.to_proto(),
2795 online: pool.is_user_online(user_id),
2796 busy,
2797 should_notify,
2798 }
2799}
2800
2801fn room_updated(room: &proto::Room, peer: &Peer) {
2802 broadcast(
2803 None,
2804 room.participants
2805 .iter()
2806 .filter_map(|participant| Some(participant.peer_id?.into())),
2807 |peer_id| {
2808 peer.send(
2809 peer_id.into(),
2810 proto::RoomUpdated {
2811 room: Some(room.clone()),
2812 },
2813 )
2814 },
2815 );
2816}
2817
2818fn channel_updated(
2819 channel_id: ChannelId,
2820 room: &proto::Room,
2821 channel_members: &[UserId],
2822 peer: &Peer,
2823 pool: &ConnectionPool,
2824) {
2825 let participants = room
2826 .participants
2827 .iter()
2828 .map(|p| p.user_id)
2829 .collect::<Vec<_>>();
2830
2831 broadcast(
2832 None,
2833 channel_members
2834 .iter()
2835 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2836 |peer_id| {
2837 peer.send(
2838 peer_id.into(),
2839 proto::UpdateChannels {
2840 channel_participants: vec![proto::ChannelParticipants {
2841 channel_id: channel_id.to_proto(),
2842 participant_user_ids: participants.clone(),
2843 }],
2844 ..Default::default()
2845 },
2846 )
2847 },
2848 );
2849}
2850
2851async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
2852 let db = session.db().await;
2853
2854 let contacts = db.get_contacts(user_id).await?;
2855 let busy = db.is_user_busy(user_id).await?;
2856
2857 let pool = session.connection_pool().await;
2858 let updated_contact = contact_for_user(user_id, false, busy, &pool);
2859 for contact in contacts {
2860 if let db::Contact::Accepted {
2861 user_id: contact_user_id,
2862 ..
2863 } = contact
2864 {
2865 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
2866 session
2867 .peer
2868 .send(
2869 contact_conn_id,
2870 proto::UpdateContacts {
2871 contacts: vec![updated_contact.clone()],
2872 remove_contacts: Default::default(),
2873 incoming_requests: Default::default(),
2874 remove_incoming_requests: Default::default(),
2875 outgoing_requests: Default::default(),
2876 remove_outgoing_requests: Default::default(),
2877 },
2878 )
2879 .trace_err();
2880 }
2881 }
2882 }
2883 Ok(())
2884}
2885
2886async fn leave_room_for_session(session: &Session) -> Result<()> {
2887 let mut contacts_to_update = HashSet::default();
2888
2889 let room_id;
2890 let canceled_calls_to_user_ids;
2891 let live_kit_room;
2892 let delete_live_kit_room;
2893 let room;
2894 let channel_members;
2895 let channel_id;
2896
2897 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
2898 contacts_to_update.insert(session.user_id);
2899
2900 for project in left_room.left_projects.values() {
2901 project_left(project, session);
2902 }
2903
2904 room_id = RoomId::from_proto(left_room.room.id);
2905 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
2906 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
2907 delete_live_kit_room = left_room.deleted;
2908 room = mem::take(&mut left_room.room);
2909 channel_members = mem::take(&mut left_room.channel_members);
2910 channel_id = left_room.channel_id;
2911
2912 room_updated(&room, &session.peer);
2913 } else {
2914 return Ok(());
2915 }
2916
2917 if let Some(channel_id) = channel_id {
2918 channel_updated(
2919 channel_id,
2920 &room,
2921 &channel_members,
2922 &session.peer,
2923 &*session.connection_pool().await,
2924 );
2925 }
2926
2927 {
2928 let pool = session.connection_pool().await;
2929 for canceled_user_id in canceled_calls_to_user_ids {
2930 for connection_id in pool.user_connection_ids(canceled_user_id) {
2931 session
2932 .peer
2933 .send(
2934 connection_id,
2935 proto::CallCanceled {
2936 room_id: room_id.to_proto(),
2937 },
2938 )
2939 .trace_err();
2940 }
2941 contacts_to_update.insert(canceled_user_id);
2942 }
2943 }
2944
2945 for contact_user_id in contacts_to_update {
2946 update_user_contacts(contact_user_id, &session).await?;
2947 }
2948
2949 if let Some(live_kit) = session.live_kit_client.as_ref() {
2950 live_kit
2951 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
2952 .await
2953 .trace_err();
2954
2955 if delete_live_kit_room {
2956 live_kit.delete_room(live_kit_room).await.trace_err();
2957 }
2958 }
2959
2960 Ok(())
2961}
2962
2963async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
2964 let left_channel_buffers = session
2965 .db()
2966 .await
2967 .leave_channel_buffers(session.connection_id)
2968 .await?;
2969
2970 for (channel_id, connections) in left_channel_buffers {
2971 channel_buffer_updated(
2972 session.connection_id,
2973 connections,
2974 &proto::RemoveChannelBufferCollaborator {
2975 channel_id: channel_id.to_proto(),
2976 peer_id: Some(session.connection_id.into()),
2977 },
2978 &session.peer,
2979 );
2980 }
2981
2982 Ok(())
2983}
2984
2985fn project_left(project: &db::LeftProject, session: &Session) {
2986 for connection_id in &project.connection_ids {
2987 if project.host_user_id == session.user_id {
2988 session
2989 .peer
2990 .send(
2991 *connection_id,
2992 proto::UnshareProject {
2993 project_id: project.id.to_proto(),
2994 },
2995 )
2996 .trace_err();
2997 } else {
2998 session
2999 .peer
3000 .send(
3001 *connection_id,
3002 proto::RemoveProjectCollaborator {
3003 project_id: project.id.to_proto(),
3004 peer_id: Some(session.connection_id.into()),
3005 },
3006 )
3007 .trace_err();
3008 }
3009 }
3010}
3011
3012pub trait ResultExt {
3013 type Ok;
3014
3015 fn trace_err(self) -> Option<Self::Ok>;
3016}
3017
3018impl<T, E> ResultExt for Result<T, E>
3019where
3020 E: std::fmt::Debug,
3021{
3022 type Ok = T;
3023
3024 fn trace_err(self) -> Option<T> {
3025 match self {
3026 Ok(value) => Some(value),
3027 Err(error) => {
3028 tracing::error!("{:?}", error);
3029 None
3030 }
3031 }
3032 }
3033}