proto_client.rs

  1use anyhow::Context;
  2use collections::HashMap;
  3use futures::{
  4    Future, FutureExt as _,
  5    future::{BoxFuture, LocalBoxFuture},
  6};
  7use gpui::{AnyEntity, AnyWeakEntity, AsyncApp, Entity};
  8use proto::{
  9    AnyTypedEnvelope, EntityMessage, Envelope, EnvelopedMessage, RequestMessage, TypedEnvelope,
 10    error::ErrorExt as _,
 11};
 12use std::{
 13    any::{Any, TypeId},
 14    sync::{Arc, Weak},
 15};
 16
 17#[derive(Clone)]
 18pub struct AnyProtoClient(Arc<dyn ProtoClient>);
 19
 20impl AnyProtoClient {
 21    pub fn downgrade(&self) -> AnyWeakProtoClient {
 22        AnyWeakProtoClient(Arc::downgrade(&self.0))
 23    }
 24}
 25
 26#[derive(Clone)]
 27pub struct AnyWeakProtoClient(Weak<dyn ProtoClient>);
 28
 29impl AnyWeakProtoClient {
 30    pub fn upgrade(&self) -> Option<AnyProtoClient> {
 31        self.0.upgrade().map(AnyProtoClient)
 32    }
 33}
 34
 35pub trait ProtoClient: Send + Sync {
 36    fn request(
 37        &self,
 38        envelope: Envelope,
 39        request_type: &'static str,
 40    ) -> BoxFuture<'static, anyhow::Result<Envelope>>;
 41
 42    fn send(&self, envelope: Envelope, message_type: &'static str) -> anyhow::Result<()>;
 43
 44    fn send_response(&self, envelope: Envelope, message_type: &'static str) -> anyhow::Result<()>;
 45
 46    fn message_handler_set(&self) -> &parking_lot::Mutex<ProtoMessageHandlerSet>;
 47
 48    fn is_via_collab(&self) -> bool;
 49}
 50
 51#[derive(Default)]
 52pub struct ProtoMessageHandlerSet {
 53    pub entity_types_by_message_type: HashMap<TypeId, TypeId>,
 54    pub entities_by_type_and_remote_id: HashMap<(TypeId, u64), EntityMessageSubscriber>,
 55    pub entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
 56    pub entities_by_message_type: HashMap<TypeId, AnyWeakEntity>,
 57    pub message_handlers: HashMap<TypeId, ProtoMessageHandler>,
 58}
 59
 60// todo! try to remove these. we can't store handles inside send/sync stuff
 61unsafe impl Send for ProtoMessageHandlerSet {}
 62unsafe impl Sync for ProtoMessageHandlerSet {}
 63
 64pub type ProtoMessageHandler = Arc<
 65    dyn Send
 66        + Sync
 67        + Fn(
 68            AnyEntity,
 69            Box<dyn AnyTypedEnvelope>,
 70            AnyProtoClient,
 71            AsyncApp,
 72        ) -> LocalBoxFuture<'static, anyhow::Result<()>>,
 73>;
 74
 75impl ProtoMessageHandlerSet {
 76    pub fn clear(&mut self) {
 77        self.message_handlers.clear();
 78        self.entities_by_message_type.clear();
 79        self.entities_by_type_and_remote_id.clear();
 80        self.entity_id_extractors.clear();
 81    }
 82
 83    fn add_message_handler(
 84        &mut self,
 85        message_type_id: TypeId,
 86        entity: gpui::AnyWeakEntity,
 87        handler: ProtoMessageHandler,
 88    ) {
 89        self.entities_by_message_type
 90            .insert(message_type_id, entity);
 91        let prev_handler = self.message_handlers.insert(message_type_id, handler);
 92        if prev_handler.is_some() {
 93            panic!("registered handler for the same message twice");
 94        }
 95    }
 96
 97    fn add_entity_message_handler(
 98        &mut self,
 99        message_type_id: TypeId,
100        entity_type_id: TypeId,
101        entity_id_extractor: fn(&dyn AnyTypedEnvelope) -> u64,
102        handler: ProtoMessageHandler,
103    ) {
104        self.entity_id_extractors
105            .entry(message_type_id)
106            .or_insert(entity_id_extractor);
107        self.entity_types_by_message_type
108            .insert(message_type_id, entity_type_id);
109        let prev_handler = self.message_handlers.insert(message_type_id, handler);
110        if prev_handler.is_some() {
111            panic!("registered handler for the same message twice");
112        }
113    }
114
115    pub fn handle_message(
116        this: &parking_lot::Mutex<Self>,
117        message: Box<dyn AnyTypedEnvelope>,
118        client: AnyProtoClient,
119        cx: AsyncApp,
120    ) -> Option<LocalBoxFuture<'static, anyhow::Result<()>>> {
121        let payload_type_id = message.payload_type_id();
122        let mut this = this.lock();
123        let handler = this.message_handlers.get(&payload_type_id)?.clone();
124        let entity = if let Some(entity) = this.entities_by_message_type.get(&payload_type_id) {
125            entity.upgrade()?
126        } else {
127            let extract_entity_id = *this.entity_id_extractors.get(&payload_type_id)?;
128            let entity_type_id = *this.entity_types_by_message_type.get(&payload_type_id)?;
129            let entity_id = (extract_entity_id)(message.as_ref());
130            match this
131                .entities_by_type_and_remote_id
132                .get_mut(&(entity_type_id, entity_id))?
133            {
134                EntityMessageSubscriber::Pending(pending) => {
135                    pending.push(message);
136                    return None;
137                }
138                EntityMessageSubscriber::Entity { handle } => handle.upgrade()?,
139            }
140        };
141        drop(this);
142        Some(handler(entity, message, client, cx))
143    }
144}
145
146pub enum EntityMessageSubscriber {
147    Entity { handle: AnyWeakEntity },
148    Pending(Vec<Box<dyn AnyTypedEnvelope>>),
149}
150
151impl std::fmt::Debug for EntityMessageSubscriber {
152    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153        match self {
154            EntityMessageSubscriber::Entity { handle } => f
155                .debug_struct("EntityMessageSubscriber::Entity")
156                .field("handle", handle)
157                .finish(),
158            EntityMessageSubscriber::Pending(vec) => f
159                .debug_struct("EntityMessageSubscriber::Pending")
160                .field(
161                    "envelopes",
162                    &vec.iter()
163                        .map(|envelope| envelope.payload_type_name())
164                        .collect::<Vec<_>>(),
165                )
166                .finish(),
167        }
168    }
169}
170
171impl<T> From<Arc<T>> for AnyProtoClient
172where
173    T: ProtoClient + 'static,
174{
175    fn from(client: Arc<T>) -> Self {
176        Self(client)
177    }
178}
179
180impl AnyProtoClient {
181    pub fn new<T: ProtoClient + 'static>(client: Arc<T>) -> Self {
182        Self(client)
183    }
184
185    pub fn is_via_collab(&self) -> bool {
186        self.0.is_via_collab()
187    }
188
189    pub fn request<T: RequestMessage>(
190        &self,
191        request: T,
192    ) -> impl Future<Output = anyhow::Result<T::Response>> + use<T> {
193        let envelope = request.into_envelope(0, None, None);
194        let response = self.0.request(envelope, T::NAME);
195        async move {
196            T::Response::from_envelope(response.await?)
197                .context("received response of the wrong type")
198        }
199    }
200
201    pub fn send<T: EnvelopedMessage>(&self, request: T) -> anyhow::Result<()> {
202        let envelope = request.into_envelope(0, None, None);
203        self.0.send(envelope, T::NAME)
204    }
205
206    pub fn send_response<T: EnvelopedMessage>(
207        &self,
208        request_id: u32,
209        request: T,
210    ) -> anyhow::Result<()> {
211        let envelope = request.into_envelope(0, Some(request_id), None);
212        self.0.send(envelope, T::NAME)
213    }
214
215    pub fn add_request_handler<M, E, H, F>(&self, entity: gpui::WeakEntity<E>, handler: H)
216    where
217        M: RequestMessage,
218        E: 'static,
219        H: 'static + Sync + Fn(Entity<E>, TypedEnvelope<M>, AsyncApp) -> F + Send + Sync,
220        F: 'static + Future<Output = anyhow::Result<M::Response>>,
221    {
222        self.0.message_handler_set().lock().add_message_handler(
223            TypeId::of::<M>(),
224            entity.into(),
225            Arc::new(move |entity, envelope, client, cx| {
226                let entity = entity.downcast::<E>().unwrap();
227                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
228                let request_id = envelope.message_id();
229                handler(entity, *envelope, cx)
230                    .then(move |result| async move {
231                        match result {
232                            Ok(response) => {
233                                client.send_response(request_id, response)?;
234                                Ok(())
235                            }
236                            Err(error) => {
237                                client.send_response(request_id, error.to_proto())?;
238                                Err(error)
239                            }
240                        }
241                    })
242                    .boxed_local()
243            }),
244        )
245    }
246
247    pub fn add_entity_request_handler<M, E, H, F>(&self, handler: H)
248    where
249        M: EnvelopedMessage + RequestMessage + EntityMessage,
250        E: 'static,
251        H: 'static + Sync + Send + Fn(gpui::Entity<E>, TypedEnvelope<M>, AsyncApp) -> F,
252        F: 'static + Future<Output = anyhow::Result<M::Response>>,
253    {
254        let message_type_id = TypeId::of::<M>();
255        let entity_type_id = TypeId::of::<E>();
256        let entity_id_extractor = |envelope: &dyn AnyTypedEnvelope| {
257            (envelope as &dyn Any)
258                .downcast_ref::<TypedEnvelope<M>>()
259                .unwrap()
260                .payload
261                .remote_entity_id()
262        };
263        self.0
264            .message_handler_set()
265            .lock()
266            .add_entity_message_handler(
267                message_type_id,
268                entity_type_id,
269                entity_id_extractor,
270                Arc::new(move |entity, envelope, client, cx| {
271                    let entity = entity.downcast::<E>().unwrap();
272                    let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
273                    let request_id = envelope.message_id();
274                    handler(entity, *envelope, cx)
275                        .then(move |result| async move {
276                            match result {
277                                Ok(response) => {
278                                    client.send_response(request_id, response)?;
279                                    Ok(())
280                                }
281                                Err(error) => {
282                                    client.send_response(request_id, error.to_proto())?;
283                                    Err(error)
284                                }
285                            }
286                        })
287                        .boxed_local()
288                }),
289            );
290    }
291
292    pub fn add_entity_message_handler<M, E, H, F>(&self, handler: H)
293    where
294        M: EnvelopedMessage + EntityMessage,
295        E: 'static,
296        H: 'static + Sync + Send + Fn(gpui::Entity<E>, TypedEnvelope<M>, AsyncApp) -> F,
297        F: 'static + Future<Output = anyhow::Result<()>>,
298    {
299        let message_type_id = TypeId::of::<M>();
300        let entity_type_id = TypeId::of::<E>();
301        let entity_id_extractor = |envelope: &dyn AnyTypedEnvelope| {
302            (envelope as &dyn Any)
303                .downcast_ref::<TypedEnvelope<M>>()
304                .unwrap()
305                .payload
306                .remote_entity_id()
307        };
308        self.0
309            .message_handler_set()
310            .lock()
311            .add_entity_message_handler(
312                message_type_id,
313                entity_type_id,
314                entity_id_extractor,
315                Arc::new(move |entity, envelope, _, cx| {
316                    let entity = entity.downcast::<E>().unwrap();
317                    let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
318                    handler(entity, *envelope, cx).boxed_local()
319                }),
320            );
321    }
322}