repl_editor.rs

  1//! REPL operations on an [`Editor`].
  2
  3use std::ops::Range;
  4use std::sync::Arc;
  5
  6use anyhow::{Context, Result};
  7use editor::Editor;
  8use gpui::{prelude::*, Entity, View, WeakView, WindowContext};
  9use language::{BufferSnapshot, Language, LanguageName, Point};
 10use project::{ProjectItem as _, WorktreeId};
 11
 12use crate::repl_store::ReplStore;
 13use crate::session::SessionEvent;
 14use crate::{
 15    ClearOutputs, Interrupt, JupyterSettings, KernelSpecification, Restart, Session, Shutdown,
 16};
 17
 18pub fn assign_kernelspec(
 19    kernel_specification: KernelSpecification,
 20    weak_editor: WeakView<Editor>,
 21    cx: &mut WindowContext,
 22) -> Result<()> {
 23    let store = ReplStore::global(cx);
 24    if !store.read(cx).is_enabled() {
 25        return Ok(());
 26    }
 27
 28    let worktree_id = crate::repl_editor::worktree_id_for_editor(weak_editor.clone(), cx)
 29        .context("editor is not in a worktree")?;
 30
 31    store.update(cx, |store, cx| {
 32        store.set_active_kernelspec(worktree_id, kernel_specification.clone(), cx);
 33    });
 34
 35    let fs = store.read(cx).fs().clone();
 36    let telemetry = store.read(cx).telemetry().clone();
 37
 38    if let Some(session) = store.read(cx).get_session(weak_editor.entity_id()).cloned() {
 39        // Drop previous session, start new one
 40        session.update(cx, |session, cx| {
 41            session.clear_outputs(cx);
 42            session.shutdown(cx);
 43            cx.notify();
 44        });
 45    }
 46
 47    let session = cx
 48        .new_view(|cx| Session::new(weak_editor.clone(), fs, telemetry, kernel_specification, cx));
 49
 50    weak_editor
 51        .update(cx, |_editor, cx| {
 52            cx.notify();
 53
 54            cx.subscribe(&session, {
 55                let store = store.clone();
 56                move |_this, _session, event, cx| match event {
 57                    SessionEvent::Shutdown(shutdown_event) => {
 58                        store.update(cx, |store, _cx| {
 59                            store.remove_session(shutdown_event.entity_id());
 60                        });
 61                    }
 62                }
 63            })
 64            .detach();
 65        })
 66        .ok();
 67
 68    store.update(cx, |store, _cx| {
 69        store.insert_session(weak_editor.entity_id(), session.clone());
 70    });
 71
 72    Ok(())
 73}
 74
 75pub fn run(editor: WeakView<Editor>, move_down: bool, cx: &mut WindowContext) -> Result<()> {
 76    let store = ReplStore::global(cx);
 77    if !store.read(cx).is_enabled() {
 78        return Ok(());
 79    }
 80
 81    let editor = editor.upgrade().context("editor was dropped")?;
 82    let selected_range = editor
 83        .update(cx, |editor, cx| editor.selections.newest_adjusted(cx))
 84        .range();
 85    let multibuffer = editor.read(cx).buffer().clone();
 86    let Some(buffer) = multibuffer.read(cx).as_singleton() else {
 87        return Ok(());
 88    };
 89
 90    let Some(project_path) = buffer.read(cx).project_path(cx) else {
 91        return Ok(());
 92    };
 93
 94    let (runnable_ranges, next_cell_point) =
 95        runnable_ranges(&buffer.read(cx).snapshot(), selected_range);
 96
 97    for runnable_range in runnable_ranges {
 98        let Some(language) = multibuffer.read(cx).language_at(runnable_range.start, cx) else {
 99            continue;
100        };
101
102        let kernel_specification = store
103            .read(cx)
104            .active_kernelspec(project_path.worktree_id, Some(language.clone()), cx)
105            .ok_or_else(|| anyhow::anyhow!("No kernel found for language: {}", language.name()))?;
106
107        let fs = store.read(cx).fs().clone();
108        let telemetry = store.read(cx).telemetry().clone();
109
110        let session = if let Some(session) = store.read(cx).get_session(editor.entity_id()).cloned()
111        {
112            session
113        } else {
114            let weak_editor = editor.downgrade();
115            let session = cx
116                .new_view(|cx| Session::new(weak_editor, fs, telemetry, kernel_specification, cx));
117
118            editor.update(cx, |_editor, cx| {
119                cx.notify();
120
121                cx.subscribe(&session, {
122                    let store = store.clone();
123                    move |_this, _session, event, cx| match event {
124                        SessionEvent::Shutdown(shutdown_event) => {
125                            store.update(cx, |store, _cx| {
126                                store.remove_session(shutdown_event.entity_id());
127                            });
128                        }
129                    }
130                })
131                .detach();
132            });
133
134            store.update(cx, |store, _cx| {
135                store.insert_session(editor.entity_id(), session.clone());
136            });
137
138            session
139        };
140
141        let selected_text;
142        let anchor_range;
143        let next_cursor;
144        {
145            let snapshot = multibuffer.read(cx).read(cx);
146            selected_text = snapshot
147                .text_for_range(runnable_range.clone())
148                .collect::<String>();
149            anchor_range = snapshot.anchor_before(runnable_range.start)
150                ..snapshot.anchor_after(runnable_range.end);
151            next_cursor = next_cell_point.map(|point| snapshot.anchor_after(point));
152        }
153
154        session.update(cx, |session, cx| {
155            session.execute(selected_text, anchor_range, next_cursor, move_down, cx);
156        });
157    }
158
159    anyhow::Ok(())
160}
161
162#[allow(clippy::large_enum_variant)]
163pub enum SessionSupport {
164    ActiveSession(View<Session>),
165    Inactive(KernelSpecification),
166    RequiresSetup(LanguageName),
167    Unsupported,
168}
169
170pub fn worktree_id_for_editor(
171    editor: WeakView<Editor>,
172    cx: &mut WindowContext,
173) -> Option<WorktreeId> {
174    editor.upgrade().and_then(|editor| {
175        editor
176            .read(cx)
177            .buffer()
178            .read(cx)
179            .as_singleton()?
180            .read(cx)
181            .project_path(cx)
182            .map(|path| path.worktree_id)
183    })
184}
185
186pub fn session(editor: WeakView<Editor>, cx: &mut WindowContext) -> SessionSupport {
187    let store = ReplStore::global(cx);
188    let entity_id = editor.entity_id();
189
190    if let Some(session) = store.read(cx).get_session(entity_id).cloned() {
191        return SessionSupport::ActiveSession(session);
192    };
193
194    let Some(language) = get_language(editor.clone(), cx) else {
195        return SessionSupport::Unsupported;
196    };
197
198    let worktree_id = worktree_id_for_editor(editor.clone(), cx);
199
200    let Some(worktree_id) = worktree_id else {
201        return SessionSupport::Unsupported;
202    };
203
204    let kernelspec = store
205        .read(cx)
206        .active_kernelspec(worktree_id, Some(language.clone()), cx);
207
208    match kernelspec {
209        Some(kernelspec) => SessionSupport::Inactive(kernelspec),
210        None => {
211            if language_supported(&language.clone()) {
212                SessionSupport::RequiresSetup(language.name())
213            } else {
214                SessionSupport::Unsupported
215            }
216        }
217    }
218}
219
220pub fn clear_outputs(editor: WeakView<Editor>, cx: &mut WindowContext) {
221    let store = ReplStore::global(cx);
222    let entity_id = editor.entity_id();
223    let Some(session) = store.read(cx).get_session(entity_id).cloned() else {
224        return;
225    };
226    session.update(cx, |session, cx| {
227        session.clear_outputs(cx);
228        cx.notify();
229    });
230}
231
232pub fn interrupt(editor: WeakView<Editor>, cx: &mut WindowContext) {
233    let store = ReplStore::global(cx);
234    let entity_id = editor.entity_id();
235    let Some(session) = store.read(cx).get_session(entity_id).cloned() else {
236        return;
237    };
238
239    session.update(cx, |session, cx| {
240        session.interrupt(cx);
241        cx.notify();
242    });
243}
244
245pub fn shutdown(editor: WeakView<Editor>, cx: &mut WindowContext) {
246    let store = ReplStore::global(cx);
247    let entity_id = editor.entity_id();
248    let Some(session) = store.read(cx).get_session(entity_id).cloned() else {
249        return;
250    };
251
252    session.update(cx, |session, cx| {
253        session.shutdown(cx);
254        cx.notify();
255    });
256}
257
258pub fn restart(editor: WeakView<Editor>, cx: &mut WindowContext) {
259    let Some(editor) = editor.upgrade() else {
260        return;
261    };
262
263    let entity_id = editor.entity_id();
264
265    let Some(session) = ReplStore::global(cx)
266        .read(cx)
267        .get_session(entity_id)
268        .cloned()
269    else {
270        return;
271    };
272
273    session.update(cx, |session, cx| {
274        session.restart(cx);
275        cx.notify();
276    });
277}
278
279pub fn setup_editor_session_actions(editor: &mut Editor, editor_handle: WeakView<Editor>) {
280    editor
281        .register_action({
282            let editor_handle = editor_handle.clone();
283            move |_: &ClearOutputs, cx| {
284                if !JupyterSettings::enabled(cx) {
285                    return;
286                }
287
288                crate::clear_outputs(editor_handle.clone(), cx);
289            }
290        })
291        .detach();
292
293    editor
294        .register_action({
295            let editor_handle = editor_handle.clone();
296            move |_: &Interrupt, cx| {
297                if !JupyterSettings::enabled(cx) {
298                    return;
299                }
300
301                crate::interrupt(editor_handle.clone(), cx);
302            }
303        })
304        .detach();
305
306    editor
307        .register_action({
308            let editor_handle = editor_handle.clone();
309            move |_: &Shutdown, cx| {
310                if !JupyterSettings::enabled(cx) {
311                    return;
312                }
313
314                crate::shutdown(editor_handle.clone(), cx);
315            }
316        })
317        .detach();
318
319    editor
320        .register_action({
321            let editor_handle = editor_handle.clone();
322            move |_: &Restart, cx| {
323                if !JupyterSettings::enabled(cx) {
324                    return;
325                }
326
327                crate::restart(editor_handle.clone(), cx);
328            }
329        })
330        .detach();
331}
332
333fn cell_range(buffer: &BufferSnapshot, start_row: u32, end_row: u32) -> Range<Point> {
334    let mut snippet_end_row = end_row;
335    while buffer.is_line_blank(snippet_end_row) && snippet_end_row > start_row {
336        snippet_end_row -= 1;
337    }
338    Point::new(start_row, 0)..Point::new(snippet_end_row, buffer.line_len(snippet_end_row))
339}
340
341// Returns the ranges of the snippets in the buffer and the next point for moving the cursor to
342fn jupytext_cells(
343    buffer: &BufferSnapshot,
344    range: Range<Point>,
345) -> (Vec<Range<Point>>, Option<Point>) {
346    let mut current_row = range.start.row;
347
348    let Some(language) = buffer.language() else {
349        return (Vec::new(), None);
350    };
351
352    let default_scope = language.default_scope();
353    let comment_prefixes = default_scope.line_comment_prefixes();
354    if comment_prefixes.is_empty() {
355        return (Vec::new(), None);
356    }
357
358    let jupytext_prefixes = comment_prefixes
359        .iter()
360        .map(|comment_prefix| format!("{comment_prefix}%%"))
361        .collect::<Vec<_>>();
362
363    let mut snippet_start_row = None;
364    loop {
365        if jupytext_prefixes
366            .iter()
367            .any(|prefix| buffer.contains_str_at(Point::new(current_row, 0), prefix))
368        {
369            snippet_start_row = Some(current_row);
370            break;
371        } else if current_row > 0 {
372            current_row -= 1;
373        } else {
374            break;
375        }
376    }
377
378    let mut snippets = Vec::new();
379    if let Some(mut snippet_start_row) = snippet_start_row {
380        for current_row in range.start.row + 1..=buffer.max_point().row {
381            if jupytext_prefixes
382                .iter()
383                .any(|prefix| buffer.contains_str_at(Point::new(current_row, 0), prefix))
384            {
385                snippets.push(cell_range(buffer, snippet_start_row, current_row - 1));
386
387                if current_row <= range.end.row {
388                    snippet_start_row = current_row;
389                } else {
390                    // Return our snippets as well as the next point for moving the cursor to
391                    return (snippets, Some(Point::new(current_row, 0)));
392                }
393            }
394        }
395
396        // Go to the end of the buffer (no more jupytext cells found)
397        snippets.push(cell_range(
398            buffer,
399            snippet_start_row,
400            buffer.max_point().row,
401        ));
402    }
403
404    (snippets, None)
405}
406
407fn runnable_ranges(
408    buffer: &BufferSnapshot,
409    range: Range<Point>,
410) -> (Vec<Range<Point>>, Option<Point>) {
411    if let Some(language) = buffer.language() {
412        if language.name() == "Markdown".into() {
413            return (markdown_code_blocks(buffer, range.clone()), None);
414        }
415    }
416
417    let (jupytext_snippets, next_cursor) = jupytext_cells(buffer, range.clone());
418    if !jupytext_snippets.is_empty() {
419        return (jupytext_snippets, next_cursor);
420    }
421
422    let snippet_range = cell_range(buffer, range.start.row, range.end.row);
423    let start_language = buffer.language_at(snippet_range.start);
424    let end_language = buffer.language_at(snippet_range.end);
425
426    if start_language
427        .zip(end_language)
428        .map_or(false, |(start, end)| start == end)
429    {
430        (vec![snippet_range], None)
431    } else {
432        (Vec::new(), None)
433    }
434}
435
436// We allow markdown code blocks to end in a trailing newline in order to render the output
437// below the final code fence. This is different than our behavior for selections and Jupytext cells.
438fn markdown_code_blocks(buffer: &BufferSnapshot, range: Range<Point>) -> Vec<Range<Point>> {
439    buffer
440        .injections_intersecting_range(range)
441        .filter(|(_, language)| language_supported(language))
442        .map(|(content_range, _)| {
443            buffer.offset_to_point(content_range.start)..buffer.offset_to_point(content_range.end)
444        })
445        .collect()
446}
447
448fn language_supported(language: &Arc<Language>) -> bool {
449    match language.name().0.as_ref() {
450        "TypeScript" | "Python" => true,
451        _ => false,
452    }
453}
454
455fn get_language(editor: WeakView<Editor>, cx: &mut WindowContext) -> Option<Arc<Language>> {
456    editor
457        .update(cx, |editor, cx| {
458            let selection = editor.selections.newest::<usize>(cx);
459            let buffer = editor.buffer().read(cx).snapshot(cx);
460            buffer.language_at(selection.head()).cloned()
461        })
462        .ok()
463        .flatten()
464}
465
466#[cfg(test)]
467mod tests {
468    use super::*;
469    use gpui::{AppContext, Context};
470    use indoc::indoc;
471    use language::{Buffer, Language, LanguageConfig, LanguageRegistry};
472
473    #[gpui::test]
474    fn test_snippet_ranges(cx: &mut AppContext) {
475        // Create a test language
476        let test_language = Arc::new(Language::new(
477            LanguageConfig {
478                name: "TestLang".into(),
479                line_comments: vec!["# ".into()],
480                ..Default::default()
481            },
482            None,
483        ));
484
485        let buffer = cx.new_model(|cx| {
486            Buffer::local(
487                indoc! { r#"
488                    print(1 + 1)
489                    print(2 + 2)
490
491                    print(4 + 4)
492
493
494                "# },
495                cx,
496            )
497            .with_language(test_language, cx)
498        });
499        let snapshot = buffer.read(cx).snapshot();
500
501        // Single-point selection
502        let (snippets, _) = runnable_ranges(&snapshot, Point::new(0, 4)..Point::new(0, 4));
503        let snippets = snippets
504            .into_iter()
505            .map(|range| snapshot.text_for_range(range).collect::<String>())
506            .collect::<Vec<_>>();
507        assert_eq!(snippets, vec!["print(1 + 1)"]);
508
509        // Multi-line selection
510        let (snippets, _) = runnable_ranges(&snapshot, Point::new(0, 5)..Point::new(2, 0));
511        let snippets = snippets
512            .into_iter()
513            .map(|range| snapshot.text_for_range(range).collect::<String>())
514            .collect::<Vec<_>>();
515        assert_eq!(
516            snippets,
517            vec![indoc! { r#"
518                print(1 + 1)
519                print(2 + 2)"# }]
520        );
521
522        // Trimming multiple trailing blank lines
523        let (snippets, _) = runnable_ranges(&snapshot, Point::new(0, 5)..Point::new(5, 0));
524
525        let snippets = snippets
526            .into_iter()
527            .map(|range| snapshot.text_for_range(range).collect::<String>())
528            .collect::<Vec<_>>();
529        assert_eq!(
530            snippets,
531            vec![indoc! { r#"
532                print(1 + 1)
533                print(2 + 2)
534
535                print(4 + 4)"# }]
536        );
537    }
538
539    #[gpui::test]
540    fn test_jupytext_snippet_ranges(cx: &mut AppContext) {
541        // Create a test language
542        let test_language = Arc::new(Language::new(
543            LanguageConfig {
544                name: "TestLang".into(),
545                line_comments: vec!["# ".into()],
546                ..Default::default()
547            },
548            None,
549        ));
550
551        let buffer = cx.new_model(|cx| {
552            Buffer::local(
553                indoc! { r#"
554                    # Hello!
555                    # %% [markdown]
556                    # This is some arithmetic
557                    print(1 + 1)
558                    print(2 + 2)
559
560                    # %%
561                    print(3 + 3)
562                    print(4 + 4)
563
564                    print(5 + 5)
565
566
567
568                "# },
569                cx,
570            )
571            .with_language(test_language, cx)
572        });
573        let snapshot = buffer.read(cx).snapshot();
574
575        // Jupytext snippet surrounding an empty selection
576        let (snippets, _) = runnable_ranges(&snapshot, Point::new(2, 5)..Point::new(2, 5));
577
578        let snippets = snippets
579            .into_iter()
580            .map(|range| snapshot.text_for_range(range).collect::<String>())
581            .collect::<Vec<_>>();
582        assert_eq!(
583            snippets,
584            vec![indoc! { r#"
585                # %% [markdown]
586                # This is some arithmetic
587                print(1 + 1)
588                print(2 + 2)"# }]
589        );
590
591        // Jupytext snippets intersecting a non-empty selection
592        let (snippets, _) = runnable_ranges(&snapshot, Point::new(2, 5)..Point::new(6, 2));
593        let snippets = snippets
594            .into_iter()
595            .map(|range| snapshot.text_for_range(range).collect::<String>())
596            .collect::<Vec<_>>();
597        assert_eq!(
598            snippets,
599            vec![
600                indoc! { r#"
601                    # %% [markdown]
602                    # This is some arithmetic
603                    print(1 + 1)
604                    print(2 + 2)"#
605                },
606                indoc! { r#"
607                    # %%
608                    print(3 + 3)
609                    print(4 + 4)
610
611                    print(5 + 5)"#
612                }
613            ]
614        );
615    }
616
617    #[gpui::test]
618    fn test_markdown_code_blocks(cx: &mut AppContext) {
619        let markdown = languages::language("markdown", tree_sitter_md::LANGUAGE.into());
620        let typescript = languages::language(
621            "typescript",
622            tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into(),
623        );
624        let python = languages::language("python", tree_sitter_python::LANGUAGE.into());
625        let language_registry = Arc::new(LanguageRegistry::new(cx.background_executor().clone()));
626        language_registry.add(markdown.clone());
627        language_registry.add(typescript.clone());
628        language_registry.add(python.clone());
629
630        // Two code blocks intersecting with selection
631        let buffer = cx.new_model(|cx| {
632            let mut buffer = Buffer::local(
633                indoc! { r#"
634                    Hey this is Markdown!
635
636                    ```typescript
637                    let foo = 999;
638                    console.log(foo + 1999);
639                    ```
640
641                    ```typescript
642                    console.log("foo")
643                    ```
644                    "#
645                },
646                cx,
647            );
648            buffer.set_language_registry(language_registry.clone());
649            buffer.set_language(Some(markdown.clone()), cx);
650            buffer
651        });
652        let snapshot = buffer.read(cx).snapshot();
653
654        let (snippets, _) = runnable_ranges(&snapshot, Point::new(3, 5)..Point::new(8, 5));
655        let snippets = snippets
656            .into_iter()
657            .map(|range| snapshot.text_for_range(range).collect::<String>())
658            .collect::<Vec<_>>();
659
660        assert_eq!(
661            snippets,
662            vec![
663                indoc! { r#"
664                    let foo = 999;
665                    console.log(foo + 1999);
666                    "#
667                },
668                "console.log(\"foo\")\n"
669            ]
670        );
671
672        // Three code blocks intersecting with selection
673        let buffer = cx.new_model(|cx| {
674            let mut buffer = Buffer::local(
675                indoc! { r#"
676                    Hey this is Markdown!
677
678                    ```typescript
679                    let foo = 999;
680                    console.log(foo + 1999);
681                    ```
682
683                    ```ts
684                    console.log("foo")
685                    ```
686
687                    ```typescript
688                    console.log("another code block")
689                    ```
690                "# },
691                cx,
692            );
693            buffer.set_language_registry(language_registry.clone());
694            buffer.set_language(Some(markdown.clone()), cx);
695            buffer
696        });
697        let snapshot = buffer.read(cx).snapshot();
698
699        let (snippets, _) = runnable_ranges(&snapshot, Point::new(3, 5)..Point::new(12, 5));
700        let snippets = snippets
701            .into_iter()
702            .map(|range| snapshot.text_for_range(range).collect::<String>())
703            .collect::<Vec<_>>();
704
705        assert_eq!(
706            snippets,
707            vec![
708                indoc! { r#"
709                    let foo = 999;
710                    console.log(foo + 1999);
711                    "#
712                },
713                "console.log(\"foo\")\n",
714                "console.log(\"another code block\")\n",
715            ]
716        );
717
718        // Python code block
719        let buffer = cx.new_model(|cx| {
720            let mut buffer = Buffer::local(
721                indoc! { r#"
722                    Hey this is Markdown!
723
724                    ```python
725                    print("hello there")
726                    print("hello there")
727                    print("hello there")
728                    ```
729                "# },
730                cx,
731            );
732            buffer.set_language_registry(language_registry.clone());
733            buffer.set_language(Some(markdown.clone()), cx);
734            buffer
735        });
736        let snapshot = buffer.read(cx).snapshot();
737
738        let (snippets, _) = runnable_ranges(&snapshot, Point::new(4, 5)..Point::new(5, 5));
739        let snippets = snippets
740            .into_iter()
741            .map(|range| snapshot.text_for_range(range).collect::<String>())
742            .collect::<Vec<_>>();
743
744        assert_eq!(
745            snippets,
746            vec![indoc! { r#"
747                print("hello there")
748                print("hello there")
749                print("hello there")
750                "#
751            },]
752        );
753    }
754}