repl_editor.rs

  1//! REPL operations on an [`Editor`].
  2
  3use std::ops::Range;
  4use std::sync::Arc;
  5
  6use anyhow::{Context, Result};
  7use editor::Editor;
  8use gpui::{prelude::*, Entity, View, WeakView, WindowContext};
  9use language::{BufferSnapshot, Language, LanguageName, Point};
 10use project::{ProjectItem as _, WorktreeId};
 11
 12use crate::repl_store::ReplStore;
 13use crate::session::SessionEvent;
 14use crate::{
 15    ClearOutputs, Interrupt, JupyterSettings, KernelSpecification, Restart, Session, Shutdown,
 16};
 17
 18pub fn assign_kernelspec(
 19    kernel_specification: KernelSpecification,
 20    weak_editor: WeakView<Editor>,
 21    cx: &mut WindowContext,
 22) -> Result<()> {
 23    let store = ReplStore::global(cx);
 24    if !store.read(cx).is_enabled() {
 25        return Ok(());
 26    }
 27
 28    let worktree_id = crate::repl_editor::worktree_id_for_editor(weak_editor.clone(), cx)
 29        .context("editor is not in a worktree")?;
 30
 31    store.update(cx, |store, cx| {
 32        store.set_active_kernelspec(worktree_id, kernel_specification.clone(), cx);
 33    });
 34
 35    let fs = store.read(cx).fs().clone();
 36
 37    if let Some(session) = store.read(cx).get_session(weak_editor.entity_id()).cloned() {
 38        // Drop previous session, start new one
 39        session.update(cx, |session, cx| {
 40            session.clear_outputs(cx);
 41            session.shutdown(cx);
 42            cx.notify();
 43        });
 44    }
 45
 46    let session = cx.new_view(|cx| Session::new(weak_editor.clone(), fs, kernel_specification, cx));
 47
 48    weak_editor
 49        .update(cx, |_editor, cx| {
 50            cx.notify();
 51
 52            cx.subscribe(&session, {
 53                let store = store.clone();
 54                move |_this, _session, event, cx| match event {
 55                    SessionEvent::Shutdown(shutdown_event) => {
 56                        store.update(cx, |store, _cx| {
 57                            store.remove_session(shutdown_event.entity_id());
 58                        });
 59                    }
 60                }
 61            })
 62            .detach();
 63        })
 64        .ok();
 65
 66    store.update(cx, |store, _cx| {
 67        store.insert_session(weak_editor.entity_id(), session.clone());
 68    });
 69
 70    Ok(())
 71}
 72
 73pub fn run(editor: WeakView<Editor>, move_down: bool, cx: &mut WindowContext) -> Result<()> {
 74    let store = ReplStore::global(cx);
 75    if !store.read(cx).is_enabled() {
 76        return Ok(());
 77    }
 78
 79    let editor = editor.upgrade().context("editor was dropped")?;
 80    let selected_range = editor
 81        .update(cx, |editor, cx| editor.selections.newest_adjusted(cx))
 82        .range();
 83    let multibuffer = editor.read(cx).buffer().clone();
 84    let Some(buffer) = multibuffer.read(cx).as_singleton() else {
 85        return Ok(());
 86    };
 87
 88    let Some(project_path) = buffer.read(cx).project_path(cx) else {
 89        return Ok(());
 90    };
 91
 92    let (runnable_ranges, next_cell_point) =
 93        runnable_ranges(&buffer.read(cx).snapshot(), selected_range);
 94
 95    for runnable_range in runnable_ranges {
 96        let Some(language) = multibuffer.read(cx).language_at(runnable_range.start, cx) else {
 97            continue;
 98        };
 99
100        let kernel_specification = store
101            .read(cx)
102            .active_kernelspec(project_path.worktree_id, Some(language.clone()), cx)
103            .ok_or_else(|| anyhow::anyhow!("No kernel found for language: {}", language.name()))?;
104
105        let fs = store.read(cx).fs().clone();
106
107        let session = if let Some(session) = store.read(cx).get_session(editor.entity_id()).cloned()
108        {
109            session
110        } else {
111            let weak_editor = editor.downgrade();
112            let session = cx.new_view(|cx| Session::new(weak_editor, fs, kernel_specification, cx));
113
114            editor.update(cx, |_editor, cx| {
115                cx.notify();
116
117                cx.subscribe(&session, {
118                    let store = store.clone();
119                    move |_this, _session, event, cx| match event {
120                        SessionEvent::Shutdown(shutdown_event) => {
121                            store.update(cx, |store, _cx| {
122                                store.remove_session(shutdown_event.entity_id());
123                            });
124                        }
125                    }
126                })
127                .detach();
128            });
129
130            store.update(cx, |store, _cx| {
131                store.insert_session(editor.entity_id(), session.clone());
132            });
133
134            session
135        };
136
137        let selected_text;
138        let anchor_range;
139        let next_cursor;
140        {
141            let snapshot = multibuffer.read(cx).read(cx);
142            selected_text = snapshot
143                .text_for_range(runnable_range.clone())
144                .collect::<String>();
145            anchor_range = snapshot.anchor_before(runnable_range.start)
146                ..snapshot.anchor_after(runnable_range.end);
147            next_cursor = next_cell_point.map(|point| snapshot.anchor_after(point));
148        }
149
150        session.update(cx, |session, cx| {
151            session.execute(selected_text, anchor_range, next_cursor, move_down, cx);
152        });
153    }
154
155    anyhow::Ok(())
156}
157
158#[allow(clippy::large_enum_variant)]
159pub enum SessionSupport {
160    ActiveSession(View<Session>),
161    Inactive(KernelSpecification),
162    RequiresSetup(LanguageName),
163    Unsupported,
164}
165
166pub fn worktree_id_for_editor(
167    editor: WeakView<Editor>,
168    cx: &mut WindowContext,
169) -> Option<WorktreeId> {
170    editor.upgrade().and_then(|editor| {
171        editor
172            .read(cx)
173            .buffer()
174            .read(cx)
175            .as_singleton()?
176            .read(cx)
177            .project_path(cx)
178            .map(|path| path.worktree_id)
179    })
180}
181
182pub fn session(editor: WeakView<Editor>, cx: &mut WindowContext) -> SessionSupport {
183    let store = ReplStore::global(cx);
184    let entity_id = editor.entity_id();
185
186    if let Some(session) = store.read(cx).get_session(entity_id).cloned() {
187        return SessionSupport::ActiveSession(session);
188    };
189
190    let Some(language) = get_language(editor.clone(), cx) else {
191        return SessionSupport::Unsupported;
192    };
193
194    let worktree_id = worktree_id_for_editor(editor.clone(), cx);
195
196    let Some(worktree_id) = worktree_id else {
197        return SessionSupport::Unsupported;
198    };
199
200    let kernelspec = store
201        .read(cx)
202        .active_kernelspec(worktree_id, Some(language.clone()), cx);
203
204    match kernelspec {
205        Some(kernelspec) => SessionSupport::Inactive(kernelspec),
206        None => {
207            if language_supported(&language.clone()) {
208                SessionSupport::RequiresSetup(language.name())
209            } else {
210                SessionSupport::Unsupported
211            }
212        }
213    }
214}
215
216pub fn clear_outputs(editor: WeakView<Editor>, cx: &mut WindowContext) {
217    let store = ReplStore::global(cx);
218    let entity_id = editor.entity_id();
219    let Some(session) = store.read(cx).get_session(entity_id).cloned() else {
220        return;
221    };
222    session.update(cx, |session, cx| {
223        session.clear_outputs(cx);
224        cx.notify();
225    });
226}
227
228pub fn interrupt(editor: WeakView<Editor>, cx: &mut WindowContext) {
229    let store = ReplStore::global(cx);
230    let entity_id = editor.entity_id();
231    let Some(session) = store.read(cx).get_session(entity_id).cloned() else {
232        return;
233    };
234
235    session.update(cx, |session, cx| {
236        session.interrupt(cx);
237        cx.notify();
238    });
239}
240
241pub fn shutdown(editor: WeakView<Editor>, cx: &mut WindowContext) {
242    let store = ReplStore::global(cx);
243    let entity_id = editor.entity_id();
244    let Some(session) = store.read(cx).get_session(entity_id).cloned() else {
245        return;
246    };
247
248    session.update(cx, |session, cx| {
249        session.shutdown(cx);
250        cx.notify();
251    });
252}
253
254pub fn restart(editor: WeakView<Editor>, cx: &mut WindowContext) {
255    let Some(editor) = editor.upgrade() else {
256        return;
257    };
258
259    let entity_id = editor.entity_id();
260
261    let Some(session) = ReplStore::global(cx)
262        .read(cx)
263        .get_session(entity_id)
264        .cloned()
265    else {
266        return;
267    };
268
269    session.update(cx, |session, cx| {
270        session.restart(cx);
271        cx.notify();
272    });
273}
274
275pub fn setup_editor_session_actions(editor: &mut Editor, editor_handle: WeakView<Editor>) {
276    editor
277        .register_action({
278            let editor_handle = editor_handle.clone();
279            move |_: &ClearOutputs, cx| {
280                if !JupyterSettings::enabled(cx) {
281                    return;
282                }
283
284                crate::clear_outputs(editor_handle.clone(), cx);
285            }
286        })
287        .detach();
288
289    editor
290        .register_action({
291            let editor_handle = editor_handle.clone();
292            move |_: &Interrupt, cx| {
293                if !JupyterSettings::enabled(cx) {
294                    return;
295                }
296
297                crate::interrupt(editor_handle.clone(), cx);
298            }
299        })
300        .detach();
301
302    editor
303        .register_action({
304            let editor_handle = editor_handle.clone();
305            move |_: &Shutdown, cx| {
306                if !JupyterSettings::enabled(cx) {
307                    return;
308                }
309
310                crate::shutdown(editor_handle.clone(), cx);
311            }
312        })
313        .detach();
314
315    editor
316        .register_action({
317            let editor_handle = editor_handle.clone();
318            move |_: &Restart, cx| {
319                if !JupyterSettings::enabled(cx) {
320                    return;
321                }
322
323                crate::restart(editor_handle.clone(), cx);
324            }
325        })
326        .detach();
327}
328
329fn cell_range(buffer: &BufferSnapshot, start_row: u32, end_row: u32) -> Range<Point> {
330    let mut snippet_end_row = end_row;
331    while buffer.is_line_blank(snippet_end_row) && snippet_end_row > start_row {
332        snippet_end_row -= 1;
333    }
334    Point::new(start_row, 0)..Point::new(snippet_end_row, buffer.line_len(snippet_end_row))
335}
336
337// Returns the ranges of the snippets in the buffer and the next point for moving the cursor to
338fn jupytext_cells(
339    buffer: &BufferSnapshot,
340    range: Range<Point>,
341) -> (Vec<Range<Point>>, Option<Point>) {
342    let mut current_row = range.start.row;
343
344    let Some(language) = buffer.language() else {
345        return (Vec::new(), None);
346    };
347
348    let default_scope = language.default_scope();
349    let comment_prefixes = default_scope.line_comment_prefixes();
350    if comment_prefixes.is_empty() {
351        return (Vec::new(), None);
352    }
353
354    let jupytext_prefixes = comment_prefixes
355        .iter()
356        .map(|comment_prefix| format!("{comment_prefix}%%"))
357        .collect::<Vec<_>>();
358
359    let mut snippet_start_row = None;
360    loop {
361        if jupytext_prefixes
362            .iter()
363            .any(|prefix| buffer.contains_str_at(Point::new(current_row, 0), prefix))
364        {
365            snippet_start_row = Some(current_row);
366            break;
367        } else if current_row > 0 {
368            current_row -= 1;
369        } else {
370            break;
371        }
372    }
373
374    let mut snippets = Vec::new();
375    if let Some(mut snippet_start_row) = snippet_start_row {
376        for current_row in range.start.row + 1..=buffer.max_point().row {
377            if jupytext_prefixes
378                .iter()
379                .any(|prefix| buffer.contains_str_at(Point::new(current_row, 0), prefix))
380            {
381                snippets.push(cell_range(buffer, snippet_start_row, current_row - 1));
382
383                if current_row <= range.end.row {
384                    snippet_start_row = current_row;
385                } else {
386                    // Return our snippets as well as the next point for moving the cursor to
387                    return (snippets, Some(Point::new(current_row, 0)));
388                }
389            }
390        }
391
392        // Go to the end of the buffer (no more jupytext cells found)
393        snippets.push(cell_range(
394            buffer,
395            snippet_start_row,
396            buffer.max_point().row,
397        ));
398    }
399
400    (snippets, None)
401}
402
403fn runnable_ranges(
404    buffer: &BufferSnapshot,
405    range: Range<Point>,
406) -> (Vec<Range<Point>>, Option<Point>) {
407    if let Some(language) = buffer.language() {
408        if language.name() == "Markdown".into() {
409            return (markdown_code_blocks(buffer, range.clone()), None);
410        }
411    }
412
413    let (jupytext_snippets, next_cursor) = jupytext_cells(buffer, range.clone());
414    if !jupytext_snippets.is_empty() {
415        return (jupytext_snippets, next_cursor);
416    }
417
418    let snippet_range = cell_range(buffer, range.start.row, range.end.row);
419    let start_language = buffer.language_at(snippet_range.start);
420    let end_language = buffer.language_at(snippet_range.end);
421
422    if start_language
423        .zip(end_language)
424        .map_or(false, |(start, end)| start == end)
425    {
426        (vec![snippet_range], None)
427    } else {
428        (Vec::new(), None)
429    }
430}
431
432// We allow markdown code blocks to end in a trailing newline in order to render the output
433// below the final code fence. This is different than our behavior for selections and Jupytext cells.
434fn markdown_code_blocks(buffer: &BufferSnapshot, range: Range<Point>) -> Vec<Range<Point>> {
435    buffer
436        .injections_intersecting_range(range)
437        .filter(|(_, language)| language_supported(language))
438        .map(|(content_range, _)| {
439            buffer.offset_to_point(content_range.start)..buffer.offset_to_point(content_range.end)
440        })
441        .collect()
442}
443
444fn language_supported(language: &Arc<Language>) -> bool {
445    match language.name().0.as_ref() {
446        "TypeScript" | "Python" => true,
447        _ => false,
448    }
449}
450
451fn get_language(editor: WeakView<Editor>, cx: &mut WindowContext) -> Option<Arc<Language>> {
452    editor
453        .update(cx, |editor, cx| {
454            let selection = editor.selections.newest::<usize>(cx);
455            let buffer = editor.buffer().read(cx).snapshot(cx);
456            buffer.language_at(selection.head()).cloned()
457        })
458        .ok()
459        .flatten()
460}
461
462#[cfg(test)]
463mod tests {
464    use super::*;
465    use gpui::{AppContext, Context};
466    use indoc::indoc;
467    use language::{Buffer, Language, LanguageConfig, LanguageRegistry};
468
469    #[gpui::test]
470    fn test_snippet_ranges(cx: &mut AppContext) {
471        // Create a test language
472        let test_language = Arc::new(Language::new(
473            LanguageConfig {
474                name: "TestLang".into(),
475                line_comments: vec!["# ".into()],
476                ..Default::default()
477            },
478            None,
479        ));
480
481        let buffer = cx.new_model(|cx| {
482            Buffer::local(
483                indoc! { r#"
484                    print(1 + 1)
485                    print(2 + 2)
486
487                    print(4 + 4)
488
489
490                "# },
491                cx,
492            )
493            .with_language(test_language, cx)
494        });
495        let snapshot = buffer.read(cx).snapshot();
496
497        // Single-point selection
498        let (snippets, _) = runnable_ranges(&snapshot, Point::new(0, 4)..Point::new(0, 4));
499        let snippets = snippets
500            .into_iter()
501            .map(|range| snapshot.text_for_range(range).collect::<String>())
502            .collect::<Vec<_>>();
503        assert_eq!(snippets, vec!["print(1 + 1)"]);
504
505        // Multi-line selection
506        let (snippets, _) = runnable_ranges(&snapshot, Point::new(0, 5)..Point::new(2, 0));
507        let snippets = snippets
508            .into_iter()
509            .map(|range| snapshot.text_for_range(range).collect::<String>())
510            .collect::<Vec<_>>();
511        assert_eq!(
512            snippets,
513            vec![indoc! { r#"
514                print(1 + 1)
515                print(2 + 2)"# }]
516        );
517
518        // Trimming multiple trailing blank lines
519        let (snippets, _) = runnable_ranges(&snapshot, Point::new(0, 5)..Point::new(5, 0));
520
521        let snippets = snippets
522            .into_iter()
523            .map(|range| snapshot.text_for_range(range).collect::<String>())
524            .collect::<Vec<_>>();
525        assert_eq!(
526            snippets,
527            vec![indoc! { r#"
528                print(1 + 1)
529                print(2 + 2)
530
531                print(4 + 4)"# }]
532        );
533    }
534
535    #[gpui::test]
536    fn test_jupytext_snippet_ranges(cx: &mut AppContext) {
537        // Create a test language
538        let test_language = Arc::new(Language::new(
539            LanguageConfig {
540                name: "TestLang".into(),
541                line_comments: vec!["# ".into()],
542                ..Default::default()
543            },
544            None,
545        ));
546
547        let buffer = cx.new_model(|cx| {
548            Buffer::local(
549                indoc! { r#"
550                    # Hello!
551                    # %% [markdown]
552                    # This is some arithmetic
553                    print(1 + 1)
554                    print(2 + 2)
555
556                    # %%
557                    print(3 + 3)
558                    print(4 + 4)
559
560                    print(5 + 5)
561
562
563
564                "# },
565                cx,
566            )
567            .with_language(test_language, cx)
568        });
569        let snapshot = buffer.read(cx).snapshot();
570
571        // Jupytext snippet surrounding an empty selection
572        let (snippets, _) = runnable_ranges(&snapshot, Point::new(2, 5)..Point::new(2, 5));
573
574        let snippets = snippets
575            .into_iter()
576            .map(|range| snapshot.text_for_range(range).collect::<String>())
577            .collect::<Vec<_>>();
578        assert_eq!(
579            snippets,
580            vec![indoc! { r#"
581                # %% [markdown]
582                # This is some arithmetic
583                print(1 + 1)
584                print(2 + 2)"# }]
585        );
586
587        // Jupytext snippets intersecting a non-empty selection
588        let (snippets, _) = runnable_ranges(&snapshot, Point::new(2, 5)..Point::new(6, 2));
589        let snippets = snippets
590            .into_iter()
591            .map(|range| snapshot.text_for_range(range).collect::<String>())
592            .collect::<Vec<_>>();
593        assert_eq!(
594            snippets,
595            vec![
596                indoc! { r#"
597                    # %% [markdown]
598                    # This is some arithmetic
599                    print(1 + 1)
600                    print(2 + 2)"#
601                },
602                indoc! { r#"
603                    # %%
604                    print(3 + 3)
605                    print(4 + 4)
606
607                    print(5 + 5)"#
608                }
609            ]
610        );
611    }
612
613    #[gpui::test]
614    fn test_markdown_code_blocks(cx: &mut AppContext) {
615        let markdown = languages::language("markdown", tree_sitter_md::LANGUAGE.into());
616        let typescript = languages::language(
617            "typescript",
618            tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into(),
619        );
620        let python = languages::language("python", tree_sitter_python::LANGUAGE.into());
621        let language_registry = Arc::new(LanguageRegistry::new(cx.background_executor().clone()));
622        language_registry.add(markdown.clone());
623        language_registry.add(typescript.clone());
624        language_registry.add(python.clone());
625
626        // Two code blocks intersecting with selection
627        let buffer = cx.new_model(|cx| {
628            let mut buffer = Buffer::local(
629                indoc! { r#"
630                    Hey this is Markdown!
631
632                    ```typescript
633                    let foo = 999;
634                    console.log(foo + 1999);
635                    ```
636
637                    ```typescript
638                    console.log("foo")
639                    ```
640                    "#
641                },
642                cx,
643            );
644            buffer.set_language_registry(language_registry.clone());
645            buffer.set_language(Some(markdown.clone()), cx);
646            buffer
647        });
648        let snapshot = buffer.read(cx).snapshot();
649
650        let (snippets, _) = runnable_ranges(&snapshot, Point::new(3, 5)..Point::new(8, 5));
651        let snippets = snippets
652            .into_iter()
653            .map(|range| snapshot.text_for_range(range).collect::<String>())
654            .collect::<Vec<_>>();
655
656        assert_eq!(
657            snippets,
658            vec![
659                indoc! { r#"
660                    let foo = 999;
661                    console.log(foo + 1999);
662                    "#
663                },
664                "console.log(\"foo\")\n"
665            ]
666        );
667
668        // Three code blocks intersecting with selection
669        let buffer = cx.new_model(|cx| {
670            let mut buffer = Buffer::local(
671                indoc! { r#"
672                    Hey this is Markdown!
673
674                    ```typescript
675                    let foo = 999;
676                    console.log(foo + 1999);
677                    ```
678
679                    ```ts
680                    console.log("foo")
681                    ```
682
683                    ```typescript
684                    console.log("another code block")
685                    ```
686                "# },
687                cx,
688            );
689            buffer.set_language_registry(language_registry.clone());
690            buffer.set_language(Some(markdown.clone()), cx);
691            buffer
692        });
693        let snapshot = buffer.read(cx).snapshot();
694
695        let (snippets, _) = runnable_ranges(&snapshot, Point::new(3, 5)..Point::new(12, 5));
696        let snippets = snippets
697            .into_iter()
698            .map(|range| snapshot.text_for_range(range).collect::<String>())
699            .collect::<Vec<_>>();
700
701        assert_eq!(
702            snippets,
703            vec![
704                indoc! { r#"
705                    let foo = 999;
706                    console.log(foo + 1999);
707                    "#
708                },
709                "console.log(\"foo\")\n",
710                "console.log(\"another code block\")\n",
711            ]
712        );
713
714        // Python code block
715        let buffer = cx.new_model(|cx| {
716            let mut buffer = Buffer::local(
717                indoc! { r#"
718                    Hey this is Markdown!
719
720                    ```python
721                    print("hello there")
722                    print("hello there")
723                    print("hello there")
724                    ```
725                "# },
726                cx,
727            );
728            buffer.set_language_registry(language_registry.clone());
729            buffer.set_language(Some(markdown.clone()), cx);
730            buffer
731        });
732        let snapshot = buffer.read(cx).snapshot();
733
734        let (snippets, _) = runnable_ranges(&snapshot, Point::new(4, 5)..Point::new(5, 5));
735        let snippets = snippets
736            .into_iter()
737            .map(|range| snapshot.text_for_range(range).collect::<String>())
738            .collect::<Vec<_>>();
739
740        assert_eq!(
741            snippets,
742            vec![indoc! { r#"
743                print("hello there")
744                print("hello there")
745                print("hello there")
746                "#
747            },]
748        );
749    }
750}