proto_client.rs

  1use anyhow::{Context, Result};
  2use collections::HashMap;
  3use futures::{
  4    Future, FutureExt as _,
  5    channel::oneshot,
  6    future::{BoxFuture, LocalBoxFuture},
  7};
  8use gpui::{AnyEntity, AnyWeakEntity, AsyncApp, BackgroundExecutor, Entity, FutureExt as _};
  9use parking_lot::Mutex;
 10use proto::{
 11    AnyTypedEnvelope, EntityMessage, Envelope, EnvelopedMessage, LspRequestId, LspRequestMessage,
 12    RequestMessage, TypedEnvelope, error::ErrorExt as _,
 13};
 14use std::{
 15    any::{Any, TypeId},
 16    sync::{
 17        Arc, OnceLock,
 18        atomic::{self, AtomicU64},
 19    },
 20    time::Duration,
 21};
 22
 23#[derive(Debug, Clone)]
 24pub struct AnyProtoClient(Arc<State>);
 25
 26type RequestIds = Arc<
 27    Mutex<
 28        HashMap<
 29            LspRequestId,
 30            oneshot::Sender<
 31                Result<
 32                    Option<TypedEnvelope<Vec<proto::ProtoLspResponse<Box<dyn AnyTypedEnvelope>>>>>,
 33                >,
 34            >,
 35        >,
 36    >,
 37>;
 38
 39static NEXT_LSP_REQUEST_ID: OnceLock<Arc<AtomicU64>> = OnceLock::new();
 40static REQUEST_IDS: OnceLock<RequestIds> = OnceLock::new();
 41
 42struct State {
 43    client: Arc<dyn ProtoClient>,
 44    next_lsp_request_id: Arc<AtomicU64>,
 45    request_ids: RequestIds,
 46}
 47
 48impl std::fmt::Debug for State {
 49    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 50        f.debug_struct("State")
 51            .field("next_lsp_request_id", &self.next_lsp_request_id)
 52            .field("request_ids", &self.request_ids)
 53            .finish_non_exhaustive()
 54    }
 55}
 56
 57pub trait ProtoClient: Send + Sync {
 58    fn request(
 59        &self,
 60        envelope: Envelope,
 61        request_type: &'static str,
 62    ) -> BoxFuture<'static, Result<Envelope>>;
 63
 64    fn send(&self, envelope: Envelope, message_type: &'static str) -> Result<()>;
 65
 66    fn send_response(&self, envelope: Envelope, message_type: &'static str) -> Result<()>;
 67
 68    fn message_handler_set(&self) -> &parking_lot::Mutex<ProtoMessageHandlerSet>;
 69
 70    fn is_via_collab(&self) -> bool;
 71    fn has_wsl_interop(&self) -> bool;
 72}
 73
 74#[derive(Default)]
 75pub struct ProtoMessageHandlerSet {
 76    pub entity_types_by_message_type: HashMap<TypeId, TypeId>,
 77    pub entities_by_type_and_remote_id: HashMap<(TypeId, u64), EntityMessageSubscriber>,
 78    pub entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
 79    pub entities_by_message_type: HashMap<TypeId, AnyWeakEntity>,
 80    pub message_handlers: HashMap<TypeId, ProtoMessageHandler>,
 81}
 82
 83pub type ProtoMessageHandler = Arc<
 84    dyn Send
 85        + Sync
 86        + Fn(
 87            AnyEntity,
 88            Box<dyn AnyTypedEnvelope>,
 89            AnyProtoClient,
 90            AsyncApp,
 91        ) -> LocalBoxFuture<'static, Result<()>>,
 92>;
 93
 94impl ProtoMessageHandlerSet {
 95    pub fn clear(&mut self) {
 96        self.message_handlers.clear();
 97        self.entities_by_message_type.clear();
 98        self.entities_by_type_and_remote_id.clear();
 99        self.entity_id_extractors.clear();
100    }
101
102    fn add_message_handler(
103        &mut self,
104        message_type_id: TypeId,
105        entity: gpui::AnyWeakEntity,
106        handler: ProtoMessageHandler,
107    ) {
108        self.entities_by_message_type
109            .insert(message_type_id, entity);
110        let prev_handler = self.message_handlers.insert(message_type_id, handler);
111        if prev_handler.is_some() {
112            panic!("registered handler for the same message twice");
113        }
114    }
115
116    fn add_entity_message_handler(
117        &mut self,
118        message_type_id: TypeId,
119        entity_type_id: TypeId,
120        entity_id_extractor: fn(&dyn AnyTypedEnvelope) -> u64,
121        handler: ProtoMessageHandler,
122    ) {
123        self.entity_id_extractors
124            .entry(message_type_id)
125            .or_insert(entity_id_extractor);
126        self.entity_types_by_message_type
127            .insert(message_type_id, entity_type_id);
128        let prev_handler = self.message_handlers.insert(message_type_id, handler);
129        if prev_handler.is_some() {
130            panic!("registered handler for the same message twice");
131        }
132    }
133
134    pub fn handle_message(
135        this: &parking_lot::Mutex<Self>,
136        message: Box<dyn AnyTypedEnvelope>,
137        client: AnyProtoClient,
138        cx: AsyncApp,
139    ) -> Option<LocalBoxFuture<'static, Result<()>>> {
140        let payload_type_id = message.payload_type_id();
141        let mut this = this.lock();
142        let handler = this.message_handlers.get(&payload_type_id)?.clone();
143        let entity = if let Some(entity) = this.entities_by_message_type.get(&payload_type_id) {
144            entity.upgrade()?
145        } else {
146            let extract_entity_id = *this.entity_id_extractors.get(&payload_type_id)?;
147            let entity_type_id = *this.entity_types_by_message_type.get(&payload_type_id)?;
148            let entity_id = (extract_entity_id)(message.as_ref());
149            match this
150                .entities_by_type_and_remote_id
151                .get_mut(&(entity_type_id, entity_id))?
152            {
153                EntityMessageSubscriber::Pending(pending) => {
154                    pending.push(message);
155                    return None;
156                }
157                EntityMessageSubscriber::Entity { handle } => handle.upgrade()?,
158            }
159        };
160        drop(this);
161        Some(handler(entity, message, client, cx))
162    }
163}
164
165pub enum EntityMessageSubscriber {
166    Entity { handle: AnyWeakEntity },
167    Pending(Vec<Box<dyn AnyTypedEnvelope>>),
168}
169
170impl std::fmt::Debug for EntityMessageSubscriber {
171    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
172        match self {
173            EntityMessageSubscriber::Entity { handle } => f
174                .debug_struct("EntityMessageSubscriber::Entity")
175                .field("handle", handle)
176                .finish(),
177            EntityMessageSubscriber::Pending(vec) => f
178                .debug_struct("EntityMessageSubscriber::Pending")
179                .field(
180                    "envelopes",
181                    &vec.iter()
182                        .map(|envelope| envelope.payload_type_name())
183                        .collect::<Vec<_>>(),
184                )
185                .finish(),
186        }
187    }
188}
189
190impl<T> From<Arc<T>> for AnyProtoClient
191where
192    T: ProtoClient + 'static,
193{
194    fn from(client: Arc<T>) -> Self {
195        Self::new(client)
196    }
197}
198
199impl AnyProtoClient {
200    pub fn new<T: ProtoClient + 'static>(client: Arc<T>) -> Self {
201        Self(Arc::new(State {
202            client,
203            next_lsp_request_id: NEXT_LSP_REQUEST_ID
204                .get_or_init(|| Arc::new(AtomicU64::new(0)))
205                .clone(),
206            request_ids: REQUEST_IDS.get_or_init(RequestIds::default).clone(),
207        }))
208    }
209
210    pub fn is_via_collab(&self) -> bool {
211        self.0.client.is_via_collab()
212    }
213
214    pub fn request<T: RequestMessage>(
215        &self,
216        request: T,
217    ) -> impl Future<Output = Result<T::Response>> + use<T> {
218        let envelope = request.into_envelope(0, None, None);
219        let response = self.0.client.request(envelope, T::NAME);
220        async move {
221            T::Response::from_envelope(response.await?)
222                .context("received response of the wrong type")
223        }
224    }
225
226    pub fn send<T: EnvelopedMessage>(&self, request: T) -> Result<()> {
227        let envelope = request.into_envelope(0, None, None);
228        self.0.client.send(envelope, T::NAME)
229    }
230
231    pub fn send_response<T: EnvelopedMessage>(&self, request_id: u32, request: T) -> Result<()> {
232        let envelope = request.into_envelope(0, Some(request_id), None);
233        self.0.client.send(envelope, T::NAME)
234    }
235
236    pub fn request_lsp<T>(
237        &self,
238        project_id: u64,
239        server_id: Option<u64>,
240        timeout: Duration,
241        executor: BackgroundExecutor,
242        request: T,
243    ) -> impl Future<
244        Output = Result<Option<TypedEnvelope<Vec<proto::ProtoLspResponse<T::Response>>>>>,
245    > + use<T>
246    where
247        T: LspRequestMessage,
248    {
249        let new_id = LspRequestId(
250            self.0
251                .next_lsp_request_id
252                .fetch_add(1, atomic::Ordering::Acquire),
253        );
254        let (tx, rx) = oneshot::channel();
255        {
256            self.0.request_ids.lock().insert(new_id, tx);
257        }
258
259        let query = proto::LspQuery {
260            project_id,
261            server_id,
262            lsp_request_id: new_id.0,
263            request: Some(request.to_proto_query()),
264        };
265        let request = self.request(query);
266        let request_ids = self.0.request_ids.clone();
267        async move {
268            match request.await {
269                Ok(_request_enqueued) => {}
270                Err(e) => {
271                    request_ids.lock().remove(&new_id);
272                    return Err(e).context("sending LSP proto request");
273                }
274            }
275
276            let response = rx.with_timeout(timeout, &executor).await;
277            {
278                request_ids.lock().remove(&new_id);
279            }
280            match response {
281                Ok(Ok(response)) => {
282                    let response = response
283                        .context("waiting for LSP proto response")?
284                        .map(|response| {
285                            anyhow::Ok(TypedEnvelope {
286                                payload: response
287                                    .payload
288                                    .into_iter()
289                                    .map(|lsp_response| lsp_response.into_response::<T>())
290                                    .collect::<Result<Vec<_>>>()?,
291                                sender_id: response.sender_id,
292                                original_sender_id: response.original_sender_id,
293                                message_id: response.message_id,
294                                received_at: response.received_at,
295                            })
296                        })
297                        .transpose()
298                        .context("converting LSP proto response")?;
299                    Ok(response)
300                }
301                Err(_cancelled_due_timeout) => Ok(None),
302                Ok(Err(_channel_dropped)) => Ok(None),
303            }
304        }
305    }
306
307    pub fn send_lsp_response<T: LspRequestMessage>(
308        &self,
309        project_id: u64,
310        lsp_request_id: LspRequestId,
311        server_responses: HashMap<u64, T::Response>,
312    ) -> Result<()> {
313        self.send(proto::LspQueryResponse {
314            project_id,
315            lsp_request_id: lsp_request_id.0,
316            responses: server_responses
317                .into_iter()
318                .map(|(server_id, response)| proto::LspResponse {
319                    server_id,
320                    response: Some(T::response_to_proto_query(response)),
321                })
322                .collect(),
323        })
324    }
325
326    pub fn handle_lsp_response(&self, mut envelope: TypedEnvelope<proto::LspQueryResponse>) {
327        let request_id = LspRequestId(envelope.payload.lsp_request_id);
328        let mut response_senders = self.0.request_ids.lock();
329        if let Some(tx) = response_senders.remove(&request_id) {
330            let responses = envelope.payload.responses.drain(..).collect::<Vec<_>>();
331            tx.send(Ok(Some(proto::TypedEnvelope {
332                sender_id: envelope.sender_id,
333                original_sender_id: envelope.original_sender_id,
334                message_id: envelope.message_id,
335                received_at: envelope.received_at,
336                payload: responses
337                    .into_iter()
338                    .filter_map(|response| {
339                        use proto::lsp_response::Response;
340
341                        let server_id = response.server_id;
342                        let response = match response.response? {
343                            Response::GetReferencesResponse(response) => {
344                                to_any_envelope(&envelope, response)
345                            }
346                            Response::GetDocumentColorResponse(response) => {
347                                to_any_envelope(&envelope, response)
348                            }
349                            Response::GetHoverResponse(response) => {
350                                to_any_envelope(&envelope, response)
351                            }
352                            Response::GetCodeActionsResponse(response) => {
353                                to_any_envelope(&envelope, response)
354                            }
355                            Response::GetSignatureHelpResponse(response) => {
356                                to_any_envelope(&envelope, response)
357                            }
358                            Response::GetCodeLensResponse(response) => {
359                                to_any_envelope(&envelope, response)
360                            }
361                            Response::GetDocumentDiagnosticsResponse(response) => {
362                                to_any_envelope(&envelope, response)
363                            }
364                            Response::GetDefinitionResponse(response) => {
365                                to_any_envelope(&envelope, response)
366                            }
367                            Response::GetDeclarationResponse(response) => {
368                                to_any_envelope(&envelope, response)
369                            }
370                            Response::GetTypeDefinitionResponse(response) => {
371                                to_any_envelope(&envelope, response)
372                            }
373                            Response::GetImplementationResponse(response) => {
374                                to_any_envelope(&envelope, response)
375                            }
376                            Response::InlayHintsResponse(response) => {
377                                to_any_envelope(&envelope, response)
378                            }
379                            Response::SemanticTokensResponse(response) => {
380                                to_any_envelope(&envelope, response)
381                            }
382                        };
383                        Some(proto::ProtoLspResponse {
384                            server_id,
385                            response,
386                        })
387                    })
388                    .collect(),
389            })))
390            .ok();
391        }
392    }
393
394    pub fn add_request_handler<M, E, H, F>(&self, entity: gpui::WeakEntity<E>, handler: H)
395    where
396        M: RequestMessage,
397        E: 'static,
398        H: 'static + Sync + Fn(Entity<E>, TypedEnvelope<M>, AsyncApp) -> F + Send + Sync,
399        F: 'static + Future<Output = Result<M::Response>>,
400    {
401        self.0
402            .client
403            .message_handler_set()
404            .lock()
405            .add_message_handler(
406                TypeId::of::<M>(),
407                entity.into(),
408                Arc::new(move |entity, envelope, client, cx| {
409                    let entity = entity.downcast::<E>().unwrap();
410                    let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
411                    let request_id = envelope.message_id();
412                    handler(entity, *envelope, cx)
413                        .then(move |result| async move {
414                            match result {
415                                Ok(response) => {
416                                    client.send_response(request_id, response)?;
417                                    Ok(())
418                                }
419                                Err(error) => {
420                                    client.send_response(request_id, error.to_proto())?;
421                                    Err(error)
422                                }
423                            }
424                        })
425                        .boxed_local()
426                }),
427            )
428    }
429
430    pub fn add_entity_request_handler<M, E, H, F>(&self, handler: H)
431    where
432        M: EnvelopedMessage + RequestMessage + EntityMessage,
433        E: 'static,
434        H: 'static + Sync + Send + Fn(gpui::Entity<E>, TypedEnvelope<M>, AsyncApp) -> F,
435        F: 'static + Future<Output = Result<M::Response>>,
436    {
437        let message_type_id = TypeId::of::<M>();
438        let entity_type_id = TypeId::of::<E>();
439        let entity_id_extractor = |envelope: &dyn AnyTypedEnvelope| {
440            (envelope as &dyn Any)
441                .downcast_ref::<TypedEnvelope<M>>()
442                .unwrap()
443                .payload
444                .remote_entity_id()
445        };
446        self.0
447            .client
448            .message_handler_set()
449            .lock()
450            .add_entity_message_handler(
451                message_type_id,
452                entity_type_id,
453                entity_id_extractor,
454                Arc::new(move |entity, envelope, client, cx| {
455                    let entity = entity.downcast::<E>().unwrap();
456                    let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
457                    let request_id = envelope.message_id();
458                    handler(entity, *envelope, cx)
459                        .then(move |result| async move {
460                            match result {
461                                Ok(response) => {
462                                    client.send_response(request_id, response)?;
463                                    Ok(())
464                                }
465                                Err(error) => {
466                                    client.send_response(request_id, error.to_proto())?;
467                                    Err(error)
468                                }
469                            }
470                        })
471                        .boxed_local()
472                }),
473            );
474    }
475
476    pub fn add_entity_message_handler<M, E, H, F>(&self, handler: H)
477    where
478        M: EnvelopedMessage + EntityMessage,
479        E: 'static,
480        H: 'static + Sync + Send + Fn(gpui::Entity<E>, TypedEnvelope<M>, AsyncApp) -> F,
481        F: 'static + Future<Output = Result<()>>,
482    {
483        let message_type_id = TypeId::of::<M>();
484        let entity_type_id = TypeId::of::<E>();
485        let entity_id_extractor = |envelope: &dyn AnyTypedEnvelope| {
486            (envelope as &dyn Any)
487                .downcast_ref::<TypedEnvelope<M>>()
488                .unwrap()
489                .payload
490                .remote_entity_id()
491        };
492        self.0
493            .client
494            .message_handler_set()
495            .lock()
496            .add_entity_message_handler(
497                message_type_id,
498                entity_type_id,
499                entity_id_extractor,
500                Arc::new(move |entity, envelope, _, cx| {
501                    let entity = entity.downcast::<E>().unwrap();
502                    let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
503                    handler(entity, *envelope, cx).boxed_local()
504                }),
505            );
506    }
507
508    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Entity<E>) {
509        let id = (TypeId::of::<E>(), remote_id);
510
511        let mut message_handlers = self.0.client.message_handler_set().lock();
512        if message_handlers
513            .entities_by_type_and_remote_id
514            .contains_key(&id)
515        {
516            panic!("already subscribed to entity");
517        }
518
519        message_handlers.entities_by_type_and_remote_id.insert(
520            id,
521            EntityMessageSubscriber::Entity {
522                handle: entity.downgrade().into(),
523            },
524        );
525    }
526
527    pub fn has_wsl_interop(&self) -> bool {
528        self.0.client.has_wsl_interop()
529    }
530}
531
532fn to_any_envelope<T: EnvelopedMessage>(
533    envelope: &TypedEnvelope<proto::LspQueryResponse>,
534    response: T,
535) -> Box<dyn AnyTypedEnvelope> {
536    Box::new(proto::TypedEnvelope {
537        sender_id: envelope.sender_id,
538        original_sender_id: envelope.original_sender_id,
539        message_id: envelope.message_id,
540        received_at: envelope.received_at,
541        payload: response,
542    }) as Box<_>
543}