1mod store;
2
3use crate::{
4 auth,
5 db::{self, ChannelId, MessageId, User, UserId},
6 AppState, Result,
7};
8use anyhow::anyhow;
9use async_tungstenite::tungstenite::{
10 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
11};
12use axum::{
13 body::Body,
14 extract::{
15 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
16 ConnectInfo, WebSocketUpgrade,
17 },
18 headers::{Header, HeaderName},
19 http::StatusCode,
20 middleware,
21 response::IntoResponse,
22 routing::get,
23 Extension, Router, TypedHeader,
24};
25use collections::HashMap;
26use futures::{
27 channel::mpsc,
28 future::{self, BoxFuture},
29 FutureExt, SinkExt, StreamExt, TryStreamExt,
30};
31use lazy_static::lazy_static;
32use prometheus::{register_int_gauge, IntGauge};
33use rpc::{
34 proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
35 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
36};
37use serde::{Serialize, Serializer};
38use std::{
39 any::TypeId,
40 future::Future,
41 marker::PhantomData,
42 net::SocketAddr,
43 ops::{Deref, DerefMut},
44 rc::Rc,
45 sync::{
46 atomic::{AtomicBool, Ordering::SeqCst},
47 Arc,
48 },
49 time::Duration,
50};
51use time::OffsetDateTime;
52use tokio::{
53 sync::{RwLock, RwLockReadGuard, RwLockWriteGuard},
54 time::Sleep,
55};
56use tower::ServiceBuilder;
57use tracing::{info_span, instrument, Instrument};
58
59pub use store::{Store, Worktree};
60
61lazy_static! {
62 static ref METRIC_CONNECTIONS: IntGauge =
63 register_int_gauge!("connections", "number of connections").unwrap();
64 static ref METRIC_REGISTERED_PROJECTS: IntGauge =
65 register_int_gauge!("registered_projects", "number of registered projects").unwrap();
66 static ref METRIC_ACTIVE_PROJECTS: IntGauge =
67 register_int_gauge!("active_projects", "number of active projects").unwrap();
68 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
69 "shared_projects",
70 "number of open projects with one or more guests"
71 )
72 .unwrap();
73}
74
75type MessageHandler =
76 Box<dyn Send + Sync + Fn(Arc<Server>, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, ()>>;
77
78struct Response<R> {
79 server: Arc<Server>,
80 receipt: Receipt<R>,
81 responded: Arc<AtomicBool>,
82}
83
84impl<R: RequestMessage> Response<R> {
85 fn send(self, payload: R::Response) -> Result<()> {
86 self.responded.store(true, SeqCst);
87 self.server.peer.respond(self.receipt, payload)?;
88 Ok(())
89 }
90
91 fn into_receipt(self) -> Receipt<R> {
92 self.responded.store(true, SeqCst);
93 self.receipt
94 }
95}
96
97pub struct Server {
98 peer: Arc<Peer>,
99 pub(crate) store: RwLock<Store>,
100 app_state: Arc<AppState>,
101 handlers: HashMap<TypeId, MessageHandler>,
102 notifications: Option<mpsc::UnboundedSender<()>>,
103}
104
105pub trait Executor: Send + Clone {
106 type Sleep: Send + Future;
107 fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F);
108 fn sleep(&self, duration: Duration) -> Self::Sleep;
109}
110
111#[derive(Clone)]
112pub struct RealExecutor;
113
114const MESSAGE_COUNT_PER_PAGE: usize = 100;
115const MAX_MESSAGE_LEN: usize = 1024;
116
117struct StoreReadGuard<'a> {
118 guard: RwLockReadGuard<'a, Store>,
119 _not_send: PhantomData<Rc<()>>,
120}
121
122struct StoreWriteGuard<'a> {
123 guard: RwLockWriteGuard<'a, Store>,
124 _not_send: PhantomData<Rc<()>>,
125}
126
127#[derive(Serialize)]
128pub struct ServerSnapshot<'a> {
129 peer: &'a Peer,
130 #[serde(serialize_with = "serialize_deref")]
131 store: RwLockReadGuard<'a, Store>,
132}
133
134pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
135where
136 S: Serializer,
137 T: Deref<Target = U>,
138 U: Serialize,
139{
140 Serialize::serialize(value.deref(), serializer)
141}
142
143impl Server {
144 pub fn new(
145 app_state: Arc<AppState>,
146 notifications: Option<mpsc::UnboundedSender<()>>,
147 ) -> Arc<Self> {
148 let mut server = Self {
149 peer: Peer::new(),
150 app_state,
151 store: Default::default(),
152 handlers: Default::default(),
153 notifications,
154 };
155
156 server
157 .add_request_handler(Server::ping)
158 .add_request_handler(Server::register_project)
159 .add_request_handler(Server::unregister_project)
160 .add_request_handler(Server::join_project)
161 .add_message_handler(Server::leave_project)
162 .add_message_handler(Server::respond_to_join_project_request)
163 .add_message_handler(Server::update_project)
164 .add_message_handler(Server::register_project_activity)
165 .add_request_handler(Server::update_worktree)
166 .add_message_handler(Server::start_language_server)
167 .add_message_handler(Server::update_language_server)
168 .add_message_handler(Server::update_diagnostic_summary)
169 .add_request_handler(Server::forward_project_request::<proto::GetHover>)
170 .add_request_handler(Server::forward_project_request::<proto::GetDefinition>)
171 .add_request_handler(Server::forward_project_request::<proto::GetReferences>)
172 .add_request_handler(Server::forward_project_request::<proto::SearchProject>)
173 .add_request_handler(Server::forward_project_request::<proto::GetDocumentHighlights>)
174 .add_request_handler(Server::forward_project_request::<proto::GetProjectSymbols>)
175 .add_request_handler(Server::forward_project_request::<proto::OpenBufferForSymbol>)
176 .add_request_handler(Server::forward_project_request::<proto::OpenBufferById>)
177 .add_request_handler(Server::forward_project_request::<proto::OpenBufferByPath>)
178 .add_request_handler(Server::forward_project_request::<proto::GetCompletions>)
179 .add_request_handler(
180 Server::forward_project_request::<proto::ApplyCompletionAdditionalEdits>,
181 )
182 .add_request_handler(Server::forward_project_request::<proto::GetCodeActions>)
183 .add_request_handler(Server::forward_project_request::<proto::ApplyCodeAction>)
184 .add_request_handler(Server::forward_project_request::<proto::PrepareRename>)
185 .add_request_handler(Server::forward_project_request::<proto::PerformRename>)
186 .add_request_handler(Server::forward_project_request::<proto::ReloadBuffers>)
187 .add_request_handler(Server::forward_project_request::<proto::FormatBuffers>)
188 .add_request_handler(Server::forward_project_request::<proto::CreateProjectEntry>)
189 .add_request_handler(Server::forward_project_request::<proto::RenameProjectEntry>)
190 .add_request_handler(Server::forward_project_request::<proto::CopyProjectEntry>)
191 .add_request_handler(Server::forward_project_request::<proto::DeleteProjectEntry>)
192 .add_request_handler(Server::update_buffer)
193 .add_message_handler(Server::update_buffer_file)
194 .add_message_handler(Server::buffer_reloaded)
195 .add_message_handler(Server::buffer_saved)
196 .add_request_handler(Server::save_buffer)
197 .add_request_handler(Server::get_channels)
198 .add_request_handler(Server::get_users)
199 .add_request_handler(Server::fuzzy_search_users)
200 .add_request_handler(Server::request_contact)
201 .add_request_handler(Server::remove_contact)
202 .add_request_handler(Server::respond_to_contact_request)
203 .add_request_handler(Server::join_channel)
204 .add_message_handler(Server::leave_channel)
205 .add_request_handler(Server::send_channel_message)
206 .add_request_handler(Server::follow)
207 .add_message_handler(Server::unfollow)
208 .add_message_handler(Server::update_followers)
209 .add_request_handler(Server::get_channel_messages);
210
211 Arc::new(server)
212 }
213
214 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
215 where
216 F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>) -> Fut,
217 Fut: 'static + Send + Future<Output = Result<()>>,
218 M: EnvelopedMessage,
219 {
220 let prev_handler = self.handlers.insert(
221 TypeId::of::<M>(),
222 Box::new(move |server, envelope| {
223 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
224 let span = info_span!(
225 "handle message",
226 payload_type = envelope.payload_type_name()
227 );
228 span.in_scope(|| {
229 tracing::info!(
230 payload = format!("{:?}", envelope.payload).as_str(),
231 "message payload"
232 );
233 });
234 let future = (handler)(server, *envelope);
235 async move {
236 if let Err(error) = future.await {
237 tracing::error!(%error, "error handling message");
238 }
239 }
240 .instrument(span)
241 .boxed()
242 }),
243 );
244 if prev_handler.is_some() {
245 panic!("registered a handler for the same message twice");
246 }
247 self
248 }
249
250 /// Handle a request while holding a lock to the store. This is useful when we're registering
251 /// a connection but we want to respond on the connection before anybody else can send on it.
252 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
253 where
254 F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>, Response<M>) -> Fut,
255 Fut: Send + Future<Output = Result<()>>,
256 M: RequestMessage,
257 {
258 let handler = Arc::new(handler);
259 self.add_message_handler(move |server, envelope| {
260 let receipt = envelope.receipt();
261 let handler = handler.clone();
262 async move {
263 let responded = Arc::new(AtomicBool::default());
264 let response = Response {
265 server: server.clone(),
266 responded: responded.clone(),
267 receipt: envelope.receipt(),
268 };
269 match (handler)(server.clone(), envelope, response).await {
270 Ok(()) => {
271 if responded.load(std::sync::atomic::Ordering::SeqCst) {
272 Ok(())
273 } else {
274 Err(anyhow!("handler did not send a response"))?
275 }
276 }
277 Err(error) => {
278 server.peer.respond_with_error(
279 receipt,
280 proto::Error {
281 message: error.to_string(),
282 },
283 )?;
284 Err(error)
285 }
286 }
287 }
288 })
289 }
290
291 pub fn handle_connection<E: Executor>(
292 self: &Arc<Self>,
293 connection: Connection,
294 address: String,
295 user: User,
296 mut send_connection_id: Option<mpsc::Sender<ConnectionId>>,
297 executor: E,
298 ) -> impl Future<Output = Result<()>> {
299 let mut this = self.clone();
300 let user_id = user.id;
301 let login = user.github_login;
302 let span = info_span!("handle connection", %user_id, %login, %address);
303 async move {
304 let (connection_id, handle_io, mut incoming_rx) = this
305 .peer
306 .add_connection(connection, {
307 let executor = executor.clone();
308 move |duration| {
309 let timer = executor.sleep(duration);
310 async move {
311 timer.await;
312 }
313 }
314 })
315 .await;
316
317 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
318
319 if let Some(send_connection_id) = send_connection_id.as_mut() {
320 let _ = send_connection_id.send(connection_id).await;
321 }
322
323 if !user.connected_once {
324 this.peer.send(connection_id, proto::ShowContacts {})?;
325 this.app_state.db.set_user_connected_once(user_id, true).await?;
326 }
327
328 let (contacts, invite_code) = future::try_join(
329 this.app_state.db.get_contacts(user_id),
330 this.app_state.db.get_invite_code_for_user(user_id)
331 ).await?;
332
333 {
334 let mut store = this.store_mut().await;
335 store.add_connection(connection_id, user_id, user.admin);
336 this.peer.send(connection_id, store.build_initial_contacts_update(contacts))?;
337
338 if let Some((code, count)) = invite_code {
339 this.peer.send(connection_id, proto::UpdateInviteInfo {
340 url: format!("{}{}", this.app_state.invite_link_prefix, code),
341 count,
342 })?;
343 }
344 }
345 this.update_user_contacts(user_id).await?;
346
347 let handle_io = handle_io.fuse();
348 futures::pin_mut!(handle_io);
349 loop {
350 let next_message = incoming_rx.next().fuse();
351 futures::pin_mut!(next_message);
352 futures::select_biased! {
353 result = handle_io => {
354 if let Err(error) = result {
355 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
356 }
357 break;
358 }
359 message = next_message => {
360 if let Some(message) = message {
361 let type_name = message.payload_type_name();
362 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
363 async {
364 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
365 let notifications = this.notifications.clone();
366 let is_background = message.is_background();
367 let handle_message = (handler)(this.clone(), message);
368 let handle_message = async move {
369 handle_message.await;
370 if let Some(mut notifications) = notifications {
371 let _ = notifications.send(()).await;
372 }
373 };
374 if is_background {
375 executor.spawn_detached(handle_message);
376 } else {
377 handle_message.await;
378 }
379 } else {
380 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
381 }
382 }.instrument(span).await;
383 } else {
384 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
385 break;
386 }
387 }
388 }
389 }
390
391 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
392 if let Err(error) = this.sign_out(connection_id).await {
393 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
394 }
395
396 Ok(())
397 }.instrument(span)
398 }
399
400 #[instrument(skip(self), err)]
401 async fn sign_out(self: &mut Arc<Self>, connection_id: ConnectionId) -> Result<()> {
402 self.peer.disconnect(connection_id);
403
404 let removed_user_id = {
405 let mut store = self.store_mut().await;
406 let removed_connection = store.remove_connection(connection_id)?;
407
408 for (project_id, project) in removed_connection.hosted_projects {
409 broadcast(connection_id, project.guests.keys().copied(), |conn_id| {
410 self.peer
411 .send(conn_id, proto::UnregisterProject { project_id })
412 });
413
414 for (_, receipts) in project.join_requests {
415 for receipt in receipts {
416 self.peer.respond(
417 receipt,
418 proto::JoinProjectResponse {
419 variant: Some(proto::join_project_response::Variant::Decline(
420 proto::join_project_response::Decline {
421 reason: proto::join_project_response::decline::Reason::WentOffline as i32
422 },
423 )),
424 },
425 )?;
426 }
427 }
428 }
429
430 for project_id in removed_connection.guest_project_ids {
431 if let Some(project) = store.project(project_id).trace_err() {
432 broadcast(connection_id, project.connection_ids(), |conn_id| {
433 self.peer.send(
434 conn_id,
435 proto::RemoveProjectCollaborator {
436 project_id,
437 peer_id: connection_id.0,
438 },
439 )
440 });
441 if project.guests.is_empty() {
442 self.peer
443 .send(
444 project.host_connection_id,
445 proto::ProjectUnshared { project_id },
446 )
447 .trace_err();
448 }
449 }
450 }
451
452 removed_connection.user_id
453 };
454
455 self.update_user_contacts(removed_user_id).await?;
456
457 Ok(())
458 }
459
460 pub async fn invite_code_redeemed(
461 self: &Arc<Self>,
462 code: &str,
463 invitee_id: UserId,
464 ) -> Result<()> {
465 let user = self.app_state.db.get_user_for_invite_code(code).await?;
466 let store = self.store().await;
467 let invitee_contact = store.contact_for_user(invitee_id, true);
468 for connection_id in store.connection_ids_for_user(user.id) {
469 self.peer.send(
470 connection_id,
471 proto::UpdateContacts {
472 contacts: vec![invitee_contact.clone()],
473 ..Default::default()
474 },
475 )?;
476 self.peer.send(
477 connection_id,
478 proto::UpdateInviteInfo {
479 url: format!("{}{}", self.app_state.invite_link_prefix, code),
480 count: user.invite_count as u32,
481 },
482 )?;
483 }
484 Ok(())
485 }
486
487 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
488 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
489 if let Some(invite_code) = &user.invite_code {
490 let store = self.store().await;
491 for connection_id in store.connection_ids_for_user(user_id) {
492 self.peer.send(
493 connection_id,
494 proto::UpdateInviteInfo {
495 url: format!("{}{}", self.app_state.invite_link_prefix, invite_code),
496 count: user.invite_count as u32,
497 },
498 )?;
499 }
500 }
501 }
502 Ok(())
503 }
504
505 async fn ping(
506 self: Arc<Server>,
507 _: TypedEnvelope<proto::Ping>,
508 response: Response<proto::Ping>,
509 ) -> Result<()> {
510 response.send(proto::Ack {})?;
511 Ok(())
512 }
513
514 async fn register_project(
515 self: Arc<Server>,
516 request: TypedEnvelope<proto::RegisterProject>,
517 response: Response<proto::RegisterProject>,
518 ) -> Result<()> {
519 let project_id;
520 {
521 let mut state = self.store_mut().await;
522 let user_id = state.user_id_for_connection(request.sender_id)?;
523 project_id = state.register_project(request.sender_id, user_id);
524 };
525
526 response.send(proto::RegisterProjectResponse { project_id })?;
527
528 Ok(())
529 }
530
531 async fn unregister_project(
532 self: Arc<Server>,
533 request: TypedEnvelope<proto::UnregisterProject>,
534 response: Response<proto::UnregisterProject>,
535 ) -> Result<()> {
536 let (user_id, project) = {
537 let mut state = self.store_mut().await;
538 let project =
539 state.unregister_project(request.payload.project_id, request.sender_id)?;
540 (state.user_id_for_connection(request.sender_id)?, project)
541 };
542
543 broadcast(
544 request.sender_id,
545 project.guests.keys().copied(),
546 |conn_id| {
547 self.peer.send(
548 conn_id,
549 proto::UnregisterProject {
550 project_id: request.payload.project_id,
551 },
552 )
553 },
554 );
555 for (_, receipts) in project.join_requests {
556 for receipt in receipts {
557 self.peer.respond(
558 receipt,
559 proto::JoinProjectResponse {
560 variant: Some(proto::join_project_response::Variant::Decline(
561 proto::join_project_response::Decline {
562 reason: proto::join_project_response::decline::Reason::Closed
563 as i32,
564 },
565 )),
566 },
567 )?;
568 }
569 }
570
571 // Send out the `UpdateContacts` message before responding to the unregister
572 // request. This way, when the project's host can keep track of the project's
573 // remote id until after they've received the `UpdateContacts` message for
574 // themself.
575 self.update_user_contacts(user_id).await?;
576 response.send(proto::Ack {})?;
577
578 Ok(())
579 }
580
581 async fn update_user_contacts(self: &Arc<Server>, user_id: UserId) -> Result<()> {
582 let contacts = self.app_state.db.get_contacts(user_id).await?;
583 let store = self.store().await;
584 let updated_contact = store.contact_for_user(user_id, false);
585 for contact in contacts {
586 if let db::Contact::Accepted {
587 user_id: contact_user_id,
588 ..
589 } = contact
590 {
591 for contact_conn_id in store.connection_ids_for_user(contact_user_id) {
592 self.peer
593 .send(
594 contact_conn_id,
595 proto::UpdateContacts {
596 contacts: vec![updated_contact.clone()],
597 remove_contacts: Default::default(),
598 incoming_requests: Default::default(),
599 remove_incoming_requests: Default::default(),
600 outgoing_requests: Default::default(),
601 remove_outgoing_requests: Default::default(),
602 },
603 )
604 .trace_err();
605 }
606 }
607 }
608 Ok(())
609 }
610
611 async fn join_project(
612 self: Arc<Server>,
613 request: TypedEnvelope<proto::JoinProject>,
614 response: Response<proto::JoinProject>,
615 ) -> Result<()> {
616 let project_id = request.payload.project_id;
617
618 let host_user_id;
619 let guest_user_id;
620 let host_connection_id;
621 {
622 let state = self.store().await;
623 let project = state.project(project_id)?;
624 host_user_id = project.host_user_id;
625 host_connection_id = project.host_connection_id;
626 guest_user_id = state.user_id_for_connection(request.sender_id)?;
627 };
628
629 tracing::info!(project_id, %host_user_id, %host_connection_id, "join project");
630 let has_contact = self
631 .app_state
632 .db
633 .has_contact(guest_user_id, host_user_id)
634 .await?;
635 if !has_contact {
636 return Err(anyhow!("no such project"))?;
637 }
638
639 self.store_mut().await.request_join_project(
640 guest_user_id,
641 project_id,
642 response.into_receipt(),
643 )?;
644 self.peer.send(
645 host_connection_id,
646 proto::RequestJoinProject {
647 project_id,
648 requester_id: guest_user_id.to_proto(),
649 },
650 )?;
651 Ok(())
652 }
653
654 async fn respond_to_join_project_request(
655 self: Arc<Server>,
656 request: TypedEnvelope<proto::RespondToJoinProjectRequest>,
657 ) -> Result<()> {
658 let host_user_id;
659
660 {
661 let mut state = self.store_mut().await;
662 let project_id = request.payload.project_id;
663 let project = state.project(project_id)?;
664 if project.host_connection_id != request.sender_id {
665 Err(anyhow!("no such connection"))?;
666 }
667
668 host_user_id = project.host_user_id;
669 let guest_user_id = UserId::from_proto(request.payload.requester_id);
670
671 if !request.payload.allow {
672 let receipts = state
673 .deny_join_project_request(request.sender_id, guest_user_id, project_id)
674 .ok_or_else(|| anyhow!("no such request"))?;
675 for receipt in receipts {
676 self.peer.respond(
677 receipt,
678 proto::JoinProjectResponse {
679 variant: Some(proto::join_project_response::Variant::Decline(
680 proto::join_project_response::Decline {
681 reason: proto::join_project_response::decline::Reason::Declined
682 as i32,
683 },
684 )),
685 },
686 )?;
687 }
688 return Ok(());
689 }
690
691 let (receipts_with_replica_ids, project) = state
692 .accept_join_project_request(request.sender_id, guest_user_id, project_id)
693 .ok_or_else(|| anyhow!("no such request"))?;
694
695 let peer_count = project.guests.len();
696 let mut collaborators = Vec::with_capacity(peer_count);
697 collaborators.push(proto::Collaborator {
698 peer_id: project.host_connection_id.0,
699 replica_id: 0,
700 user_id: project.host_user_id.to_proto(),
701 });
702 let worktrees = project
703 .worktrees
704 .iter()
705 .filter_map(|(id, shared_worktree)| {
706 let worktree = project.worktrees.get(&id)?;
707 Some(proto::Worktree {
708 id: *id,
709 root_name: worktree.root_name.clone(),
710 entries: shared_worktree.entries.values().cloned().collect(),
711 diagnostic_summaries: shared_worktree
712 .diagnostic_summaries
713 .values()
714 .cloned()
715 .collect(),
716 visible: worktree.visible,
717 scan_id: shared_worktree.scan_id,
718 })
719 })
720 .collect::<Vec<_>>();
721
722 // Add all guests other than the requesting user's own connections as collaborators
723 for (peer_conn_id, (peer_replica_id, peer_user_id)) in &project.guests {
724 if receipts_with_replica_ids
725 .iter()
726 .all(|(receipt, _)| receipt.sender_id != *peer_conn_id)
727 {
728 collaborators.push(proto::Collaborator {
729 peer_id: peer_conn_id.0,
730 replica_id: *peer_replica_id as u32,
731 user_id: peer_user_id.to_proto(),
732 });
733 }
734 }
735
736 for conn_id in project.connection_ids() {
737 for (receipt, replica_id) in &receipts_with_replica_ids {
738 if conn_id != receipt.sender_id {
739 self.peer.send(
740 conn_id,
741 proto::AddProjectCollaborator {
742 project_id,
743 collaborator: Some(proto::Collaborator {
744 peer_id: receipt.sender_id.0,
745 replica_id: *replica_id as u32,
746 user_id: guest_user_id.to_proto(),
747 }),
748 },
749 )?;
750 }
751 }
752 }
753
754 for (receipt, replica_id) in receipts_with_replica_ids {
755 self.peer.respond(
756 receipt,
757 proto::JoinProjectResponse {
758 variant: Some(proto::join_project_response::Variant::Accept(
759 proto::join_project_response::Accept {
760 worktrees: worktrees.clone(),
761 replica_id: replica_id as u32,
762 collaborators: collaborators.clone(),
763 language_servers: project.language_servers.clone(),
764 },
765 )),
766 },
767 )?;
768 }
769 }
770
771 self.update_user_contacts(host_user_id).await?;
772 Ok(())
773 }
774
775 async fn leave_project(
776 self: Arc<Server>,
777 request: TypedEnvelope<proto::LeaveProject>,
778 ) -> Result<()> {
779 let sender_id = request.sender_id;
780 let project_id = request.payload.project_id;
781 let project;
782 {
783 let mut store = self.store_mut().await;
784 project = store.leave_project(sender_id, project_id)?;
785 tracing::info!(
786 project_id,
787 host_user_id = %project.host_user_id,
788 host_connection_id = %project.host_connection_id,
789 "leave project"
790 );
791
792 if project.remove_collaborator {
793 broadcast(sender_id, project.connection_ids, |conn_id| {
794 self.peer.send(
795 conn_id,
796 proto::RemoveProjectCollaborator {
797 project_id,
798 peer_id: sender_id.0,
799 },
800 )
801 });
802 }
803
804 if let Some(requester_id) = project.cancel_request {
805 self.peer.send(
806 project.host_connection_id,
807 proto::JoinProjectRequestCancelled {
808 project_id,
809 requester_id: requester_id.to_proto(),
810 },
811 )?;
812 }
813
814 if project.unshare {
815 self.peer.send(
816 project.host_connection_id,
817 proto::ProjectUnshared { project_id },
818 )?;
819 }
820 }
821 self.update_user_contacts(project.host_user_id).await?;
822 Ok(())
823 }
824
825 async fn update_project(
826 self: Arc<Server>,
827 request: TypedEnvelope<proto::UpdateProject>,
828 ) -> Result<()> {
829 let user_id;
830 {
831 let mut state = self.store_mut().await;
832 user_id = state.user_id_for_connection(request.sender_id)?;
833 let guest_connection_ids = state
834 .read_project(request.payload.project_id, request.sender_id)?
835 .guest_connection_ids();
836 state.update_project(
837 request.payload.project_id,
838 &request.payload.worktrees,
839 request.sender_id,
840 )?;
841 broadcast(request.sender_id, guest_connection_ids, |connection_id| {
842 self.peer
843 .forward_send(request.sender_id, connection_id, request.payload.clone())
844 });
845 };
846 self.update_user_contacts(user_id).await?;
847 Ok(())
848 }
849
850 async fn register_project_activity(
851 self: Arc<Server>,
852 request: TypedEnvelope<proto::RegisterProjectActivity>,
853 ) -> Result<()> {
854 self.store_mut()
855 .await
856 .register_project_activity(request.payload.project_id, request.sender_id)?;
857 Ok(())
858 }
859
860 async fn update_worktree(
861 self: Arc<Server>,
862 request: TypedEnvelope<proto::UpdateWorktree>,
863 response: Response<proto::UpdateWorktree>,
864 ) -> Result<()> {
865 let (connection_ids, metadata_changed) = {
866 let mut store = self.store_mut().await;
867 let (connection_ids, metadata_changed, extension_counts) = store.update_worktree(
868 request.sender_id,
869 request.payload.project_id,
870 request.payload.worktree_id,
871 &request.payload.root_name,
872 &request.payload.removed_entries,
873 &request.payload.updated_entries,
874 request.payload.scan_id,
875 )?;
876 for (extension, count) in extension_counts {
877 tracing::info!(
878 project_id = request.payload.project_id,
879 worktree_id = request.payload.worktree_id,
880 ?extension,
881 %count,
882 "worktree updated"
883 );
884 }
885 (connection_ids, metadata_changed)
886 };
887
888 broadcast(request.sender_id, connection_ids, |connection_id| {
889 self.peer
890 .forward_send(request.sender_id, connection_id, request.payload.clone())
891 });
892 if metadata_changed {
893 let user_id = self
894 .store()
895 .await
896 .user_id_for_connection(request.sender_id)?;
897 self.update_user_contacts(user_id).await?;
898 }
899 response.send(proto::Ack {})?;
900 Ok(())
901 }
902
903 async fn update_diagnostic_summary(
904 self: Arc<Server>,
905 request: TypedEnvelope<proto::UpdateDiagnosticSummary>,
906 ) -> Result<()> {
907 let summary = request
908 .payload
909 .summary
910 .clone()
911 .ok_or_else(|| anyhow!("invalid summary"))?;
912 let receiver_ids = self.store_mut().await.update_diagnostic_summary(
913 request.payload.project_id,
914 request.payload.worktree_id,
915 request.sender_id,
916 summary,
917 )?;
918
919 broadcast(request.sender_id, receiver_ids, |connection_id| {
920 self.peer
921 .forward_send(request.sender_id, connection_id, request.payload.clone())
922 });
923 Ok(())
924 }
925
926 async fn start_language_server(
927 self: Arc<Server>,
928 request: TypedEnvelope<proto::StartLanguageServer>,
929 ) -> Result<()> {
930 let receiver_ids = self.store_mut().await.start_language_server(
931 request.payload.project_id,
932 request.sender_id,
933 request
934 .payload
935 .server
936 .clone()
937 .ok_or_else(|| anyhow!("invalid language server"))?,
938 )?;
939 broadcast(request.sender_id, receiver_ids, |connection_id| {
940 self.peer
941 .forward_send(request.sender_id, connection_id, request.payload.clone())
942 });
943 Ok(())
944 }
945
946 async fn update_language_server(
947 self: Arc<Server>,
948 request: TypedEnvelope<proto::UpdateLanguageServer>,
949 ) -> Result<()> {
950 let receiver_ids = self
951 .store()
952 .await
953 .project_connection_ids(request.payload.project_id, request.sender_id)?;
954 broadcast(request.sender_id, receiver_ids, |connection_id| {
955 self.peer
956 .forward_send(request.sender_id, connection_id, request.payload.clone())
957 });
958 Ok(())
959 }
960
961 async fn forward_project_request<T>(
962 self: Arc<Server>,
963 request: TypedEnvelope<T>,
964 response: Response<T>,
965 ) -> Result<()>
966 where
967 T: EntityMessage + RequestMessage,
968 {
969 let host_connection_id = self
970 .store()
971 .await
972 .read_project(request.payload.remote_entity_id(), request.sender_id)?
973 .host_connection_id;
974
975 response.send(
976 self.peer
977 .forward_request(request.sender_id, host_connection_id, request.payload)
978 .await?,
979 )?;
980 Ok(())
981 }
982
983 async fn save_buffer(
984 self: Arc<Server>,
985 request: TypedEnvelope<proto::SaveBuffer>,
986 response: Response<proto::SaveBuffer>,
987 ) -> Result<()> {
988 let host = self
989 .store()
990 .await
991 .read_project(request.payload.project_id, request.sender_id)?
992 .host_connection_id;
993 let response_payload = self
994 .peer
995 .forward_request(request.sender_id, host, request.payload.clone())
996 .await?;
997
998 let mut guests = self
999 .store()
1000 .await
1001 .read_project(request.payload.project_id, request.sender_id)?
1002 .connection_ids();
1003 guests.retain(|guest_connection_id| *guest_connection_id != request.sender_id);
1004 broadcast(host, guests, |conn_id| {
1005 self.peer
1006 .forward_send(host, conn_id, response_payload.clone())
1007 });
1008 response.send(response_payload)?;
1009 Ok(())
1010 }
1011
1012 async fn update_buffer(
1013 self: Arc<Server>,
1014 request: TypedEnvelope<proto::UpdateBuffer>,
1015 response: Response<proto::UpdateBuffer>,
1016 ) -> Result<()> {
1017 let receiver_ids = {
1018 let mut store = self.store_mut().await;
1019 store.register_project_activity(request.payload.project_id, request.sender_id)?;
1020 store.project_connection_ids(request.payload.project_id, request.sender_id)?
1021 };
1022
1023 broadcast(request.sender_id, receiver_ids, |connection_id| {
1024 self.peer
1025 .forward_send(request.sender_id, connection_id, request.payload.clone())
1026 });
1027 response.send(proto::Ack {})?;
1028 Ok(())
1029 }
1030
1031 async fn update_buffer_file(
1032 self: Arc<Server>,
1033 request: TypedEnvelope<proto::UpdateBufferFile>,
1034 ) -> Result<()> {
1035 let receiver_ids = self
1036 .store()
1037 .await
1038 .project_connection_ids(request.payload.project_id, request.sender_id)?;
1039 broadcast(request.sender_id, receiver_ids, |connection_id| {
1040 self.peer
1041 .forward_send(request.sender_id, connection_id, request.payload.clone())
1042 });
1043 Ok(())
1044 }
1045
1046 async fn buffer_reloaded(
1047 self: Arc<Server>,
1048 request: TypedEnvelope<proto::BufferReloaded>,
1049 ) -> Result<()> {
1050 let receiver_ids = self
1051 .store()
1052 .await
1053 .project_connection_ids(request.payload.project_id, request.sender_id)?;
1054 broadcast(request.sender_id, receiver_ids, |connection_id| {
1055 self.peer
1056 .forward_send(request.sender_id, connection_id, request.payload.clone())
1057 });
1058 Ok(())
1059 }
1060
1061 async fn buffer_saved(
1062 self: Arc<Server>,
1063 request: TypedEnvelope<proto::BufferSaved>,
1064 ) -> Result<()> {
1065 let receiver_ids = self
1066 .store()
1067 .await
1068 .project_connection_ids(request.payload.project_id, request.sender_id)?;
1069 broadcast(request.sender_id, receiver_ids, |connection_id| {
1070 self.peer
1071 .forward_send(request.sender_id, connection_id, request.payload.clone())
1072 });
1073 Ok(())
1074 }
1075
1076 async fn follow(
1077 self: Arc<Self>,
1078 request: TypedEnvelope<proto::Follow>,
1079 response: Response<proto::Follow>,
1080 ) -> Result<()> {
1081 let leader_id = ConnectionId(request.payload.leader_id);
1082 let follower_id = request.sender_id;
1083 {
1084 let mut store = self.store_mut().await;
1085 if !store
1086 .project_connection_ids(request.payload.project_id, follower_id)?
1087 .contains(&leader_id)
1088 {
1089 Err(anyhow!("no such peer"))?;
1090 }
1091
1092 store.register_project_activity(request.payload.project_id, follower_id)?;
1093 }
1094
1095 let mut response_payload = self
1096 .peer
1097 .forward_request(request.sender_id, leader_id, request.payload)
1098 .await?;
1099 response_payload
1100 .views
1101 .retain(|view| view.leader_id != Some(follower_id.0));
1102 response.send(response_payload)?;
1103 Ok(())
1104 }
1105
1106 async fn unfollow(self: Arc<Self>, request: TypedEnvelope<proto::Unfollow>) -> Result<()> {
1107 let leader_id = ConnectionId(request.payload.leader_id);
1108 let mut store = self.store_mut().await;
1109 if !store
1110 .project_connection_ids(request.payload.project_id, request.sender_id)?
1111 .contains(&leader_id)
1112 {
1113 Err(anyhow!("no such peer"))?;
1114 }
1115 store.register_project_activity(request.payload.project_id, request.sender_id)?;
1116 self.peer
1117 .forward_send(request.sender_id, leader_id, request.payload)?;
1118 Ok(())
1119 }
1120
1121 async fn update_followers(
1122 self: Arc<Self>,
1123 request: TypedEnvelope<proto::UpdateFollowers>,
1124 ) -> Result<()> {
1125 let mut store = self.store_mut().await;
1126 store.register_project_activity(request.payload.project_id, request.sender_id)?;
1127 let connection_ids =
1128 store.project_connection_ids(request.payload.project_id, request.sender_id)?;
1129 let leader_id = request
1130 .payload
1131 .variant
1132 .as_ref()
1133 .and_then(|variant| match variant {
1134 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1135 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1136 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1137 });
1138 for follower_id in &request.payload.follower_ids {
1139 let follower_id = ConnectionId(*follower_id);
1140 if connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
1141 self.peer
1142 .forward_send(request.sender_id, follower_id, request.payload.clone())?;
1143 }
1144 }
1145 Ok(())
1146 }
1147
1148 async fn get_channels(
1149 self: Arc<Server>,
1150 request: TypedEnvelope<proto::GetChannels>,
1151 response: Response<proto::GetChannels>,
1152 ) -> Result<()> {
1153 let user_id = self
1154 .store()
1155 .await
1156 .user_id_for_connection(request.sender_id)?;
1157 let channels = self.app_state.db.get_accessible_channels(user_id).await?;
1158 response.send(proto::GetChannelsResponse {
1159 channels: channels
1160 .into_iter()
1161 .map(|chan| proto::Channel {
1162 id: chan.id.to_proto(),
1163 name: chan.name,
1164 })
1165 .collect(),
1166 })?;
1167 Ok(())
1168 }
1169
1170 async fn get_users(
1171 self: Arc<Server>,
1172 request: TypedEnvelope<proto::GetUsers>,
1173 response: Response<proto::GetUsers>,
1174 ) -> Result<()> {
1175 let user_ids = request
1176 .payload
1177 .user_ids
1178 .into_iter()
1179 .map(UserId::from_proto)
1180 .collect();
1181 let users = self
1182 .app_state
1183 .db
1184 .get_users_by_ids(user_ids)
1185 .await?
1186 .into_iter()
1187 .map(|user| proto::User {
1188 id: user.id.to_proto(),
1189 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1190 github_login: user.github_login,
1191 })
1192 .collect();
1193 response.send(proto::UsersResponse { users })?;
1194 Ok(())
1195 }
1196
1197 async fn fuzzy_search_users(
1198 self: Arc<Server>,
1199 request: TypedEnvelope<proto::FuzzySearchUsers>,
1200 response: Response<proto::FuzzySearchUsers>,
1201 ) -> Result<()> {
1202 let user_id = self
1203 .store()
1204 .await
1205 .user_id_for_connection(request.sender_id)?;
1206 let query = request.payload.query;
1207 let db = &self.app_state.db;
1208 let users = match query.len() {
1209 0 => vec![],
1210 1 | 2 => db
1211 .get_user_by_github_login(&query)
1212 .await?
1213 .into_iter()
1214 .collect(),
1215 _ => db.fuzzy_search_users(&query, 10).await?,
1216 };
1217 let users = users
1218 .into_iter()
1219 .filter(|user| user.id != user_id)
1220 .map(|user| proto::User {
1221 id: user.id.to_proto(),
1222 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1223 github_login: user.github_login,
1224 })
1225 .collect();
1226 response.send(proto::UsersResponse { users })?;
1227 Ok(())
1228 }
1229
1230 async fn request_contact(
1231 self: Arc<Server>,
1232 request: TypedEnvelope<proto::RequestContact>,
1233 response: Response<proto::RequestContact>,
1234 ) -> Result<()> {
1235 let requester_id = self
1236 .store()
1237 .await
1238 .user_id_for_connection(request.sender_id)?;
1239 let responder_id = UserId::from_proto(request.payload.responder_id);
1240 if requester_id == responder_id {
1241 return Err(anyhow!("cannot add yourself as a contact"))?;
1242 }
1243
1244 self.app_state
1245 .db
1246 .send_contact_request(requester_id, responder_id)
1247 .await?;
1248
1249 // Update outgoing contact requests of requester
1250 let mut update = proto::UpdateContacts::default();
1251 update.outgoing_requests.push(responder_id.to_proto());
1252 for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1253 self.peer.send(connection_id, update.clone())?;
1254 }
1255
1256 // Update incoming contact requests of responder
1257 let mut update = proto::UpdateContacts::default();
1258 update
1259 .incoming_requests
1260 .push(proto::IncomingContactRequest {
1261 requester_id: requester_id.to_proto(),
1262 should_notify: true,
1263 });
1264 for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1265 self.peer.send(connection_id, update.clone())?;
1266 }
1267
1268 response.send(proto::Ack {})?;
1269 Ok(())
1270 }
1271
1272 async fn respond_to_contact_request(
1273 self: Arc<Server>,
1274 request: TypedEnvelope<proto::RespondToContactRequest>,
1275 response: Response<proto::RespondToContactRequest>,
1276 ) -> Result<()> {
1277 let responder_id = self
1278 .store()
1279 .await
1280 .user_id_for_connection(request.sender_id)?;
1281 let requester_id = UserId::from_proto(request.payload.requester_id);
1282 if request.payload.response == proto::ContactRequestResponse::Dismiss as i32 {
1283 self.app_state
1284 .db
1285 .dismiss_contact_notification(responder_id, requester_id)
1286 .await?;
1287 } else {
1288 let accept = request.payload.response == proto::ContactRequestResponse::Accept as i32;
1289 self.app_state
1290 .db
1291 .respond_to_contact_request(responder_id, requester_id, accept)
1292 .await?;
1293
1294 let store = self.store().await;
1295 // Update responder with new contact
1296 let mut update = proto::UpdateContacts::default();
1297 if accept {
1298 update
1299 .contacts
1300 .push(store.contact_for_user(requester_id, false));
1301 }
1302 update
1303 .remove_incoming_requests
1304 .push(requester_id.to_proto());
1305 for connection_id in store.connection_ids_for_user(responder_id) {
1306 self.peer.send(connection_id, update.clone())?;
1307 }
1308
1309 // Update requester with new contact
1310 let mut update = proto::UpdateContacts::default();
1311 if accept {
1312 update
1313 .contacts
1314 .push(store.contact_for_user(responder_id, true));
1315 }
1316 update
1317 .remove_outgoing_requests
1318 .push(responder_id.to_proto());
1319 for connection_id in store.connection_ids_for_user(requester_id) {
1320 self.peer.send(connection_id, update.clone())?;
1321 }
1322 }
1323
1324 response.send(proto::Ack {})?;
1325 Ok(())
1326 }
1327
1328 async fn remove_contact(
1329 self: Arc<Server>,
1330 request: TypedEnvelope<proto::RemoveContact>,
1331 response: Response<proto::RemoveContact>,
1332 ) -> Result<()> {
1333 let requester_id = self
1334 .store()
1335 .await
1336 .user_id_for_connection(request.sender_id)?;
1337 let responder_id = UserId::from_proto(request.payload.user_id);
1338 self.app_state
1339 .db
1340 .remove_contact(requester_id, responder_id)
1341 .await?;
1342
1343 // Update outgoing contact requests of requester
1344 let mut update = proto::UpdateContacts::default();
1345 update
1346 .remove_outgoing_requests
1347 .push(responder_id.to_proto());
1348 for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1349 self.peer.send(connection_id, update.clone())?;
1350 }
1351
1352 // Update incoming contact requests of responder
1353 let mut update = proto::UpdateContacts::default();
1354 update
1355 .remove_incoming_requests
1356 .push(requester_id.to_proto());
1357 for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1358 self.peer.send(connection_id, update.clone())?;
1359 }
1360
1361 response.send(proto::Ack {})?;
1362 Ok(())
1363 }
1364
1365 async fn join_channel(
1366 self: Arc<Self>,
1367 request: TypedEnvelope<proto::JoinChannel>,
1368 response: Response<proto::JoinChannel>,
1369 ) -> Result<()> {
1370 let user_id = self
1371 .store()
1372 .await
1373 .user_id_for_connection(request.sender_id)?;
1374 let channel_id = ChannelId::from_proto(request.payload.channel_id);
1375 if !self
1376 .app_state
1377 .db
1378 .can_user_access_channel(user_id, channel_id)
1379 .await?
1380 {
1381 Err(anyhow!("access denied"))?;
1382 }
1383
1384 self.store_mut()
1385 .await
1386 .join_channel(request.sender_id, channel_id);
1387 let messages = self
1388 .app_state
1389 .db
1390 .get_channel_messages(channel_id, MESSAGE_COUNT_PER_PAGE, None)
1391 .await?
1392 .into_iter()
1393 .map(|msg| proto::ChannelMessage {
1394 id: msg.id.to_proto(),
1395 body: msg.body,
1396 timestamp: msg.sent_at.unix_timestamp() as u64,
1397 sender_id: msg.sender_id.to_proto(),
1398 nonce: Some(msg.nonce.as_u128().into()),
1399 })
1400 .collect::<Vec<_>>();
1401 response.send(proto::JoinChannelResponse {
1402 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
1403 messages,
1404 })?;
1405 Ok(())
1406 }
1407
1408 async fn leave_channel(
1409 self: Arc<Self>,
1410 request: TypedEnvelope<proto::LeaveChannel>,
1411 ) -> Result<()> {
1412 let user_id = self
1413 .store()
1414 .await
1415 .user_id_for_connection(request.sender_id)?;
1416 let channel_id = ChannelId::from_proto(request.payload.channel_id);
1417 if !self
1418 .app_state
1419 .db
1420 .can_user_access_channel(user_id, channel_id)
1421 .await?
1422 {
1423 Err(anyhow!("access denied"))?;
1424 }
1425
1426 self.store_mut()
1427 .await
1428 .leave_channel(request.sender_id, channel_id);
1429
1430 Ok(())
1431 }
1432
1433 async fn send_channel_message(
1434 self: Arc<Self>,
1435 request: TypedEnvelope<proto::SendChannelMessage>,
1436 response: Response<proto::SendChannelMessage>,
1437 ) -> Result<()> {
1438 let channel_id = ChannelId::from_proto(request.payload.channel_id);
1439 let user_id;
1440 let connection_ids;
1441 {
1442 let state = self.store().await;
1443 user_id = state.user_id_for_connection(request.sender_id)?;
1444 connection_ids = state.channel_connection_ids(channel_id)?;
1445 }
1446
1447 // Validate the message body.
1448 let body = request.payload.body.trim().to_string();
1449 if body.len() > MAX_MESSAGE_LEN {
1450 return Err(anyhow!("message is too long"))?;
1451 }
1452 if body.is_empty() {
1453 return Err(anyhow!("message can't be blank"))?;
1454 }
1455
1456 let timestamp = OffsetDateTime::now_utc();
1457 let nonce = request
1458 .payload
1459 .nonce
1460 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
1461
1462 let message_id = self
1463 .app_state
1464 .db
1465 .create_channel_message(channel_id, user_id, &body, timestamp, nonce.clone().into())
1466 .await?
1467 .to_proto();
1468 let message = proto::ChannelMessage {
1469 sender_id: user_id.to_proto(),
1470 id: message_id,
1471 body,
1472 timestamp: timestamp.unix_timestamp() as u64,
1473 nonce: Some(nonce),
1474 };
1475 broadcast(request.sender_id, connection_ids, |conn_id| {
1476 self.peer.send(
1477 conn_id,
1478 proto::ChannelMessageSent {
1479 channel_id: channel_id.to_proto(),
1480 message: Some(message.clone()),
1481 },
1482 )
1483 });
1484 response.send(proto::SendChannelMessageResponse {
1485 message: Some(message),
1486 })?;
1487 Ok(())
1488 }
1489
1490 async fn get_channel_messages(
1491 self: Arc<Self>,
1492 request: TypedEnvelope<proto::GetChannelMessages>,
1493 response: Response<proto::GetChannelMessages>,
1494 ) -> Result<()> {
1495 let user_id = self
1496 .store()
1497 .await
1498 .user_id_for_connection(request.sender_id)?;
1499 let channel_id = ChannelId::from_proto(request.payload.channel_id);
1500 if !self
1501 .app_state
1502 .db
1503 .can_user_access_channel(user_id, channel_id)
1504 .await?
1505 {
1506 Err(anyhow!("access denied"))?;
1507 }
1508
1509 let messages = self
1510 .app_state
1511 .db
1512 .get_channel_messages(
1513 channel_id,
1514 MESSAGE_COUNT_PER_PAGE,
1515 Some(MessageId::from_proto(request.payload.before_message_id)),
1516 )
1517 .await?
1518 .into_iter()
1519 .map(|msg| proto::ChannelMessage {
1520 id: msg.id.to_proto(),
1521 body: msg.body,
1522 timestamp: msg.sent_at.unix_timestamp() as u64,
1523 sender_id: msg.sender_id.to_proto(),
1524 nonce: Some(msg.nonce.as_u128().into()),
1525 })
1526 .collect::<Vec<_>>();
1527 response.send(proto::GetChannelMessagesResponse {
1528 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
1529 messages,
1530 })?;
1531 Ok(())
1532 }
1533
1534 async fn store<'a>(self: &'a Arc<Self>) -> StoreReadGuard<'a> {
1535 #[cfg(test)]
1536 tokio::task::yield_now().await;
1537 let guard = self.store.read().await;
1538 #[cfg(test)]
1539 tokio::task::yield_now().await;
1540 StoreReadGuard {
1541 guard,
1542 _not_send: PhantomData,
1543 }
1544 }
1545
1546 async fn store_mut<'a>(self: &'a Arc<Self>) -> StoreWriteGuard<'a> {
1547 #[cfg(test)]
1548 tokio::task::yield_now().await;
1549 let guard = self.store.write().await;
1550 #[cfg(test)]
1551 tokio::task::yield_now().await;
1552 StoreWriteGuard {
1553 guard,
1554 _not_send: PhantomData,
1555 }
1556 }
1557
1558 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1559 ServerSnapshot {
1560 store: self.store.read().await,
1561 peer: &self.peer,
1562 }
1563 }
1564}
1565
1566impl<'a> Deref for StoreReadGuard<'a> {
1567 type Target = Store;
1568
1569 fn deref(&self) -> &Self::Target {
1570 &*self.guard
1571 }
1572}
1573
1574impl<'a> Deref for StoreWriteGuard<'a> {
1575 type Target = Store;
1576
1577 fn deref(&self) -> &Self::Target {
1578 &*self.guard
1579 }
1580}
1581
1582impl<'a> DerefMut for StoreWriteGuard<'a> {
1583 fn deref_mut(&mut self) -> &mut Self::Target {
1584 &mut *self.guard
1585 }
1586}
1587
1588impl<'a> Drop for StoreWriteGuard<'a> {
1589 fn drop(&mut self) {
1590 #[cfg(test)]
1591 self.check_invariants();
1592
1593 let metrics = self.metrics();
1594
1595 METRIC_CONNECTIONS.set(metrics.connections as _);
1596 METRIC_REGISTERED_PROJECTS.set(metrics.registered_projects as _);
1597 METRIC_ACTIVE_PROJECTS.set(metrics.active_projects as _);
1598 METRIC_SHARED_PROJECTS.set(metrics.shared_projects as _);
1599
1600 tracing::info!(
1601 connections = metrics.connections,
1602 registered_projects = metrics.registered_projects,
1603 active_projects = metrics.active_projects,
1604 shared_projects = metrics.shared_projects,
1605 "metrics"
1606 );
1607 }
1608}
1609
1610impl Executor for RealExecutor {
1611 type Sleep = Sleep;
1612
1613 fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F) {
1614 tokio::task::spawn(future);
1615 }
1616
1617 fn sleep(&self, duration: Duration) -> Self::Sleep {
1618 tokio::time::sleep(duration)
1619 }
1620}
1621
1622fn broadcast<F>(
1623 sender_id: ConnectionId,
1624 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1625 mut f: F,
1626) where
1627 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1628{
1629 for receiver_id in receiver_ids {
1630 if receiver_id != sender_id {
1631 f(receiver_id).trace_err();
1632 }
1633 }
1634}
1635
1636lazy_static! {
1637 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
1638}
1639
1640pub struct ProtocolVersion(u32);
1641
1642impl Header for ProtocolVersion {
1643 fn name() -> &'static HeaderName {
1644 &ZED_PROTOCOL_VERSION
1645 }
1646
1647 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1648 where
1649 Self: Sized,
1650 I: Iterator<Item = &'i axum::http::HeaderValue>,
1651 {
1652 let version = values
1653 .next()
1654 .ok_or_else(|| axum::headers::Error::invalid())?
1655 .to_str()
1656 .map_err(|_| axum::headers::Error::invalid())?
1657 .parse()
1658 .map_err(|_| axum::headers::Error::invalid())?;
1659 Ok(Self(version))
1660 }
1661
1662 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1663 values.extend([self.0.to_string().parse().unwrap()]);
1664 }
1665}
1666
1667pub fn routes(server: Arc<Server>) -> Router<Body> {
1668 Router::new()
1669 .route("/rpc", get(handle_websocket_request))
1670 .layer(
1671 ServiceBuilder::new()
1672 .layer(Extension(server.app_state.clone()))
1673 .layer(middleware::from_fn(auth::validate_header)),
1674 )
1675 .route("/metrics", get(handle_metrics))
1676 .layer(Extension(server))
1677}
1678
1679pub async fn handle_websocket_request(
1680 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1681 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1682 Extension(server): Extension<Arc<Server>>,
1683 Extension(user): Extension<User>,
1684 ws: WebSocketUpgrade,
1685) -> axum::response::Response {
1686 if protocol_version != rpc::PROTOCOL_VERSION {
1687 return (
1688 StatusCode::UPGRADE_REQUIRED,
1689 "client must be upgraded".to_string(),
1690 )
1691 .into_response();
1692 }
1693 let socket_address = socket_address.to_string();
1694 ws.on_upgrade(move |socket| {
1695 use util::ResultExt;
1696 let socket = socket
1697 .map_ok(to_tungstenite_message)
1698 .err_into()
1699 .with(|message| async move { Ok(to_axum_message(message)) });
1700 let connection = Connection::new(Box::pin(socket));
1701 async move {
1702 server
1703 .handle_connection(connection, socket_address, user, None, RealExecutor)
1704 .await
1705 .log_err();
1706 }
1707 })
1708}
1709
1710pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> axum::response::Response {
1711 // We call `store_mut` here for its side effects of updating metrics.
1712 server.store_mut().await;
1713
1714 let encoder = prometheus::TextEncoder::new();
1715 let metric_families = prometheus::gather();
1716 match encoder.encode_to_string(&metric_families) {
1717 Ok(string) => (StatusCode::OK, string).into_response(),
1718 Err(error) => (
1719 StatusCode::INTERNAL_SERVER_ERROR,
1720 format!("failed to encode metrics {:?}", error),
1721 )
1722 .into_response(),
1723 }
1724}
1725
1726fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1727 match message {
1728 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1729 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1730 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1731 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1732 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1733 code: frame.code.into(),
1734 reason: frame.reason,
1735 })),
1736 }
1737}
1738
1739fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1740 match message {
1741 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1742 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1743 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1744 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1745 AxumMessage::Close(frame) => {
1746 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1747 code: frame.code.into(),
1748 reason: frame.reason,
1749 }))
1750 }
1751 }
1752}
1753
1754pub trait ResultExt {
1755 type Ok;
1756
1757 fn trace_err(self) -> Option<Self::Ok>;
1758}
1759
1760impl<T, E> ResultExt for Result<T, E>
1761where
1762 E: std::fmt::Debug,
1763{
1764 type Ok = T;
1765
1766 fn trace_err(self) -> Option<T> {
1767 match self {
1768 Ok(value) => Some(value),
1769 Err(error) => {
1770 tracing::error!("{:?}", error);
1771 None
1772 }
1773 }
1774 }
1775}