proto_client.rs

  1use anyhow::anyhow;
  2use collections::HashMap;
  3use futures::{
  4    future::{BoxFuture, LocalBoxFuture},
  5    Future, FutureExt as _,
  6};
  7use gpui::{AnyModel, AnyWeakModel, AsyncAppContext, Model};
  8// pub use prost::Message;
  9use proto::{
 10    error::ErrorExt as _, AnyTypedEnvelope, EntityMessage, Envelope, EnvelopedMessage,
 11    RequestMessage, TypedEnvelope,
 12};
 13use std::{any::TypeId, sync::Arc};
 14
 15#[derive(Clone)]
 16pub struct AnyProtoClient(Arc<dyn ProtoClient>);
 17
 18pub trait ProtoClient: Send + Sync {
 19    fn request(
 20        &self,
 21        envelope: Envelope,
 22        request_type: &'static str,
 23    ) -> BoxFuture<'static, anyhow::Result<Envelope>>;
 24
 25    fn send(&self, envelope: Envelope, message_type: &'static str) -> anyhow::Result<()>;
 26
 27    fn send_response(&self, envelope: Envelope, message_type: &'static str) -> anyhow::Result<()>;
 28
 29    fn message_handler_set(&self) -> &parking_lot::Mutex<ProtoMessageHandlerSet>;
 30
 31    fn is_via_collab(&self) -> bool;
 32}
 33
 34#[derive(Default)]
 35pub struct ProtoMessageHandlerSet {
 36    pub entity_types_by_message_type: HashMap<TypeId, TypeId>,
 37    pub entities_by_type_and_remote_id: HashMap<(TypeId, u64), EntityMessageSubscriber>,
 38    pub entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
 39    pub models_by_message_type: HashMap<TypeId, AnyWeakModel>,
 40    pub message_handlers: HashMap<TypeId, ProtoMessageHandler>,
 41}
 42
 43pub type ProtoMessageHandler = Arc<
 44    dyn Send
 45        + Sync
 46        + Fn(
 47            AnyModel,
 48            Box<dyn AnyTypedEnvelope>,
 49            AnyProtoClient,
 50            AsyncAppContext,
 51        ) -> LocalBoxFuture<'static, anyhow::Result<()>>,
 52>;
 53
 54impl ProtoMessageHandlerSet {
 55    pub fn clear(&mut self) {
 56        self.message_handlers.clear();
 57        self.models_by_message_type.clear();
 58        self.entities_by_type_and_remote_id.clear();
 59        self.entity_id_extractors.clear();
 60    }
 61
 62    fn add_message_handler(
 63        &mut self,
 64        message_type_id: TypeId,
 65        model: gpui::AnyWeakModel,
 66        handler: ProtoMessageHandler,
 67    ) {
 68        self.models_by_message_type.insert(message_type_id, model);
 69        let prev_handler = self.message_handlers.insert(message_type_id, handler);
 70        if prev_handler.is_some() {
 71            panic!("registered handler for the same message twice");
 72        }
 73    }
 74
 75    fn add_entity_message_handler(
 76        &mut self,
 77        message_type_id: TypeId,
 78        model_type_id: TypeId,
 79        entity_id_extractor: fn(&dyn AnyTypedEnvelope) -> u64,
 80        handler: ProtoMessageHandler,
 81    ) {
 82        self.entity_id_extractors
 83            .entry(message_type_id)
 84            .or_insert(entity_id_extractor);
 85        self.entity_types_by_message_type
 86            .insert(message_type_id, model_type_id);
 87        let prev_handler = self.message_handlers.insert(message_type_id, handler);
 88        if prev_handler.is_some() {
 89            panic!("registered handler for the same message twice");
 90        }
 91    }
 92
 93    pub fn handle_message(
 94        this: &parking_lot::Mutex<Self>,
 95        message: Box<dyn AnyTypedEnvelope>,
 96        client: AnyProtoClient,
 97        cx: AsyncAppContext,
 98    ) -> Option<LocalBoxFuture<'static, anyhow::Result<()>>> {
 99        let payload_type_id = message.payload_type_id();
100        let mut this = this.lock();
101        let handler = this.message_handlers.get(&payload_type_id)?.clone();
102        let entity = if let Some(entity) = this.models_by_message_type.get(&payload_type_id) {
103            entity.upgrade()?
104        } else {
105            let extract_entity_id = *this.entity_id_extractors.get(&payload_type_id)?;
106            let entity_type_id = *this.entity_types_by_message_type.get(&payload_type_id)?;
107            let entity_id = (extract_entity_id)(message.as_ref());
108
109            match this
110                .entities_by_type_and_remote_id
111                .get_mut(&(entity_type_id, entity_id))?
112            {
113                EntityMessageSubscriber::Pending(pending) => {
114                    pending.push(message);
115                    return None;
116                }
117                EntityMessageSubscriber::Entity { handle } => handle.upgrade()?,
118            }
119        };
120        drop(this);
121        Some(handler(entity, message, client, cx))
122    }
123}
124
125pub enum EntityMessageSubscriber {
126    Entity { handle: AnyWeakModel },
127    Pending(Vec<Box<dyn AnyTypedEnvelope>>),
128}
129
130impl<T> From<Arc<T>> for AnyProtoClient
131where
132    T: ProtoClient + 'static,
133{
134    fn from(client: Arc<T>) -> Self {
135        Self(client)
136    }
137}
138
139impl AnyProtoClient {
140    pub fn new<T: ProtoClient + 'static>(client: Arc<T>) -> Self {
141        Self(client)
142    }
143
144    pub fn is_via_collab(&self) -> bool {
145        self.0.is_via_collab()
146    }
147
148    pub fn request<T: RequestMessage>(
149        &self,
150        request: T,
151    ) -> impl Future<Output = anyhow::Result<T::Response>> {
152        let envelope = request.into_envelope(0, None, None);
153        let response = self.0.request(envelope, T::NAME);
154        async move {
155            T::Response::from_envelope(response.await?)
156                .ok_or_else(|| anyhow!("received response of the wrong type"))
157        }
158    }
159
160    pub fn send<T: EnvelopedMessage>(&self, request: T) -> anyhow::Result<()> {
161        let envelope = request.into_envelope(0, None, None);
162        self.0.send(envelope, T::NAME)
163    }
164
165    pub fn send_response<T: EnvelopedMessage>(
166        &self,
167        request_id: u32,
168        request: T,
169    ) -> anyhow::Result<()> {
170        let envelope = request.into_envelope(0, Some(request_id), None);
171        self.0.send(envelope, T::NAME)
172    }
173
174    pub fn add_request_handler<M, E, H, F>(&self, model: gpui::WeakModel<E>, handler: H)
175    where
176        M: RequestMessage,
177        E: 'static,
178        H: 'static + Sync + Fn(Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F + Send + Sync,
179        F: 'static + Future<Output = anyhow::Result<M::Response>>,
180    {
181        self.0.message_handler_set().lock().add_message_handler(
182            TypeId::of::<M>(),
183            model.into(),
184            Arc::new(move |model, envelope, client, cx| {
185                let model = model.downcast::<E>().unwrap();
186                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
187                let request_id = envelope.message_id();
188                handler(model, *envelope, cx)
189                    .then(move |result| async move {
190                        match result {
191                            Ok(response) => {
192                                client.send_response(request_id, response)?;
193                                Ok(())
194                            }
195                            Err(error) => {
196                                client.send_response(request_id, error.to_proto())?;
197                                Err(error)
198                            }
199                        }
200                    })
201                    .boxed_local()
202            }),
203        )
204    }
205
206    pub fn add_model_request_handler<M, E, H, F>(&self, handler: H)
207    where
208        M: EnvelopedMessage + RequestMessage + EntityMessage,
209        E: 'static,
210        H: 'static + Sync + Send + Fn(gpui::Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F,
211        F: 'static + Future<Output = anyhow::Result<M::Response>>,
212    {
213        let message_type_id = TypeId::of::<M>();
214        let model_type_id = TypeId::of::<E>();
215        let entity_id_extractor = |envelope: &dyn AnyTypedEnvelope| {
216            envelope
217                .as_any()
218                .downcast_ref::<TypedEnvelope<M>>()
219                .unwrap()
220                .payload
221                .remote_entity_id()
222        };
223        self.0
224            .message_handler_set()
225            .lock()
226            .add_entity_message_handler(
227                message_type_id,
228                model_type_id,
229                entity_id_extractor,
230                Arc::new(move |model, envelope, client, cx| {
231                    let model = model.downcast::<E>().unwrap();
232                    let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
233                    let request_id = envelope.message_id();
234                    handler(model, *envelope, cx)
235                        .then(move |result| async move {
236                            match result {
237                                Ok(response) => {
238                                    client.send_response(request_id, response)?;
239                                    Ok(())
240                                }
241                                Err(error) => {
242                                    client.send_response(request_id, error.to_proto())?;
243                                    Err(error)
244                                }
245                            }
246                        })
247                        .boxed_local()
248                }),
249            );
250    }
251
252    pub fn add_model_message_handler<M, E, H, F>(&self, handler: H)
253    where
254        M: EnvelopedMessage + EntityMessage,
255        E: 'static,
256        H: 'static + Sync + Send + Fn(gpui::Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F,
257        F: 'static + Future<Output = anyhow::Result<()>>,
258    {
259        let message_type_id = TypeId::of::<M>();
260        let model_type_id = TypeId::of::<E>();
261        let entity_id_extractor = |envelope: &dyn AnyTypedEnvelope| {
262            envelope
263                .as_any()
264                .downcast_ref::<TypedEnvelope<M>>()
265                .unwrap()
266                .payload
267                .remote_entity_id()
268        };
269        self.0
270            .message_handler_set()
271            .lock()
272            .add_entity_message_handler(
273                message_type_id,
274                model_type_id,
275                entity_id_extractor,
276                Arc::new(move |model, envelope, _, cx| {
277                    let model = model.downcast::<E>().unwrap();
278                    let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
279                    handler(model, *envelope, cx).boxed_local()
280                }),
281            );
282    }
283}