1use anyhow::anyhow;
2use collections::HashMap;
3use futures::{
4 future::{BoxFuture, LocalBoxFuture},
5 Future, FutureExt as _,
6};
7use gpui::{AnyModel, AnyWeakModel, AsyncAppContext, Model};
8// pub use prost::Message;
9use proto::{
10 error::ErrorExt as _, AnyTypedEnvelope, EntityMessage, Envelope, EnvelopedMessage,
11 RequestMessage, TypedEnvelope,
12};
13use std::{any::TypeId, sync::Arc};
14
15#[derive(Clone)]
16pub struct AnyProtoClient(Arc<dyn ProtoClient>);
17
18pub trait ProtoClient: Send + Sync {
19 fn request(
20 &self,
21 envelope: Envelope,
22 request_type: &'static str,
23 ) -> BoxFuture<'static, anyhow::Result<Envelope>>;
24
25 fn send(&self, envelope: Envelope, message_type: &'static str) -> anyhow::Result<()>;
26
27 fn send_response(&self, envelope: Envelope, message_type: &'static str) -> anyhow::Result<()>;
28
29 fn message_handler_set(&self) -> &parking_lot::Mutex<ProtoMessageHandlerSet>;
30
31 fn is_via_collab(&self) -> bool;
32}
33
34#[derive(Default)]
35pub struct ProtoMessageHandlerSet {
36 pub entity_types_by_message_type: HashMap<TypeId, TypeId>,
37 pub entities_by_type_and_remote_id: HashMap<(TypeId, u64), EntityMessageSubscriber>,
38 pub entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
39 pub models_by_message_type: HashMap<TypeId, AnyWeakModel>,
40 pub message_handlers: HashMap<TypeId, ProtoMessageHandler>,
41}
42
43pub type ProtoMessageHandler = Arc<
44 dyn Send
45 + Sync
46 + Fn(
47 AnyModel,
48 Box<dyn AnyTypedEnvelope>,
49 AnyProtoClient,
50 AsyncAppContext,
51 ) -> LocalBoxFuture<'static, anyhow::Result<()>>,
52>;
53
54impl ProtoMessageHandlerSet {
55 pub fn clear(&mut self) {
56 self.message_handlers.clear();
57 self.models_by_message_type.clear();
58 self.entities_by_type_and_remote_id.clear();
59 self.entity_id_extractors.clear();
60 }
61
62 fn add_message_handler(
63 &mut self,
64 message_type_id: TypeId,
65 model: gpui::AnyWeakModel,
66 handler: ProtoMessageHandler,
67 ) {
68 self.models_by_message_type.insert(message_type_id, model);
69 let prev_handler = self.message_handlers.insert(message_type_id, handler);
70 if prev_handler.is_some() {
71 panic!("registered handler for the same message twice");
72 }
73 }
74
75 fn add_entity_message_handler(
76 &mut self,
77 message_type_id: TypeId,
78 model_type_id: TypeId,
79 entity_id_extractor: fn(&dyn AnyTypedEnvelope) -> u64,
80 handler: ProtoMessageHandler,
81 ) {
82 self.entity_id_extractors
83 .entry(message_type_id)
84 .or_insert(entity_id_extractor);
85 self.entity_types_by_message_type
86 .insert(message_type_id, model_type_id);
87 let prev_handler = self.message_handlers.insert(message_type_id, handler);
88 if prev_handler.is_some() {
89 panic!("registered handler for the same message twice");
90 }
91 }
92
93 pub fn handle_message(
94 this: &parking_lot::Mutex<Self>,
95 message: Box<dyn AnyTypedEnvelope>,
96 client: AnyProtoClient,
97 cx: AsyncAppContext,
98 ) -> Option<LocalBoxFuture<'static, anyhow::Result<()>>> {
99 let payload_type_id = message.payload_type_id();
100 let mut this = this.lock();
101 let handler = this.message_handlers.get(&payload_type_id)?.clone();
102 let entity = if let Some(entity) = this.models_by_message_type.get(&payload_type_id) {
103 entity.upgrade()?
104 } else {
105 let extract_entity_id = *this.entity_id_extractors.get(&payload_type_id)?;
106 let entity_type_id = *this.entity_types_by_message_type.get(&payload_type_id)?;
107 let entity_id = (extract_entity_id)(message.as_ref());
108
109 match this
110 .entities_by_type_and_remote_id
111 .get_mut(&(entity_type_id, entity_id))?
112 {
113 EntityMessageSubscriber::Pending(pending) => {
114 pending.push(message);
115 return None;
116 }
117 EntityMessageSubscriber::Entity { handle } => handle.upgrade()?,
118 }
119 };
120 drop(this);
121 Some(handler(entity, message, client, cx))
122 }
123}
124
125pub enum EntityMessageSubscriber {
126 Entity { handle: AnyWeakModel },
127 Pending(Vec<Box<dyn AnyTypedEnvelope>>),
128}
129
130impl<T> From<Arc<T>> for AnyProtoClient
131where
132 T: ProtoClient + 'static,
133{
134 fn from(client: Arc<T>) -> Self {
135 Self(client)
136 }
137}
138
139impl AnyProtoClient {
140 pub fn new<T: ProtoClient + 'static>(client: Arc<T>) -> Self {
141 Self(client)
142 }
143
144 pub fn is_via_collab(&self) -> bool {
145 self.0.is_via_collab()
146 }
147
148 pub fn request<T: RequestMessage>(
149 &self,
150 request: T,
151 ) -> impl Future<Output = anyhow::Result<T::Response>> {
152 let envelope = request.into_envelope(0, None, None);
153 let response = self.0.request(envelope, T::NAME);
154 async move {
155 T::Response::from_envelope(response.await?)
156 .ok_or_else(|| anyhow!("received response of the wrong type"))
157 }
158 }
159
160 pub fn send<T: EnvelopedMessage>(&self, request: T) -> anyhow::Result<()> {
161 let envelope = request.into_envelope(0, None, None);
162 self.0.send(envelope, T::NAME)
163 }
164
165 pub fn send_response<T: EnvelopedMessage>(
166 &self,
167 request_id: u32,
168 request: T,
169 ) -> anyhow::Result<()> {
170 let envelope = request.into_envelope(0, Some(request_id), None);
171 self.0.send(envelope, T::NAME)
172 }
173
174 pub fn add_request_handler<M, E, H, F>(&self, model: gpui::WeakModel<E>, handler: H)
175 where
176 M: RequestMessage,
177 E: 'static,
178 H: 'static + Sync + Fn(Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F + Send + Sync,
179 F: 'static + Future<Output = anyhow::Result<M::Response>>,
180 {
181 self.0.message_handler_set().lock().add_message_handler(
182 TypeId::of::<M>(),
183 model.into(),
184 Arc::new(move |model, envelope, client, cx| {
185 let model = model.downcast::<E>().unwrap();
186 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
187 let request_id = envelope.message_id();
188 handler(model, *envelope, cx)
189 .then(move |result| async move {
190 match result {
191 Ok(response) => {
192 client.send_response(request_id, response)?;
193 Ok(())
194 }
195 Err(error) => {
196 client.send_response(request_id, error.to_proto())?;
197 Err(error)
198 }
199 }
200 })
201 .boxed_local()
202 }),
203 )
204 }
205
206 pub fn add_model_request_handler<M, E, H, F>(&self, handler: H)
207 where
208 M: EnvelopedMessage + RequestMessage + EntityMessage,
209 E: 'static,
210 H: 'static + Sync + Send + Fn(gpui::Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F,
211 F: 'static + Future<Output = anyhow::Result<M::Response>>,
212 {
213 let message_type_id = TypeId::of::<M>();
214 let model_type_id = TypeId::of::<E>();
215 let entity_id_extractor = |envelope: &dyn AnyTypedEnvelope| {
216 envelope
217 .as_any()
218 .downcast_ref::<TypedEnvelope<M>>()
219 .unwrap()
220 .payload
221 .remote_entity_id()
222 };
223 self.0
224 .message_handler_set()
225 .lock()
226 .add_entity_message_handler(
227 message_type_id,
228 model_type_id,
229 entity_id_extractor,
230 Arc::new(move |model, envelope, client, cx| {
231 let model = model.downcast::<E>().unwrap();
232 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
233 let request_id = envelope.message_id();
234 handler(model, *envelope, cx)
235 .then(move |result| async move {
236 match result {
237 Ok(response) => {
238 client.send_response(request_id, response)?;
239 Ok(())
240 }
241 Err(error) => {
242 client.send_response(request_id, error.to_proto())?;
243 Err(error)
244 }
245 }
246 })
247 .boxed_local()
248 }),
249 );
250 }
251
252 pub fn add_model_message_handler<M, E, H, F>(&self, handler: H)
253 where
254 M: EnvelopedMessage + EntityMessage,
255 E: 'static,
256 H: 'static + Sync + Send + Fn(gpui::Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F,
257 F: 'static + Future<Output = anyhow::Result<()>>,
258 {
259 let message_type_id = TypeId::of::<M>();
260 let model_type_id = TypeId::of::<E>();
261 let entity_id_extractor = |envelope: &dyn AnyTypedEnvelope| {
262 envelope
263 .as_any()
264 .downcast_ref::<TypedEnvelope<M>>()
265 .unwrap()
266 .payload
267 .remote_entity_id()
268 };
269 self.0
270 .message_handler_set()
271 .lock()
272 .add_entity_message_handler(
273 message_type_id,
274 model_type_id,
275 entity_id_extractor,
276 Arc::new(move |model, envelope, _, cx| {
277 let model = model.downcast::<E>().unwrap();
278 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
279 handler(model, *envelope, cx).boxed_local()
280 }),
281 );
282 }
283}