proto_client.rs

  1use anyhow::anyhow;
  2use collections::HashMap;
  3use futures::{
  4    future::{BoxFuture, LocalBoxFuture},
  5    Future, FutureExt as _,
  6};
  7use gpui::{AnyModel, AnyWeakModel, AsyncAppContext, Model};
  8// pub use prost::Message;
  9use proto::{
 10    error::ErrorExt as _, AnyTypedEnvelope, EntityMessage, Envelope, EnvelopedMessage,
 11    RequestMessage, TypedEnvelope,
 12};
 13use std::{
 14    any::TypeId,
 15    sync::{Arc, Weak},
 16};
 17
 18#[derive(Clone)]
 19pub struct AnyProtoClient(Arc<dyn ProtoClient>);
 20
 21impl AnyProtoClient {
 22    pub fn downgrade(&self) -> AnyWeakProtoClient {
 23        AnyWeakProtoClient(Arc::downgrade(&self.0))
 24    }
 25}
 26
 27#[derive(Clone)]
 28pub struct AnyWeakProtoClient(Weak<dyn ProtoClient>);
 29
 30impl AnyWeakProtoClient {
 31    pub fn upgrade(&self) -> Option<AnyProtoClient> {
 32        self.0.upgrade().map(AnyProtoClient)
 33    }
 34}
 35
 36pub trait ProtoClient: Send + Sync {
 37    fn request(
 38        &self,
 39        envelope: Envelope,
 40        request_type: &'static str,
 41    ) -> BoxFuture<'static, anyhow::Result<Envelope>>;
 42
 43    fn send(&self, envelope: Envelope, message_type: &'static str) -> anyhow::Result<()>;
 44
 45    fn send_response(&self, envelope: Envelope, message_type: &'static str) -> anyhow::Result<()>;
 46
 47    fn message_handler_set(&self) -> &parking_lot::Mutex<ProtoMessageHandlerSet>;
 48
 49    fn is_via_collab(&self) -> bool;
 50}
 51
 52#[derive(Default)]
 53pub struct ProtoMessageHandlerSet {
 54    pub entity_types_by_message_type: HashMap<TypeId, TypeId>,
 55    pub entities_by_type_and_remote_id: HashMap<(TypeId, u64), EntityMessageSubscriber>,
 56    pub entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
 57    pub models_by_message_type: HashMap<TypeId, AnyWeakModel>,
 58    pub message_handlers: HashMap<TypeId, ProtoMessageHandler>,
 59}
 60
 61pub type ProtoMessageHandler = Arc<
 62    dyn Send
 63        + Sync
 64        + Fn(
 65            AnyModel,
 66            Box<dyn AnyTypedEnvelope>,
 67            AnyProtoClient,
 68            AsyncAppContext,
 69        ) -> LocalBoxFuture<'static, anyhow::Result<()>>,
 70>;
 71
 72impl ProtoMessageHandlerSet {
 73    pub fn clear(&mut self) {
 74        self.message_handlers.clear();
 75        self.models_by_message_type.clear();
 76        self.entities_by_type_and_remote_id.clear();
 77        self.entity_id_extractors.clear();
 78    }
 79
 80    fn add_message_handler(
 81        &mut self,
 82        message_type_id: TypeId,
 83        model: gpui::AnyWeakModel,
 84        handler: ProtoMessageHandler,
 85    ) {
 86        self.models_by_message_type.insert(message_type_id, model);
 87        let prev_handler = self.message_handlers.insert(message_type_id, handler);
 88        if prev_handler.is_some() {
 89            panic!("registered handler for the same message twice");
 90        }
 91    }
 92
 93    fn add_entity_message_handler(
 94        &mut self,
 95        message_type_id: TypeId,
 96        model_type_id: TypeId,
 97        entity_id_extractor: fn(&dyn AnyTypedEnvelope) -> u64,
 98        handler: ProtoMessageHandler,
 99    ) {
100        self.entity_id_extractors
101            .entry(message_type_id)
102            .or_insert(entity_id_extractor);
103        self.entity_types_by_message_type
104            .insert(message_type_id, model_type_id);
105        let prev_handler = self.message_handlers.insert(message_type_id, handler);
106        if prev_handler.is_some() {
107            panic!("registered handler for the same message twice");
108        }
109    }
110
111    pub fn handle_message(
112        this: &parking_lot::Mutex<Self>,
113        message: Box<dyn AnyTypedEnvelope>,
114        client: AnyProtoClient,
115        cx: AsyncAppContext,
116    ) -> Option<LocalBoxFuture<'static, anyhow::Result<()>>> {
117        let payload_type_id = message.payload_type_id();
118        let mut this = this.lock();
119        let handler = this.message_handlers.get(&payload_type_id)?.clone();
120        let entity = if let Some(entity) = this.models_by_message_type.get(&payload_type_id) {
121            entity.upgrade()?
122        } else {
123            let extract_entity_id = *this.entity_id_extractors.get(&payload_type_id)?;
124            let entity_type_id = *this.entity_types_by_message_type.get(&payload_type_id)?;
125            let entity_id = (extract_entity_id)(message.as_ref());
126            match this
127                .entities_by_type_and_remote_id
128                .get_mut(&(entity_type_id, entity_id))?
129            {
130                EntityMessageSubscriber::Pending(pending) => {
131                    pending.push(message);
132                    return None;
133                }
134                EntityMessageSubscriber::Entity { handle } => handle.upgrade()?,
135            }
136        };
137        drop(this);
138        Some(handler(entity, message, client, cx))
139    }
140}
141
142pub enum EntityMessageSubscriber {
143    Entity { handle: AnyWeakModel },
144    Pending(Vec<Box<dyn AnyTypedEnvelope>>),
145}
146
147impl std::fmt::Debug for EntityMessageSubscriber {
148    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149        match self {
150            EntityMessageSubscriber::Entity { handle } => f
151                .debug_struct("EntityMessageSubscriber::Entity")
152                .field("handle", handle)
153                .finish(),
154            EntityMessageSubscriber::Pending(vec) => f
155                .debug_struct("EntityMessageSubscriber::Pending")
156                .field(
157                    "envelopes",
158                    &vec.iter()
159                        .map(|envelope| envelope.payload_type_name())
160                        .collect::<Vec<_>>(),
161                )
162                .finish(),
163        }
164    }
165}
166
167impl<T> From<Arc<T>> for AnyProtoClient
168where
169    T: ProtoClient + 'static,
170{
171    fn from(client: Arc<T>) -> Self {
172        Self(client)
173    }
174}
175
176impl AnyProtoClient {
177    pub fn new<T: ProtoClient + 'static>(client: Arc<T>) -> Self {
178        Self(client)
179    }
180
181    pub fn is_via_collab(&self) -> bool {
182        self.0.is_via_collab()
183    }
184
185    pub fn request<T: RequestMessage>(
186        &self,
187        request: T,
188    ) -> impl Future<Output = anyhow::Result<T::Response>> {
189        let envelope = request.into_envelope(0, None, None);
190        let response = self.0.request(envelope, T::NAME);
191        async move {
192            T::Response::from_envelope(response.await?)
193                .ok_or_else(|| anyhow!("received response of the wrong type"))
194        }
195    }
196
197    pub fn send<T: EnvelopedMessage>(&self, request: T) -> anyhow::Result<()> {
198        let envelope = request.into_envelope(0, None, None);
199        self.0.send(envelope, T::NAME)
200    }
201
202    pub fn send_response<T: EnvelopedMessage>(
203        &self,
204        request_id: u32,
205        request: T,
206    ) -> anyhow::Result<()> {
207        let envelope = request.into_envelope(0, Some(request_id), None);
208        self.0.send(envelope, T::NAME)
209    }
210
211    pub fn add_request_handler<M, E, H, F>(&self, model: gpui::WeakModel<E>, handler: H)
212    where
213        M: RequestMessage,
214        E: 'static,
215        H: 'static + Sync + Fn(Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F + Send + Sync,
216        F: 'static + Future<Output = anyhow::Result<M::Response>>,
217    {
218        self.0.message_handler_set().lock().add_message_handler(
219            TypeId::of::<M>(),
220            model.into(),
221            Arc::new(move |model, envelope, client, cx| {
222                let model = model.downcast::<E>().unwrap();
223                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
224                let request_id = envelope.message_id();
225                handler(model, *envelope, cx)
226                    .then(move |result| async move {
227                        match result {
228                            Ok(response) => {
229                                client.send_response(request_id, response)?;
230                                Ok(())
231                            }
232                            Err(error) => {
233                                client.send_response(request_id, error.to_proto())?;
234                                Err(error)
235                            }
236                        }
237                    })
238                    .boxed_local()
239            }),
240        )
241    }
242
243    pub fn add_model_request_handler<M, E, H, F>(&self, handler: H)
244    where
245        M: EnvelopedMessage + RequestMessage + EntityMessage,
246        E: 'static,
247        H: 'static + Sync + Send + Fn(gpui::Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F,
248        F: 'static + Future<Output = anyhow::Result<M::Response>>,
249    {
250        let message_type_id = TypeId::of::<M>();
251        let model_type_id = TypeId::of::<E>();
252        let entity_id_extractor = |envelope: &dyn AnyTypedEnvelope| {
253            envelope
254                .as_any()
255                .downcast_ref::<TypedEnvelope<M>>()
256                .unwrap()
257                .payload
258                .remote_entity_id()
259        };
260        self.0
261            .message_handler_set()
262            .lock()
263            .add_entity_message_handler(
264                message_type_id,
265                model_type_id,
266                entity_id_extractor,
267                Arc::new(move |model, envelope, client, cx| {
268                    let model = model.downcast::<E>().unwrap();
269                    let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
270                    let request_id = envelope.message_id();
271                    handler(model, *envelope, cx)
272                        .then(move |result| async move {
273                            match result {
274                                Ok(response) => {
275                                    client.send_response(request_id, response)?;
276                                    Ok(())
277                                }
278                                Err(error) => {
279                                    client.send_response(request_id, error.to_proto())?;
280                                    Err(error)
281                                }
282                            }
283                        })
284                        .boxed_local()
285                }),
286            );
287    }
288
289    pub fn add_model_message_handler<M, E, H, F>(&self, handler: H)
290    where
291        M: EnvelopedMessage + EntityMessage,
292        E: 'static,
293        H: 'static + Sync + Send + Fn(gpui::Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F,
294        F: 'static + Future<Output = anyhow::Result<()>>,
295    {
296        let message_type_id = TypeId::of::<M>();
297        let model_type_id = TypeId::of::<E>();
298        let entity_id_extractor = |envelope: &dyn AnyTypedEnvelope| {
299            envelope
300                .as_any()
301                .downcast_ref::<TypedEnvelope<M>>()
302                .unwrap()
303                .payload
304                .remote_entity_id()
305        };
306        self.0
307            .message_handler_set()
308            .lock()
309            .add_entity_message_handler(
310                message_type_id,
311                model_type_id,
312                entity_id_extractor,
313                Arc::new(move |model, envelope, _, cx| {
314                    let model = model.downcast::<E>().unwrap();
315                    let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
316                    handler(model, *envelope, cx).boxed_local()
317                }),
318            );
319    }
320}