proto_client.rs

  1use anyhow::{Context, Result};
  2use collections::HashMap;
  3use futures::{
  4    Future, FutureExt as _,
  5    channel::oneshot,
  6    future::{BoxFuture, LocalBoxFuture},
  7};
  8use gpui::{AnyEntity, AnyWeakEntity, AsyncApp, BackgroundExecutor, Entity, FutureExt as _};
  9use parking_lot::Mutex;
 10use proto::{
 11    AnyTypedEnvelope, EntityMessage, Envelope, EnvelopedMessage, LspRequestId, LspRequestMessage,
 12    RequestMessage, TypedEnvelope, error::ErrorExt as _,
 13};
 14use std::{
 15    any::{Any, TypeId},
 16    sync::{
 17        Arc, OnceLock,
 18        atomic::{self, AtomicU64},
 19    },
 20    time::Duration,
 21};
 22
 23#[derive(Clone)]
 24pub struct AnyProtoClient(Arc<State>);
 25
 26type RequestIds = Arc<
 27    Mutex<
 28        HashMap<
 29            LspRequestId,
 30            oneshot::Sender<
 31                Result<
 32                    Option<TypedEnvelope<Vec<proto::ProtoLspResponse<Box<dyn AnyTypedEnvelope>>>>>,
 33                >,
 34            >,
 35        >,
 36    >,
 37>;
 38
 39static NEXT_LSP_REQUEST_ID: OnceLock<Arc<AtomicU64>> = OnceLock::new();
 40static REQUEST_IDS: OnceLock<RequestIds> = OnceLock::new();
 41
 42struct State {
 43    client: Arc<dyn ProtoClient>,
 44    next_lsp_request_id: Arc<AtomicU64>,
 45    request_ids: RequestIds,
 46}
 47
 48pub trait ProtoClient: Send + Sync {
 49    fn request(
 50        &self,
 51        envelope: Envelope,
 52        request_type: &'static str,
 53    ) -> BoxFuture<'static, Result<Envelope>>;
 54
 55    fn send(&self, envelope: Envelope, message_type: &'static str) -> Result<()>;
 56
 57    fn send_response(&self, envelope: Envelope, message_type: &'static str) -> Result<()>;
 58
 59    fn message_handler_set(&self) -> &parking_lot::Mutex<ProtoMessageHandlerSet>;
 60
 61    fn is_via_collab(&self) -> bool;
 62    fn has_wsl_interop(&self) -> bool;
 63}
 64
 65#[derive(Default)]
 66pub struct ProtoMessageHandlerSet {
 67    pub entity_types_by_message_type: HashMap<TypeId, TypeId>,
 68    pub entities_by_type_and_remote_id: HashMap<(TypeId, u64), EntityMessageSubscriber>,
 69    pub entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
 70    pub entities_by_message_type: HashMap<TypeId, AnyWeakEntity>,
 71    pub message_handlers: HashMap<TypeId, ProtoMessageHandler>,
 72}
 73
 74pub type ProtoMessageHandler = Arc<
 75    dyn Send
 76        + Sync
 77        + Fn(
 78            AnyEntity,
 79            Box<dyn AnyTypedEnvelope>,
 80            AnyProtoClient,
 81            AsyncApp,
 82        ) -> LocalBoxFuture<'static, Result<()>>,
 83>;
 84
 85impl ProtoMessageHandlerSet {
 86    pub fn clear(&mut self) {
 87        self.message_handlers.clear();
 88        self.entities_by_message_type.clear();
 89        self.entities_by_type_and_remote_id.clear();
 90        self.entity_id_extractors.clear();
 91    }
 92
 93    fn add_message_handler(
 94        &mut self,
 95        message_type_id: TypeId,
 96        entity: gpui::AnyWeakEntity,
 97        handler: ProtoMessageHandler,
 98    ) {
 99        self.entities_by_message_type
100            .insert(message_type_id, entity);
101        let prev_handler = self.message_handlers.insert(message_type_id, handler);
102        if prev_handler.is_some() {
103            panic!("registered handler for the same message twice");
104        }
105    }
106
107    fn add_entity_message_handler(
108        &mut self,
109        message_type_id: TypeId,
110        entity_type_id: TypeId,
111        entity_id_extractor: fn(&dyn AnyTypedEnvelope) -> u64,
112        handler: ProtoMessageHandler,
113    ) {
114        self.entity_id_extractors
115            .entry(message_type_id)
116            .or_insert(entity_id_extractor);
117        self.entity_types_by_message_type
118            .insert(message_type_id, entity_type_id);
119        let prev_handler = self.message_handlers.insert(message_type_id, handler);
120        if prev_handler.is_some() {
121            panic!("registered handler for the same message twice");
122        }
123    }
124
125    pub fn handle_message(
126        this: &parking_lot::Mutex<Self>,
127        message: Box<dyn AnyTypedEnvelope>,
128        client: AnyProtoClient,
129        cx: AsyncApp,
130    ) -> Option<LocalBoxFuture<'static, Result<()>>> {
131        let payload_type_id = message.payload_type_id();
132        let mut this = this.lock();
133        let handler = this.message_handlers.get(&payload_type_id)?.clone();
134        let entity = if let Some(entity) = this.entities_by_message_type.get(&payload_type_id) {
135            entity.upgrade()?
136        } else {
137            let extract_entity_id = *this.entity_id_extractors.get(&payload_type_id)?;
138            let entity_type_id = *this.entity_types_by_message_type.get(&payload_type_id)?;
139            let entity_id = (extract_entity_id)(message.as_ref());
140            match this
141                .entities_by_type_and_remote_id
142                .get_mut(&(entity_type_id, entity_id))?
143            {
144                EntityMessageSubscriber::Pending(pending) => {
145                    pending.push(message);
146                    return None;
147                }
148                EntityMessageSubscriber::Entity { handle } => handle.upgrade()?,
149            }
150        };
151        drop(this);
152        Some(handler(entity, message, client, cx))
153    }
154}
155
156pub enum EntityMessageSubscriber {
157    Entity { handle: AnyWeakEntity },
158    Pending(Vec<Box<dyn AnyTypedEnvelope>>),
159}
160
161impl std::fmt::Debug for EntityMessageSubscriber {
162    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163        match self {
164            EntityMessageSubscriber::Entity { handle } => f
165                .debug_struct("EntityMessageSubscriber::Entity")
166                .field("handle", handle)
167                .finish(),
168            EntityMessageSubscriber::Pending(vec) => f
169                .debug_struct("EntityMessageSubscriber::Pending")
170                .field(
171                    "envelopes",
172                    &vec.iter()
173                        .map(|envelope| envelope.payload_type_name())
174                        .collect::<Vec<_>>(),
175                )
176                .finish(),
177        }
178    }
179}
180
181impl<T> From<Arc<T>> for AnyProtoClient
182where
183    T: ProtoClient + 'static,
184{
185    fn from(client: Arc<T>) -> Self {
186        Self::new(client)
187    }
188}
189
190impl AnyProtoClient {
191    pub fn new<T: ProtoClient + 'static>(client: Arc<T>) -> Self {
192        Self(Arc::new(State {
193            client,
194            next_lsp_request_id: NEXT_LSP_REQUEST_ID
195                .get_or_init(|| Arc::new(AtomicU64::new(0)))
196                .clone(),
197            request_ids: REQUEST_IDS.get_or_init(RequestIds::default).clone(),
198        }))
199    }
200
201    pub fn is_via_collab(&self) -> bool {
202        self.0.client.is_via_collab()
203    }
204
205    pub fn request<T: RequestMessage>(
206        &self,
207        request: T,
208    ) -> impl Future<Output = Result<T::Response>> + use<T> {
209        let envelope = request.into_envelope(0, None, None);
210        let response = self.0.client.request(envelope, T::NAME);
211        async move {
212            T::Response::from_envelope(response.await?)
213                .context("received response of the wrong type")
214        }
215    }
216
217    pub fn send<T: EnvelopedMessage>(&self, request: T) -> Result<()> {
218        let envelope = request.into_envelope(0, None, None);
219        self.0.client.send(envelope, T::NAME)
220    }
221
222    pub fn send_response<T: EnvelopedMessage>(&self, request_id: u32, request: T) -> Result<()> {
223        let envelope = request.into_envelope(0, Some(request_id), None);
224        self.0.client.send(envelope, T::NAME)
225    }
226
227    pub fn request_lsp<T>(
228        &self,
229        project_id: u64,
230        server_id: Option<u64>,
231        timeout: Duration,
232        executor: BackgroundExecutor,
233        request: T,
234    ) -> impl Future<
235        Output = Result<Option<TypedEnvelope<Vec<proto::ProtoLspResponse<T::Response>>>>>,
236    > + use<T>
237    where
238        T: LspRequestMessage,
239    {
240        let new_id = LspRequestId(
241            self.0
242                .next_lsp_request_id
243                .fetch_add(1, atomic::Ordering::Acquire),
244        );
245        let (tx, rx) = oneshot::channel();
246        {
247            self.0.request_ids.lock().insert(new_id, tx);
248        }
249
250        let query = proto::LspQuery {
251            project_id,
252            server_id,
253            lsp_request_id: new_id.0,
254            request: Some(request.to_proto_query()),
255        };
256        let request = self.request(query);
257        let request_ids = self.0.request_ids.clone();
258        async move {
259            match request.await {
260                Ok(_request_enqueued) => {}
261                Err(e) => {
262                    request_ids.lock().remove(&new_id);
263                    return Err(e).context("sending LSP proto request");
264                }
265            }
266
267            let response = rx.with_timeout(timeout, &executor).await;
268            {
269                request_ids.lock().remove(&new_id);
270            }
271            match response {
272                Ok(Ok(response)) => {
273                    let response = response
274                        .context("waiting for LSP proto response")?
275                        .map(|response| {
276                            anyhow::Ok(TypedEnvelope {
277                                payload: response
278                                    .payload
279                                    .into_iter()
280                                    .map(|lsp_response| lsp_response.into_response::<T>())
281                                    .collect::<Result<Vec<_>>>()?,
282                                sender_id: response.sender_id,
283                                original_sender_id: response.original_sender_id,
284                                message_id: response.message_id,
285                                received_at: response.received_at,
286                            })
287                        })
288                        .transpose()
289                        .context("converting LSP proto response")?;
290                    Ok(response)
291                }
292                Err(_cancelled_due_timeout) => Ok(None),
293                Ok(Err(_channel_dropped)) => Ok(None),
294            }
295        }
296    }
297
298    pub fn send_lsp_response<T: LspRequestMessage>(
299        &self,
300        project_id: u64,
301        lsp_request_id: LspRequestId,
302        server_responses: HashMap<u64, T::Response>,
303    ) -> Result<()> {
304        self.send(proto::LspQueryResponse {
305            project_id,
306            lsp_request_id: lsp_request_id.0,
307            responses: server_responses
308                .into_iter()
309                .map(|(server_id, response)| proto::LspResponse {
310                    server_id,
311                    response: Some(T::response_to_proto_query(response)),
312                })
313                .collect(),
314        })
315    }
316
317    pub fn handle_lsp_response(&self, mut envelope: TypedEnvelope<proto::LspQueryResponse>) {
318        let request_id = LspRequestId(envelope.payload.lsp_request_id);
319        let mut response_senders = self.0.request_ids.lock();
320        if let Some(tx) = response_senders.remove(&request_id) {
321            let responses = envelope.payload.responses.drain(..).collect::<Vec<_>>();
322            tx.send(Ok(Some(proto::TypedEnvelope {
323                sender_id: envelope.sender_id,
324                original_sender_id: envelope.original_sender_id,
325                message_id: envelope.message_id,
326                received_at: envelope.received_at,
327                payload: responses
328                    .into_iter()
329                    .filter_map(|response| {
330                        use proto::lsp_response::Response;
331
332                        let server_id = response.server_id;
333                        let response = match response.response? {
334                            Response::GetReferencesResponse(response) => {
335                                to_any_envelope(&envelope, response)
336                            }
337                            Response::GetDocumentColorResponse(response) => {
338                                to_any_envelope(&envelope, response)
339                            }
340                            Response::GetHoverResponse(response) => {
341                                to_any_envelope(&envelope, response)
342                            }
343                            Response::GetCodeActionsResponse(response) => {
344                                to_any_envelope(&envelope, response)
345                            }
346                            Response::GetSignatureHelpResponse(response) => {
347                                to_any_envelope(&envelope, response)
348                            }
349                            Response::GetCodeLensResponse(response) => {
350                                to_any_envelope(&envelope, response)
351                            }
352                            Response::GetDocumentDiagnosticsResponse(response) => {
353                                to_any_envelope(&envelope, response)
354                            }
355                            Response::GetDefinitionResponse(response) => {
356                                to_any_envelope(&envelope, response)
357                            }
358                            Response::GetDeclarationResponse(response) => {
359                                to_any_envelope(&envelope, response)
360                            }
361                            Response::GetTypeDefinitionResponse(response) => {
362                                to_any_envelope(&envelope, response)
363                            }
364                            Response::GetImplementationResponse(response) => {
365                                to_any_envelope(&envelope, response)
366                            }
367                            Response::InlayHintsResponse(response) => {
368                                to_any_envelope(&envelope, response)
369                            }
370                        };
371                        Some(proto::ProtoLspResponse {
372                            server_id,
373                            response,
374                        })
375                    })
376                    .collect(),
377            })))
378            .ok();
379        }
380    }
381
382    pub fn add_request_handler<M, E, H, F>(&self, entity: gpui::WeakEntity<E>, handler: H)
383    where
384        M: RequestMessage,
385        E: 'static,
386        H: 'static + Sync + Fn(Entity<E>, TypedEnvelope<M>, AsyncApp) -> F + Send + Sync,
387        F: 'static + Future<Output = Result<M::Response>>,
388    {
389        self.0
390            .client
391            .message_handler_set()
392            .lock()
393            .add_message_handler(
394                TypeId::of::<M>(),
395                entity.into(),
396                Arc::new(move |entity, envelope, client, cx| {
397                    let entity = entity.downcast::<E>().unwrap();
398                    let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
399                    let request_id = envelope.message_id();
400                    handler(entity, *envelope, cx)
401                        .then(move |result| async move {
402                            match result {
403                                Ok(response) => {
404                                    client.send_response(request_id, response)?;
405                                    Ok(())
406                                }
407                                Err(error) => {
408                                    client.send_response(request_id, error.to_proto())?;
409                                    Err(error)
410                                }
411                            }
412                        })
413                        .boxed_local()
414                }),
415            )
416    }
417
418    pub fn add_entity_request_handler<M, E, H, F>(&self, handler: H)
419    where
420        M: EnvelopedMessage + RequestMessage + EntityMessage,
421        E: 'static,
422        H: 'static + Sync + Send + Fn(gpui::Entity<E>, TypedEnvelope<M>, AsyncApp) -> F,
423        F: 'static + Future<Output = Result<M::Response>>,
424    {
425        let message_type_id = TypeId::of::<M>();
426        let entity_type_id = TypeId::of::<E>();
427        let entity_id_extractor = |envelope: &dyn AnyTypedEnvelope| {
428            (envelope as &dyn Any)
429                .downcast_ref::<TypedEnvelope<M>>()
430                .unwrap()
431                .payload
432                .remote_entity_id()
433        };
434        self.0
435            .client
436            .message_handler_set()
437            .lock()
438            .add_entity_message_handler(
439                message_type_id,
440                entity_type_id,
441                entity_id_extractor,
442                Arc::new(move |entity, envelope, client, cx| {
443                    let entity = entity.downcast::<E>().unwrap();
444                    let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
445                    let request_id = envelope.message_id();
446                    handler(entity, *envelope, cx)
447                        .then(move |result| async move {
448                            match result {
449                                Ok(response) => {
450                                    client.send_response(request_id, response)?;
451                                    Ok(())
452                                }
453                                Err(error) => {
454                                    client.send_response(request_id, error.to_proto())?;
455                                    Err(error)
456                                }
457                            }
458                        })
459                        .boxed_local()
460                }),
461            );
462    }
463
464    pub fn add_entity_message_handler<M, E, H, F>(&self, handler: H)
465    where
466        M: EnvelopedMessage + EntityMessage,
467        E: 'static,
468        H: 'static + Sync + Send + Fn(gpui::Entity<E>, TypedEnvelope<M>, AsyncApp) -> F,
469        F: 'static + Future<Output = Result<()>>,
470    {
471        let message_type_id = TypeId::of::<M>();
472        let entity_type_id = TypeId::of::<E>();
473        let entity_id_extractor = |envelope: &dyn AnyTypedEnvelope| {
474            (envelope as &dyn Any)
475                .downcast_ref::<TypedEnvelope<M>>()
476                .unwrap()
477                .payload
478                .remote_entity_id()
479        };
480        self.0
481            .client
482            .message_handler_set()
483            .lock()
484            .add_entity_message_handler(
485                message_type_id,
486                entity_type_id,
487                entity_id_extractor,
488                Arc::new(move |entity, envelope, _, cx| {
489                    let entity = entity.downcast::<E>().unwrap();
490                    let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
491                    handler(entity, *envelope, cx).boxed_local()
492                }),
493            );
494    }
495
496    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Entity<E>) {
497        let id = (TypeId::of::<E>(), remote_id);
498
499        let mut message_handlers = self.0.client.message_handler_set().lock();
500        if message_handlers
501            .entities_by_type_and_remote_id
502            .contains_key(&id)
503        {
504            panic!("already subscribed to entity");
505        }
506
507        message_handlers.entities_by_type_and_remote_id.insert(
508            id,
509            EntityMessageSubscriber::Entity {
510                handle: entity.downgrade().into(),
511            },
512        );
513    }
514
515    pub fn has_wsl_interop(&self) -> bool {
516        self.0.client.has_wsl_interop()
517    }
518}
519
520fn to_any_envelope<T: EnvelopedMessage>(
521    envelope: &TypedEnvelope<proto::LspQueryResponse>,
522    response: T,
523) -> Box<dyn AnyTypedEnvelope> {
524    Box::new(proto::TypedEnvelope {
525        sender_id: envelope.sender_id,
526        original_sender_id: envelope.original_sender_id,
527        message_id: envelope.message_id,
528        received_at: envelope.received_at,
529        payload: response,
530    }) as Box<_>
531}