proto_client.rs

  1use anyhow::anyhow;
  2use collections::HashMap;
  3use futures::{
  4    future::{BoxFuture, LocalBoxFuture},
  5    Future, FutureExt as _,
  6};
  7use gpui::{AnyModel, AnyWeakModel, AsyncAppContext, Model};
  8// pub use prost::Message;
  9use proto::{
 10    error::ErrorExt as _, AnyTypedEnvelope, EntityMessage, Envelope, EnvelopedMessage,
 11    RequestMessage, TypedEnvelope,
 12};
 13use std::{
 14    any::TypeId,
 15    sync::{Arc, Weak},
 16};
 17
 18#[derive(Clone)]
 19pub struct AnyProtoClient(Arc<dyn ProtoClient>);
 20
 21impl AnyProtoClient {
 22    pub fn downgrade(&self) -> AnyWeakProtoClient {
 23        AnyWeakProtoClient(Arc::downgrade(&self.0))
 24    }
 25}
 26
 27#[derive(Clone)]
 28pub struct AnyWeakProtoClient(Weak<dyn ProtoClient>);
 29
 30impl AnyWeakProtoClient {
 31    pub fn upgrade(&self) -> Option<AnyProtoClient> {
 32        self.0.upgrade().map(AnyProtoClient)
 33    }
 34}
 35
 36pub trait ProtoClient: Send + Sync {
 37    fn request(
 38        &self,
 39        envelope: Envelope,
 40        request_type: &'static str,
 41    ) -> BoxFuture<'static, anyhow::Result<Envelope>>;
 42
 43    fn send(&self, envelope: Envelope, message_type: &'static str) -> anyhow::Result<()>;
 44
 45    fn send_response(&self, envelope: Envelope, message_type: &'static str) -> anyhow::Result<()>;
 46
 47    fn message_handler_set(&self) -> &parking_lot::Mutex<ProtoMessageHandlerSet>;
 48
 49    fn is_via_collab(&self) -> bool;
 50}
 51
 52#[derive(Default)]
 53pub struct ProtoMessageHandlerSet {
 54    pub entity_types_by_message_type: HashMap<TypeId, TypeId>,
 55    pub entities_by_type_and_remote_id: HashMap<(TypeId, u64), EntityMessageSubscriber>,
 56    pub entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
 57    pub models_by_message_type: HashMap<TypeId, AnyWeakModel>,
 58    pub message_handlers: HashMap<TypeId, ProtoMessageHandler>,
 59}
 60
 61pub type ProtoMessageHandler = Arc<
 62    dyn Send
 63        + Sync
 64        + Fn(
 65            AnyModel,
 66            Box<dyn AnyTypedEnvelope>,
 67            AnyProtoClient,
 68            AsyncAppContext,
 69        ) -> LocalBoxFuture<'static, anyhow::Result<()>>,
 70>;
 71
 72impl ProtoMessageHandlerSet {
 73    pub fn clear(&mut self) {
 74        self.message_handlers.clear();
 75        self.models_by_message_type.clear();
 76        self.entities_by_type_and_remote_id.clear();
 77        self.entity_id_extractors.clear();
 78    }
 79
 80    fn add_message_handler(
 81        &mut self,
 82        message_type_id: TypeId,
 83        model: gpui::AnyWeakModel,
 84        handler: ProtoMessageHandler,
 85    ) {
 86        self.models_by_message_type.insert(message_type_id, model);
 87        let prev_handler = self.message_handlers.insert(message_type_id, handler);
 88        if prev_handler.is_some() {
 89            panic!("registered handler for the same message twice");
 90        }
 91    }
 92
 93    fn add_entity_message_handler(
 94        &mut self,
 95        message_type_id: TypeId,
 96        model_type_id: TypeId,
 97        entity_id_extractor: fn(&dyn AnyTypedEnvelope) -> u64,
 98        handler: ProtoMessageHandler,
 99    ) {
100        self.entity_id_extractors
101            .entry(message_type_id)
102            .or_insert(entity_id_extractor);
103        self.entity_types_by_message_type
104            .insert(message_type_id, model_type_id);
105        let prev_handler = self.message_handlers.insert(message_type_id, handler);
106        if prev_handler.is_some() {
107            panic!("registered handler for the same message twice");
108        }
109    }
110
111    pub fn handle_message(
112        this: &parking_lot::Mutex<Self>,
113        message: Box<dyn AnyTypedEnvelope>,
114        client: AnyProtoClient,
115        cx: AsyncAppContext,
116    ) -> Option<LocalBoxFuture<'static, anyhow::Result<()>>> {
117        let payload_type_id = message.payload_type_id();
118        let mut this = this.lock();
119        let handler = this.message_handlers.get(&payload_type_id)?.clone();
120        let entity = if let Some(entity) = this.models_by_message_type.get(&payload_type_id) {
121            entity.upgrade()?
122        } else {
123            let extract_entity_id = *this.entity_id_extractors.get(&payload_type_id)?;
124            let entity_type_id = *this.entity_types_by_message_type.get(&payload_type_id)?;
125            let entity_id = (extract_entity_id)(message.as_ref());
126
127            match this
128                .entities_by_type_and_remote_id
129                .get_mut(&(entity_type_id, entity_id))?
130            {
131                EntityMessageSubscriber::Pending(pending) => {
132                    pending.push(message);
133                    return None;
134                }
135                EntityMessageSubscriber::Entity { handle } => handle.upgrade()?,
136            }
137        };
138        drop(this);
139        Some(handler(entity, message, client, cx))
140    }
141}
142
143pub enum EntityMessageSubscriber {
144    Entity { handle: AnyWeakModel },
145    Pending(Vec<Box<dyn AnyTypedEnvelope>>),
146}
147
148impl<T> From<Arc<T>> for AnyProtoClient
149where
150    T: ProtoClient + 'static,
151{
152    fn from(client: Arc<T>) -> Self {
153        Self(client)
154    }
155}
156
157impl AnyProtoClient {
158    pub fn new<T: ProtoClient + 'static>(client: Arc<T>) -> Self {
159        Self(client)
160    }
161
162    pub fn is_via_collab(&self) -> bool {
163        self.0.is_via_collab()
164    }
165
166    pub fn request<T: RequestMessage>(
167        &self,
168        request: T,
169    ) -> impl Future<Output = anyhow::Result<T::Response>> {
170        let envelope = request.into_envelope(0, None, None);
171        let response = self.0.request(envelope, T::NAME);
172        async move {
173            T::Response::from_envelope(response.await?)
174                .ok_or_else(|| anyhow!("received response of the wrong type"))
175        }
176    }
177
178    pub fn send<T: EnvelopedMessage>(&self, request: T) -> anyhow::Result<()> {
179        let envelope = request.into_envelope(0, None, None);
180        self.0.send(envelope, T::NAME)
181    }
182
183    pub fn send_response<T: EnvelopedMessage>(
184        &self,
185        request_id: u32,
186        request: T,
187    ) -> anyhow::Result<()> {
188        let envelope = request.into_envelope(0, Some(request_id), None);
189        self.0.send(envelope, T::NAME)
190    }
191
192    pub fn add_request_handler<M, E, H, F>(&self, model: gpui::WeakModel<E>, handler: H)
193    where
194        M: RequestMessage,
195        E: 'static,
196        H: 'static + Sync + Fn(Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F + Send + Sync,
197        F: 'static + Future<Output = anyhow::Result<M::Response>>,
198    {
199        self.0.message_handler_set().lock().add_message_handler(
200            TypeId::of::<M>(),
201            model.into(),
202            Arc::new(move |model, envelope, client, cx| {
203                let model = model.downcast::<E>().unwrap();
204                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
205                let request_id = envelope.message_id();
206                handler(model, *envelope, cx)
207                    .then(move |result| async move {
208                        match result {
209                            Ok(response) => {
210                                client.send_response(request_id, response)?;
211                                Ok(())
212                            }
213                            Err(error) => {
214                                client.send_response(request_id, error.to_proto())?;
215                                Err(error)
216                            }
217                        }
218                    })
219                    .boxed_local()
220            }),
221        )
222    }
223
224    pub fn add_model_request_handler<M, E, H, F>(&self, handler: H)
225    where
226        M: EnvelopedMessage + RequestMessage + EntityMessage,
227        E: 'static,
228        H: 'static + Sync + Send + Fn(gpui::Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F,
229        F: 'static + Future<Output = anyhow::Result<M::Response>>,
230    {
231        let message_type_id = TypeId::of::<M>();
232        let model_type_id = TypeId::of::<E>();
233        let entity_id_extractor = |envelope: &dyn AnyTypedEnvelope| {
234            envelope
235                .as_any()
236                .downcast_ref::<TypedEnvelope<M>>()
237                .unwrap()
238                .payload
239                .remote_entity_id()
240        };
241        self.0
242            .message_handler_set()
243            .lock()
244            .add_entity_message_handler(
245                message_type_id,
246                model_type_id,
247                entity_id_extractor,
248                Arc::new(move |model, envelope, client, cx| {
249                    let model = model.downcast::<E>().unwrap();
250                    let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
251                    let request_id = envelope.message_id();
252                    handler(model, *envelope, cx)
253                        .then(move |result| async move {
254                            match result {
255                                Ok(response) => {
256                                    client.send_response(request_id, response)?;
257                                    Ok(())
258                                }
259                                Err(error) => {
260                                    client.send_response(request_id, error.to_proto())?;
261                                    Err(error)
262                                }
263                            }
264                        })
265                        .boxed_local()
266                }),
267            );
268    }
269
270    pub fn add_model_message_handler<M, E, H, F>(&self, handler: H)
271    where
272        M: EnvelopedMessage + EntityMessage,
273        E: 'static,
274        H: 'static + Sync + Send + Fn(gpui::Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F,
275        F: 'static + Future<Output = anyhow::Result<()>>,
276    {
277        let message_type_id = TypeId::of::<M>();
278        let model_type_id = TypeId::of::<E>();
279        let entity_id_extractor = |envelope: &dyn AnyTypedEnvelope| {
280            envelope
281                .as_any()
282                .downcast_ref::<TypedEnvelope<M>>()
283                .unwrap()
284                .payload
285                .remote_entity_id()
286        };
287        self.0
288            .message_handler_set()
289            .lock()
290            .add_entity_message_handler(
291                message_type_id,
292                model_type_id,
293                entity_id_extractor,
294                Arc::new(move |model, envelope, _, cx| {
295                    let model = model.downcast::<E>().unwrap();
296                    let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
297                    handler(model, *envelope, cx).boxed_local()
298                }),
299            );
300    }
301}