1mod store;
2
3use crate::{
4 auth,
5 db::{self, ChannelId, MessageId, User, UserId},
6 AppState, Result,
7};
8use anyhow::anyhow;
9use async_tungstenite::tungstenite::{
10 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
11};
12use axum::{
13 body::Body,
14 extract::{
15 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
16 ConnectInfo, WebSocketUpgrade,
17 },
18 headers::{Header, HeaderName},
19 http::StatusCode,
20 middleware,
21 response::IntoResponse,
22 routing::get,
23 Extension, Router, TypedHeader,
24};
25use collections::HashMap;
26use futures::{
27 channel::mpsc,
28 future::{self, BoxFuture},
29 FutureExt, SinkExt, StreamExt, TryStreamExt,
30};
31use lazy_static::lazy_static;
32use rpc::{
33 proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
34 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
35};
36use serde::{Serialize, Serializer};
37use std::{
38 any::TypeId,
39 future::Future,
40 marker::PhantomData,
41 net::SocketAddr,
42 ops::{Deref, DerefMut},
43 rc::Rc,
44 sync::{
45 atomic::{AtomicBool, Ordering::SeqCst},
46 Arc,
47 },
48 time::Duration,
49};
50use time::OffsetDateTime;
51use tokio::{
52 sync::{RwLock, RwLockReadGuard, RwLockWriteGuard},
53 time::Sleep,
54};
55use tower::ServiceBuilder;
56use tracing::{info_span, instrument, Instrument};
57
58pub use store::{Store, Worktree};
59
60type MessageHandler =
61 Box<dyn Send + Sync + Fn(Arc<Server>, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, ()>>;
62
63struct Response<R> {
64 server: Arc<Server>,
65 receipt: Receipt<R>,
66 responded: Arc<AtomicBool>,
67}
68
69impl<R: RequestMessage> Response<R> {
70 fn send(self, payload: R::Response) -> Result<()> {
71 self.responded.store(true, SeqCst);
72 self.server.peer.respond(self.receipt, payload)?;
73 Ok(())
74 }
75
76 fn into_receipt(self) -> Receipt<R> {
77 self.responded.store(true, SeqCst);
78 self.receipt
79 }
80}
81
82pub struct Server {
83 peer: Arc<Peer>,
84 pub(crate) store: RwLock<Store>,
85 app_state: Arc<AppState>,
86 handlers: HashMap<TypeId, MessageHandler>,
87 notifications: Option<mpsc::UnboundedSender<()>>,
88}
89
90pub trait Executor: Send + Clone {
91 type Sleep: Send + Future;
92 fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F);
93 fn sleep(&self, duration: Duration) -> Self::Sleep;
94}
95
96#[derive(Clone)]
97pub struct RealExecutor;
98
99const MESSAGE_COUNT_PER_PAGE: usize = 100;
100const MAX_MESSAGE_LEN: usize = 1024;
101
102struct StoreReadGuard<'a> {
103 guard: RwLockReadGuard<'a, Store>,
104 _not_send: PhantomData<Rc<()>>,
105}
106
107struct StoreWriteGuard<'a> {
108 guard: RwLockWriteGuard<'a, Store>,
109 _not_send: PhantomData<Rc<()>>,
110}
111
112#[derive(Serialize)]
113pub struct ServerSnapshot<'a> {
114 peer: &'a Peer,
115 #[serde(serialize_with = "serialize_deref")]
116 store: RwLockReadGuard<'a, Store>,
117}
118
119pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
120where
121 S: Serializer,
122 T: Deref<Target = U>,
123 U: Serialize,
124{
125 Serialize::serialize(value.deref(), serializer)
126}
127
128impl Server {
129 pub fn new(
130 app_state: Arc<AppState>,
131 notifications: Option<mpsc::UnboundedSender<()>>,
132 ) -> Arc<Self> {
133 let mut server = Self {
134 peer: Peer::new(),
135 app_state,
136 store: Default::default(),
137 handlers: Default::default(),
138 notifications,
139 };
140
141 server
142 .add_request_handler(Server::ping)
143 .add_request_handler(Server::register_project)
144 .add_message_handler(Server::unregister_project)
145 .add_request_handler(Server::join_project)
146 .add_message_handler(Server::leave_project)
147 .add_message_handler(Server::respond_to_join_project_request)
148 .add_request_handler(Server::register_worktree)
149 .add_message_handler(Server::unregister_worktree)
150 .add_request_handler(Server::update_worktree)
151 .add_message_handler(Server::start_language_server)
152 .add_message_handler(Server::update_language_server)
153 .add_message_handler(Server::update_diagnostic_summary)
154 .add_request_handler(Server::forward_project_request::<proto::GetDefinition>)
155 .add_request_handler(Server::forward_project_request::<proto::GetReferences>)
156 .add_request_handler(Server::forward_project_request::<proto::SearchProject>)
157 .add_request_handler(Server::forward_project_request::<proto::GetDocumentHighlights>)
158 .add_request_handler(Server::forward_project_request::<proto::GetProjectSymbols>)
159 .add_request_handler(Server::forward_project_request::<proto::OpenBufferForSymbol>)
160 .add_request_handler(Server::forward_project_request::<proto::OpenBufferById>)
161 .add_request_handler(Server::forward_project_request::<proto::OpenBufferByPath>)
162 .add_request_handler(Server::forward_project_request::<proto::GetCompletions>)
163 .add_request_handler(
164 Server::forward_project_request::<proto::ApplyCompletionAdditionalEdits>,
165 )
166 .add_request_handler(Server::forward_project_request::<proto::GetCodeActions>)
167 .add_request_handler(Server::forward_project_request::<proto::ApplyCodeAction>)
168 .add_request_handler(Server::forward_project_request::<proto::PrepareRename>)
169 .add_request_handler(Server::forward_project_request::<proto::PerformRename>)
170 .add_request_handler(Server::forward_project_request::<proto::ReloadBuffers>)
171 .add_request_handler(Server::forward_project_request::<proto::FormatBuffers>)
172 .add_request_handler(Server::forward_project_request::<proto::CreateProjectEntry>)
173 .add_request_handler(Server::forward_project_request::<proto::RenameProjectEntry>)
174 .add_request_handler(Server::forward_project_request::<proto::CopyProjectEntry>)
175 .add_request_handler(Server::forward_project_request::<proto::DeleteProjectEntry>)
176 .add_request_handler(Server::update_buffer)
177 .add_message_handler(Server::update_buffer_file)
178 .add_message_handler(Server::buffer_reloaded)
179 .add_message_handler(Server::buffer_saved)
180 .add_request_handler(Server::save_buffer)
181 .add_request_handler(Server::get_channels)
182 .add_request_handler(Server::get_users)
183 .add_request_handler(Server::fuzzy_search_users)
184 .add_request_handler(Server::request_contact)
185 .add_request_handler(Server::remove_contact)
186 .add_request_handler(Server::respond_to_contact_request)
187 .add_request_handler(Server::join_channel)
188 .add_message_handler(Server::leave_channel)
189 .add_request_handler(Server::send_channel_message)
190 .add_request_handler(Server::follow)
191 .add_message_handler(Server::unfollow)
192 .add_message_handler(Server::update_followers)
193 .add_request_handler(Server::get_channel_messages);
194
195 Arc::new(server)
196 }
197
198 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
199 where
200 F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>) -> Fut,
201 Fut: 'static + Send + Future<Output = Result<()>>,
202 M: EnvelopedMessage,
203 {
204 let prev_handler = self.handlers.insert(
205 TypeId::of::<M>(),
206 Box::new(move |server, envelope| {
207 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
208 let span = info_span!(
209 "handle message",
210 payload_type = envelope.payload_type_name(),
211 payload = format!("{:?}", envelope.payload).as_str(),
212 );
213 let future = (handler)(server, *envelope);
214 async move {
215 if let Err(error) = future.await {
216 tracing::error!(%error, "error handling message");
217 }
218 }
219 .instrument(span)
220 .boxed()
221 }),
222 );
223 if prev_handler.is_some() {
224 panic!("registered a handler for the same message twice");
225 }
226 self
227 }
228
229 /// Handle a request while holding a lock to the store. This is useful when we're registering
230 /// a connection but we want to respond on the connection before anybody else can send on it.
231 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
232 where
233 F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>, Response<M>) -> Fut,
234 Fut: Send + Future<Output = Result<()>>,
235 M: RequestMessage,
236 {
237 let handler = Arc::new(handler);
238 self.add_message_handler(move |server, envelope| {
239 let receipt = envelope.receipt();
240 let handler = handler.clone();
241 async move {
242 let responded = Arc::new(AtomicBool::default());
243 let response = Response {
244 server: server.clone(),
245 responded: responded.clone(),
246 receipt: envelope.receipt(),
247 };
248 match (handler)(server.clone(), envelope, response).await {
249 Ok(()) => {
250 if responded.load(std::sync::atomic::Ordering::SeqCst) {
251 Ok(())
252 } else {
253 Err(anyhow!("handler did not send a response"))?
254 }
255 }
256 Err(error) => {
257 server.peer.respond_with_error(
258 receipt,
259 proto::Error {
260 message: error.to_string(),
261 },
262 )?;
263 Err(error)
264 }
265 }
266 }
267 })
268 }
269
270 pub fn handle_connection<E: Executor>(
271 self: &Arc<Self>,
272 connection: Connection,
273 address: String,
274 user: User,
275 mut send_connection_id: Option<mpsc::Sender<ConnectionId>>,
276 executor: E,
277 ) -> impl Future<Output = Result<()>> {
278 let mut this = self.clone();
279 let user_id = user.id;
280 let login = user.github_login;
281 let span = info_span!("handle connection", %user_id, %login, %address);
282 async move {
283 let (connection_id, handle_io, mut incoming_rx) = this
284 .peer
285 .add_connection(connection, {
286 let executor = executor.clone();
287 move |duration| {
288 let timer = executor.sleep(duration);
289 async move {
290 timer.await;
291 }
292 }
293 })
294 .await;
295
296 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
297
298 if let Some(send_connection_id) = send_connection_id.as_mut() {
299 let _ = send_connection_id.send(connection_id).await;
300 }
301
302 if !user.connected_once {
303 this.peer.send(connection_id, proto::ShowContacts {})?;
304 this.app_state.db.set_user_connected_once(user_id, true).await?;
305 }
306
307 let (contacts, invite_code) = future::try_join(
308 this.app_state.db.get_contacts(user_id),
309 this.app_state.db.get_invite_code_for_user(user_id)
310 ).await?;
311
312 {
313 let mut store = this.store_mut().await;
314 store.add_connection(connection_id, user_id);
315 this.peer.send(connection_id, store.build_initial_contacts_update(contacts))?;
316
317 if let Some((code, count)) = invite_code {
318 this.peer.send(connection_id, proto::UpdateInviteInfo {
319 url: format!("{}{}", this.app_state.invite_link_prefix, code),
320 count,
321 })?;
322 }
323 }
324 this.update_user_contacts(user_id).await?;
325
326 let handle_io = handle_io.fuse();
327 futures::pin_mut!(handle_io);
328 loop {
329 let next_message = incoming_rx.next().fuse();
330 futures::pin_mut!(next_message);
331 futures::select_biased! {
332 result = handle_io => {
333 if let Err(error) = result {
334 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
335 }
336 break;
337 }
338 message = next_message => {
339 if let Some(message) = message {
340 let type_name = message.payload_type_name();
341 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
342 async {
343 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
344 let notifications = this.notifications.clone();
345 let is_background = message.is_background();
346 let handle_message = (handler)(this.clone(), message);
347 let handle_message = async move {
348 handle_message.await;
349 if let Some(mut notifications) = notifications {
350 let _ = notifications.send(()).await;
351 }
352 };
353 if is_background {
354 executor.spawn_detached(handle_message);
355 } else {
356 handle_message.await;
357 }
358 } else {
359 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
360 }
361 }.instrument(span).await;
362 } else {
363 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
364 break;
365 }
366 }
367 }
368 }
369
370 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
371 if let Err(error) = this.sign_out(connection_id).await {
372 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
373 }
374
375 Ok(())
376 }.instrument(span)
377 }
378
379 #[instrument(skip(self), err)]
380 async fn sign_out(self: &mut Arc<Self>, connection_id: ConnectionId) -> Result<()> {
381 self.peer.disconnect(connection_id);
382
383 let removed_user_id = {
384 let mut store = self.store_mut().await;
385 let removed_connection = store.remove_connection(connection_id)?;
386
387 for (project_id, project) in removed_connection.hosted_projects {
388 broadcast(connection_id, project.guests.keys().copied(), |conn_id| {
389 self.peer
390 .send(conn_id, proto::UnregisterProject { project_id })
391 });
392
393 for (_, receipts) in project.join_requests {
394 for receipt in receipts {
395 self.peer.respond(
396 receipt,
397 proto::JoinProjectResponse {
398 variant: Some(proto::join_project_response::Variant::Decline(
399 proto::join_project_response::Decline {
400 reason: proto::join_project_response::decline::Reason::WentOffline as i32
401 },
402 )),
403 },
404 )?;
405 }
406 }
407 }
408
409 for project_id in removed_connection.guest_project_ids {
410 if let Some(project) = store.project(project_id).trace_err() {
411 broadcast(connection_id, project.connection_ids(), |conn_id| {
412 self.peer.send(
413 conn_id,
414 proto::RemoveProjectCollaborator {
415 project_id,
416 peer_id: connection_id.0,
417 },
418 )
419 });
420 if project.guests.is_empty() {
421 self.peer
422 .send(
423 project.host_connection_id,
424 proto::ProjectUnshared { project_id },
425 )
426 .trace_err();
427 }
428 }
429 }
430
431 removed_connection.user_id
432 };
433
434 self.update_user_contacts(removed_user_id).await?;
435
436 Ok(())
437 }
438
439 pub async fn invite_code_redeemed(
440 self: &Arc<Self>,
441 code: &str,
442 invitee_id: UserId,
443 ) -> Result<()> {
444 let user = self.app_state.db.get_user_for_invite_code(code).await?;
445 let store = self.store().await;
446 let invitee_contact = store.contact_for_user(invitee_id, true);
447 for connection_id in store.connection_ids_for_user(user.id) {
448 self.peer.send(
449 connection_id,
450 proto::UpdateContacts {
451 contacts: vec![invitee_contact.clone()],
452 ..Default::default()
453 },
454 )?;
455 self.peer.send(
456 connection_id,
457 proto::UpdateInviteInfo {
458 url: format!("{}{}", self.app_state.invite_link_prefix, code),
459 count: user.invite_count as u32,
460 },
461 )?;
462 }
463 Ok(())
464 }
465
466 async fn ping(
467 self: Arc<Server>,
468 _: TypedEnvelope<proto::Ping>,
469 response: Response<proto::Ping>,
470 ) -> Result<()> {
471 response.send(proto::Ack {})?;
472 Ok(())
473 }
474
475 async fn register_project(
476 self: Arc<Server>,
477 request: TypedEnvelope<proto::RegisterProject>,
478 response: Response<proto::RegisterProject>,
479 ) -> Result<()> {
480 let user_id;
481 let project_id;
482 {
483 let mut state = self.store_mut().await;
484 user_id = state.user_id_for_connection(request.sender_id)?;
485 project_id = state.register_project(request.sender_id, user_id);
486 };
487 self.update_user_contacts(user_id).await?;
488 response.send(proto::RegisterProjectResponse { project_id })?;
489 Ok(())
490 }
491
492 async fn unregister_project(
493 self: Arc<Server>,
494 request: TypedEnvelope<proto::UnregisterProject>,
495 ) -> Result<()> {
496 let (user_id, project) = {
497 let mut state = self.store_mut().await;
498 let project =
499 state.unregister_project(request.payload.project_id, request.sender_id)?;
500 (state.user_id_for_connection(request.sender_id)?, project)
501 };
502
503 broadcast(
504 request.sender_id,
505 project.guests.keys().copied(),
506 |conn_id| {
507 self.peer.send(
508 conn_id,
509 proto::UnregisterProject {
510 project_id: request.payload.project_id,
511 },
512 )
513 },
514 );
515 for (_, receipts) in project.join_requests {
516 for receipt in receipts {
517 self.peer.respond(
518 receipt,
519 proto::JoinProjectResponse {
520 variant: Some(proto::join_project_response::Variant::Decline(
521 proto::join_project_response::Decline {
522 reason: proto::join_project_response::decline::Reason::Closed
523 as i32,
524 },
525 )),
526 },
527 )?;
528 }
529 }
530
531 self.update_user_contacts(user_id).await?;
532 Ok(())
533 }
534
535 async fn update_user_contacts(self: &Arc<Server>, user_id: UserId) -> Result<()> {
536 let contacts = self.app_state.db.get_contacts(user_id).await?;
537 let store = self.store().await;
538 let updated_contact = store.contact_for_user(user_id, false);
539 for contact in contacts {
540 if let db::Contact::Accepted {
541 user_id: contact_user_id,
542 ..
543 } = contact
544 {
545 for contact_conn_id in store.connection_ids_for_user(contact_user_id) {
546 self.peer
547 .send(
548 contact_conn_id,
549 proto::UpdateContacts {
550 contacts: vec![updated_contact.clone()],
551 remove_contacts: Default::default(),
552 incoming_requests: Default::default(),
553 remove_incoming_requests: Default::default(),
554 outgoing_requests: Default::default(),
555 remove_outgoing_requests: Default::default(),
556 },
557 )
558 .trace_err();
559 }
560 }
561 }
562 Ok(())
563 }
564
565 async fn join_project(
566 self: Arc<Server>,
567 request: TypedEnvelope<proto::JoinProject>,
568 response: Response<proto::JoinProject>,
569 ) -> Result<()> {
570 let project_id = request.payload.project_id;
571 let host_user_id;
572 let guest_user_id;
573 let host_connection_id;
574 {
575 let state = self.store().await;
576 let project = state.project(project_id)?;
577 host_user_id = project.host_user_id;
578 host_connection_id = project.host_connection_id;
579 guest_user_id = state.user_id_for_connection(request.sender_id)?;
580 };
581
582 let has_contact = self
583 .app_state
584 .db
585 .has_contact(guest_user_id, host_user_id)
586 .await?;
587 if !has_contact {
588 return Err(anyhow!("no such project"))?;
589 }
590
591 self.store_mut().await.request_join_project(
592 guest_user_id,
593 project_id,
594 response.into_receipt(),
595 )?;
596 self.peer.send(
597 host_connection_id,
598 proto::RequestJoinProject {
599 project_id,
600 requester_id: guest_user_id.to_proto(),
601 },
602 )?;
603 Ok(())
604 }
605
606 async fn respond_to_join_project_request(
607 self: Arc<Server>,
608 request: TypedEnvelope<proto::RespondToJoinProjectRequest>,
609 ) -> Result<()> {
610 let host_user_id;
611
612 {
613 let mut state = self.store_mut().await;
614 let project_id = request.payload.project_id;
615 let project = state.project(project_id)?;
616 if project.host_connection_id != request.sender_id {
617 Err(anyhow!("no such connection"))?;
618 }
619
620 host_user_id = project.host_user_id;
621 let guest_user_id = UserId::from_proto(request.payload.requester_id);
622
623 if !request.payload.allow {
624 let receipts = state
625 .deny_join_project_request(request.sender_id, guest_user_id, project_id)
626 .ok_or_else(|| anyhow!("no such request"))?;
627 for receipt in receipts {
628 self.peer.respond(
629 receipt,
630 proto::JoinProjectResponse {
631 variant: Some(proto::join_project_response::Variant::Decline(
632 proto::join_project_response::Decline {
633 reason: proto::join_project_response::decline::Reason::Declined
634 as i32,
635 },
636 )),
637 },
638 )?;
639 }
640 return Ok(());
641 }
642
643 let (receipts_with_replica_ids, project) = state
644 .accept_join_project_request(request.sender_id, guest_user_id, project_id)
645 .ok_or_else(|| anyhow!("no such request"))?;
646
647 let peer_count = project.guests.len();
648 let mut collaborators = Vec::with_capacity(peer_count);
649 collaborators.push(proto::Collaborator {
650 peer_id: project.host_connection_id.0,
651 replica_id: 0,
652 user_id: project.host_user_id.to_proto(),
653 });
654 let worktrees = project
655 .worktrees
656 .iter()
657 .filter_map(|(id, shared_worktree)| {
658 let worktree = project.worktrees.get(&id)?;
659 Some(proto::Worktree {
660 id: *id,
661 root_name: worktree.root_name.clone(),
662 entries: shared_worktree.entries.values().cloned().collect(),
663 diagnostic_summaries: shared_worktree
664 .diagnostic_summaries
665 .values()
666 .cloned()
667 .collect(),
668 visible: worktree.visible,
669 scan_id: shared_worktree.scan_id,
670 })
671 })
672 .collect::<Vec<_>>();
673
674 // Add all guests other than the requesting user's own connections as collaborators
675 for (peer_conn_id, (peer_replica_id, peer_user_id)) in &project.guests {
676 if receipts_with_replica_ids
677 .iter()
678 .all(|(receipt, _)| receipt.sender_id != *peer_conn_id)
679 {
680 collaborators.push(proto::Collaborator {
681 peer_id: peer_conn_id.0,
682 replica_id: *peer_replica_id as u32,
683 user_id: peer_user_id.to_proto(),
684 });
685 }
686 }
687
688 for conn_id in project.connection_ids() {
689 for (receipt, replica_id) in &receipts_with_replica_ids {
690 if conn_id != receipt.sender_id {
691 self.peer.send(
692 conn_id,
693 proto::AddProjectCollaborator {
694 project_id,
695 collaborator: Some(proto::Collaborator {
696 peer_id: receipt.sender_id.0,
697 replica_id: *replica_id as u32,
698 user_id: guest_user_id.to_proto(),
699 }),
700 },
701 )?;
702 }
703 }
704 }
705
706 for (receipt, replica_id) in receipts_with_replica_ids {
707 self.peer.respond(
708 receipt,
709 proto::JoinProjectResponse {
710 variant: Some(proto::join_project_response::Variant::Accept(
711 proto::join_project_response::Accept {
712 worktrees: worktrees.clone(),
713 replica_id: replica_id as u32,
714 collaborators: collaborators.clone(),
715 language_servers: project.language_servers.clone(),
716 },
717 )),
718 },
719 )?;
720 }
721 }
722
723 self.update_user_contacts(host_user_id).await?;
724 Ok(())
725 }
726
727 async fn leave_project(
728 self: Arc<Server>,
729 request: TypedEnvelope<proto::LeaveProject>,
730 ) -> Result<()> {
731 let sender_id = request.sender_id;
732 let project_id = request.payload.project_id;
733 let project;
734 {
735 let mut store = self.store_mut().await;
736 project = store.leave_project(sender_id, project_id)?;
737
738 if project.remove_collaborator {
739 broadcast(sender_id, project.connection_ids, |conn_id| {
740 self.peer.send(
741 conn_id,
742 proto::RemoveProjectCollaborator {
743 project_id,
744 peer_id: sender_id.0,
745 },
746 )
747 });
748 }
749
750 if let Some(requester_id) = project.cancel_request {
751 self.peer.send(
752 project.host_connection_id,
753 proto::JoinProjectRequestCancelled {
754 project_id,
755 requester_id: requester_id.to_proto(),
756 },
757 )?;
758 }
759
760 if project.unshare {
761 self.peer.send(
762 project.host_connection_id,
763 proto::ProjectUnshared { project_id },
764 )?;
765 }
766 }
767 self.update_user_contacts(project.host_user_id).await?;
768 Ok(())
769 }
770
771 async fn register_worktree(
772 self: Arc<Server>,
773 request: TypedEnvelope<proto::RegisterWorktree>,
774 response: Response<proto::RegisterWorktree>,
775 ) -> Result<()> {
776 let host_user_id;
777 {
778 let mut state = self.store_mut().await;
779 host_user_id = state.user_id_for_connection(request.sender_id)?;
780
781 let guest_connection_ids = state
782 .read_project(request.payload.project_id, request.sender_id)?
783 .guest_connection_ids();
784 state.register_worktree(
785 request.payload.project_id,
786 request.payload.worktree_id,
787 request.sender_id,
788 Worktree {
789 root_name: request.payload.root_name.clone(),
790 visible: request.payload.visible,
791 ..Default::default()
792 },
793 )?;
794
795 broadcast(request.sender_id, guest_connection_ids, |connection_id| {
796 self.peer
797 .forward_send(request.sender_id, connection_id, request.payload.clone())
798 });
799 }
800 self.update_user_contacts(host_user_id).await?;
801 response.send(proto::Ack {})?;
802 Ok(())
803 }
804
805 async fn unregister_worktree(
806 self: Arc<Server>,
807 request: TypedEnvelope<proto::UnregisterWorktree>,
808 ) -> Result<()> {
809 let host_user_id;
810 let project_id = request.payload.project_id;
811 let worktree_id = request.payload.worktree_id;
812 {
813 let mut state = self.store_mut().await;
814 let (_, guest_connection_ids) =
815 state.unregister_worktree(project_id, worktree_id, request.sender_id)?;
816 host_user_id = state.user_id_for_connection(request.sender_id)?;
817 broadcast(request.sender_id, guest_connection_ids, |conn_id| {
818 self.peer.send(
819 conn_id,
820 proto::UnregisterWorktree {
821 project_id,
822 worktree_id,
823 },
824 )
825 });
826 }
827 self.update_user_contacts(host_user_id).await?;
828 Ok(())
829 }
830
831 async fn update_worktree(
832 self: Arc<Server>,
833 request: TypedEnvelope<proto::UpdateWorktree>,
834 response: Response<proto::UpdateWorktree>,
835 ) -> Result<()> {
836 let connection_ids = self.store_mut().await.update_worktree(
837 request.sender_id,
838 request.payload.project_id,
839 request.payload.worktree_id,
840 &request.payload.removed_entries,
841 &request.payload.updated_entries,
842 request.payload.scan_id,
843 )?;
844
845 broadcast(request.sender_id, connection_ids, |connection_id| {
846 self.peer
847 .forward_send(request.sender_id, connection_id, request.payload.clone())
848 });
849 response.send(proto::Ack {})?;
850 Ok(())
851 }
852
853 async fn update_diagnostic_summary(
854 self: Arc<Server>,
855 request: TypedEnvelope<proto::UpdateDiagnosticSummary>,
856 ) -> Result<()> {
857 let summary = request
858 .payload
859 .summary
860 .clone()
861 .ok_or_else(|| anyhow!("invalid summary"))?;
862 let receiver_ids = self.store_mut().await.update_diagnostic_summary(
863 request.payload.project_id,
864 request.payload.worktree_id,
865 request.sender_id,
866 summary,
867 )?;
868
869 broadcast(request.sender_id, receiver_ids, |connection_id| {
870 self.peer
871 .forward_send(request.sender_id, connection_id, request.payload.clone())
872 });
873 Ok(())
874 }
875
876 async fn start_language_server(
877 self: Arc<Server>,
878 request: TypedEnvelope<proto::StartLanguageServer>,
879 ) -> Result<()> {
880 let receiver_ids = self.store_mut().await.start_language_server(
881 request.payload.project_id,
882 request.sender_id,
883 request
884 .payload
885 .server
886 .clone()
887 .ok_or_else(|| anyhow!("invalid language server"))?,
888 )?;
889 broadcast(request.sender_id, receiver_ids, |connection_id| {
890 self.peer
891 .forward_send(request.sender_id, connection_id, request.payload.clone())
892 });
893 Ok(())
894 }
895
896 async fn update_language_server(
897 self: Arc<Server>,
898 request: TypedEnvelope<proto::UpdateLanguageServer>,
899 ) -> Result<()> {
900 let receiver_ids = self
901 .store()
902 .await
903 .project_connection_ids(request.payload.project_id, request.sender_id)?;
904 broadcast(request.sender_id, receiver_ids, |connection_id| {
905 self.peer
906 .forward_send(request.sender_id, connection_id, request.payload.clone())
907 });
908 Ok(())
909 }
910
911 async fn forward_project_request<T>(
912 self: Arc<Server>,
913 request: TypedEnvelope<T>,
914 response: Response<T>,
915 ) -> Result<()>
916 where
917 T: EntityMessage + RequestMessage,
918 {
919 let host_connection_id = self
920 .store()
921 .await
922 .read_project(request.payload.remote_entity_id(), request.sender_id)?
923 .host_connection_id;
924
925 response.send(
926 self.peer
927 .forward_request(request.sender_id, host_connection_id, request.payload)
928 .await?,
929 )?;
930 Ok(())
931 }
932
933 async fn save_buffer(
934 self: Arc<Server>,
935 request: TypedEnvelope<proto::SaveBuffer>,
936 response: Response<proto::SaveBuffer>,
937 ) -> Result<()> {
938 let host = self
939 .store()
940 .await
941 .read_project(request.payload.project_id, request.sender_id)?
942 .host_connection_id;
943 let response_payload = self
944 .peer
945 .forward_request(request.sender_id, host, request.payload.clone())
946 .await?;
947
948 let mut guests = self
949 .store()
950 .await
951 .read_project(request.payload.project_id, request.sender_id)?
952 .connection_ids();
953 guests.retain(|guest_connection_id| *guest_connection_id != request.sender_id);
954 broadcast(host, guests, |conn_id| {
955 self.peer
956 .forward_send(host, conn_id, response_payload.clone())
957 });
958 response.send(response_payload)?;
959 Ok(())
960 }
961
962 async fn update_buffer(
963 self: Arc<Server>,
964 request: TypedEnvelope<proto::UpdateBuffer>,
965 response: Response<proto::UpdateBuffer>,
966 ) -> Result<()> {
967 let receiver_ids = self
968 .store()
969 .await
970 .project_connection_ids(request.payload.project_id, request.sender_id)?;
971 broadcast(request.sender_id, receiver_ids, |connection_id| {
972 self.peer
973 .forward_send(request.sender_id, connection_id, request.payload.clone())
974 });
975 response.send(proto::Ack {})?;
976 Ok(())
977 }
978
979 async fn update_buffer_file(
980 self: Arc<Server>,
981 request: TypedEnvelope<proto::UpdateBufferFile>,
982 ) -> Result<()> {
983 let receiver_ids = self
984 .store()
985 .await
986 .project_connection_ids(request.payload.project_id, request.sender_id)?;
987 broadcast(request.sender_id, receiver_ids, |connection_id| {
988 self.peer
989 .forward_send(request.sender_id, connection_id, request.payload.clone())
990 });
991 Ok(())
992 }
993
994 async fn buffer_reloaded(
995 self: Arc<Server>,
996 request: TypedEnvelope<proto::BufferReloaded>,
997 ) -> Result<()> {
998 let receiver_ids = self
999 .store()
1000 .await
1001 .project_connection_ids(request.payload.project_id, request.sender_id)?;
1002 broadcast(request.sender_id, receiver_ids, |connection_id| {
1003 self.peer
1004 .forward_send(request.sender_id, connection_id, request.payload.clone())
1005 });
1006 Ok(())
1007 }
1008
1009 async fn buffer_saved(
1010 self: Arc<Server>,
1011 request: TypedEnvelope<proto::BufferSaved>,
1012 ) -> Result<()> {
1013 let receiver_ids = self
1014 .store()
1015 .await
1016 .project_connection_ids(request.payload.project_id, request.sender_id)?;
1017 broadcast(request.sender_id, receiver_ids, |connection_id| {
1018 self.peer
1019 .forward_send(request.sender_id, connection_id, request.payload.clone())
1020 });
1021 Ok(())
1022 }
1023
1024 async fn follow(
1025 self: Arc<Self>,
1026 request: TypedEnvelope<proto::Follow>,
1027 response: Response<proto::Follow>,
1028 ) -> Result<()> {
1029 let leader_id = ConnectionId(request.payload.leader_id);
1030 let follower_id = request.sender_id;
1031 if !self
1032 .store()
1033 .await
1034 .project_connection_ids(request.payload.project_id, follower_id)?
1035 .contains(&leader_id)
1036 {
1037 Err(anyhow!("no such peer"))?;
1038 }
1039 let mut response_payload = self
1040 .peer
1041 .forward_request(request.sender_id, leader_id, request.payload)
1042 .await?;
1043 response_payload
1044 .views
1045 .retain(|view| view.leader_id != Some(follower_id.0));
1046 response.send(response_payload)?;
1047 Ok(())
1048 }
1049
1050 async fn unfollow(self: Arc<Self>, request: TypedEnvelope<proto::Unfollow>) -> Result<()> {
1051 let leader_id = ConnectionId(request.payload.leader_id);
1052 if !self
1053 .store()
1054 .await
1055 .project_connection_ids(request.payload.project_id, request.sender_id)?
1056 .contains(&leader_id)
1057 {
1058 Err(anyhow!("no such peer"))?;
1059 }
1060 self.peer
1061 .forward_send(request.sender_id, leader_id, request.payload)?;
1062 Ok(())
1063 }
1064
1065 async fn update_followers(
1066 self: Arc<Self>,
1067 request: TypedEnvelope<proto::UpdateFollowers>,
1068 ) -> Result<()> {
1069 let connection_ids = self
1070 .store()
1071 .await
1072 .project_connection_ids(request.payload.project_id, request.sender_id)?;
1073 let leader_id = request
1074 .payload
1075 .variant
1076 .as_ref()
1077 .and_then(|variant| match variant {
1078 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1079 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1080 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1081 });
1082 for follower_id in &request.payload.follower_ids {
1083 let follower_id = ConnectionId(*follower_id);
1084 if connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
1085 self.peer
1086 .forward_send(request.sender_id, follower_id, request.payload.clone())?;
1087 }
1088 }
1089 Ok(())
1090 }
1091
1092 async fn get_channels(
1093 self: Arc<Server>,
1094 request: TypedEnvelope<proto::GetChannels>,
1095 response: Response<proto::GetChannels>,
1096 ) -> Result<()> {
1097 let user_id = self
1098 .store()
1099 .await
1100 .user_id_for_connection(request.sender_id)?;
1101 let channels = self.app_state.db.get_accessible_channels(user_id).await?;
1102 response.send(proto::GetChannelsResponse {
1103 channels: channels
1104 .into_iter()
1105 .map(|chan| proto::Channel {
1106 id: chan.id.to_proto(),
1107 name: chan.name,
1108 })
1109 .collect(),
1110 })?;
1111 Ok(())
1112 }
1113
1114 async fn get_users(
1115 self: Arc<Server>,
1116 request: TypedEnvelope<proto::GetUsers>,
1117 response: Response<proto::GetUsers>,
1118 ) -> Result<()> {
1119 let user_ids = request
1120 .payload
1121 .user_ids
1122 .into_iter()
1123 .map(UserId::from_proto)
1124 .collect();
1125 let users = self
1126 .app_state
1127 .db
1128 .get_users_by_ids(user_ids)
1129 .await?
1130 .into_iter()
1131 .map(|user| proto::User {
1132 id: user.id.to_proto(),
1133 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1134 github_login: user.github_login,
1135 })
1136 .collect();
1137 response.send(proto::UsersResponse { users })?;
1138 Ok(())
1139 }
1140
1141 async fn fuzzy_search_users(
1142 self: Arc<Server>,
1143 request: TypedEnvelope<proto::FuzzySearchUsers>,
1144 response: Response<proto::FuzzySearchUsers>,
1145 ) -> Result<()> {
1146 let user_id = self
1147 .store()
1148 .await
1149 .user_id_for_connection(request.sender_id)?;
1150 let query = request.payload.query;
1151 let db = &self.app_state.db;
1152 let users = match query.len() {
1153 0 => vec![],
1154 1 | 2 => db
1155 .get_user_by_github_login(&query)
1156 .await?
1157 .into_iter()
1158 .collect(),
1159 _ => db.fuzzy_search_users(&query, 10).await?,
1160 };
1161 let users = users
1162 .into_iter()
1163 .filter(|user| user.id != user_id)
1164 .map(|user| proto::User {
1165 id: user.id.to_proto(),
1166 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1167 github_login: user.github_login,
1168 })
1169 .collect();
1170 response.send(proto::UsersResponse { users })?;
1171 Ok(())
1172 }
1173
1174 async fn request_contact(
1175 self: Arc<Server>,
1176 request: TypedEnvelope<proto::RequestContact>,
1177 response: Response<proto::RequestContact>,
1178 ) -> Result<()> {
1179 let requester_id = self
1180 .store()
1181 .await
1182 .user_id_for_connection(request.sender_id)?;
1183 let responder_id = UserId::from_proto(request.payload.responder_id);
1184 if requester_id == responder_id {
1185 return Err(anyhow!("cannot add yourself as a contact"))?;
1186 }
1187
1188 self.app_state
1189 .db
1190 .send_contact_request(requester_id, responder_id)
1191 .await?;
1192
1193 // Update outgoing contact requests of requester
1194 let mut update = proto::UpdateContacts::default();
1195 update.outgoing_requests.push(responder_id.to_proto());
1196 for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1197 self.peer.send(connection_id, update.clone())?;
1198 }
1199
1200 // Update incoming contact requests of responder
1201 let mut update = proto::UpdateContacts::default();
1202 update
1203 .incoming_requests
1204 .push(proto::IncomingContactRequest {
1205 requester_id: requester_id.to_proto(),
1206 should_notify: true,
1207 });
1208 for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1209 self.peer.send(connection_id, update.clone())?;
1210 }
1211
1212 response.send(proto::Ack {})?;
1213 Ok(())
1214 }
1215
1216 async fn respond_to_contact_request(
1217 self: Arc<Server>,
1218 request: TypedEnvelope<proto::RespondToContactRequest>,
1219 response: Response<proto::RespondToContactRequest>,
1220 ) -> Result<()> {
1221 let responder_id = self
1222 .store()
1223 .await
1224 .user_id_for_connection(request.sender_id)?;
1225 let requester_id = UserId::from_proto(request.payload.requester_id);
1226 if request.payload.response == proto::ContactRequestResponse::Dismiss as i32 {
1227 self.app_state
1228 .db
1229 .dismiss_contact_notification(responder_id, requester_id)
1230 .await?;
1231 } else {
1232 let accept = request.payload.response == proto::ContactRequestResponse::Accept as i32;
1233 self.app_state
1234 .db
1235 .respond_to_contact_request(responder_id, requester_id, accept)
1236 .await?;
1237
1238 let store = self.store().await;
1239 // Update responder with new contact
1240 let mut update = proto::UpdateContacts::default();
1241 if accept {
1242 update
1243 .contacts
1244 .push(store.contact_for_user(requester_id, false));
1245 }
1246 update
1247 .remove_incoming_requests
1248 .push(requester_id.to_proto());
1249 for connection_id in store.connection_ids_for_user(responder_id) {
1250 self.peer.send(connection_id, update.clone())?;
1251 }
1252
1253 // Update requester with new contact
1254 let mut update = proto::UpdateContacts::default();
1255 if accept {
1256 update
1257 .contacts
1258 .push(store.contact_for_user(responder_id, true));
1259 }
1260 update
1261 .remove_outgoing_requests
1262 .push(responder_id.to_proto());
1263 for connection_id in store.connection_ids_for_user(requester_id) {
1264 self.peer.send(connection_id, update.clone())?;
1265 }
1266 }
1267
1268 response.send(proto::Ack {})?;
1269 Ok(())
1270 }
1271
1272 async fn remove_contact(
1273 self: Arc<Server>,
1274 request: TypedEnvelope<proto::RemoveContact>,
1275 response: Response<proto::RemoveContact>,
1276 ) -> Result<()> {
1277 let requester_id = self
1278 .store()
1279 .await
1280 .user_id_for_connection(request.sender_id)?;
1281 let responder_id = UserId::from_proto(request.payload.user_id);
1282 self.app_state
1283 .db
1284 .remove_contact(requester_id, responder_id)
1285 .await?;
1286
1287 // Update outgoing contact requests of requester
1288 let mut update = proto::UpdateContacts::default();
1289 update
1290 .remove_outgoing_requests
1291 .push(responder_id.to_proto());
1292 for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1293 self.peer.send(connection_id, update.clone())?;
1294 }
1295
1296 // Update incoming contact requests of responder
1297 let mut update = proto::UpdateContacts::default();
1298 update
1299 .remove_incoming_requests
1300 .push(requester_id.to_proto());
1301 for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1302 self.peer.send(connection_id, update.clone())?;
1303 }
1304
1305 response.send(proto::Ack {})?;
1306 Ok(())
1307 }
1308
1309 async fn join_channel(
1310 self: Arc<Self>,
1311 request: TypedEnvelope<proto::JoinChannel>,
1312 response: Response<proto::JoinChannel>,
1313 ) -> Result<()> {
1314 let user_id = self
1315 .store()
1316 .await
1317 .user_id_for_connection(request.sender_id)?;
1318 let channel_id = ChannelId::from_proto(request.payload.channel_id);
1319 if !self
1320 .app_state
1321 .db
1322 .can_user_access_channel(user_id, channel_id)
1323 .await?
1324 {
1325 Err(anyhow!("access denied"))?;
1326 }
1327
1328 self.store_mut()
1329 .await
1330 .join_channel(request.sender_id, channel_id);
1331 let messages = self
1332 .app_state
1333 .db
1334 .get_channel_messages(channel_id, MESSAGE_COUNT_PER_PAGE, None)
1335 .await?
1336 .into_iter()
1337 .map(|msg| proto::ChannelMessage {
1338 id: msg.id.to_proto(),
1339 body: msg.body,
1340 timestamp: msg.sent_at.unix_timestamp() as u64,
1341 sender_id: msg.sender_id.to_proto(),
1342 nonce: Some(msg.nonce.as_u128().into()),
1343 })
1344 .collect::<Vec<_>>();
1345 response.send(proto::JoinChannelResponse {
1346 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
1347 messages,
1348 })?;
1349 Ok(())
1350 }
1351
1352 async fn leave_channel(
1353 self: Arc<Self>,
1354 request: TypedEnvelope<proto::LeaveChannel>,
1355 ) -> Result<()> {
1356 let user_id = self
1357 .store()
1358 .await
1359 .user_id_for_connection(request.sender_id)?;
1360 let channel_id = ChannelId::from_proto(request.payload.channel_id);
1361 if !self
1362 .app_state
1363 .db
1364 .can_user_access_channel(user_id, channel_id)
1365 .await?
1366 {
1367 Err(anyhow!("access denied"))?;
1368 }
1369
1370 self.store_mut()
1371 .await
1372 .leave_channel(request.sender_id, channel_id);
1373
1374 Ok(())
1375 }
1376
1377 async fn send_channel_message(
1378 self: Arc<Self>,
1379 request: TypedEnvelope<proto::SendChannelMessage>,
1380 response: Response<proto::SendChannelMessage>,
1381 ) -> Result<()> {
1382 let channel_id = ChannelId::from_proto(request.payload.channel_id);
1383 let user_id;
1384 let connection_ids;
1385 {
1386 let state = self.store().await;
1387 user_id = state.user_id_for_connection(request.sender_id)?;
1388 connection_ids = state.channel_connection_ids(channel_id)?;
1389 }
1390
1391 // Validate the message body.
1392 let body = request.payload.body.trim().to_string();
1393 if body.len() > MAX_MESSAGE_LEN {
1394 return Err(anyhow!("message is too long"))?;
1395 }
1396 if body.is_empty() {
1397 return Err(anyhow!("message can't be blank"))?;
1398 }
1399
1400 let timestamp = OffsetDateTime::now_utc();
1401 let nonce = request
1402 .payload
1403 .nonce
1404 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
1405
1406 let message_id = self
1407 .app_state
1408 .db
1409 .create_channel_message(channel_id, user_id, &body, timestamp, nonce.clone().into())
1410 .await?
1411 .to_proto();
1412 let message = proto::ChannelMessage {
1413 sender_id: user_id.to_proto(),
1414 id: message_id,
1415 body,
1416 timestamp: timestamp.unix_timestamp() as u64,
1417 nonce: Some(nonce),
1418 };
1419 broadcast(request.sender_id, connection_ids, |conn_id| {
1420 self.peer.send(
1421 conn_id,
1422 proto::ChannelMessageSent {
1423 channel_id: channel_id.to_proto(),
1424 message: Some(message.clone()),
1425 },
1426 )
1427 });
1428 response.send(proto::SendChannelMessageResponse {
1429 message: Some(message),
1430 })?;
1431 Ok(())
1432 }
1433
1434 async fn get_channel_messages(
1435 self: Arc<Self>,
1436 request: TypedEnvelope<proto::GetChannelMessages>,
1437 response: Response<proto::GetChannelMessages>,
1438 ) -> Result<()> {
1439 let user_id = self
1440 .store()
1441 .await
1442 .user_id_for_connection(request.sender_id)?;
1443 let channel_id = ChannelId::from_proto(request.payload.channel_id);
1444 if !self
1445 .app_state
1446 .db
1447 .can_user_access_channel(user_id, channel_id)
1448 .await?
1449 {
1450 Err(anyhow!("access denied"))?;
1451 }
1452
1453 let messages = self
1454 .app_state
1455 .db
1456 .get_channel_messages(
1457 channel_id,
1458 MESSAGE_COUNT_PER_PAGE,
1459 Some(MessageId::from_proto(request.payload.before_message_id)),
1460 )
1461 .await?
1462 .into_iter()
1463 .map(|msg| proto::ChannelMessage {
1464 id: msg.id.to_proto(),
1465 body: msg.body,
1466 timestamp: msg.sent_at.unix_timestamp() as u64,
1467 sender_id: msg.sender_id.to_proto(),
1468 nonce: Some(msg.nonce.as_u128().into()),
1469 })
1470 .collect::<Vec<_>>();
1471 response.send(proto::GetChannelMessagesResponse {
1472 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
1473 messages,
1474 })?;
1475 Ok(())
1476 }
1477
1478 async fn store<'a>(self: &'a Arc<Self>) -> StoreReadGuard<'a> {
1479 #[cfg(test)]
1480 tokio::task::yield_now().await;
1481 let guard = self.store.read().await;
1482 #[cfg(test)]
1483 tokio::task::yield_now().await;
1484 StoreReadGuard {
1485 guard,
1486 _not_send: PhantomData,
1487 }
1488 }
1489
1490 async fn store_mut<'a>(self: &'a Arc<Self>) -> StoreWriteGuard<'a> {
1491 #[cfg(test)]
1492 tokio::task::yield_now().await;
1493 let guard = self.store.write().await;
1494 #[cfg(test)]
1495 tokio::task::yield_now().await;
1496 StoreWriteGuard {
1497 guard,
1498 _not_send: PhantomData,
1499 }
1500 }
1501
1502 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1503 ServerSnapshot {
1504 store: self.store.read().await,
1505 peer: &self.peer,
1506 }
1507 }
1508}
1509
1510impl<'a> Deref for StoreReadGuard<'a> {
1511 type Target = Store;
1512
1513 fn deref(&self) -> &Self::Target {
1514 &*self.guard
1515 }
1516}
1517
1518impl<'a> Deref for StoreWriteGuard<'a> {
1519 type Target = Store;
1520
1521 fn deref(&self) -> &Self::Target {
1522 &*self.guard
1523 }
1524}
1525
1526impl<'a> DerefMut for StoreWriteGuard<'a> {
1527 fn deref_mut(&mut self) -> &mut Self::Target {
1528 &mut *self.guard
1529 }
1530}
1531
1532impl<'a> Drop for StoreWriteGuard<'a> {
1533 fn drop(&mut self) {
1534 #[cfg(test)]
1535 self.check_invariants();
1536
1537 let metrics = self.metrics();
1538 tracing::info!(
1539 connections = metrics.connections,
1540 registered_projects = metrics.registered_projects,
1541 shared_projects = metrics.shared_projects,
1542 "metrics"
1543 );
1544 }
1545}
1546
1547impl Executor for RealExecutor {
1548 type Sleep = Sleep;
1549
1550 fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F) {
1551 tokio::task::spawn(future);
1552 }
1553
1554 fn sleep(&self, duration: Duration) -> Self::Sleep {
1555 tokio::time::sleep(duration)
1556 }
1557}
1558
1559fn broadcast<F>(
1560 sender_id: ConnectionId,
1561 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1562 mut f: F,
1563) where
1564 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1565{
1566 for receiver_id in receiver_ids {
1567 if receiver_id != sender_id {
1568 f(receiver_id).trace_err();
1569 }
1570 }
1571}
1572
1573lazy_static! {
1574 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
1575}
1576
1577pub struct ProtocolVersion(u32);
1578
1579impl Header for ProtocolVersion {
1580 fn name() -> &'static HeaderName {
1581 &ZED_PROTOCOL_VERSION
1582 }
1583
1584 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1585 where
1586 Self: Sized,
1587 I: Iterator<Item = &'i axum::http::HeaderValue>,
1588 {
1589 let version = values
1590 .next()
1591 .ok_or_else(|| axum::headers::Error::invalid())?
1592 .to_str()
1593 .map_err(|_| axum::headers::Error::invalid())?
1594 .parse()
1595 .map_err(|_| axum::headers::Error::invalid())?;
1596 Ok(Self(version))
1597 }
1598
1599 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1600 values.extend([self.0.to_string().parse().unwrap()]);
1601 }
1602}
1603
1604pub fn routes(server: Arc<Server>) -> Router<Body> {
1605 Router::new()
1606 .route("/rpc", get(handle_websocket_request))
1607 .layer(
1608 ServiceBuilder::new()
1609 .layer(Extension(server.app_state.clone()))
1610 .layer(middleware::from_fn(auth::validate_header))
1611 .layer(Extension(server)),
1612 )
1613}
1614
1615pub async fn handle_websocket_request(
1616 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1617 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1618 Extension(server): Extension<Arc<Server>>,
1619 Extension(user): Extension<User>,
1620 ws: WebSocketUpgrade,
1621) -> axum::response::Response {
1622 if protocol_version != rpc::PROTOCOL_VERSION {
1623 return (
1624 StatusCode::UPGRADE_REQUIRED,
1625 "client must be upgraded".to_string(),
1626 )
1627 .into_response();
1628 }
1629 let socket_address = socket_address.to_string();
1630 ws.on_upgrade(move |socket| {
1631 use util::ResultExt;
1632 let socket = socket
1633 .map_ok(to_tungstenite_message)
1634 .err_into()
1635 .with(|message| async move { Ok(to_axum_message(message)) });
1636 let connection = Connection::new(Box::pin(socket));
1637 async move {
1638 server
1639 .handle_connection(connection, socket_address, user, None, RealExecutor)
1640 .await
1641 .log_err();
1642 }
1643 })
1644}
1645
1646fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1647 match message {
1648 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1649 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1650 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1651 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1652 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1653 code: frame.code.into(),
1654 reason: frame.reason,
1655 })),
1656 }
1657}
1658
1659fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1660 match message {
1661 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1662 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1663 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1664 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1665 AxumMessage::Close(frame) => {
1666 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1667 code: frame.code.into(),
1668 reason: frame.reason,
1669 }))
1670 }
1671 }
1672}
1673
1674pub trait ResultExt {
1675 type Ok;
1676
1677 fn trace_err(self) -> Option<Self::Ok>;
1678}
1679
1680impl<T, E> ResultExt for Result<T, E>
1681where
1682 E: std::fmt::Debug,
1683{
1684 type Ok = T;
1685
1686 fn trace_err(self) -> Option<T> {
1687 match self {
1688 Ok(value) => Some(value),
1689 Err(error) => {
1690 tracing::error!("{:?}", error);
1691 None
1692 }
1693 }
1694 }
1695}