@@ -13,14 +13,15 @@ use parking_lot::Mutex;
use smol::io::BufReader;
use crate::{
- AnyNotification, AnyResponse, CONTENT_LEN_HEADER, IoHandler, IoKind, RequestId, ResponseHandler,
+ AnyResponse, CONTENT_LEN_HEADER, IoHandler, IoKind, NotificationOrRequest, RequestId,
+ ResponseHandler,
};
const HEADER_DELIMITER: &[u8; 4] = b"\r\n\r\n";
/// Handler for stdout of language server.
pub struct LspStdoutHandler {
pub(super) loop_handle: Task<Result<()>>,
- pub(super) notifications_channel: UnboundedReceiver<AnyNotification>,
+ pub(super) incoming_messages: UnboundedReceiver<NotificationOrRequest>,
}
async fn read_headers<Stdout>(reader: &mut BufReader<Stdout>, buffer: &mut Vec<u8>) -> Result<()>
@@ -54,13 +55,13 @@ impl LspStdoutHandler {
let loop_handle = cx.spawn(Self::handler(stdout, tx, response_handlers, io_handlers));
Self {
loop_handle,
- notifications_channel,
+ incoming_messages: notifications_channel,
}
}
async fn handler<Input>(
stdout: Input,
- notifications_sender: UnboundedSender<AnyNotification>,
+ notifications_sender: UnboundedSender<NotificationOrRequest>,
response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
io_handlers: Arc<Mutex<HashMap<i32, IoHandler>>>,
) -> anyhow::Result<()>
@@ -96,7 +97,7 @@ impl LspStdoutHandler {
}
}
- if let Ok(msg) = serde_json::from_slice::<AnyNotification>(&buffer) {
+ if let Ok(msg) = serde_json::from_slice::<NotificationOrRequest>(&buffer) {
notifications_sender.unbounded_send(msg)?;
} else if let Ok(AnyResponse {
id, error, result, ..
@@ -242,7 +242,7 @@ struct Notification<'a, T> {
/// Language server RPC notification message before it is deserialized into a concrete type.
#[derive(Debug, Clone, Deserialize)]
-struct AnyNotification {
+struct NotificationOrRequest {
#[serde(default)]
id: Option<RequestId>,
method: String,
@@ -252,7 +252,10 @@ struct AnyNotification {
#[derive(Debug, Serialize, Deserialize)]
struct Error {
+ code: i64,
message: String,
+ #[serde(default)]
+ data: Option<serde_json::Value>,
}
pub trait LspRequestFuture<O>: Future<Output = ConnectionResult<O>> {
@@ -364,6 +367,7 @@ impl LanguageServer {
notification.method,
serde_json::to_string_pretty(¬ification.params).unwrap(),
);
+ false
},
);
@@ -389,7 +393,7 @@ impl LanguageServer {
Stdin: AsyncWrite + Unpin + Send + 'static,
Stdout: AsyncRead + Unpin + Send + 'static,
Stderr: AsyncRead + Unpin + Send + 'static,
- F: FnMut(AnyNotification) + 'static + Send + Sync + Clone,
+ F: Fn(&NotificationOrRequest) -> bool + 'static + Send + Sync + Clone,
{
let (outbound_tx, outbound_rx) = channel::unbounded::<String>();
let (output_done_tx, output_done_rx) = barrier::channel();
@@ -400,14 +404,34 @@ impl LanguageServer {
let io_handlers = Arc::new(Mutex::new(HashMap::default()));
let stdout_input_task = cx.spawn({
- let on_unhandled_notification = on_unhandled_notification.clone();
+ let unhandled_notification_wrapper = {
+ let response_channel = outbound_tx.clone();
+ async move |msg: NotificationOrRequest| {
+ let did_handle = on_unhandled_notification(&msg);
+ if !did_handle && let Some(message_id) = msg.id {
+ let response = AnyResponse {
+ jsonrpc: JSON_RPC_VERSION,
+ id: message_id,
+ error: Some(Error {
+ code: -32601,
+ message: format!("Unrecognized method `{}`", msg.method),
+ data: None,
+ }),
+ result: None,
+ };
+ if let Ok(response) = serde_json::to_string(&response) {
+ response_channel.send(response).await.ok();
+ }
+ }
+ }
+ };
let notification_handlers = notification_handlers.clone();
let response_handlers = response_handlers.clone();
let io_handlers = io_handlers.clone();
async move |cx| {
- Self::handle_input(
+ Self::handle_incoming_messages(
stdout,
- on_unhandled_notification,
+ unhandled_notification_wrapper,
notification_handlers,
response_handlers,
io_handlers,
@@ -433,7 +457,7 @@ impl LanguageServer {
stdout.or(stderr)
});
let output_task = cx.background_spawn({
- Self::handle_output(
+ Self::handle_outgoing_messages(
stdin,
outbound_rx,
output_done_tx,
@@ -479,9 +503,9 @@ impl LanguageServer {
self.code_action_kinds.clone()
}
- async fn handle_input<Stdout, F>(
+ async fn handle_incoming_messages<Stdout>(
stdout: Stdout,
- mut on_unhandled_notification: F,
+ on_unhandled_notification: impl AsyncFn(NotificationOrRequest) + 'static + Send,
notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
io_handlers: Arc<Mutex<HashMap<i32, IoHandler>>>,
@@ -489,7 +513,6 @@ impl LanguageServer {
) -> anyhow::Result<()>
where
Stdout: AsyncRead + Unpin + Send + 'static,
- F: FnMut(AnyNotification) + 'static + Send,
{
use smol::stream::StreamExt;
let stdout = BufReader::new(stdout);
@@ -506,15 +529,19 @@ impl LanguageServer {
cx.background_executor().clone(),
);
- while let Some(msg) = input_handler.notifications_channel.next().await {
- {
+ while let Some(msg) = input_handler.incoming_messages.next().await {
+ let unhandled_message = {
let mut notification_handlers = notification_handlers.lock();
if let Some(handler) = notification_handlers.get_mut(msg.method.as_str()) {
handler(msg.id, msg.params.unwrap_or(Value::Null), cx);
+ None
} else {
- drop(notification_handlers);
- on_unhandled_notification(msg);
+ Some(msg)
}
+ };
+
+ if let Some(msg) = unhandled_message {
+ on_unhandled_notification(msg).await;
}
// Don't starve the main thread when receiving lots of notifications at once.
@@ -558,7 +585,7 @@ impl LanguageServer {
}
}
- async fn handle_output<Stdin>(
+ async fn handle_outgoing_messages<Stdin>(
stdin: Stdin,
outbound_rx: channel::Receiver<String>,
output_done_tx: barrier::Sender,
@@ -1036,7 +1063,9 @@ impl LanguageServer {
jsonrpc: JSON_RPC_VERSION,
id,
value: LspResult::Error(Some(Error {
+ code: lsp_types::error_codes::REQUEST_FAILED,
message: error.to_string(),
+ data: None,
})),
},
};
@@ -1057,7 +1086,9 @@ impl LanguageServer {
id,
result: None,
error: Some(Error {
+ code: -32700, // Parse error
message: error.to_string(),
+ data: None,
}),
};
if let Some(response) = serde_json::to_string(&response).log_err() {
@@ -1559,7 +1590,7 @@ impl FakeLanguageServer {
root,
Some(workspace_folders.clone()),
cx,
- |_| {},
+ |_| false,
);
server.process_name = process_name;
let fake = FakeLanguageServer {
@@ -1582,9 +1613,10 @@ impl FakeLanguageServer {
notifications_tx
.try_send((
msg.method.to_string(),
- msg.params.unwrap_or(Value::Null).to_string(),
+ msg.params.as_ref().unwrap_or(&Value::Null).to_string(),
))
.ok();
+ true
},
);
server.process_name = name.as_str().into();
@@ -1862,7 +1894,7 @@ mod tests {
#[gpui::test]
fn test_deserialize_string_digit_id() {
let json = r#"{"jsonrpc":"2.0","id":"2","method":"workspace/configuration","params":{"items":[{"scopeUri":"file:///Users/mph/Devel/personal/hello-scala/","section":"metals"}]}}"#;
- let notification = serde_json::from_str::<AnyNotification>(json)
+ let notification = serde_json::from_str::<NotificationOrRequest>(json)
.expect("message with string id should be parsed");
let expected_id = RequestId::Str("2".to_string());
assert_eq!(notification.id, Some(expected_id));
@@ -1871,7 +1903,7 @@ mod tests {
#[gpui::test]
fn test_deserialize_string_id() {
let json = r#"{"jsonrpc":"2.0","id":"anythingAtAll","method":"workspace/configuration","params":{"items":[{"scopeUri":"file:///Users/mph/Devel/personal/hello-scala/","section":"metals"}]}}"#;
- let notification = serde_json::from_str::<AnyNotification>(json)
+ let notification = serde_json::from_str::<NotificationOrRequest>(json)
.expect("message with string id should be parsed");
let expected_id = RequestId::Str("anythingAtAll".to_string());
assert_eq!(notification.id, Some(expected_id));
@@ -1880,7 +1912,7 @@ mod tests {
#[gpui::test]
fn test_deserialize_int_id() {
let json = r#"{"jsonrpc":"2.0","id":2,"method":"workspace/configuration","params":{"items":[{"scopeUri":"file:///Users/mph/Devel/personal/hello-scala/","section":"metals"}]}}"#;
- let notification = serde_json::from_str::<AnyNotification>(json)
+ let notification = serde_json::from_str::<NotificationOrRequest>(json)
.expect("message with string id should be parsed");
let expected_id = RequestId::Int(2);
assert_eq!(notification.id, Some(expected_id));