1mod store;
2
3use crate::{
4 auth,
5 db::{self, ProjectId, RoomId, User, UserId},
6 AppState, Result,
7};
8use anyhow::anyhow;
9use async_tungstenite::tungstenite::{
10 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
11};
12use axum::{
13 body::Body,
14 extract::{
15 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
16 ConnectInfo, WebSocketUpgrade,
17 },
18 headers::{Header, HeaderName},
19 http::StatusCode,
20 middleware,
21 response::IntoResponse,
22 routing::get,
23 Extension, Router, TypedHeader,
24};
25use collections::{HashMap, HashSet};
26use futures::{
27 channel::oneshot,
28 future::{self, BoxFuture},
29 stream::FuturesUnordered,
30 FutureExt, SinkExt, StreamExt, TryStreamExt,
31};
32use lazy_static::lazy_static;
33use prometheus::{register_int_gauge, IntGauge};
34use rpc::{
35 proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
36 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
37};
38use serde::{Serialize, Serializer};
39use std::{
40 any::TypeId,
41 future::Future,
42 marker::PhantomData,
43 net::SocketAddr,
44 ops::{Deref, DerefMut},
45 os::unix::prelude::OsStrExt,
46 rc::Rc,
47 sync::{
48 atomic::{AtomicBool, Ordering::SeqCst},
49 Arc,
50 },
51 time::Duration,
52};
53pub use store::{Store, Worktree};
54use tokio::{
55 sync::{Mutex, MutexGuard},
56 time::Sleep,
57};
58use tower::ServiceBuilder;
59use tracing::{info_span, instrument, Instrument};
60
61lazy_static! {
62 static ref METRIC_CONNECTIONS: IntGauge =
63 register_int_gauge!("connections", "number of connections").unwrap();
64 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
65 "shared_projects",
66 "number of open projects with one or more guests"
67 )
68 .unwrap();
69}
70
71type MessageHandler = Box<
72 dyn Send + Sync + Fn(Arc<Server>, UserId, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, ()>,
73>;
74
75struct Message<T> {
76 sender_user_id: UserId,
77 sender_connection_id: ConnectionId,
78 payload: T,
79}
80
81struct Response<R> {
82 server: Arc<Server>,
83 receipt: Receipt<R>,
84 responded: Arc<AtomicBool>,
85}
86
87impl<R: RequestMessage> Response<R> {
88 fn send(self, payload: R::Response) -> Result<()> {
89 self.responded.store(true, SeqCst);
90 self.server.peer.respond(self.receipt, payload)?;
91 Ok(())
92 }
93}
94
95pub struct Server {
96 peer: Arc<Peer>,
97 pub(crate) store: Mutex<Store>,
98 app_state: Arc<AppState>,
99 handlers: HashMap<TypeId, MessageHandler>,
100}
101
102pub trait Executor: Send + Clone {
103 type Sleep: Send + Future;
104 fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F);
105 fn sleep(&self, duration: Duration) -> Self::Sleep;
106}
107
108#[derive(Clone)]
109pub struct RealExecutor;
110
111pub(crate) struct StoreGuard<'a> {
112 guard: MutexGuard<'a, Store>,
113 _not_send: PhantomData<Rc<()>>,
114}
115
116#[derive(Serialize)]
117pub struct ServerSnapshot<'a> {
118 peer: &'a Peer,
119 #[serde(serialize_with = "serialize_deref")]
120 store: StoreGuard<'a>,
121}
122
123pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
124where
125 S: Serializer,
126 T: Deref<Target = U>,
127 U: Serialize,
128{
129 Serialize::serialize(value.deref(), serializer)
130}
131
132impl Server {
133 pub fn new(app_state: Arc<AppState>) -> Arc<Self> {
134 let mut server = Self {
135 peer: Peer::new(),
136 app_state,
137 store: Default::default(),
138 handlers: Default::default(),
139 };
140
141 server
142 .add_request_handler(Server::ping)
143 .add_request_handler(Server::create_room)
144 .add_request_handler(Server::join_room)
145 .add_message_handler(Server::leave_room)
146 .add_request_handler(Server::call)
147 .add_request_handler(Server::cancel_call)
148 .add_message_handler(Server::decline_call)
149 .add_request_handler(Server::update_participant_location)
150 .add_request_handler(Server::share_project)
151 .add_message_handler(Server::unshare_project)
152 .add_request_handler(Server::join_project)
153 .add_message_handler(Server::leave_project)
154 .add_message_handler(Server::update_project)
155 .add_request_handler(Server::update_worktree)
156 .add_message_handler(Server::start_language_server)
157 .add_message_handler(Server::update_language_server)
158 .add_message_handler(Server::update_diagnostic_summary)
159 .add_request_handler(Server::forward_project_request::<proto::GetHover>)
160 .add_request_handler(Server::forward_project_request::<proto::GetDefinition>)
161 .add_request_handler(Server::forward_project_request::<proto::GetTypeDefinition>)
162 .add_request_handler(Server::forward_project_request::<proto::GetReferences>)
163 .add_request_handler(Server::forward_project_request::<proto::SearchProject>)
164 .add_request_handler(Server::forward_project_request::<proto::GetDocumentHighlights>)
165 .add_request_handler(Server::forward_project_request::<proto::GetProjectSymbols>)
166 .add_request_handler(Server::forward_project_request::<proto::OpenBufferForSymbol>)
167 .add_request_handler(Server::forward_project_request::<proto::OpenBufferById>)
168 .add_request_handler(Server::forward_project_request::<proto::OpenBufferByPath>)
169 .add_request_handler(Server::forward_project_request::<proto::GetCompletions>)
170 .add_request_handler(
171 Server::forward_project_request::<proto::ApplyCompletionAdditionalEdits>,
172 )
173 .add_request_handler(Server::forward_project_request::<proto::GetCodeActions>)
174 .add_request_handler(Server::forward_project_request::<proto::ApplyCodeAction>)
175 .add_request_handler(Server::forward_project_request::<proto::PrepareRename>)
176 .add_request_handler(Server::forward_project_request::<proto::PerformRename>)
177 .add_request_handler(Server::forward_project_request::<proto::ReloadBuffers>)
178 .add_request_handler(Server::forward_project_request::<proto::FormatBuffers>)
179 .add_request_handler(Server::forward_project_request::<proto::CreateProjectEntry>)
180 .add_request_handler(Server::forward_project_request::<proto::RenameProjectEntry>)
181 .add_request_handler(Server::forward_project_request::<proto::CopyProjectEntry>)
182 .add_request_handler(Server::forward_project_request::<proto::DeleteProjectEntry>)
183 .add_message_handler(Server::create_buffer_for_peer)
184 .add_request_handler(Server::update_buffer)
185 .add_message_handler(Server::update_buffer_file)
186 .add_message_handler(Server::buffer_reloaded)
187 .add_message_handler(Server::buffer_saved)
188 .add_request_handler(Server::save_buffer)
189 .add_request_handler(Server::get_users)
190 .add_request_handler(Server::fuzzy_search_users)
191 .add_request_handler(Server::request_contact)
192 .add_request_handler(Server::remove_contact)
193 .add_request_handler(Server::respond_to_contact_request)
194 .add_request_handler(Server::follow)
195 .add_message_handler(Server::unfollow)
196 .add_message_handler(Server::update_followers)
197 .add_message_handler(Server::update_diff_base)
198 .add_request_handler(Server::get_private_user_info);
199
200 Arc::new(server)
201 }
202
203 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
204 where
205 F: 'static + Send + Sync + Fn(Arc<Self>, UserId, TypedEnvelope<M>) -> Fut,
206 Fut: 'static + Send + Future<Output = Result<()>>,
207 M: EnvelopedMessage,
208 {
209 let prev_handler = self.handlers.insert(
210 TypeId::of::<M>(),
211 Box::new(move |server, sender_user_id, envelope| {
212 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
213 let span = info_span!(
214 "handle message",
215 payload_type = envelope.payload_type_name()
216 );
217 span.in_scope(|| {
218 tracing::info!(
219 payload_type = envelope.payload_type_name(),
220 "message received"
221 );
222 });
223 let future = (handler)(server, sender_user_id, *envelope);
224 async move {
225 if let Err(error) = future.await {
226 tracing::error!(%error, "error handling message");
227 }
228 }
229 .instrument(span)
230 .boxed()
231 }),
232 );
233 if prev_handler.is_some() {
234 panic!("registered a handler for the same message twice");
235 }
236 self
237 }
238
239 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
240 where
241 F: 'static + Send + Sync + Fn(Arc<Self>, Message<M>) -> Fut,
242 Fut: 'static + Send + Future<Output = Result<()>>,
243 M: EnvelopedMessage,
244 {
245 self.add_handler(move |server, sender_user_id, envelope| {
246 handler(
247 server,
248 Message {
249 sender_user_id,
250 sender_connection_id: envelope.sender_id,
251 payload: envelope.payload,
252 },
253 )
254 });
255 self
256 }
257
258 /// Handle a request while holding a lock to the store. This is useful when we're registering
259 /// a connection but we want to respond on the connection before anybody else can send on it.
260 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
261 where
262 F: 'static + Send + Sync + Fn(Arc<Self>, Message<M>, Response<M>) -> Fut,
263 Fut: Send + Future<Output = Result<()>>,
264 M: RequestMessage,
265 {
266 let handler = Arc::new(handler);
267 self.add_handler(move |server, sender_user_id, envelope| {
268 let receipt = envelope.receipt();
269 let handler = handler.clone();
270 async move {
271 let request = Message {
272 sender_user_id,
273 sender_connection_id: envelope.sender_id,
274 payload: envelope.payload,
275 };
276 let responded = Arc::new(AtomicBool::default());
277 let response = Response {
278 server: server.clone(),
279 responded: responded.clone(),
280 receipt,
281 };
282 match (handler)(server.clone(), request, response).await {
283 Ok(()) => {
284 if responded.load(std::sync::atomic::Ordering::SeqCst) {
285 Ok(())
286 } else {
287 Err(anyhow!("handler did not send a response"))?
288 }
289 }
290 Err(error) => {
291 server.peer.respond_with_error(
292 receipt,
293 proto::Error {
294 message: error.to_string(),
295 },
296 )?;
297 Err(error)
298 }
299 }
300 }
301 })
302 }
303
304 pub fn handle_connection<E: Executor>(
305 self: &Arc<Self>,
306 connection: Connection,
307 address: String,
308 user: User,
309 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
310 executor: E,
311 ) -> impl Future<Output = Result<()>> {
312 let mut this = self.clone();
313 let user_id = user.id;
314 let login = user.github_login;
315 let span = info_span!("handle connection", %user_id, %login, %address);
316 async move {
317 let (connection_id, handle_io, mut incoming_rx) = this
318 .peer
319 .add_connection(connection, {
320 let executor = executor.clone();
321 move |duration| {
322 let timer = executor.sleep(duration);
323 async move {
324 timer.await;
325 }
326 }
327 });
328
329 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
330 this.peer.send(connection_id, proto::Hello { peer_id: connection_id.0 })?;
331 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
332
333 if let Some(send_connection_id) = send_connection_id.take() {
334 let _ = send_connection_id.send(connection_id);
335 }
336
337 if !user.connected_once {
338 this.peer.send(connection_id, proto::ShowContacts {})?;
339 this.app_state.db.set_user_connected_once(user_id, true).await?;
340 }
341
342 let (contacts, invite_code) = future::try_join(
343 this.app_state.db.get_contacts(user_id),
344 this.app_state.db.get_invite_code_for_user(user_id)
345 ).await?;
346
347 {
348 let mut store = this.store().await;
349 let incoming_call = store.add_connection(connection_id, user_id, user.admin);
350 if let Some(incoming_call) = incoming_call {
351 this.peer.send(connection_id, incoming_call)?;
352 }
353
354 this.peer.send(connection_id, store.build_initial_contacts_update(contacts))?;
355
356 if let Some((code, count)) = invite_code {
357 this.peer.send(connection_id, proto::UpdateInviteInfo {
358 url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
359 count,
360 })?;
361 }
362 }
363 this.update_user_contacts(user_id).await?;
364
365 let handle_io = handle_io.fuse();
366 futures::pin_mut!(handle_io);
367
368 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
369 // This prevents deadlocks when e.g., client A performs a request to client B and
370 // client B performs a request to client A. If both clients stop processing further
371 // messages until their respective request completes, they won't have a chance to
372 // respond to the other client's request and cause a deadlock.
373 //
374 // This arrangement ensures we will attempt to process earlier messages first, but fall
375 // back to processing messages arrived later in the spirit of making progress.
376 let mut foreground_message_handlers = FuturesUnordered::new();
377 loop {
378 let next_message = incoming_rx.next().fuse();
379 futures::pin_mut!(next_message);
380 futures::select_biased! {
381 result = handle_io => {
382 if let Err(error) = result {
383 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
384 }
385 break;
386 }
387 _ = foreground_message_handlers.next() => {}
388 message = next_message => {
389 if let Some(message) = message {
390 let type_name = message.payload_type_name();
391 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
392 let span_enter = span.enter();
393 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
394 let is_background = message.is_background();
395 let handle_message = (handler)(this.clone(), user_id, message);
396 drop(span_enter);
397
398 let handle_message = handle_message.instrument(span);
399 if is_background {
400 executor.spawn_detached(handle_message);
401 } else {
402 foreground_message_handlers.push(handle_message);
403 }
404 } else {
405 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
406 }
407 } else {
408 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
409 break;
410 }
411 }
412 }
413 }
414
415 drop(foreground_message_handlers);
416 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
417 if let Err(error) = this.sign_out(connection_id).await {
418 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
419 }
420
421 Ok(())
422 }.instrument(span)
423 }
424
425 #[instrument(skip(self), err)]
426 async fn sign_out(self: &mut Arc<Self>, connection_id: ConnectionId) -> Result<()> {
427 self.peer.disconnect(connection_id);
428
429 let mut projects_to_unshare = Vec::new();
430 let mut contacts_to_update = HashSet::default();
431 let mut room_left = None;
432 {
433 let mut store = self.store().await;
434
435 #[cfg(test)]
436 let removed_connection = store.remove_connection(connection_id).unwrap();
437 #[cfg(not(test))]
438 let removed_connection = store.remove_connection(connection_id)?;
439
440 for project in removed_connection.hosted_projects {
441 projects_to_unshare.push(project.id);
442 broadcast(connection_id, project.guests.keys().copied(), |conn_id| {
443 self.peer.send(
444 conn_id,
445 proto::UnshareProject {
446 project_id: project.id.to_proto(),
447 },
448 )
449 });
450 }
451
452 for project in removed_connection.guest_projects {
453 broadcast(connection_id, project.connection_ids, |conn_id| {
454 self.peer.send(
455 conn_id,
456 proto::RemoveProjectCollaborator {
457 project_id: project.id.to_proto(),
458 peer_id: connection_id.0,
459 },
460 )
461 });
462 }
463
464 if let Some(room) = removed_connection.room {
465 self.room_updated(&room);
466 room_left = Some(self.room_left(&room, connection_id));
467 }
468
469 contacts_to_update.insert(removed_connection.user_id);
470 for connection_id in removed_connection.canceled_call_connection_ids {
471 self.peer
472 .send(connection_id, proto::CallCanceled {})
473 .trace_err();
474 contacts_to_update.extend(store.user_id_for_connection(connection_id).ok());
475 }
476 };
477
478 if let Some(room_left) = room_left {
479 room_left.await.trace_err();
480 }
481
482 for user_id in contacts_to_update {
483 self.update_user_contacts(user_id).await.trace_err();
484 }
485
486 for project_id in projects_to_unshare {
487 self.app_state
488 .db
489 .unshare_project(project_id)
490 .await
491 .trace_err();
492 }
493
494 Ok(())
495 }
496
497 pub async fn invite_code_redeemed(
498 self: &Arc<Self>,
499 inviter_id: UserId,
500 invitee_id: UserId,
501 ) -> Result<()> {
502 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
503 if let Some(code) = &user.invite_code {
504 let store = self.store().await;
505 let invitee_contact = store.contact_for_user(invitee_id, true);
506 for connection_id in store.connection_ids_for_user(inviter_id) {
507 self.peer.send(
508 connection_id,
509 proto::UpdateContacts {
510 contacts: vec![invitee_contact.clone()],
511 ..Default::default()
512 },
513 )?;
514 self.peer.send(
515 connection_id,
516 proto::UpdateInviteInfo {
517 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
518 count: user.invite_count as u32,
519 },
520 )?;
521 }
522 }
523 }
524 Ok(())
525 }
526
527 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
528 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
529 if let Some(invite_code) = &user.invite_code {
530 let store = self.store().await;
531 for connection_id in store.connection_ids_for_user(user_id) {
532 self.peer.send(
533 connection_id,
534 proto::UpdateInviteInfo {
535 url: format!(
536 "{}{}",
537 self.app_state.config.invite_link_prefix, invite_code
538 ),
539 count: user.invite_count as u32,
540 },
541 )?;
542 }
543 }
544 }
545 Ok(())
546 }
547
548 async fn ping(
549 self: Arc<Server>,
550 _: Message<proto::Ping>,
551 response: Response<proto::Ping>,
552 ) -> Result<()> {
553 response.send(proto::Ack {})?;
554 Ok(())
555 }
556
557 async fn create_room(
558 self: Arc<Server>,
559 request: Message<proto::CreateRoom>,
560 response: Response<proto::CreateRoom>,
561 ) -> Result<()> {
562 let room = self
563 .app_state
564 .db
565 .create_room(request.sender_user_id, request.sender_connection_id)
566 .await?;
567
568 let live_kit_connection_info =
569 if let Some(live_kit) = self.app_state.live_kit_client.as_ref() {
570 if let Some(_) = live_kit
571 .create_room(room.live_kit_room.clone())
572 .await
573 .trace_err()
574 {
575 if let Some(token) = live_kit
576 .room_token(
577 &room.live_kit_room,
578 &request.sender_connection_id.to_string(),
579 )
580 .trace_err()
581 {
582 Some(proto::LiveKitConnectionInfo {
583 server_url: live_kit.url().into(),
584 token,
585 })
586 } else {
587 None
588 }
589 } else {
590 None
591 }
592 } else {
593 None
594 };
595
596 response.send(proto::CreateRoomResponse {
597 room: Some(room),
598 live_kit_connection_info,
599 })?;
600 self.update_user_contacts(request.sender_user_id).await?;
601 Ok(())
602 }
603
604 async fn join_room(
605 self: Arc<Server>,
606 request: Message<proto::JoinRoom>,
607 response: Response<proto::JoinRoom>,
608 ) -> Result<()> {
609 {
610 let mut store = self.store().await;
611 let (room, recipient_connection_ids) =
612 store.join_room(request.payload.id, request.sender_connection_id)?;
613 for recipient_id in recipient_connection_ids {
614 self.peer
615 .send(recipient_id, proto::CallCanceled {})
616 .trace_err();
617 }
618
619 let live_kit_connection_info =
620 if let Some(live_kit) = self.app_state.live_kit_client.as_ref() {
621 if let Some(token) = live_kit
622 .room_token(
623 &room.live_kit_room,
624 &request.sender_connection_id.to_string(),
625 )
626 .trace_err()
627 {
628 Some(proto::LiveKitConnectionInfo {
629 server_url: live_kit.url().into(),
630 token,
631 })
632 } else {
633 None
634 }
635 } else {
636 None
637 };
638
639 response.send(proto::JoinRoomResponse {
640 room: Some(room.clone()),
641 live_kit_connection_info,
642 })?;
643 self.room_updated(room);
644 }
645 self.update_user_contacts(request.sender_user_id).await?;
646 Ok(())
647 }
648
649 async fn leave_room(self: Arc<Server>, message: Message<proto::LeaveRoom>) -> Result<()> {
650 let mut contacts_to_update = HashSet::default();
651 let room_left;
652 {
653 let mut store = self.store().await;
654 let left_room = store.leave_room(message.payload.id, message.sender_connection_id)?;
655 contacts_to_update.insert(message.sender_user_id);
656
657 for project in left_room.unshared_projects {
658 for connection_id in project.connection_ids() {
659 self.peer.send(
660 connection_id,
661 proto::UnshareProject {
662 project_id: project.id.to_proto(),
663 },
664 )?;
665 }
666 }
667
668 for project in left_room.left_projects {
669 if project.remove_collaborator {
670 for connection_id in project.connection_ids {
671 self.peer.send(
672 connection_id,
673 proto::RemoveProjectCollaborator {
674 project_id: project.id.to_proto(),
675 peer_id: message.sender_connection_id.0,
676 },
677 )?;
678 }
679
680 self.peer.send(
681 message.sender_connection_id,
682 proto::UnshareProject {
683 project_id: project.id.to_proto(),
684 },
685 )?;
686 }
687 }
688
689 self.room_updated(&left_room.room);
690 room_left = self.room_left(&left_room.room, message.sender_connection_id);
691
692 for connection_id in left_room.canceled_call_connection_ids {
693 self.peer
694 .send(connection_id, proto::CallCanceled {})
695 .trace_err();
696 contacts_to_update.extend(store.user_id_for_connection(connection_id).ok());
697 }
698 }
699
700 room_left.await.trace_err();
701 for user_id in contacts_to_update {
702 self.update_user_contacts(user_id).await?;
703 }
704
705 Ok(())
706 }
707
708 async fn call(
709 self: Arc<Server>,
710 request: Message<proto::Call>,
711 response: Response<proto::Call>,
712 ) -> Result<()> {
713 let room_id = RoomId::from_proto(request.payload.room_id);
714 let calling_user_id = request.sender_user_id;
715 let called_user_id = UserId::from_proto(request.payload.called_user_id);
716 let initial_project_id = request
717 .payload
718 .initial_project_id
719 .map(ProjectId::from_proto);
720 if !self
721 .app_state
722 .db
723 .has_contact(calling_user_id, called_user_id)
724 .await?
725 {
726 return Err(anyhow!("cannot call a user who isn't a contact"))?;
727 }
728
729 let room = self
730 .app_state
731 .db
732 .call(room_id, calling_user_id, called_user_id, initial_project_id)
733 .await?;
734 self.room_updated(&room);
735 self.update_user_contacts(called_user_id).await?;
736
737 let incoming_call = proto::IncomingCall {
738 room_id: room_id.to_proto(),
739 calling_user_id: calling_user_id.to_proto(),
740 participant_user_ids: room
741 .participants
742 .iter()
743 .map(|participant| participant.user_id)
744 .collect(),
745 initial_project: room.participants.iter().find_map(|participant| {
746 let initial_project_id = initial_project_id?.to_proto();
747 participant
748 .projects
749 .iter()
750 .find(|project| project.id == initial_project_id)
751 .cloned()
752 }),
753 };
754
755 let mut calls = self
756 .store()
757 .await
758 .connection_ids_for_user(called_user_id)
759 .map(|connection_id| self.peer.request(connection_id, incoming_call.clone()))
760 .collect::<FuturesUnordered<_>>();
761
762 while let Some(call_response) = calls.next().await {
763 match call_response.as_ref() {
764 Ok(_) => {
765 response.send(proto::Ack {})?;
766 return Ok(());
767 }
768 Err(_) => {
769 call_response.trace_err();
770 }
771 }
772 }
773
774 let room = self
775 .app_state
776 .db
777 .call_failed(room_id, called_user_id)
778 .await?;
779 self.room_updated(&room);
780 self.update_user_contacts(called_user_id).await?;
781
782 Err(anyhow!("failed to ring call recipient"))?
783 }
784
785 async fn cancel_call(
786 self: Arc<Server>,
787 request: Message<proto::CancelCall>,
788 response: Response<proto::CancelCall>,
789 ) -> Result<()> {
790 let recipient_user_id = UserId::from_proto(request.payload.called_user_id);
791 {
792 let mut store = self.store().await;
793 let (room, recipient_connection_ids) = store.cancel_call(
794 request.payload.room_id,
795 recipient_user_id,
796 request.sender_connection_id,
797 )?;
798 for recipient_id in recipient_connection_ids {
799 self.peer
800 .send(recipient_id, proto::CallCanceled {})
801 .trace_err();
802 }
803 self.room_updated(room);
804 response.send(proto::Ack {})?;
805 }
806 self.update_user_contacts(recipient_user_id).await?;
807 Ok(())
808 }
809
810 async fn decline_call(self: Arc<Server>, message: Message<proto::DeclineCall>) -> Result<()> {
811 let recipient_user_id = message.sender_user_id;
812 {
813 let mut store = self.store().await;
814 let (room, recipient_connection_ids) =
815 store.decline_call(message.payload.room_id, message.sender_connection_id)?;
816 for recipient_id in recipient_connection_ids {
817 self.peer
818 .send(recipient_id, proto::CallCanceled {})
819 .trace_err();
820 }
821 self.room_updated(room);
822 }
823 self.update_user_contacts(recipient_user_id).await?;
824 Ok(())
825 }
826
827 async fn update_participant_location(
828 self: Arc<Server>,
829 request: Message<proto::UpdateParticipantLocation>,
830 response: Response<proto::UpdateParticipantLocation>,
831 ) -> Result<()> {
832 let room_id = RoomId::from_proto(request.payload.room_id);
833 let location = request
834 .payload
835 .location
836 .ok_or_else(|| anyhow!("invalid location"))?;
837 let room = self
838 .app_state
839 .db
840 .update_room_participant_location(room_id, request.sender_user_id, location)
841 .await?;
842 self.room_updated(&room);
843 response.send(proto::Ack {})?;
844 Ok(())
845 }
846
847 fn room_updated(&self, room: &proto::Room) {
848 for participant in &room.participants {
849 self.peer
850 .send(
851 ConnectionId(participant.peer_id),
852 proto::RoomUpdated {
853 room: Some(room.clone()),
854 },
855 )
856 .trace_err();
857 }
858 }
859
860 fn room_left(
861 &self,
862 room: &proto::Room,
863 connection_id: ConnectionId,
864 ) -> impl Future<Output = Result<()>> {
865 let client = self.app_state.live_kit_client.clone();
866 let room_name = room.live_kit_room.clone();
867 let participant_count = room.participants.len();
868 async move {
869 if let Some(client) = client {
870 client
871 .remove_participant(room_name.clone(), connection_id.to_string())
872 .await?;
873
874 if participant_count == 0 {
875 client.delete_room(room_name).await?;
876 }
877 }
878
879 Ok(())
880 }
881 }
882
883 async fn share_project(
884 self: Arc<Server>,
885 request: Message<proto::ShareProject>,
886 response: Response<proto::ShareProject>,
887 ) -> Result<()> {
888 let (project_id, room) = self
889 .app_state
890 .db
891 .share_project(
892 request.sender_user_id,
893 request.sender_connection_id,
894 RoomId::from_proto(request.payload.room_id),
895 &request.payload.worktrees,
896 )
897 .await?;
898 response.send(proto::ShareProjectResponse {
899 project_id: project_id.to_proto(),
900 })?;
901 self.room_updated(&room);
902
903 Ok(())
904 }
905
906 async fn unshare_project(
907 self: Arc<Server>,
908 message: Message<proto::UnshareProject>,
909 ) -> Result<()> {
910 let project_id = ProjectId::from_proto(message.payload.project_id);
911 let mut store = self.store().await;
912 let (room, project) = store.unshare_project(project_id, message.sender_connection_id)?;
913 broadcast(
914 message.sender_connection_id,
915 project.guest_connection_ids(),
916 |conn_id| self.peer.send(conn_id, message.payload.clone()),
917 );
918 self.room_updated(room);
919
920 Ok(())
921 }
922
923 async fn update_user_contacts(self: &Arc<Server>, user_id: UserId) -> Result<()> {
924 let contacts = self.app_state.db.get_contacts(user_id).await?;
925 let store = self.store().await;
926 let updated_contact = store.contact_for_user(user_id, false);
927 for contact in contacts {
928 if let db::Contact::Accepted {
929 user_id: contact_user_id,
930 ..
931 } = contact
932 {
933 for contact_conn_id in store.connection_ids_for_user(contact_user_id) {
934 self.peer
935 .send(
936 contact_conn_id,
937 proto::UpdateContacts {
938 contacts: vec![updated_contact.clone()],
939 remove_contacts: Default::default(),
940 incoming_requests: Default::default(),
941 remove_incoming_requests: Default::default(),
942 outgoing_requests: Default::default(),
943 remove_outgoing_requests: Default::default(),
944 },
945 )
946 .trace_err();
947 }
948 }
949 }
950 Ok(())
951 }
952
953 async fn join_project(
954 self: Arc<Server>,
955 request: Message<proto::JoinProject>,
956 response: Response<proto::JoinProject>,
957 ) -> Result<()> {
958 let project_id = ProjectId::from_proto(request.payload.project_id);
959 let guest_user_id = request.sender_user_id;
960 let host_user_id;
961 let host_connection_id;
962 {
963 let state = self.store().await;
964 let project = state.project(project_id)?;
965 host_user_id = project.host.user_id;
966 host_connection_id = project.host_connection_id;
967 };
968
969 tracing::info!(%project_id, %host_user_id, %host_connection_id, "join project");
970
971 let mut store = self.store().await;
972 let (project, replica_id) = store.join_project(request.sender_connection_id, project_id)?;
973 let peer_count = project.guests.len();
974 let mut collaborators = Vec::with_capacity(peer_count);
975 collaborators.push(proto::Collaborator {
976 peer_id: project.host_connection_id.0,
977 replica_id: 0,
978 user_id: project.host.user_id.to_proto(),
979 });
980 let worktrees = project
981 .worktrees
982 .iter()
983 .map(|(id, worktree)| proto::WorktreeMetadata {
984 id: *id,
985 root_name: worktree.root_name.clone(),
986 visible: worktree.visible,
987 abs_path: worktree.abs_path.as_os_str().as_bytes().to_vec(),
988 })
989 .collect::<Vec<_>>();
990
991 // Add all guests other than the requesting user's own connections as collaborators
992 for (guest_conn_id, guest) in &project.guests {
993 if request.sender_connection_id != *guest_conn_id {
994 collaborators.push(proto::Collaborator {
995 peer_id: guest_conn_id.0,
996 replica_id: guest.replica_id as u32,
997 user_id: guest.user_id.to_proto(),
998 });
999 }
1000 }
1001
1002 for conn_id in project.connection_ids() {
1003 if conn_id != request.sender_connection_id {
1004 self.peer
1005 .send(
1006 conn_id,
1007 proto::AddProjectCollaborator {
1008 project_id: project_id.to_proto(),
1009 collaborator: Some(proto::Collaborator {
1010 peer_id: request.sender_connection_id.0,
1011 replica_id: replica_id as u32,
1012 user_id: guest_user_id.to_proto(),
1013 }),
1014 },
1015 )
1016 .trace_err();
1017 }
1018 }
1019
1020 // First, we send the metadata associated with each worktree.
1021 response.send(proto::JoinProjectResponse {
1022 worktrees: worktrees.clone(),
1023 replica_id: replica_id as u32,
1024 collaborators: collaborators.clone(),
1025 language_servers: project.language_servers.clone(),
1026 })?;
1027
1028 for (worktree_id, worktree) in &project.worktrees {
1029 #[cfg(any(test, feature = "test-support"))]
1030 const MAX_CHUNK_SIZE: usize = 2;
1031 #[cfg(not(any(test, feature = "test-support")))]
1032 const MAX_CHUNK_SIZE: usize = 256;
1033
1034 // Stream this worktree's entries.
1035 let message = proto::UpdateWorktree {
1036 project_id: project_id.to_proto(),
1037 worktree_id: *worktree_id,
1038 abs_path: worktree.abs_path.as_os_str().as_bytes().to_vec(),
1039 root_name: worktree.root_name.clone(),
1040 updated_entries: worktree.entries.values().cloned().collect(),
1041 removed_entries: Default::default(),
1042 scan_id: worktree.scan_id,
1043 is_last_update: worktree.is_complete,
1044 };
1045 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1046 self.peer
1047 .send(request.sender_connection_id, update.clone())?;
1048 }
1049
1050 // Stream this worktree's diagnostics.
1051 for summary in worktree.diagnostic_summaries.values() {
1052 self.peer.send(
1053 request.sender_connection_id,
1054 proto::UpdateDiagnosticSummary {
1055 project_id: project_id.to_proto(),
1056 worktree_id: *worktree_id,
1057 summary: Some(summary.clone()),
1058 },
1059 )?;
1060 }
1061 }
1062
1063 for language_server in &project.language_servers {
1064 self.peer.send(
1065 request.sender_connection_id,
1066 proto::UpdateLanguageServer {
1067 project_id: project_id.to_proto(),
1068 language_server_id: language_server.id,
1069 variant: Some(
1070 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1071 proto::LspDiskBasedDiagnosticsUpdated {},
1072 ),
1073 ),
1074 },
1075 )?;
1076 }
1077
1078 Ok(())
1079 }
1080
1081 async fn leave_project(self: Arc<Server>, request: Message<proto::LeaveProject>) -> Result<()> {
1082 let sender_id = request.sender_connection_id;
1083 let project_id = ProjectId::from_proto(request.payload.project_id);
1084 let project;
1085 {
1086 let mut store = self.store().await;
1087 project = store.leave_project(project_id, sender_id)?;
1088 tracing::info!(
1089 %project_id,
1090 host_user_id = %project.host_user_id,
1091 host_connection_id = %project.host_connection_id,
1092 "leave project"
1093 );
1094
1095 if project.remove_collaborator {
1096 broadcast(sender_id, project.connection_ids, |conn_id| {
1097 self.peer.send(
1098 conn_id,
1099 proto::RemoveProjectCollaborator {
1100 project_id: project_id.to_proto(),
1101 peer_id: sender_id.0,
1102 },
1103 )
1104 });
1105 }
1106 }
1107
1108 Ok(())
1109 }
1110
1111 async fn update_project(
1112 self: Arc<Server>,
1113 request: Message<proto::UpdateProject>,
1114 ) -> Result<()> {
1115 let project_id = ProjectId::from_proto(request.payload.project_id);
1116 {
1117 let mut state = self.store().await;
1118 let guest_connection_ids = state
1119 .read_project(project_id, request.sender_connection_id)?
1120 .guest_connection_ids();
1121 let room = state.update_project(
1122 project_id,
1123 &request.payload.worktrees,
1124 request.sender_connection_id,
1125 )?;
1126 broadcast(
1127 request.sender_connection_id,
1128 guest_connection_ids,
1129 |connection_id| {
1130 self.peer.forward_send(
1131 request.sender_connection_id,
1132 connection_id,
1133 request.payload.clone(),
1134 )
1135 },
1136 );
1137 self.room_updated(room);
1138 };
1139
1140 Ok(())
1141 }
1142
1143 async fn update_worktree(
1144 self: Arc<Server>,
1145 request: Message<proto::UpdateWorktree>,
1146 response: Response<proto::UpdateWorktree>,
1147 ) -> Result<()> {
1148 let project_id = ProjectId::from_proto(request.payload.project_id);
1149 let worktree_id = request.payload.worktree_id;
1150 let connection_ids = self.store().await.update_worktree(
1151 request.sender_connection_id,
1152 project_id,
1153 worktree_id,
1154 &request.payload.root_name,
1155 &request.payload.removed_entries,
1156 &request.payload.updated_entries,
1157 request.payload.scan_id,
1158 request.payload.is_last_update,
1159 )?;
1160
1161 broadcast(
1162 request.sender_connection_id,
1163 connection_ids,
1164 |connection_id| {
1165 self.peer.forward_send(
1166 request.sender_connection_id,
1167 connection_id,
1168 request.payload.clone(),
1169 )
1170 },
1171 );
1172 response.send(proto::Ack {})?;
1173 Ok(())
1174 }
1175
1176 async fn update_diagnostic_summary(
1177 self: Arc<Server>,
1178 request: Message<proto::UpdateDiagnosticSummary>,
1179 ) -> Result<()> {
1180 let summary = request
1181 .payload
1182 .summary
1183 .clone()
1184 .ok_or_else(|| anyhow!("invalid summary"))?;
1185 let receiver_ids = self.store().await.update_diagnostic_summary(
1186 ProjectId::from_proto(request.payload.project_id),
1187 request.payload.worktree_id,
1188 request.sender_connection_id,
1189 summary,
1190 )?;
1191
1192 broadcast(
1193 request.sender_connection_id,
1194 receiver_ids,
1195 |connection_id| {
1196 self.peer.forward_send(
1197 request.sender_connection_id,
1198 connection_id,
1199 request.payload.clone(),
1200 )
1201 },
1202 );
1203 Ok(())
1204 }
1205
1206 async fn start_language_server(
1207 self: Arc<Server>,
1208 request: Message<proto::StartLanguageServer>,
1209 ) -> Result<()> {
1210 let receiver_ids = self.store().await.start_language_server(
1211 ProjectId::from_proto(request.payload.project_id),
1212 request.sender_connection_id,
1213 request
1214 .payload
1215 .server
1216 .clone()
1217 .ok_or_else(|| anyhow!("invalid language server"))?,
1218 )?;
1219 broadcast(
1220 request.sender_connection_id,
1221 receiver_ids,
1222 |connection_id| {
1223 self.peer.forward_send(
1224 request.sender_connection_id,
1225 connection_id,
1226 request.payload.clone(),
1227 )
1228 },
1229 );
1230 Ok(())
1231 }
1232
1233 async fn update_language_server(
1234 self: Arc<Server>,
1235 request: Message<proto::UpdateLanguageServer>,
1236 ) -> Result<()> {
1237 let receiver_ids = self.store().await.project_connection_ids(
1238 ProjectId::from_proto(request.payload.project_id),
1239 request.sender_connection_id,
1240 )?;
1241 broadcast(
1242 request.sender_connection_id,
1243 receiver_ids,
1244 |connection_id| {
1245 self.peer.forward_send(
1246 request.sender_connection_id,
1247 connection_id,
1248 request.payload.clone(),
1249 )
1250 },
1251 );
1252 Ok(())
1253 }
1254
1255 async fn forward_project_request<T>(
1256 self: Arc<Server>,
1257 request: Message<T>,
1258 response: Response<T>,
1259 ) -> Result<()>
1260 where
1261 T: EntityMessage + RequestMessage,
1262 {
1263 let project_id = ProjectId::from_proto(request.payload.remote_entity_id());
1264 let host_connection_id = self
1265 .store()
1266 .await
1267 .read_project(project_id, request.sender_connection_id)?
1268 .host_connection_id;
1269 let payload = self
1270 .peer
1271 .forward_request(
1272 request.sender_connection_id,
1273 host_connection_id,
1274 request.payload,
1275 )
1276 .await?;
1277
1278 // Ensure project still exists by the time we get the response from the host.
1279 self.store()
1280 .await
1281 .read_project(project_id, request.sender_connection_id)?;
1282
1283 response.send(payload)?;
1284 Ok(())
1285 }
1286
1287 async fn save_buffer(
1288 self: Arc<Server>,
1289 request: Message<proto::SaveBuffer>,
1290 response: Response<proto::SaveBuffer>,
1291 ) -> Result<()> {
1292 let project_id = ProjectId::from_proto(request.payload.project_id);
1293 let host = self
1294 .store()
1295 .await
1296 .read_project(project_id, request.sender_connection_id)?
1297 .host_connection_id;
1298 let response_payload = self
1299 .peer
1300 .forward_request(request.sender_connection_id, host, request.payload.clone())
1301 .await?;
1302
1303 let mut guests = self
1304 .store()
1305 .await
1306 .read_project(project_id, request.sender_connection_id)?
1307 .connection_ids();
1308 guests.retain(|guest_connection_id| *guest_connection_id != request.sender_connection_id);
1309 broadcast(host, guests, |conn_id| {
1310 self.peer
1311 .forward_send(host, conn_id, response_payload.clone())
1312 });
1313 response.send(response_payload)?;
1314 Ok(())
1315 }
1316
1317 async fn create_buffer_for_peer(
1318 self: Arc<Server>,
1319 request: Message<proto::CreateBufferForPeer>,
1320 ) -> Result<()> {
1321 self.peer.forward_send(
1322 request.sender_connection_id,
1323 ConnectionId(request.payload.peer_id),
1324 request.payload,
1325 )?;
1326 Ok(())
1327 }
1328
1329 async fn update_buffer(
1330 self: Arc<Server>,
1331 request: Message<proto::UpdateBuffer>,
1332 response: Response<proto::UpdateBuffer>,
1333 ) -> Result<()> {
1334 let project_id = ProjectId::from_proto(request.payload.project_id);
1335 let receiver_ids = {
1336 let store = self.store().await;
1337 store.project_connection_ids(project_id, request.sender_connection_id)?
1338 };
1339
1340 broadcast(
1341 request.sender_connection_id,
1342 receiver_ids,
1343 |connection_id| {
1344 self.peer.forward_send(
1345 request.sender_connection_id,
1346 connection_id,
1347 request.payload.clone(),
1348 )
1349 },
1350 );
1351 response.send(proto::Ack {})?;
1352 Ok(())
1353 }
1354
1355 async fn update_buffer_file(
1356 self: Arc<Server>,
1357 request: Message<proto::UpdateBufferFile>,
1358 ) -> Result<()> {
1359 let receiver_ids = self.store().await.project_connection_ids(
1360 ProjectId::from_proto(request.payload.project_id),
1361 request.sender_connection_id,
1362 )?;
1363 broadcast(
1364 request.sender_connection_id,
1365 receiver_ids,
1366 |connection_id| {
1367 self.peer.forward_send(
1368 request.sender_connection_id,
1369 connection_id,
1370 request.payload.clone(),
1371 )
1372 },
1373 );
1374 Ok(())
1375 }
1376
1377 async fn buffer_reloaded(
1378 self: Arc<Server>,
1379 request: Message<proto::BufferReloaded>,
1380 ) -> Result<()> {
1381 let receiver_ids = self.store().await.project_connection_ids(
1382 ProjectId::from_proto(request.payload.project_id),
1383 request.sender_connection_id,
1384 )?;
1385 broadcast(
1386 request.sender_connection_id,
1387 receiver_ids,
1388 |connection_id| {
1389 self.peer.forward_send(
1390 request.sender_connection_id,
1391 connection_id,
1392 request.payload.clone(),
1393 )
1394 },
1395 );
1396 Ok(())
1397 }
1398
1399 async fn buffer_saved(self: Arc<Server>, request: Message<proto::BufferSaved>) -> Result<()> {
1400 let receiver_ids = self.store().await.project_connection_ids(
1401 ProjectId::from_proto(request.payload.project_id),
1402 request.sender_connection_id,
1403 )?;
1404 broadcast(
1405 request.sender_connection_id,
1406 receiver_ids,
1407 |connection_id| {
1408 self.peer.forward_send(
1409 request.sender_connection_id,
1410 connection_id,
1411 request.payload.clone(),
1412 )
1413 },
1414 );
1415 Ok(())
1416 }
1417
1418 async fn follow(
1419 self: Arc<Self>,
1420 request: Message<proto::Follow>,
1421 response: Response<proto::Follow>,
1422 ) -> Result<()> {
1423 let project_id = ProjectId::from_proto(request.payload.project_id);
1424 let leader_id = ConnectionId(request.payload.leader_id);
1425 let follower_id = request.sender_connection_id;
1426 {
1427 let store = self.store().await;
1428 if !store
1429 .project_connection_ids(project_id, follower_id)?
1430 .contains(&leader_id)
1431 {
1432 Err(anyhow!("no such peer"))?;
1433 }
1434 }
1435
1436 let mut response_payload = self
1437 .peer
1438 .forward_request(request.sender_connection_id, leader_id, request.payload)
1439 .await?;
1440 response_payload
1441 .views
1442 .retain(|view| view.leader_id != Some(follower_id.0));
1443 response.send(response_payload)?;
1444 Ok(())
1445 }
1446
1447 async fn unfollow(self: Arc<Self>, request: Message<proto::Unfollow>) -> Result<()> {
1448 let project_id = ProjectId::from_proto(request.payload.project_id);
1449 let leader_id = ConnectionId(request.payload.leader_id);
1450 let store = self.store().await;
1451 if !store
1452 .project_connection_ids(project_id, request.sender_connection_id)?
1453 .contains(&leader_id)
1454 {
1455 Err(anyhow!("no such peer"))?;
1456 }
1457 self.peer
1458 .forward_send(request.sender_connection_id, leader_id, request.payload)?;
1459 Ok(())
1460 }
1461
1462 async fn update_followers(
1463 self: Arc<Self>,
1464 request: Message<proto::UpdateFollowers>,
1465 ) -> Result<()> {
1466 let project_id = ProjectId::from_proto(request.payload.project_id);
1467 let store = self.store().await;
1468 let connection_ids =
1469 store.project_connection_ids(project_id, request.sender_connection_id)?;
1470 let leader_id = request
1471 .payload
1472 .variant
1473 .as_ref()
1474 .and_then(|variant| match variant {
1475 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1476 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1477 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1478 });
1479 for follower_id in &request.payload.follower_ids {
1480 let follower_id = ConnectionId(*follower_id);
1481 if connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
1482 self.peer.forward_send(
1483 request.sender_connection_id,
1484 follower_id,
1485 request.payload.clone(),
1486 )?;
1487 }
1488 }
1489 Ok(())
1490 }
1491
1492 async fn get_users(
1493 self: Arc<Server>,
1494 request: Message<proto::GetUsers>,
1495 response: Response<proto::GetUsers>,
1496 ) -> Result<()> {
1497 let user_ids = request
1498 .payload
1499 .user_ids
1500 .into_iter()
1501 .map(UserId::from_proto)
1502 .collect();
1503 let users = self
1504 .app_state
1505 .db
1506 .get_users_by_ids(user_ids)
1507 .await?
1508 .into_iter()
1509 .map(|user| proto::User {
1510 id: user.id.to_proto(),
1511 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1512 github_login: user.github_login,
1513 })
1514 .collect();
1515 response.send(proto::UsersResponse { users })?;
1516 Ok(())
1517 }
1518
1519 async fn fuzzy_search_users(
1520 self: Arc<Server>,
1521 request: Message<proto::FuzzySearchUsers>,
1522 response: Response<proto::FuzzySearchUsers>,
1523 ) -> Result<()> {
1524 let query = request.payload.query;
1525 let db = &self.app_state.db;
1526 let users = match query.len() {
1527 0 => vec![],
1528 1 | 2 => db
1529 .get_user_by_github_account(&query, None)
1530 .await?
1531 .into_iter()
1532 .collect(),
1533 _ => db.fuzzy_search_users(&query, 10).await?,
1534 };
1535 let users = users
1536 .into_iter()
1537 .filter(|user| user.id != request.sender_user_id)
1538 .map(|user| proto::User {
1539 id: user.id.to_proto(),
1540 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1541 github_login: user.github_login,
1542 })
1543 .collect();
1544 response.send(proto::UsersResponse { users })?;
1545 Ok(())
1546 }
1547
1548 async fn request_contact(
1549 self: Arc<Server>,
1550 request: Message<proto::RequestContact>,
1551 response: Response<proto::RequestContact>,
1552 ) -> Result<()> {
1553 let requester_id = request.sender_user_id;
1554 let responder_id = UserId::from_proto(request.payload.responder_id);
1555 if requester_id == responder_id {
1556 return Err(anyhow!("cannot add yourself as a contact"))?;
1557 }
1558
1559 self.app_state
1560 .db
1561 .send_contact_request(requester_id, responder_id)
1562 .await?;
1563
1564 // Update outgoing contact requests of requester
1565 let mut update = proto::UpdateContacts::default();
1566 update.outgoing_requests.push(responder_id.to_proto());
1567 for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1568 self.peer.send(connection_id, update.clone())?;
1569 }
1570
1571 // Update incoming contact requests of responder
1572 let mut update = proto::UpdateContacts::default();
1573 update
1574 .incoming_requests
1575 .push(proto::IncomingContactRequest {
1576 requester_id: requester_id.to_proto(),
1577 should_notify: true,
1578 });
1579 for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1580 self.peer.send(connection_id, update.clone())?;
1581 }
1582
1583 response.send(proto::Ack {})?;
1584 Ok(())
1585 }
1586
1587 async fn respond_to_contact_request(
1588 self: Arc<Server>,
1589 request: Message<proto::RespondToContactRequest>,
1590 response: Response<proto::RespondToContactRequest>,
1591 ) -> Result<()> {
1592 let responder_id = request.sender_user_id;
1593 let requester_id = UserId::from_proto(request.payload.requester_id);
1594 if request.payload.response == proto::ContactRequestResponse::Dismiss as i32 {
1595 self.app_state
1596 .db
1597 .dismiss_contact_notification(responder_id, requester_id)
1598 .await?;
1599 } else {
1600 let accept = request.payload.response == proto::ContactRequestResponse::Accept as i32;
1601 self.app_state
1602 .db
1603 .respond_to_contact_request(responder_id, requester_id, accept)
1604 .await?;
1605
1606 let store = self.store().await;
1607 // Update responder with new contact
1608 let mut update = proto::UpdateContacts::default();
1609 if accept {
1610 update
1611 .contacts
1612 .push(store.contact_for_user(requester_id, false));
1613 }
1614 update
1615 .remove_incoming_requests
1616 .push(requester_id.to_proto());
1617 for connection_id in store.connection_ids_for_user(responder_id) {
1618 self.peer.send(connection_id, update.clone())?;
1619 }
1620
1621 // Update requester with new contact
1622 let mut update = proto::UpdateContacts::default();
1623 if accept {
1624 update
1625 .contacts
1626 .push(store.contact_for_user(responder_id, true));
1627 }
1628 update
1629 .remove_outgoing_requests
1630 .push(responder_id.to_proto());
1631 for connection_id in store.connection_ids_for_user(requester_id) {
1632 self.peer.send(connection_id, update.clone())?;
1633 }
1634 }
1635
1636 response.send(proto::Ack {})?;
1637 Ok(())
1638 }
1639
1640 async fn remove_contact(
1641 self: Arc<Server>,
1642 request: Message<proto::RemoveContact>,
1643 response: Response<proto::RemoveContact>,
1644 ) -> Result<()> {
1645 let requester_id = request.sender_user_id;
1646 let responder_id = UserId::from_proto(request.payload.user_id);
1647 self.app_state
1648 .db
1649 .remove_contact(requester_id, responder_id)
1650 .await?;
1651
1652 // Update outgoing contact requests of requester
1653 let mut update = proto::UpdateContacts::default();
1654 update
1655 .remove_outgoing_requests
1656 .push(responder_id.to_proto());
1657 for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1658 self.peer.send(connection_id, update.clone())?;
1659 }
1660
1661 // Update incoming contact requests of responder
1662 let mut update = proto::UpdateContacts::default();
1663 update
1664 .remove_incoming_requests
1665 .push(requester_id.to_proto());
1666 for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1667 self.peer.send(connection_id, update.clone())?;
1668 }
1669
1670 response.send(proto::Ack {})?;
1671 Ok(())
1672 }
1673
1674 async fn update_diff_base(
1675 self: Arc<Server>,
1676 request: Message<proto::UpdateDiffBase>,
1677 ) -> Result<()> {
1678 let receiver_ids = self.store().await.project_connection_ids(
1679 ProjectId::from_proto(request.payload.project_id),
1680 request.sender_connection_id,
1681 )?;
1682 broadcast(
1683 request.sender_connection_id,
1684 receiver_ids,
1685 |connection_id| {
1686 self.peer.forward_send(
1687 request.sender_connection_id,
1688 connection_id,
1689 request.payload.clone(),
1690 )
1691 },
1692 );
1693 Ok(())
1694 }
1695
1696 async fn get_private_user_info(
1697 self: Arc<Self>,
1698 request: Message<proto::GetPrivateUserInfo>,
1699 response: Response<proto::GetPrivateUserInfo>,
1700 ) -> Result<()> {
1701 let metrics_id = self
1702 .app_state
1703 .db
1704 .get_user_metrics_id(request.sender_user_id)
1705 .await?;
1706 let user = self
1707 .app_state
1708 .db
1709 .get_user_by_id(request.sender_user_id)
1710 .await?
1711 .ok_or_else(|| anyhow!("user not found"))?;
1712 response.send(proto::GetPrivateUserInfoResponse {
1713 metrics_id,
1714 staff: user.admin,
1715 })?;
1716 Ok(())
1717 }
1718
1719 pub(crate) async fn store(&self) -> StoreGuard<'_> {
1720 #[cfg(test)]
1721 tokio::task::yield_now().await;
1722 let guard = self.store.lock().await;
1723 #[cfg(test)]
1724 tokio::task::yield_now().await;
1725 StoreGuard {
1726 guard,
1727 _not_send: PhantomData,
1728 }
1729 }
1730
1731 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1732 ServerSnapshot {
1733 store: self.store().await,
1734 peer: &self.peer,
1735 }
1736 }
1737}
1738
1739impl<'a> Deref for StoreGuard<'a> {
1740 type Target = Store;
1741
1742 fn deref(&self) -> &Self::Target {
1743 &*self.guard
1744 }
1745}
1746
1747impl<'a> DerefMut for StoreGuard<'a> {
1748 fn deref_mut(&mut self) -> &mut Self::Target {
1749 &mut *self.guard
1750 }
1751}
1752
1753impl<'a> Drop for StoreGuard<'a> {
1754 fn drop(&mut self) {
1755 #[cfg(test)]
1756 self.check_invariants();
1757 }
1758}
1759
1760impl Executor for RealExecutor {
1761 type Sleep = Sleep;
1762
1763 fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F) {
1764 tokio::task::spawn(future);
1765 }
1766
1767 fn sleep(&self, duration: Duration) -> Self::Sleep {
1768 tokio::time::sleep(duration)
1769 }
1770}
1771
1772fn broadcast<F>(
1773 sender_id: ConnectionId,
1774 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1775 mut f: F,
1776) where
1777 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1778{
1779 for receiver_id in receiver_ids {
1780 if receiver_id != sender_id {
1781 f(receiver_id).trace_err();
1782 }
1783 }
1784}
1785
1786lazy_static! {
1787 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
1788}
1789
1790pub struct ProtocolVersion(u32);
1791
1792impl Header for ProtocolVersion {
1793 fn name() -> &'static HeaderName {
1794 &ZED_PROTOCOL_VERSION
1795 }
1796
1797 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1798 where
1799 Self: Sized,
1800 I: Iterator<Item = &'i axum::http::HeaderValue>,
1801 {
1802 let version = values
1803 .next()
1804 .ok_or_else(axum::headers::Error::invalid)?
1805 .to_str()
1806 .map_err(|_| axum::headers::Error::invalid())?
1807 .parse()
1808 .map_err(|_| axum::headers::Error::invalid())?;
1809 Ok(Self(version))
1810 }
1811
1812 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1813 values.extend([self.0.to_string().parse().unwrap()]);
1814 }
1815}
1816
1817pub fn routes(server: Arc<Server>) -> Router<Body> {
1818 Router::new()
1819 .route("/rpc", get(handle_websocket_request))
1820 .layer(
1821 ServiceBuilder::new()
1822 .layer(Extension(server.app_state.clone()))
1823 .layer(middleware::from_fn(auth::validate_header)),
1824 )
1825 .route("/metrics", get(handle_metrics))
1826 .layer(Extension(server))
1827}
1828
1829pub async fn handle_websocket_request(
1830 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1831 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1832 Extension(server): Extension<Arc<Server>>,
1833 Extension(user): Extension<User>,
1834 ws: WebSocketUpgrade,
1835) -> axum::response::Response {
1836 if protocol_version != rpc::PROTOCOL_VERSION {
1837 return (
1838 StatusCode::UPGRADE_REQUIRED,
1839 "client must be upgraded".to_string(),
1840 )
1841 .into_response();
1842 }
1843 let socket_address = socket_address.to_string();
1844 ws.on_upgrade(move |socket| {
1845 use util::ResultExt;
1846 let socket = socket
1847 .map_ok(to_tungstenite_message)
1848 .err_into()
1849 .with(|message| async move { Ok(to_axum_message(message)) });
1850 let connection = Connection::new(Box::pin(socket));
1851 async move {
1852 server
1853 .handle_connection(connection, socket_address, user, None, RealExecutor)
1854 .await
1855 .log_err();
1856 }
1857 })
1858}
1859
1860pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> axum::response::Response {
1861 let metrics = server.store().await.metrics();
1862 METRIC_CONNECTIONS.set(metrics.connections as _);
1863 METRIC_SHARED_PROJECTS.set(metrics.shared_projects as _);
1864
1865 let encoder = prometheus::TextEncoder::new();
1866 let metric_families = prometheus::gather();
1867 match encoder.encode_to_string(&metric_families) {
1868 Ok(string) => (StatusCode::OK, string).into_response(),
1869 Err(error) => (
1870 StatusCode::INTERNAL_SERVER_ERROR,
1871 format!("failed to encode metrics {:?}", error),
1872 )
1873 .into_response(),
1874 }
1875}
1876
1877fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1878 match message {
1879 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1880 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1881 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1882 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1883 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1884 code: frame.code.into(),
1885 reason: frame.reason,
1886 })),
1887 }
1888}
1889
1890fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1891 match message {
1892 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1893 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1894 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1895 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1896 AxumMessage::Close(frame) => {
1897 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1898 code: frame.code.into(),
1899 reason: frame.reason,
1900 }))
1901 }
1902 }
1903}
1904
1905pub trait ResultExt {
1906 type Ok;
1907
1908 fn trace_err(self) -> Option<Self::Ok>;
1909}
1910
1911impl<T, E> ResultExt for Result<T, E>
1912where
1913 E: std::fmt::Debug,
1914{
1915 type Ok = T;
1916
1917 fn trace_err(self) -> Option<T> {
1918 match self {
1919 Ok(value) => Some(value),
1920 Err(error) => {
1921 tracing::error!("{:?}", error);
1922 None
1923 }
1924 }
1925 }
1926}