proto_client.rs

  1use anyhow::anyhow;
  2use collections::HashMap;
  3use futures::{
  4    future::{BoxFuture, LocalBoxFuture},
  5    Future, FutureExt as _,
  6};
  7use gpui::{AnyModel, AnyWeakModel, AsyncAppContext, Model};
  8use proto::{
  9    error::ErrorExt as _, AnyTypedEnvelope, EntityMessage, Envelope, EnvelopedMessage,
 10    RequestMessage, TypedEnvelope,
 11};
 12use std::{
 13    any::TypeId,
 14    sync::{Arc, Weak},
 15};
 16
 17#[derive(Clone)]
 18pub struct AnyProtoClient(Arc<dyn ProtoClient>);
 19
 20impl AnyProtoClient {
 21    pub fn downgrade(&self) -> AnyWeakProtoClient {
 22        AnyWeakProtoClient(Arc::downgrade(&self.0))
 23    }
 24}
 25
 26#[derive(Clone)]
 27pub struct AnyWeakProtoClient(Weak<dyn ProtoClient>);
 28
 29impl AnyWeakProtoClient {
 30    pub fn upgrade(&self) -> Option<AnyProtoClient> {
 31        self.0.upgrade().map(AnyProtoClient)
 32    }
 33}
 34
 35pub trait ProtoClient: Send + Sync {
 36    fn request(
 37        &self,
 38        envelope: Envelope,
 39        request_type: &'static str,
 40    ) -> BoxFuture<'static, anyhow::Result<Envelope>>;
 41
 42    fn send(&self, envelope: Envelope, message_type: &'static str) -> anyhow::Result<()>;
 43
 44    fn send_response(&self, envelope: Envelope, message_type: &'static str) -> anyhow::Result<()>;
 45
 46    fn message_handler_set(&self) -> &parking_lot::Mutex<ProtoMessageHandlerSet>;
 47
 48    fn is_via_collab(&self) -> bool;
 49}
 50
 51#[derive(Default)]
 52pub struct ProtoMessageHandlerSet {
 53    pub entity_types_by_message_type: HashMap<TypeId, TypeId>,
 54    pub entities_by_type_and_remote_id: HashMap<(TypeId, u64), EntityMessageSubscriber>,
 55    pub entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
 56    pub models_by_message_type: HashMap<TypeId, AnyWeakModel>,
 57    pub message_handlers: HashMap<TypeId, ProtoMessageHandler>,
 58}
 59
 60pub type ProtoMessageHandler = Arc<
 61    dyn Send
 62        + Sync
 63        + Fn(
 64            AnyModel,
 65            Box<dyn AnyTypedEnvelope>,
 66            AnyProtoClient,
 67            AsyncAppContext,
 68        ) -> LocalBoxFuture<'static, anyhow::Result<()>>,
 69>;
 70
 71impl ProtoMessageHandlerSet {
 72    pub fn clear(&mut self) {
 73        self.message_handlers.clear();
 74        self.models_by_message_type.clear();
 75        self.entities_by_type_and_remote_id.clear();
 76        self.entity_id_extractors.clear();
 77    }
 78
 79    fn add_message_handler(
 80        &mut self,
 81        message_type_id: TypeId,
 82        model: gpui::AnyWeakModel,
 83        handler: ProtoMessageHandler,
 84    ) {
 85        self.models_by_message_type.insert(message_type_id, model);
 86        let prev_handler = self.message_handlers.insert(message_type_id, handler);
 87        if prev_handler.is_some() {
 88            panic!("registered handler for the same message twice");
 89        }
 90    }
 91
 92    fn add_entity_message_handler(
 93        &mut self,
 94        message_type_id: TypeId,
 95        model_type_id: TypeId,
 96        entity_id_extractor: fn(&dyn AnyTypedEnvelope) -> u64,
 97        handler: ProtoMessageHandler,
 98    ) {
 99        self.entity_id_extractors
100            .entry(message_type_id)
101            .or_insert(entity_id_extractor);
102        self.entity_types_by_message_type
103            .insert(message_type_id, model_type_id);
104        let prev_handler = self.message_handlers.insert(message_type_id, handler);
105        if prev_handler.is_some() {
106            panic!("registered handler for the same message twice");
107        }
108    }
109
110    pub fn handle_message(
111        this: &parking_lot::Mutex<Self>,
112        message: Box<dyn AnyTypedEnvelope>,
113        client: AnyProtoClient,
114        cx: AsyncAppContext,
115    ) -> Option<LocalBoxFuture<'static, anyhow::Result<()>>> {
116        let payload_type_id = message.payload_type_id();
117        let mut this = this.lock();
118        let handler = this.message_handlers.get(&payload_type_id)?.clone();
119        let entity = if let Some(entity) = this.models_by_message_type.get(&payload_type_id) {
120            entity.upgrade()?
121        } else {
122            let extract_entity_id = *this.entity_id_extractors.get(&payload_type_id)?;
123            let entity_type_id = *this.entity_types_by_message_type.get(&payload_type_id)?;
124            let entity_id = (extract_entity_id)(message.as_ref());
125            match this
126                .entities_by_type_and_remote_id
127                .get_mut(&(entity_type_id, entity_id))?
128            {
129                EntityMessageSubscriber::Pending(pending) => {
130                    pending.push(message);
131                    return None;
132                }
133                EntityMessageSubscriber::Entity { handle } => handle.upgrade()?,
134            }
135        };
136        drop(this);
137        Some(handler(entity, message, client, cx))
138    }
139}
140
141pub enum EntityMessageSubscriber {
142    Entity { handle: AnyWeakModel },
143    Pending(Vec<Box<dyn AnyTypedEnvelope>>),
144}
145
146impl std::fmt::Debug for EntityMessageSubscriber {
147    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
148        match self {
149            EntityMessageSubscriber::Entity { handle } => f
150                .debug_struct("EntityMessageSubscriber::Entity")
151                .field("handle", handle)
152                .finish(),
153            EntityMessageSubscriber::Pending(vec) => f
154                .debug_struct("EntityMessageSubscriber::Pending")
155                .field(
156                    "envelopes",
157                    &vec.iter()
158                        .map(|envelope| envelope.payload_type_name())
159                        .collect::<Vec<_>>(),
160                )
161                .finish(),
162        }
163    }
164}
165
166impl<T> From<Arc<T>> for AnyProtoClient
167where
168    T: ProtoClient + 'static,
169{
170    fn from(client: Arc<T>) -> Self {
171        Self(client)
172    }
173}
174
175impl AnyProtoClient {
176    pub fn new<T: ProtoClient + 'static>(client: Arc<T>) -> Self {
177        Self(client)
178    }
179
180    pub fn is_via_collab(&self) -> bool {
181        self.0.is_via_collab()
182    }
183
184    pub fn request<T: RequestMessage>(
185        &self,
186        request: T,
187    ) -> impl Future<Output = anyhow::Result<T::Response>> {
188        let envelope = request.into_envelope(0, None, None);
189        let response = self.0.request(envelope, T::NAME);
190        async move {
191            T::Response::from_envelope(response.await?)
192                .ok_or_else(|| anyhow!("received response of the wrong type"))
193        }
194    }
195
196    pub fn send<T: EnvelopedMessage>(&self, request: T) -> anyhow::Result<()> {
197        let envelope = request.into_envelope(0, None, None);
198        self.0.send(envelope, T::NAME)
199    }
200
201    pub fn send_response<T: EnvelopedMessage>(
202        &self,
203        request_id: u32,
204        request: T,
205    ) -> anyhow::Result<()> {
206        let envelope = request.into_envelope(0, Some(request_id), None);
207        self.0.send(envelope, T::NAME)
208    }
209
210    pub fn add_request_handler<M, E, H, F>(&self, model: gpui::WeakModel<E>, handler: H)
211    where
212        M: RequestMessage,
213        E: 'static,
214        H: 'static + Sync + Fn(Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F + Send + Sync,
215        F: 'static + Future<Output = anyhow::Result<M::Response>>,
216    {
217        self.0.message_handler_set().lock().add_message_handler(
218            TypeId::of::<M>(),
219            model.into(),
220            Arc::new(move |model, envelope, client, cx| {
221                let model = model.downcast::<E>().unwrap();
222                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
223                let request_id = envelope.message_id();
224                handler(model, *envelope, cx)
225                    .then(move |result| async move {
226                        match result {
227                            Ok(response) => {
228                                client.send_response(request_id, response)?;
229                                Ok(())
230                            }
231                            Err(error) => {
232                                client.send_response(request_id, error.to_proto())?;
233                                Err(error)
234                            }
235                        }
236                    })
237                    .boxed_local()
238            }),
239        )
240    }
241
242    pub fn add_model_request_handler<M, E, H, F>(&self, handler: H)
243    where
244        M: EnvelopedMessage + RequestMessage + EntityMessage,
245        E: 'static,
246        H: 'static + Sync + Send + Fn(gpui::Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F,
247        F: 'static + Future<Output = anyhow::Result<M::Response>>,
248    {
249        let message_type_id = TypeId::of::<M>();
250        let model_type_id = TypeId::of::<E>();
251        let entity_id_extractor = |envelope: &dyn AnyTypedEnvelope| {
252            envelope
253                .as_any()
254                .downcast_ref::<TypedEnvelope<M>>()
255                .unwrap()
256                .payload
257                .remote_entity_id()
258        };
259        self.0
260            .message_handler_set()
261            .lock()
262            .add_entity_message_handler(
263                message_type_id,
264                model_type_id,
265                entity_id_extractor,
266                Arc::new(move |model, envelope, client, cx| {
267                    let model = model.downcast::<E>().unwrap();
268                    let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
269                    let request_id = envelope.message_id();
270                    handler(model, *envelope, cx)
271                        .then(move |result| async move {
272                            match result {
273                                Ok(response) => {
274                                    client.send_response(request_id, response)?;
275                                    Ok(())
276                                }
277                                Err(error) => {
278                                    client.send_response(request_id, error.to_proto())?;
279                                    Err(error)
280                                }
281                            }
282                        })
283                        .boxed_local()
284                }),
285            );
286    }
287
288    pub fn add_model_message_handler<M, E, H, F>(&self, handler: H)
289    where
290        M: EnvelopedMessage + EntityMessage,
291        E: 'static,
292        H: 'static + Sync + Send + Fn(gpui::Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F,
293        F: 'static + Future<Output = anyhow::Result<()>>,
294    {
295        let message_type_id = TypeId::of::<M>();
296        let model_type_id = TypeId::of::<E>();
297        let entity_id_extractor = |envelope: &dyn AnyTypedEnvelope| {
298            envelope
299                .as_any()
300                .downcast_ref::<TypedEnvelope<M>>()
301                .unwrap()
302                .payload
303                .remote_entity_id()
304        };
305        self.0
306            .message_handler_set()
307            .lock()
308            .add_entity_message_handler(
309                message_type_id,
310                model_type_id,
311                entity_id_extractor,
312                Arc::new(move |model, envelope, _, cx| {
313                    let model = model.downcast::<E>().unwrap();
314                    let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
315                    handler(model, *envelope, cx).boxed_local()
316                }),
317            );
318    }
319}