@@ -20,10 +20,10 @@ use std::{
future::Future,
io::Write,
path::PathBuf,
- str::FromStr,
+ str::{self, FromStr as _},
sync::{
atomic::{AtomicUsize, Ordering::SeqCst},
- Arc,
+ Arc, Weak,
},
};
use std::{path::Path, process::Stdio};
@@ -34,16 +34,18 @@ const CONTENT_LEN_HEADER: &str = "Content-Length: ";
type NotificationHandler = Box<dyn Send + FnMut(Option<usize>, &str, AsyncAppContext)>;
type ResponseHandler = Box<dyn Send + FnOnce(Result<&str, Error>)>;
+type IoHandler = Box<dyn Send + FnMut(bool, &str)>;
pub struct LanguageServer {
server_id: LanguageServerId,
next_id: AtomicUsize,
- outbound_tx: channel::Sender<Vec<u8>>,
+ outbound_tx: channel::Sender<String>,
name: String,
capabilities: ServerCapabilities,
code_action_kinds: Option<Vec<CodeActionKind>>,
notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
response_handlers: Arc<Mutex<Option<HashMap<usize, ResponseHandler>>>>,
+ io_handlers: Arc<Mutex<HashMap<usize, IoHandler>>>,
executor: Arc<executor::Background>,
#[allow(clippy::type_complexity)]
io_tasks: Mutex<Option<(Task<Option<()>>, Task<Option<()>>)>>,
@@ -56,9 +58,16 @@ pub struct LanguageServer {
#[repr(transparent)]
pub struct LanguageServerId(pub usize);
-pub struct Subscription {
- method: &'static str,
- notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
+pub enum Subscription {
+ Detached,
+ Notification {
+ method: &'static str,
+ notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
+ },
+ Io {
+ id: usize,
+ io_handlers: Weak<Mutex<HashMap<usize, IoHandler>>>,
+ },
}
#[derive(Serialize, Deserialize)]
@@ -177,33 +186,40 @@ impl LanguageServer {
Stdout: AsyncRead + Unpin + Send + 'static,
F: FnMut(AnyNotification) + 'static + Send,
{
- let (outbound_tx, outbound_rx) = channel::unbounded::<Vec<u8>>();
+ let (outbound_tx, outbound_rx) = channel::unbounded::<String>();
+ let (output_done_tx, output_done_rx) = barrier::channel();
let notification_handlers =
Arc::new(Mutex::new(HashMap::<_, NotificationHandler>::default()));
let response_handlers =
Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default())));
+ let io_handlers = Arc::new(Mutex::new(HashMap::default()));
let input_task = cx.spawn(|cx| {
- let notification_handlers = notification_handlers.clone();
- let response_handlers = response_handlers.clone();
Self::handle_input(
stdout,
on_unhandled_notification,
- notification_handlers,
- response_handlers,
+ notification_handlers.clone(),
+ response_handlers.clone(),
+ io_handlers.clone(),
cx,
)
.log_err()
});
- let (output_done_tx, output_done_rx) = barrier::channel();
let output_task = cx.background().spawn({
- let response_handlers = response_handlers.clone();
- Self::handle_output(stdin, outbound_rx, output_done_tx, response_handlers).log_err()
+ Self::handle_output(
+ stdin,
+ outbound_rx,
+ output_done_tx,
+ response_handlers.clone(),
+ io_handlers.clone(),
+ )
+ .log_err()
});
Self {
server_id,
notification_handlers,
response_handlers,
+ io_handlers,
name: Default::default(),
capabilities: Default::default(),
code_action_kinds,
@@ -226,6 +242,7 @@ impl LanguageServer {
mut on_unhandled_notification: F,
notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
response_handlers: Arc<Mutex<Option<HashMap<usize, ResponseHandler>>>>,
+ io_handlers: Arc<Mutex<HashMap<usize, IoHandler>>>,
cx: AsyncAppContext,
) -> anyhow::Result<()>
where
@@ -252,7 +269,13 @@ impl LanguageServer {
buffer.resize(message_len, 0);
stdout.read_exact(&mut buffer).await?;
- log::trace!("incoming message:{}", String::from_utf8_lossy(&buffer));
+
+ if let Ok(message) = str::from_utf8(&buffer) {
+ log::trace!("incoming message:{}", message);
+ for handler in io_handlers.lock().values_mut() {
+ handler(true, message);
+ }
+ }
if let Ok(msg) = serde_json::from_slice::<AnyNotification>(&buffer) {
if let Some(handler) = notification_handlers.lock().get_mut(msg.method) {
@@ -291,9 +314,10 @@ impl LanguageServer {
async fn handle_output<Stdin>(
stdin: Stdin,
- outbound_rx: channel::Receiver<Vec<u8>>,
+ outbound_rx: channel::Receiver<String>,
output_done_tx: barrier::Sender,
response_handlers: Arc<Mutex<Option<HashMap<usize, ResponseHandler>>>>,
+ io_handlers: Arc<Mutex<HashMap<usize, IoHandler>>>,
) -> anyhow::Result<()>
where
Stdin: AsyncWrite + Unpin + Send + 'static,
@@ -307,13 +331,17 @@ impl LanguageServer {
});
let mut content_len_buffer = Vec::new();
while let Ok(message) = outbound_rx.recv().await {
- log::trace!("outgoing message:{}", String::from_utf8_lossy(&message));
+ log::trace!("outgoing message:{}", message);
+ for handler in io_handlers.lock().values_mut() {
+ handler(false, &message);
+ }
+
content_len_buffer.clear();
write!(content_len_buffer, "{}", message.len()).unwrap();
stdin.write_all(CONTENT_LEN_HEADER.as_bytes()).await?;
stdin.write_all(&content_len_buffer).await?;
stdin.write_all("\r\n\r\n".as_bytes()).await?;
- stdin.write_all(&message).await?;
+ stdin.write_all(message.as_bytes()).await?;
stdin.flush().await?;
}
drop(output_done_tx);
@@ -464,6 +492,19 @@ impl LanguageServer {
self.on_custom_request(T::METHOD, f)
}
+ #[must_use]
+ pub fn on_io<F>(&self, f: F) -> Subscription
+ where
+ F: 'static + Send + FnMut(bool, &str),
+ {
+ let id = self.next_id.fetch_add(1, SeqCst);
+ self.io_handlers.lock().insert(id, Box::new(f));
+ Subscription::Io {
+ id,
+ io_handlers: Arc::downgrade(&self.io_handlers),
+ }
+ }
+
pub fn remove_request_handler<T: request::Request>(&self) {
self.notification_handlers.lock().remove(T::METHOD);
}
@@ -490,7 +531,7 @@ impl LanguageServer {
prev_handler.is_none(),
"registered multiple handlers for the same LSP method"
);
- Subscription {
+ Subscription::Notification {
method,
notification_handlers: self.notification_handlers.clone(),
}
@@ -537,7 +578,7 @@ impl LanguageServer {
},
};
if let Some(response) =
- serde_json::to_vec(&response).log_err()
+ serde_json::to_string(&response).log_err()
{
outbound_tx.try_send(response).ok();
}
@@ -560,7 +601,7 @@ impl LanguageServer {
message: error.to_string(),
}),
};
- if let Some(response) = serde_json::to_vec(&response).log_err() {
+ if let Some(response) = serde_json::to_string(&response).log_err() {
outbound_tx.try_send(response).ok();
}
}
@@ -572,7 +613,7 @@ impl LanguageServer {
prev_handler.is_none(),
"registered multiple handlers for the same LSP method"
);
- Subscription {
+ Subscription::Notification {
method,
notification_handlers: self.notification_handlers.clone(),
}
@@ -612,14 +653,14 @@ impl LanguageServer {
fn request_internal<T: request::Request>(
next_id: &AtomicUsize,
response_handlers: &Mutex<Option<HashMap<usize, ResponseHandler>>>,
- outbound_tx: &channel::Sender<Vec<u8>>,
+ outbound_tx: &channel::Sender<String>,
params: T::Params,
) -> impl 'static + Future<Output = Result<T::Result>>
where
T::Result: 'static + Send,
{
let id = next_id.fetch_add(1, SeqCst);
- let message = serde_json::to_vec(&Request {
+ let message = serde_json::to_string(&Request {
jsonrpc: JSON_RPC_VERSION,
id,
method: T::METHOD,
@@ -662,10 +703,10 @@ impl LanguageServer {
}
fn notify_internal<T: notification::Notification>(
- outbound_tx: &channel::Sender<Vec<u8>>,
+ outbound_tx: &channel::Sender<String>,
params: T::Params,
) -> Result<()> {
- let message = serde_json::to_vec(&Notification {
+ let message = serde_json::to_string(&Notification {
jsonrpc: JSON_RPC_VERSION,
method: T::METHOD,
params,
@@ -686,7 +727,7 @@ impl Drop for LanguageServer {
impl Subscription {
pub fn detach(mut self) {
- self.method = "";
+ *(&mut self) = Self::Detached;
}
}
@@ -698,7 +739,20 @@ impl fmt::Display for LanguageServerId {
impl Drop for Subscription {
fn drop(&mut self) {
- self.notification_handlers.lock().remove(self.method);
+ match self {
+ Subscription::Detached => {}
+ Subscription::Notification {
+ method,
+ notification_handlers,
+ } => {
+ notification_handlers.lock().remove(method);
+ }
+ Subscription::Io { id, io_handlers } => {
+ if let Some(io_handlers) = io_handlers.upgrade() {
+ io_handlers.lock().remove(id);
+ }
+ }
+ }
}
}