lib.rs

  1use anyhow::{anyhow, Context, Result};
  2use gpui::{executor, AppContext, Task};
  3use parking_lot::Mutex;
  4use serde::{Deserialize, Serialize};
  5use serde_json::value::RawValue;
  6use smol::{
  7    channel,
  8    io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader},
  9    process::Command,
 10};
 11use std::{
 12    collections::HashMap,
 13    future::Future,
 14    io::Write,
 15    sync::{
 16        atomic::{AtomicUsize, Ordering::SeqCst},
 17        Arc,
 18    },
 19};
 20use std::{path::Path, process::Stdio};
 21use util::TryFutureExt;
 22
 23const JSON_RPC_VERSION: &'static str = "2.0";
 24const CONTENT_LEN_HEADER: &'static str = "Content-Length: ";
 25
 26pub struct LanguageServer {
 27    next_id: AtomicUsize,
 28    outbound_tx: channel::Sender<Vec<u8>>,
 29    response_handlers: Arc<Mutex<HashMap<usize, ResponseHandler>>>,
 30    _input_task: Task<Option<()>>,
 31    _output_task: Task<Option<()>>,
 32}
 33
 34type ResponseHandler = Box<dyn Send + FnOnce(Result<&str, Error>)>;
 35
 36#[derive(Serialize)]
 37struct Request<T> {
 38    jsonrpc: &'static str,
 39    id: usize,
 40    method: &'static str,
 41    params: T,
 42}
 43
 44#[derive(Deserialize)]
 45struct Response<'a> {
 46    id: usize,
 47    #[serde(default)]
 48    error: Option<Error>,
 49    #[serde(default, borrow)]
 50    result: Option<&'a RawValue>,
 51}
 52
 53#[derive(Serialize)]
 54struct OutboundNotification<T> {
 55    jsonrpc: &'static str,
 56    method: &'static str,
 57    params: T,
 58}
 59
 60#[derive(Deserialize)]
 61struct InboundNotification<'a> {
 62    #[serde(borrow)]
 63    method: &'a str,
 64    #[serde(borrow)]
 65    params: &'a RawValue,
 66}
 67
 68#[derive(Deserialize)]
 69struct Error {
 70    message: String,
 71}
 72
 73impl LanguageServer {
 74    pub fn rust(cx: &AppContext) -> Result<Arc<Self>> {
 75        const BUNDLE: Option<&'static str> = option_env!("BUNDLE");
 76        const TARGET: &'static str = env!("TARGET");
 77
 78        let rust_analyzer_name = format!("rust-analyzer-{}", TARGET);
 79        if BUNDLE.map_or(Ok(false), |b| b.parse())? {
 80            let rust_analyzer_path = cx
 81                .platform()
 82                .path_for_resource(Some(&rust_analyzer_name), None)?;
 83            Self::new(&rust_analyzer_path, cx.background())
 84        } else {
 85            Self::new(Path::new(&rust_analyzer_name), cx.background())
 86        }
 87    }
 88
 89    pub fn new(path: &Path, background: &executor::Background) -> Result<Arc<Self>> {
 90        let mut server = Command::new(path)
 91            .stdin(Stdio::piped())
 92            .stdout(Stdio::piped())
 93            .stderr(Stdio::inherit())
 94            .spawn()?;
 95        let mut stdin = server.stdin.take().unwrap();
 96        let mut stdout = BufReader::new(server.stdout.take().unwrap());
 97        let (outbound_tx, outbound_rx) = channel::unbounded::<Vec<u8>>();
 98        let response_handlers = Arc::new(Mutex::new(HashMap::<usize, ResponseHandler>::new()));
 99        let _input_task = background.spawn(
100            {
101                let response_handlers = response_handlers.clone();
102                async move {
103                    let mut buffer = Vec::new();
104                    loop {
105                        buffer.clear();
106
107                        stdout.read_until(b'\n', &mut buffer).await?;
108                        stdout.read_until(b'\n', &mut buffer).await?;
109                        let message_len: usize = std::str::from_utf8(&buffer)?
110                            .strip_prefix(CONTENT_LEN_HEADER)
111                            .ok_or_else(|| anyhow!("invalid header"))?
112                            .trim_end()
113                            .parse()?;
114
115                        buffer.resize(message_len, 0);
116                        stdout.read_exact(&mut buffer).await?;
117                        if let Ok(InboundNotification { .. }) = serde_json::from_slice(&buffer) {
118                        } else if let Ok(Response { id, error, result }) =
119                            serde_json::from_slice(&buffer)
120                        {
121                            if let Some(handler) = response_handlers.lock().remove(&id) {
122                                if let Some(result) = result {
123                                    handler(Ok(result.get()));
124                                } else if let Some(error) = error {
125                                    handler(Err(error));
126                                }
127                            }
128                        } else {
129                            return Err(anyhow!(
130                                "failed to deserialize message:\n{}",
131                                std::str::from_utf8(&buffer)?
132                            ));
133                        }
134                    }
135                }
136            }
137            .log_err(),
138        );
139        let _output_task = background.spawn(
140            async move {
141                let mut content_len_buffer = Vec::new();
142                loop {
143                    let message = outbound_rx.recv().await?;
144                    write!(content_len_buffer, "{}", message.len()).unwrap();
145                    stdin.write_all(CONTENT_LEN_HEADER.as_bytes()).await?;
146                    stdin.write_all(&content_len_buffer).await?;
147                    stdin.write_all("\r\n\r\n".as_bytes()).await?;
148                    stdin.write_all(&message).await?;
149                }
150            }
151            .log_err(),
152        );
153
154        let this = Arc::new(Self {
155            response_handlers,
156            next_id: Default::default(),
157            outbound_tx,
158            _input_task,
159            _output_task,
160        });
161        background.spawn(this.clone().init().log_err()).detach();
162
163        Ok(this)
164    }
165
166    async fn init(self: Arc<Self>) -> Result<()> {
167        self.request::<lsp_types::request::Initialize>(lsp_types::InitializeParams {
168            process_id: Default::default(),
169            root_path: Default::default(),
170            root_uri: Default::default(),
171            initialization_options: Default::default(),
172            capabilities: Default::default(),
173            trace: Default::default(),
174            workspace_folders: Default::default(),
175            client_info: Default::default(),
176            locale: Default::default(),
177        })
178        .await?;
179        self.notify::<lsp_types::notification::Initialized>(lsp_types::InitializedParams {})?;
180        Ok(())
181    }
182
183    pub fn request<T: lsp_types::request::Request>(
184        self: &Arc<Self>,
185        params: T::Params,
186    ) -> impl Future<Output = Result<T::Result>>
187    where
188        T::Result: 'static + Send,
189    {
190        let id = self.next_id.fetch_add(1, SeqCst);
191        let message = serde_json::to_vec(&Request {
192            jsonrpc: JSON_RPC_VERSION,
193            id,
194            method: T::METHOD,
195            params,
196        })
197        .unwrap();
198        let mut response_handlers = self.response_handlers.lock();
199        let (tx, rx) = smol::channel::bounded(1);
200        response_handlers.insert(
201            id,
202            Box::new(move |result| {
203                let response = match result {
204                    Ok(response) => {
205                        serde_json::from_str(response).context("failed to deserialize response")
206                    }
207                    Err(error) => Err(anyhow!("{}", error.message)),
208                };
209                let _ = smol::block_on(tx.send(response));
210            }),
211        );
212
213        let outbound_tx = self.outbound_tx.clone();
214        async move {
215            outbound_tx.send(message).await?;
216            rx.recv().await?
217        }
218    }
219
220    pub fn notify<T: lsp_types::notification::Notification>(
221        &self,
222        params: T::Params,
223    ) -> Result<()> {
224        let message = serde_json::to_vec(&OutboundNotification {
225            jsonrpc: JSON_RPC_VERSION,
226            method: T::METHOD,
227            params,
228        })
229        .unwrap();
230        smol::block_on(self.outbound_tx.send(message))?;
231        Ok(())
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238    use gpui::TestAppContext;
239
240    #[gpui::test]
241    async fn test_basic(cx: TestAppContext) {
242        let server = cx.read(|cx| LanguageServer::rust(cx).unwrap());
243    }
244}