repl_editor.rs

  1//! REPL operations on an [`Editor`].
  2
  3use std::ops::Range;
  4use std::sync::Arc;
  5
  6use anyhow::{Context, Result};
  7use editor::Editor;
  8use gpui::{prelude::*, Entity, View, WeakView, WindowContext};
  9use language::{BufferSnapshot, Language, LanguageName, Point};
 10
 11use crate::repl_store::ReplStore;
 12use crate::session::SessionEvent;
 13use crate::{KernelSpecification, Session};
 14
 15pub fn assign_kernelspec(
 16    kernel_specification: KernelSpecification,
 17    weak_editor: WeakView<Editor>,
 18    cx: &mut WindowContext,
 19) -> Result<()> {
 20    let store = ReplStore::global(cx);
 21    if !store.read(cx).is_enabled() {
 22        return Ok(());
 23    }
 24
 25    let fs = store.read(cx).fs().clone();
 26    let telemetry = store.read(cx).telemetry().clone();
 27
 28    if let Some(session) = store.read(cx).get_session(weak_editor.entity_id()).cloned() {
 29        // Drop previous session, start new one
 30        session.update(cx, |session, cx| {
 31            session.clear_outputs(cx);
 32            session.shutdown(cx);
 33            cx.notify();
 34        });
 35    }
 36
 37    let session = cx
 38        .new_view(|cx| Session::new(weak_editor.clone(), fs, telemetry, kernel_specification, cx));
 39
 40    weak_editor
 41        .update(cx, |_editor, cx| {
 42            cx.notify();
 43
 44            cx.subscribe(&session, {
 45                let store = store.clone();
 46                move |_this, _session, event, cx| match event {
 47                    SessionEvent::Shutdown(shutdown_event) => {
 48                        store.update(cx, |store, _cx| {
 49                            store.remove_session(shutdown_event.entity_id());
 50                        });
 51                    }
 52                }
 53            })
 54            .detach();
 55        })
 56        .ok();
 57
 58    store.update(cx, |store, _cx| {
 59        store.insert_session(weak_editor.entity_id(), session.clone());
 60    });
 61
 62    Ok(())
 63}
 64
 65pub fn run(editor: WeakView<Editor>, move_down: bool, cx: &mut WindowContext) -> Result<()> {
 66    let store = ReplStore::global(cx);
 67    if !store.read(cx).is_enabled() {
 68        return Ok(());
 69    }
 70
 71    let editor = editor.upgrade().context("editor was dropped")?;
 72    let selected_range = editor
 73        .update(cx, |editor, cx| editor.selections.newest_adjusted(cx))
 74        .range();
 75    let multibuffer = editor.read(cx).buffer().clone();
 76    let Some(buffer) = multibuffer.read(cx).as_singleton() else {
 77        return Ok(());
 78    };
 79
 80    let (runnable_ranges, next_cell_point) =
 81        runnable_ranges(&buffer.read(cx).snapshot(), selected_range);
 82
 83    for runnable_range in runnable_ranges {
 84        let Some(language) = multibuffer.read(cx).language_at(runnable_range.start, cx) else {
 85            continue;
 86        };
 87
 88        let kernel_specification = store.update(cx, |store, cx| {
 89            store
 90                .kernelspec(language.code_fence_block_name().as_ref(), cx)
 91                .with_context(|| format!("No kernel found for language: {}", language.name()))
 92        })?;
 93
 94        let fs = store.read(cx).fs().clone();
 95        let telemetry = store.read(cx).telemetry().clone();
 96
 97        let session = if let Some(session) = store.read(cx).get_session(editor.entity_id()).cloned()
 98        {
 99            session
100        } else {
101            let weak_editor = editor.downgrade();
102            let session = cx
103                .new_view(|cx| Session::new(weak_editor, fs, telemetry, kernel_specification, cx));
104
105            editor.update(cx, |_editor, cx| {
106                cx.notify();
107
108                cx.subscribe(&session, {
109                    let store = store.clone();
110                    move |_this, _session, event, cx| match event {
111                        SessionEvent::Shutdown(shutdown_event) => {
112                            store.update(cx, |store, _cx| {
113                                store.remove_session(shutdown_event.entity_id());
114                            });
115                        }
116                    }
117                })
118                .detach();
119            });
120
121            store.update(cx, |store, _cx| {
122                store.insert_session(editor.entity_id(), session.clone());
123            });
124
125            session
126        };
127
128        let selected_text;
129        let anchor_range;
130        let next_cursor;
131        {
132            let snapshot = multibuffer.read(cx).read(cx);
133            selected_text = snapshot
134                .text_for_range(runnable_range.clone())
135                .collect::<String>();
136            anchor_range = snapshot.anchor_before(runnable_range.start)
137                ..snapshot.anchor_after(runnable_range.end);
138            next_cursor = next_cell_point.map(|point| snapshot.anchor_after(point));
139        }
140
141        session.update(cx, |session, cx| {
142            session.execute(selected_text, anchor_range, next_cursor, move_down, cx);
143        });
144    }
145
146    anyhow::Ok(())
147}
148
149#[allow(clippy::large_enum_variant)]
150pub enum SessionSupport {
151    ActiveSession(View<Session>),
152    Inactive(KernelSpecification),
153    RequiresSetup(LanguageName),
154    Unsupported,
155}
156
157pub fn session(editor: WeakView<Editor>, cx: &mut WindowContext) -> SessionSupport {
158    let store = ReplStore::global(cx);
159    let entity_id = editor.entity_id();
160
161    if let Some(session) = store.read(cx).get_session(entity_id).cloned() {
162        return SessionSupport::ActiveSession(session);
163    };
164
165    let Some(language) = get_language(editor, cx) else {
166        return SessionSupport::Unsupported;
167    };
168    let kernelspec = store.update(cx, |store, cx| {
169        store.kernelspec(language.code_fence_block_name().as_ref(), cx)
170    });
171
172    match kernelspec {
173        Some(kernelspec) => SessionSupport::Inactive(kernelspec),
174        None => {
175            if language_supported(&language) {
176                SessionSupport::RequiresSetup(language.name())
177            } else {
178                SessionSupport::Unsupported
179            }
180        }
181    }
182}
183
184pub fn clear_outputs(editor: WeakView<Editor>, cx: &mut WindowContext) {
185    let store = ReplStore::global(cx);
186    let entity_id = editor.entity_id();
187    let Some(session) = store.read(cx).get_session(entity_id).cloned() else {
188        return;
189    };
190    session.update(cx, |session, cx| {
191        session.clear_outputs(cx);
192        cx.notify();
193    });
194}
195
196pub fn interrupt(editor: WeakView<Editor>, cx: &mut WindowContext) {
197    let store = ReplStore::global(cx);
198    let entity_id = editor.entity_id();
199    let Some(session) = store.read(cx).get_session(entity_id).cloned() else {
200        return;
201    };
202
203    session.update(cx, |session, cx| {
204        session.interrupt(cx);
205        cx.notify();
206    });
207}
208
209pub fn shutdown(editor: WeakView<Editor>, cx: &mut WindowContext) {
210    let store = ReplStore::global(cx);
211    let entity_id = editor.entity_id();
212    let Some(session) = store.read(cx).get_session(entity_id).cloned() else {
213        return;
214    };
215
216    session.update(cx, |session, cx| {
217        session.shutdown(cx);
218        cx.notify();
219    });
220}
221
222pub fn restart(editor: WeakView<Editor>, cx: &mut WindowContext) {
223    let Some(editor) = editor.upgrade() else {
224        return;
225    };
226
227    let entity_id = editor.entity_id();
228
229    let Some(session) = ReplStore::global(cx)
230        .read(cx)
231        .get_session(entity_id)
232        .cloned()
233    else {
234        return;
235    };
236
237    session.update(cx, |session, cx| {
238        session.restart(cx);
239        cx.notify();
240    });
241}
242
243fn cell_range(buffer: &BufferSnapshot, start_row: u32, end_row: u32) -> Range<Point> {
244    let mut snippet_end_row = end_row;
245    while buffer.is_line_blank(snippet_end_row) && snippet_end_row > start_row {
246        snippet_end_row -= 1;
247    }
248    Point::new(start_row, 0)..Point::new(snippet_end_row, buffer.line_len(snippet_end_row))
249}
250
251// Returns the ranges of the snippets in the buffer and the next point for moving the cursor to
252fn jupytext_cells(
253    buffer: &BufferSnapshot,
254    range: Range<Point>,
255) -> (Vec<Range<Point>>, Option<Point>) {
256    let mut current_row = range.start.row;
257
258    let Some(language) = buffer.language() else {
259        return (Vec::new(), None);
260    };
261
262    let default_scope = language.default_scope();
263    let comment_prefixes = default_scope.line_comment_prefixes();
264    if comment_prefixes.is_empty() {
265        return (Vec::new(), None);
266    }
267
268    let jupytext_prefixes = comment_prefixes
269        .iter()
270        .map(|comment_prefix| format!("{comment_prefix}%%"))
271        .collect::<Vec<_>>();
272
273    let mut snippet_start_row = None;
274    loop {
275        if jupytext_prefixes
276            .iter()
277            .any(|prefix| buffer.contains_str_at(Point::new(current_row, 0), prefix))
278        {
279            snippet_start_row = Some(current_row);
280            break;
281        } else if current_row > 0 {
282            current_row -= 1;
283        } else {
284            break;
285        }
286    }
287
288    let mut snippets = Vec::new();
289    if let Some(mut snippet_start_row) = snippet_start_row {
290        for current_row in range.start.row + 1..=buffer.max_point().row {
291            if jupytext_prefixes
292                .iter()
293                .any(|prefix| buffer.contains_str_at(Point::new(current_row, 0), prefix))
294            {
295                snippets.push(cell_range(buffer, snippet_start_row, current_row - 1));
296
297                if current_row <= range.end.row {
298                    snippet_start_row = current_row;
299                } else {
300                    // Return our snippets as well as the next point for moving the cursor to
301                    return (snippets, Some(Point::new(current_row, 0)));
302                }
303            }
304        }
305
306        // Go to the end of the buffer (no more jupytext cells found)
307        snippets.push(cell_range(
308            buffer,
309            snippet_start_row,
310            buffer.max_point().row,
311        ));
312    }
313
314    (snippets, None)
315}
316
317fn runnable_ranges(
318    buffer: &BufferSnapshot,
319    range: Range<Point>,
320) -> (Vec<Range<Point>>, Option<Point>) {
321    if let Some(language) = buffer.language() {
322        if language.name() == "Markdown".into() {
323            return (markdown_code_blocks(buffer, range.clone()), None);
324        }
325    }
326
327    let (jupytext_snippets, next_cursor) = jupytext_cells(buffer, range.clone());
328    if !jupytext_snippets.is_empty() {
329        return (jupytext_snippets, next_cursor);
330    }
331
332    let snippet_range = cell_range(buffer, range.start.row, range.end.row);
333    let start_language = buffer.language_at(snippet_range.start);
334    let end_language = buffer.language_at(snippet_range.end);
335
336    if start_language
337        .zip(end_language)
338        .map_or(false, |(start, end)| start == end)
339    {
340        (vec![snippet_range], None)
341    } else {
342        (Vec::new(), None)
343    }
344}
345
346// We allow markdown code blocks to end in a trailing newline in order to render the output
347// below the final code fence. This is different than our behavior for selections and Jupytext cells.
348fn markdown_code_blocks(buffer: &BufferSnapshot, range: Range<Point>) -> Vec<Range<Point>> {
349    buffer
350        .injections_intersecting_range(range)
351        .filter(|(_, language)| language_supported(language))
352        .map(|(content_range, _)| {
353            buffer.offset_to_point(content_range.start)..buffer.offset_to_point(content_range.end)
354        })
355        .collect()
356}
357
358fn language_supported(language: &Arc<Language>) -> bool {
359    match language.name().0.as_ref() {
360        "TypeScript" | "Python" => true,
361        _ => false,
362    }
363}
364
365fn get_language(editor: WeakView<Editor>, cx: &mut WindowContext) -> Option<Arc<Language>> {
366    editor
367        .update(cx, |editor, cx| {
368            let selection = editor.selections.newest::<usize>(cx);
369            let buffer = editor.buffer().read(cx).snapshot(cx);
370            buffer.language_at(selection.head()).cloned()
371        })
372        .ok()
373        .flatten()
374}
375
376#[cfg(test)]
377mod tests {
378    use super::*;
379    use gpui::{AppContext, Context};
380    use indoc::indoc;
381    use language::{Buffer, Language, LanguageConfig, LanguageRegistry};
382
383    #[gpui::test]
384    fn test_snippet_ranges(cx: &mut AppContext) {
385        // Create a test language
386        let test_language = Arc::new(Language::new(
387            LanguageConfig {
388                name: "TestLang".into(),
389                line_comments: vec!["# ".into()],
390                ..Default::default()
391            },
392            None,
393        ));
394
395        let buffer = cx.new_model(|cx| {
396            Buffer::local(
397                indoc! { r#"
398                    print(1 + 1)
399                    print(2 + 2)
400
401                    print(4 + 4)
402
403
404                "# },
405                cx,
406            )
407            .with_language(test_language, cx)
408        });
409        let snapshot = buffer.read(cx).snapshot();
410
411        // Single-point selection
412        let (snippets, _) = runnable_ranges(&snapshot, Point::new(0, 4)..Point::new(0, 4));
413        let snippets = snippets
414            .into_iter()
415            .map(|range| snapshot.text_for_range(range).collect::<String>())
416            .collect::<Vec<_>>();
417        assert_eq!(snippets, vec!["print(1 + 1)"]);
418
419        // Multi-line selection
420        let (snippets, _) = runnable_ranges(&snapshot, Point::new(0, 5)..Point::new(2, 0));
421        let snippets = snippets
422            .into_iter()
423            .map(|range| snapshot.text_for_range(range).collect::<String>())
424            .collect::<Vec<_>>();
425        assert_eq!(
426            snippets,
427            vec![indoc! { r#"
428                print(1 + 1)
429                print(2 + 2)"# }]
430        );
431
432        // Trimming multiple trailing blank lines
433        let (snippets, _) = runnable_ranges(&snapshot, Point::new(0, 5)..Point::new(5, 0));
434
435        let snippets = snippets
436            .into_iter()
437            .map(|range| snapshot.text_for_range(range).collect::<String>())
438            .collect::<Vec<_>>();
439        assert_eq!(
440            snippets,
441            vec![indoc! { r#"
442                print(1 + 1)
443                print(2 + 2)
444
445                print(4 + 4)"# }]
446        );
447    }
448
449    #[gpui::test]
450    fn test_jupytext_snippet_ranges(cx: &mut AppContext) {
451        // Create a test language
452        let test_language = Arc::new(Language::new(
453            LanguageConfig {
454                name: "TestLang".into(),
455                line_comments: vec!["# ".into()],
456                ..Default::default()
457            },
458            None,
459        ));
460
461        let buffer = cx.new_model(|cx| {
462            Buffer::local(
463                indoc! { r#"
464                    # Hello!
465                    # %% [markdown]
466                    # This is some arithmetic
467                    print(1 + 1)
468                    print(2 + 2)
469
470                    # %%
471                    print(3 + 3)
472                    print(4 + 4)
473
474                    print(5 + 5)
475
476
477
478                "# },
479                cx,
480            )
481            .with_language(test_language, cx)
482        });
483        let snapshot = buffer.read(cx).snapshot();
484
485        // Jupytext snippet surrounding an empty selection
486        let (snippets, _) = runnable_ranges(&snapshot, Point::new(2, 5)..Point::new(2, 5));
487
488        let snippets = snippets
489            .into_iter()
490            .map(|range| snapshot.text_for_range(range).collect::<String>())
491            .collect::<Vec<_>>();
492        assert_eq!(
493            snippets,
494            vec![indoc! { r#"
495                # %% [markdown]
496                # This is some arithmetic
497                print(1 + 1)
498                print(2 + 2)"# }]
499        );
500
501        // Jupytext snippets intersecting a non-empty selection
502        let (snippets, _) = runnable_ranges(&snapshot, Point::new(2, 5)..Point::new(6, 2));
503        let snippets = snippets
504            .into_iter()
505            .map(|range| snapshot.text_for_range(range).collect::<String>())
506            .collect::<Vec<_>>();
507        assert_eq!(
508            snippets,
509            vec![
510                indoc! { r#"
511                    # %% [markdown]
512                    # This is some arithmetic
513                    print(1 + 1)
514                    print(2 + 2)"#
515                },
516                indoc! { r#"
517                    # %%
518                    print(3 + 3)
519                    print(4 + 4)
520
521                    print(5 + 5)"#
522                }
523            ]
524        );
525    }
526
527    #[gpui::test]
528    fn test_markdown_code_blocks(cx: &mut AppContext) {
529        let markdown = languages::language("markdown", tree_sitter_md::LANGUAGE.into());
530        let typescript = languages::language(
531            "typescript",
532            tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into(),
533        );
534        let python = languages::language("python", tree_sitter_python::LANGUAGE.into());
535        let language_registry = Arc::new(LanguageRegistry::new(cx.background_executor().clone()));
536        language_registry.add(markdown.clone());
537        language_registry.add(typescript.clone());
538        language_registry.add(python.clone());
539
540        // Two code blocks intersecting with selection
541        let buffer = cx.new_model(|cx| {
542            let mut buffer = Buffer::local(
543                indoc! { r#"
544                    Hey this is Markdown!
545
546                    ```typescript
547                    let foo = 999;
548                    console.log(foo + 1999);
549                    ```
550
551                    ```typescript
552                    console.log("foo")
553                    ```
554                    "#
555                },
556                cx,
557            );
558            buffer.set_language_registry(language_registry.clone());
559            buffer.set_language(Some(markdown.clone()), cx);
560            buffer
561        });
562        let snapshot = buffer.read(cx).snapshot();
563
564        let (snippets, _) = runnable_ranges(&snapshot, Point::new(3, 5)..Point::new(8, 5));
565        let snippets = snippets
566            .into_iter()
567            .map(|range| snapshot.text_for_range(range).collect::<String>())
568            .collect::<Vec<_>>();
569
570        assert_eq!(
571            snippets,
572            vec![
573                indoc! { r#"
574                    let foo = 999;
575                    console.log(foo + 1999);
576                    "#
577                },
578                "console.log(\"foo\")\n"
579            ]
580        );
581
582        // Three code blocks intersecting with selection
583        let buffer = cx.new_model(|cx| {
584            let mut buffer = Buffer::local(
585                indoc! { r#"
586                    Hey this is Markdown!
587
588                    ```typescript
589                    let foo = 999;
590                    console.log(foo + 1999);
591                    ```
592
593                    ```ts
594                    console.log("foo")
595                    ```
596
597                    ```typescript
598                    console.log("another code block")
599                    ```
600                "# },
601                cx,
602            );
603            buffer.set_language_registry(language_registry.clone());
604            buffer.set_language(Some(markdown.clone()), cx);
605            buffer
606        });
607        let snapshot = buffer.read(cx).snapshot();
608
609        let (snippets, _) = runnable_ranges(&snapshot, Point::new(3, 5)..Point::new(12, 5));
610        let snippets = snippets
611            .into_iter()
612            .map(|range| snapshot.text_for_range(range).collect::<String>())
613            .collect::<Vec<_>>();
614
615        assert_eq!(
616            snippets,
617            vec![
618                indoc! { r#"
619                    let foo = 999;
620                    console.log(foo + 1999);
621                    "#
622                },
623                "console.log(\"foo\")\n",
624                "console.log(\"another code block\")\n",
625            ]
626        );
627
628        // Python code block
629        let buffer = cx.new_model(|cx| {
630            let mut buffer = Buffer::local(
631                indoc! { r#"
632                    Hey this is Markdown!
633
634                    ```python
635                    print("hello there")
636                    print("hello there")
637                    print("hello there")
638                    ```
639                "# },
640                cx,
641            );
642            buffer.set_language_registry(language_registry.clone());
643            buffer.set_language(Some(markdown.clone()), cx);
644            buffer
645        });
646        let snapshot = buffer.read(cx).snapshot();
647
648        let (snippets, _) = runnable_ranges(&snapshot, Point::new(4, 5)..Point::new(5, 5));
649        let snippets = snippets
650            .into_iter()
651            .map(|range| snapshot.text_for_range(range).collect::<String>())
652            .collect::<Vec<_>>();
653
654        assert_eq!(
655            snippets,
656            vec![indoc! { r#"
657                print("hello there")
658                print("hello there")
659                print("hello there")
660                "#
661            },]
662        );
663    }
664}