1use anyhow::Context;
2use collections::HashMap;
3use futures::{
4 Future, FutureExt as _,
5 future::{BoxFuture, LocalBoxFuture},
6};
7use gpui::{AnyEntity, AnyWeakEntity, AsyncApp, Entity};
8use proto::{
9 AnyTypedEnvelope, EntityMessage, Envelope, EnvelopedMessage, RequestMessage, TypedEnvelope,
10 error::ErrorExt as _,
11};
12use std::{
13 any::{Any, TypeId},
14 sync::{Arc, Weak},
15};
16
17#[derive(Clone)]
18pub struct AnyProtoClient(Arc<dyn ProtoClient>);
19
20impl AnyProtoClient {
21 pub fn downgrade(&self) -> AnyWeakProtoClient {
22 AnyWeakProtoClient(Arc::downgrade(&self.0))
23 }
24}
25
26#[derive(Clone)]
27pub struct AnyWeakProtoClient(Weak<dyn ProtoClient>);
28
29impl AnyWeakProtoClient {
30 pub fn upgrade(&self) -> Option<AnyProtoClient> {
31 self.0.upgrade().map(AnyProtoClient)
32 }
33}
34
35pub trait ProtoClient: Send + Sync {
36 fn request(
37 &self,
38 envelope: Envelope,
39 request_type: &'static str,
40 ) -> BoxFuture<'static, anyhow::Result<Envelope>>;
41
42 fn send(&self, envelope: Envelope, message_type: &'static str) -> anyhow::Result<()>;
43
44 fn send_response(&self, envelope: Envelope, message_type: &'static str) -> anyhow::Result<()>;
45
46 fn message_handler_set(&self) -> &parking_lot::Mutex<ProtoMessageHandlerSet>;
47
48 fn is_via_collab(&self) -> bool;
49}
50
51#[derive(Default)]
52pub struct ProtoMessageHandlerSet {
53 pub entity_types_by_message_type: HashMap<TypeId, TypeId>,
54 pub entities_by_type_and_remote_id: HashMap<(TypeId, u64), EntityMessageSubscriber>,
55 pub entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
56 pub entities_by_message_type: HashMap<TypeId, AnyWeakEntity>,
57 pub message_handlers: HashMap<TypeId, ProtoMessageHandler>,
58}
59
60pub type ProtoMessageHandler = Arc<
61 dyn Send
62 + Sync
63 + Fn(
64 AnyEntity,
65 Box<dyn AnyTypedEnvelope>,
66 AnyProtoClient,
67 AsyncApp,
68 ) -> LocalBoxFuture<'static, anyhow::Result<()>>,
69>;
70
71impl ProtoMessageHandlerSet {
72 pub fn clear(&mut self) {
73 self.message_handlers.clear();
74 self.entities_by_message_type.clear();
75 self.entities_by_type_and_remote_id.clear();
76 self.entity_id_extractors.clear();
77 }
78
79 fn add_message_handler(
80 &mut self,
81 message_type_id: TypeId,
82 entity: gpui::AnyWeakEntity,
83 handler: ProtoMessageHandler,
84 ) {
85 self.entities_by_message_type
86 .insert(message_type_id, entity);
87 let prev_handler = self.message_handlers.insert(message_type_id, handler);
88 if prev_handler.is_some() {
89 panic!("registered handler for the same message twice");
90 }
91 }
92
93 fn add_entity_message_handler(
94 &mut self,
95 message_type_id: TypeId,
96 entity_type_id: TypeId,
97 entity_id_extractor: fn(&dyn AnyTypedEnvelope) -> u64,
98 handler: ProtoMessageHandler,
99 ) {
100 self.entity_id_extractors
101 .entry(message_type_id)
102 .or_insert(entity_id_extractor);
103 self.entity_types_by_message_type
104 .insert(message_type_id, entity_type_id);
105 let prev_handler = self.message_handlers.insert(message_type_id, handler);
106 if prev_handler.is_some() {
107 panic!("registered handler for the same message twice");
108 }
109 }
110
111 pub fn handle_message(
112 this: &parking_lot::Mutex<Self>,
113 message: Box<dyn AnyTypedEnvelope>,
114 client: AnyProtoClient,
115 cx: AsyncApp,
116 ) -> Option<LocalBoxFuture<'static, anyhow::Result<()>>> {
117 let payload_type_id = message.payload_type_id();
118 let mut this = this.lock();
119 let handler = this.message_handlers.get(&payload_type_id)?.clone();
120 let entity = if let Some(entity) = this.entities_by_message_type.get(&payload_type_id) {
121 entity.upgrade()?
122 } else {
123 let extract_entity_id = *this.entity_id_extractors.get(&payload_type_id)?;
124 let entity_type_id = *this.entity_types_by_message_type.get(&payload_type_id)?;
125 let entity_id = (extract_entity_id)(message.as_ref());
126 match this
127 .entities_by_type_and_remote_id
128 .get_mut(&(entity_type_id, entity_id))?
129 {
130 EntityMessageSubscriber::Pending(pending) => {
131 pending.push(message);
132 return None;
133 }
134 EntityMessageSubscriber::Entity { handle } => handle.upgrade()?,
135 }
136 };
137 drop(this);
138 Some(handler(entity, message, client, cx))
139 }
140}
141
142pub enum EntityMessageSubscriber {
143 Entity { handle: AnyWeakEntity },
144 Pending(Vec<Box<dyn AnyTypedEnvelope>>),
145}
146
147impl std::fmt::Debug for EntityMessageSubscriber {
148 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149 match self {
150 EntityMessageSubscriber::Entity { handle } => f
151 .debug_struct("EntityMessageSubscriber::Entity")
152 .field("handle", handle)
153 .finish(),
154 EntityMessageSubscriber::Pending(vec) => f
155 .debug_struct("EntityMessageSubscriber::Pending")
156 .field(
157 "envelopes",
158 &vec.iter()
159 .map(|envelope| envelope.payload_type_name())
160 .collect::<Vec<_>>(),
161 )
162 .finish(),
163 }
164 }
165}
166
167impl<T> From<Arc<T>> for AnyProtoClient
168where
169 T: ProtoClient + 'static,
170{
171 fn from(client: Arc<T>) -> Self {
172 Self(client)
173 }
174}
175
176impl AnyProtoClient {
177 pub fn new<T: ProtoClient + 'static>(client: Arc<T>) -> Self {
178 Self(client)
179 }
180
181 pub fn is_via_collab(&self) -> bool {
182 self.0.is_via_collab()
183 }
184
185 pub fn request<T: RequestMessage>(
186 &self,
187 request: T,
188 ) -> impl Future<Output = anyhow::Result<T::Response>> + use<T> {
189 let envelope = request.into_envelope(0, None, None);
190 let response = self.0.request(envelope, T::NAME);
191 async move {
192 T::Response::from_envelope(response.await?)
193 .context("received response of the wrong type")
194 }
195 }
196
197 pub fn send<T: EnvelopedMessage>(&self, request: T) -> anyhow::Result<()> {
198 let envelope = request.into_envelope(0, None, None);
199 self.0.send(envelope, T::NAME)
200 }
201
202 pub fn send_response<T: EnvelopedMessage>(
203 &self,
204 request_id: u32,
205 request: T,
206 ) -> anyhow::Result<()> {
207 let envelope = request.into_envelope(0, Some(request_id), None);
208 self.0.send(envelope, T::NAME)
209 }
210
211 pub fn add_request_handler<M, E, H, F>(&self, entity: gpui::WeakEntity<E>, handler: H)
212 where
213 M: RequestMessage,
214 E: 'static,
215 H: 'static + Sync + Fn(Entity<E>, TypedEnvelope<M>, AsyncApp) -> F + Send + Sync,
216 F: 'static + Future<Output = anyhow::Result<M::Response>>,
217 {
218 self.0.message_handler_set().lock().add_message_handler(
219 TypeId::of::<M>(),
220 entity.into(),
221 Arc::new(move |entity, envelope, client, cx| {
222 let entity = entity.downcast::<E>().unwrap();
223 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
224 let request_id = envelope.message_id();
225 handler(entity, *envelope, cx)
226 .then(move |result| async move {
227 match result {
228 Ok(response) => {
229 client.send_response(request_id, response)?;
230 Ok(())
231 }
232 Err(error) => {
233 client.send_response(request_id, error.to_proto())?;
234 Err(error)
235 }
236 }
237 })
238 .boxed_local()
239 }),
240 )
241 }
242
243 pub fn add_entity_request_handler<M, E, H, F>(&self, handler: H)
244 where
245 M: EnvelopedMessage + RequestMessage + EntityMessage,
246 E: 'static,
247 H: 'static + Sync + Send + Fn(gpui::Entity<E>, TypedEnvelope<M>, AsyncApp) -> F,
248 F: 'static + Future<Output = anyhow::Result<M::Response>>,
249 {
250 let message_type_id = TypeId::of::<M>();
251 let entity_type_id = TypeId::of::<E>();
252 let entity_id_extractor = |envelope: &dyn AnyTypedEnvelope| {
253 (envelope as &dyn Any)
254 .downcast_ref::<TypedEnvelope<M>>()
255 .unwrap()
256 .payload
257 .remote_entity_id()
258 };
259 self.0
260 .message_handler_set()
261 .lock()
262 .add_entity_message_handler(
263 message_type_id,
264 entity_type_id,
265 entity_id_extractor,
266 Arc::new(move |entity, envelope, client, cx| {
267 let entity = entity.downcast::<E>().unwrap();
268 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
269 let request_id = envelope.message_id();
270 handler(entity, *envelope, cx)
271 .then(move |result| async move {
272 match result {
273 Ok(response) => {
274 client.send_response(request_id, response)?;
275 Ok(())
276 }
277 Err(error) => {
278 client.send_response(request_id, error.to_proto())?;
279 Err(error)
280 }
281 }
282 })
283 .boxed_local()
284 }),
285 );
286 }
287
288 pub fn add_entity_message_handler<M, E, H, F>(&self, handler: H)
289 where
290 M: EnvelopedMessage + EntityMessage,
291 E: 'static,
292 H: 'static + Sync + Send + Fn(gpui::Entity<E>, TypedEnvelope<M>, AsyncApp) -> F,
293 F: 'static + Future<Output = anyhow::Result<()>>,
294 {
295 let message_type_id = TypeId::of::<M>();
296 let entity_type_id = TypeId::of::<E>();
297 let entity_id_extractor = |envelope: &dyn AnyTypedEnvelope| {
298 (envelope as &dyn Any)
299 .downcast_ref::<TypedEnvelope<M>>()
300 .unwrap()
301 .payload
302 .remote_entity_id()
303 };
304 self.0
305 .message_handler_set()
306 .lock()
307 .add_entity_message_handler(
308 message_type_id,
309 entity_type_id,
310 entity_id_extractor,
311 Arc::new(move |entity, envelope, _, cx| {
312 let entity = entity.downcast::<E>().unwrap();
313 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
314 handler(entity, *envelope, cx).boxed_local()
315 }),
316 );
317 }
318
319 pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Entity<E>) {
320 let id = (TypeId::of::<E>(), remote_id);
321
322 let mut message_handlers = self.0.message_handler_set().lock();
323 if message_handlers
324 .entities_by_type_and_remote_id
325 .contains_key(&id)
326 {
327 panic!("already subscribed to entity");
328 }
329
330 message_handlers.entities_by_type_and_remote_id.insert(
331 id,
332 EntityMessageSubscriber::Entity {
333 handle: entity.downgrade().into(),
334 },
335 );
336 }
337}