neovim_connection.rs

  1#[cfg(feature = "neovim")]
  2use std::{
  3    cmp,
  4    ops::{Deref, DerefMut},
  5};
  6use std::{ops::Range, path::PathBuf};
  7
  8#[cfg(feature = "neovim")]
  9use async_compat::Compat;
 10#[cfg(feature = "neovim")]
 11use async_trait::async_trait;
 12#[cfg(feature = "neovim")]
 13use gpui::keymap_matcher::Keystroke;
 14
 15use language::Point;
 16
 17#[cfg(feature = "neovim")]
 18use nvim_rs::{
 19    create::tokio::new_child_cmd, error::LoopError, Handler, Neovim, UiAttachOptions, Value,
 20};
 21#[cfg(feature = "neovim")]
 22use parking_lot::ReentrantMutex;
 23use serde::{Deserialize, Serialize};
 24#[cfg(feature = "neovim")]
 25use tokio::{
 26    process::{Child, ChildStdin, Command},
 27    task::JoinHandle,
 28};
 29
 30use crate::state::Mode;
 31use collections::VecDeque;
 32
 33// Neovim doesn't like to be started simultaneously from multiple threads. We use this lock
 34// to ensure we are only constructing one neovim connection at a time.
 35#[cfg(feature = "neovim")]
 36static NEOVIM_LOCK: ReentrantMutex<()> = ReentrantMutex::new(());
 37
 38#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
 39pub enum NeovimData {
 40    Put { state: String },
 41    Key(String),
 42    Get { state: String, mode: Option<Mode> },
 43    ReadRegister { name: char, value: String },
 44}
 45
 46pub struct NeovimConnection {
 47    data: VecDeque<NeovimData>,
 48    #[cfg(feature = "neovim")]
 49    test_case_id: String,
 50    #[cfg(feature = "neovim")]
 51    nvim: Neovim<nvim_rs::compat::tokio::Compat<ChildStdin>>,
 52    #[cfg(feature = "neovim")]
 53    _join_handle: JoinHandle<Result<(), Box<LoopError>>>,
 54    #[cfg(feature = "neovim")]
 55    _child: Child,
 56}
 57
 58impl NeovimConnection {
 59    pub async fn new(test_case_id: String) -> Self {
 60        #[cfg(feature = "neovim")]
 61        let handler = NvimHandler {};
 62        #[cfg(feature = "neovim")]
 63        let (nvim, join_handle, child) = Compat::new(async {
 64            // Ensure we don't create neovim connections in parallel
 65            let _lock = NEOVIM_LOCK.lock();
 66            let (nvim, join_handle, child) = new_child_cmd(
 67                &mut Command::new("nvim").arg("--embed").arg("--clean"),
 68                handler,
 69            )
 70            .await
 71            .expect("Could not connect to neovim process");
 72
 73            nvim.ui_attach(100, 100, &UiAttachOptions::default())
 74                .await
 75                .expect("Could not attach to ui");
 76
 77            // Makes system act a little more like zed in terms of indentation
 78            nvim.set_option("smartindent", nvim_rs::Value::Boolean(true))
 79                .await
 80                .expect("Could not set smartindent on startup");
 81
 82            (nvim, join_handle, child)
 83        })
 84        .await;
 85
 86        Self {
 87            #[cfg(feature = "neovim")]
 88            data: Default::default(),
 89            #[cfg(not(feature = "neovim"))]
 90            data: Self::read_test_data(&test_case_id),
 91            #[cfg(feature = "neovim")]
 92            test_case_id,
 93            #[cfg(feature = "neovim")]
 94            nvim,
 95            #[cfg(feature = "neovim")]
 96            _join_handle: join_handle,
 97            #[cfg(feature = "neovim")]
 98            _child: child,
 99        }
100    }
101
102    // Sends a keystroke to the neovim process.
103    #[cfg(feature = "neovim")]
104    pub async fn send_keystroke(&mut self, keystroke_text: &str) {
105        let keystroke = Keystroke::parse(keystroke_text).unwrap();
106        let special = keystroke.shift
107            || keystroke.ctrl
108            || keystroke.alt
109            || keystroke.cmd
110            || keystroke.key.len() > 1;
111        let start = if special { "<" } else { "" };
112        let shift = if keystroke.shift { "S-" } else { "" };
113        let ctrl = if keystroke.ctrl { "C-" } else { "" };
114        let alt = if keystroke.alt { "M-" } else { "" };
115        let cmd = if keystroke.cmd { "D-" } else { "" };
116        let end = if special { ">" } else { "" };
117
118        let key = format!("{start}{shift}{ctrl}{alt}{cmd}{}{end}", keystroke.key);
119
120        self.data
121            .push_back(NeovimData::Key(keystroke_text.to_string()));
122        self.nvim
123            .input(&key)
124            .await
125            .expect("Could not input keystroke");
126    }
127
128    #[cfg(not(feature = "neovim"))]
129    pub async fn send_keystroke(&mut self, keystroke_text: &str) {
130        if matches!(self.data.front(), Some(NeovimData::Get { .. })) {
131            self.data.pop_front();
132        }
133        assert_eq!(
134            self.data.pop_front(),
135            Some(NeovimData::Key(keystroke_text.to_string())),
136            "operation does not match recorded script. re-record with --features=neovim"
137        );
138    }
139
140    #[cfg(feature = "neovim")]
141    pub async fn set_state(&mut self, marked_text: &str) {
142        let (text, selections) = parse_state(&marked_text);
143
144        let nvim_buffer = self
145            .nvim
146            .get_current_buf()
147            .await
148            .expect("Could not get neovim buffer");
149        let lines = text
150            .split('\n')
151            .map(|line| line.to_string())
152            .collect::<Vec<_>>();
153
154        nvim_buffer
155            .set_lines(0, -1, false, lines)
156            .await
157            .expect("Could not set nvim buffer text");
158
159        self.nvim
160            .input("<escape>")
161            .await
162            .expect("Could not send escape to nvim");
163        self.nvim
164            .input("<escape>")
165            .await
166            .expect("Could not send escape to nvim");
167
168        let nvim_window = self
169            .nvim
170            .get_current_win()
171            .await
172            .expect("Could not get neovim window");
173
174        if selections.len() != 1 {
175            panic!("must have one selection");
176        }
177        let selection = &selections[0];
178
179        let cursor = selection.start;
180        nvim_window
181            .set_cursor((cursor.row as i64 + 1, cursor.column as i64))
182            .await
183            .expect("Could not set nvim cursor position");
184
185        if !selection.is_empty() {
186            self.nvim
187                .input("v")
188                .await
189                .expect("could not enter visual mode");
190
191            let cursor = selection.end;
192            nvim_window
193                .set_cursor((cursor.row as i64 + 1, cursor.column as i64))
194                .await
195                .expect("Could not set nvim cursor position");
196        }
197
198        if let Some(NeovimData::Get { mode, state }) = self.data.back() {
199            if *mode == Some(Mode::Normal) && *state == marked_text {
200                return;
201            }
202        }
203        self.data.push_back(NeovimData::Put {
204            state: marked_text.to_string(),
205        })
206    }
207
208    #[cfg(not(feature = "neovim"))]
209    pub async fn set_state(&mut self, marked_text: &str) {
210        if let Some(NeovimData::Get { mode, state: text }) = self.data.front() {
211            if *mode == Some(Mode::Normal) && *text == marked_text {
212                return;
213            }
214            self.data.pop_front();
215        }
216        assert_eq!(
217            self.data.pop_front(),
218            Some(NeovimData::Put {
219                state: marked_text.to_string()
220            }),
221            "operation does not match recorded script. re-record with --features=neovim"
222        );
223    }
224
225    #[cfg(not(feature = "neovim"))]
226    pub async fn read_register(&mut self, register: char) -> String {
227        if let Some(NeovimData::Get { .. }) = self.data.front() {
228            self.data.pop_front();
229        };
230        if let Some(NeovimData::ReadRegister { name, value }) = self.data.pop_front() {
231            if name == register {
232                return value;
233            }
234        }
235
236        panic!("operation does not match recorded script. re-record with --features=neovim")
237    }
238
239    #[cfg(feature = "neovim")]
240    pub async fn read_register(&mut self, name: char) -> String {
241        let value = self
242            .nvim
243            .command_output(format!("echo getreg('{}')", name).as_str())
244            .await
245            .unwrap();
246
247        self.data.push_back(NeovimData::ReadRegister {
248            name,
249            value: value.clone(),
250        });
251
252        value
253    }
254
255    #[cfg(feature = "neovim")]
256    async fn read_position(&mut self, cmd: &str) -> u32 {
257        self.nvim
258            .command_output(cmd)
259            .await
260            .unwrap()
261            .parse::<u32>()
262            .unwrap()
263    }
264
265    #[cfg(feature = "neovim")]
266    pub async fn state(&mut self) -> (Option<Mode>, String, Vec<Range<Point>>) {
267        let nvim_buffer = self
268            .nvim
269            .get_current_buf()
270            .await
271            .expect("Could not get neovim buffer");
272        let text = nvim_buffer
273            .get_lines(0, -1, false)
274            .await
275            .expect("Could not get buffer text")
276            .join("\n");
277
278        // nvim columns are 1-based, so -1.
279        let mut cursor_row = self.read_position("echo line('.')").await - 1;
280        let mut cursor_col = self.read_position("echo col('.')").await - 1;
281        let mut selection_row = self.read_position("echo line('v')").await - 1;
282        let mut selection_col = self.read_position("echo col('v')").await - 1;
283        let total_rows = self.read_position("echo line('$')").await - 1;
284
285        let nvim_mode_text = self
286            .nvim
287            .get_mode()
288            .await
289            .expect("Could not get mode")
290            .into_iter()
291            .find_map(|(key, value)| {
292                if key.as_str() == Some("mode") {
293                    Some(value.as_str().unwrap().to_owned())
294                } else {
295                    None
296                }
297            })
298            .expect("Could not find mode value");
299
300        let mode = match nvim_mode_text.as_ref() {
301            "i" => Some(Mode::Insert),
302            "n" => Some(Mode::Normal),
303            "v" => Some(Mode::Visual),
304            "V" => Some(Mode::VisualLine),
305            "\x16" => Some(Mode::VisualBlock),
306            _ => None,
307        };
308
309        let mut selections = Vec::new();
310        // Vim uses the index of the first and last character in the selection
311        // Zed uses the index of the positions between the characters, so we need
312        // to add one to the end in visual mode.
313        match mode {
314            Some(Mode::VisualBlock) if selection_row != cursor_row => {
315                // in zed we fake a block selecrtion by using multiple cursors (one per line)
316                // this code emulates that.
317                // to deal with casees where the selection is not perfectly rectangular we extract
318                // the content of the selection via the "a register to get the shape correctly.
319                self.nvim.input("\"aygv").await.unwrap();
320                let content = self.nvim.command_output("echo getreg('a')").await.unwrap();
321                let lines = content.split("\n").collect::<Vec<_>>();
322                let top = cmp::min(selection_row, cursor_row);
323                let left = cmp::min(selection_col, cursor_col);
324                for row in top..=cmp::max(selection_row, cursor_row) {
325                    let content = if row - top >= lines.len() as u32 {
326                        ""
327                    } else {
328                        lines[(row - top) as usize]
329                    };
330                    let line_len = self
331                        .read_position(format!("echo strlen(getline({}))", row + 1).as_str())
332                        .await;
333
334                    if left > line_len {
335                        continue;
336                    }
337
338                    let start = Point::new(row, left);
339                    let end = Point::new(row, left + content.len() as u32);
340                    if cursor_col >= selection_col {
341                        selections.push(start..end)
342                    } else {
343                        selections.push(end..start)
344                    }
345                }
346            }
347            Some(Mode::Visual) | Some(Mode::VisualLine) | Some(Mode::VisualBlock) => {
348                if selection_col > cursor_col {
349                    let selection_line_length =
350                        self.read_position("echo strlen(getline(line('v')))").await;
351                    if selection_line_length > selection_col {
352                        selection_col += 1;
353                    } else if selection_row < total_rows {
354                        selection_col = 0;
355                        selection_row += 1;
356                    }
357                } else {
358                    let cursor_line_length =
359                        self.read_position("echo strlen(getline(line('.')))").await;
360                    if cursor_line_length > cursor_col {
361                        cursor_col += 1;
362                    } else if cursor_row < total_rows {
363                        cursor_col = 0;
364                        cursor_row += 1;
365                    }
366                }
367                selections.push(
368                    Point::new(selection_row, selection_col)..Point::new(cursor_row, cursor_col),
369                )
370            }
371            Some(Mode::Insert) | Some(Mode::Normal) | None => selections
372                .push(Point::new(selection_row, selection_col)..Point::new(cursor_row, cursor_col)),
373        }
374
375        let state = NeovimData::Get {
376            mode,
377            state: encode_ranges(&text, &selections),
378        };
379
380        if self.data.back() != Some(&state) {
381            self.data.push_back(state.clone());
382        }
383
384        (mode, text, selections)
385    }
386
387    #[cfg(not(feature = "neovim"))]
388    pub async fn state(&mut self) -> (Option<Mode>, String, Vec<Range<Point>>) {
389        if let Some(NeovimData::Get { state: text, mode }) = self.data.front() {
390            let (text, ranges) = parse_state(text);
391            (*mode, text, ranges)
392        } else {
393            panic!("operation does not match recorded script. re-record with --features=neovim");
394        }
395    }
396
397    pub async fn selections(&mut self) -> Vec<Range<Point>> {
398        self.state().await.2
399    }
400
401    pub async fn mode(&mut self) -> Option<Mode> {
402        self.state().await.0
403    }
404
405    pub async fn text(&mut self) -> String {
406        self.state().await.1
407    }
408
409    fn test_data_path(test_case_id: &str) -> PathBuf {
410        let mut data_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
411        data_path.push("test_data");
412        data_path.push(format!("{}.json", test_case_id));
413        data_path
414    }
415
416    #[cfg(not(feature = "neovim"))]
417    fn read_test_data(test_case_id: &str) -> VecDeque<NeovimData> {
418        let path = Self::test_data_path(test_case_id);
419        let json = std::fs::read_to_string(path).expect(
420            "Could not read test data. Is it generated? Try running test with '--features neovim'",
421        );
422
423        let mut result = VecDeque::new();
424        for line in json.lines() {
425            result.push_back(
426                serde_json::from_str(line)
427                    .expect("invalid test data. regenerate it with '--features neovim'"),
428            );
429        }
430        result
431    }
432
433    #[cfg(feature = "neovim")]
434    fn write_test_data(test_case_id: &str, data: &VecDeque<NeovimData>) {
435        let path = Self::test_data_path(test_case_id);
436        let mut json = Vec::new();
437        for entry in data {
438            serde_json::to_writer(&mut json, entry).unwrap();
439            json.push(b'\n');
440        }
441        std::fs::create_dir_all(path.parent().unwrap())
442            .expect("could not create test data directory");
443        std::fs::write(path, json).expect("could not write out test data");
444    }
445}
446
447#[cfg(feature = "neovim")]
448impl Deref for NeovimConnection {
449    type Target = Neovim<nvim_rs::compat::tokio::Compat<ChildStdin>>;
450
451    fn deref(&self) -> &Self::Target {
452        &self.nvim
453    }
454}
455
456#[cfg(feature = "neovim")]
457impl DerefMut for NeovimConnection {
458    fn deref_mut(&mut self) -> &mut Self::Target {
459        &mut self.nvim
460    }
461}
462
463#[cfg(feature = "neovim")]
464impl Drop for NeovimConnection {
465    fn drop(&mut self) {
466        Self::write_test_data(&self.test_case_id, &self.data);
467    }
468}
469
470#[cfg(feature = "neovim")]
471#[derive(Clone)]
472struct NvimHandler {}
473
474#[cfg(feature = "neovim")]
475#[async_trait]
476impl Handler for NvimHandler {
477    type Writer = nvim_rs::compat::tokio::Compat<ChildStdin>;
478
479    async fn handle_request(
480        &self,
481        _event_name: String,
482        _arguments: Vec<Value>,
483        _neovim: Neovim<Self::Writer>,
484    ) -> Result<Value, Value> {
485        unimplemented!();
486    }
487
488    async fn handle_notify(
489        &self,
490        _event_name: String,
491        _arguments: Vec<Value>,
492        _neovim: Neovim<Self::Writer>,
493    ) {
494    }
495}
496
497fn parse_state(marked_text: &str) -> (String, Vec<Range<Point>>) {
498    let (text, ranges) = util::test::marked_text_ranges(marked_text, true);
499    let point_ranges = ranges
500        .into_iter()
501        .map(|byte_range| {
502            let mut point_range = Point::zero()..Point::zero();
503            let mut ix = 0;
504            let mut position = Point::zero();
505            for c in text.chars().chain(['\0']) {
506                if ix == byte_range.start {
507                    point_range.start = position;
508                }
509                if ix == byte_range.end {
510                    point_range.end = position;
511                }
512                let len_utf8 = c.len_utf8();
513                ix += len_utf8;
514                if c == '\n' {
515                    position.row += 1;
516                    position.column = 0;
517                } else {
518                    position.column += len_utf8 as u32;
519                }
520            }
521            point_range
522        })
523        .collect::<Vec<_>>();
524    (text, point_ranges)
525}
526
527#[cfg(feature = "neovim")]
528fn encode_ranges(text: &str, point_ranges: &Vec<Range<Point>>) -> String {
529    let byte_ranges = point_ranges
530        .into_iter()
531        .map(|range| {
532            let mut byte_range = 0..0;
533            let mut ix = 0;
534            let mut position = Point::zero();
535            for c in text.chars().chain(['\0']) {
536                if position == range.start {
537                    byte_range.start = ix;
538                }
539                if position == range.end {
540                    byte_range.end = ix;
541                }
542                let len_utf8 = c.len_utf8();
543                ix += len_utf8;
544                if c == '\n' {
545                    position.row += 1;
546                    position.column = 0;
547                } else {
548                    position.column += len_utf8 as u32;
549                }
550            }
551            byte_range
552        })
553        .collect::<Vec<_>>();
554    util::test::generate_marked_text(text, &byte_ranges[..], true)
555}