neovim_connection.rs

  1#[cfg(feature = "neovim")]
  2use std::{
  3    cmp,
  4    ops::{Deref, DerefMut},
  5};
  6use std::{ops::Range, path::PathBuf};
  7
  8#[cfg(feature = "neovim")]
  9use async_compat::Compat;
 10#[cfg(feature = "neovim")]
 11use async_trait::async_trait;
 12#[cfg(feature = "neovim")]
 13use gpui::keymap_matcher::Keystroke;
 14
 15use language::Point;
 16
 17#[cfg(feature = "neovim")]
 18use nvim_rs::{
 19    create::tokio::new_child_cmd, error::LoopError, Handler, Neovim, UiAttachOptions, Value,
 20};
 21#[cfg(feature = "neovim")]
 22use parking_lot::ReentrantMutex;
 23use serde::{Deserialize, Serialize};
 24#[cfg(feature = "neovim")]
 25use tokio::{
 26    process::{Child, ChildStdin, Command},
 27    task::JoinHandle,
 28};
 29
 30use crate::state::Mode;
 31use collections::VecDeque;
 32
 33// Neovim doesn't like to be started simultaneously from multiple threads. We use this lock
 34// to ensure we are only constructing one neovim connection at a time.
 35#[cfg(feature = "neovim")]
 36static NEOVIM_LOCK: ReentrantMutex<()> = ReentrantMutex::new(());
 37
 38#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
 39pub enum NeovimData {
 40    Put { state: String },
 41    Key(String),
 42    Get { state: String, mode: Option<Mode> },
 43    ReadRegister { name: char, value: String },
 44    SetOption { value: String },
 45}
 46
 47pub struct NeovimConnection {
 48    data: VecDeque<NeovimData>,
 49    #[cfg(feature = "neovim")]
 50    test_case_id: String,
 51    #[cfg(feature = "neovim")]
 52    nvim: Neovim<nvim_rs::compat::tokio::Compat<ChildStdin>>,
 53    #[cfg(feature = "neovim")]
 54    _join_handle: JoinHandle<Result<(), Box<LoopError>>>,
 55    #[cfg(feature = "neovim")]
 56    _child: Child,
 57}
 58
 59impl NeovimConnection {
 60    pub async fn new(test_case_id: String) -> Self {
 61        #[cfg(feature = "neovim")]
 62        let handler = NvimHandler {};
 63        #[cfg(feature = "neovim")]
 64        let (nvim, join_handle, child) = Compat::new(async {
 65            // Ensure we don't create neovim connections in parallel
 66            let _lock = NEOVIM_LOCK.lock();
 67            let (nvim, join_handle, child) = new_child_cmd(
 68                &mut Command::new("nvim").arg("--embed").arg("--clean"),
 69                handler,
 70            )
 71            .await
 72            .expect("Could not connect to neovim process");
 73
 74            nvim.ui_attach(100, 100, &UiAttachOptions::default())
 75                .await
 76                .expect("Could not attach to ui");
 77
 78            // Makes system act a little more like zed in terms of indentation
 79            nvim.set_option("smartindent", nvim_rs::Value::Boolean(true))
 80                .await
 81                .expect("Could not set smartindent on startup");
 82
 83            (nvim, join_handle, child)
 84        })
 85        .await;
 86
 87        Self {
 88            #[cfg(feature = "neovim")]
 89            data: Default::default(),
 90            #[cfg(not(feature = "neovim"))]
 91            data: Self::read_test_data(&test_case_id),
 92            #[cfg(feature = "neovim")]
 93            test_case_id,
 94            #[cfg(feature = "neovim")]
 95            nvim,
 96            #[cfg(feature = "neovim")]
 97            _join_handle: join_handle,
 98            #[cfg(feature = "neovim")]
 99            _child: child,
100        }
101    }
102
103    // Sends a keystroke to the neovim process.
104    #[cfg(feature = "neovim")]
105    pub async fn send_keystroke(&mut self, keystroke_text: &str) {
106        let keystroke = Keystroke::parse(keystroke_text).unwrap();
107        let special = keystroke.shift
108            || keystroke.ctrl
109            || keystroke.alt
110            || keystroke.cmd
111            || keystroke.key.len() > 1;
112        let start = if special { "<" } else { "" };
113        let shift = if keystroke.shift { "S-" } else { "" };
114        let ctrl = if keystroke.ctrl { "C-" } else { "" };
115        let alt = if keystroke.alt { "M-" } else { "" };
116        let cmd = if keystroke.cmd { "D-" } else { "" };
117        let end = if special { ">" } else { "" };
118
119        let key = format!("{start}{shift}{ctrl}{alt}{cmd}{}{end}", keystroke.key);
120
121        self.data
122            .push_back(NeovimData::Key(keystroke_text.to_string()));
123        self.nvim
124            .input(&key)
125            .await
126            .expect("Could not input keystroke");
127    }
128
129    #[cfg(not(feature = "neovim"))]
130    pub async fn send_keystroke(&mut self, keystroke_text: &str) {
131        if matches!(self.data.front(), Some(NeovimData::Get { .. })) {
132            self.data.pop_front();
133        }
134        assert_eq!(
135            self.data.pop_front(),
136            Some(NeovimData::Key(keystroke_text.to_string())),
137            "operation does not match recorded script. re-record with --features=neovim"
138        );
139    }
140
141    #[cfg(feature = "neovim")]
142    pub async fn set_state(&mut self, marked_text: &str) {
143        let (text, selections) = parse_state(&marked_text);
144
145        let nvim_buffer = self
146            .nvim
147            .get_current_buf()
148            .await
149            .expect("Could not get neovim buffer");
150        let lines = text
151            .split('\n')
152            .map(|line| line.to_string())
153            .collect::<Vec<_>>();
154
155        nvim_buffer
156            .set_lines(0, -1, false, lines)
157            .await
158            .expect("Could not set nvim buffer text");
159
160        self.nvim
161            .input("<escape>")
162            .await
163            .expect("Could not send escape to nvim");
164        self.nvim
165            .input("<escape>")
166            .await
167            .expect("Could not send escape to nvim");
168
169        let nvim_window = self
170            .nvim
171            .get_current_win()
172            .await
173            .expect("Could not get neovim window");
174
175        if selections.len() != 1 {
176            panic!("must have one selection");
177        }
178        let selection = &selections[0];
179
180        let cursor = selection.start;
181        nvim_window
182            .set_cursor((cursor.row as i64 + 1, cursor.column as i64))
183            .await
184            .expect("Could not set nvim cursor position");
185
186        if !selection.is_empty() {
187            self.nvim
188                .input("v")
189                .await
190                .expect("could not enter visual mode");
191
192            let cursor = selection.end;
193            nvim_window
194                .set_cursor((cursor.row as i64 + 1, cursor.column as i64))
195                .await
196                .expect("Could not set nvim cursor position");
197        }
198
199        if let Some(NeovimData::Get { mode, state }) = self.data.back() {
200            if *mode == Some(Mode::Normal) && *state == marked_text {
201                return;
202            }
203        }
204        self.data.push_back(NeovimData::Put {
205            state: marked_text.to_string(),
206        })
207    }
208
209    #[cfg(not(feature = "neovim"))]
210    pub async fn set_state(&mut self, marked_text: &str) {
211        if let Some(NeovimData::Get { mode, state: text }) = self.data.front() {
212            if *mode == Some(Mode::Normal) && *text == marked_text {
213                return;
214            }
215            self.data.pop_front();
216        }
217        assert_eq!(
218            self.data.pop_front(),
219            Some(NeovimData::Put {
220                state: marked_text.to_string()
221            }),
222            "operation does not match recorded script. re-record with --features=neovim"
223        );
224    }
225
226    #[cfg(feature = "neovim")]
227    pub async fn set_option(&mut self, value: &str) {
228        self.nvim
229            .command_output(format!("set {}", value).as_str())
230            .await
231            .unwrap();
232
233        self.data.push_back(NeovimData::SetOption {
234            value: value.to_string(),
235        })
236    }
237
238    #[cfg(not(feature = "neovim"))]
239    pub async fn set_option(&mut self, value: &str) {
240        if let Some(NeovimData::Get { .. }) = self.data.front() {
241            self.data.pop_front();
242        };
243        assert_eq!(
244            self.data.pop_front(),
245            Some(NeovimData::SetOption {
246                value: value.to_string(),
247            }),
248            "operation does not match recorded script. re-record with --features=neovim"
249        );
250    }
251
252    #[cfg(not(feature = "neovim"))]
253    pub async fn read_register(&mut self, register: char) -> String {
254        if let Some(NeovimData::Get { .. }) = self.data.front() {
255            self.data.pop_front();
256        };
257        if let Some(NeovimData::ReadRegister { name, value }) = self.data.pop_front() {
258            if name == register {
259                return value;
260            }
261        }
262
263        panic!("operation does not match recorded script. re-record with --features=neovim")
264    }
265
266    #[cfg(feature = "neovim")]
267    pub async fn read_register(&mut self, name: char) -> String {
268        let value = self
269            .nvim
270            .command_output(format!("echo getreg('{}')", name).as_str())
271            .await
272            .unwrap();
273
274        self.data.push_back(NeovimData::ReadRegister {
275            name,
276            value: value.clone(),
277        });
278
279        value
280    }
281
282    #[cfg(feature = "neovim")]
283    async fn read_position(&mut self, cmd: &str) -> u32 {
284        self.nvim
285            .command_output(cmd)
286            .await
287            .unwrap()
288            .parse::<u32>()
289            .unwrap()
290    }
291
292    #[cfg(feature = "neovim")]
293    pub async fn state(&mut self) -> (Option<Mode>, String, Vec<Range<Point>>) {
294        let nvim_buffer = self
295            .nvim
296            .get_current_buf()
297            .await
298            .expect("Could not get neovim buffer");
299        let text = nvim_buffer
300            .get_lines(0, -1, false)
301            .await
302            .expect("Could not get buffer text")
303            .join("\n");
304
305        // nvim columns are 1-based, so -1.
306        let mut cursor_row = self.read_position("echo line('.')").await - 1;
307        let mut cursor_col = self.read_position("echo col('.')").await - 1;
308        let mut selection_row = self.read_position("echo line('v')").await - 1;
309        let mut selection_col = self.read_position("echo col('v')").await - 1;
310        let total_rows = self.read_position("echo line('$')").await - 1;
311
312        let nvim_mode_text = self
313            .nvim
314            .get_mode()
315            .await
316            .expect("Could not get mode")
317            .into_iter()
318            .find_map(|(key, value)| {
319                if key.as_str() == Some("mode") {
320                    Some(value.as_str().unwrap().to_owned())
321                } else {
322                    None
323                }
324            })
325            .expect("Could not find mode value");
326
327        let mode = match nvim_mode_text.as_ref() {
328            "i" => Some(Mode::Insert),
329            "n" => Some(Mode::Normal),
330            "v" => Some(Mode::Visual),
331            "V" => Some(Mode::VisualLine),
332            "\x16" => Some(Mode::VisualBlock),
333            _ => None,
334        };
335
336        let mut selections = Vec::new();
337        // Vim uses the index of the first and last character in the selection
338        // Zed uses the index of the positions between the characters, so we need
339        // to add one to the end in visual mode.
340        match mode {
341            Some(Mode::VisualBlock) if selection_row != cursor_row => {
342                // in zed we fake a block selecrtion by using multiple cursors (one per line)
343                // this code emulates that.
344                // to deal with casees where the selection is not perfectly rectangular we extract
345                // the content of the selection via the "a register to get the shape correctly.
346                self.nvim.input("\"aygv").await.unwrap();
347                let content = self.nvim.command_output("echo getreg('a')").await.unwrap();
348                let lines = content.split("\n").collect::<Vec<_>>();
349                let top = cmp::min(selection_row, cursor_row);
350                let left = cmp::min(selection_col, cursor_col);
351                for row in top..=cmp::max(selection_row, cursor_row) {
352                    let content = if row - top >= lines.len() as u32 {
353                        ""
354                    } else {
355                        lines[(row - top) as usize]
356                    };
357                    let line_len = self
358                        .read_position(format!("echo strlen(getline({}))", row + 1).as_str())
359                        .await;
360
361                    if left > line_len {
362                        continue;
363                    }
364
365                    let start = Point::new(row, left);
366                    let end = Point::new(row, left + content.len() as u32);
367                    if cursor_col >= selection_col {
368                        selections.push(start..end)
369                    } else {
370                        selections.push(end..start)
371                    }
372                }
373            }
374            Some(Mode::Visual) | Some(Mode::VisualLine) | Some(Mode::VisualBlock) => {
375                if selection_col > cursor_col {
376                    let selection_line_length =
377                        self.read_position("echo strlen(getline(line('v')))").await;
378                    if selection_line_length > selection_col {
379                        selection_col += 1;
380                    } else if selection_row < total_rows {
381                        selection_col = 0;
382                        selection_row += 1;
383                    }
384                } else {
385                    let cursor_line_length =
386                        self.read_position("echo strlen(getline(line('.')))").await;
387                    if cursor_line_length > cursor_col {
388                        cursor_col += 1;
389                    } else if cursor_row < total_rows {
390                        cursor_col = 0;
391                        cursor_row += 1;
392                    }
393                }
394                selections.push(
395                    Point::new(selection_row, selection_col)..Point::new(cursor_row, cursor_col),
396                )
397            }
398            Some(Mode::Insert) | Some(Mode::Normal) | None => selections
399                .push(Point::new(selection_row, selection_col)..Point::new(cursor_row, cursor_col)),
400        }
401
402        let state = NeovimData::Get {
403            mode,
404            state: encode_ranges(&text, &selections),
405        };
406
407        if self.data.back() != Some(&state) {
408            self.data.push_back(state.clone());
409        }
410
411        (mode, text, selections)
412    }
413
414    #[cfg(not(feature = "neovim"))]
415    pub async fn state(&mut self) -> (Option<Mode>, String, Vec<Range<Point>>) {
416        if let Some(NeovimData::Get { state: text, mode }) = self.data.front() {
417            let (text, ranges) = parse_state(text);
418            (*mode, text, ranges)
419        } else {
420            panic!("operation does not match recorded script. re-record with --features=neovim");
421        }
422    }
423
424    pub async fn selections(&mut self) -> Vec<Range<Point>> {
425        self.state().await.2
426    }
427
428    pub async fn mode(&mut self) -> Option<Mode> {
429        self.state().await.0
430    }
431
432    pub async fn text(&mut self) -> String {
433        self.state().await.1
434    }
435
436    fn test_data_path(test_case_id: &str) -> PathBuf {
437        let mut data_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
438        data_path.push("test_data");
439        data_path.push(format!("{}.json", test_case_id));
440        data_path
441    }
442
443    #[cfg(not(feature = "neovim"))]
444    fn read_test_data(test_case_id: &str) -> VecDeque<NeovimData> {
445        let path = Self::test_data_path(test_case_id);
446        let json = std::fs::read_to_string(path).expect(
447            "Could not read test data. Is it generated? Try running test with '--features neovim'",
448        );
449
450        let mut result = VecDeque::new();
451        for line in json.lines() {
452            result.push_back(
453                serde_json::from_str(line)
454                    .expect("invalid test data. regenerate it with '--features neovim'"),
455            );
456        }
457        result
458    }
459
460    #[cfg(feature = "neovim")]
461    fn write_test_data(test_case_id: &str, data: &VecDeque<NeovimData>) {
462        let path = Self::test_data_path(test_case_id);
463        let mut json = Vec::new();
464        for entry in data {
465            serde_json::to_writer(&mut json, entry).unwrap();
466            json.push(b'\n');
467        }
468        std::fs::create_dir_all(path.parent().unwrap())
469            .expect("could not create test data directory");
470        std::fs::write(path, json).expect("could not write out test data");
471    }
472}
473
474#[cfg(feature = "neovim")]
475impl Deref for NeovimConnection {
476    type Target = Neovim<nvim_rs::compat::tokio::Compat<ChildStdin>>;
477
478    fn deref(&self) -> &Self::Target {
479        &self.nvim
480    }
481}
482
483#[cfg(feature = "neovim")]
484impl DerefMut for NeovimConnection {
485    fn deref_mut(&mut self) -> &mut Self::Target {
486        &mut self.nvim
487    }
488}
489
490#[cfg(feature = "neovim")]
491impl Drop for NeovimConnection {
492    fn drop(&mut self) {
493        Self::write_test_data(&self.test_case_id, &self.data);
494    }
495}
496
497#[cfg(feature = "neovim")]
498#[derive(Clone)]
499struct NvimHandler {}
500
501#[cfg(feature = "neovim")]
502#[async_trait]
503impl Handler for NvimHandler {
504    type Writer = nvim_rs::compat::tokio::Compat<ChildStdin>;
505
506    async fn handle_request(
507        &self,
508        _event_name: String,
509        _arguments: Vec<Value>,
510        _neovim: Neovim<Self::Writer>,
511    ) -> Result<Value, Value> {
512        unimplemented!();
513    }
514
515    async fn handle_notify(
516        &self,
517        _event_name: String,
518        _arguments: Vec<Value>,
519        _neovim: Neovim<Self::Writer>,
520    ) {
521    }
522}
523
524fn parse_state(marked_text: &str) -> (String, Vec<Range<Point>>) {
525    let (text, ranges) = util::test::marked_text_ranges(marked_text, true);
526    let point_ranges = ranges
527        .into_iter()
528        .map(|byte_range| {
529            let mut point_range = Point::zero()..Point::zero();
530            let mut ix = 0;
531            let mut position = Point::zero();
532            for c in text.chars().chain(['\0']) {
533                if ix == byte_range.start {
534                    point_range.start = position;
535                }
536                if ix == byte_range.end {
537                    point_range.end = position;
538                }
539                let len_utf8 = c.len_utf8();
540                ix += len_utf8;
541                if c == '\n' {
542                    position.row += 1;
543                    position.column = 0;
544                } else {
545                    position.column += len_utf8 as u32;
546                }
547            }
548            point_range
549        })
550        .collect::<Vec<_>>();
551    (text, point_ranges)
552}
553
554#[cfg(feature = "neovim")]
555fn encode_ranges(text: &str, point_ranges: &Vec<Range<Point>>) -> String {
556    let byte_ranges = point_ranges
557        .into_iter()
558        .map(|range| {
559            let mut byte_range = 0..0;
560            let mut ix = 0;
561            let mut position = Point::zero();
562            for c in text.chars().chain(['\0']) {
563                if position == range.start {
564                    byte_range.start = ix;
565                }
566                if position == range.end {
567                    byte_range.end = ix;
568                }
569                let len_utf8 = c.len_utf8();
570                ix += len_utf8;
571                if c == '\n' {
572                    position.row += 1;
573                    position.column = 0;
574                } else {
575                    position.column += len_utf8 as u32;
576                }
577            }
578            byte_range
579        })
580        .collect::<Vec<_>>();
581    util::test::generate_marked_text(text, &byte_ranges[..], true)
582}