1use collections::HashMap;
2use gpui::{
3 Animation, AnimationExt, AnyElement, Context, ImageSource, RenderImage, StyledText, Task, img,
4 pulsating_between,
5};
6use std::collections::BTreeMap;
7use std::ops::Range;
8use std::sync::{Arc, OnceLock};
9use std::time::Duration;
10use ui::prelude::*;
11
12use crate::parser::{CodeBlockKind, MarkdownEvent, MarkdownTag};
13
14use super::{Markdown, MarkdownStyle, ParsedMarkdown};
15
16type MermaidDiagramCache = HashMap<ParsedMarkdownMermaidDiagramContents, Arc<CachedMermaidDiagram>>;
17
18#[derive(Clone, Debug)]
19pub(crate) struct ParsedMarkdownMermaidDiagram {
20 pub(crate) content_range: Range<usize>,
21 pub(crate) contents: ParsedMarkdownMermaidDiagramContents,
22}
23
24#[derive(Clone, Debug, PartialEq, Eq, Hash)]
25pub(crate) struct ParsedMarkdownMermaidDiagramContents {
26 pub(crate) contents: SharedString,
27 pub(crate) scale: u32,
28}
29
30#[derive(Default, Clone)]
31pub(crate) struct MermaidState {
32 cache: MermaidDiagramCache,
33 order: Vec<ParsedMarkdownMermaidDiagramContents>,
34}
35
36struct CachedMermaidDiagram {
37 render_image: Arc<OnceLock<anyhow::Result<Arc<RenderImage>>>>,
38 fallback_image: Option<Arc<RenderImage>>,
39 _task: Task<()>,
40}
41
42impl MermaidState {
43 pub(crate) fn clear(&mut self) {
44 self.cache.clear();
45 self.order.clear();
46 }
47
48 fn get_fallback_image(
49 idx: usize,
50 old_order: &[ParsedMarkdownMermaidDiagramContents],
51 new_order_len: usize,
52 cache: &MermaidDiagramCache,
53 ) -> Option<Arc<RenderImage>> {
54 if old_order.len() != new_order_len {
55 return None;
56 }
57
58 old_order.get(idx).and_then(|old_content| {
59 cache.get(old_content).and_then(|old_cached| {
60 old_cached
61 .render_image
62 .get()
63 .and_then(|result| result.as_ref().ok().cloned())
64 .or_else(|| old_cached.fallback_image.clone())
65 })
66 })
67 }
68
69 pub(crate) fn update(&mut self, parsed: &ParsedMarkdown, cx: &mut Context<Markdown>) {
70 let mut new_order = Vec::new();
71 for mermaid_diagram in parsed.mermaid_diagrams.values() {
72 new_order.push(mermaid_diagram.contents.clone());
73 }
74
75 for (idx, new_content) in new_order.iter().enumerate() {
76 if !self.cache.contains_key(new_content) {
77 let fallback =
78 Self::get_fallback_image(idx, &self.order, new_order.len(), &self.cache);
79 self.cache.insert(
80 new_content.clone(),
81 Arc::new(CachedMermaidDiagram::new(new_content.clone(), fallback, cx)),
82 );
83 }
84 }
85
86 let new_order_set: std::collections::HashSet<_> = new_order.iter().cloned().collect();
87 self.cache
88 .retain(|content, _| new_order_set.contains(content));
89 self.order = new_order;
90 }
91}
92
93impl CachedMermaidDiagram {
94 fn new(
95 contents: ParsedMarkdownMermaidDiagramContents,
96 fallback_image: Option<Arc<RenderImage>>,
97 cx: &mut Context<Markdown>,
98 ) -> Self {
99 let render_image = Arc::new(OnceLock::<anyhow::Result<Arc<RenderImage>>>::new());
100 let render_image_clone = render_image.clone();
101 let svg_renderer = cx.svg_renderer();
102
103 let task = cx.spawn(async move |this, cx| {
104 let value = cx
105 .background_spawn(async move {
106 let svg_string = mermaid_rs_renderer::render(&contents.contents)?;
107 let scale = contents.scale as f32 / 100.0;
108 svg_renderer
109 .render_single_frame(svg_string.as_bytes(), scale)
110 .map_err(|error| anyhow::anyhow!("{error}"))
111 })
112 .await;
113 let _ = render_image_clone.set(value);
114 this.update(cx, |_, cx| {
115 cx.notify();
116 })
117 .ok();
118 });
119
120 Self {
121 render_image,
122 fallback_image,
123 _task: task,
124 }
125 }
126
127 #[cfg(test)]
128 fn new_for_test(
129 render_image: Option<Arc<RenderImage>>,
130 fallback_image: Option<Arc<RenderImage>>,
131 ) -> Self {
132 let result = Arc::new(OnceLock::new());
133 if let Some(render_image) = render_image {
134 let _ = result.set(Ok(render_image));
135 }
136 Self {
137 render_image: result,
138 fallback_image,
139 _task: Task::ready(()),
140 }
141 }
142}
143
144fn parse_mermaid_info(info: &str) -> Option<u32> {
145 let mut parts = info.split_whitespace();
146 if parts.next()? != "mermaid" {
147 return None;
148 }
149
150 Some(
151 parts
152 .next()
153 .and_then(|scale| scale.parse().ok())
154 .unwrap_or(100)
155 .clamp(10, 500),
156 )
157}
158
159pub(crate) fn extract_mermaid_diagrams(
160 source: &str,
161 events: &[(Range<usize>, MarkdownEvent)],
162) -> BTreeMap<usize, ParsedMarkdownMermaidDiagram> {
163 let mut mermaid_diagrams = BTreeMap::default();
164
165 for (source_range, event) in events {
166 let MarkdownEvent::Start(MarkdownTag::CodeBlock { kind, metadata }) = event else {
167 continue;
168 };
169 let CodeBlockKind::FencedLang(info) = kind else {
170 continue;
171 };
172 let Some(scale) = parse_mermaid_info(info.as_ref()) else {
173 continue;
174 };
175
176 let contents = source[metadata.content_range.clone()]
177 .strip_suffix('\n')
178 .unwrap_or(&source[metadata.content_range.clone()])
179 .to_string();
180 mermaid_diagrams.insert(
181 source_range.start,
182 ParsedMarkdownMermaidDiagram {
183 content_range: metadata.content_range.clone(),
184 contents: ParsedMarkdownMermaidDiagramContents {
185 contents: contents.into(),
186 scale,
187 },
188 },
189 );
190 }
191
192 mermaid_diagrams
193}
194
195pub(crate) fn render_mermaid_diagram(
196 parsed: &ParsedMarkdownMermaidDiagram,
197 mermaid_state: &MermaidState,
198 style: &MarkdownStyle,
199) -> AnyElement {
200 let cached = mermaid_state.cache.get(&parsed.contents);
201 let mut container = div().w_full();
202 container.style().refine(&style.code_block);
203
204 if let Some(result) = cached.and_then(|cached| cached.render_image.get()) {
205 match result {
206 Ok(render_image) => container
207 .child(
208 div().w_full().child(
209 img(ImageSource::Render(render_image.clone()))
210 .max_w_full()
211 .with_fallback(|| {
212 div()
213 .child(Label::new("Failed to load mermaid diagram"))
214 .into_any_element()
215 }),
216 ),
217 )
218 .into_any_element(),
219 Err(_) => container
220 .child(StyledText::new(parsed.contents.contents.clone()))
221 .into_any_element(),
222 }
223 } else if let Some(fallback) = cached.and_then(|cached| cached.fallback_image.as_ref()) {
224 container
225 .child(
226 div()
227 .w_full()
228 .child(
229 img(ImageSource::Render(fallback.clone()))
230 .max_w_full()
231 .with_fallback(|| {
232 div()
233 .child(Label::new("Failed to load mermaid diagram"))
234 .into_any_element()
235 }),
236 )
237 .with_animation(
238 "mermaid-fallback-pulse",
239 Animation::new(Duration::from_secs(2))
240 .repeat()
241 .with_easing(pulsating_between(0.6, 1.0)),
242 |element, delta| element.opacity(delta),
243 ),
244 )
245 .into_any_element()
246 } else {
247 container
248 .child(
249 Label::new("Rendering mermaid diagram...")
250 .color(Color::Muted)
251 .with_animation(
252 "mermaid-loading-pulse",
253 Animation::new(Duration::from_secs(2))
254 .repeat()
255 .with_easing(pulsating_between(0.4, 0.8)),
256 |label, delta| label.alpha(delta),
257 ),
258 )
259 .into_any_element()
260 }
261}
262
263#[cfg(test)]
264mod tests {
265 use super::{
266 CachedMermaidDiagram, MermaidDiagramCache, MermaidState,
267 ParsedMarkdownMermaidDiagramContents, extract_mermaid_diagrams, parse_mermaid_info,
268 };
269 use crate::{
270 CodeBlockRenderer, CopyButtonVisibility, Markdown, MarkdownElement, MarkdownOptions,
271 MarkdownStyle,
272 };
273 use collections::HashMap;
274 use gpui::{Context, IntoElement, Render, RenderImage, TestAppContext, Window, size};
275 use std::sync::Arc;
276 use ui::prelude::*;
277
278 fn ensure_theme_initialized(cx: &mut TestAppContext) {
279 cx.update(|cx| {
280 if !cx.has_global::<settings::SettingsStore>() {
281 settings::init(cx);
282 }
283 if !cx.has_global::<theme::GlobalTheme>() {
284 theme_settings::init(theme::LoadThemes::JustBase, cx);
285 }
286 });
287 }
288
289 fn render_markdown_with_options(
290 markdown: &str,
291 options: MarkdownOptions,
292 cx: &mut TestAppContext,
293 ) -> crate::RenderedText {
294 struct TestWindow;
295
296 impl Render for TestWindow {
297 fn render(&mut self, _: &mut Window, _: &mut Context<Self>) -> impl IntoElement {
298 div()
299 }
300 }
301
302 ensure_theme_initialized(cx);
303
304 let (_, cx) = cx.add_window_view(|_, _| TestWindow);
305 let markdown = cx.new(|cx| {
306 Markdown::new_with_options(markdown.to_string().into(), None, None, options, cx)
307 });
308 cx.run_until_parked();
309 let (rendered, _) = cx.draw(
310 Default::default(),
311 size(px(600.0), px(600.0)),
312 |_window, _cx| {
313 MarkdownElement::new(markdown, MarkdownStyle::default()).code_block_renderer(
314 CodeBlockRenderer::Default {
315 copy_button_visibility: CopyButtonVisibility::Hidden,
316 border: false,
317 },
318 )
319 },
320 );
321 rendered.text
322 }
323
324 fn mock_render_image(cx: &mut TestAppContext) -> Arc<RenderImage> {
325 cx.update(|cx| {
326 cx.svg_renderer()
327 .render_single_frame(
328 br#"<svg xmlns="http://www.w3.org/2000/svg" width="1" height="1"></svg>"#,
329 1.0,
330 )
331 .unwrap()
332 })
333 }
334
335 fn mermaid_contents(contents: &str) -> ParsedMarkdownMermaidDiagramContents {
336 ParsedMarkdownMermaidDiagramContents {
337 contents: contents.to_string().into(),
338 scale: 100,
339 }
340 }
341
342 fn mermaid_sequence(diagrams: &[&str]) -> Vec<ParsedMarkdownMermaidDiagramContents> {
343 diagrams
344 .iter()
345 .map(|diagram| mermaid_contents(diagram))
346 .collect()
347 }
348
349 fn mermaid_fallback(
350 new_diagram: &str,
351 new_full_order: &[ParsedMarkdownMermaidDiagramContents],
352 old_full_order: &[ParsedMarkdownMermaidDiagramContents],
353 cache: &MermaidDiagramCache,
354 ) -> Option<Arc<RenderImage>> {
355 let new_content = mermaid_contents(new_diagram);
356 let idx = new_full_order
357 .iter()
358 .position(|diagram| diagram == &new_content)?;
359 MermaidState::get_fallback_image(idx, old_full_order, new_full_order.len(), cache)
360 }
361
362 #[test]
363 fn test_parse_mermaid_info() {
364 assert_eq!(parse_mermaid_info("mermaid"), Some(100));
365 assert_eq!(parse_mermaid_info("mermaid 150"), Some(150));
366 assert_eq!(parse_mermaid_info("mermaid 5"), Some(10));
367 assert_eq!(parse_mermaid_info("mermaid 999"), Some(500));
368 assert_eq!(parse_mermaid_info("rust"), None);
369 }
370
371 #[test]
372 fn test_extract_mermaid_diagrams_parses_scale() {
373 let markdown = "```mermaid 150\ngraph TD;\n```\n\n```rust\nfn main() {}\n```";
374 let events = crate::parser::parse_markdown_with_options(markdown, false).events;
375 let diagrams = extract_mermaid_diagrams(markdown, &events);
376
377 assert_eq!(diagrams.len(), 1);
378 let diagram = diagrams.values().next().unwrap();
379 assert_eq!(diagram.contents.contents, "graph TD;");
380 assert_eq!(diagram.contents.scale, 150);
381 }
382
383 #[gpui::test]
384 fn test_mermaid_fallback_on_edit(cx: &mut TestAppContext) {
385 let old_full_order = mermaid_sequence(&["graph A", "graph B", "graph C"]);
386 let new_full_order = mermaid_sequence(&["graph A", "graph B modified", "graph C"]);
387
388 let svg_b = mock_render_image(cx);
389
390 let mut cache: MermaidDiagramCache = HashMap::default();
391 cache.insert(
392 mermaid_contents("graph A"),
393 Arc::new(CachedMermaidDiagram::new_for_test(
394 Some(mock_render_image(cx)),
395 None,
396 )),
397 );
398 cache.insert(
399 mermaid_contents("graph B"),
400 Arc::new(CachedMermaidDiagram::new_for_test(
401 Some(svg_b.clone()),
402 None,
403 )),
404 );
405 cache.insert(
406 mermaid_contents("graph C"),
407 Arc::new(CachedMermaidDiagram::new_for_test(
408 Some(mock_render_image(cx)),
409 None,
410 )),
411 );
412
413 let fallback =
414 mermaid_fallback("graph B modified", &new_full_order, &old_full_order, &cache);
415
416 assert_eq!(fallback.as_ref().map(|image| image.id), Some(svg_b.id));
417 }
418
419 #[gpui::test]
420 fn test_mermaid_no_fallback_on_add_in_middle(cx: &mut TestAppContext) {
421 let old_full_order = mermaid_sequence(&["graph A", "graph C"]);
422 let new_full_order = mermaid_sequence(&["graph A", "graph NEW", "graph C"]);
423
424 let mut cache: MermaidDiagramCache = HashMap::default();
425 cache.insert(
426 mermaid_contents("graph A"),
427 Arc::new(CachedMermaidDiagram::new_for_test(
428 Some(mock_render_image(cx)),
429 None,
430 )),
431 );
432 cache.insert(
433 mermaid_contents("graph C"),
434 Arc::new(CachedMermaidDiagram::new_for_test(
435 Some(mock_render_image(cx)),
436 None,
437 )),
438 );
439
440 let fallback = mermaid_fallback("graph NEW", &new_full_order, &old_full_order, &cache);
441
442 assert!(fallback.is_none());
443 }
444
445 #[gpui::test]
446 fn test_mermaid_fallback_chains_on_rapid_edits(cx: &mut TestAppContext) {
447 let old_full_order = mermaid_sequence(&["graph A", "graph B modified", "graph C"]);
448 let new_full_order = mermaid_sequence(&["graph A", "graph B modified again", "graph C"]);
449
450 let original_svg = mock_render_image(cx);
451
452 let mut cache: MermaidDiagramCache = HashMap::default();
453 cache.insert(
454 mermaid_contents("graph A"),
455 Arc::new(CachedMermaidDiagram::new_for_test(
456 Some(mock_render_image(cx)),
457 None,
458 )),
459 );
460 cache.insert(
461 mermaid_contents("graph B modified"),
462 Arc::new(CachedMermaidDiagram::new_for_test(
463 None,
464 Some(original_svg.clone()),
465 )),
466 );
467 cache.insert(
468 mermaid_contents("graph C"),
469 Arc::new(CachedMermaidDiagram::new_for_test(
470 Some(mock_render_image(cx)),
471 None,
472 )),
473 );
474
475 let fallback = mermaid_fallback(
476 "graph B modified again",
477 &new_full_order,
478 &old_full_order,
479 &cache,
480 );
481
482 assert_eq!(
483 fallback.as_ref().map(|image| image.id),
484 Some(original_svg.id)
485 );
486 }
487
488 #[gpui::test]
489 fn test_mermaid_fallback_with_duplicate_blocks_edit_second(cx: &mut TestAppContext) {
490 let old_full_order = mermaid_sequence(&["graph A", "graph A", "graph B"]);
491 let new_full_order = mermaid_sequence(&["graph A", "graph A edited", "graph B"]);
492
493 let svg_a = mock_render_image(cx);
494
495 let mut cache: MermaidDiagramCache = HashMap::default();
496 cache.insert(
497 mermaid_contents("graph A"),
498 Arc::new(CachedMermaidDiagram::new_for_test(
499 Some(svg_a.clone()),
500 None,
501 )),
502 );
503 cache.insert(
504 mermaid_contents("graph B"),
505 Arc::new(CachedMermaidDiagram::new_for_test(
506 Some(mock_render_image(cx)),
507 None,
508 )),
509 );
510
511 let fallback = mermaid_fallback("graph A edited", &new_full_order, &old_full_order, &cache);
512
513 assert_eq!(fallback.as_ref().map(|image| image.id), Some(svg_a.id));
514 }
515
516 #[gpui::test]
517 fn test_mermaid_rendering_replaces_code_block_text(cx: &mut TestAppContext) {
518 let rendered = render_markdown_with_options(
519 "```mermaid\ngraph TD;\n```",
520 MarkdownOptions {
521 render_mermaid_diagrams: true,
522 ..Default::default()
523 },
524 cx,
525 );
526
527 let text = rendered
528 .lines
529 .iter()
530 .map(|line| line.layout.wrapped_text())
531 .collect::<Vec<_>>()
532 .join("\n");
533
534 assert!(!text.contains("graph TD;"));
535 }
536
537 #[gpui::test]
538 fn test_mermaid_source_anchor_maps_inside_block(cx: &mut TestAppContext) {
539 struct TestWindow;
540
541 impl Render for TestWindow {
542 fn render(&mut self, _: &mut Window, _: &mut Context<Self>) -> impl IntoElement {
543 div()
544 }
545 }
546
547 ensure_theme_initialized(cx);
548
549 let (_, cx) = cx.add_window_view(|_, _| TestWindow);
550 let markdown = cx.new(|cx| {
551 Markdown::new_with_options(
552 "```mermaid\ngraph TD;\n```".into(),
553 None,
554 None,
555 MarkdownOptions {
556 render_mermaid_diagrams: true,
557 ..Default::default()
558 },
559 cx,
560 )
561 });
562 cx.run_until_parked();
563 let render_image = mock_render_image(cx);
564 markdown.update(cx, |markdown, _| {
565 let contents = markdown
566 .parsed_markdown
567 .mermaid_diagrams
568 .values()
569 .next()
570 .unwrap()
571 .contents
572 .clone();
573 markdown.mermaid_state.cache.insert(
574 contents.clone(),
575 Arc::new(CachedMermaidDiagram::new_for_test(Some(render_image), None)),
576 );
577 markdown.mermaid_state.order = vec![contents];
578 });
579
580 let (rendered, _) = cx.draw(
581 Default::default(),
582 size(px(600.0), px(600.0)),
583 |_window, _cx| {
584 MarkdownElement::new(markdown.clone(), MarkdownStyle::default())
585 .code_block_renderer(CodeBlockRenderer::Default {
586 copy_button_visibility: CopyButtonVisibility::Hidden,
587 border: false,
588 })
589 },
590 );
591
592 let mermaid_diagram = markdown.update(cx, |markdown, _| {
593 markdown
594 .parsed_markdown
595 .mermaid_diagrams
596 .values()
597 .next()
598 .unwrap()
599 .clone()
600 });
601 assert!(
602 rendered
603 .text
604 .position_for_source_index(mermaid_diagram.content_range.start)
605 .is_some()
606 );
607 assert!(
608 rendered
609 .text
610 .position_for_source_index(mermaid_diagram.content_range.end.saturating_sub(1))
611 .is_some()
612 );
613 }
614}