proto_client.rs

  1use anyhow::{Context, Result};
  2use collections::HashMap;
  3use futures::{
  4    Future, FutureExt as _,
  5    channel::oneshot,
  6    future::{BoxFuture, LocalBoxFuture},
  7};
  8use gpui::{AnyEntity, AnyWeakEntity, AsyncApp, BackgroundExecutor, Entity, FutureExt as _};
  9use parking_lot::Mutex;
 10use proto::{
 11    AnyTypedEnvelope, EntityMessage, Envelope, EnvelopedMessage, LspRequestId, LspRequestMessage,
 12    RequestMessage, TypedEnvelope, error::ErrorExt as _,
 13};
 14use std::{
 15    any::{Any, TypeId},
 16    sync::{
 17        Arc, OnceLock,
 18        atomic::{self, AtomicU64},
 19    },
 20    time::Duration,
 21};
 22
 23#[derive(Clone)]
 24pub struct AnyProtoClient(Arc<State>);
 25
 26type RequestIds = Arc<
 27    Mutex<
 28        HashMap<
 29            LspRequestId,
 30            oneshot::Sender<
 31                Result<
 32                    Option<TypedEnvelope<Vec<proto::ProtoLspResponse<Box<dyn AnyTypedEnvelope>>>>>,
 33                >,
 34            >,
 35        >,
 36    >,
 37>;
 38
 39static NEXT_LSP_REQUEST_ID: OnceLock<Arc<AtomicU64>> = OnceLock::new();
 40static REQUEST_IDS: OnceLock<RequestIds> = OnceLock::new();
 41
 42struct State {
 43    client: Arc<dyn ProtoClient>,
 44    next_lsp_request_id: Arc<AtomicU64>,
 45    request_ids: RequestIds,
 46}
 47
 48pub trait ProtoClient: Send + Sync {
 49    fn request(
 50        &self,
 51        envelope: Envelope,
 52        request_type: &'static str,
 53    ) -> BoxFuture<'static, Result<Envelope>>;
 54
 55    fn send(&self, envelope: Envelope, message_type: &'static str) -> Result<()>;
 56
 57    fn send_response(&self, envelope: Envelope, message_type: &'static str) -> Result<()>;
 58
 59    fn message_handler_set(&self) -> &parking_lot::Mutex<ProtoMessageHandlerSet>;
 60
 61    fn is_via_collab(&self) -> bool;
 62}
 63
 64#[derive(Default)]
 65pub struct ProtoMessageHandlerSet {
 66    pub entity_types_by_message_type: HashMap<TypeId, TypeId>,
 67    pub entities_by_type_and_remote_id: HashMap<(TypeId, u64), EntityMessageSubscriber>,
 68    pub entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
 69    pub entities_by_message_type: HashMap<TypeId, AnyWeakEntity>,
 70    pub message_handlers: HashMap<TypeId, ProtoMessageHandler>,
 71}
 72
 73pub type ProtoMessageHandler = Arc<
 74    dyn Send
 75        + Sync
 76        + Fn(
 77            AnyEntity,
 78            Box<dyn AnyTypedEnvelope>,
 79            AnyProtoClient,
 80            AsyncApp,
 81        ) -> LocalBoxFuture<'static, Result<()>>,
 82>;
 83
 84impl ProtoMessageHandlerSet {
 85    pub fn clear(&mut self) {
 86        self.message_handlers.clear();
 87        self.entities_by_message_type.clear();
 88        self.entities_by_type_and_remote_id.clear();
 89        self.entity_id_extractors.clear();
 90    }
 91
 92    fn add_message_handler(
 93        &mut self,
 94        message_type_id: TypeId,
 95        entity: gpui::AnyWeakEntity,
 96        handler: ProtoMessageHandler,
 97    ) {
 98        self.entities_by_message_type
 99            .insert(message_type_id, entity);
100        let prev_handler = self.message_handlers.insert(message_type_id, handler);
101        if prev_handler.is_some() {
102            panic!("registered handler for the same message twice");
103        }
104    }
105
106    fn add_entity_message_handler(
107        &mut self,
108        message_type_id: TypeId,
109        entity_type_id: TypeId,
110        entity_id_extractor: fn(&dyn AnyTypedEnvelope) -> u64,
111        handler: ProtoMessageHandler,
112    ) {
113        self.entity_id_extractors
114            .entry(message_type_id)
115            .or_insert(entity_id_extractor);
116        self.entity_types_by_message_type
117            .insert(message_type_id, entity_type_id);
118        let prev_handler = self.message_handlers.insert(message_type_id, handler);
119        if prev_handler.is_some() {
120            panic!("registered handler for the same message twice");
121        }
122    }
123
124    pub fn handle_message(
125        this: &parking_lot::Mutex<Self>,
126        message: Box<dyn AnyTypedEnvelope>,
127        client: AnyProtoClient,
128        cx: AsyncApp,
129    ) -> Option<LocalBoxFuture<'static, Result<()>>> {
130        let payload_type_id = message.payload_type_id();
131        let mut this = this.lock();
132        let handler = this.message_handlers.get(&payload_type_id)?.clone();
133        let entity = if let Some(entity) = this.entities_by_message_type.get(&payload_type_id) {
134            entity.upgrade()?
135        } else {
136            let extract_entity_id = *this.entity_id_extractors.get(&payload_type_id)?;
137            let entity_type_id = *this.entity_types_by_message_type.get(&payload_type_id)?;
138            let entity_id = (extract_entity_id)(message.as_ref());
139            match this
140                .entities_by_type_and_remote_id
141                .get_mut(&(entity_type_id, entity_id))?
142            {
143                EntityMessageSubscriber::Pending(pending) => {
144                    pending.push(message);
145                    return None;
146                }
147                EntityMessageSubscriber::Entity { handle } => handle.upgrade()?,
148            }
149        };
150        drop(this);
151        Some(handler(entity, message, client, cx))
152    }
153}
154
155pub enum EntityMessageSubscriber {
156    Entity { handle: AnyWeakEntity },
157    Pending(Vec<Box<dyn AnyTypedEnvelope>>),
158}
159
160impl std::fmt::Debug for EntityMessageSubscriber {
161    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162        match self {
163            EntityMessageSubscriber::Entity { handle } => f
164                .debug_struct("EntityMessageSubscriber::Entity")
165                .field("handle", handle)
166                .finish(),
167            EntityMessageSubscriber::Pending(vec) => f
168                .debug_struct("EntityMessageSubscriber::Pending")
169                .field(
170                    "envelopes",
171                    &vec.iter()
172                        .map(|envelope| envelope.payload_type_name())
173                        .collect::<Vec<_>>(),
174                )
175                .finish(),
176        }
177    }
178}
179
180impl<T> From<Arc<T>> for AnyProtoClient
181where
182    T: ProtoClient + 'static,
183{
184    fn from(client: Arc<T>) -> Self {
185        Self::new(client)
186    }
187}
188
189impl AnyProtoClient {
190    pub fn new<T: ProtoClient + 'static>(client: Arc<T>) -> Self {
191        Self(Arc::new(State {
192            client,
193            next_lsp_request_id: NEXT_LSP_REQUEST_ID
194                .get_or_init(|| Arc::new(AtomicU64::new(0)))
195                .clone(),
196            request_ids: REQUEST_IDS.get_or_init(RequestIds::default).clone(),
197        }))
198    }
199
200    pub fn is_via_collab(&self) -> bool {
201        self.0.client.is_via_collab()
202    }
203
204    pub fn request<T: RequestMessage>(
205        &self,
206        request: T,
207    ) -> impl Future<Output = Result<T::Response>> + use<T> {
208        let envelope = request.into_envelope(0, None, None);
209        let response = self.0.client.request(envelope, T::NAME);
210        async move {
211            T::Response::from_envelope(response.await?)
212                .context("received response of the wrong type")
213        }
214    }
215
216    pub fn send<T: EnvelopedMessage>(&self, request: T) -> Result<()> {
217        let envelope = request.into_envelope(0, None, None);
218        self.0.client.send(envelope, T::NAME)
219    }
220
221    pub fn send_response<T: EnvelopedMessage>(&self, request_id: u32, request: T) -> Result<()> {
222        let envelope = request.into_envelope(0, Some(request_id), None);
223        self.0.client.send(envelope, T::NAME)
224    }
225
226    pub fn request_lsp<T>(
227        &self,
228        project_id: u64,
229        server_id: Option<u64>,
230        timeout: Duration,
231        executor: BackgroundExecutor,
232        request: T,
233    ) -> impl Future<
234        Output = Result<Option<TypedEnvelope<Vec<proto::ProtoLspResponse<T::Response>>>>>,
235    > + use<T>
236    where
237        T: LspRequestMessage,
238    {
239        let new_id = LspRequestId(
240            self.0
241                .next_lsp_request_id
242                .fetch_add(1, atomic::Ordering::Acquire),
243        );
244        let (tx, rx) = oneshot::channel();
245        {
246            self.0.request_ids.lock().insert(new_id, tx);
247        }
248
249        let query = proto::LspQuery {
250            project_id,
251            server_id,
252            lsp_request_id: new_id.0,
253            request: Some(request.to_proto_query()),
254        };
255        let request = self.request(query);
256        let request_ids = self.0.request_ids.clone();
257        async move {
258            match request.await {
259                Ok(_request_enqueued) => {}
260                Err(e) => {
261                    request_ids.lock().remove(&new_id);
262                    return Err(e).context("sending LSP proto request");
263                }
264            }
265
266            let response = rx.with_timeout(timeout, &executor).await;
267            {
268                request_ids.lock().remove(&new_id);
269            }
270            match response {
271                Ok(Ok(response)) => {
272                    let response = response
273                        .context("waiting for LSP proto response")?
274                        .map(|response| {
275                            anyhow::Ok(TypedEnvelope {
276                                payload: response
277                                    .payload
278                                    .into_iter()
279                                    .map(|lsp_response| lsp_response.into_response::<T>())
280                                    .collect::<Result<Vec<_>>>()?,
281                                sender_id: response.sender_id,
282                                original_sender_id: response.original_sender_id,
283                                message_id: response.message_id,
284                                received_at: response.received_at,
285                            })
286                        })
287                        .transpose()
288                        .context("converting LSP proto response")?;
289                    Ok(response)
290                }
291                Err(_cancelled_due_timeout) => Ok(None),
292                Ok(Err(_channel_dropped)) => Ok(None),
293            }
294        }
295    }
296
297    pub fn send_lsp_response<T: LspRequestMessage>(
298        &self,
299        project_id: u64,
300        lsp_request_id: LspRequestId,
301        server_responses: HashMap<u64, T::Response>,
302    ) -> Result<()> {
303        self.send(proto::LspQueryResponse {
304            project_id,
305            lsp_request_id: lsp_request_id.0,
306            responses: server_responses
307                .into_iter()
308                .map(|(server_id, response)| proto::LspResponse {
309                    server_id,
310                    response: Some(T::response_to_proto_query(response)),
311                })
312                .collect(),
313        })
314    }
315
316    pub fn handle_lsp_response(&self, mut envelope: TypedEnvelope<proto::LspQueryResponse>) {
317        let request_id = LspRequestId(envelope.payload.lsp_request_id);
318        let mut response_senders = self.0.request_ids.lock();
319        if let Some(tx) = response_senders.remove(&request_id) {
320            let responses = envelope.payload.responses.drain(..).collect::<Vec<_>>();
321            tx.send(Ok(Some(proto::TypedEnvelope {
322                sender_id: envelope.sender_id,
323                original_sender_id: envelope.original_sender_id,
324                message_id: envelope.message_id,
325                received_at: envelope.received_at,
326                payload: responses
327                    .into_iter()
328                    .filter_map(|response| {
329                        use proto::lsp_response::Response;
330
331                        let server_id = response.server_id;
332                        let response = match response.response? {
333                            Response::GetReferencesResponse(response) => {
334                                to_any_envelope(&envelope, response)
335                            }
336                            Response::GetDocumentColorResponse(response) => {
337                                to_any_envelope(&envelope, response)
338                            }
339                            Response::GetHoverResponse(response) => {
340                                to_any_envelope(&envelope, response)
341                            }
342                            Response::GetCodeActionsResponse(response) => {
343                                to_any_envelope(&envelope, response)
344                            }
345                            Response::GetSignatureHelpResponse(response) => {
346                                to_any_envelope(&envelope, response)
347                            }
348                            Response::GetCodeLensResponse(response) => {
349                                to_any_envelope(&envelope, response)
350                            }
351                            Response::GetDocumentDiagnosticsResponse(response) => {
352                                to_any_envelope(&envelope, response)
353                            }
354                            Response::GetDefinitionResponse(response) => {
355                                to_any_envelope(&envelope, response)
356                            }
357                            Response::GetDeclarationResponse(response) => {
358                                to_any_envelope(&envelope, response)
359                            }
360                            Response::GetTypeDefinitionResponse(response) => {
361                                to_any_envelope(&envelope, response)
362                            }
363                            Response::GetImplementationResponse(response) => {
364                                to_any_envelope(&envelope, response)
365                            }
366                            Response::InlayHintsResponse(response) => {
367                                to_any_envelope(&envelope, response)
368                            }
369                        };
370                        Some(proto::ProtoLspResponse {
371                            server_id,
372                            response,
373                        })
374                    })
375                    .collect(),
376            })))
377            .ok();
378        }
379    }
380
381    pub fn add_request_handler<M, E, H, F>(&self, entity: gpui::WeakEntity<E>, handler: H)
382    where
383        M: RequestMessage,
384        E: 'static,
385        H: 'static + Sync + Fn(Entity<E>, TypedEnvelope<M>, AsyncApp) -> F + Send + Sync,
386        F: 'static + Future<Output = Result<M::Response>>,
387    {
388        self.0
389            .client
390            .message_handler_set()
391            .lock()
392            .add_message_handler(
393                TypeId::of::<M>(),
394                entity.into(),
395                Arc::new(move |entity, envelope, client, cx| {
396                    let entity = entity.downcast::<E>().unwrap();
397                    let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
398                    let request_id = envelope.message_id();
399                    handler(entity, *envelope, cx)
400                        .then(move |result| async move {
401                            match result {
402                                Ok(response) => {
403                                    client.send_response(request_id, response)?;
404                                    Ok(())
405                                }
406                                Err(error) => {
407                                    client.send_response(request_id, error.to_proto())?;
408                                    Err(error)
409                                }
410                            }
411                        })
412                        .boxed_local()
413                }),
414            )
415    }
416
417    pub fn add_entity_request_handler<M, E, H, F>(&self, handler: H)
418    where
419        M: EnvelopedMessage + RequestMessage + EntityMessage,
420        E: 'static,
421        H: 'static + Sync + Send + Fn(gpui::Entity<E>, TypedEnvelope<M>, AsyncApp) -> F,
422        F: 'static + Future<Output = Result<M::Response>>,
423    {
424        let message_type_id = TypeId::of::<M>();
425        let entity_type_id = TypeId::of::<E>();
426        let entity_id_extractor = |envelope: &dyn AnyTypedEnvelope| {
427            (envelope as &dyn Any)
428                .downcast_ref::<TypedEnvelope<M>>()
429                .unwrap()
430                .payload
431                .remote_entity_id()
432        };
433        self.0
434            .client
435            .message_handler_set()
436            .lock()
437            .add_entity_message_handler(
438                message_type_id,
439                entity_type_id,
440                entity_id_extractor,
441                Arc::new(move |entity, envelope, client, cx| {
442                    let entity = entity.downcast::<E>().unwrap();
443                    let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
444                    let request_id = envelope.message_id();
445                    handler(entity, *envelope, cx)
446                        .then(move |result| async move {
447                            match result {
448                                Ok(response) => {
449                                    client.send_response(request_id, response)?;
450                                    Ok(())
451                                }
452                                Err(error) => {
453                                    client.send_response(request_id, error.to_proto())?;
454                                    Err(error)
455                                }
456                            }
457                        })
458                        .boxed_local()
459                }),
460            );
461    }
462
463    pub fn add_entity_message_handler<M, E, H, F>(&self, handler: H)
464    where
465        M: EnvelopedMessage + EntityMessage,
466        E: 'static,
467        H: 'static + Sync + Send + Fn(gpui::Entity<E>, TypedEnvelope<M>, AsyncApp) -> F,
468        F: 'static + Future<Output = Result<()>>,
469    {
470        let message_type_id = TypeId::of::<M>();
471        let entity_type_id = TypeId::of::<E>();
472        let entity_id_extractor = |envelope: &dyn AnyTypedEnvelope| {
473            (envelope as &dyn Any)
474                .downcast_ref::<TypedEnvelope<M>>()
475                .unwrap()
476                .payload
477                .remote_entity_id()
478        };
479        self.0
480            .client
481            .message_handler_set()
482            .lock()
483            .add_entity_message_handler(
484                message_type_id,
485                entity_type_id,
486                entity_id_extractor,
487                Arc::new(move |entity, envelope, _, cx| {
488                    let entity = entity.downcast::<E>().unwrap();
489                    let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
490                    handler(entity, *envelope, cx).boxed_local()
491                }),
492            );
493    }
494
495    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Entity<E>) {
496        let id = (TypeId::of::<E>(), remote_id);
497
498        let mut message_handlers = self.0.client.message_handler_set().lock();
499        if message_handlers
500            .entities_by_type_and_remote_id
501            .contains_key(&id)
502        {
503            panic!("already subscribed to entity");
504        }
505
506        message_handlers.entities_by_type_and_remote_id.insert(
507            id,
508            EntityMessageSubscriber::Entity {
509                handle: entity.downgrade().into(),
510            },
511        );
512    }
513}
514
515fn to_any_envelope<T: EnvelopedMessage>(
516    envelope: &TypedEnvelope<proto::LspQueryResponse>,
517    response: T,
518) -> Box<dyn AnyTypedEnvelope> {
519    Box::new(proto::TypedEnvelope {
520        sender_id: envelope.sender_id,
521        original_sender_id: envelope.original_sender_id,
522        message_id: envelope.message_id,
523        received_at: envelope.received_at,
524        payload: response,
525    }) as Box<_>
526}