1use anyhow::anyhow;
2use collections::HashMap;
3use futures::{
4 future::{BoxFuture, LocalBoxFuture},
5 Future, FutureExt as _,
6};
7use gpui::{AnyModel, AnyWeakModel, AsyncAppContext, Model};
8use proto::{
9 error::ErrorExt as _, AnyTypedEnvelope, EntityMessage, Envelope, EnvelopedMessage,
10 RequestMessage, TypedEnvelope,
11};
12use std::{
13 any::TypeId,
14 sync::{Arc, Weak},
15};
16
17#[derive(Clone)]
18pub struct AnyProtoClient(Arc<dyn ProtoClient>);
19
20impl AnyProtoClient {
21 pub fn downgrade(&self) -> AnyWeakProtoClient {
22 AnyWeakProtoClient(Arc::downgrade(&self.0))
23 }
24}
25
26#[derive(Clone)]
27pub struct AnyWeakProtoClient(Weak<dyn ProtoClient>);
28
29impl AnyWeakProtoClient {
30 pub fn upgrade(&self) -> Option<AnyProtoClient> {
31 self.0.upgrade().map(AnyProtoClient)
32 }
33}
34
35pub trait ProtoClient: Send + Sync {
36 fn request(
37 &self,
38 envelope: Envelope,
39 request_type: &'static str,
40 ) -> BoxFuture<'static, anyhow::Result<Envelope>>;
41
42 fn send(&self, envelope: Envelope, message_type: &'static str) -> anyhow::Result<()>;
43
44 fn send_response(&self, envelope: Envelope, message_type: &'static str) -> anyhow::Result<()>;
45
46 fn message_handler_set(&self) -> &parking_lot::Mutex<ProtoMessageHandlerSet>;
47
48 fn is_via_collab(&self) -> bool;
49}
50
51#[derive(Default)]
52pub struct ProtoMessageHandlerSet {
53 pub entity_types_by_message_type: HashMap<TypeId, TypeId>,
54 pub entities_by_type_and_remote_id: HashMap<(TypeId, u64), EntityMessageSubscriber>,
55 pub entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
56 pub models_by_message_type: HashMap<TypeId, AnyWeakModel>,
57 pub message_handlers: HashMap<TypeId, ProtoMessageHandler>,
58}
59
60pub type ProtoMessageHandler = Arc<
61 dyn Send
62 + Sync
63 + Fn(
64 AnyModel,
65 Box<dyn AnyTypedEnvelope>,
66 AnyProtoClient,
67 AsyncAppContext,
68 ) -> LocalBoxFuture<'static, anyhow::Result<()>>,
69>;
70
71impl ProtoMessageHandlerSet {
72 pub fn clear(&mut self) {
73 self.message_handlers.clear();
74 self.models_by_message_type.clear();
75 self.entities_by_type_and_remote_id.clear();
76 self.entity_id_extractors.clear();
77 }
78
79 fn add_message_handler(
80 &mut self,
81 message_type_id: TypeId,
82 model: gpui::AnyWeakModel,
83 handler: ProtoMessageHandler,
84 ) {
85 self.models_by_message_type.insert(message_type_id, model);
86 let prev_handler = self.message_handlers.insert(message_type_id, handler);
87 if prev_handler.is_some() {
88 panic!("registered handler for the same message twice");
89 }
90 }
91
92 fn add_entity_message_handler(
93 &mut self,
94 message_type_id: TypeId,
95 model_type_id: TypeId,
96 entity_id_extractor: fn(&dyn AnyTypedEnvelope) -> u64,
97 handler: ProtoMessageHandler,
98 ) {
99 self.entity_id_extractors
100 .entry(message_type_id)
101 .or_insert(entity_id_extractor);
102 self.entity_types_by_message_type
103 .insert(message_type_id, model_type_id);
104 let prev_handler = self.message_handlers.insert(message_type_id, handler);
105 if prev_handler.is_some() {
106 panic!("registered handler for the same message twice");
107 }
108 }
109
110 pub fn handle_message(
111 this: &parking_lot::Mutex<Self>,
112 message: Box<dyn AnyTypedEnvelope>,
113 client: AnyProtoClient,
114 cx: AsyncAppContext,
115 ) -> Option<LocalBoxFuture<'static, anyhow::Result<()>>> {
116 let payload_type_id = message.payload_type_id();
117 let mut this = this.lock();
118 let handler = this.message_handlers.get(&payload_type_id)?.clone();
119 let entity = if let Some(entity) = this.models_by_message_type.get(&payload_type_id) {
120 entity.upgrade()?
121 } else {
122 let extract_entity_id = *this.entity_id_extractors.get(&payload_type_id)?;
123 let entity_type_id = *this.entity_types_by_message_type.get(&payload_type_id)?;
124 let entity_id = (extract_entity_id)(message.as_ref());
125 match this
126 .entities_by_type_and_remote_id
127 .get_mut(&(entity_type_id, entity_id))?
128 {
129 EntityMessageSubscriber::Pending(pending) => {
130 pending.push(message);
131 return None;
132 }
133 EntityMessageSubscriber::Entity { handle } => handle.upgrade()?,
134 }
135 };
136 drop(this);
137 Some(handler(entity, message, client, cx))
138 }
139}
140
141pub enum EntityMessageSubscriber {
142 Entity { handle: AnyWeakModel },
143 Pending(Vec<Box<dyn AnyTypedEnvelope>>),
144}
145
146impl std::fmt::Debug for EntityMessageSubscriber {
147 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
148 match self {
149 EntityMessageSubscriber::Entity { handle } => f
150 .debug_struct("EntityMessageSubscriber::Entity")
151 .field("handle", handle)
152 .finish(),
153 EntityMessageSubscriber::Pending(vec) => f
154 .debug_struct("EntityMessageSubscriber::Pending")
155 .field(
156 "envelopes",
157 &vec.iter()
158 .map(|envelope| envelope.payload_type_name())
159 .collect::<Vec<_>>(),
160 )
161 .finish(),
162 }
163 }
164}
165
166impl<T> From<Arc<T>> for AnyProtoClient
167where
168 T: ProtoClient + 'static,
169{
170 fn from(client: Arc<T>) -> Self {
171 Self(client)
172 }
173}
174
175impl AnyProtoClient {
176 pub fn new<T: ProtoClient + 'static>(client: Arc<T>) -> Self {
177 Self(client)
178 }
179
180 pub fn is_via_collab(&self) -> bool {
181 self.0.is_via_collab()
182 }
183
184 pub fn request<T: RequestMessage>(
185 &self,
186 request: T,
187 ) -> impl Future<Output = anyhow::Result<T::Response>> {
188 let envelope = request.into_envelope(0, None, None);
189 let response = self.0.request(envelope, T::NAME);
190 async move {
191 T::Response::from_envelope(response.await?)
192 .ok_or_else(|| anyhow!("received response of the wrong type"))
193 }
194 }
195
196 pub fn send<T: EnvelopedMessage>(&self, request: T) -> anyhow::Result<()> {
197 let envelope = request.into_envelope(0, None, None);
198 self.0.send(envelope, T::NAME)
199 }
200
201 pub fn send_response<T: EnvelopedMessage>(
202 &self,
203 request_id: u32,
204 request: T,
205 ) -> anyhow::Result<()> {
206 let envelope = request.into_envelope(0, Some(request_id), None);
207 self.0.send(envelope, T::NAME)
208 }
209
210 pub fn add_request_handler<M, E, H, F>(&self, model: gpui::WeakModel<E>, handler: H)
211 where
212 M: RequestMessage,
213 E: 'static,
214 H: 'static + Sync + Fn(Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F + Send + Sync,
215 F: 'static + Future<Output = anyhow::Result<M::Response>>,
216 {
217 self.0.message_handler_set().lock().add_message_handler(
218 TypeId::of::<M>(),
219 model.into(),
220 Arc::new(move |model, envelope, client, cx| {
221 let model = model.downcast::<E>().unwrap();
222 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
223 let request_id = envelope.message_id();
224 handler(model, *envelope, cx)
225 .then(move |result| async move {
226 match result {
227 Ok(response) => {
228 client.send_response(request_id, response)?;
229 Ok(())
230 }
231 Err(error) => {
232 client.send_response(request_id, error.to_proto())?;
233 Err(error)
234 }
235 }
236 })
237 .boxed_local()
238 }),
239 )
240 }
241
242 pub fn add_model_request_handler<M, E, H, F>(&self, handler: H)
243 where
244 M: EnvelopedMessage + RequestMessage + EntityMessage,
245 E: 'static,
246 H: 'static + Sync + Send + Fn(gpui::Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F,
247 F: 'static + Future<Output = anyhow::Result<M::Response>>,
248 {
249 let message_type_id = TypeId::of::<M>();
250 let model_type_id = TypeId::of::<E>();
251 let entity_id_extractor = |envelope: &dyn AnyTypedEnvelope| {
252 envelope
253 .as_any()
254 .downcast_ref::<TypedEnvelope<M>>()
255 .unwrap()
256 .payload
257 .remote_entity_id()
258 };
259 self.0
260 .message_handler_set()
261 .lock()
262 .add_entity_message_handler(
263 message_type_id,
264 model_type_id,
265 entity_id_extractor,
266 Arc::new(move |model, envelope, client, cx| {
267 let model = model.downcast::<E>().unwrap();
268 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
269 let request_id = envelope.message_id();
270 handler(model, *envelope, cx)
271 .then(move |result| async move {
272 match result {
273 Ok(response) => {
274 client.send_response(request_id, response)?;
275 Ok(())
276 }
277 Err(error) => {
278 client.send_response(request_id, error.to_proto())?;
279 Err(error)
280 }
281 }
282 })
283 .boxed_local()
284 }),
285 );
286 }
287
288 pub fn add_model_message_handler<M, E, H, F>(&self, handler: H)
289 where
290 M: EnvelopedMessage + EntityMessage,
291 E: 'static,
292 H: 'static + Sync + Send + Fn(gpui::Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F,
293 F: 'static + Future<Output = anyhow::Result<()>>,
294 {
295 let message_type_id = TypeId::of::<M>();
296 let model_type_id = TypeId::of::<E>();
297 let entity_id_extractor = |envelope: &dyn AnyTypedEnvelope| {
298 envelope
299 .as_any()
300 .downcast_ref::<TypedEnvelope<M>>()
301 .unwrap()
302 .payload
303 .remote_entity_id()
304 };
305 self.0
306 .message_handler_set()
307 .lock()
308 .add_entity_message_handler(
309 message_type_id,
310 model_type_id,
311 entity_id_extractor,
312 Arc::new(move |model, envelope, _, cx| {
313 let model = model.downcast::<E>().unwrap();
314 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
315 handler(model, *envelope, cx).boxed_local()
316 }),
317 );
318 }
319}