neovim_connection.rs

  1#[cfg(feature = "neovim")]
  2use std::ops::{Deref, DerefMut};
  3use std::{ops::Range, path::PathBuf};
  4
  5#[cfg(feature = "neovim")]
  6use async_compat::Compat;
  7#[cfg(feature = "neovim")]
  8use async_trait::async_trait;
  9#[cfg(feature = "neovim")]
 10use gpui::keymap_matcher::Keystroke;
 11
 12use language::{Point, Selection};
 13
 14#[cfg(feature = "neovim")]
 15use lazy_static::lazy_static;
 16#[cfg(feature = "neovim")]
 17use nvim_rs::{
 18    create::tokio::new_child_cmd, error::LoopError, Handler, Neovim, UiAttachOptions, Value,
 19};
 20#[cfg(feature = "neovim")]
 21use parking_lot::ReentrantMutex;
 22use serde::{Deserialize, Serialize};
 23#[cfg(feature = "neovim")]
 24use tokio::{
 25    process::{Child, ChildStdin, Command},
 26    task::JoinHandle,
 27};
 28
 29use crate::state::Mode;
 30use collections::VecDeque;
 31
 32// Neovim doesn't like to be started simultaneously from multiple threads. We use thsi lock
 33// to ensure we are only constructing one neovim connection at a time.
 34#[cfg(feature = "neovim")]
 35lazy_static! {
 36    static ref NEOVIM_LOCK: ReentrantMutex<()> = ReentrantMutex::new(());
 37}
 38
 39#[derive(Serialize, Deserialize)]
 40pub enum NeovimData {
 41    Text(String),
 42    Selection { start: (u32, u32), end: (u32, u32) },
 43    Mode(Option<Mode>),
 44}
 45
 46pub struct NeovimConnection {
 47    data: VecDeque<NeovimData>,
 48    #[cfg(feature = "neovim")]
 49    test_case_id: String,
 50    #[cfg(feature = "neovim")]
 51    nvim: Neovim<nvim_rs::compat::tokio::Compat<ChildStdin>>,
 52    #[cfg(feature = "neovim")]
 53    _join_handle: JoinHandle<Result<(), Box<LoopError>>>,
 54    #[cfg(feature = "neovim")]
 55    _child: Child,
 56}
 57
 58impl NeovimConnection {
 59    pub async fn new(test_case_id: String) -> Self {
 60        #[cfg(feature = "neovim")]
 61        let handler = NvimHandler {};
 62        #[cfg(feature = "neovim")]
 63        let (nvim, join_handle, child) = Compat::new(async {
 64            // Ensure we don't create neovim connections in parallel
 65            let _lock = NEOVIM_LOCK.lock();
 66            let (nvim, join_handle, child) = new_child_cmd(
 67                &mut Command::new("nvim").arg("--embed").arg("--clean"),
 68                handler,
 69            )
 70            .await
 71            .expect("Could not connect to neovim process");
 72
 73            nvim.ui_attach(100, 100, &UiAttachOptions::default())
 74                .await
 75                .expect("Could not attach to ui");
 76
 77            // Makes system act a little more like zed in terms of indentation
 78            nvim.set_option("smartindent", nvim_rs::Value::Boolean(true))
 79                .await
 80                .expect("Could not set smartindent on startup");
 81
 82            (nvim, join_handle, child)
 83        })
 84        .await;
 85
 86        Self {
 87            #[cfg(feature = "neovim")]
 88            data: Default::default(),
 89            #[cfg(not(feature = "neovim"))]
 90            data: Self::read_test_data(&test_case_id),
 91            #[cfg(feature = "neovim")]
 92            test_case_id,
 93            #[cfg(feature = "neovim")]
 94            nvim,
 95            #[cfg(feature = "neovim")]
 96            _join_handle: join_handle,
 97            #[cfg(feature = "neovim")]
 98            _child: child,
 99        }
100    }
101
102    // Sends a keystroke to the neovim process.
103    #[cfg(feature = "neovim")]
104    pub async fn send_keystroke(&mut self, keystroke_text: &str) {
105        let keystroke = Keystroke::parse(keystroke_text).unwrap();
106        let special = keystroke.shift
107            || keystroke.ctrl
108            || keystroke.alt
109            || keystroke.cmd
110            || keystroke.key.len() > 1;
111        let start = if special { "<" } else { "" };
112        let shift = if keystroke.shift { "S-" } else { "" };
113        let ctrl = if keystroke.ctrl { "C-" } else { "" };
114        let alt = if keystroke.alt { "M-" } else { "" };
115        let cmd = if keystroke.cmd { "D-" } else { "" };
116        let end = if special { ">" } else { "" };
117
118        let key = format!("{start}{shift}{ctrl}{alt}{cmd}{}{end}", keystroke.key);
119
120        self.nvim
121            .input(&key)
122            .await
123            .expect("Could not input keystroke");
124    }
125
126    // If not running with a live neovim connection, this is a no-op
127    #[cfg(not(feature = "neovim"))]
128    pub async fn send_keystroke(&mut self, _keystroke_text: &str) {}
129
130    #[cfg(feature = "neovim")]
131    pub async fn set_state(&mut self, selection: Selection<Point>, text: &str) {
132        let nvim_buffer = self
133            .nvim
134            .get_current_buf()
135            .await
136            .expect("Could not get neovim buffer");
137        let lines = text
138            .split('\n')
139            .map(|line| line.to_string())
140            .collect::<Vec<_>>();
141
142        nvim_buffer
143            .set_lines(0, -1, false, lines)
144            .await
145            .expect("Could not set nvim buffer text");
146
147        self.nvim
148            .input("<escape>")
149            .await
150            .expect("Could not send escape to nvim");
151        self.nvim
152            .input("<escape>")
153            .await
154            .expect("Could not send escape to nvim");
155
156        let nvim_window = self
157            .nvim
158            .get_current_win()
159            .await
160            .expect("Could not get neovim window");
161
162        if !selection.is_empty() {
163            panic!("Setting neovim state with non empty selection not yet supported");
164        }
165        let cursor = selection.head();
166        nvim_window
167            .set_cursor((cursor.row as i64 + 1, cursor.column as i64))
168            .await
169            .expect("Could not set nvim cursor position");
170    }
171
172    #[cfg(not(feature = "neovim"))]
173    pub async fn set_state(&mut self, _selection: Selection<Point>, _text: &str) {}
174
175    #[cfg(feature = "neovim")]
176    pub async fn text(&mut self) -> String {
177        let nvim_buffer = self
178            .nvim
179            .get_current_buf()
180            .await
181            .expect("Could not get neovim buffer");
182        let text = nvim_buffer
183            .get_lines(0, -1, false)
184            .await
185            .expect("Could not get buffer text")
186            .join("\n");
187
188        self.data.push_back(NeovimData::Text(text.clone()));
189
190        text
191    }
192
193    #[cfg(not(feature = "neovim"))]
194    pub async fn text(&mut self) -> String {
195        if let Some(NeovimData::Text(text)) = self.data.pop_front() {
196            text
197        } else {
198            panic!("Invalid test data. Is test deterministic? Try running with '--features neovim' to regenerate");
199        }
200    }
201
202    #[cfg(feature = "neovim")]
203    pub async fn selection(&mut self) -> Range<Point> {
204        let cursor_row: u32 = self
205            .nvim
206            .command_output("echo line('.')")
207            .await
208            .unwrap()
209            .parse::<u32>()
210            .unwrap()
211            - 1; // Neovim rows start at 1
212        let cursor_col: u32 = self
213            .nvim
214            .command_output("echo col('.')")
215            .await
216            .unwrap()
217            .parse::<u32>()
218            .unwrap()
219            - 1; // Neovim columns start at 1
220
221        let (start, end) = if let Some(Mode::Visual { .. }) = self.mode().await {
222            self.nvim
223                .input("<escape>")
224                .await
225                .expect("Could not exit visual mode");
226            let nvim_buffer = self
227                .nvim
228                .get_current_buf()
229                .await
230                .expect("Could not get neovim buffer");
231            let (start_row, start_col) = nvim_buffer
232                .get_mark("<")
233                .await
234                .expect("Could not get selection start");
235            let (end_row, end_col) = nvim_buffer
236                .get_mark(">")
237                .await
238                .expect("Could not get selection end");
239            self.nvim
240                .input("gv")
241                .await
242                .expect("Could not reselect visual selection");
243
244            if cursor_row == start_row as u32 - 1 && cursor_col == start_col as u32 {
245                (
246                    (end_row as u32 - 1, end_col as u32),
247                    (start_row as u32 - 1, start_col as u32),
248                )
249            } else {
250                (
251                    (start_row as u32 - 1, start_col as u32),
252                    (end_row as u32 - 1, end_col as u32),
253                )
254            }
255        } else {
256            ((cursor_row, cursor_col), (cursor_row, cursor_col))
257        };
258
259        self.data.push_back(NeovimData::Selection { start, end });
260
261        Point::new(start.0, start.1)..Point::new(end.0, end.1)
262    }
263
264    #[cfg(not(feature = "neovim"))]
265    pub async fn selection(&mut self) -> Range<Point> {
266        // Selection code fetches the mode. This emulates that.
267        let _mode = self.mode().await;
268        if let Some(NeovimData::Selection { start, end }) = self.data.pop_front() {
269            Point::new(start.0, start.1)..Point::new(end.0, end.1)
270        } else {
271            panic!("Invalid test data. Is test deterministic? Try running with '--features neovim' to regenerate");
272        }
273    }
274
275    #[cfg(feature = "neovim")]
276    pub async fn mode(&mut self) -> Option<Mode> {
277        let nvim_mode_text = self
278            .nvim
279            .get_mode()
280            .await
281            .expect("Could not get mode")
282            .into_iter()
283            .find_map(|(key, value)| {
284                if key.as_str() == Some("mode") {
285                    Some(value.as_str().unwrap().to_owned())
286                } else {
287                    None
288                }
289            })
290            .expect("Could not find mode value");
291
292        let mode = match nvim_mode_text.as_ref() {
293            "i" => Some(Mode::Insert),
294            "n" => Some(Mode::Normal),
295            "v" => Some(Mode::Visual { line: false }),
296            "V" => Some(Mode::Visual { line: true }),
297            _ => None,
298        };
299
300        self.data.push_back(NeovimData::Mode(mode.clone()));
301
302        mode
303    }
304
305    #[cfg(not(feature = "neovim"))]
306    pub async fn mode(&mut self) -> Option<Mode> {
307        if let Some(NeovimData::Mode(mode)) = self.data.pop_front() {
308            mode
309        } else {
310            panic!("Invalid test data. Is test deterministic? Try running with '--features neovim' to regenerate");
311        }
312    }
313
314    fn test_data_path(test_case_id: &str) -> PathBuf {
315        let mut data_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
316        data_path.push("test_data");
317        data_path.push(format!("{}.json", test_case_id));
318        data_path
319    }
320
321    #[cfg(not(feature = "neovim"))]
322    fn read_test_data(test_case_id: &str) -> VecDeque<NeovimData> {
323        let path = Self::test_data_path(test_case_id);
324        let json = std::fs::read_to_string(path).expect(
325            "Could not read test data. Is it generated? Try running test with '--features neovim'",
326        );
327
328        serde_json::from_str(&json)
329            .expect("Test data corrupted. Try regenerating it with '--features neovim'")
330    }
331}
332
333#[cfg(feature = "neovim")]
334impl Deref for NeovimConnection {
335    type Target = Neovim<nvim_rs::compat::tokio::Compat<ChildStdin>>;
336
337    fn deref(&self) -> &Self::Target {
338        &self.nvim
339    }
340}
341
342#[cfg(feature = "neovim")]
343impl DerefMut for NeovimConnection {
344    fn deref_mut(&mut self) -> &mut Self::Target {
345        &mut self.nvim
346    }
347}
348
349#[cfg(feature = "neovim")]
350impl Drop for NeovimConnection {
351    fn drop(&mut self) {
352        let path = Self::test_data_path(&self.test_case_id);
353        std::fs::create_dir_all(path.parent().unwrap())
354            .expect("Could not create test data directory");
355        let json = serde_json::to_string(&self.data).expect("Could not serialize test data");
356        std::fs::write(path, json).expect("Could not write out test data");
357    }
358}
359
360#[cfg(feature = "neovim")]
361#[derive(Clone)]
362struct NvimHandler {}
363
364#[cfg(feature = "neovim")]
365#[async_trait]
366impl Handler for NvimHandler {
367    type Writer = nvim_rs::compat::tokio::Compat<ChildStdin>;
368
369    async fn handle_request(
370        &self,
371        _event_name: String,
372        _arguments: Vec<Value>,
373        _neovim: Neovim<Self::Writer>,
374    ) -> Result<Value, Value> {
375        unimplemented!();
376    }
377
378    async fn handle_notify(
379        &self,
380        _event_name: String,
381        _arguments: Vec<Value>,
382        _neovim: Neovim<Self::Writer>,
383    ) {
384    }
385}