proto_client.rs

  1use anyhow::{Context, Result};
  2use collections::HashMap;
  3use futures::{
  4    Future, FutureExt as _,
  5    channel::oneshot,
  6    future::{BoxFuture, LocalBoxFuture},
  7};
  8use gpui::{AnyEntity, AnyWeakEntity, AsyncApp, BackgroundExecutor, Entity, FutureExt as _};
  9use parking_lot::Mutex;
 10use proto::{
 11    AnyTypedEnvelope, EntityMessage, Envelope, EnvelopedMessage, LspRequestId, LspRequestMessage,
 12    RequestMessage, TypedEnvelope, error::ErrorExt as _,
 13};
 14use std::{
 15    any::{Any, TypeId},
 16    sync::{
 17        Arc, OnceLock,
 18        atomic::{self, AtomicU64},
 19    },
 20    time::Duration,
 21};
 22
 23#[derive(Debug, Clone)]
 24pub struct AnyProtoClient(Arc<State>);
 25
 26type RequestIds = Arc<
 27    Mutex<
 28        HashMap<
 29            LspRequestId,
 30            oneshot::Sender<
 31                Result<
 32                    Option<TypedEnvelope<Vec<proto::ProtoLspResponse<Box<dyn AnyTypedEnvelope>>>>>,
 33                >,
 34            >,
 35        >,
 36    >,
 37>;
 38
 39static NEXT_LSP_REQUEST_ID: OnceLock<Arc<AtomicU64>> = OnceLock::new();
 40static REQUEST_IDS: OnceLock<RequestIds> = OnceLock::new();
 41
 42struct State {
 43    client: Arc<dyn ProtoClient>,
 44    next_lsp_request_id: Arc<AtomicU64>,
 45    request_ids: RequestIds,
 46}
 47
 48impl std::fmt::Debug for State {
 49    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 50        f.debug_struct("State")
 51            .field("next_lsp_request_id", &self.next_lsp_request_id)
 52            .field("request_ids", &self.request_ids)
 53            .finish_non_exhaustive()
 54    }
 55}
 56
 57pub trait ProtoClient: Send + Sync {
 58    fn request(
 59        &self,
 60        envelope: Envelope,
 61        request_type: &'static str,
 62    ) -> BoxFuture<'static, Result<Envelope>>;
 63
 64    fn send(&self, envelope: Envelope, message_type: &'static str) -> Result<()>;
 65
 66    fn send_response(&self, envelope: Envelope, message_type: &'static str) -> Result<()>;
 67
 68    fn message_handler_set(&self) -> &parking_lot::Mutex<ProtoMessageHandlerSet>;
 69
 70    fn is_via_collab(&self) -> bool;
 71    fn has_wsl_interop(&self) -> bool;
 72}
 73
 74#[derive(Default)]
 75pub struct ProtoMessageHandlerSet {
 76    pub entity_types_by_message_type: HashMap<TypeId, TypeId>,
 77    pub entities_by_type_and_remote_id: HashMap<(TypeId, u64), EntityMessageSubscriber>,
 78    pub entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
 79    pub entities_by_message_type: HashMap<TypeId, AnyWeakEntity>,
 80    pub message_handlers: HashMap<TypeId, ProtoMessageHandler>,
 81}
 82
 83pub type ProtoMessageHandler = Arc<
 84    dyn Send
 85        + Sync
 86        + Fn(
 87            AnyEntity,
 88            Box<dyn AnyTypedEnvelope>,
 89            AnyProtoClient,
 90            AsyncApp,
 91        ) -> LocalBoxFuture<'static, Result<()>>,
 92>;
 93
 94impl ProtoMessageHandlerSet {
 95    pub fn clear(&mut self) {
 96        self.message_handlers.clear();
 97        self.entities_by_message_type.clear();
 98        self.entities_by_type_and_remote_id.clear();
 99        self.entity_id_extractors.clear();
100    }
101
102    fn add_message_handler(
103        &mut self,
104        message_type_id: TypeId,
105        entity: gpui::AnyWeakEntity,
106        handler: ProtoMessageHandler,
107    ) {
108        self.entities_by_message_type
109            .insert(message_type_id, entity);
110        let prev_handler = self.message_handlers.insert(message_type_id, handler);
111        if prev_handler.is_some() {
112            panic!("registered handler for the same message twice");
113        }
114    }
115
116    fn add_entity_message_handler(
117        &mut self,
118        message_type_id: TypeId,
119        entity_type_id: TypeId,
120        entity_id_extractor: fn(&dyn AnyTypedEnvelope) -> u64,
121        handler: ProtoMessageHandler,
122    ) {
123        self.entity_id_extractors
124            .entry(message_type_id)
125            .or_insert(entity_id_extractor);
126        self.entity_types_by_message_type
127            .insert(message_type_id, entity_type_id);
128        let prev_handler = self.message_handlers.insert(message_type_id, handler);
129        if prev_handler.is_some() {
130            panic!("registered handler for the same message twice");
131        }
132    }
133
134    pub fn handle_message(
135        this: &parking_lot::Mutex<Self>,
136        message: Box<dyn AnyTypedEnvelope>,
137        client: AnyProtoClient,
138        cx: AsyncApp,
139    ) -> Option<LocalBoxFuture<'static, Result<()>>> {
140        let payload_type_id = message.payload_type_id();
141        let mut this = this.lock();
142        let handler = this.message_handlers.get(&payload_type_id)?.clone();
143        let entity = if let Some(entity) = this.entities_by_message_type.get(&payload_type_id) {
144            entity.upgrade()?
145        } else {
146            let extract_entity_id = *this.entity_id_extractors.get(&payload_type_id)?;
147            let entity_type_id = *this.entity_types_by_message_type.get(&payload_type_id)?;
148            let entity_id = (extract_entity_id)(message.as_ref());
149            match this
150                .entities_by_type_and_remote_id
151                .get_mut(&(entity_type_id, entity_id))?
152            {
153                EntityMessageSubscriber::Pending(pending) => {
154                    pending.push(message);
155                    return None;
156                }
157                EntityMessageSubscriber::Entity { handle } => handle.upgrade()?,
158            }
159        };
160        drop(this);
161        Some(handler(entity, message, client, cx))
162    }
163}
164
165pub enum EntityMessageSubscriber {
166    Entity { handle: AnyWeakEntity },
167    Pending(Vec<Box<dyn AnyTypedEnvelope>>),
168}
169
170impl std::fmt::Debug for EntityMessageSubscriber {
171    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
172        match self {
173            EntityMessageSubscriber::Entity { handle } => f
174                .debug_struct("EntityMessageSubscriber::Entity")
175                .field("handle", handle)
176                .finish(),
177            EntityMessageSubscriber::Pending(vec) => f
178                .debug_struct("EntityMessageSubscriber::Pending")
179                .field(
180                    "envelopes",
181                    &vec.iter()
182                        .map(|envelope| envelope.payload_type_name())
183                        .collect::<Vec<_>>(),
184                )
185                .finish(),
186        }
187    }
188}
189
190impl<T> From<Arc<T>> for AnyProtoClient
191where
192    T: ProtoClient + 'static,
193{
194    fn from(client: Arc<T>) -> Self {
195        Self::new(client)
196    }
197}
198
199impl AnyProtoClient {
200    pub fn new<T: ProtoClient + 'static>(client: Arc<T>) -> Self {
201        Self(Arc::new(State {
202            client,
203            next_lsp_request_id: NEXT_LSP_REQUEST_ID
204                .get_or_init(|| Arc::new(AtomicU64::new(0)))
205                .clone(),
206            request_ids: REQUEST_IDS.get_or_init(RequestIds::default).clone(),
207        }))
208    }
209
210    pub fn is_via_collab(&self) -> bool {
211        self.0.client.is_via_collab()
212    }
213
214    pub fn request<T: RequestMessage>(
215        &self,
216        request: T,
217    ) -> impl Future<Output = Result<T::Response>> + use<T> {
218        let envelope = request.into_envelope(0, None, None);
219        let response = self.0.client.request(envelope, T::NAME);
220        async move {
221            T::Response::from_envelope(response.await?)
222                .context("received response of the wrong type")
223        }
224    }
225
226    pub fn send<T: EnvelopedMessage>(&self, request: T) -> Result<()> {
227        let envelope = request.into_envelope(0, None, None);
228        self.0.client.send(envelope, T::NAME)
229    }
230
231    pub fn send_response<T: EnvelopedMessage>(&self, request_id: u32, request: T) -> Result<()> {
232        let envelope = request.into_envelope(0, Some(request_id), None);
233        self.0.client.send(envelope, T::NAME)
234    }
235
236    pub fn request_lsp<T>(
237        &self,
238        project_id: u64,
239        server_id: Option<u64>,
240        timeout: Duration,
241        executor: BackgroundExecutor,
242        request: T,
243    ) -> impl Future<
244        Output = Result<Option<TypedEnvelope<Vec<proto::ProtoLspResponse<T::Response>>>>>,
245    > + use<T>
246    where
247        T: LspRequestMessage,
248    {
249        let new_id = LspRequestId(
250            self.0
251                .next_lsp_request_id
252                .fetch_add(1, atomic::Ordering::Acquire),
253        );
254        let (tx, rx) = oneshot::channel();
255        {
256            self.0.request_ids.lock().insert(new_id, tx);
257        }
258
259        let query = proto::LspQuery {
260            project_id,
261            server_id,
262            lsp_request_id: new_id.0,
263            request: Some(request.to_proto_query()),
264        };
265        let request = self.request(query);
266        let request_ids = self.0.request_ids.clone();
267        async move {
268            match request.await {
269                Ok(_request_enqueued) => {}
270                Err(e) => {
271                    request_ids.lock().remove(&new_id);
272                    return Err(e).context("sending LSP proto request");
273                }
274            }
275
276            let response = rx.with_timeout(timeout, &executor).await;
277            {
278                request_ids.lock().remove(&new_id);
279            }
280            match response {
281                Ok(Ok(response)) => {
282                    let response = response
283                        .context("waiting for LSP proto response")?
284                        .map(|response| {
285                            anyhow::Ok(TypedEnvelope {
286                                payload: response
287                                    .payload
288                                    .into_iter()
289                                    .map(|lsp_response| lsp_response.into_response::<T>())
290                                    .collect::<Result<Vec<_>>>()?,
291                                sender_id: response.sender_id,
292                                original_sender_id: response.original_sender_id,
293                                message_id: response.message_id,
294                                received_at: response.received_at,
295                            })
296                        })
297                        .transpose()
298                        .context("converting LSP proto response")?;
299                    Ok(response)
300                }
301                Err(_cancelled_due_timeout) => Ok(None),
302                Ok(Err(_channel_dropped)) => Ok(None),
303            }
304        }
305    }
306
307    pub fn send_lsp_response<T: LspRequestMessage>(
308        &self,
309        project_id: u64,
310        lsp_request_id: LspRequestId,
311        server_responses: HashMap<u64, T::Response>,
312    ) -> Result<()> {
313        self.send(proto::LspQueryResponse {
314            project_id,
315            lsp_request_id: lsp_request_id.0,
316            responses: server_responses
317                .into_iter()
318                .map(|(server_id, response)| proto::LspResponse {
319                    server_id,
320                    response: Some(T::response_to_proto_query(response)),
321                })
322                .collect(),
323        })
324    }
325
326    pub fn handle_lsp_response(&self, mut envelope: TypedEnvelope<proto::LspQueryResponse>) {
327        let request_id = LspRequestId(envelope.payload.lsp_request_id);
328        let mut response_senders = self.0.request_ids.lock();
329        if let Some(tx) = response_senders.remove(&request_id) {
330            let responses = envelope.payload.responses.drain(..).collect::<Vec<_>>();
331            tx.send(Ok(Some(proto::TypedEnvelope {
332                sender_id: envelope.sender_id,
333                original_sender_id: envelope.original_sender_id,
334                message_id: envelope.message_id,
335                received_at: envelope.received_at,
336                payload: responses
337                    .into_iter()
338                    .filter_map(|response| {
339                        use proto::lsp_response::Response;
340
341                        let server_id = response.server_id;
342                        let response = match response.response? {
343                            Response::GetReferencesResponse(response) => {
344                                to_any_envelope(&envelope, response)
345                            }
346                            Response::GetDocumentColorResponse(response) => {
347                                to_any_envelope(&envelope, response)
348                            }
349                            Response::GetHoverResponse(response) => {
350                                to_any_envelope(&envelope, response)
351                            }
352                            Response::GetCodeActionsResponse(response) => {
353                                to_any_envelope(&envelope, response)
354                            }
355                            Response::GetSignatureHelpResponse(response) => {
356                                to_any_envelope(&envelope, response)
357                            }
358                            Response::GetCodeLensResponse(response) => {
359                                to_any_envelope(&envelope, response)
360                            }
361                            Response::GetDocumentDiagnosticsResponse(response) => {
362                                to_any_envelope(&envelope, response)
363                            }
364                            Response::GetDefinitionResponse(response) => {
365                                to_any_envelope(&envelope, response)
366                            }
367                            Response::GetDeclarationResponse(response) => {
368                                to_any_envelope(&envelope, response)
369                            }
370                            Response::GetTypeDefinitionResponse(response) => {
371                                to_any_envelope(&envelope, response)
372                            }
373                            Response::GetImplementationResponse(response) => {
374                                to_any_envelope(&envelope, response)
375                            }
376                            Response::InlayHintsResponse(response) => {
377                                to_any_envelope(&envelope, response)
378                            }
379                        };
380                        Some(proto::ProtoLspResponse {
381                            server_id,
382                            response,
383                        })
384                    })
385                    .collect(),
386            })))
387            .ok();
388        }
389    }
390
391    pub fn add_request_handler<M, E, H, F>(&self, entity: gpui::WeakEntity<E>, handler: H)
392    where
393        M: RequestMessage,
394        E: 'static,
395        H: 'static + Sync + Fn(Entity<E>, TypedEnvelope<M>, AsyncApp) -> F + Send + Sync,
396        F: 'static + Future<Output = Result<M::Response>>,
397    {
398        self.0
399            .client
400            .message_handler_set()
401            .lock()
402            .add_message_handler(
403                TypeId::of::<M>(),
404                entity.into(),
405                Arc::new(move |entity, envelope, client, cx| {
406                    let entity = entity.downcast::<E>().unwrap();
407                    let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
408                    let request_id = envelope.message_id();
409                    handler(entity, *envelope, cx)
410                        .then(move |result| async move {
411                            match result {
412                                Ok(response) => {
413                                    client.send_response(request_id, response)?;
414                                    Ok(())
415                                }
416                                Err(error) => {
417                                    client.send_response(request_id, error.to_proto())?;
418                                    Err(error)
419                                }
420                            }
421                        })
422                        .boxed_local()
423                }),
424            )
425    }
426
427    pub fn add_entity_request_handler<M, E, H, F>(&self, handler: H)
428    where
429        M: EnvelopedMessage + RequestMessage + EntityMessage,
430        E: 'static,
431        H: 'static + Sync + Send + Fn(gpui::Entity<E>, TypedEnvelope<M>, AsyncApp) -> F,
432        F: 'static + Future<Output = Result<M::Response>>,
433    {
434        let message_type_id = TypeId::of::<M>();
435        let entity_type_id = TypeId::of::<E>();
436        let entity_id_extractor = |envelope: &dyn AnyTypedEnvelope| {
437            (envelope as &dyn Any)
438                .downcast_ref::<TypedEnvelope<M>>()
439                .unwrap()
440                .payload
441                .remote_entity_id()
442        };
443        self.0
444            .client
445            .message_handler_set()
446            .lock()
447            .add_entity_message_handler(
448                message_type_id,
449                entity_type_id,
450                entity_id_extractor,
451                Arc::new(move |entity, envelope, client, cx| {
452                    let entity = entity.downcast::<E>().unwrap();
453                    let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
454                    let request_id = envelope.message_id();
455                    handler(entity, *envelope, cx)
456                        .then(move |result| async move {
457                            match result {
458                                Ok(response) => {
459                                    client.send_response(request_id, response)?;
460                                    Ok(())
461                                }
462                                Err(error) => {
463                                    client.send_response(request_id, error.to_proto())?;
464                                    Err(error)
465                                }
466                            }
467                        })
468                        .boxed_local()
469                }),
470            );
471    }
472
473    pub fn add_entity_message_handler<M, E, H, F>(&self, handler: H)
474    where
475        M: EnvelopedMessage + EntityMessage,
476        E: 'static,
477        H: 'static + Sync + Send + Fn(gpui::Entity<E>, TypedEnvelope<M>, AsyncApp) -> F,
478        F: 'static + Future<Output = Result<()>>,
479    {
480        let message_type_id = TypeId::of::<M>();
481        let entity_type_id = TypeId::of::<E>();
482        let entity_id_extractor = |envelope: &dyn AnyTypedEnvelope| {
483            (envelope as &dyn Any)
484                .downcast_ref::<TypedEnvelope<M>>()
485                .unwrap()
486                .payload
487                .remote_entity_id()
488        };
489        self.0
490            .client
491            .message_handler_set()
492            .lock()
493            .add_entity_message_handler(
494                message_type_id,
495                entity_type_id,
496                entity_id_extractor,
497                Arc::new(move |entity, envelope, _, cx| {
498                    let entity = entity.downcast::<E>().unwrap();
499                    let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
500                    handler(entity, *envelope, cx).boxed_local()
501                }),
502            );
503    }
504
505    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Entity<E>) {
506        let id = (TypeId::of::<E>(), remote_id);
507
508        let mut message_handlers = self.0.client.message_handler_set().lock();
509        if message_handlers
510            .entities_by_type_and_remote_id
511            .contains_key(&id)
512        {
513            panic!("already subscribed to entity");
514        }
515
516        message_handlers.entities_by_type_and_remote_id.insert(
517            id,
518            EntityMessageSubscriber::Entity {
519                handle: entity.downgrade().into(),
520            },
521        );
522    }
523
524    pub fn has_wsl_interop(&self) -> bool {
525        self.0.client.has_wsl_interop()
526    }
527}
528
529fn to_any_envelope<T: EnvelopedMessage>(
530    envelope: &TypedEnvelope<proto::LspQueryResponse>,
531    response: T,
532) -> Box<dyn AnyTypedEnvelope> {
533    Box::new(proto::TypedEnvelope {
534        sender_id: envelope.sender_id,
535        original_sender_id: envelope.original_sender_id,
536        message_id: envelope.message_id,
537        received_at: envelope.received_at,
538        payload: response,
539    }) as Box<_>
540}