neovim_connection.rs

  1use std::path::PathBuf;
  2#[cfg(feature = "neovim")]
  3use std::{
  4    cmp,
  5    ops::{Deref, DerefMut, Range},
  6};
  7
  8#[cfg(feature = "neovim")]
  9use async_compat::Compat;
 10#[cfg(feature = "neovim")]
 11use async_trait::async_trait;
 12#[cfg(feature = "neovim")]
 13use gpui::Keystroke;
 14
 15#[cfg(feature = "neovim")]
 16use language::Point;
 17
 18#[cfg(feature = "neovim")]
 19use nvim_rs::{
 20    create::tokio::new_child_cmd, error::LoopError, Handler, Neovim, UiAttachOptions, Value,
 21};
 22#[cfg(feature = "neovim")]
 23use parking_lot::ReentrantMutex;
 24use serde::{Deserialize, Serialize};
 25#[cfg(feature = "neovim")]
 26use tokio::{
 27    process::{Child, ChildStdin, Command},
 28    task::JoinHandle,
 29};
 30
 31use crate::state::Mode;
 32use collections::VecDeque;
 33
 34// Neovim doesn't like to be started simultaneously from multiple threads. We use this lock
 35// to ensure we are only constructing one neovim connection at a time.
 36#[cfg(feature = "neovim")]
 37static NEOVIM_LOCK: ReentrantMutex<()> = ReentrantMutex::new(());
 38
 39#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
 40pub enum NeovimData {
 41    Put { state: String },
 42    Key(String),
 43    Get { state: String, mode: Option<Mode> },
 44    ReadRegister { name: char, value: String },
 45    Exec { command: String },
 46    SetOption { value: String },
 47}
 48
 49pub struct NeovimConnection {
 50    data: VecDeque<NeovimData>,
 51    #[cfg(feature = "neovim")]
 52    test_case_id: String,
 53    #[cfg(feature = "neovim")]
 54    nvim: Neovim<nvim_rs::compat::tokio::Compat<ChildStdin>>,
 55    #[cfg(feature = "neovim")]
 56    _join_handle: JoinHandle<Result<(), Box<LoopError>>>,
 57    #[cfg(feature = "neovim")]
 58    _child: Child,
 59}
 60
 61impl NeovimConnection {
 62    pub async fn new(test_case_id: String) -> Self {
 63        #[cfg(feature = "neovim")]
 64        let handler = NvimHandler {};
 65        #[cfg(feature = "neovim")]
 66        let (nvim, join_handle, child) = Compat::new(async {
 67            // Ensure we don't create neovim connections in parallel
 68            let _lock = NEOVIM_LOCK.lock();
 69            let (nvim, join_handle, child) = new_child_cmd(
 70                &mut Command::new("nvim")
 71                    .arg("--embed")
 72                    .arg("--clean")
 73                    // disable swap (otherwise after about 1000 test runs you run out of swap file names)
 74                    .arg("-n")
 75                    // disable writing files (just in case)
 76                    .arg("-m"),
 77                handler,
 78            )
 79            .await
 80            .expect("Could not connect to neovim process");
 81
 82            nvim.ui_attach(100, 100, &UiAttachOptions::default())
 83                .await
 84                .expect("Could not attach to ui");
 85
 86            // Makes system act a little more like zed in terms of indentation
 87            nvim.set_option("smartindent", nvim_rs::Value::Boolean(true))
 88                .await
 89                .expect("Could not set smartindent on startup");
 90
 91            (nvim, join_handle, child)
 92        })
 93        .await;
 94
 95        Self {
 96            #[cfg(feature = "neovim")]
 97            data: Default::default(),
 98            #[cfg(not(feature = "neovim"))]
 99            data: Self::read_test_data(&test_case_id),
100            #[cfg(feature = "neovim")]
101            test_case_id,
102            #[cfg(feature = "neovim")]
103            nvim,
104            #[cfg(feature = "neovim")]
105            _join_handle: join_handle,
106            #[cfg(feature = "neovim")]
107            _child: child,
108        }
109    }
110
111    // Sends a keystroke to the neovim process.
112    #[cfg(feature = "neovim")]
113    pub async fn send_keystroke(&mut self, keystroke_text: &str) {
114        let mut keystroke = Keystroke::parse(keystroke_text).unwrap();
115
116        if keystroke.key == "<" {
117            keystroke.key = "lt".to_string()
118        }
119
120        let special = keystroke.modifiers.shift
121            || keystroke.modifiers.control
122            || keystroke.modifiers.alt
123            || keystroke.modifiers.command
124            || keystroke.key.len() > 1;
125        let start = if special { "<" } else { "" };
126        let shift = if keystroke.modifiers.shift { "S-" } else { "" };
127        let ctrl = if keystroke.modifiers.control {
128            "C-"
129        } else {
130            ""
131        };
132        let alt = if keystroke.modifiers.alt { "M-" } else { "" };
133        let cmd = if keystroke.modifiers.command {
134            "D-"
135        } else {
136            ""
137        };
138        let end = if special { ">" } else { "" };
139
140        let key = format!("{start}{shift}{ctrl}{alt}{cmd}{}{end}", keystroke.key);
141
142        self.data
143            .push_back(NeovimData::Key(keystroke_text.to_string()));
144        self.nvim
145            .input(&key)
146            .await
147            .expect("Could not input keystroke");
148    }
149
150    #[cfg(not(feature = "neovim"))]
151    pub async fn send_keystroke(&mut self, keystroke_text: &str) {
152        if matches!(self.data.front(), Some(NeovimData::Get { .. })) {
153            self.data.pop_front();
154        }
155        assert_eq!(
156            self.data.pop_front(),
157            Some(NeovimData::Key(keystroke_text.to_string())),
158            "operation does not match recorded script. re-record with --features=neovim"
159        );
160    }
161
162    #[cfg(feature = "neovim")]
163    pub async fn set_state(&mut self, marked_text: &str) {
164        let (text, selections) = parse_state(&marked_text);
165
166        let nvim_buffer = self
167            .nvim
168            .get_current_buf()
169            .await
170            .expect("Could not get neovim buffer");
171        let lines = text
172            .split('\n')
173            .map(|line| line.to_string())
174            .collect::<Vec<_>>();
175
176        nvim_buffer
177            .set_lines(0, -1, false, lines)
178            .await
179            .expect("Could not set nvim buffer text");
180
181        self.nvim
182            .input("<escape>")
183            .await
184            .expect("Could not send escape to nvim");
185        self.nvim
186            .input("<escape>")
187            .await
188            .expect("Could not send escape to nvim");
189
190        let nvim_window = self
191            .nvim
192            .get_current_win()
193            .await
194            .expect("Could not get neovim window");
195
196        if selections.len() != 1 {
197            panic!("must have one selection");
198        }
199        let selection = &selections[0];
200
201        let cursor = selection.start;
202        nvim_window
203            .set_cursor((cursor.row as i64 + 1, cursor.column as i64))
204            .await
205            .expect("Could not set nvim cursor position");
206
207        if !selection.is_empty() {
208            self.nvim
209                .input("v")
210                .await
211                .expect("could not enter visual mode");
212
213            let cursor = selection.end;
214            nvim_window
215                .set_cursor((cursor.row as i64 + 1, cursor.column as i64))
216                .await
217                .expect("Could not set nvim cursor position");
218        }
219
220        if let Some(NeovimData::Get { mode, state }) = self.data.back() {
221            if *mode == Some(Mode::Normal) && *state == marked_text {
222                return;
223            }
224        }
225        self.data.push_back(NeovimData::Put {
226            state: marked_text.to_string(),
227        })
228    }
229
230    #[cfg(not(feature = "neovim"))]
231    pub async fn set_state(&mut self, marked_text: &str) {
232        if let Some(NeovimData::Get { mode, state: text }) = self.data.front() {
233            if *mode == Some(Mode::Normal) && *text == marked_text {
234                return;
235            }
236            self.data.pop_front();
237        }
238        assert_eq!(
239            self.data.pop_front(),
240            Some(NeovimData::Put {
241                state: marked_text.to_string()
242            }),
243            "operation does not match recorded script. re-record with --features=neovim"
244        );
245    }
246
247    #[cfg(feature = "neovim")]
248    pub async fn set_option(&mut self, value: &str) {
249        self.nvim
250            .command_output(format!("set {}", value).as_str())
251            .await
252            .unwrap();
253
254        self.data.push_back(NeovimData::SetOption {
255            value: value.to_string(),
256        })
257    }
258
259    #[cfg(not(feature = "neovim"))]
260    pub async fn set_option(&mut self, value: &str) {
261        if let Some(NeovimData::Get { .. }) = self.data.front() {
262            self.data.pop_front();
263        };
264        assert_eq!(
265            self.data.pop_front(),
266            Some(NeovimData::SetOption {
267                value: value.to_string(),
268            }),
269            "operation does not match recorded script. re-record with --features=neovim"
270        );
271    }
272
273    #[cfg(feature = "neovim")]
274    pub async fn exec(&mut self, value: &str) {
275        self.nvim
276            .command_output(format!("{}", value).as_str())
277            .await
278            .unwrap();
279
280        self.data.push_back(NeovimData::Exec {
281            command: value.to_string(),
282        })
283    }
284
285    #[cfg(not(feature = "neovim"))]
286    pub async fn exec(&mut self, value: &str) {
287        if let Some(NeovimData::Get { .. }) = self.data.front() {
288            self.data.pop_front();
289        };
290        assert_eq!(
291            self.data.pop_front(),
292            Some(NeovimData::Exec {
293                command: value.to_string(),
294            }),
295            "operation does not match recorded script. re-record with --features=neovim"
296        );
297    }
298
299    #[cfg(not(feature = "neovim"))]
300    pub async fn read_register(&mut self, register: char) -> String {
301        if let Some(NeovimData::Get { .. }) = self.data.front() {
302            self.data.pop_front();
303        };
304        if let Some(NeovimData::ReadRegister { name, value }) = self.data.pop_front() {
305            if name == register {
306                return value;
307            }
308        }
309
310        panic!("operation does not match recorded script. re-record with --features=neovim")
311    }
312
313    #[cfg(feature = "neovim")]
314    pub async fn read_register(&mut self, name: char) -> String {
315        let value = self
316            .nvim
317            .command_output(format!("echo getreg('{}')", name).as_str())
318            .await
319            .unwrap();
320
321        self.data.push_back(NeovimData::ReadRegister {
322            name,
323            value: value.clone(),
324        });
325
326        value
327    }
328
329    #[cfg(feature = "neovim")]
330    async fn read_position(&mut self, cmd: &str) -> u32 {
331        self.nvim
332            .command_output(cmd)
333            .await
334            .unwrap()
335            .parse::<u32>()
336            .unwrap()
337    }
338
339    #[cfg(feature = "neovim")]
340    pub async fn state(&mut self) -> (Option<Mode>, String) {
341        let nvim_buffer = self
342            .nvim
343            .get_current_buf()
344            .await
345            .expect("Could not get neovim buffer");
346        let text = nvim_buffer
347            .get_lines(0, -1, false)
348            .await
349            .expect("Could not get buffer text")
350            .join("\n");
351
352        // nvim columns are 1-based, so -1.
353        let mut cursor_row = self.read_position("echo line('.')").await - 1;
354        let mut cursor_col = self.read_position("echo col('.')").await - 1;
355        let mut selection_row = self.read_position("echo line('v')").await - 1;
356        let mut selection_col = self.read_position("echo col('v')").await - 1;
357        let total_rows = self.read_position("echo line('$')").await - 1;
358
359        let nvim_mode_text = self
360            .nvim
361            .get_mode()
362            .await
363            .expect("Could not get mode")
364            .into_iter()
365            .find_map(|(key, value)| {
366                if key.as_str() == Some("mode") {
367                    Some(value.as_str().unwrap().to_owned())
368                } else {
369                    None
370                }
371            })
372            .expect("Could not find mode value");
373
374        let mode = match nvim_mode_text.as_ref() {
375            "i" => Some(Mode::Insert),
376            "n" => Some(Mode::Normal),
377            "v" => Some(Mode::Visual),
378            "V" => Some(Mode::VisualLine),
379            "\x16" => Some(Mode::VisualBlock),
380            _ => None,
381        };
382
383        let mut selections = Vec::new();
384        // Vim uses the index of the first and last character in the selection
385        // Zed uses the index of the positions between the characters, so we need
386        // to add one to the end in visual mode.
387        match mode {
388            Some(Mode::VisualBlock) if selection_row != cursor_row => {
389                // in zed we fake a block selection by using multiple cursors (one per line)
390                // this code emulates that.
391                // to deal with casees where the selection is not perfectly rectangular we extract
392                // the content of the selection via the "a register to get the shape correctly.
393                self.nvim.input("\"aygv").await.unwrap();
394                let content = self.nvim.command_output("echo getreg('a')").await.unwrap();
395                let lines = content.split("\n").collect::<Vec<_>>();
396                let top = cmp::min(selection_row, cursor_row);
397                let left = cmp::min(selection_col, cursor_col);
398                for row in top..=cmp::max(selection_row, cursor_row) {
399                    let content = if row - top >= lines.len() as u32 {
400                        ""
401                    } else {
402                        lines[(row - top) as usize]
403                    };
404                    let line_len = self
405                        .read_position(format!("echo strlen(getline({}))", row + 1).as_str())
406                        .await;
407
408                    if left > line_len {
409                        continue;
410                    }
411
412                    let start = Point::new(row, left);
413                    let end = Point::new(row, left + content.len() as u32);
414                    if cursor_col >= selection_col {
415                        selections.push(start..end)
416                    } else {
417                        selections.push(end..start)
418                    }
419                }
420            }
421            Some(Mode::Visual) | Some(Mode::VisualLine) | Some(Mode::VisualBlock) => {
422                if selection_col > cursor_col {
423                    let selection_line_length =
424                        self.read_position("echo strlen(getline(line('v')))").await;
425                    if selection_line_length > selection_col {
426                        selection_col += 1;
427                    } else if selection_row < total_rows {
428                        selection_col = 0;
429                        selection_row += 1;
430                    }
431                } else {
432                    let cursor_line_length =
433                        self.read_position("echo strlen(getline(line('.')))").await;
434                    if cursor_line_length > cursor_col {
435                        cursor_col += 1;
436                    } else if cursor_row < total_rows {
437                        cursor_col = 0;
438                        cursor_row += 1;
439                    }
440                }
441                selections.push(
442                    Point::new(selection_row, selection_col)..Point::new(cursor_row, cursor_col),
443                )
444            }
445            Some(Mode::Insert) | Some(Mode::Normal) | None => selections
446                .push(Point::new(selection_row, selection_col)..Point::new(cursor_row, cursor_col)),
447        }
448
449        let ranges = encode_ranges(&text, &selections);
450        let state = NeovimData::Get {
451            mode,
452            state: ranges.clone(),
453        };
454
455        if self.data.back() != Some(&state) {
456            self.data.push_back(state.clone());
457        }
458
459        (mode, ranges)
460    }
461
462    #[cfg(not(feature = "neovim"))]
463    pub async fn state(&mut self) -> (Option<Mode>, String) {
464        if let Some(NeovimData::Get { state: raw, mode }) = self.data.front() {
465            (*mode, raw.to_string())
466        } else {
467            panic!("operation does not match recorded script. re-record with --features=neovim");
468        }
469    }
470
471    pub async fn mode(&mut self) -> Option<Mode> {
472        self.state().await.0
473    }
474
475    pub async fn marked_text(&mut self) -> String {
476        self.state().await.1
477    }
478
479    fn test_data_path(test_case_id: &str) -> PathBuf {
480        let mut data_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
481        data_path.push("test_data");
482        data_path.push(format!("{}.json", test_case_id));
483        data_path
484    }
485
486    #[cfg(not(feature = "neovim"))]
487    fn read_test_data(test_case_id: &str) -> VecDeque<NeovimData> {
488        let path = Self::test_data_path(test_case_id);
489        let json = std::fs::read_to_string(path).expect(
490            "Could not read test data. Is it generated? Try running test with '--features neovim'",
491        );
492
493        let mut result = VecDeque::new();
494        for line in json.lines() {
495            result.push_back(
496                serde_json::from_str(line)
497                    .expect("invalid test data. regenerate it with '--features neovim'"),
498            );
499        }
500        result
501    }
502
503    #[cfg(feature = "neovim")]
504    fn write_test_data(test_case_id: &str, data: &VecDeque<NeovimData>) {
505        let path = Self::test_data_path(test_case_id);
506        let mut json = Vec::new();
507        for entry in data {
508            serde_json::to_writer(&mut json, entry).unwrap();
509            json.push(b'\n');
510        }
511        std::fs::create_dir_all(path.parent().unwrap())
512            .expect("could not create test data directory");
513        std::fs::write(path, json).expect("could not write out test data");
514    }
515}
516
517#[cfg(feature = "neovim")]
518impl Deref for NeovimConnection {
519    type Target = Neovim<nvim_rs::compat::tokio::Compat<ChildStdin>>;
520
521    fn deref(&self) -> &Self::Target {
522        &self.nvim
523    }
524}
525
526#[cfg(feature = "neovim")]
527impl DerefMut for NeovimConnection {
528    fn deref_mut(&mut self) -> &mut Self::Target {
529        &mut self.nvim
530    }
531}
532
533#[cfg(feature = "neovim")]
534impl Drop for NeovimConnection {
535    fn drop(&mut self) {
536        Self::write_test_data(&self.test_case_id, &self.data);
537    }
538}
539
540#[cfg(feature = "neovim")]
541#[derive(Clone)]
542struct NvimHandler {}
543
544#[cfg(feature = "neovim")]
545#[async_trait]
546impl Handler for NvimHandler {
547    type Writer = nvim_rs::compat::tokio::Compat<ChildStdin>;
548
549    async fn handle_request(
550        &self,
551        _event_name: String,
552        _arguments: Vec<Value>,
553        _neovim: Neovim<Self::Writer>,
554    ) -> Result<Value, Value> {
555        unimplemented!();
556    }
557
558    async fn handle_notify(
559        &self,
560        _event_name: String,
561        _arguments: Vec<Value>,
562        _neovim: Neovim<Self::Writer>,
563    ) {
564    }
565}
566
567#[cfg(feature = "neovim")]
568fn parse_state(marked_text: &str) -> (String, Vec<Range<Point>>) {
569    let (text, ranges) = util::test::marked_text_ranges(marked_text, true);
570    let point_ranges = ranges
571        .into_iter()
572        .map(|byte_range| {
573            let mut point_range = Point::zero()..Point::zero();
574            let mut ix = 0;
575            let mut position = Point::zero();
576            for c in text.chars().chain(['\0']) {
577                if ix == byte_range.start {
578                    point_range.start = position;
579                }
580                if ix == byte_range.end {
581                    point_range.end = position;
582                }
583                let len_utf8 = c.len_utf8();
584                ix += len_utf8;
585                if c == '\n' {
586                    position.row += 1;
587                    position.column = 0;
588                } else {
589                    position.column += len_utf8 as u32;
590                }
591            }
592            point_range
593        })
594        .collect::<Vec<_>>();
595    (text, point_ranges)
596}
597
598#[cfg(feature = "neovim")]
599fn encode_ranges(text: &str, point_ranges: &Vec<Range<Point>>) -> String {
600    let byte_ranges = point_ranges
601        .into_iter()
602        .map(|range| {
603            let mut byte_range = 0..0;
604            let mut ix = 0;
605            let mut position = Point::zero();
606            for c in text.chars().chain(['\0']) {
607                if position == range.start {
608                    byte_range.start = ix;
609                }
610                if position == range.end {
611                    byte_range.end = ix;
612                }
613                let len_utf8 = c.len_utf8();
614                ix += len_utf8;
615                if c == '\n' {
616                    position.row += 1;
617                    position.column = 0;
618                } else {
619                    position.column += len_utf8 as u32;
620                }
621            }
622            byte_range
623        })
624        .collect::<Vec<_>>();
625    util::test::generate_marked_text(text, &byte_ranges[..], true)
626}