@@ -4,7 +4,7 @@ use futures::{channel::oneshot, io::BufWriter, AsyncRead, AsyncWrite};
use gpui::{executor, Task};
use parking_lot::{Mutex, RwLock};
use postage::{barrier, prelude::Stream};
-use serde::{Deserialize, Serialize};
+use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::{json, value::RawValue, Value};
use smol::{
channel,
@@ -29,7 +29,8 @@ pub use lsp_types::*;
const JSON_RPC_VERSION: &'static str = "2.0";
const CONTENT_LEN_HEADER: &'static str = "Content-Length: ";
-type NotificationHandler = Box<dyn Send + Sync + FnMut(&str)>;
+type NotificationHandler =
+ Box<dyn Send + Sync + FnMut(Option<usize>, &str, &mut channel::Sender<Vec<u8>>) -> Result<()>>;
type ResponseHandler = Box<dyn Send + FnOnce(Result<&str, Error>)>;
pub struct LanguageServer {
@@ -80,6 +81,12 @@ struct AnyResponse<'a> {
result: Option<&'a RawValue>,
}
+#[derive(Serialize)]
+struct Response<T> {
+ id: usize,
+ result: T,
+}
+
#[derive(Serialize, Deserialize)]
struct Notification<'a, T> {
#[serde(borrow)]
@@ -91,6 +98,8 @@ struct Notification<'a, T> {
#[derive(Deserialize)]
struct AnyNotification<'a> {
+ #[serde(default)]
+ id: Option<usize>,
#[serde(borrow)]
method: &'a str,
#[serde(borrow)]
@@ -152,6 +161,7 @@ impl LanguageServer {
{
let notification_handlers = notification_handlers.clone();
let response_handlers = response_handlers.clone();
+ let mut outbound_tx = outbound_tx.clone();
async move {
let _clear_response_handlers = ClearResponseHandlers(response_handlers.clone());
let mut buffer = Vec::new();
@@ -168,11 +178,13 @@ impl LanguageServer {
buffer.resize(message_len, 0);
stdout.read_exact(&mut buffer).await?;
- if let Ok(AnyNotification { method, params }) =
+ if let Ok(AnyNotification { id, method, params }) =
serde_json::from_slice(&buffer)
{
if let Some(handler) = notification_handlers.write().get_mut(method) {
- handler(params.get());
+ if let Err(e) = handler(id, params.get(), &mut outbound_tx) {
+ log::error!("error handling {} message: {:?}", method, e);
+ }
} else {
log::info!(
"unhandled notification {}:\n{}",
@@ -351,12 +363,11 @@ impl LanguageServer {
{
let prev_handler = self.notification_handlers.write().insert(
T::METHOD,
- Box::new(
- move |notification| match serde_json::from_str(notification) {
- Ok(notification) => f(notification),
- Err(err) => log::error!("error parsing notification {}: {}", T::METHOD, err),
- },
- ),
+ Box::new(move |_, params, _| {
+ let params = serde_json::from_str(params)?;
+ f(params);
+ Ok(())
+ }),
);
assert!(
@@ -370,6 +381,40 @@ impl LanguageServer {
}
}
+ pub fn on_custom_request<Params, Resp, F>(
+ &mut self,
+ method: &'static str,
+ mut f: F,
+ ) -> Subscription
+ where
+ F: 'static + Send + Sync + FnMut(Params) -> Result<Resp>,
+ Params: DeserializeOwned,
+ Resp: Serialize,
+ {
+ let prev_handler = self.notification_handlers.write().insert(
+ method,
+ Box::new(move |id, params, tx| {
+ if let Some(id) = id {
+ let params = serde_json::from_str(params)?;
+ let result = f(params)?;
+ let response = serde_json::to_vec(&Response { id, result })?;
+ tx.try_send(response)?;
+ }
+ Ok(())
+ }),
+ );
+
+ assert!(
+ prev_handler.is_none(),
+ "registered multiple handlers for the same notification"
+ );
+
+ Subscription {
+ method,
+ notification_handlers: self.notification_handlers.clone(),
+ }
+ }
+
pub fn name<'a>(self: &'a Arc<Self>) -> &'a str {
&self.name
}