mermaid.rs

  1use collections::HashMap;
  2use gpui::{
  3    Animation, AnimationExt, AnyElement, Context, ImageSource, RenderImage, StyledText, Task, img,
  4    pulsating_between,
  5};
  6use std::collections::BTreeMap;
  7use std::ops::Range;
  8use std::sync::{Arc, OnceLock};
  9use std::time::Duration;
 10use ui::prelude::*;
 11
 12use crate::parser::{CodeBlockKind, MarkdownEvent, MarkdownTag};
 13
 14use super::{Markdown, MarkdownStyle, ParsedMarkdown};
 15
 16type MermaidDiagramCache = HashMap<ParsedMarkdownMermaidDiagramContents, Arc<CachedMermaidDiagram>>;
 17
 18#[derive(Clone, Debug)]
 19pub(crate) struct ParsedMarkdownMermaidDiagram {
 20    pub(crate) content_range: Range<usize>,
 21    pub(crate) contents: ParsedMarkdownMermaidDiagramContents,
 22}
 23
 24#[derive(Clone, Debug, PartialEq, Eq, Hash)]
 25pub(crate) struct ParsedMarkdownMermaidDiagramContents {
 26    pub(crate) contents: SharedString,
 27    pub(crate) scale: u32,
 28}
 29
 30#[derive(Default, Clone)]
 31pub(crate) struct MermaidState {
 32    cache: MermaidDiagramCache,
 33    order: Vec<ParsedMarkdownMermaidDiagramContents>,
 34}
 35
 36struct CachedMermaidDiagram {
 37    render_image: Arc<OnceLock<anyhow::Result<Arc<RenderImage>>>>,
 38    fallback_image: Option<Arc<RenderImage>>,
 39    _task: Task<()>,
 40}
 41
 42impl MermaidState {
 43    pub(crate) fn clear(&mut self) {
 44        self.cache.clear();
 45        self.order.clear();
 46    }
 47
 48    fn get_fallback_image(
 49        idx: usize,
 50        old_order: &[ParsedMarkdownMermaidDiagramContents],
 51        new_order_len: usize,
 52        cache: &MermaidDiagramCache,
 53    ) -> Option<Arc<RenderImage>> {
 54        if old_order.len() != new_order_len {
 55            return None;
 56        }
 57
 58        old_order.get(idx).and_then(|old_content| {
 59            cache.get(old_content).and_then(|old_cached| {
 60                old_cached
 61                    .render_image
 62                    .get()
 63                    .and_then(|result| result.as_ref().ok().cloned())
 64                    .or_else(|| old_cached.fallback_image.clone())
 65            })
 66        })
 67    }
 68
 69    pub(crate) fn update(&mut self, parsed: &ParsedMarkdown, cx: &mut Context<Markdown>) {
 70        let mut new_order = Vec::new();
 71        for mermaid_diagram in parsed.mermaid_diagrams.values() {
 72            new_order.push(mermaid_diagram.contents.clone());
 73        }
 74
 75        for (idx, new_content) in new_order.iter().enumerate() {
 76            if !self.cache.contains_key(new_content) {
 77                let fallback =
 78                    Self::get_fallback_image(idx, &self.order, new_order.len(), &self.cache);
 79                self.cache.insert(
 80                    new_content.clone(),
 81                    Arc::new(CachedMermaidDiagram::new(new_content.clone(), fallback, cx)),
 82                );
 83            }
 84        }
 85
 86        let new_order_set: std::collections::HashSet<_> = new_order.iter().cloned().collect();
 87        self.cache
 88            .retain(|content, _| new_order_set.contains(content));
 89        self.order = new_order;
 90    }
 91}
 92
 93impl CachedMermaidDiagram {
 94    fn new(
 95        contents: ParsedMarkdownMermaidDiagramContents,
 96        fallback_image: Option<Arc<RenderImage>>,
 97        cx: &mut Context<Markdown>,
 98    ) -> Self {
 99        let render_image = Arc::new(OnceLock::<anyhow::Result<Arc<RenderImage>>>::new());
100        let render_image_clone = render_image.clone();
101        let svg_renderer = cx.svg_renderer();
102
103        let task = cx.spawn(async move |this, cx| {
104            let value = cx
105                .background_spawn(async move {
106                    let svg_string = mermaid_rs_renderer::render(&contents.contents)?;
107                    let scale = contents.scale as f32 / 100.0;
108                    svg_renderer
109                        .render_single_frame(svg_string.as_bytes(), scale, true)
110                        .map_err(|error| anyhow::anyhow!("{error}"))
111                })
112                .await;
113            let _ = render_image_clone.set(value);
114            this.update(cx, |_, cx| {
115                cx.notify();
116            })
117            .ok();
118        });
119
120        Self {
121            render_image,
122            fallback_image,
123            _task: task,
124        }
125    }
126
127    #[cfg(test)]
128    fn new_for_test(
129        render_image: Option<Arc<RenderImage>>,
130        fallback_image: Option<Arc<RenderImage>>,
131    ) -> Self {
132        let result = Arc::new(OnceLock::new());
133        if let Some(render_image) = render_image {
134            let _ = result.set(Ok(render_image));
135        }
136        Self {
137            render_image: result,
138            fallback_image,
139            _task: Task::ready(()),
140        }
141    }
142}
143
144fn parse_mermaid_info(info: &str) -> Option<u32> {
145    let mut parts = info.split_whitespace();
146    if parts.next()? != "mermaid" {
147        return None;
148    }
149
150    Some(
151        parts
152            .next()
153            .and_then(|scale| scale.parse().ok())
154            .unwrap_or(100)
155            .clamp(10, 500),
156    )
157}
158
159pub(crate) fn extract_mermaid_diagrams(
160    source: &str,
161    events: &[(Range<usize>, MarkdownEvent)],
162) -> BTreeMap<usize, ParsedMarkdownMermaidDiagram> {
163    let mut mermaid_diagrams = BTreeMap::default();
164
165    for (source_range, event) in events {
166        let MarkdownEvent::Start(MarkdownTag::CodeBlock { kind, metadata }) = event else {
167            continue;
168        };
169        let CodeBlockKind::FencedLang(info) = kind else {
170            continue;
171        };
172        let Some(scale) = parse_mermaid_info(info.as_ref()) else {
173            continue;
174        };
175
176        let contents = source[metadata.content_range.clone()]
177            .strip_suffix('\n')
178            .unwrap_or(&source[metadata.content_range.clone()])
179            .to_string();
180        mermaid_diagrams.insert(
181            source_range.start,
182            ParsedMarkdownMermaidDiagram {
183                content_range: metadata.content_range.clone(),
184                contents: ParsedMarkdownMermaidDiagramContents {
185                    contents: contents.into(),
186                    scale,
187                },
188            },
189        );
190    }
191
192    mermaid_diagrams
193}
194
195pub(crate) fn render_mermaid_diagram(
196    parsed: &ParsedMarkdownMermaidDiagram,
197    mermaid_state: &MermaidState,
198    style: &MarkdownStyle,
199) -> AnyElement {
200    let cached = mermaid_state.cache.get(&parsed.contents);
201    let mut container = div().w_full();
202    container.style().refine(&style.code_block);
203
204    if let Some(result) = cached.and_then(|cached| cached.render_image.get()) {
205        match result {
206            Ok(render_image) => container
207                .child(
208                    div().w_full().child(
209                        img(ImageSource::Render(render_image.clone()))
210                            .max_w_full()
211                            .with_fallback(|| {
212                                div()
213                                    .child(Label::new("Failed to load mermaid diagram"))
214                                    .into_any_element()
215                            }),
216                    ),
217                )
218                .into_any_element(),
219            Err(_) => container
220                .child(StyledText::new(parsed.contents.contents.clone()))
221                .into_any_element(),
222        }
223    } else if let Some(fallback) = cached.and_then(|cached| cached.fallback_image.as_ref()) {
224        container
225            .child(
226                div()
227                    .w_full()
228                    .child(
229                        img(ImageSource::Render(fallback.clone()))
230                            .max_w_full()
231                            .with_fallback(|| {
232                                div()
233                                    .child(Label::new("Failed to load mermaid diagram"))
234                                    .into_any_element()
235                            }),
236                    )
237                    .with_animation(
238                        "mermaid-fallback-pulse",
239                        Animation::new(Duration::from_secs(2))
240                            .repeat()
241                            .with_easing(pulsating_between(0.6, 1.0)),
242                        |element, delta| element.opacity(delta),
243                    ),
244            )
245            .into_any_element()
246    } else {
247        container
248            .child(
249                Label::new("Rendering mermaid diagram...")
250                    .color(Color::Muted)
251                    .with_animation(
252                        "mermaid-loading-pulse",
253                        Animation::new(Duration::from_secs(2))
254                            .repeat()
255                            .with_easing(pulsating_between(0.4, 0.8)),
256                        |label, delta| label.alpha(delta),
257                    ),
258            )
259            .into_any_element()
260    }
261}
262
263#[cfg(test)]
264mod tests {
265    use super::{
266        CachedMermaidDiagram, MermaidDiagramCache, MermaidState,
267        ParsedMarkdownMermaidDiagramContents, extract_mermaid_diagrams, parse_mermaid_info,
268    };
269    use crate::{CodeBlockRenderer, Markdown, MarkdownElement, MarkdownOptions, MarkdownStyle};
270    use collections::HashMap;
271    use gpui::{Context, IntoElement, Render, RenderImage, TestAppContext, Window, size};
272    use std::sync::Arc;
273    use ui::prelude::*;
274
275    fn ensure_theme_initialized(cx: &mut TestAppContext) {
276        cx.update(|cx| {
277            if !cx.has_global::<settings::SettingsStore>() {
278                settings::init(cx);
279            }
280            if !cx.has_global::<theme::GlobalTheme>() {
281                theme::init(theme::LoadThemes::JustBase, cx);
282            }
283        });
284    }
285
286    fn render_markdown_with_options(
287        markdown: &str,
288        options: MarkdownOptions,
289        cx: &mut TestAppContext,
290    ) -> crate::RenderedText {
291        struct TestWindow;
292
293        impl Render for TestWindow {
294            fn render(&mut self, _: &mut Window, _: &mut Context<Self>) -> impl IntoElement {
295                div()
296            }
297        }
298
299        ensure_theme_initialized(cx);
300
301        let (_, cx) = cx.add_window_view(|_, _| TestWindow);
302        let markdown = cx.new(|cx| {
303            Markdown::new_with_options(markdown.to_string().into(), None, None, options, cx)
304        });
305        cx.run_until_parked();
306        let (rendered, _) = cx.draw(
307            Default::default(),
308            size(px(600.0), px(600.0)),
309            |_window, _cx| {
310                MarkdownElement::new(markdown, MarkdownStyle::default()).code_block_renderer(
311                    CodeBlockRenderer::Default {
312                        copy_button: false,
313                        copy_button_on_hover: false,
314                        border: false,
315                    },
316                )
317            },
318        );
319        rendered.text
320    }
321
322    fn mock_render_image(cx: &mut TestAppContext) -> Arc<RenderImage> {
323        cx.update(|cx| {
324            cx.svg_renderer()
325                .render_single_frame(
326                    br#"<svg xmlns="http://www.w3.org/2000/svg" width="1" height="1"></svg>"#,
327                    1.0,
328                    true,
329                )
330                .unwrap()
331        })
332    }
333
334    fn mermaid_contents(contents: &str) -> ParsedMarkdownMermaidDiagramContents {
335        ParsedMarkdownMermaidDiagramContents {
336            contents: contents.to_string().into(),
337            scale: 100,
338        }
339    }
340
341    fn mermaid_sequence(diagrams: &[&str]) -> Vec<ParsedMarkdownMermaidDiagramContents> {
342        diagrams
343            .iter()
344            .map(|diagram| mermaid_contents(diagram))
345            .collect()
346    }
347
348    fn mermaid_fallback(
349        new_diagram: &str,
350        new_full_order: &[ParsedMarkdownMermaidDiagramContents],
351        old_full_order: &[ParsedMarkdownMermaidDiagramContents],
352        cache: &MermaidDiagramCache,
353    ) -> Option<Arc<RenderImage>> {
354        let new_content = mermaid_contents(new_diagram);
355        let idx = new_full_order
356            .iter()
357            .position(|diagram| diagram == &new_content)?;
358        MermaidState::get_fallback_image(idx, old_full_order, new_full_order.len(), cache)
359    }
360
361    #[test]
362    fn test_parse_mermaid_info() {
363        assert_eq!(parse_mermaid_info("mermaid"), Some(100));
364        assert_eq!(parse_mermaid_info("mermaid 150"), Some(150));
365        assert_eq!(parse_mermaid_info("mermaid 5"), Some(10));
366        assert_eq!(parse_mermaid_info("mermaid 999"), Some(500));
367        assert_eq!(parse_mermaid_info("rust"), None);
368    }
369
370    #[test]
371    fn test_extract_mermaid_diagrams_parses_scale() {
372        let markdown = "```mermaid 150\ngraph TD;\n```\n\n```rust\nfn main() {}\n```";
373        let events = crate::parser::parse_markdown_with_options(markdown, false).events;
374        let diagrams = extract_mermaid_diagrams(markdown, &events);
375
376        assert_eq!(diagrams.len(), 1);
377        let diagram = diagrams.values().next().unwrap();
378        assert_eq!(diagram.contents.contents, "graph TD;");
379        assert_eq!(diagram.contents.scale, 150);
380    }
381
382    #[gpui::test]
383    fn test_mermaid_fallback_on_edit(cx: &mut TestAppContext) {
384        let old_full_order = mermaid_sequence(&["graph A", "graph B", "graph C"]);
385        let new_full_order = mermaid_sequence(&["graph A", "graph B modified", "graph C"]);
386
387        let svg_b = mock_render_image(cx);
388
389        let mut cache: MermaidDiagramCache = HashMap::default();
390        cache.insert(
391            mermaid_contents("graph A"),
392            Arc::new(CachedMermaidDiagram::new_for_test(
393                Some(mock_render_image(cx)),
394                None,
395            )),
396        );
397        cache.insert(
398            mermaid_contents("graph B"),
399            Arc::new(CachedMermaidDiagram::new_for_test(
400                Some(svg_b.clone()),
401                None,
402            )),
403        );
404        cache.insert(
405            mermaid_contents("graph C"),
406            Arc::new(CachedMermaidDiagram::new_for_test(
407                Some(mock_render_image(cx)),
408                None,
409            )),
410        );
411
412        let fallback =
413            mermaid_fallback("graph B modified", &new_full_order, &old_full_order, &cache);
414
415        assert_eq!(fallback.as_ref().map(|image| image.id), Some(svg_b.id));
416    }
417
418    #[gpui::test]
419    fn test_mermaid_no_fallback_on_add_in_middle(cx: &mut TestAppContext) {
420        let old_full_order = mermaid_sequence(&["graph A", "graph C"]);
421        let new_full_order = mermaid_sequence(&["graph A", "graph NEW", "graph C"]);
422
423        let mut cache: MermaidDiagramCache = HashMap::default();
424        cache.insert(
425            mermaid_contents("graph A"),
426            Arc::new(CachedMermaidDiagram::new_for_test(
427                Some(mock_render_image(cx)),
428                None,
429            )),
430        );
431        cache.insert(
432            mermaid_contents("graph C"),
433            Arc::new(CachedMermaidDiagram::new_for_test(
434                Some(mock_render_image(cx)),
435                None,
436            )),
437        );
438
439        let fallback = mermaid_fallback("graph NEW", &new_full_order, &old_full_order, &cache);
440
441        assert!(fallback.is_none());
442    }
443
444    #[gpui::test]
445    fn test_mermaid_fallback_chains_on_rapid_edits(cx: &mut TestAppContext) {
446        let old_full_order = mermaid_sequence(&["graph A", "graph B modified", "graph C"]);
447        let new_full_order = mermaid_sequence(&["graph A", "graph B modified again", "graph C"]);
448
449        let original_svg = mock_render_image(cx);
450
451        let mut cache: MermaidDiagramCache = HashMap::default();
452        cache.insert(
453            mermaid_contents("graph A"),
454            Arc::new(CachedMermaidDiagram::new_for_test(
455                Some(mock_render_image(cx)),
456                None,
457            )),
458        );
459        cache.insert(
460            mermaid_contents("graph B modified"),
461            Arc::new(CachedMermaidDiagram::new_for_test(
462                None,
463                Some(original_svg.clone()),
464            )),
465        );
466        cache.insert(
467            mermaid_contents("graph C"),
468            Arc::new(CachedMermaidDiagram::new_for_test(
469                Some(mock_render_image(cx)),
470                None,
471            )),
472        );
473
474        let fallback = mermaid_fallback(
475            "graph B modified again",
476            &new_full_order,
477            &old_full_order,
478            &cache,
479        );
480
481        assert_eq!(
482            fallback.as_ref().map(|image| image.id),
483            Some(original_svg.id)
484        );
485    }
486
487    #[gpui::test]
488    fn test_mermaid_fallback_with_duplicate_blocks_edit_second(cx: &mut TestAppContext) {
489        let old_full_order = mermaid_sequence(&["graph A", "graph A", "graph B"]);
490        let new_full_order = mermaid_sequence(&["graph A", "graph A edited", "graph B"]);
491
492        let svg_a = mock_render_image(cx);
493
494        let mut cache: MermaidDiagramCache = HashMap::default();
495        cache.insert(
496            mermaid_contents("graph A"),
497            Arc::new(CachedMermaidDiagram::new_for_test(
498                Some(svg_a.clone()),
499                None,
500            )),
501        );
502        cache.insert(
503            mermaid_contents("graph B"),
504            Arc::new(CachedMermaidDiagram::new_for_test(
505                Some(mock_render_image(cx)),
506                None,
507            )),
508        );
509
510        let fallback = mermaid_fallback("graph A edited", &new_full_order, &old_full_order, &cache);
511
512        assert_eq!(fallback.as_ref().map(|image| image.id), Some(svg_a.id));
513    }
514
515    #[gpui::test]
516    fn test_mermaid_rendering_replaces_code_block_text(cx: &mut TestAppContext) {
517        let rendered = render_markdown_with_options(
518            "```mermaid\ngraph TD;\n```",
519            MarkdownOptions {
520                render_mermaid_diagrams: true,
521                ..Default::default()
522            },
523            cx,
524        );
525
526        let text = rendered
527            .lines
528            .iter()
529            .map(|line| line.layout.wrapped_text())
530            .collect::<Vec<_>>()
531            .join("\n");
532
533        assert!(!text.contains("graph TD;"));
534    }
535
536    #[gpui::test]
537    fn test_mermaid_source_anchor_maps_inside_block(cx: &mut TestAppContext) {
538        struct TestWindow;
539
540        impl Render for TestWindow {
541            fn render(&mut self, _: &mut Window, _: &mut Context<Self>) -> impl IntoElement {
542                div()
543            }
544        }
545
546        ensure_theme_initialized(cx);
547
548        let (_, cx) = cx.add_window_view(|_, _| TestWindow);
549        let markdown = cx.new(|cx| {
550            Markdown::new_with_options(
551                "```mermaid\ngraph TD;\n```".into(),
552                None,
553                None,
554                MarkdownOptions {
555                    render_mermaid_diagrams: true,
556                    ..Default::default()
557                },
558                cx,
559            )
560        });
561        cx.run_until_parked();
562        let render_image = mock_render_image(cx);
563        markdown.update(cx, |markdown, _| {
564            let contents = markdown
565                .parsed_markdown
566                .mermaid_diagrams
567                .values()
568                .next()
569                .unwrap()
570                .contents
571                .clone();
572            markdown.mermaid_state.cache.insert(
573                contents.clone(),
574                Arc::new(CachedMermaidDiagram::new_for_test(Some(render_image), None)),
575            );
576            markdown.mermaid_state.order = vec![contents];
577        });
578
579        let (rendered, _) = cx.draw(
580            Default::default(),
581            size(px(600.0), px(600.0)),
582            |_window, _cx| {
583                MarkdownElement::new(markdown.clone(), MarkdownStyle::default())
584                    .code_block_renderer(CodeBlockRenderer::Default {
585                        copy_button: false,
586                        copy_button_on_hover: false,
587                        border: false,
588                    })
589            },
590        );
591
592        let mermaid_diagram = markdown.update(cx, |markdown, _| {
593            markdown
594                .parsed_markdown
595                .mermaid_diagrams
596                .values()
597                .next()
598                .unwrap()
599                .clone()
600        });
601        assert!(
602            rendered
603                .text
604                .position_for_source_index(mermaid_diagram.content_range.start)
605                .is_some()
606        );
607        assert!(
608            rendered
609                .text
610                .position_for_source_index(mermaid_diagram.content_range.end.saturating_sub(1))
611                .is_some()
612        );
613    }
614}