proto_client.rs

  1use anyhow::Context;
  2use collections::HashMap;
  3use futures::{
  4    Future, FutureExt as _,
  5    future::{BoxFuture, LocalBoxFuture},
  6};
  7use gpui::{AnyEntity, AnyWeakEntity, AsyncApp, Entity};
  8use proto::{
  9    AnyTypedEnvelope, EntityMessage, Envelope, EnvelopedMessage, RequestMessage, TypedEnvelope,
 10    error::ErrorExt as _,
 11};
 12use std::{
 13    any::{Any, TypeId},
 14    sync::{Arc, Weak},
 15};
 16
 17#[derive(Clone)]
 18pub struct AnyProtoClient(Arc<dyn ProtoClient>);
 19
 20impl AnyProtoClient {
 21    pub fn downgrade(&self) -> AnyWeakProtoClient {
 22        AnyWeakProtoClient(Arc::downgrade(&self.0))
 23    }
 24}
 25
 26#[derive(Clone)]
 27pub struct AnyWeakProtoClient(Weak<dyn ProtoClient>);
 28
 29impl AnyWeakProtoClient {
 30    pub fn upgrade(&self) -> Option<AnyProtoClient> {
 31        self.0.upgrade().map(AnyProtoClient)
 32    }
 33}
 34
 35pub trait ProtoClient: Send + Sync {
 36    fn request(
 37        &self,
 38        envelope: Envelope,
 39        request_type: &'static str,
 40    ) -> BoxFuture<'static, anyhow::Result<Envelope>>;
 41
 42    fn send(&self, envelope: Envelope, message_type: &'static str) -> anyhow::Result<()>;
 43
 44    fn send_response(&self, envelope: Envelope, message_type: &'static str) -> anyhow::Result<()>;
 45
 46    fn message_handler_set(&self) -> &parking_lot::Mutex<ProtoMessageHandlerSet>;
 47
 48    fn is_via_collab(&self) -> bool;
 49}
 50
 51#[derive(Default)]
 52pub struct ProtoMessageHandlerSet {
 53    pub entity_types_by_message_type: HashMap<TypeId, TypeId>,
 54    pub entities_by_type_and_remote_id: HashMap<(TypeId, u64), EntityMessageSubscriber>,
 55    pub entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
 56    pub entities_by_message_type: HashMap<TypeId, AnyWeakEntity>,
 57    pub message_handlers: HashMap<TypeId, ProtoMessageHandler>,
 58}
 59
 60pub type ProtoMessageHandler = Arc<
 61    dyn Send
 62        + Sync
 63        + Fn(
 64            AnyEntity,
 65            Box<dyn AnyTypedEnvelope>,
 66            AnyProtoClient,
 67            AsyncApp,
 68        ) -> LocalBoxFuture<'static, anyhow::Result<()>>,
 69>;
 70
 71impl ProtoMessageHandlerSet {
 72    pub fn clear(&mut self) {
 73        self.message_handlers.clear();
 74        self.entities_by_message_type.clear();
 75        self.entities_by_type_and_remote_id.clear();
 76        self.entity_id_extractors.clear();
 77    }
 78
 79    fn add_message_handler(
 80        &mut self,
 81        message_type_id: TypeId,
 82        entity: gpui::AnyWeakEntity,
 83        handler: ProtoMessageHandler,
 84    ) {
 85        self.entities_by_message_type
 86            .insert(message_type_id, entity);
 87        let prev_handler = self.message_handlers.insert(message_type_id, handler);
 88        if prev_handler.is_some() {
 89            panic!("registered handler for the same message twice");
 90        }
 91    }
 92
 93    fn add_entity_message_handler(
 94        &mut self,
 95        message_type_id: TypeId,
 96        entity_type_id: TypeId,
 97        entity_id_extractor: fn(&dyn AnyTypedEnvelope) -> u64,
 98        handler: ProtoMessageHandler,
 99    ) {
100        self.entity_id_extractors
101            .entry(message_type_id)
102            .or_insert(entity_id_extractor);
103        self.entity_types_by_message_type
104            .insert(message_type_id, entity_type_id);
105        let prev_handler = self.message_handlers.insert(message_type_id, handler);
106        if prev_handler.is_some() {
107            panic!("registered handler for the same message twice");
108        }
109    }
110
111    pub fn handle_message(
112        this: &parking_lot::Mutex<Self>,
113        message: Box<dyn AnyTypedEnvelope>,
114        client: AnyProtoClient,
115        cx: AsyncApp,
116    ) -> Option<LocalBoxFuture<'static, anyhow::Result<()>>> {
117        let payload_type_id = message.payload_type_id();
118        let mut this = this.lock();
119        let handler = this.message_handlers.get(&payload_type_id)?.clone();
120        let entity = if let Some(entity) = this.entities_by_message_type.get(&payload_type_id) {
121            entity.upgrade()?
122        } else {
123            let extract_entity_id = *this.entity_id_extractors.get(&payload_type_id)?;
124            let entity_type_id = *this.entity_types_by_message_type.get(&payload_type_id)?;
125            let entity_id = (extract_entity_id)(message.as_ref());
126            match this
127                .entities_by_type_and_remote_id
128                .get_mut(&(entity_type_id, entity_id))?
129            {
130                EntityMessageSubscriber::Pending(pending) => {
131                    pending.push(message);
132                    return None;
133                }
134                EntityMessageSubscriber::Entity { handle } => handle.upgrade()?,
135            }
136        };
137        drop(this);
138        Some(handler(entity, message, client, cx))
139    }
140}
141
142pub enum EntityMessageSubscriber {
143    Entity { handle: AnyWeakEntity },
144    Pending(Vec<Box<dyn AnyTypedEnvelope>>),
145}
146
147impl std::fmt::Debug for EntityMessageSubscriber {
148    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149        match self {
150            EntityMessageSubscriber::Entity { handle } => f
151                .debug_struct("EntityMessageSubscriber::Entity")
152                .field("handle", handle)
153                .finish(),
154            EntityMessageSubscriber::Pending(vec) => f
155                .debug_struct("EntityMessageSubscriber::Pending")
156                .field(
157                    "envelopes",
158                    &vec.iter()
159                        .map(|envelope| envelope.payload_type_name())
160                        .collect::<Vec<_>>(),
161                )
162                .finish(),
163        }
164    }
165}
166
167impl<T> From<Arc<T>> for AnyProtoClient
168where
169    T: ProtoClient + 'static,
170{
171    fn from(client: Arc<T>) -> Self {
172        Self(client)
173    }
174}
175
176impl AnyProtoClient {
177    pub fn new<T: ProtoClient + 'static>(client: Arc<T>) -> Self {
178        Self(client)
179    }
180
181    pub fn is_via_collab(&self) -> bool {
182        self.0.is_via_collab()
183    }
184
185    pub fn request<T: RequestMessage>(
186        &self,
187        request: T,
188    ) -> impl Future<Output = anyhow::Result<T::Response>> + use<T> {
189        let envelope = request.into_envelope(0, None, None);
190        let response = self.0.request(envelope, T::NAME);
191        async move {
192            T::Response::from_envelope(response.await?)
193                .context("received response of the wrong type")
194        }
195    }
196
197    pub fn send<T: EnvelopedMessage>(&self, request: T) -> anyhow::Result<()> {
198        let envelope = request.into_envelope(0, None, None);
199        self.0.send(envelope, T::NAME)
200    }
201
202    pub fn send_response<T: EnvelopedMessage>(
203        &self,
204        request_id: u32,
205        request: T,
206    ) -> anyhow::Result<()> {
207        let envelope = request.into_envelope(0, Some(request_id), None);
208        self.0.send(envelope, T::NAME)
209    }
210
211    pub fn add_request_handler<M, E, H, F>(&self, entity: gpui::WeakEntity<E>, handler: H)
212    where
213        M: RequestMessage,
214        E: 'static,
215        H: 'static + Sync + Fn(Entity<E>, TypedEnvelope<M>, AsyncApp) -> F + Send + Sync,
216        F: 'static + Future<Output = anyhow::Result<M::Response>>,
217    {
218        self.0.message_handler_set().lock().add_message_handler(
219            TypeId::of::<M>(),
220            entity.into(),
221            Arc::new(move |entity, envelope, client, cx| {
222                let entity = entity.downcast::<E>().unwrap();
223                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
224                let request_id = envelope.message_id();
225                handler(entity, *envelope, cx)
226                    .then(move |result| async move {
227                        match result {
228                            Ok(response) => {
229                                client.send_response(request_id, response)?;
230                                Ok(())
231                            }
232                            Err(error) => {
233                                client.send_response(request_id, error.to_proto())?;
234                                Err(error)
235                            }
236                        }
237                    })
238                    .boxed_local()
239            }),
240        )
241    }
242
243    pub fn add_entity_request_handler<M, E, H, F>(&self, handler: H)
244    where
245        M: EnvelopedMessage + RequestMessage + EntityMessage,
246        E: 'static,
247        H: 'static + Sync + Send + Fn(gpui::Entity<E>, TypedEnvelope<M>, AsyncApp) -> F,
248        F: 'static + Future<Output = anyhow::Result<M::Response>>,
249    {
250        let message_type_id = TypeId::of::<M>();
251        let entity_type_id = TypeId::of::<E>();
252        let entity_id_extractor = |envelope: &dyn AnyTypedEnvelope| {
253            (envelope as &dyn Any)
254                .downcast_ref::<TypedEnvelope<M>>()
255                .unwrap()
256                .payload
257                .remote_entity_id()
258        };
259        self.0
260            .message_handler_set()
261            .lock()
262            .add_entity_message_handler(
263                message_type_id,
264                entity_type_id,
265                entity_id_extractor,
266                Arc::new(move |entity, envelope, client, cx| {
267                    let entity = entity.downcast::<E>().unwrap();
268                    let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
269                    let request_id = envelope.message_id();
270                    handler(entity, *envelope, cx)
271                        .then(move |result| async move {
272                            match result {
273                                Ok(response) => {
274                                    client.send_response(request_id, response)?;
275                                    Ok(())
276                                }
277                                Err(error) => {
278                                    client.send_response(request_id, error.to_proto())?;
279                                    Err(error)
280                                }
281                            }
282                        })
283                        .boxed_local()
284                }),
285            );
286    }
287
288    pub fn add_entity_message_handler<M, E, H, F>(&self, handler: H)
289    where
290        M: EnvelopedMessage + EntityMessage,
291        E: 'static,
292        H: 'static + Sync + Send + Fn(gpui::Entity<E>, TypedEnvelope<M>, AsyncApp) -> F,
293        F: 'static + Future<Output = anyhow::Result<()>>,
294    {
295        let message_type_id = TypeId::of::<M>();
296        let entity_type_id = TypeId::of::<E>();
297        let entity_id_extractor = |envelope: &dyn AnyTypedEnvelope| {
298            (envelope as &dyn Any)
299                .downcast_ref::<TypedEnvelope<M>>()
300                .unwrap()
301                .payload
302                .remote_entity_id()
303        };
304        self.0
305            .message_handler_set()
306            .lock()
307            .add_entity_message_handler(
308                message_type_id,
309                entity_type_id,
310                entity_id_extractor,
311                Arc::new(move |entity, envelope, _, cx| {
312                    let entity = entity.downcast::<E>().unwrap();
313                    let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
314                    handler(entity, *envelope, cx).boxed_local()
315                }),
316            );
317    }
318
319    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Entity<E>) {
320        let id = (TypeId::of::<E>(), remote_id);
321
322        let mut message_handlers = self.0.message_handler_set().lock();
323        if message_handlers
324            .entities_by_type_and_remote_id
325            .contains_key(&id)
326        {
327            panic!("already subscribed to entity");
328        }
329
330        message_handlers.entities_by_type_and_remote_id.insert(
331            id,
332            EntityMessageSubscriber::Entity {
333                handle: entity.downgrade().into(),
334            },
335        );
336    }
337}