@@ -1,4 +1,5 @@
use anyhow::{anyhow, Context, Result};
+use futures::{io::BufWriter, AsyncRead, AsyncWrite};
use gpui::{executor, AppContext, Task};
use parking_lot::{Mutex, RwLock};
use postage::{barrier, oneshot, prelude::Stream, sink::Sink};
@@ -13,6 +14,7 @@ use std::{
collections::HashMap,
future::Future,
io::Write,
+ marker::PhantomData,
str::FromStr,
sync::{
atomic::{AtomicUsize, Ordering::SeqCst},
@@ -22,6 +24,8 @@ use std::{
use std::{path::Path, process::Stdio};
use util::TryFutureExt;
+pub use lsp_types::*;
+
const JSON_RPC_VERSION: &'static str = "2.0";
const CONTENT_LEN_HEADER: &'static str = "Content-Length: ";
@@ -43,16 +47,16 @@ pub struct Subscription {
notification_handlers: Arc<RwLock<HashMap<&'static str, NotificationHandler>>>,
}
-#[derive(Serialize)]
-struct Request<T> {
- jsonrpc: &'static str,
+#[derive(Serialize, Deserialize)]
+struct Request<'a, T> {
+ jsonrpc: &'a str,
id: usize,
- method: &'static str,
+ method: &'a str,
params: T,
}
-#[derive(Deserialize)]
-struct Response<'a> {
+#[derive(Serialize, Deserialize)]
+struct AnyResponse<'a> {
id: usize,
#[serde(default)]
error: Option<Error>,
@@ -60,22 +64,24 @@ struct Response<'a> {
result: &'a RawValue,
}
-#[derive(Serialize)]
-struct OutboundNotification<T> {
- jsonrpc: &'static str,
- method: &'static str,
+#[derive(Serialize, Deserialize)]
+struct Notification<'a, T> {
+ #[serde(borrow)]
+ jsonrpc: &'a str,
+ #[serde(borrow)]
+ method: &'a str,
params: T,
}
#[derive(Deserialize)]
-struct InboundNotification<'a> {
+struct AnyNotification<'a> {
#[serde(borrow)]
method: &'a str,
#[serde(borrow)]
params: &'a RawValue,
}
-#[derive(Debug, Deserialize)]
+#[derive(Debug, Serialize, Deserialize)]
struct Error {
message: String,
}
@@ -90,24 +96,46 @@ impl LanguageServer {
let rust_analyzer_path = cx
.platform()
.path_for_resource(Some(&rust_analyzer_name), None)?;
- Self::new(root_path, &rust_analyzer_path, cx.background())
+ Self::new(root_path, &rust_analyzer_path, &[], cx.background())
} else {
- Self::new(root_path, Path::new(&rust_analyzer_name), cx.background())
+ Self::new(
+ root_path,
+ Path::new(&rust_analyzer_name),
+ &[],
+ cx.background(),
+ )
}
}
pub fn new(
root_path: &Path,
server_path: &Path,
+ server_args: &[&str],
background: &executor::Background,
) -> Result<Arc<Self>> {
let mut server = Command::new(server_path)
+ .args(server_args)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::inherit())
.spawn()?;
- let mut stdin = server.stdin.take().unwrap();
- let mut stdout = BufReader::new(server.stdout.take().unwrap());
+ let stdin = server.stdin.take().unwrap();
+ let stdout = server.stdout.take().unwrap();
+ Self::new_internal(root_path, stdin, stdout, background)
+ }
+
+ fn new_internal<Stdin, Stdout>(
+ root_path: &Path,
+ stdin: Stdin,
+ stdout: Stdout,
+ background: &executor::Background,
+ ) -> Result<Arc<Self>>
+ where
+ Stdin: AsyncWrite + Unpin + Send + 'static,
+ Stdout: AsyncRead + Unpin + Send + 'static,
+ {
+ let mut stdin = BufWriter::new(stdin);
+ let mut stdout = BufReader::new(stdout);
let (outbound_tx, outbound_rx) = channel::unbounded::<Vec<u8>>();
let notification_handlers = Arc::new(RwLock::new(HashMap::<_, NotificationHandler>::new()));
let response_handlers = Arc::new(Mutex::new(HashMap::<_, ResponseHandler>::new()));
@@ -119,7 +147,6 @@ impl LanguageServer {
let mut buffer = Vec::new();
loop {
buffer.clear();
-
stdout.read_until(b'\n', &mut buffer).await?;
stdout.read_until(b'\n', &mut buffer).await?;
let message_len: usize = std::str::from_utf8(&buffer)?
@@ -131,7 +158,7 @@ impl LanguageServer {
buffer.resize(message_len, 0);
stdout.read_exact(&mut buffer).await?;
- if let Ok(InboundNotification { method, params }) =
+ if let Ok(AnyNotification { method, params }) =
serde_json::from_slice(&buffer)
{
if let Some(handler) = notification_handlers.read().get(method) {
@@ -146,7 +173,7 @@ impl LanguageServer {
.unwrap()
);
}
- } else if let Ok(Response { id, error, result }) =
+ } else if let Ok(AnyResponse { id, error, result }) =
serde_json::from_slice(&buffer)
{
if let Some(handler) = response_handlers.lock().remove(&id) {
@@ -179,6 +206,7 @@ impl LanguageServer {
stdin.write_all(&content_len_buffer).await?;
stdin.write_all("\r\n\r\n".as_bytes()).await?;
stdin.write_all(&message).await?;
+ stdin.flush().await?;
}
}
.log_err(),
@@ -211,7 +239,8 @@ impl LanguageServer {
}
async fn init(self: Arc<Self>, root_uri: lsp_types::Url) -> Result<()> {
- self.request_internal::<lsp_types::request::Initialize>(lsp_types::InitializeParams {
+ #[allow(deprecated)]
+ let params = lsp_types::InitializeParams {
process_id: Default::default(),
root_path: Default::default(),
root_uri: Some(root_uri),
@@ -226,8 +255,10 @@ impl LanguageServer {
workspace_folders: Default::default(),
client_info: Default::default(),
locale: Default::default(),
- })
- .await?;
+ };
+
+ self.request_internal::<lsp_types::request::Initialize>(params)
+ .await?;
self.notify_internal::<lsp_types::notification::Initialized>(
lsp_types::InitializedParams {},
)
@@ -327,7 +358,7 @@ impl LanguageServer {
self: &Arc<Self>,
params: T::Params,
) -> impl Future<Output = Result<()>> {
- let message = serde_json::to_vec(&OutboundNotification {
+ let message = serde_json::to_vec(&Notification {
jsonrpc: JSON_RPC_VERSION,
method: T::METHOD,
params,
@@ -342,16 +373,136 @@ impl LanguageServer {
}
}
+impl Subscription {
+ pub fn detach(mut self) {
+ self.method = "";
+ }
+}
+
impl Drop for Subscription {
fn drop(&mut self) {
self.notification_handlers.write().remove(self.method);
}
}
+#[cfg(any(test, feature = "test-support"))]
+pub struct FakeLanguageServer {
+ buffer: Vec<u8>,
+ stdin: smol::io::BufReader<async_pipe::PipeReader>,
+ stdout: smol::io::BufWriter<async_pipe::PipeWriter>,
+}
+
+#[cfg(any(test, feature = "test-support"))]
+pub struct RequestId<T> {
+ id: usize,
+ _type: std::marker::PhantomData<T>,
+}
+
+#[cfg(any(test, feature = "test-support"))]
+impl LanguageServer {
+ pub async fn fake(executor: &executor::Background) -> (Arc<Self>, FakeLanguageServer) {
+ let stdin = async_pipe::pipe();
+ let stdout = async_pipe::pipe();
+ (
+ Self::new_internal(Path::new("/"), stdin.0, stdout.1, executor).unwrap(),
+ FakeLanguageServer {
+ stdin: smol::io::BufReader::new(stdin.1),
+ stdout: smol::io::BufWriter::new(stdout.0),
+ buffer: Vec::new(),
+ },
+ )
+ }
+}
+
+#[cfg(any(test, feature = "test-support"))]
+impl FakeLanguageServer {
+ pub async fn notify<T: notification::Notification>(&mut self, params: T::Params) {
+ let message = serde_json::to_vec(&Notification {
+ jsonrpc: JSON_RPC_VERSION,
+ method: T::METHOD,
+ params,
+ })
+ .unwrap();
+ self.send(message).await;
+ }
+
+ pub async fn respond<'a, T: request::Request>(
+ &mut self,
+ request_id: RequestId<T>,
+ result: T::Result,
+ ) {
+ let result = serde_json::to_string(&result).unwrap();
+ let message = serde_json::to_vec(&AnyResponse {
+ id: request_id.id,
+ error: None,
+ result: &RawValue::from_string(result).unwrap(),
+ })
+ .unwrap();
+ self.send(message).await;
+ }
+
+ pub async fn receive_request<T: request::Request>(&mut self) -> (RequestId<T>, T::Params) {
+ self.receive().await;
+ let request = serde_json::from_slice::<Request<T::Params>>(&self.buffer).unwrap();
+ assert_eq!(request.method, T::METHOD);
+ assert_eq!(request.jsonrpc, JSON_RPC_VERSION);
+ (
+ RequestId {
+ id: request.id,
+ _type: PhantomData,
+ },
+ request.params,
+ )
+ }
+
+ pub async fn receive_notification<T: notification::Notification>(&mut self) -> T::Params {
+ self.receive().await;
+ let notification = serde_json::from_slice::<Notification<T::Params>>(&self.buffer).unwrap();
+ assert_eq!(notification.method, T::METHOD);
+ notification.params
+ }
+
+ async fn send(&mut self, message: Vec<u8>) {
+ self.stdout
+ .write_all(CONTENT_LEN_HEADER.as_bytes())
+ .await
+ .unwrap();
+ self.stdout
+ .write_all((format!("{}", message.len())).as_bytes())
+ .await
+ .unwrap();
+ self.stdout.write_all("\r\n\r\n".as_bytes()).await.unwrap();
+ self.stdout.write_all(&message).await.unwrap();
+ self.stdout.flush().await.unwrap();
+ }
+
+ async fn receive(&mut self) {
+ self.buffer.clear();
+ self.stdin
+ .read_until(b'\n', &mut self.buffer)
+ .await
+ .unwrap();
+ self.stdin
+ .read_until(b'\n', &mut self.buffer)
+ .await
+ .unwrap();
+ let message_len: usize = std::str::from_utf8(&self.buffer)
+ .unwrap()
+ .strip_prefix(CONTENT_LEN_HEADER)
+ .unwrap()
+ .trim_end()
+ .parse()
+ .unwrap();
+ self.buffer.resize(message_len, 0);
+ self.stdin.read_exact(&mut self.buffer).await.unwrap();
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
use gpui::TestAppContext;
+ use simplelog::SimpleLogger;
use unindent::Unindent;
use util::test::temp_tree;
@@ -414,6 +565,68 @@ mod tests {
);
}
+ #[gpui::test]
+ async fn test_fake(cx: TestAppContext) {
+ SimpleLogger::init(log::LevelFilter::Info, Default::default()).unwrap();
+
+ let (server, mut fake) = LanguageServer::fake(&cx.background()).await;
+
+ let (message_tx, message_rx) = channel::unbounded();
+ let (diagnostics_tx, diagnostics_rx) = channel::unbounded();
+ server
+ .on_notification::<notification::ShowMessage, _>(move |params| {
+ message_tx.try_send(params).unwrap()
+ })
+ .detach();
+ server
+ .on_notification::<notification::PublishDiagnostics, _>(move |params| {
+ diagnostics_tx.try_send(params).unwrap()
+ })
+ .detach();
+
+ let (init_id, _) = fake.receive_request::<request::Initialize>().await;
+ fake.respond(init_id, InitializeResult::default()).await;
+ fake.receive_notification::<notification::Initialized>()
+ .await;
+
+ server
+ .notify::<notification::DidOpenTextDocument>(DidOpenTextDocumentParams {
+ text_document: TextDocumentItem::new(
+ Url::from_str("file://a/b").unwrap(),
+ "rust".to_string(),
+ 0,
+ "".to_string(),
+ ),
+ })
+ .await
+ .unwrap();
+ assert_eq!(
+ fake.receive_notification::<notification::DidOpenTextDocument>()
+ .await
+ .text_document
+ .uri
+ .as_str(),
+ "file://a/b"
+ );
+
+ fake.notify::<notification::ShowMessage>(ShowMessageParams {
+ typ: MessageType::ERROR,
+ message: "ok".to_string(),
+ })
+ .await;
+ fake.notify::<notification::PublishDiagnostics>(PublishDiagnosticsParams {
+ uri: Url::from_str("file://b/c").unwrap(),
+ version: Some(5),
+ diagnostics: vec![],
+ })
+ .await;
+ assert_eq!(message_rx.recv().await.unwrap().message, "ok");
+ assert_eq!(
+ diagnostics_rx.recv().await.unwrap().uri.as_str(),
+ "file://b/c"
+ );
+ }
+
impl LanguageServer {
async fn next_idle_notification(self: &Arc<Self>) {
let (tx, rx) = channel::unbounded();