1use anyhow::{anyhow, Context, Result};
2use gpui::{executor, AppContext, Task};
3use parking_lot::{Mutex, RwLock};
4use postage::{barrier, oneshot, prelude::Stream, sink::Sink};
5use serde::{Deserialize, Serialize};
6use serde_json::{json, value::RawValue};
7use smol::{
8 channel,
9 io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader},
10 process::Command,
11};
12use std::{
13 collections::HashMap,
14 future::Future,
15 io::Write,
16 sync::{
17 atomic::{AtomicUsize, Ordering::SeqCst},
18 Arc,
19 },
20};
21use std::{path::Path, process::Stdio};
22use util::TryFutureExt;
23
24const JSON_RPC_VERSION: &'static str = "2.0";
25const CONTENT_LEN_HEADER: &'static str = "Content-Length: ";
26
27type NotificationHandler = Box<dyn Send + Sync + Fn(&str)>;
28type ResponseHandler = Box<dyn Send + FnOnce(Result<&str, Error>)>;
29
30pub struct LanguageServer {
31 next_id: AtomicUsize,
32 outbound_tx: channel::Sender<Vec<u8>>,
33 notification_handlers: Arc<RwLock<HashMap<&'static str, NotificationHandler>>>,
34 response_handlers: Arc<Mutex<HashMap<usize, ResponseHandler>>>,
35 _input_task: Task<Option<()>>,
36 _output_task: Task<Option<()>>,
37 initialized: barrier::Receiver,
38}
39
40pub struct Subscription {
41 method: &'static str,
42 notification_handlers: Arc<RwLock<HashMap<&'static str, NotificationHandler>>>,
43}
44
45#[derive(Serialize)]
46struct Request<T> {
47 jsonrpc: &'static str,
48 id: usize,
49 method: &'static str,
50 params: T,
51}
52
53#[derive(Deserialize)]
54struct Response<'a> {
55 id: usize,
56 #[serde(default)]
57 error: Option<Error>,
58 #[serde(borrow)]
59 result: &'a RawValue,
60}
61
62#[derive(Serialize)]
63struct OutboundNotification<T> {
64 jsonrpc: &'static str,
65 method: &'static str,
66 params: T,
67}
68
69#[derive(Deserialize)]
70struct InboundNotification<'a> {
71 #[serde(borrow)]
72 method: &'a str,
73 #[serde(borrow)]
74 params: &'a RawValue,
75}
76
77#[derive(Debug, Deserialize)]
78struct Error {
79 message: String,
80}
81
82impl LanguageServer {
83 pub fn rust(root_path: &Path, cx: &AppContext) -> Result<Arc<Self>> {
84 const ZED_BUNDLE: Option<&'static str> = option_env!("ZED_BUNDLE");
85 const ZED_TARGET: &'static str = env!("ZED_TARGET");
86
87 let rust_analyzer_name = format!("rust-analyzer-{}", ZED_TARGET);
88 if ZED_BUNDLE.map_or(Ok(false), |b| b.parse())? {
89 let rust_analyzer_path = cx
90 .platform()
91 .path_for_resource(Some(&rust_analyzer_name), None)?;
92 Self::new(root_path, &rust_analyzer_path, cx.background())
93 } else {
94 Self::new(root_path, Path::new(&rust_analyzer_name), cx.background())
95 }
96 }
97
98 pub fn new(
99 root_path: &Path,
100 server_path: &Path,
101 background: &executor::Background,
102 ) -> Result<Arc<Self>> {
103 let mut server = Command::new(server_path)
104 .stdin(Stdio::piped())
105 .stdout(Stdio::piped())
106 .stderr(Stdio::inherit())
107 .spawn()?;
108 let mut stdin = server.stdin.take().unwrap();
109 let mut stdout = BufReader::new(server.stdout.take().unwrap());
110 let (outbound_tx, outbound_rx) = channel::unbounded::<Vec<u8>>();
111 let notification_handlers = Arc::new(RwLock::new(HashMap::<_, NotificationHandler>::new()));
112 let response_handlers = Arc::new(Mutex::new(HashMap::<_, ResponseHandler>::new()));
113 let _input_task = background.spawn(
114 {
115 let notification_handlers = notification_handlers.clone();
116 let response_handlers = response_handlers.clone();
117 async move {
118 let mut buffer = Vec::new();
119 loop {
120 buffer.clear();
121
122 stdout.read_until(b'\n', &mut buffer).await?;
123 stdout.read_until(b'\n', &mut buffer).await?;
124 let message_len: usize = std::str::from_utf8(&buffer)?
125 .strip_prefix(CONTENT_LEN_HEADER)
126 .ok_or_else(|| anyhow!("invalid header"))?
127 .trim_end()
128 .parse()?;
129
130 buffer.resize(message_len, 0);
131 stdout.read_exact(&mut buffer).await?;
132
133 if let Ok(InboundNotification { method, params }) =
134 serde_json::from_slice(&buffer)
135 {
136 if let Some(handler) = notification_handlers.read().get(method) {
137 handler(params.get());
138 }
139 } else if let Ok(Response { id, error, result }) =
140 serde_json::from_slice(&buffer)
141 {
142 if let Some(handler) = response_handlers.lock().remove(&id) {
143 if let Some(error) = error {
144 handler(Err(error));
145 } else {
146 handler(Ok(result.get()));
147 }
148 }
149 } else {
150 return Err(anyhow!(
151 "failed to deserialize message:\n{}",
152 std::str::from_utf8(&buffer)?
153 ));
154 }
155 }
156 }
157 }
158 .log_err(),
159 );
160 let _output_task = background.spawn(
161 async move {
162 let mut content_len_buffer = Vec::new();
163 loop {
164 content_len_buffer.clear();
165
166 let message = outbound_rx.recv().await?;
167 write!(content_len_buffer, "{}", message.len()).unwrap();
168 stdin.write_all(CONTENT_LEN_HEADER.as_bytes()).await?;
169 stdin.write_all(&content_len_buffer).await?;
170 stdin.write_all("\r\n\r\n".as_bytes()).await?;
171 stdin.write_all(&message).await?;
172 }
173 }
174 .log_err(),
175 );
176
177 let (initialized_tx, initialized_rx) = barrier::channel();
178 let this = Arc::new(Self {
179 notification_handlers,
180 response_handlers,
181 next_id: Default::default(),
182 outbound_tx,
183 _input_task,
184 _output_task,
185 initialized: initialized_rx,
186 });
187
188 let root_uri =
189 lsp_types::Url::from_file_path(root_path).map_err(|_| anyhow!("invalid root path"))?;
190 background
191 .spawn({
192 let this = this.clone();
193 async move {
194 this.init(root_uri).log_err().await;
195 drop(initialized_tx);
196 }
197 })
198 .detach();
199
200 Ok(this)
201 }
202
203 async fn init(self: Arc<Self>, root_uri: lsp_types::Url) -> Result<()> {
204 self.request_internal::<lsp_types::request::Initialize>(lsp_types::InitializeParams {
205 process_id: Default::default(),
206 root_path: Default::default(),
207 root_uri: Some(root_uri),
208 initialization_options: Default::default(),
209 capabilities: lsp_types::ClientCapabilities {
210 experimental: Some(json!({
211 "serverStatusNotification": true,
212 })),
213 ..Default::default()
214 },
215 trace: Default::default(),
216 workspace_folders: Default::default(),
217 client_info: Default::default(),
218 locale: Default::default(),
219 })
220 .await?;
221 self.notify_internal::<lsp_types::notification::Initialized>(
222 lsp_types::InitializedParams {},
223 )
224 .await?;
225 Ok(())
226 }
227
228 pub fn on_notification<T, F>(&self, f: F) -> Subscription
229 where
230 T: lsp_types::notification::Notification,
231 F: 'static + Send + Sync + Fn(T::Params),
232 {
233 let prev_handler = self.notification_handlers.write().insert(
234 T::METHOD,
235 Box::new(
236 move |notification| match serde_json::from_str(notification) {
237 Ok(notification) => f(notification),
238 Err(err) => log::error!("error parsing notification {}: {}", T::METHOD, err),
239 },
240 ),
241 );
242
243 assert!(
244 prev_handler.is_none(),
245 "registered multiple handlers for the same notification"
246 );
247
248 Subscription {
249 method: T::METHOD,
250 notification_handlers: self.notification_handlers.clone(),
251 }
252 }
253
254 pub fn request<T: lsp_types::request::Request>(
255 self: Arc<Self>,
256 params: T::Params,
257 ) -> impl Future<Output = Result<T::Result>>
258 where
259 T::Result: 'static + Send,
260 {
261 let this = self.clone();
262 async move {
263 this.initialized.clone().recv().await;
264 this.request_internal::<T>(params).await
265 }
266 }
267
268 fn request_internal<T: lsp_types::request::Request>(
269 self: &Arc<Self>,
270 params: T::Params,
271 ) -> impl Future<Output = Result<T::Result>>
272 where
273 T::Result: 'static + Send,
274 {
275 let id = self.next_id.fetch_add(1, SeqCst);
276 let message = serde_json::to_vec(&Request {
277 jsonrpc: JSON_RPC_VERSION,
278 id,
279 method: T::METHOD,
280 params,
281 })
282 .unwrap();
283 let mut response_handlers = self.response_handlers.lock();
284 let (mut tx, mut rx) = oneshot::channel();
285 response_handlers.insert(
286 id,
287 Box::new(move |result| {
288 let response = match result {
289 Ok(response) => {
290 serde_json::from_str(response).context("failed to deserialize response")
291 }
292 Err(error) => Err(anyhow!("{}", error.message)),
293 };
294 let _ = tx.try_send(response);
295 }),
296 );
297
298 let this = self.clone();
299 async move {
300 this.outbound_tx.send(message).await?;
301 rx.recv().await.unwrap()
302 }
303 }
304
305 pub fn notify<T: lsp_types::notification::Notification>(
306 self: &Arc<Self>,
307 params: T::Params,
308 ) -> impl Future<Output = Result<()>> {
309 let this = self.clone();
310 async move {
311 this.initialized.clone().recv().await;
312 this.notify_internal::<T>(params).await
313 }
314 }
315
316 fn notify_internal<T: lsp_types::notification::Notification>(
317 self: &Arc<Self>,
318 params: T::Params,
319 ) -> impl Future<Output = Result<()>> {
320 let message = serde_json::to_vec(&OutboundNotification {
321 jsonrpc: JSON_RPC_VERSION,
322 method: T::METHOD,
323 params,
324 })
325 .unwrap();
326
327 let this = self.clone();
328 async move {
329 this.outbound_tx.send(message).await?;
330 Ok(())
331 }
332 }
333}
334
335impl Drop for Subscription {
336 fn drop(&mut self) {
337 self.notification_handlers.write().remove(self.method);
338 }
339}
340
341#[cfg(test)]
342mod tests {
343 use super::*;
344 use gpui::TestAppContext;
345 use unindent::Unindent;
346 use util::test::temp_tree;
347
348 #[gpui::test]
349 async fn test_basic(cx: TestAppContext) {
350 let root_dir = temp_tree(json!({
351 "Cargo.toml": r#"
352 [package]
353 name = "temp"
354 version = "0.1.0"
355 edition = "2018"
356 "#.unindent(),
357 "src": {
358 "lib.rs": r#"
359 fn fun() {
360 let hello = "world";
361 }
362 "#.unindent()
363 }
364 }));
365
366 let server = cx.read(|cx| LanguageServer::rust(root_dir.path(), cx).unwrap());
367 server.next_idle_notification().await;
368
369 let hover = server
370 .request::<lsp_types::request::HoverRequest>(lsp_types::HoverParams {
371 text_document_position_params: lsp_types::TextDocumentPositionParams {
372 text_document: lsp_types::TextDocumentIdentifier::new(
373 lsp_types::Url::from_file_path(root_dir.path().join("src/lib.rs")).unwrap(),
374 ),
375 position: lsp_types::Position::new(1, 21),
376 },
377 work_done_progress_params: Default::default(),
378 })
379 .await
380 .unwrap()
381 .unwrap();
382 assert_eq!(
383 hover.contents,
384 lsp_types::HoverContents::Markup(lsp_types::MarkupContent {
385 kind: lsp_types::MarkupKind::Markdown,
386 value: "&str".to_string()
387 })
388 );
389 }
390
391 impl LanguageServer {
392 async fn next_idle_notification(self: &Arc<Self>) {
393 let (tx, rx) = channel::unbounded();
394 let _subscription =
395 self.on_notification::<ServerStatusNotification, _>(move |params| {
396 if params.quiescent {
397 tx.try_send(()).unwrap();
398 }
399 });
400 let _ = rx.recv().await;
401 }
402 }
403
404 pub enum ServerStatusNotification {}
405
406 impl lsp_types::notification::Notification for ServerStatusNotification {
407 type Params = ServerStatusParams;
408 const METHOD: &'static str = "experimental/serverStatus";
409 }
410
411 #[derive(Deserialize, Serialize, PartialEq, Eq, Clone)]
412 pub struct ServerStatusParams {
413 pub quiescent: bool,
414 }
415}