neovim_connection.rs

  1#[cfg(feature = "neovim")]
  2use std::{
  3    cmp,
  4    ops::{Deref, DerefMut},
  5};
  6use std::{ops::Range, path::PathBuf};
  7
  8#[cfg(feature = "neovim")]
  9use async_compat::Compat;
 10#[cfg(feature = "neovim")]
 11use async_trait::async_trait;
 12#[cfg(feature = "neovim")]
 13use gpui::keymap_matcher::Keystroke;
 14
 15use language::Point;
 16
 17#[cfg(feature = "neovim")]
 18use nvim_rs::{
 19    create::tokio::new_child_cmd, error::LoopError, Handler, Neovim, UiAttachOptions, Value,
 20};
 21#[cfg(feature = "neovim")]
 22use parking_lot::ReentrantMutex;
 23use serde::{Deserialize, Serialize};
 24#[cfg(feature = "neovim")]
 25use tokio::{
 26    process::{Child, ChildStdin, Command},
 27    task::JoinHandle,
 28};
 29
 30use crate::state::Mode;
 31use collections::VecDeque;
 32
 33// Neovim doesn't like to be started simultaneously from multiple threads. We use this lock
 34// to ensure we are only constructing one neovim connection at a time.
 35#[cfg(feature = "neovim")]
 36static NEOVIM_LOCK: ReentrantMutex<()> = ReentrantMutex::new(());
 37
 38#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
 39pub enum NeovimData {
 40    Put { state: String },
 41    Key(String),
 42    Get { state: String, mode: Option<Mode> },
 43    ReadRegister { name: char, value: String },
 44    SetOption { value: String },
 45}
 46
 47pub struct NeovimConnection {
 48    data: VecDeque<NeovimData>,
 49    #[cfg(feature = "neovim")]
 50    test_case_id: String,
 51    #[cfg(feature = "neovim")]
 52    nvim: Neovim<nvim_rs::compat::tokio::Compat<ChildStdin>>,
 53    #[cfg(feature = "neovim")]
 54    _join_handle: JoinHandle<Result<(), Box<LoopError>>>,
 55    #[cfg(feature = "neovim")]
 56    _child: Child,
 57}
 58
 59impl NeovimConnection {
 60    pub async fn new(test_case_id: String) -> Self {
 61        #[cfg(feature = "neovim")]
 62        let handler = NvimHandler {};
 63        #[cfg(feature = "neovim")]
 64        let (nvim, join_handle, child) = Compat::new(async {
 65            // Ensure we don't create neovim connections in parallel
 66            let _lock = NEOVIM_LOCK.lock();
 67            let (nvim, join_handle, child) = new_child_cmd(
 68                &mut Command::new("nvim").arg("--embed").arg("--clean"),
 69                handler,
 70            )
 71            .await
 72            .expect("Could not connect to neovim process");
 73
 74            nvim.ui_attach(100, 100, &UiAttachOptions::default())
 75                .await
 76                .expect("Could not attach to ui");
 77
 78            // Makes system act a little more like zed in terms of indentation
 79            nvim.set_option("smartindent", nvim_rs::Value::Boolean(true))
 80                .await
 81                .expect("Could not set smartindent on startup");
 82
 83            (nvim, join_handle, child)
 84        })
 85        .await;
 86
 87        Self {
 88            #[cfg(feature = "neovim")]
 89            data: Default::default(),
 90            #[cfg(not(feature = "neovim"))]
 91            data: Self::read_test_data(&test_case_id),
 92            #[cfg(feature = "neovim")]
 93            test_case_id,
 94            #[cfg(feature = "neovim")]
 95            nvim,
 96            #[cfg(feature = "neovim")]
 97            _join_handle: join_handle,
 98            #[cfg(feature = "neovim")]
 99            _child: child,
100        }
101    }
102
103    // Sends a keystroke to the neovim process.
104    #[cfg(feature = "neovim")]
105    pub async fn send_keystroke(&mut self, keystroke_text: &str) {
106        let keystroke = Keystroke::parse(keystroke_text).unwrap();
107        let special = keystroke.shift
108            || keystroke.ctrl
109            || keystroke.alt
110            || keystroke.cmd
111            || keystroke.key.len() > 1;
112        let start = if special { "<" } else { "" };
113        let shift = if keystroke.shift { "S-" } else { "" };
114        let ctrl = if keystroke.ctrl { "C-" } else { "" };
115        let alt = if keystroke.alt { "M-" } else { "" };
116        let cmd = if keystroke.cmd { "D-" } else { "" };
117        let end = if special { ">" } else { "" };
118
119        let key = format!("{start}{shift}{ctrl}{alt}{cmd}{}{end}", keystroke.key);
120
121        self.data
122            .push_back(NeovimData::Key(keystroke_text.to_string()));
123        self.nvim
124            .input(&key)
125            .await
126            .expect("Could not input keystroke");
127    }
128
129    #[cfg(not(feature = "neovim"))]
130    pub async fn send_keystroke(&mut self, keystroke_text: &str) {
131        if matches!(self.data.front(), Some(NeovimData::Get { .. })) {
132            self.data.pop_front();
133        }
134        assert_eq!(
135            self.data.pop_front(),
136            Some(NeovimData::Key(keystroke_text.to_string())),
137            "operation does not match recorded script. re-record with --features=neovim"
138        );
139    }
140
141    #[cfg(feature = "neovim")]
142    pub async fn set_state(&mut self, marked_text: &str) {
143        let (text, selections) = parse_state(&marked_text);
144
145        let nvim_buffer = self
146            .nvim
147            .get_current_buf()
148            .await
149            .expect("Could not get neovim buffer");
150        let lines = text
151            .split('\n')
152            .map(|line| line.to_string())
153            .collect::<Vec<_>>();
154
155        nvim_buffer
156            .set_lines(0, -1, false, lines)
157            .await
158            .expect("Could not set nvim buffer text");
159
160        self.nvim
161            .input("<escape>")
162            .await
163            .expect("Could not send escape to nvim");
164        self.nvim
165            .input("<escape>")
166            .await
167            .expect("Could not send escape to nvim");
168
169        let nvim_window = self
170            .nvim
171            .get_current_win()
172            .await
173            .expect("Could not get neovim window");
174
175        if selections.len() != 1 {
176            panic!("must have one selection");
177        }
178        let selection = &selections[0];
179
180        let cursor = selection.start;
181        nvim_window
182            .set_cursor((cursor.row as i64 + 1, cursor.column as i64))
183            .await
184            .expect("Could not set nvim cursor position");
185
186        if !selection.is_empty() {
187            self.nvim
188                .input("v")
189                .await
190                .expect("could not enter visual mode");
191
192            let cursor = selection.end;
193            nvim_window
194                .set_cursor((cursor.row as i64 + 1, cursor.column as i64))
195                .await
196                .expect("Could not set nvim cursor position");
197        }
198
199        if let Some(NeovimData::Get { mode, state }) = self.data.back() {
200            if *mode == Some(Mode::Normal) && *state == marked_text {
201                return;
202            }
203        }
204        self.data.push_back(NeovimData::Put {
205            state: marked_text.to_string(),
206        })
207    }
208
209    #[cfg(not(feature = "neovim"))]
210    pub async fn set_state(&mut self, marked_text: &str) {
211        if let Some(NeovimData::Get { mode, state: text }) = self.data.front() {
212            if *mode == Some(Mode::Normal) && *text == marked_text {
213                return;
214            }
215            self.data.pop_front();
216        }
217        assert_eq!(
218            self.data.pop_front(),
219            Some(NeovimData::Put {
220                state: marked_text.to_string()
221            }),
222            "operation does not match recorded script. re-record with --features=neovim"
223        );
224    }
225
226    #[cfg(feature = "neovim")]
227    pub async fn set_option(&mut self, value: &str) {
228        self.nvim
229            .command_output(format!("set {}", value).as_str())
230            .await
231            .unwrap();
232
233        self.data.push_back(NeovimData::SetOption {
234            value: value.to_string(),
235        })
236    }
237
238    #[cfg(not(feature = "neovim"))]
239    pub async fn set_option(&mut self, value: &str) {
240        assert_eq!(
241            self.data.pop_front(),
242            Some(NeovimData::SetOption {
243                value: value.to_string(),
244            }),
245            "operation does not match recorded script. re-record with --features=neovim"
246        );
247    }
248
249    #[cfg(not(feature = "neovim"))]
250    pub async fn read_register(&mut self, register: char) -> String {
251        if let Some(NeovimData::Get { .. }) = self.data.front() {
252            self.data.pop_front();
253        };
254        if let Some(NeovimData::ReadRegister { name, value }) = self.data.pop_front() {
255            if name == register {
256                return value;
257            }
258        }
259
260        panic!("operation does not match recorded script. re-record with --features=neovim")
261    }
262
263    #[cfg(feature = "neovim")]
264    pub async fn read_register(&mut self, name: char) -> String {
265        let value = self
266            .nvim
267            .command_output(format!("echo getreg('{}')", name).as_str())
268            .await
269            .unwrap();
270
271        self.data.push_back(NeovimData::ReadRegister {
272            name,
273            value: value.clone(),
274        });
275
276        value
277    }
278
279    #[cfg(feature = "neovim")]
280    async fn read_position(&mut self, cmd: &str) -> u32 {
281        self.nvim
282            .command_output(cmd)
283            .await
284            .unwrap()
285            .parse::<u32>()
286            .unwrap()
287    }
288
289    #[cfg(feature = "neovim")]
290    pub async fn state(&mut self) -> (Option<Mode>, String, Vec<Range<Point>>) {
291        let nvim_buffer = self
292            .nvim
293            .get_current_buf()
294            .await
295            .expect("Could not get neovim buffer");
296        let text = nvim_buffer
297            .get_lines(0, -1, false)
298            .await
299            .expect("Could not get buffer text")
300            .join("\n");
301
302        // nvim columns are 1-based, so -1.
303        let mut cursor_row = self.read_position("echo line('.')").await - 1;
304        let mut cursor_col = self.read_position("echo col('.')").await - 1;
305        let mut selection_row = self.read_position("echo line('v')").await - 1;
306        let mut selection_col = self.read_position("echo col('v')").await - 1;
307        let total_rows = self.read_position("echo line('$')").await - 1;
308
309        let nvim_mode_text = self
310            .nvim
311            .get_mode()
312            .await
313            .expect("Could not get mode")
314            .into_iter()
315            .find_map(|(key, value)| {
316                if key.as_str() == Some("mode") {
317                    Some(value.as_str().unwrap().to_owned())
318                } else {
319                    None
320                }
321            })
322            .expect("Could not find mode value");
323
324        let mode = match nvim_mode_text.as_ref() {
325            "i" => Some(Mode::Insert),
326            "n" => Some(Mode::Normal),
327            "v" => Some(Mode::Visual),
328            "V" => Some(Mode::VisualLine),
329            "\x16" => Some(Mode::VisualBlock),
330            _ => None,
331        };
332
333        let mut selections = Vec::new();
334        // Vim uses the index of the first and last character in the selection
335        // Zed uses the index of the positions between the characters, so we need
336        // to add one to the end in visual mode.
337        match mode {
338            Some(Mode::VisualBlock) if selection_row != cursor_row => {
339                // in zed we fake a block selecrtion by using multiple cursors (one per line)
340                // this code emulates that.
341                // to deal with casees where the selection is not perfectly rectangular we extract
342                // the content of the selection via the "a register to get the shape correctly.
343                self.nvim.input("\"aygv").await.unwrap();
344                let content = self.nvim.command_output("echo getreg('a')").await.unwrap();
345                let lines = content.split("\n").collect::<Vec<_>>();
346                let top = cmp::min(selection_row, cursor_row);
347                let left = cmp::min(selection_col, cursor_col);
348                for row in top..=cmp::max(selection_row, cursor_row) {
349                    let content = if row - top >= lines.len() as u32 {
350                        ""
351                    } else {
352                        lines[(row - top) as usize]
353                    };
354                    let line_len = self
355                        .read_position(format!("echo strlen(getline({}))", row + 1).as_str())
356                        .await;
357
358                    if left > line_len {
359                        continue;
360                    }
361
362                    let start = Point::new(row, left);
363                    let end = Point::new(row, left + content.len() as u32);
364                    if cursor_col >= selection_col {
365                        selections.push(start..end)
366                    } else {
367                        selections.push(end..start)
368                    }
369                }
370            }
371            Some(Mode::Visual) | Some(Mode::VisualLine) | Some(Mode::VisualBlock) => {
372                if selection_col > cursor_col {
373                    let selection_line_length =
374                        self.read_position("echo strlen(getline(line('v')))").await;
375                    if selection_line_length > selection_col {
376                        selection_col += 1;
377                    } else if selection_row < total_rows {
378                        selection_col = 0;
379                        selection_row += 1;
380                    }
381                } else {
382                    let cursor_line_length =
383                        self.read_position("echo strlen(getline(line('.')))").await;
384                    if cursor_line_length > cursor_col {
385                        cursor_col += 1;
386                    } else if cursor_row < total_rows {
387                        cursor_col = 0;
388                        cursor_row += 1;
389                    }
390                }
391                selections.push(
392                    Point::new(selection_row, selection_col)..Point::new(cursor_row, cursor_col),
393                )
394            }
395            Some(Mode::Insert) | Some(Mode::Normal) | None => selections
396                .push(Point::new(selection_row, selection_col)..Point::new(cursor_row, cursor_col)),
397        }
398
399        let state = NeovimData::Get {
400            mode,
401            state: encode_ranges(&text, &selections),
402        };
403
404        if self.data.back() != Some(&state) {
405            self.data.push_back(state.clone());
406        }
407
408        (mode, text, selections)
409    }
410
411    #[cfg(not(feature = "neovim"))]
412    pub async fn state(&mut self) -> (Option<Mode>, String, Vec<Range<Point>>) {
413        if let Some(NeovimData::Get { state: text, mode }) = self.data.front() {
414            let (text, ranges) = parse_state(text);
415            (*mode, text, ranges)
416        } else {
417            panic!("operation does not match recorded script. re-record with --features=neovim");
418        }
419    }
420
421    pub async fn selections(&mut self) -> Vec<Range<Point>> {
422        self.state().await.2
423    }
424
425    pub async fn mode(&mut self) -> Option<Mode> {
426        self.state().await.0
427    }
428
429    pub async fn text(&mut self) -> String {
430        self.state().await.1
431    }
432
433    fn test_data_path(test_case_id: &str) -> PathBuf {
434        let mut data_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
435        data_path.push("test_data");
436        data_path.push(format!("{}.json", test_case_id));
437        data_path
438    }
439
440    #[cfg(not(feature = "neovim"))]
441    fn read_test_data(test_case_id: &str) -> VecDeque<NeovimData> {
442        let path = Self::test_data_path(test_case_id);
443        let json = std::fs::read_to_string(path).expect(
444            "Could not read test data. Is it generated? Try running test with '--features neovim'",
445        );
446
447        let mut result = VecDeque::new();
448        for line in json.lines() {
449            result.push_back(
450                serde_json::from_str(line)
451                    .expect("invalid test data. regenerate it with '--features neovim'"),
452            );
453        }
454        result
455    }
456
457    #[cfg(feature = "neovim")]
458    fn write_test_data(test_case_id: &str, data: &VecDeque<NeovimData>) {
459        let path = Self::test_data_path(test_case_id);
460        let mut json = Vec::new();
461        for entry in data {
462            serde_json::to_writer(&mut json, entry).unwrap();
463            json.push(b'\n');
464        }
465        std::fs::create_dir_all(path.parent().unwrap())
466            .expect("could not create test data directory");
467        std::fs::write(path, json).expect("could not write out test data");
468    }
469}
470
471#[cfg(feature = "neovim")]
472impl Deref for NeovimConnection {
473    type Target = Neovim<nvim_rs::compat::tokio::Compat<ChildStdin>>;
474
475    fn deref(&self) -> &Self::Target {
476        &self.nvim
477    }
478}
479
480#[cfg(feature = "neovim")]
481impl DerefMut for NeovimConnection {
482    fn deref_mut(&mut self) -> &mut Self::Target {
483        &mut self.nvim
484    }
485}
486
487#[cfg(feature = "neovim")]
488impl Drop for NeovimConnection {
489    fn drop(&mut self) {
490        Self::write_test_data(&self.test_case_id, &self.data);
491    }
492}
493
494#[cfg(feature = "neovim")]
495#[derive(Clone)]
496struct NvimHandler {}
497
498#[cfg(feature = "neovim")]
499#[async_trait]
500impl Handler for NvimHandler {
501    type Writer = nvim_rs::compat::tokio::Compat<ChildStdin>;
502
503    async fn handle_request(
504        &self,
505        _event_name: String,
506        _arguments: Vec<Value>,
507        _neovim: Neovim<Self::Writer>,
508    ) -> Result<Value, Value> {
509        unimplemented!();
510    }
511
512    async fn handle_notify(
513        &self,
514        _event_name: String,
515        _arguments: Vec<Value>,
516        _neovim: Neovim<Self::Writer>,
517    ) {
518    }
519}
520
521fn parse_state(marked_text: &str) -> (String, Vec<Range<Point>>) {
522    let (text, ranges) = util::test::marked_text_ranges(marked_text, true);
523    let point_ranges = ranges
524        .into_iter()
525        .map(|byte_range| {
526            let mut point_range = Point::zero()..Point::zero();
527            let mut ix = 0;
528            let mut position = Point::zero();
529            for c in text.chars().chain(['\0']) {
530                if ix == byte_range.start {
531                    point_range.start = position;
532                }
533                if ix == byte_range.end {
534                    point_range.end = position;
535                }
536                let len_utf8 = c.len_utf8();
537                ix += len_utf8;
538                if c == '\n' {
539                    position.row += 1;
540                    position.column = 0;
541                } else {
542                    position.column += len_utf8 as u32;
543                }
544            }
545            point_range
546        })
547        .collect::<Vec<_>>();
548    (text, point_ranges)
549}
550
551#[cfg(feature = "neovim")]
552fn encode_ranges(text: &str, point_ranges: &Vec<Range<Point>>) -> String {
553    let byte_ranges = point_ranges
554        .into_iter()
555        .map(|range| {
556            let mut byte_range = 0..0;
557            let mut ix = 0;
558            let mut position = Point::zero();
559            for c in text.chars().chain(['\0']) {
560                if position == range.start {
561                    byte_range.start = ix;
562                }
563                if position == range.end {
564                    byte_range.end = ix;
565                }
566                let len_utf8 = c.len_utf8();
567                ix += len_utf8;
568                if c == '\n' {
569                    position.row += 1;
570                    position.column = 0;
571                } else {
572                    position.column += len_utf8 as u32;
573                }
574            }
575            byte_range
576        })
577        .collect::<Vec<_>>();
578    util::test::generate_marked_text(text, &byte_ranges[..], true)
579}