1mod store;
2
3use crate::{
4 auth,
5 db::{self, ChannelId, MessageId, ProjectId, User, UserId},
6 AppState, Result,
7};
8use anyhow::anyhow;
9use async_tungstenite::tungstenite::{
10 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
11};
12use axum::{
13 body::Body,
14 extract::{
15 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
16 ConnectInfo, WebSocketUpgrade,
17 },
18 headers::{Header, HeaderName},
19 http::StatusCode,
20 middleware,
21 response::IntoResponse,
22 routing::get,
23 Extension, Router, TypedHeader,
24};
25use collections::HashMap;
26use futures::{
27 channel::mpsc,
28 future::{self, BoxFuture},
29 FutureExt, SinkExt, StreamExt, TryStreamExt,
30};
31use lazy_static::lazy_static;
32use prometheus::{register_int_gauge, IntGauge};
33use rpc::{
34 proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
35 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
36};
37use serde::{Serialize, Serializer};
38use std::{
39 any::TypeId,
40 future::Future,
41 marker::PhantomData,
42 net::SocketAddr,
43 ops::{Deref, DerefMut},
44 rc::Rc,
45 sync::{
46 atomic::{AtomicBool, Ordering::SeqCst},
47 Arc,
48 },
49 time::Duration,
50};
51use time::OffsetDateTime;
52use tokio::{
53 sync::{RwLock, RwLockReadGuard, RwLockWriteGuard},
54 time::Sleep,
55};
56use tower::ServiceBuilder;
57use tracing::{info_span, instrument, Instrument};
58
59pub use store::{Store, Worktree};
60
61lazy_static! {
62 static ref METRIC_CONNECTIONS: IntGauge =
63 register_int_gauge!("connections", "number of connections").unwrap();
64 static ref METRIC_REGISTERED_PROJECTS: IntGauge =
65 register_int_gauge!("registered_projects", "number of registered projects").unwrap();
66 static ref METRIC_ACTIVE_PROJECTS: IntGauge =
67 register_int_gauge!("active_projects", "number of active projects").unwrap();
68 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
69 "shared_projects",
70 "number of open projects with one or more guests"
71 )
72 .unwrap();
73}
74
75type MessageHandler =
76 Box<dyn Send + Sync + Fn(Arc<Server>, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, ()>>;
77
78struct Response<R> {
79 server: Arc<Server>,
80 receipt: Receipt<R>,
81 responded: Arc<AtomicBool>,
82}
83
84impl<R: RequestMessage> Response<R> {
85 fn send(self, payload: R::Response) -> Result<()> {
86 self.responded.store(true, SeqCst);
87 self.server.peer.respond(self.receipt, payload)?;
88 Ok(())
89 }
90
91 fn into_receipt(self) -> Receipt<R> {
92 self.responded.store(true, SeqCst);
93 self.receipt
94 }
95}
96
97pub struct Server {
98 peer: Arc<Peer>,
99 pub(crate) store: RwLock<Store>,
100 app_state: Arc<AppState>,
101 handlers: HashMap<TypeId, MessageHandler>,
102 notifications: Option<mpsc::UnboundedSender<()>>,
103}
104
105pub trait Executor: Send + Clone {
106 type Sleep: Send + Future;
107 fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F);
108 fn sleep(&self, duration: Duration) -> Self::Sleep;
109}
110
111#[derive(Clone)]
112pub struct RealExecutor;
113
114const MESSAGE_COUNT_PER_PAGE: usize = 100;
115const MAX_MESSAGE_LEN: usize = 1024;
116
117struct StoreReadGuard<'a> {
118 guard: RwLockReadGuard<'a, Store>,
119 _not_send: PhantomData<Rc<()>>,
120}
121
122struct StoreWriteGuard<'a> {
123 guard: RwLockWriteGuard<'a, Store>,
124 _not_send: PhantomData<Rc<()>>,
125}
126
127#[derive(Serialize)]
128pub struct ServerSnapshot<'a> {
129 peer: &'a Peer,
130 #[serde(serialize_with = "serialize_deref")]
131 store: RwLockReadGuard<'a, Store>,
132}
133
134pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
135where
136 S: Serializer,
137 T: Deref<Target = U>,
138 U: Serialize,
139{
140 Serialize::serialize(value.deref(), serializer)
141}
142
143impl Server {
144 pub fn new(
145 app_state: Arc<AppState>,
146 notifications: Option<mpsc::UnboundedSender<()>>,
147 ) -> Arc<Self> {
148 let mut server = Self {
149 peer: Peer::new(),
150 app_state,
151 store: Default::default(),
152 handlers: Default::default(),
153 notifications,
154 };
155
156 server
157 .add_request_handler(Server::ping)
158 .add_request_handler(Server::register_project)
159 .add_request_handler(Server::unregister_project)
160 .add_request_handler(Server::join_project)
161 .add_message_handler(Server::leave_project)
162 .add_message_handler(Server::respond_to_join_project_request)
163 .add_message_handler(Server::update_project)
164 .add_message_handler(Server::register_project_activity)
165 .add_request_handler(Server::update_worktree)
166 .add_message_handler(Server::start_language_server)
167 .add_message_handler(Server::update_language_server)
168 .add_message_handler(Server::update_diagnostic_summary)
169 .add_request_handler(Server::forward_project_request::<proto::GetHover>)
170 .add_request_handler(Server::forward_project_request::<proto::GetDefinition>)
171 .add_request_handler(Server::forward_project_request::<proto::GetReferences>)
172 .add_request_handler(Server::forward_project_request::<proto::SearchProject>)
173 .add_request_handler(Server::forward_project_request::<proto::GetDocumentHighlights>)
174 .add_request_handler(Server::forward_project_request::<proto::GetProjectSymbols>)
175 .add_request_handler(Server::forward_project_request::<proto::OpenBufferForSymbol>)
176 .add_request_handler(Server::forward_project_request::<proto::OpenBufferById>)
177 .add_request_handler(Server::forward_project_request::<proto::OpenBufferByPath>)
178 .add_request_handler(Server::forward_project_request::<proto::GetCompletions>)
179 .add_request_handler(
180 Server::forward_project_request::<proto::ApplyCompletionAdditionalEdits>,
181 )
182 .add_request_handler(Server::forward_project_request::<proto::GetCodeActions>)
183 .add_request_handler(Server::forward_project_request::<proto::ApplyCodeAction>)
184 .add_request_handler(Server::forward_project_request::<proto::PrepareRename>)
185 .add_request_handler(Server::forward_project_request::<proto::PerformRename>)
186 .add_request_handler(Server::forward_project_request::<proto::ReloadBuffers>)
187 .add_request_handler(Server::forward_project_request::<proto::FormatBuffers>)
188 .add_request_handler(Server::forward_project_request::<proto::CreateProjectEntry>)
189 .add_request_handler(Server::forward_project_request::<proto::RenameProjectEntry>)
190 .add_request_handler(Server::forward_project_request::<proto::CopyProjectEntry>)
191 .add_request_handler(Server::forward_project_request::<proto::DeleteProjectEntry>)
192 .add_request_handler(Server::update_buffer)
193 .add_message_handler(Server::update_buffer_file)
194 .add_message_handler(Server::buffer_reloaded)
195 .add_message_handler(Server::buffer_saved)
196 .add_request_handler(Server::save_buffer)
197 .add_request_handler(Server::get_channels)
198 .add_request_handler(Server::get_users)
199 .add_request_handler(Server::fuzzy_search_users)
200 .add_request_handler(Server::request_contact)
201 .add_request_handler(Server::remove_contact)
202 .add_request_handler(Server::respond_to_contact_request)
203 .add_request_handler(Server::join_channel)
204 .add_message_handler(Server::leave_channel)
205 .add_request_handler(Server::send_channel_message)
206 .add_request_handler(Server::follow)
207 .add_message_handler(Server::unfollow)
208 .add_message_handler(Server::update_followers)
209 .add_request_handler(Server::get_channel_messages);
210
211 Arc::new(server)
212 }
213
214 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
215 where
216 F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>) -> Fut,
217 Fut: 'static + Send + Future<Output = Result<()>>,
218 M: EnvelopedMessage,
219 {
220 let prev_handler = self.handlers.insert(
221 TypeId::of::<M>(),
222 Box::new(move |server, envelope| {
223 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
224 let span = info_span!(
225 "handle message",
226 payload_type = envelope.payload_type_name()
227 );
228 span.in_scope(|| {
229 tracing::info!(
230 payload = format!("{:?}", envelope.payload).as_str(),
231 "message payload"
232 );
233 });
234 let future = (handler)(server, *envelope);
235 async move {
236 if let Err(error) = future.await {
237 tracing::error!(%error, "error handling message");
238 }
239 }
240 .instrument(span)
241 .boxed()
242 }),
243 );
244 if prev_handler.is_some() {
245 panic!("registered a handler for the same message twice");
246 }
247 self
248 }
249
250 /// Handle a request while holding a lock to the store. This is useful when we're registering
251 /// a connection but we want to respond on the connection before anybody else can send on it.
252 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
253 where
254 F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>, Response<M>) -> Fut,
255 Fut: Send + Future<Output = Result<()>>,
256 M: RequestMessage,
257 {
258 let handler = Arc::new(handler);
259 self.add_message_handler(move |server, envelope| {
260 let receipt = envelope.receipt();
261 let handler = handler.clone();
262 async move {
263 let responded = Arc::new(AtomicBool::default());
264 let response = Response {
265 server: server.clone(),
266 responded: responded.clone(),
267 receipt: envelope.receipt(),
268 };
269 match (handler)(server.clone(), envelope, response).await {
270 Ok(()) => {
271 if responded.load(std::sync::atomic::Ordering::SeqCst) {
272 Ok(())
273 } else {
274 Err(anyhow!("handler did not send a response"))?
275 }
276 }
277 Err(error) => {
278 server.peer.respond_with_error(
279 receipt,
280 proto::Error {
281 message: error.to_string(),
282 },
283 )?;
284 Err(error)
285 }
286 }
287 }
288 })
289 }
290
291 /// Start a long lived task that records which users are active in which projects.
292 pub fn start_recording_project_activity<E: 'static + Executor>(
293 self: &Arc<Self>,
294 interval: Duration,
295 executor: E,
296 ) {
297 executor.spawn_detached({
298 let this = Arc::downgrade(self);
299 let executor = executor.clone();
300 async move {
301 let mut period_start = OffsetDateTime::now_utc();
302 let mut active_projects = Vec::<(UserId, ProjectId)>::new();
303 loop {
304 let sleep = executor.sleep(interval);
305 sleep.await;
306 let this = if let Some(this) = this.upgrade() {
307 this
308 } else {
309 break;
310 };
311
312 active_projects.clear();
313 active_projects.extend(this.store().await.projects().flat_map(
314 |(project_id, project)| {
315 project.guests.values().chain([&project.host]).filter_map(
316 |collaborator| {
317 if !collaborator.admin
318 && collaborator
319 .last_activity
320 .map_or(false, |activity| activity > period_start)
321 {
322 Some((collaborator.user_id, *project_id))
323 } else {
324 None
325 }
326 },
327 )
328 },
329 ));
330
331 let period_end = OffsetDateTime::now_utc();
332 this.app_state
333 .db
334 .record_project_activity(period_start..period_end, &active_projects)
335 .await
336 .trace_err();
337 period_start = period_end;
338 }
339 }
340 });
341 }
342
343 pub fn handle_connection<E: Executor>(
344 self: &Arc<Self>,
345 connection: Connection,
346 address: String,
347 user: User,
348 mut send_connection_id: Option<mpsc::Sender<ConnectionId>>,
349 executor: E,
350 ) -> impl Future<Output = Result<()>> {
351 let mut this = self.clone();
352 let user_id = user.id;
353 let login = user.github_login;
354 let span = info_span!("handle connection", %user_id, %login, %address);
355 async move {
356 let (connection_id, handle_io, mut incoming_rx) = this
357 .peer
358 .add_connection(connection, {
359 let executor = executor.clone();
360 move |duration| {
361 let timer = executor.sleep(duration);
362 async move {
363 timer.await;
364 }
365 }
366 })
367 .await;
368
369 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
370
371 if let Some(send_connection_id) = send_connection_id.as_mut() {
372 let _ = send_connection_id.send(connection_id).await;
373 }
374
375 if !user.connected_once {
376 this.peer.send(connection_id, proto::ShowContacts {})?;
377 this.app_state.db.set_user_connected_once(user_id, true).await?;
378 }
379
380 let (contacts, invite_code) = future::try_join(
381 this.app_state.db.get_contacts(user_id),
382 this.app_state.db.get_invite_code_for_user(user_id)
383 ).await?;
384
385 {
386 let mut store = this.store_mut().await;
387 store.add_connection(connection_id, user_id, user.admin);
388 this.peer.send(connection_id, store.build_initial_contacts_update(contacts))?;
389
390 if let Some((code, count)) = invite_code {
391 this.peer.send(connection_id, proto::UpdateInviteInfo {
392 url: format!("{}{}", this.app_state.invite_link_prefix, code),
393 count,
394 })?;
395 }
396 }
397 this.update_user_contacts(user_id).await?;
398
399 let handle_io = handle_io.fuse();
400 futures::pin_mut!(handle_io);
401 loop {
402 let next_message = incoming_rx.next().fuse();
403 futures::pin_mut!(next_message);
404 futures::select_biased! {
405 result = handle_io => {
406 if let Err(error) = result {
407 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
408 }
409 break;
410 }
411 message = next_message => {
412 if let Some(message) = message {
413 let type_name = message.payload_type_name();
414 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
415 async {
416 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
417 let notifications = this.notifications.clone();
418 let is_background = message.is_background();
419 let handle_message = (handler)(this.clone(), message);
420 let handle_message = async move {
421 handle_message.await;
422 if let Some(mut notifications) = notifications {
423 let _ = notifications.send(()).await;
424 }
425 };
426 if is_background {
427 executor.spawn_detached(handle_message);
428 } else {
429 handle_message.await;
430 }
431 } else {
432 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
433 }
434 }.instrument(span).await;
435 } else {
436 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
437 break;
438 }
439 }
440 }
441 }
442
443 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
444 if let Err(error) = this.sign_out(connection_id).await {
445 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
446 }
447
448 Ok(())
449 }.instrument(span)
450 }
451
452 #[instrument(skip(self), err)]
453 async fn sign_out(self: &mut Arc<Self>, connection_id: ConnectionId) -> Result<()> {
454 self.peer.disconnect(connection_id);
455
456 let mut projects_to_unregister = Vec::new();
457 let removed_user_id;
458 {
459 let mut store = self.store_mut().await;
460 let removed_connection = store.remove_connection(connection_id)?;
461
462 for (project_id, project) in removed_connection.hosted_projects {
463 projects_to_unregister.push(project_id);
464 broadcast(connection_id, project.guests.keys().copied(), |conn_id| {
465 self.peer.send(
466 conn_id,
467 proto::UnregisterProject {
468 project_id: project_id.to_proto(),
469 },
470 )
471 });
472
473 for (_, receipts) in project.join_requests {
474 for receipt in receipts {
475 self.peer.respond(
476 receipt,
477 proto::JoinProjectResponse {
478 variant: Some(proto::join_project_response::Variant::Decline(
479 proto::join_project_response::Decline {
480 reason: proto::join_project_response::decline::Reason::WentOffline as i32
481 },
482 )),
483 },
484 )?;
485 }
486 }
487 }
488
489 for project_id in removed_connection.guest_project_ids {
490 if let Some(project) = store.project(project_id).trace_err() {
491 broadcast(connection_id, project.connection_ids(), |conn_id| {
492 self.peer.send(
493 conn_id,
494 proto::RemoveProjectCollaborator {
495 project_id: project_id.to_proto(),
496 peer_id: connection_id.0,
497 },
498 )
499 });
500 if project.guests.is_empty() {
501 self.peer
502 .send(
503 project.host_connection_id,
504 proto::ProjectUnshared {
505 project_id: project_id.to_proto(),
506 },
507 )
508 .trace_err();
509 }
510 }
511 }
512
513 removed_user_id = removed_connection.user_id;
514 };
515
516 self.update_user_contacts(removed_user_id).await.trace_err();
517
518 for project_id in projects_to_unregister {
519 self.app_state
520 .db
521 .unregister_project(project_id)
522 .await
523 .trace_err();
524 }
525
526 Ok(())
527 }
528
529 pub async fn invite_code_redeemed(
530 self: &Arc<Self>,
531 code: &str,
532 invitee_id: UserId,
533 ) -> Result<()> {
534 let user = self.app_state.db.get_user_for_invite_code(code).await?;
535 let store = self.store().await;
536 let invitee_contact = store.contact_for_user(invitee_id, true);
537 for connection_id in store.connection_ids_for_user(user.id) {
538 self.peer.send(
539 connection_id,
540 proto::UpdateContacts {
541 contacts: vec![invitee_contact.clone()],
542 ..Default::default()
543 },
544 )?;
545 self.peer.send(
546 connection_id,
547 proto::UpdateInviteInfo {
548 url: format!("{}{}", self.app_state.invite_link_prefix, code),
549 count: user.invite_count as u32,
550 },
551 )?;
552 }
553 Ok(())
554 }
555
556 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
557 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
558 if let Some(invite_code) = &user.invite_code {
559 let store = self.store().await;
560 for connection_id in store.connection_ids_for_user(user_id) {
561 self.peer.send(
562 connection_id,
563 proto::UpdateInviteInfo {
564 url: format!("{}{}", self.app_state.invite_link_prefix, invite_code),
565 count: user.invite_count as u32,
566 },
567 )?;
568 }
569 }
570 }
571 Ok(())
572 }
573
574 async fn ping(
575 self: Arc<Server>,
576 _: TypedEnvelope<proto::Ping>,
577 response: Response<proto::Ping>,
578 ) -> Result<()> {
579 response.send(proto::Ack {})?;
580 Ok(())
581 }
582
583 async fn register_project(
584 self: Arc<Server>,
585 request: TypedEnvelope<proto::RegisterProject>,
586 response: Response<proto::RegisterProject>,
587 ) -> Result<()> {
588 let user_id = self
589 .store()
590 .await
591 .user_id_for_connection(request.sender_id)?;
592 let project_id = self.app_state.db.register_project(user_id).await?;
593 self.store_mut()
594 .await
595 .register_project(request.sender_id, project_id)?;
596
597 response.send(proto::RegisterProjectResponse {
598 project_id: project_id.to_proto(),
599 })?;
600
601 Ok(())
602 }
603
604 async fn unregister_project(
605 self: Arc<Server>,
606 request: TypedEnvelope<proto::UnregisterProject>,
607 response: Response<proto::UnregisterProject>,
608 ) -> Result<()> {
609 let project_id = ProjectId::from_proto(request.payload.project_id);
610 let (user_id, project) = {
611 let mut state = self.store_mut().await;
612 let project = state.unregister_project(project_id, request.sender_id)?;
613 (state.user_id_for_connection(request.sender_id)?, project)
614 };
615 self.app_state.db.unregister_project(project_id).await?;
616
617 broadcast(
618 request.sender_id,
619 project.guests.keys().copied(),
620 |conn_id| {
621 self.peer.send(
622 conn_id,
623 proto::UnregisterProject {
624 project_id: project_id.to_proto(),
625 },
626 )
627 },
628 );
629 for (_, receipts) in project.join_requests {
630 for receipt in receipts {
631 self.peer.respond(
632 receipt,
633 proto::JoinProjectResponse {
634 variant: Some(proto::join_project_response::Variant::Decline(
635 proto::join_project_response::Decline {
636 reason: proto::join_project_response::decline::Reason::Closed
637 as i32,
638 },
639 )),
640 },
641 )?;
642 }
643 }
644
645 // Send out the `UpdateContacts` message before responding to the unregister
646 // request. This way, when the project's host can keep track of the project's
647 // remote id until after they've received the `UpdateContacts` message for
648 // themself.
649 self.update_user_contacts(user_id).await?;
650 response.send(proto::Ack {})?;
651
652 Ok(())
653 }
654
655 async fn update_user_contacts(self: &Arc<Server>, user_id: UserId) -> Result<()> {
656 let contacts = self.app_state.db.get_contacts(user_id).await?;
657 let store = self.store().await;
658 let updated_contact = store.contact_for_user(user_id, false);
659 for contact in contacts {
660 if let db::Contact::Accepted {
661 user_id: contact_user_id,
662 ..
663 } = contact
664 {
665 for contact_conn_id in store.connection_ids_for_user(contact_user_id) {
666 self.peer
667 .send(
668 contact_conn_id,
669 proto::UpdateContacts {
670 contacts: vec![updated_contact.clone()],
671 remove_contacts: Default::default(),
672 incoming_requests: Default::default(),
673 remove_incoming_requests: Default::default(),
674 outgoing_requests: Default::default(),
675 remove_outgoing_requests: Default::default(),
676 },
677 )
678 .trace_err();
679 }
680 }
681 }
682 Ok(())
683 }
684
685 async fn join_project(
686 self: Arc<Server>,
687 request: TypedEnvelope<proto::JoinProject>,
688 response: Response<proto::JoinProject>,
689 ) -> Result<()> {
690 let project_id = ProjectId::from_proto(request.payload.project_id);
691
692 let host_user_id;
693 let guest_user_id;
694 let host_connection_id;
695 {
696 let state = self.store().await;
697 let project = state.project(project_id)?;
698 host_user_id = project.host.user_id;
699 host_connection_id = project.host_connection_id;
700 guest_user_id = state.user_id_for_connection(request.sender_id)?;
701 };
702
703 tracing::info!(%project_id, %host_user_id, %host_connection_id, "join project");
704 let has_contact = self
705 .app_state
706 .db
707 .has_contact(guest_user_id, host_user_id)
708 .await?;
709 if !has_contact {
710 return Err(anyhow!("no such project"))?;
711 }
712
713 self.store_mut().await.request_join_project(
714 guest_user_id,
715 project_id,
716 response.into_receipt(),
717 )?;
718 self.peer.send(
719 host_connection_id,
720 proto::RequestJoinProject {
721 project_id: project_id.to_proto(),
722 requester_id: guest_user_id.to_proto(),
723 },
724 )?;
725 Ok(())
726 }
727
728 async fn respond_to_join_project_request(
729 self: Arc<Server>,
730 request: TypedEnvelope<proto::RespondToJoinProjectRequest>,
731 ) -> Result<()> {
732 let host_user_id;
733
734 {
735 let mut state = self.store_mut().await;
736 let project_id = ProjectId::from_proto(request.payload.project_id);
737 let project = state.project(project_id)?;
738 if project.host_connection_id != request.sender_id {
739 Err(anyhow!("no such connection"))?;
740 }
741
742 host_user_id = project.host.user_id;
743 let guest_user_id = UserId::from_proto(request.payload.requester_id);
744
745 if !request.payload.allow {
746 let receipts = state
747 .deny_join_project_request(request.sender_id, guest_user_id, project_id)
748 .ok_or_else(|| anyhow!("no such request"))?;
749 for receipt in receipts {
750 self.peer.respond(
751 receipt,
752 proto::JoinProjectResponse {
753 variant: Some(proto::join_project_response::Variant::Decline(
754 proto::join_project_response::Decline {
755 reason: proto::join_project_response::decline::Reason::Declined
756 as i32,
757 },
758 )),
759 },
760 )?;
761 }
762 return Ok(());
763 }
764
765 let (receipts_with_replica_ids, project) = state
766 .accept_join_project_request(request.sender_id, guest_user_id, project_id)
767 .ok_or_else(|| anyhow!("no such request"))?;
768
769 let peer_count = project.guests.len();
770 let mut collaborators = Vec::with_capacity(peer_count);
771 collaborators.push(proto::Collaborator {
772 peer_id: project.host_connection_id.0,
773 replica_id: 0,
774 user_id: project.host.user_id.to_proto(),
775 });
776 let worktrees = project
777 .worktrees
778 .iter()
779 .filter_map(|(id, shared_worktree)| {
780 let worktree = project.worktrees.get(&id)?;
781 Some(proto::Worktree {
782 id: *id,
783 root_name: worktree.root_name.clone(),
784 entries: shared_worktree.entries.values().cloned().collect(),
785 diagnostic_summaries: shared_worktree
786 .diagnostic_summaries
787 .values()
788 .cloned()
789 .collect(),
790 visible: worktree.visible,
791 scan_id: shared_worktree.scan_id,
792 })
793 })
794 .collect::<Vec<_>>();
795
796 // Add all guests other than the requesting user's own connections as collaborators
797 for (guest_conn_id, guest) in &project.guests {
798 if receipts_with_replica_ids
799 .iter()
800 .all(|(receipt, _)| receipt.sender_id != *guest_conn_id)
801 {
802 collaborators.push(proto::Collaborator {
803 peer_id: guest_conn_id.0,
804 replica_id: guest.replica_id as u32,
805 user_id: guest.user_id.to_proto(),
806 });
807 }
808 }
809
810 for conn_id in project.connection_ids() {
811 for (receipt, replica_id) in &receipts_with_replica_ids {
812 if conn_id != receipt.sender_id {
813 self.peer.send(
814 conn_id,
815 proto::AddProjectCollaborator {
816 project_id: project_id.to_proto(),
817 collaborator: Some(proto::Collaborator {
818 peer_id: receipt.sender_id.0,
819 replica_id: *replica_id as u32,
820 user_id: guest_user_id.to_proto(),
821 }),
822 },
823 )?;
824 }
825 }
826 }
827
828 for (receipt, replica_id) in receipts_with_replica_ids {
829 self.peer.respond(
830 receipt,
831 proto::JoinProjectResponse {
832 variant: Some(proto::join_project_response::Variant::Accept(
833 proto::join_project_response::Accept {
834 worktrees: worktrees.clone(),
835 replica_id: replica_id as u32,
836 collaborators: collaborators.clone(),
837 language_servers: project.language_servers.clone(),
838 },
839 )),
840 },
841 )?;
842 }
843 }
844
845 self.update_user_contacts(host_user_id).await?;
846 Ok(())
847 }
848
849 async fn leave_project(
850 self: Arc<Server>,
851 request: TypedEnvelope<proto::LeaveProject>,
852 ) -> Result<()> {
853 let sender_id = request.sender_id;
854 let project_id = ProjectId::from_proto(request.payload.project_id);
855 let project;
856 {
857 let mut store = self.store_mut().await;
858 project = store.leave_project(sender_id, project_id)?;
859 tracing::info!(
860 %project_id,
861 host_user_id = %project.host_user_id,
862 host_connection_id = %project.host_connection_id,
863 "leave project"
864 );
865
866 if project.remove_collaborator {
867 broadcast(sender_id, project.connection_ids, |conn_id| {
868 self.peer.send(
869 conn_id,
870 proto::RemoveProjectCollaborator {
871 project_id: project_id.to_proto(),
872 peer_id: sender_id.0,
873 },
874 )
875 });
876 }
877
878 if let Some(requester_id) = project.cancel_request {
879 self.peer.send(
880 project.host_connection_id,
881 proto::JoinProjectRequestCancelled {
882 project_id: project_id.to_proto(),
883 requester_id: requester_id.to_proto(),
884 },
885 )?;
886 }
887
888 if project.unshare {
889 self.peer.send(
890 project.host_connection_id,
891 proto::ProjectUnshared {
892 project_id: project_id.to_proto(),
893 },
894 )?;
895 }
896 }
897 self.update_user_contacts(project.host_user_id).await?;
898 Ok(())
899 }
900
901 async fn update_project(
902 self: Arc<Server>,
903 request: TypedEnvelope<proto::UpdateProject>,
904 ) -> Result<()> {
905 let project_id = ProjectId::from_proto(request.payload.project_id);
906 let user_id;
907 {
908 let mut state = self.store_mut().await;
909 user_id = state.user_id_for_connection(request.sender_id)?;
910 let guest_connection_ids = state
911 .read_project(project_id, request.sender_id)?
912 .guest_connection_ids();
913 state.update_project(project_id, &request.payload.worktrees, request.sender_id)?;
914 broadcast(request.sender_id, guest_connection_ids, |connection_id| {
915 self.peer
916 .forward_send(request.sender_id, connection_id, request.payload.clone())
917 });
918 };
919 self.update_user_contacts(user_id).await?;
920 Ok(())
921 }
922
923 async fn register_project_activity(
924 self: Arc<Server>,
925 request: TypedEnvelope<proto::RegisterProjectActivity>,
926 ) -> Result<()> {
927 self.store_mut().await.register_project_activity(
928 ProjectId::from_proto(request.payload.project_id),
929 request.sender_id,
930 )?;
931 Ok(())
932 }
933
934 async fn update_worktree(
935 self: Arc<Server>,
936 request: TypedEnvelope<proto::UpdateWorktree>,
937 response: Response<proto::UpdateWorktree>,
938 ) -> Result<()> {
939 let project_id = ProjectId::from_proto(request.payload.project_id);
940 let worktree_id = request.payload.worktree_id;
941 let (connection_ids, metadata_changed, extension_counts) = {
942 let mut store = self.store_mut().await;
943 let (connection_ids, metadata_changed, extension_counts) = store.update_worktree(
944 request.sender_id,
945 project_id,
946 worktree_id,
947 &request.payload.root_name,
948 &request.payload.removed_entries,
949 &request.payload.updated_entries,
950 request.payload.scan_id,
951 )?;
952 (connection_ids, metadata_changed, extension_counts.clone())
953 };
954 self.app_state
955 .db
956 .update_worktree_extensions(project_id, worktree_id, extension_counts)
957 .await?;
958
959 broadcast(request.sender_id, connection_ids, |connection_id| {
960 self.peer
961 .forward_send(request.sender_id, connection_id, request.payload.clone())
962 });
963 if metadata_changed {
964 let user_id = self
965 .store()
966 .await
967 .user_id_for_connection(request.sender_id)?;
968 self.update_user_contacts(user_id).await?;
969 }
970 response.send(proto::Ack {})?;
971 Ok(())
972 }
973
974 async fn update_diagnostic_summary(
975 self: Arc<Server>,
976 request: TypedEnvelope<proto::UpdateDiagnosticSummary>,
977 ) -> Result<()> {
978 let summary = request
979 .payload
980 .summary
981 .clone()
982 .ok_or_else(|| anyhow!("invalid summary"))?;
983 let receiver_ids = self.store_mut().await.update_diagnostic_summary(
984 ProjectId::from_proto(request.payload.project_id),
985 request.payload.worktree_id,
986 request.sender_id,
987 summary,
988 )?;
989
990 broadcast(request.sender_id, receiver_ids, |connection_id| {
991 self.peer
992 .forward_send(request.sender_id, connection_id, request.payload.clone())
993 });
994 Ok(())
995 }
996
997 async fn start_language_server(
998 self: Arc<Server>,
999 request: TypedEnvelope<proto::StartLanguageServer>,
1000 ) -> Result<()> {
1001 let receiver_ids = self.store_mut().await.start_language_server(
1002 ProjectId::from_proto(request.payload.project_id),
1003 request.sender_id,
1004 request
1005 .payload
1006 .server
1007 .clone()
1008 .ok_or_else(|| anyhow!("invalid language server"))?,
1009 )?;
1010 broadcast(request.sender_id, receiver_ids, |connection_id| {
1011 self.peer
1012 .forward_send(request.sender_id, connection_id, request.payload.clone())
1013 });
1014 Ok(())
1015 }
1016
1017 async fn update_language_server(
1018 self: Arc<Server>,
1019 request: TypedEnvelope<proto::UpdateLanguageServer>,
1020 ) -> Result<()> {
1021 let receiver_ids = self.store().await.project_connection_ids(
1022 ProjectId::from_proto(request.payload.project_id),
1023 request.sender_id,
1024 )?;
1025 broadcast(request.sender_id, receiver_ids, |connection_id| {
1026 self.peer
1027 .forward_send(request.sender_id, connection_id, request.payload.clone())
1028 });
1029 Ok(())
1030 }
1031
1032 async fn forward_project_request<T>(
1033 self: Arc<Server>,
1034 request: TypedEnvelope<T>,
1035 response: Response<T>,
1036 ) -> Result<()>
1037 where
1038 T: EntityMessage + RequestMessage,
1039 {
1040 let host_connection_id = self
1041 .store()
1042 .await
1043 .read_project(
1044 ProjectId::from_proto(request.payload.remote_entity_id()),
1045 request.sender_id,
1046 )?
1047 .host_connection_id;
1048
1049 response.send(
1050 self.peer
1051 .forward_request(request.sender_id, host_connection_id, request.payload)
1052 .await?,
1053 )?;
1054 Ok(())
1055 }
1056
1057 async fn save_buffer(
1058 self: Arc<Server>,
1059 request: TypedEnvelope<proto::SaveBuffer>,
1060 response: Response<proto::SaveBuffer>,
1061 ) -> Result<()> {
1062 let project_id = ProjectId::from_proto(request.payload.project_id);
1063 let host = self
1064 .store()
1065 .await
1066 .read_project(project_id, request.sender_id)?
1067 .host_connection_id;
1068 let response_payload = self
1069 .peer
1070 .forward_request(request.sender_id, host, request.payload.clone())
1071 .await?;
1072
1073 let mut guests = self
1074 .store()
1075 .await
1076 .read_project(project_id, request.sender_id)?
1077 .connection_ids();
1078 guests.retain(|guest_connection_id| *guest_connection_id != request.sender_id);
1079 broadcast(host, guests, |conn_id| {
1080 self.peer
1081 .forward_send(host, conn_id, response_payload.clone())
1082 });
1083 response.send(response_payload)?;
1084 Ok(())
1085 }
1086
1087 async fn update_buffer(
1088 self: Arc<Server>,
1089 request: TypedEnvelope<proto::UpdateBuffer>,
1090 response: Response<proto::UpdateBuffer>,
1091 ) -> Result<()> {
1092 let project_id = ProjectId::from_proto(request.payload.project_id);
1093 let receiver_ids = {
1094 let mut store = self.store_mut().await;
1095 store.register_project_activity(project_id, request.sender_id)?;
1096 store.project_connection_ids(project_id, request.sender_id)?
1097 };
1098
1099 broadcast(request.sender_id, receiver_ids, |connection_id| {
1100 self.peer
1101 .forward_send(request.sender_id, connection_id, request.payload.clone())
1102 });
1103 response.send(proto::Ack {})?;
1104 Ok(())
1105 }
1106
1107 async fn update_buffer_file(
1108 self: Arc<Server>,
1109 request: TypedEnvelope<proto::UpdateBufferFile>,
1110 ) -> Result<()> {
1111 let receiver_ids = self.store().await.project_connection_ids(
1112 ProjectId::from_proto(request.payload.project_id),
1113 request.sender_id,
1114 )?;
1115 broadcast(request.sender_id, receiver_ids, |connection_id| {
1116 self.peer
1117 .forward_send(request.sender_id, connection_id, request.payload.clone())
1118 });
1119 Ok(())
1120 }
1121
1122 async fn buffer_reloaded(
1123 self: Arc<Server>,
1124 request: TypedEnvelope<proto::BufferReloaded>,
1125 ) -> Result<()> {
1126 let receiver_ids = self.store().await.project_connection_ids(
1127 ProjectId::from_proto(request.payload.project_id),
1128 request.sender_id,
1129 )?;
1130 broadcast(request.sender_id, receiver_ids, |connection_id| {
1131 self.peer
1132 .forward_send(request.sender_id, connection_id, request.payload.clone())
1133 });
1134 Ok(())
1135 }
1136
1137 async fn buffer_saved(
1138 self: Arc<Server>,
1139 request: TypedEnvelope<proto::BufferSaved>,
1140 ) -> Result<()> {
1141 let receiver_ids = self.store().await.project_connection_ids(
1142 ProjectId::from_proto(request.payload.project_id),
1143 request.sender_id,
1144 )?;
1145 broadcast(request.sender_id, receiver_ids, |connection_id| {
1146 self.peer
1147 .forward_send(request.sender_id, connection_id, request.payload.clone())
1148 });
1149 Ok(())
1150 }
1151
1152 async fn follow(
1153 self: Arc<Self>,
1154 request: TypedEnvelope<proto::Follow>,
1155 response: Response<proto::Follow>,
1156 ) -> Result<()> {
1157 let project_id = ProjectId::from_proto(request.payload.project_id);
1158 let leader_id = ConnectionId(request.payload.leader_id);
1159 let follower_id = request.sender_id;
1160 {
1161 let mut store = self.store_mut().await;
1162 if !store
1163 .project_connection_ids(project_id, follower_id)?
1164 .contains(&leader_id)
1165 {
1166 Err(anyhow!("no such peer"))?;
1167 }
1168
1169 store.register_project_activity(project_id, follower_id)?;
1170 }
1171
1172 let mut response_payload = self
1173 .peer
1174 .forward_request(request.sender_id, leader_id, request.payload)
1175 .await?;
1176 response_payload
1177 .views
1178 .retain(|view| view.leader_id != Some(follower_id.0));
1179 response.send(response_payload)?;
1180 Ok(())
1181 }
1182
1183 async fn unfollow(self: Arc<Self>, request: TypedEnvelope<proto::Unfollow>) -> Result<()> {
1184 let project_id = ProjectId::from_proto(request.payload.project_id);
1185 let leader_id = ConnectionId(request.payload.leader_id);
1186 let mut store = self.store_mut().await;
1187 if !store
1188 .project_connection_ids(project_id, request.sender_id)?
1189 .contains(&leader_id)
1190 {
1191 Err(anyhow!("no such peer"))?;
1192 }
1193 store.register_project_activity(project_id, request.sender_id)?;
1194 self.peer
1195 .forward_send(request.sender_id, leader_id, request.payload)?;
1196 Ok(())
1197 }
1198
1199 async fn update_followers(
1200 self: Arc<Self>,
1201 request: TypedEnvelope<proto::UpdateFollowers>,
1202 ) -> Result<()> {
1203 let project_id = ProjectId::from_proto(request.payload.project_id);
1204 let mut store = self.store_mut().await;
1205 store.register_project_activity(project_id, request.sender_id)?;
1206 let connection_ids = store.project_connection_ids(project_id, request.sender_id)?;
1207 let leader_id = request
1208 .payload
1209 .variant
1210 .as_ref()
1211 .and_then(|variant| match variant {
1212 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1213 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1214 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1215 });
1216 for follower_id in &request.payload.follower_ids {
1217 let follower_id = ConnectionId(*follower_id);
1218 if connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
1219 self.peer
1220 .forward_send(request.sender_id, follower_id, request.payload.clone())?;
1221 }
1222 }
1223 Ok(())
1224 }
1225
1226 async fn get_channels(
1227 self: Arc<Server>,
1228 request: TypedEnvelope<proto::GetChannels>,
1229 response: Response<proto::GetChannels>,
1230 ) -> Result<()> {
1231 let user_id = self
1232 .store()
1233 .await
1234 .user_id_for_connection(request.sender_id)?;
1235 let channels = self.app_state.db.get_accessible_channels(user_id).await?;
1236 response.send(proto::GetChannelsResponse {
1237 channels: channels
1238 .into_iter()
1239 .map(|chan| proto::Channel {
1240 id: chan.id.to_proto(),
1241 name: chan.name,
1242 })
1243 .collect(),
1244 })?;
1245 Ok(())
1246 }
1247
1248 async fn get_users(
1249 self: Arc<Server>,
1250 request: TypedEnvelope<proto::GetUsers>,
1251 response: Response<proto::GetUsers>,
1252 ) -> Result<()> {
1253 let user_ids = request
1254 .payload
1255 .user_ids
1256 .into_iter()
1257 .map(UserId::from_proto)
1258 .collect();
1259 let users = self
1260 .app_state
1261 .db
1262 .get_users_by_ids(user_ids)
1263 .await?
1264 .into_iter()
1265 .map(|user| proto::User {
1266 id: user.id.to_proto(),
1267 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1268 github_login: user.github_login,
1269 })
1270 .collect();
1271 response.send(proto::UsersResponse { users })?;
1272 Ok(())
1273 }
1274
1275 async fn fuzzy_search_users(
1276 self: Arc<Server>,
1277 request: TypedEnvelope<proto::FuzzySearchUsers>,
1278 response: Response<proto::FuzzySearchUsers>,
1279 ) -> Result<()> {
1280 let user_id = self
1281 .store()
1282 .await
1283 .user_id_for_connection(request.sender_id)?;
1284 let query = request.payload.query;
1285 let db = &self.app_state.db;
1286 let users = match query.len() {
1287 0 => vec![],
1288 1 | 2 => db
1289 .get_user_by_github_login(&query)
1290 .await?
1291 .into_iter()
1292 .collect(),
1293 _ => db.fuzzy_search_users(&query, 10).await?,
1294 };
1295 let users = users
1296 .into_iter()
1297 .filter(|user| user.id != user_id)
1298 .map(|user| proto::User {
1299 id: user.id.to_proto(),
1300 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1301 github_login: user.github_login,
1302 })
1303 .collect();
1304 response.send(proto::UsersResponse { users })?;
1305 Ok(())
1306 }
1307
1308 async fn request_contact(
1309 self: Arc<Server>,
1310 request: TypedEnvelope<proto::RequestContact>,
1311 response: Response<proto::RequestContact>,
1312 ) -> Result<()> {
1313 let requester_id = self
1314 .store()
1315 .await
1316 .user_id_for_connection(request.sender_id)?;
1317 let responder_id = UserId::from_proto(request.payload.responder_id);
1318 if requester_id == responder_id {
1319 return Err(anyhow!("cannot add yourself as a contact"))?;
1320 }
1321
1322 self.app_state
1323 .db
1324 .send_contact_request(requester_id, responder_id)
1325 .await?;
1326
1327 // Update outgoing contact requests of requester
1328 let mut update = proto::UpdateContacts::default();
1329 update.outgoing_requests.push(responder_id.to_proto());
1330 for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1331 self.peer.send(connection_id, update.clone())?;
1332 }
1333
1334 // Update incoming contact requests of responder
1335 let mut update = proto::UpdateContacts::default();
1336 update
1337 .incoming_requests
1338 .push(proto::IncomingContactRequest {
1339 requester_id: requester_id.to_proto(),
1340 should_notify: true,
1341 });
1342 for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1343 self.peer.send(connection_id, update.clone())?;
1344 }
1345
1346 response.send(proto::Ack {})?;
1347 Ok(())
1348 }
1349
1350 async fn respond_to_contact_request(
1351 self: Arc<Server>,
1352 request: TypedEnvelope<proto::RespondToContactRequest>,
1353 response: Response<proto::RespondToContactRequest>,
1354 ) -> Result<()> {
1355 let responder_id = self
1356 .store()
1357 .await
1358 .user_id_for_connection(request.sender_id)?;
1359 let requester_id = UserId::from_proto(request.payload.requester_id);
1360 if request.payload.response == proto::ContactRequestResponse::Dismiss as i32 {
1361 self.app_state
1362 .db
1363 .dismiss_contact_notification(responder_id, requester_id)
1364 .await?;
1365 } else {
1366 let accept = request.payload.response == proto::ContactRequestResponse::Accept as i32;
1367 self.app_state
1368 .db
1369 .respond_to_contact_request(responder_id, requester_id, accept)
1370 .await?;
1371
1372 let store = self.store().await;
1373 // Update responder with new contact
1374 let mut update = proto::UpdateContacts::default();
1375 if accept {
1376 update
1377 .contacts
1378 .push(store.contact_for_user(requester_id, false));
1379 }
1380 update
1381 .remove_incoming_requests
1382 .push(requester_id.to_proto());
1383 for connection_id in store.connection_ids_for_user(responder_id) {
1384 self.peer.send(connection_id, update.clone())?;
1385 }
1386
1387 // Update requester with new contact
1388 let mut update = proto::UpdateContacts::default();
1389 if accept {
1390 update
1391 .contacts
1392 .push(store.contact_for_user(responder_id, true));
1393 }
1394 update
1395 .remove_outgoing_requests
1396 .push(responder_id.to_proto());
1397 for connection_id in store.connection_ids_for_user(requester_id) {
1398 self.peer.send(connection_id, update.clone())?;
1399 }
1400 }
1401
1402 response.send(proto::Ack {})?;
1403 Ok(())
1404 }
1405
1406 async fn remove_contact(
1407 self: Arc<Server>,
1408 request: TypedEnvelope<proto::RemoveContact>,
1409 response: Response<proto::RemoveContact>,
1410 ) -> Result<()> {
1411 let requester_id = self
1412 .store()
1413 .await
1414 .user_id_for_connection(request.sender_id)?;
1415 let responder_id = UserId::from_proto(request.payload.user_id);
1416 self.app_state
1417 .db
1418 .remove_contact(requester_id, responder_id)
1419 .await?;
1420
1421 // Update outgoing contact requests of requester
1422 let mut update = proto::UpdateContacts::default();
1423 update
1424 .remove_outgoing_requests
1425 .push(responder_id.to_proto());
1426 for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1427 self.peer.send(connection_id, update.clone())?;
1428 }
1429
1430 // Update incoming contact requests of responder
1431 let mut update = proto::UpdateContacts::default();
1432 update
1433 .remove_incoming_requests
1434 .push(requester_id.to_proto());
1435 for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1436 self.peer.send(connection_id, update.clone())?;
1437 }
1438
1439 response.send(proto::Ack {})?;
1440 Ok(())
1441 }
1442
1443 async fn join_channel(
1444 self: Arc<Self>,
1445 request: TypedEnvelope<proto::JoinChannel>,
1446 response: Response<proto::JoinChannel>,
1447 ) -> Result<()> {
1448 let user_id = self
1449 .store()
1450 .await
1451 .user_id_for_connection(request.sender_id)?;
1452 let channel_id = ChannelId::from_proto(request.payload.channel_id);
1453 if !self
1454 .app_state
1455 .db
1456 .can_user_access_channel(user_id, channel_id)
1457 .await?
1458 {
1459 Err(anyhow!("access denied"))?;
1460 }
1461
1462 self.store_mut()
1463 .await
1464 .join_channel(request.sender_id, channel_id);
1465 let messages = self
1466 .app_state
1467 .db
1468 .get_channel_messages(channel_id, MESSAGE_COUNT_PER_PAGE, None)
1469 .await?
1470 .into_iter()
1471 .map(|msg| proto::ChannelMessage {
1472 id: msg.id.to_proto(),
1473 body: msg.body,
1474 timestamp: msg.sent_at.unix_timestamp() as u64,
1475 sender_id: msg.sender_id.to_proto(),
1476 nonce: Some(msg.nonce.as_u128().into()),
1477 })
1478 .collect::<Vec<_>>();
1479 response.send(proto::JoinChannelResponse {
1480 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
1481 messages,
1482 })?;
1483 Ok(())
1484 }
1485
1486 async fn leave_channel(
1487 self: Arc<Self>,
1488 request: TypedEnvelope<proto::LeaveChannel>,
1489 ) -> Result<()> {
1490 let user_id = self
1491 .store()
1492 .await
1493 .user_id_for_connection(request.sender_id)?;
1494 let channel_id = ChannelId::from_proto(request.payload.channel_id);
1495 if !self
1496 .app_state
1497 .db
1498 .can_user_access_channel(user_id, channel_id)
1499 .await?
1500 {
1501 Err(anyhow!("access denied"))?;
1502 }
1503
1504 self.store_mut()
1505 .await
1506 .leave_channel(request.sender_id, channel_id);
1507
1508 Ok(())
1509 }
1510
1511 async fn send_channel_message(
1512 self: Arc<Self>,
1513 request: TypedEnvelope<proto::SendChannelMessage>,
1514 response: Response<proto::SendChannelMessage>,
1515 ) -> Result<()> {
1516 let channel_id = ChannelId::from_proto(request.payload.channel_id);
1517 let user_id;
1518 let connection_ids;
1519 {
1520 let state = self.store().await;
1521 user_id = state.user_id_for_connection(request.sender_id)?;
1522 connection_ids = state.channel_connection_ids(channel_id)?;
1523 }
1524
1525 // Validate the message body.
1526 let body = request.payload.body.trim().to_string();
1527 if body.len() > MAX_MESSAGE_LEN {
1528 return Err(anyhow!("message is too long"))?;
1529 }
1530 if body.is_empty() {
1531 return Err(anyhow!("message can't be blank"))?;
1532 }
1533
1534 let timestamp = OffsetDateTime::now_utc();
1535 let nonce = request
1536 .payload
1537 .nonce
1538 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
1539
1540 let message_id = self
1541 .app_state
1542 .db
1543 .create_channel_message(channel_id, user_id, &body, timestamp, nonce.clone().into())
1544 .await?
1545 .to_proto();
1546 let message = proto::ChannelMessage {
1547 sender_id: user_id.to_proto(),
1548 id: message_id,
1549 body,
1550 timestamp: timestamp.unix_timestamp() as u64,
1551 nonce: Some(nonce),
1552 };
1553 broadcast(request.sender_id, connection_ids, |conn_id| {
1554 self.peer.send(
1555 conn_id,
1556 proto::ChannelMessageSent {
1557 channel_id: channel_id.to_proto(),
1558 message: Some(message.clone()),
1559 },
1560 )
1561 });
1562 response.send(proto::SendChannelMessageResponse {
1563 message: Some(message),
1564 })?;
1565 Ok(())
1566 }
1567
1568 async fn get_channel_messages(
1569 self: Arc<Self>,
1570 request: TypedEnvelope<proto::GetChannelMessages>,
1571 response: Response<proto::GetChannelMessages>,
1572 ) -> Result<()> {
1573 let user_id = self
1574 .store()
1575 .await
1576 .user_id_for_connection(request.sender_id)?;
1577 let channel_id = ChannelId::from_proto(request.payload.channel_id);
1578 if !self
1579 .app_state
1580 .db
1581 .can_user_access_channel(user_id, channel_id)
1582 .await?
1583 {
1584 Err(anyhow!("access denied"))?;
1585 }
1586
1587 let messages = self
1588 .app_state
1589 .db
1590 .get_channel_messages(
1591 channel_id,
1592 MESSAGE_COUNT_PER_PAGE,
1593 Some(MessageId::from_proto(request.payload.before_message_id)),
1594 )
1595 .await?
1596 .into_iter()
1597 .map(|msg| proto::ChannelMessage {
1598 id: msg.id.to_proto(),
1599 body: msg.body,
1600 timestamp: msg.sent_at.unix_timestamp() as u64,
1601 sender_id: msg.sender_id.to_proto(),
1602 nonce: Some(msg.nonce.as_u128().into()),
1603 })
1604 .collect::<Vec<_>>();
1605 response.send(proto::GetChannelMessagesResponse {
1606 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
1607 messages,
1608 })?;
1609 Ok(())
1610 }
1611
1612 async fn store<'a>(self: &'a Arc<Self>) -> StoreReadGuard<'a> {
1613 #[cfg(test)]
1614 tokio::task::yield_now().await;
1615 let guard = self.store.read().await;
1616 #[cfg(test)]
1617 tokio::task::yield_now().await;
1618 StoreReadGuard {
1619 guard,
1620 _not_send: PhantomData,
1621 }
1622 }
1623
1624 async fn store_mut<'a>(self: &'a Arc<Self>) -> StoreWriteGuard<'a> {
1625 #[cfg(test)]
1626 tokio::task::yield_now().await;
1627 let guard = self.store.write().await;
1628 #[cfg(test)]
1629 tokio::task::yield_now().await;
1630 StoreWriteGuard {
1631 guard,
1632 _not_send: PhantomData,
1633 }
1634 }
1635
1636 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1637 ServerSnapshot {
1638 store: self.store.read().await,
1639 peer: &self.peer,
1640 }
1641 }
1642}
1643
1644impl<'a> Deref for StoreReadGuard<'a> {
1645 type Target = Store;
1646
1647 fn deref(&self) -> &Self::Target {
1648 &*self.guard
1649 }
1650}
1651
1652impl<'a> Deref for StoreWriteGuard<'a> {
1653 type Target = Store;
1654
1655 fn deref(&self) -> &Self::Target {
1656 &*self.guard
1657 }
1658}
1659
1660impl<'a> DerefMut for StoreWriteGuard<'a> {
1661 fn deref_mut(&mut self) -> &mut Self::Target {
1662 &mut *self.guard
1663 }
1664}
1665
1666impl<'a> Drop for StoreWriteGuard<'a> {
1667 fn drop(&mut self) {
1668 #[cfg(test)]
1669 self.check_invariants();
1670
1671 let metrics = self.metrics();
1672
1673 METRIC_CONNECTIONS.set(metrics.connections as _);
1674 METRIC_REGISTERED_PROJECTS.set(metrics.registered_projects as _);
1675 METRIC_ACTIVE_PROJECTS.set(metrics.active_projects as _);
1676 METRIC_SHARED_PROJECTS.set(metrics.shared_projects as _);
1677
1678 tracing::info!(
1679 connections = metrics.connections,
1680 registered_projects = metrics.registered_projects,
1681 active_projects = metrics.active_projects,
1682 shared_projects = metrics.shared_projects,
1683 "metrics"
1684 );
1685 }
1686}
1687
1688impl Executor for RealExecutor {
1689 type Sleep = Sleep;
1690
1691 fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F) {
1692 tokio::task::spawn(future);
1693 }
1694
1695 fn sleep(&self, duration: Duration) -> Self::Sleep {
1696 tokio::time::sleep(duration)
1697 }
1698}
1699
1700fn broadcast<F>(
1701 sender_id: ConnectionId,
1702 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1703 mut f: F,
1704) where
1705 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1706{
1707 for receiver_id in receiver_ids {
1708 if receiver_id != sender_id {
1709 f(receiver_id).trace_err();
1710 }
1711 }
1712}
1713
1714lazy_static! {
1715 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
1716}
1717
1718pub struct ProtocolVersion(u32);
1719
1720impl Header for ProtocolVersion {
1721 fn name() -> &'static HeaderName {
1722 &ZED_PROTOCOL_VERSION
1723 }
1724
1725 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1726 where
1727 Self: Sized,
1728 I: Iterator<Item = &'i axum::http::HeaderValue>,
1729 {
1730 let version = values
1731 .next()
1732 .ok_or_else(|| axum::headers::Error::invalid())?
1733 .to_str()
1734 .map_err(|_| axum::headers::Error::invalid())?
1735 .parse()
1736 .map_err(|_| axum::headers::Error::invalid())?;
1737 Ok(Self(version))
1738 }
1739
1740 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1741 values.extend([self.0.to_string().parse().unwrap()]);
1742 }
1743}
1744
1745pub fn routes(server: Arc<Server>) -> Router<Body> {
1746 Router::new()
1747 .route("/rpc", get(handle_websocket_request))
1748 .layer(
1749 ServiceBuilder::new()
1750 .layer(Extension(server.app_state.clone()))
1751 .layer(middleware::from_fn(auth::validate_header)),
1752 )
1753 .route("/metrics", get(handle_metrics))
1754 .layer(Extension(server))
1755}
1756
1757pub async fn handle_websocket_request(
1758 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1759 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1760 Extension(server): Extension<Arc<Server>>,
1761 Extension(user): Extension<User>,
1762 ws: WebSocketUpgrade,
1763) -> axum::response::Response {
1764 if protocol_version != rpc::PROTOCOL_VERSION {
1765 return (
1766 StatusCode::UPGRADE_REQUIRED,
1767 "client must be upgraded".to_string(),
1768 )
1769 .into_response();
1770 }
1771 let socket_address = socket_address.to_string();
1772 ws.on_upgrade(move |socket| {
1773 use util::ResultExt;
1774 let socket = socket
1775 .map_ok(to_tungstenite_message)
1776 .err_into()
1777 .with(|message| async move { Ok(to_axum_message(message)) });
1778 let connection = Connection::new(Box::pin(socket));
1779 async move {
1780 server
1781 .handle_connection(connection, socket_address, user, None, RealExecutor)
1782 .await
1783 .log_err();
1784 }
1785 })
1786}
1787
1788pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> axum::response::Response {
1789 // We call `store_mut` here for its side effects of updating metrics.
1790 server.store_mut().await;
1791
1792 let encoder = prometheus::TextEncoder::new();
1793 let metric_families = prometheus::gather();
1794 match encoder.encode_to_string(&metric_families) {
1795 Ok(string) => (StatusCode::OK, string).into_response(),
1796 Err(error) => (
1797 StatusCode::INTERNAL_SERVER_ERROR,
1798 format!("failed to encode metrics {:?}", error),
1799 )
1800 .into_response(),
1801 }
1802}
1803
1804fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1805 match message {
1806 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1807 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1808 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1809 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1810 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1811 code: frame.code.into(),
1812 reason: frame.reason,
1813 })),
1814 }
1815}
1816
1817fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1818 match message {
1819 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1820 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1821 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1822 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1823 AxumMessage::Close(frame) => {
1824 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1825 code: frame.code.into(),
1826 reason: frame.reason,
1827 }))
1828 }
1829 }
1830}
1831
1832pub trait ResultExt {
1833 type Ok;
1834
1835 fn trace_err(self) -> Option<Self::Ok>;
1836}
1837
1838impl<T, E> ResultExt for Result<T, E>
1839where
1840 E: std::fmt::Debug,
1841{
1842 type Ok = T;
1843
1844 fn trace_err(self) -> Option<T> {
1845 match self {
1846 Ok(value) => Some(value),
1847 Err(error) => {
1848 tracing::error!("{:?}", error);
1849 None
1850 }
1851 }
1852 }
1853}