neovim_connection.rs

  1#[cfg(feature = "neovim")]
  2use std::{
  3    cmp,
  4    ops::{Deref, DerefMut},
  5};
  6use std::{ops::Range, path::PathBuf};
  7
  8#[cfg(feature = "neovim")]
  9use async_compat::Compat;
 10#[cfg(feature = "neovim")]
 11use async_trait::async_trait;
 12#[cfg(feature = "neovim")]
 13use gpui::keymap_matcher::Keystroke;
 14
 15use language::Point;
 16
 17#[cfg(feature = "neovim")]
 18use nvim_rs::{
 19    create::tokio::new_child_cmd, error::LoopError, Handler, Neovim, UiAttachOptions, Value,
 20};
 21#[cfg(feature = "neovim")]
 22use parking_lot::ReentrantMutex;
 23use serde::{Deserialize, Serialize};
 24#[cfg(feature = "neovim")]
 25use tokio::{
 26    process::{Child, ChildStdin, Command},
 27    task::JoinHandle,
 28};
 29
 30use crate::state::Mode;
 31use collections::VecDeque;
 32
 33// Neovim doesn't like to be started simultaneously from multiple threads. We use this lock
 34// to ensure we are only constructing one neovim connection at a time.
 35#[cfg(feature = "neovim")]
 36static NEOVIM_LOCK: ReentrantMutex<()> = ReentrantMutex::new(());
 37
 38#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
 39pub enum NeovimData {
 40    Put { state: String },
 41    Key(String),
 42    Get { state: String, mode: Option<Mode> },
 43}
 44
 45pub struct NeovimConnection {
 46    data: VecDeque<NeovimData>,
 47    #[cfg(feature = "neovim")]
 48    test_case_id: String,
 49    #[cfg(feature = "neovim")]
 50    nvim: Neovim<nvim_rs::compat::tokio::Compat<ChildStdin>>,
 51    #[cfg(feature = "neovim")]
 52    _join_handle: JoinHandle<Result<(), Box<LoopError>>>,
 53    #[cfg(feature = "neovim")]
 54    _child: Child,
 55}
 56
 57impl NeovimConnection {
 58    pub async fn new(test_case_id: String) -> Self {
 59        #[cfg(feature = "neovim")]
 60        let handler = NvimHandler {};
 61        #[cfg(feature = "neovim")]
 62        let (nvim, join_handle, child) = Compat::new(async {
 63            // Ensure we don't create neovim connections in parallel
 64            let _lock = NEOVIM_LOCK.lock();
 65            let (nvim, join_handle, child) = new_child_cmd(
 66                &mut Command::new("nvim").arg("--embed").arg("--clean"),
 67                handler,
 68            )
 69            .await
 70            .expect("Could not connect to neovim process");
 71
 72            nvim.ui_attach(100, 100, &UiAttachOptions::default())
 73                .await
 74                .expect("Could not attach to ui");
 75
 76            // Makes system act a little more like zed in terms of indentation
 77            nvim.set_option("smartindent", nvim_rs::Value::Boolean(true))
 78                .await
 79                .expect("Could not set smartindent on startup");
 80
 81            (nvim, join_handle, child)
 82        })
 83        .await;
 84
 85        Self {
 86            #[cfg(feature = "neovim")]
 87            data: Default::default(),
 88            #[cfg(not(feature = "neovim"))]
 89            data: Self::read_test_data(&test_case_id),
 90            #[cfg(feature = "neovim")]
 91            test_case_id,
 92            #[cfg(feature = "neovim")]
 93            nvim,
 94            #[cfg(feature = "neovim")]
 95            _join_handle: join_handle,
 96            #[cfg(feature = "neovim")]
 97            _child: child,
 98        }
 99    }
100
101    // Sends a keystroke to the neovim process.
102    #[cfg(feature = "neovim")]
103    pub async fn send_keystroke(&mut self, keystroke_text: &str) {
104        let keystroke = Keystroke::parse(keystroke_text).unwrap();
105        let special = keystroke.shift
106            || keystroke.ctrl
107            || keystroke.alt
108            || keystroke.cmd
109            || keystroke.key.len() > 1;
110        let start = if special { "<" } else { "" };
111        let shift = if keystroke.shift { "S-" } else { "" };
112        let ctrl = if keystroke.ctrl { "C-" } else { "" };
113        let alt = if keystroke.alt { "M-" } else { "" };
114        let cmd = if keystroke.cmd { "D-" } else { "" };
115        let end = if special { ">" } else { "" };
116
117        let key = format!("{start}{shift}{ctrl}{alt}{cmd}{}{end}", keystroke.key);
118
119        self.data
120            .push_back(NeovimData::Key(keystroke_text.to_string()));
121        self.nvim
122            .input(&key)
123            .await
124            .expect("Could not input keystroke");
125    }
126
127    #[cfg(not(feature = "neovim"))]
128    pub async fn send_keystroke(&mut self, keystroke_text: &str) {
129        if matches!(self.data.front(), Some(NeovimData::Get { .. })) {
130            self.data.pop_front();
131        }
132        assert_eq!(
133            self.data.pop_front(),
134            Some(NeovimData::Key(keystroke_text.to_string())),
135            "operation does not match recorded script. re-record with --features=neovim"
136        );
137    }
138
139    #[cfg(feature = "neovim")]
140    pub async fn set_state(&mut self, marked_text: &str) {
141        let (text, selections) = parse_state(&marked_text);
142
143        let nvim_buffer = self
144            .nvim
145            .get_current_buf()
146            .await
147            .expect("Could not get neovim buffer");
148        let lines = text
149            .split('\n')
150            .map(|line| line.to_string())
151            .collect::<Vec<_>>();
152
153        nvim_buffer
154            .set_lines(0, -1, false, lines)
155            .await
156            .expect("Could not set nvim buffer text");
157
158        self.nvim
159            .input("<escape>")
160            .await
161            .expect("Could not send escape to nvim");
162        self.nvim
163            .input("<escape>")
164            .await
165            .expect("Could not send escape to nvim");
166
167        let nvim_window = self
168            .nvim
169            .get_current_win()
170            .await
171            .expect("Could not get neovim window");
172
173        if selections.len() != 1 {
174            panic!("must have one selection");
175        }
176        let selection = &selections[0];
177
178        let cursor = selection.start;
179        nvim_window
180            .set_cursor((cursor.row as i64 + 1, cursor.column as i64))
181            .await
182            .expect("Could not set nvim cursor position");
183
184        if !selection.is_empty() {
185            self.nvim
186                .input("v")
187                .await
188                .expect("could not enter visual mode");
189
190            let cursor = selection.end;
191            nvim_window
192                .set_cursor((cursor.row as i64 + 1, cursor.column as i64))
193                .await
194                .expect("Could not set nvim cursor position");
195        }
196
197        if let Some(NeovimData::Get { mode, state }) = self.data.back() {
198            if *mode == Some(Mode::Normal) && *state == marked_text {
199                return;
200            }
201        }
202        self.data.push_back(NeovimData::Put {
203            state: marked_text.to_string(),
204        })
205    }
206
207    #[cfg(not(feature = "neovim"))]
208    pub async fn set_state(&mut self, marked_text: &str) {
209        if let Some(NeovimData::Get { mode, state: text }) = self.data.front() {
210            if *mode == Some(Mode::Normal) && *text == marked_text {
211                return;
212            }
213            self.data.pop_front();
214        }
215        assert_eq!(
216            self.data.pop_front(),
217            Some(NeovimData::Put {
218                state: marked_text.to_string()
219            }),
220            "operation does not match recorded script. re-record with --features=neovim"
221        );
222    }
223
224    #[cfg(feature = "neovim")]
225    async fn read_position(&mut self, cmd: &str) -> u32 {
226        self.nvim
227            .command_output(cmd)
228            .await
229            .unwrap()
230            .parse::<u32>()
231            .unwrap()
232    }
233
234    #[cfg(feature = "neovim")]
235    pub async fn state(&mut self) -> (Option<Mode>, String, Vec<Range<Point>>) {
236        let nvim_buffer = self
237            .nvim
238            .get_current_buf()
239            .await
240            .expect("Could not get neovim buffer");
241        let text = nvim_buffer
242            .get_lines(0, -1, false)
243            .await
244            .expect("Could not get buffer text")
245            .join("\n");
246
247        // nvim columns are 1-based, so -1.
248        let mut cursor_row = self.read_position("echo line('.')").await - 1;
249        let mut cursor_col = self.read_position("echo col('.')").await - 1;
250        let mut selection_row = self.read_position("echo line('v')").await - 1;
251        let mut selection_col = self.read_position("echo col('v')").await - 1;
252        let total_rows = self.read_position("echo line('$')").await - 1;
253
254        let nvim_mode_text = self
255            .nvim
256            .get_mode()
257            .await
258            .expect("Could not get mode")
259            .into_iter()
260            .find_map(|(key, value)| {
261                if key.as_str() == Some("mode") {
262                    Some(value.as_str().unwrap().to_owned())
263                } else {
264                    None
265                }
266            })
267            .expect("Could not find mode value");
268
269        let mode = match nvim_mode_text.as_ref() {
270            "i" => Some(Mode::Insert),
271            "n" => Some(Mode::Normal),
272            "v" => Some(Mode::Visual),
273            "V" => Some(Mode::VisualLine),
274            "\x16" => Some(Mode::VisualBlock),
275            _ => None,
276        };
277
278        let mut selections = Vec::new();
279        // Vim uses the index of the first and last character in the selection
280        // Zed uses the index of the positions between the characters, so we need
281        // to add one to the end in visual mode.
282        match mode {
283            Some(Mode::VisualBlock) if selection_row != cursor_row => {
284                // in zed we fake a block selecrtion by using multiple cursors (one per line)
285                // this code emulates that.
286                // to deal with casees where the selection is not perfectly rectangular we extract
287                // the content of the selection via the "a register to get the shape correctly.
288                self.nvim.input("\"aygv").await.unwrap();
289                let content = self.nvim.command_output("echo getreg('a')").await.unwrap();
290                let lines = content.split("\n").collect::<Vec<_>>();
291                let top = cmp::min(selection_row, cursor_row);
292                let left = cmp::min(selection_col, cursor_col);
293                for row in top..=cmp::max(selection_row, cursor_row) {
294                    let content = if row - top >= lines.len() as u32 {
295                        ""
296                    } else {
297                        lines[(row - top) as usize]
298                    };
299                    let line_len = self
300                        .read_position(format!("echo strlen(getline({}))", row + 1).as_str())
301                        .await;
302
303                    if left > line_len {
304                        continue;
305                    }
306
307                    let start = Point::new(row, left);
308                    let end = Point::new(row, left + content.len() as u32);
309                    if cursor_col >= selection_col {
310                        selections.push(start..end)
311                    } else {
312                        selections.push(end..start)
313                    }
314                }
315            }
316            Some(Mode::Visual) | Some(Mode::VisualLine) | Some(Mode::VisualBlock) => {
317                if selection_col > cursor_col {
318                    let selection_line_length =
319                        self.read_position("echo strlen(getline(line('v')))").await;
320                    if selection_line_length > selection_col {
321                        selection_col += 1;
322                    } else if selection_row < total_rows {
323                        selection_col = 0;
324                        selection_row += 1;
325                    }
326                } else {
327                    let cursor_line_length =
328                        self.read_position("echo strlen(getline(line('.')))").await;
329                    if cursor_line_length > cursor_col {
330                        cursor_col += 1;
331                    } else if cursor_row < total_rows {
332                        cursor_col = 0;
333                        cursor_row += 1;
334                    }
335                }
336                selections.push(
337                    Point::new(selection_row, selection_col)..Point::new(cursor_row, cursor_col),
338                )
339            }
340            Some(Mode::Insert) | Some(Mode::Normal) | None => selections
341                .push(Point::new(selection_row, selection_col)..Point::new(cursor_row, cursor_col)),
342        }
343
344        let state = NeovimData::Get {
345            mode,
346            state: encode_ranges(&text, &selections),
347        };
348
349        if self.data.back() != Some(&state) {
350            self.data.push_back(state.clone());
351        }
352
353        (mode, text, selections)
354    }
355
356    #[cfg(not(feature = "neovim"))]
357    pub async fn state(&mut self) -> (Option<Mode>, String, Vec<Range<Point>>) {
358        if let Some(NeovimData::Get { state: text, mode }) = self.data.front() {
359            let (text, ranges) = parse_state(text);
360            (*mode, text, ranges)
361        } else {
362            panic!("operation does not match recorded script. re-record with --features=neovim");
363        }
364    }
365
366    pub async fn selections(&mut self) -> Vec<Range<Point>> {
367        self.state().await.2
368    }
369
370    pub async fn mode(&mut self) -> Option<Mode> {
371        self.state().await.0
372    }
373
374    pub async fn text(&mut self) -> String {
375        self.state().await.1
376    }
377
378    fn test_data_path(test_case_id: &str) -> PathBuf {
379        let mut data_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
380        data_path.push("test_data");
381        data_path.push(format!("{}.json", test_case_id));
382        data_path
383    }
384
385    #[cfg(not(feature = "neovim"))]
386    fn read_test_data(test_case_id: &str) -> VecDeque<NeovimData> {
387        let path = Self::test_data_path(test_case_id);
388        let json = std::fs::read_to_string(path).expect(
389            "Could not read test data. Is it generated? Try running test with '--features neovim'",
390        );
391
392        let mut result = VecDeque::new();
393        for line in json.lines() {
394            result.push_back(
395                serde_json::from_str(line)
396                    .expect("invalid test data. regenerate it with '--features neovim'"),
397            );
398        }
399        result
400    }
401
402    #[cfg(feature = "neovim")]
403    fn write_test_data(test_case_id: &str, data: &VecDeque<NeovimData>) {
404        let path = Self::test_data_path(test_case_id);
405        let mut json = Vec::new();
406        for entry in data {
407            serde_json::to_writer(&mut json, entry).unwrap();
408            json.push(b'\n');
409        }
410        std::fs::create_dir_all(path.parent().unwrap())
411            .expect("could not create test data directory");
412        std::fs::write(path, json).expect("could not write out test data");
413    }
414}
415
416#[cfg(feature = "neovim")]
417impl Deref for NeovimConnection {
418    type Target = Neovim<nvim_rs::compat::tokio::Compat<ChildStdin>>;
419
420    fn deref(&self) -> &Self::Target {
421        &self.nvim
422    }
423}
424
425#[cfg(feature = "neovim")]
426impl DerefMut for NeovimConnection {
427    fn deref_mut(&mut self) -> &mut Self::Target {
428        &mut self.nvim
429    }
430}
431
432#[cfg(feature = "neovim")]
433impl Drop for NeovimConnection {
434    fn drop(&mut self) {
435        Self::write_test_data(&self.test_case_id, &self.data);
436    }
437}
438
439#[cfg(feature = "neovim")]
440#[derive(Clone)]
441struct NvimHandler {}
442
443#[cfg(feature = "neovim")]
444#[async_trait]
445impl Handler for NvimHandler {
446    type Writer = nvim_rs::compat::tokio::Compat<ChildStdin>;
447
448    async fn handle_request(
449        &self,
450        _event_name: String,
451        _arguments: Vec<Value>,
452        _neovim: Neovim<Self::Writer>,
453    ) -> Result<Value, Value> {
454        unimplemented!();
455    }
456
457    async fn handle_notify(
458        &self,
459        _event_name: String,
460        _arguments: Vec<Value>,
461        _neovim: Neovim<Self::Writer>,
462    ) {
463    }
464}
465
466fn parse_state(marked_text: &str) -> (String, Vec<Range<Point>>) {
467    let (text, ranges) = util::test::marked_text_ranges(marked_text, true);
468    let point_ranges = ranges
469        .into_iter()
470        .map(|byte_range| {
471            let mut point_range = Point::zero()..Point::zero();
472            let mut ix = 0;
473            let mut position = Point::zero();
474            for c in text.chars().chain(['\0']) {
475                if ix == byte_range.start {
476                    point_range.start = position;
477                }
478                if ix == byte_range.end {
479                    point_range.end = position;
480                }
481                let len_utf8 = c.len_utf8();
482                ix += len_utf8;
483                if c == '\n' {
484                    position.row += 1;
485                    position.column = 0;
486                } else {
487                    position.column += len_utf8 as u32;
488                }
489            }
490            point_range
491        })
492        .collect::<Vec<_>>();
493    (text, point_ranges)
494}
495
496#[cfg(feature = "neovim")]
497fn encode_ranges(text: &str, point_ranges: &Vec<Range<Point>>) -> String {
498    let byte_ranges = point_ranges
499        .into_iter()
500        .map(|range| {
501            let mut byte_range = 0..0;
502            let mut ix = 0;
503            let mut position = Point::zero();
504            for c in text.chars().chain(['\0']) {
505                if position == range.start {
506                    byte_range.start = ix;
507                }
508                if position == range.end {
509                    byte_range.end = ix;
510                }
511                let len_utf8 = c.len_utf8();
512                ix += len_utf8;
513                if c == '\n' {
514                    position.row += 1;
515                    position.column = 0;
516                } else {
517                    position.column += len_utf8 as u32;
518                }
519            }
520            byte_range
521        })
522        .collect::<Vec<_>>();
523    util::test::generate_marked_text(text, &byte_ranges[..], true)
524}