1use anyhow::Context;
2use collections::HashMap;
3use futures::{
4 Future, FutureExt as _,
5 future::{BoxFuture, LocalBoxFuture},
6};
7use gpui::{AnyEntity, AnyWeakEntity, AsyncApp, Entity};
8use proto::{
9 AnyTypedEnvelope, EntityMessage, Envelope, EnvelopedMessage, RequestMessage, TypedEnvelope,
10 error::ErrorExt as _,
11};
12use std::{
13 any::{Any, TypeId},
14 sync::{Arc, Weak},
15};
16
17#[derive(Clone)]
18pub struct AnyProtoClient(Arc<dyn ProtoClient>);
19
20impl AnyProtoClient {
21 pub fn downgrade(&self) -> AnyWeakProtoClient {
22 AnyWeakProtoClient(Arc::downgrade(&self.0))
23 }
24}
25
26#[derive(Clone)]
27pub struct AnyWeakProtoClient(Weak<dyn ProtoClient>);
28
29impl AnyWeakProtoClient {
30 pub fn upgrade(&self) -> Option<AnyProtoClient> {
31 self.0.upgrade().map(AnyProtoClient)
32 }
33}
34
35pub trait ProtoClient: Send + Sync {
36 fn request(
37 &self,
38 envelope: Envelope,
39 request_type: &'static str,
40 ) -> BoxFuture<'static, anyhow::Result<Envelope>>;
41
42 fn send(&self, envelope: Envelope, message_type: &'static str) -> anyhow::Result<()>;
43
44 fn send_response(&self, envelope: Envelope, message_type: &'static str) -> anyhow::Result<()>;
45
46 fn message_handler_set(&self) -> &parking_lot::Mutex<ProtoMessageHandlerSet>;
47
48 fn is_via_collab(&self) -> bool;
49}
50
51#[derive(Default)]
52pub struct ProtoMessageHandlerSet {
53 pub entity_types_by_message_type: HashMap<TypeId, TypeId>,
54 pub entities_by_type_and_remote_id: HashMap<(TypeId, u64), EntityMessageSubscriber>,
55 pub entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
56 pub entities_by_message_type: HashMap<TypeId, AnyWeakEntity>,
57 pub message_handlers: HashMap<TypeId, ProtoMessageHandler>,
58}
59
60// todo! try to remove these. we can't store handles inside send/sync stuff
61unsafe impl Send for ProtoMessageHandlerSet {}
62unsafe impl Sync for ProtoMessageHandlerSet {}
63
64pub type ProtoMessageHandler = Arc<
65 dyn Send
66 + Sync
67 + Fn(
68 AnyEntity,
69 Box<dyn AnyTypedEnvelope>,
70 AnyProtoClient,
71 AsyncApp,
72 ) -> LocalBoxFuture<'static, anyhow::Result<()>>,
73>;
74
75impl ProtoMessageHandlerSet {
76 pub fn clear(&mut self) {
77 self.message_handlers.clear();
78 self.entities_by_message_type.clear();
79 self.entities_by_type_and_remote_id.clear();
80 self.entity_id_extractors.clear();
81 }
82
83 fn add_message_handler(
84 &mut self,
85 message_type_id: TypeId,
86 entity: gpui::AnyWeakEntity,
87 handler: ProtoMessageHandler,
88 ) {
89 self.entities_by_message_type
90 .insert(message_type_id, entity);
91 let prev_handler = self.message_handlers.insert(message_type_id, handler);
92 if prev_handler.is_some() {
93 panic!("registered handler for the same message twice");
94 }
95 }
96
97 fn add_entity_message_handler(
98 &mut self,
99 message_type_id: TypeId,
100 entity_type_id: TypeId,
101 entity_id_extractor: fn(&dyn AnyTypedEnvelope) -> u64,
102 handler: ProtoMessageHandler,
103 ) {
104 self.entity_id_extractors
105 .entry(message_type_id)
106 .or_insert(entity_id_extractor);
107 self.entity_types_by_message_type
108 .insert(message_type_id, entity_type_id);
109 let prev_handler = self.message_handlers.insert(message_type_id, handler);
110 if prev_handler.is_some() {
111 panic!("registered handler for the same message twice");
112 }
113 }
114
115 pub fn handle_message(
116 this: &parking_lot::Mutex<Self>,
117 message: Box<dyn AnyTypedEnvelope>,
118 client: AnyProtoClient,
119 cx: AsyncApp,
120 ) -> Option<LocalBoxFuture<'static, anyhow::Result<()>>> {
121 let payload_type_id = message.payload_type_id();
122 let mut this = this.lock();
123 let handler = this.message_handlers.get(&payload_type_id)?.clone();
124 let entity = if let Some(entity) = this.entities_by_message_type.get(&payload_type_id) {
125 entity.upgrade()?
126 } else {
127 let extract_entity_id = *this.entity_id_extractors.get(&payload_type_id)?;
128 let entity_type_id = *this.entity_types_by_message_type.get(&payload_type_id)?;
129 let entity_id = (extract_entity_id)(message.as_ref());
130 match this
131 .entities_by_type_and_remote_id
132 .get_mut(&(entity_type_id, entity_id))?
133 {
134 EntityMessageSubscriber::Pending(pending) => {
135 pending.push(message);
136 return None;
137 }
138 EntityMessageSubscriber::Entity { handle } => handle.upgrade()?,
139 }
140 };
141 drop(this);
142 Some(handler(entity, message, client, cx))
143 }
144}
145
146pub enum EntityMessageSubscriber {
147 Entity { handle: AnyWeakEntity },
148 Pending(Vec<Box<dyn AnyTypedEnvelope>>),
149}
150
151impl std::fmt::Debug for EntityMessageSubscriber {
152 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153 match self {
154 EntityMessageSubscriber::Entity { handle } => f
155 .debug_struct("EntityMessageSubscriber::Entity")
156 .field("handle", handle)
157 .finish(),
158 EntityMessageSubscriber::Pending(vec) => f
159 .debug_struct("EntityMessageSubscriber::Pending")
160 .field(
161 "envelopes",
162 &vec.iter()
163 .map(|envelope| envelope.payload_type_name())
164 .collect::<Vec<_>>(),
165 )
166 .finish(),
167 }
168 }
169}
170
171impl<T> From<Arc<T>> for AnyProtoClient
172where
173 T: ProtoClient + 'static,
174{
175 fn from(client: Arc<T>) -> Self {
176 Self(client)
177 }
178}
179
180impl AnyProtoClient {
181 pub fn new<T: ProtoClient + 'static>(client: Arc<T>) -> Self {
182 Self(client)
183 }
184
185 pub fn is_via_collab(&self) -> bool {
186 self.0.is_via_collab()
187 }
188
189 pub fn request<T: RequestMessage>(
190 &self,
191 request: T,
192 ) -> impl Future<Output = anyhow::Result<T::Response>> + use<T> {
193 let envelope = request.into_envelope(0, None, None);
194 let response = self.0.request(envelope, T::NAME);
195 async move {
196 T::Response::from_envelope(response.await?)
197 .context("received response of the wrong type")
198 }
199 }
200
201 pub fn send<T: EnvelopedMessage>(&self, request: T) -> anyhow::Result<()> {
202 let envelope = request.into_envelope(0, None, None);
203 self.0.send(envelope, T::NAME)
204 }
205
206 pub fn send_response<T: EnvelopedMessage>(
207 &self,
208 request_id: u32,
209 request: T,
210 ) -> anyhow::Result<()> {
211 let envelope = request.into_envelope(0, Some(request_id), None);
212 self.0.send(envelope, T::NAME)
213 }
214
215 pub fn add_request_handler<M, E, H, F>(&self, entity: gpui::WeakEntity<E>, handler: H)
216 where
217 M: RequestMessage,
218 E: 'static,
219 H: 'static + Sync + Fn(Entity<E>, TypedEnvelope<M>, AsyncApp) -> F + Send + Sync,
220 F: 'static + Future<Output = anyhow::Result<M::Response>>,
221 {
222 self.0.message_handler_set().lock().add_message_handler(
223 TypeId::of::<M>(),
224 entity.into(),
225 Arc::new(move |entity, envelope, client, cx| {
226 let entity = entity.downcast::<E>().unwrap();
227 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
228 let request_id = envelope.message_id();
229 handler(entity, *envelope, cx)
230 .then(move |result| async move {
231 match result {
232 Ok(response) => {
233 client.send_response(request_id, response)?;
234 Ok(())
235 }
236 Err(error) => {
237 client.send_response(request_id, error.to_proto())?;
238 Err(error)
239 }
240 }
241 })
242 .boxed_local()
243 }),
244 )
245 }
246
247 pub fn add_entity_request_handler<M, E, H, F>(&self, handler: H)
248 where
249 M: EnvelopedMessage + RequestMessage + EntityMessage,
250 E: 'static,
251 H: 'static + Sync + Send + Fn(gpui::Entity<E>, TypedEnvelope<M>, AsyncApp) -> F,
252 F: 'static + Future<Output = anyhow::Result<M::Response>>,
253 {
254 let message_type_id = TypeId::of::<M>();
255 let entity_type_id = TypeId::of::<E>();
256 let entity_id_extractor = |envelope: &dyn AnyTypedEnvelope| {
257 (envelope as &dyn Any)
258 .downcast_ref::<TypedEnvelope<M>>()
259 .unwrap()
260 .payload
261 .remote_entity_id()
262 };
263 self.0
264 .message_handler_set()
265 .lock()
266 .add_entity_message_handler(
267 message_type_id,
268 entity_type_id,
269 entity_id_extractor,
270 Arc::new(move |entity, envelope, client, cx| {
271 let entity = entity.downcast::<E>().unwrap();
272 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
273 let request_id = envelope.message_id();
274 handler(entity, *envelope, cx)
275 .then(move |result| async move {
276 match result {
277 Ok(response) => {
278 client.send_response(request_id, response)?;
279 Ok(())
280 }
281 Err(error) => {
282 client.send_response(request_id, error.to_proto())?;
283 Err(error)
284 }
285 }
286 })
287 .boxed_local()
288 }),
289 );
290 }
291
292 pub fn add_entity_message_handler<M, E, H, F>(&self, handler: H)
293 where
294 M: EnvelopedMessage + EntityMessage,
295 E: 'static,
296 H: 'static + Sync + Send + Fn(gpui::Entity<E>, TypedEnvelope<M>, AsyncApp) -> F,
297 F: 'static + Future<Output = anyhow::Result<()>>,
298 {
299 let message_type_id = TypeId::of::<M>();
300 let entity_type_id = TypeId::of::<E>();
301 let entity_id_extractor = |envelope: &dyn AnyTypedEnvelope| {
302 (envelope as &dyn Any)
303 .downcast_ref::<TypedEnvelope<M>>()
304 .unwrap()
305 .payload
306 .remote_entity_id()
307 };
308 self.0
309 .message_handler_set()
310 .lock()
311 .add_entity_message_handler(
312 message_type_id,
313 entity_type_id,
314 entity_id_extractor,
315 Arc::new(move |entity, envelope, _, cx| {
316 let entity = entity.downcast::<E>().unwrap();
317 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
318 handler(entity, *envelope, cx).boxed_local()
319 }),
320 );
321 }
322}