neovim_connection.rs

  1#[cfg(feature = "neovim")]
  2use std::ops::{Deref, DerefMut};
  3use std::{ops::Range, path::PathBuf};
  4
  5#[cfg(feature = "neovim")]
  6use async_compat::Compat;
  7#[cfg(feature = "neovim")]
  8use async_trait::async_trait;
  9#[cfg(feature = "neovim")]
 10use gpui::keymap_matcher::Keystroke;
 11
 12use language::Point;
 13
 14#[cfg(feature = "neovim")]
 15use lazy_static::lazy_static;
 16#[cfg(feature = "neovim")]
 17use nvim_rs::{
 18    create::tokio::new_child_cmd, error::LoopError, Handler, Neovim, UiAttachOptions, Value,
 19};
 20#[cfg(feature = "neovim")]
 21use parking_lot::ReentrantMutex;
 22use serde::{Deserialize, Serialize};
 23#[cfg(feature = "neovim")]
 24use tokio::{
 25    process::{Child, ChildStdin, Command},
 26    task::JoinHandle,
 27};
 28
 29use crate::state::Mode;
 30use collections::VecDeque;
 31
 32// Neovim doesn't like to be started simultaneously from multiple threads. We use this lock
 33// to ensure we are only constructing one neovim connection at a time.
 34#[cfg(feature = "neovim")]
 35lazy_static! {
 36    static ref NEOVIM_LOCK: ReentrantMutex<()> = ReentrantMutex::new(());
 37}
 38
 39#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
 40pub enum NeovimData {
 41    Put { state: String },
 42    Key(String),
 43    Get { state: String, mode: Option<Mode> },
 44}
 45
 46pub struct NeovimConnection {
 47    data: VecDeque<NeovimData>,
 48    #[cfg(feature = "neovim")]
 49    test_case_id: String,
 50    #[cfg(feature = "neovim")]
 51    nvim: Neovim<nvim_rs::compat::tokio::Compat<ChildStdin>>,
 52    #[cfg(feature = "neovim")]
 53    _join_handle: JoinHandle<Result<(), Box<LoopError>>>,
 54    #[cfg(feature = "neovim")]
 55    _child: Child,
 56}
 57
 58impl NeovimConnection {
 59    pub async fn new(test_case_id: String) -> Self {
 60        #[cfg(feature = "neovim")]
 61        let handler = NvimHandler {};
 62        #[cfg(feature = "neovim")]
 63        let (nvim, join_handle, child) = Compat::new(async {
 64            // Ensure we don't create neovim connections in parallel
 65            let _lock = NEOVIM_LOCK.lock();
 66            let (nvim, join_handle, child) = new_child_cmd(
 67                &mut Command::new("nvim").arg("--embed").arg("--clean"),
 68                handler,
 69            )
 70            .await
 71            .expect("Could not connect to neovim process");
 72
 73            nvim.ui_attach(100, 100, &UiAttachOptions::default())
 74                .await
 75                .expect("Could not attach to ui");
 76
 77            // Makes system act a little more like zed in terms of indentation
 78            nvim.set_option("smartindent", nvim_rs::Value::Boolean(true))
 79                .await
 80                .expect("Could not set smartindent on startup");
 81
 82            (nvim, join_handle, child)
 83        })
 84        .await;
 85
 86        Self {
 87            #[cfg(feature = "neovim")]
 88            data: Default::default(),
 89            #[cfg(not(feature = "neovim"))]
 90            data: Self::read_test_data(&test_case_id),
 91            #[cfg(feature = "neovim")]
 92            test_case_id,
 93            #[cfg(feature = "neovim")]
 94            nvim,
 95            #[cfg(feature = "neovim")]
 96            _join_handle: join_handle,
 97            #[cfg(feature = "neovim")]
 98            _child: child,
 99        }
100    }
101
102    // Sends a keystroke to the neovim process.
103    #[cfg(feature = "neovim")]
104    pub async fn send_keystroke(&mut self, keystroke_text: &str) {
105        let keystroke = Keystroke::parse(keystroke_text).unwrap();
106        let special = keystroke.shift
107            || keystroke.ctrl
108            || keystroke.alt
109            || keystroke.cmd
110            || keystroke.key.len() > 1;
111        let start = if special { "<" } else { "" };
112        let shift = if keystroke.shift { "S-" } else { "" };
113        let ctrl = if keystroke.ctrl { "C-" } else { "" };
114        let alt = if keystroke.alt { "M-" } else { "" };
115        let cmd = if keystroke.cmd { "D-" } else { "" };
116        let end = if special { ">" } else { "" };
117
118        let key = format!("{start}{shift}{ctrl}{alt}{cmd}{}{end}", keystroke.key);
119
120        self.data
121            .push_back(NeovimData::Key(keystroke_text.to_string()));
122        self.nvim
123            .input(&key)
124            .await
125            .expect("Could not input keystroke");
126    }
127
128    #[cfg(not(feature = "neovim"))]
129    pub async fn send_keystroke(&mut self, keystroke_text: &str) {
130        if matches!(self.data.front(), Some(NeovimData::Get { .. })) {
131            self.data.pop_front();
132        }
133        assert_eq!(
134            self.data.pop_front(),
135            Some(NeovimData::Key(keystroke_text.to_string())),
136            "operation does not match recorded script. re-record with --features=neovim"
137        );
138    }
139
140    #[cfg(feature = "neovim")]
141    pub async fn set_state(&mut self, marked_text: &str) {
142        let (text, selection) = parse_state(&marked_text);
143
144        let nvim_buffer = self
145            .nvim
146            .get_current_buf()
147            .await
148            .expect("Could not get neovim buffer");
149        let lines = text
150            .split('\n')
151            .map(|line| line.to_string())
152            .collect::<Vec<_>>();
153
154        nvim_buffer
155            .set_lines(0, -1, false, lines)
156            .await
157            .expect("Could not set nvim buffer text");
158
159        self.nvim
160            .input("<escape>")
161            .await
162            .expect("Could not send escape to nvim");
163        self.nvim
164            .input("<escape>")
165            .await
166            .expect("Could not send escape to nvim");
167
168        let nvim_window = self
169            .nvim
170            .get_current_win()
171            .await
172            .expect("Could not get neovim window");
173
174        if !selection.is_empty() {
175            panic!("Setting neovim state with non empty selection not yet supported");
176        }
177        let cursor = selection.start;
178        nvim_window
179            .set_cursor((cursor.row as i64 + 1, cursor.column as i64))
180            .await
181            .expect("Could not set nvim cursor position");
182
183        if let Some(NeovimData::Get { mode, state }) = self.data.back() {
184            if *mode == Some(Mode::Normal) && *state == marked_text {
185                return;
186            }
187        }
188        self.data.push_back(NeovimData::Put {
189            state: marked_text.to_string(),
190        })
191    }
192
193    #[cfg(not(feature = "neovim"))]
194    pub async fn set_state(&mut self, marked_text: &str) {
195        if let Some(NeovimData::Get { mode, state: text }) = self.data.front() {
196            if *mode == Some(Mode::Normal) && *text == marked_text {
197                return;
198            }
199            self.data.pop_front();
200        }
201        assert_eq!(
202            self.data.pop_front(),
203            Some(NeovimData::Put {
204                state: marked_text.to_string()
205            }),
206            "operation does not match recorded script. re-record with --features=neovim"
207        );
208    }
209
210    #[cfg(feature = "neovim")]
211    pub async fn state(&mut self) -> (Option<Mode>, String, Range<Point>) {
212        let nvim_buffer = self
213            .nvim
214            .get_current_buf()
215            .await
216            .expect("Could not get neovim buffer");
217        let text = nvim_buffer
218            .get_lines(0, -1, false)
219            .await
220            .expect("Could not get buffer text")
221            .join("\n");
222
223        let cursor_row: u32 = self
224            .nvim
225            .command_output("echo line('.')")
226            .await
227            .unwrap()
228            .parse::<u32>()
229            .unwrap()
230            - 1; // Neovim rows start at 1
231        let cursor_col: u32 = self
232            .nvim
233            .command_output("echo col('.')")
234            .await
235            .unwrap()
236            .parse::<u32>()
237            .unwrap()
238            - 1; // Neovim columns start at 1
239
240        let nvim_mode_text = self
241            .nvim
242            .get_mode()
243            .await
244            .expect("Could not get mode")
245            .into_iter()
246            .find_map(|(key, value)| {
247                if key.as_str() == Some("mode") {
248                    Some(value.as_str().unwrap().to_owned())
249                } else {
250                    None
251                }
252            })
253            .expect("Could not find mode value");
254
255        let mode = match nvim_mode_text.as_ref() {
256            "i" => Some(Mode::Insert),
257            "n" => Some(Mode::Normal),
258            "v" => Some(Mode::Visual { line: false }),
259            "V" => Some(Mode::Visual { line: true }),
260            _ => None,
261        };
262
263        let (start, end) = if let Some(Mode::Visual { .. }) = mode {
264            self.nvim
265                .input("<escape>")
266                .await
267                .expect("Could not exit visual mode");
268            let nvim_buffer = self
269                .nvim
270                .get_current_buf()
271                .await
272                .expect("Could not get neovim buffer");
273            let (start_row, start_col) = nvim_buffer
274                .get_mark("<")
275                .await
276                .expect("Could not get selection start");
277            let (end_row, end_col) = nvim_buffer
278                .get_mark(">")
279                .await
280                .expect("Could not get selection end");
281            self.nvim
282                .input("gv")
283                .await
284                .expect("Could not reselect visual selection");
285
286            if cursor_row == start_row as u32 - 1 && cursor_col == start_col as u32 {
287                (
288                    Point::new(end_row as u32 - 1, end_col as u32),
289                    Point::new(start_row as u32 - 1, start_col as u32),
290                )
291            } else {
292                (
293                    Point::new(start_row as u32 - 1, start_col as u32),
294                    Point::new(end_row as u32 - 1, end_col as u32),
295                )
296            }
297        } else {
298            (
299                Point::new(cursor_row, cursor_col),
300                Point::new(cursor_row, cursor_col),
301            )
302        };
303
304        let state = NeovimData::Get {
305            mode,
306            state: encode_range(&text, start..end),
307        };
308
309        if self.data.back() != Some(&state) {
310            self.data.push_back(state.clone());
311        }
312
313        (mode, text, start..end)
314    }
315
316    #[cfg(not(feature = "neovim"))]
317    pub async fn state(&mut self) -> (Option<Mode>, String, Range<Point>) {
318        if let Some(NeovimData::Get { state: text, mode }) = self.data.front() {
319            let (text, range) = parse_state(text);
320            (*mode, text, range)
321        } else {
322            panic!("operation does not match recorded script. re-record with --features=neovim");
323        }
324    }
325
326    pub async fn selection(&mut self) -> Range<Point> {
327        self.state().await.2
328    }
329
330    pub async fn mode(&mut self) -> Option<Mode> {
331        self.state().await.0
332    }
333
334    pub async fn text(&mut self) -> String {
335        self.state().await.1
336    }
337
338    fn test_data_path(test_case_id: &str) -> PathBuf {
339        let mut data_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
340        data_path.push("test_data");
341        data_path.push(format!("{}.json", test_case_id));
342        data_path
343    }
344
345    #[cfg(not(feature = "neovim"))]
346    fn read_test_data(test_case_id: &str) -> VecDeque<NeovimData> {
347        let path = Self::test_data_path(test_case_id);
348        let json = std::fs::read_to_string(path).expect(
349            "Could not read test data. Is it generated? Try running test with '--features neovim'",
350        );
351
352        let mut result = VecDeque::new();
353        for line in json.lines() {
354            result.push_back(
355                serde_json::from_str(line)
356                    .expect("invalid test data. regenerate it with '--features neovim'"),
357            );
358        }
359        result
360    }
361
362    #[cfg(feature = "neovim")]
363    fn write_test_data(test_case_id: &str, data: &VecDeque<NeovimData>) {
364        let path = Self::test_data_path(test_case_id);
365        let mut json = Vec::new();
366        for entry in data {
367            serde_json::to_writer(&mut json, entry).unwrap();
368            json.push(b'\n');
369        }
370        std::fs::create_dir_all(path.parent().unwrap())
371            .expect("could not create test data directory");
372        std::fs::write(path, json).expect("could not write out test data");
373    }
374}
375
376#[cfg(feature = "neovim")]
377impl Deref for NeovimConnection {
378    type Target = Neovim<nvim_rs::compat::tokio::Compat<ChildStdin>>;
379
380    fn deref(&self) -> &Self::Target {
381        &self.nvim
382    }
383}
384
385#[cfg(feature = "neovim")]
386impl DerefMut for NeovimConnection {
387    fn deref_mut(&mut self) -> &mut Self::Target {
388        &mut self.nvim
389    }
390}
391
392#[cfg(feature = "neovim")]
393impl Drop for NeovimConnection {
394    fn drop(&mut self) {
395        Self::write_test_data(&self.test_case_id, &self.data);
396    }
397}
398
399#[cfg(feature = "neovim")]
400#[derive(Clone)]
401struct NvimHandler {}
402
403#[cfg(feature = "neovim")]
404#[async_trait]
405impl Handler for NvimHandler {
406    type Writer = nvim_rs::compat::tokio::Compat<ChildStdin>;
407
408    async fn handle_request(
409        &self,
410        _event_name: String,
411        _arguments: Vec<Value>,
412        _neovim: Neovim<Self::Writer>,
413    ) -> Result<Value, Value> {
414        unimplemented!();
415    }
416
417    async fn handle_notify(
418        &self,
419        _event_name: String,
420        _arguments: Vec<Value>,
421        _neovim: Neovim<Self::Writer>,
422    ) {
423    }
424}
425
426fn parse_state(marked_text: &str) -> (String, Range<Point>) {
427    let (text, ranges) = util::test::marked_text_ranges(marked_text, true);
428    let byte_range = ranges[0].clone();
429    let mut point_range = Point::zero()..Point::zero();
430    let mut ix = 0;
431    let mut position = Point::zero();
432    for c in text.chars().chain(['\0']) {
433        if ix == byte_range.start {
434            point_range.start = position;
435        }
436        if ix == byte_range.end {
437            point_range.end = position;
438        }
439        let len_utf8 = c.len_utf8();
440        ix += len_utf8;
441        if c == '\n' {
442            position.row += 1;
443            position.column = 0;
444        } else {
445            position.column += len_utf8 as u32;
446        }
447    }
448    (text, point_range)
449}
450
451#[cfg(feature = "neovim")]
452fn encode_range(text: &str, range: Range<Point>) -> String {
453    let mut byte_range = 0..0;
454    let mut ix = 0;
455    let mut position = Point::zero();
456    for c in text.chars().chain(['\0']) {
457        if position == range.start {
458            byte_range.start = ix;
459        }
460        if position == range.end {
461            byte_range.end = ix;
462        }
463        let len_utf8 = c.len_utf8();
464        ix += len_utf8;
465        if c == '\n' {
466            position.row += 1;
467            position.column = 0;
468        } else {
469            position.column += len_utf8 as u32;
470        }
471    }
472    util::test::generate_marked_text(text, &[byte_range], true)
473}