proto_client.rs

  1use crate::{
  2    error::ErrorExt as _, AnyTypedEnvelope, EntityMessage, Envelope, EnvelopedMessage,
  3    RequestMessage, TypedEnvelope,
  4};
  5use anyhow::anyhow;
  6use collections::HashMap;
  7use futures::{
  8    future::{BoxFuture, LocalBoxFuture},
  9    Future, FutureExt as _,
 10};
 11use gpui::{AnyModel, AnyWeakModel, AsyncAppContext, Model};
 12pub use prost::Message;
 13use std::{any::TypeId, sync::Arc};
 14
 15#[derive(Clone)]
 16pub struct AnyProtoClient(Arc<dyn ProtoClient>);
 17
 18pub trait ProtoClient: Send + Sync {
 19    fn request(
 20        &self,
 21        envelope: Envelope,
 22        request_type: &'static str,
 23    ) -> BoxFuture<'static, anyhow::Result<Envelope>>;
 24
 25    fn send(&self, envelope: Envelope, message_type: &'static str) -> anyhow::Result<()>;
 26
 27    fn send_response(&self, envelope: Envelope, message_type: &'static str) -> anyhow::Result<()>;
 28
 29    fn message_handler_set(&self) -> &parking_lot::Mutex<ProtoMessageHandlerSet>;
 30}
 31
 32#[derive(Default)]
 33pub struct ProtoMessageHandlerSet {
 34    pub entity_types_by_message_type: HashMap<TypeId, TypeId>,
 35    pub entities_by_type_and_remote_id: HashMap<(TypeId, u64), EntityMessageSubscriber>,
 36    pub entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
 37    pub models_by_message_type: HashMap<TypeId, AnyWeakModel>,
 38    pub message_handlers: HashMap<TypeId, ProtoMessageHandler>,
 39}
 40
 41pub type ProtoMessageHandler = Arc<
 42    dyn Send
 43        + Sync
 44        + Fn(
 45            AnyModel,
 46            Box<dyn AnyTypedEnvelope>,
 47            AnyProtoClient,
 48            AsyncAppContext,
 49        ) -> LocalBoxFuture<'static, anyhow::Result<()>>,
 50>;
 51
 52impl ProtoMessageHandlerSet {
 53    pub fn clear(&mut self) {
 54        self.message_handlers.clear();
 55        self.models_by_message_type.clear();
 56        self.entities_by_type_and_remote_id.clear();
 57        self.entity_id_extractors.clear();
 58    }
 59
 60    fn add_message_handler(
 61        &mut self,
 62        message_type_id: TypeId,
 63        model: gpui::AnyWeakModel,
 64        handler: ProtoMessageHandler,
 65    ) {
 66        self.models_by_message_type.insert(message_type_id, model);
 67        let prev_handler = self.message_handlers.insert(message_type_id, handler);
 68        if prev_handler.is_some() {
 69            panic!("registered handler for the same message twice");
 70        }
 71    }
 72
 73    fn add_entity_message_handler(
 74        &mut self,
 75        message_type_id: TypeId,
 76        model_type_id: TypeId,
 77        entity_id_extractor: fn(&dyn AnyTypedEnvelope) -> u64,
 78        handler: ProtoMessageHandler,
 79    ) {
 80        self.entity_id_extractors
 81            .entry(message_type_id)
 82            .or_insert(entity_id_extractor);
 83        self.entity_types_by_message_type
 84            .insert(message_type_id, model_type_id);
 85        let prev_handler = self.message_handlers.insert(message_type_id, handler);
 86        if prev_handler.is_some() {
 87            panic!("registered handler for the same message twice");
 88        }
 89    }
 90
 91    pub fn handle_message(
 92        this: &parking_lot::Mutex<Self>,
 93        message: Box<dyn AnyTypedEnvelope>,
 94        client: AnyProtoClient,
 95        cx: AsyncAppContext,
 96    ) -> Option<LocalBoxFuture<'static, anyhow::Result<()>>> {
 97        let payload_type_id = message.payload_type_id();
 98        let mut this = this.lock();
 99        let handler = this.message_handlers.get(&payload_type_id)?.clone();
100        let entity = if let Some(entity) = this.models_by_message_type.get(&payload_type_id) {
101            entity.upgrade()?
102        } else {
103            let extract_entity_id = *this.entity_id_extractors.get(&payload_type_id)?;
104            let entity_type_id = *this.entity_types_by_message_type.get(&payload_type_id)?;
105            let entity_id = (extract_entity_id)(message.as_ref());
106
107            match this
108                .entities_by_type_and_remote_id
109                .get_mut(&(entity_type_id, entity_id))?
110            {
111                EntityMessageSubscriber::Pending(pending) => {
112                    pending.push(message);
113                    return None;
114                }
115                EntityMessageSubscriber::Entity { handle } => handle.upgrade()?,
116            }
117        };
118        drop(this);
119        Some(handler(entity, message, client, cx))
120    }
121}
122
123pub enum EntityMessageSubscriber {
124    Entity { handle: AnyWeakModel },
125    Pending(Vec<Box<dyn AnyTypedEnvelope>>),
126}
127
128impl<T> From<Arc<T>> for AnyProtoClient
129where
130    T: ProtoClient + 'static,
131{
132    fn from(client: Arc<T>) -> Self {
133        Self(client)
134    }
135}
136
137impl AnyProtoClient {
138    pub fn new<T: ProtoClient + 'static>(client: Arc<T>) -> Self {
139        Self(client)
140    }
141
142    pub fn request<T: RequestMessage>(
143        &self,
144        request: T,
145    ) -> impl Future<Output = anyhow::Result<T::Response>> {
146        let envelope = request.into_envelope(0, None, None);
147        let response = self.0.request(envelope, T::NAME);
148        async move {
149            T::Response::from_envelope(response.await?)
150                .ok_or_else(|| anyhow!("received response of the wrong type"))
151        }
152    }
153
154    pub fn send<T: EnvelopedMessage>(&self, request: T) -> anyhow::Result<()> {
155        let envelope = request.into_envelope(0, None, None);
156        self.0.send(envelope, T::NAME)
157    }
158
159    pub fn send_response<T: EnvelopedMessage>(
160        &self,
161        request_id: u32,
162        request: T,
163    ) -> anyhow::Result<()> {
164        let envelope = request.into_envelope(0, Some(request_id), None);
165        self.0.send(envelope, T::NAME)
166    }
167
168    pub fn add_request_handler<M, E, H, F>(&self, model: gpui::WeakModel<E>, handler: H)
169    where
170        M: RequestMessage,
171        E: 'static,
172        H: 'static + Sync + Fn(Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F + Send + Sync,
173        F: 'static + Future<Output = anyhow::Result<M::Response>>,
174    {
175        self.0.message_handler_set().lock().add_message_handler(
176            TypeId::of::<M>(),
177            model.into(),
178            Arc::new(move |model, envelope, client, cx| {
179                let model = model.downcast::<E>().unwrap();
180                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
181                let request_id = envelope.message_id();
182                handler(model, *envelope, cx)
183                    .then(move |result| async move {
184                        match result {
185                            Ok(response) => {
186                                client.send_response(request_id, response)?;
187                                Ok(())
188                            }
189                            Err(error) => {
190                                client.send_response(request_id, error.to_proto())?;
191                                Err(error)
192                            }
193                        }
194                    })
195                    .boxed_local()
196            }),
197        )
198    }
199
200    pub fn add_model_request_handler<M, E, H, F>(&self, handler: H)
201    where
202        M: EnvelopedMessage + RequestMessage + EntityMessage,
203        E: 'static,
204        H: 'static + Sync + Send + Fn(gpui::Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F,
205        F: 'static + Future<Output = anyhow::Result<M::Response>>,
206    {
207        let message_type_id = TypeId::of::<M>();
208        let model_type_id = TypeId::of::<E>();
209        let entity_id_extractor = |envelope: &dyn AnyTypedEnvelope| {
210            envelope
211                .as_any()
212                .downcast_ref::<TypedEnvelope<M>>()
213                .unwrap()
214                .payload
215                .remote_entity_id()
216        };
217        self.0
218            .message_handler_set()
219            .lock()
220            .add_entity_message_handler(
221                message_type_id,
222                model_type_id,
223                entity_id_extractor,
224                Arc::new(move |model, envelope, client, cx| {
225                    let model = model.downcast::<E>().unwrap();
226                    let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
227                    let request_id = envelope.message_id();
228                    handler(model, *envelope, cx)
229                        .then(move |result| async move {
230                            match result {
231                                Ok(response) => {
232                                    client.send_response(request_id, response)?;
233                                    Ok(())
234                                }
235                                Err(error) => {
236                                    client.send_response(request_id, error.to_proto())?;
237                                    Err(error)
238                                }
239                            }
240                        })
241                        .boxed_local()
242                }),
243            );
244    }
245
246    pub fn add_model_message_handler<M, E, H, F>(&self, handler: H)
247    where
248        M: EnvelopedMessage + EntityMessage,
249        E: 'static,
250        H: 'static + Sync + Send + Fn(gpui::Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F,
251        F: 'static + Future<Output = anyhow::Result<()>>,
252    {
253        let message_type_id = TypeId::of::<M>();
254        let model_type_id = TypeId::of::<E>();
255        let entity_id_extractor = |envelope: &dyn AnyTypedEnvelope| {
256            envelope
257                .as_any()
258                .downcast_ref::<TypedEnvelope<M>>()
259                .unwrap()
260                .payload
261                .remote_entity_id()
262        };
263        self.0
264            .message_handler_set()
265            .lock()
266            .add_entity_message_handler(
267                message_type_id,
268                model_type_id,
269                entity_id_extractor,
270                Arc::new(move |model, envelope, _, cx| {
271                    let model = model.downcast::<E>().unwrap();
272                    let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
273                    handler(model, *envelope, cx).boxed_local()
274                }),
275            );
276    }
277}