1use anyhow::anyhow;
2use collections::HashMap;
3use futures::{
4 future::{BoxFuture, LocalBoxFuture},
5 Future, FutureExt as _,
6};
7use gpui::{AnyModel, AnyWeakModel, AsyncAppContext, Model};
8// pub use prost::Message;
9use proto::{
10 error::ErrorExt as _, AnyTypedEnvelope, EntityMessage, Envelope, EnvelopedMessage,
11 RequestMessage, TypedEnvelope,
12};
13use std::{
14 any::TypeId,
15 sync::{Arc, Weak},
16};
17
18#[derive(Clone)]
19pub struct AnyProtoClient(Arc<dyn ProtoClient>);
20
21impl AnyProtoClient {
22 pub fn downgrade(&self) -> AnyWeakProtoClient {
23 AnyWeakProtoClient(Arc::downgrade(&self.0))
24 }
25}
26
27#[derive(Clone)]
28pub struct AnyWeakProtoClient(Weak<dyn ProtoClient>);
29
30impl AnyWeakProtoClient {
31 pub fn upgrade(&self) -> Option<AnyProtoClient> {
32 self.0.upgrade().map(AnyProtoClient)
33 }
34}
35
36pub trait ProtoClient: Send + Sync {
37 fn request(
38 &self,
39 envelope: Envelope,
40 request_type: &'static str,
41 ) -> BoxFuture<'static, anyhow::Result<Envelope>>;
42
43 fn send(&self, envelope: Envelope, message_type: &'static str) -> anyhow::Result<()>;
44
45 fn send_response(&self, envelope: Envelope, message_type: &'static str) -> anyhow::Result<()>;
46
47 fn message_handler_set(&self) -> &parking_lot::Mutex<ProtoMessageHandlerSet>;
48
49 fn is_via_collab(&self) -> bool;
50}
51
52#[derive(Default)]
53pub struct ProtoMessageHandlerSet {
54 pub entity_types_by_message_type: HashMap<TypeId, TypeId>,
55 pub entities_by_type_and_remote_id: HashMap<(TypeId, u64), EntityMessageSubscriber>,
56 pub entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
57 pub models_by_message_type: HashMap<TypeId, AnyWeakModel>,
58 pub message_handlers: HashMap<TypeId, ProtoMessageHandler>,
59}
60
61pub type ProtoMessageHandler = Arc<
62 dyn Send
63 + Sync
64 + Fn(
65 AnyModel,
66 Box<dyn AnyTypedEnvelope>,
67 AnyProtoClient,
68 AsyncAppContext,
69 ) -> LocalBoxFuture<'static, anyhow::Result<()>>,
70>;
71
72impl ProtoMessageHandlerSet {
73 pub fn clear(&mut self) {
74 self.message_handlers.clear();
75 self.models_by_message_type.clear();
76 self.entities_by_type_and_remote_id.clear();
77 self.entity_id_extractors.clear();
78 }
79
80 fn add_message_handler(
81 &mut self,
82 message_type_id: TypeId,
83 model: gpui::AnyWeakModel,
84 handler: ProtoMessageHandler,
85 ) {
86 self.models_by_message_type.insert(message_type_id, model);
87 let prev_handler = self.message_handlers.insert(message_type_id, handler);
88 if prev_handler.is_some() {
89 panic!("registered handler for the same message twice");
90 }
91 }
92
93 fn add_entity_message_handler(
94 &mut self,
95 message_type_id: TypeId,
96 model_type_id: TypeId,
97 entity_id_extractor: fn(&dyn AnyTypedEnvelope) -> u64,
98 handler: ProtoMessageHandler,
99 ) {
100 self.entity_id_extractors
101 .entry(message_type_id)
102 .or_insert(entity_id_extractor);
103 self.entity_types_by_message_type
104 .insert(message_type_id, model_type_id);
105 let prev_handler = self.message_handlers.insert(message_type_id, handler);
106 if prev_handler.is_some() {
107 panic!("registered handler for the same message twice");
108 }
109 }
110
111 pub fn handle_message(
112 this: &parking_lot::Mutex<Self>,
113 message: Box<dyn AnyTypedEnvelope>,
114 client: AnyProtoClient,
115 cx: AsyncAppContext,
116 ) -> Option<LocalBoxFuture<'static, anyhow::Result<()>>> {
117 let payload_type_id = message.payload_type_id();
118 let mut this = this.lock();
119 let handler = this.message_handlers.get(&payload_type_id)?.clone();
120 let entity = if let Some(entity) = this.models_by_message_type.get(&payload_type_id) {
121 entity.upgrade()?
122 } else {
123 let extract_entity_id = *this.entity_id_extractors.get(&payload_type_id)?;
124 let entity_type_id = *this.entity_types_by_message_type.get(&payload_type_id)?;
125 let entity_id = (extract_entity_id)(message.as_ref());
126
127 match this
128 .entities_by_type_and_remote_id
129 .get_mut(&(entity_type_id, entity_id))?
130 {
131 EntityMessageSubscriber::Pending(pending) => {
132 pending.push(message);
133 return None;
134 }
135 EntityMessageSubscriber::Entity { handle } => handle.upgrade()?,
136 }
137 };
138 drop(this);
139 Some(handler(entity, message, client, cx))
140 }
141}
142
143pub enum EntityMessageSubscriber {
144 Entity { handle: AnyWeakModel },
145 Pending(Vec<Box<dyn AnyTypedEnvelope>>),
146}
147
148impl<T> From<Arc<T>> for AnyProtoClient
149where
150 T: ProtoClient + 'static,
151{
152 fn from(client: Arc<T>) -> Self {
153 Self(client)
154 }
155}
156
157impl AnyProtoClient {
158 pub fn new<T: ProtoClient + 'static>(client: Arc<T>) -> Self {
159 Self(client)
160 }
161
162 pub fn is_via_collab(&self) -> bool {
163 self.0.is_via_collab()
164 }
165
166 pub fn request<T: RequestMessage>(
167 &self,
168 request: T,
169 ) -> impl Future<Output = anyhow::Result<T::Response>> {
170 let envelope = request.into_envelope(0, None, None);
171 let response = self.0.request(envelope, T::NAME);
172 async move {
173 T::Response::from_envelope(response.await?)
174 .ok_or_else(|| anyhow!("received response of the wrong type"))
175 }
176 }
177
178 pub fn send<T: EnvelopedMessage>(&self, request: T) -> anyhow::Result<()> {
179 let envelope = request.into_envelope(0, None, None);
180 self.0.send(envelope, T::NAME)
181 }
182
183 pub fn send_response<T: EnvelopedMessage>(
184 &self,
185 request_id: u32,
186 request: T,
187 ) -> anyhow::Result<()> {
188 let envelope = request.into_envelope(0, Some(request_id), None);
189 self.0.send(envelope, T::NAME)
190 }
191
192 pub fn add_request_handler<M, E, H, F>(&self, model: gpui::WeakModel<E>, handler: H)
193 where
194 M: RequestMessage,
195 E: 'static,
196 H: 'static + Sync + Fn(Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F + Send + Sync,
197 F: 'static + Future<Output = anyhow::Result<M::Response>>,
198 {
199 self.0.message_handler_set().lock().add_message_handler(
200 TypeId::of::<M>(),
201 model.into(),
202 Arc::new(move |model, envelope, client, cx| {
203 let model = model.downcast::<E>().unwrap();
204 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
205 let request_id = envelope.message_id();
206 handler(model, *envelope, cx)
207 .then(move |result| async move {
208 match result {
209 Ok(response) => {
210 client.send_response(request_id, response)?;
211 Ok(())
212 }
213 Err(error) => {
214 client.send_response(request_id, error.to_proto())?;
215 Err(error)
216 }
217 }
218 })
219 .boxed_local()
220 }),
221 )
222 }
223
224 pub fn add_model_request_handler<M, E, H, F>(&self, handler: H)
225 where
226 M: EnvelopedMessage + RequestMessage + EntityMessage,
227 E: 'static,
228 H: 'static + Sync + Send + Fn(gpui::Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F,
229 F: 'static + Future<Output = anyhow::Result<M::Response>>,
230 {
231 let message_type_id = TypeId::of::<M>();
232 let model_type_id = TypeId::of::<E>();
233 let entity_id_extractor = |envelope: &dyn AnyTypedEnvelope| {
234 envelope
235 .as_any()
236 .downcast_ref::<TypedEnvelope<M>>()
237 .unwrap()
238 .payload
239 .remote_entity_id()
240 };
241 self.0
242 .message_handler_set()
243 .lock()
244 .add_entity_message_handler(
245 message_type_id,
246 model_type_id,
247 entity_id_extractor,
248 Arc::new(move |model, envelope, client, cx| {
249 let model = model.downcast::<E>().unwrap();
250 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
251 let request_id = envelope.message_id();
252 handler(model, *envelope, cx)
253 .then(move |result| async move {
254 match result {
255 Ok(response) => {
256 client.send_response(request_id, response)?;
257 Ok(())
258 }
259 Err(error) => {
260 client.send_response(request_id, error.to_proto())?;
261 Err(error)
262 }
263 }
264 })
265 .boxed_local()
266 }),
267 );
268 }
269
270 pub fn add_model_message_handler<M, E, H, F>(&self, handler: H)
271 where
272 M: EnvelopedMessage + EntityMessage,
273 E: 'static,
274 H: 'static + Sync + Send + Fn(gpui::Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F,
275 F: 'static + Future<Output = anyhow::Result<()>>,
276 {
277 let message_type_id = TypeId::of::<M>();
278 let model_type_id = TypeId::of::<E>();
279 let entity_id_extractor = |envelope: &dyn AnyTypedEnvelope| {
280 envelope
281 .as_any()
282 .downcast_ref::<TypedEnvelope<M>>()
283 .unwrap()
284 .payload
285 .remote_entity_id()
286 };
287 self.0
288 .message_handler_set()
289 .lock()
290 .add_entity_message_handler(
291 message_type_id,
292 model_type_id,
293 entity_id_extractor,
294 Arc::new(move |model, envelope, _, cx| {
295 let model = model.downcast::<E>().unwrap();
296 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
297 handler(model, *envelope, cx).boxed_local()
298 }),
299 );
300 }
301}