Sanitize ranges in code labels coming from extensions (#10307)

Max Brunsfeld created

Without any sanitization, extensions would be able to crash zed, because
the editor code assumes these ranges are valid.

Release Notes:

- N/A

Change summary

crates/extension/src/extension_lsp_adapter.rs | 118 +++++++++++++++-----
crates/extension/src/wasm_host/wit.rs         |  12 -
2 files changed, 91 insertions(+), 39 deletions(-)

Detailed changes

crates/extension/src/extension_lsp_adapter.rs 🔗

@@ -292,13 +292,13 @@ fn labels_from_wit(
     labels
         .into_iter()
         .map(|label| {
-            label.map(|label| {
-                build_code_label(
-                    &label,
-                    &language.highlight_text(&label.code.as_str().into(), 0..label.code.len()),
-                    &language,
-                )
-            })
+            let label = label?;
+            let runs = if !label.code.is_empty() {
+                language.highlight_text(&label.code.as_str().into(), 0..label.code.len())
+            } else {
+                Vec::new()
+            };
+            build_code_label(&label, &runs, &language)
         })
         .collect()
 }
@@ -307,7 +307,7 @@ fn build_code_label(
     label: &wit::CodeLabel,
     parsed_runs: &[(Range<usize>, HighlightId)],
     language: &Arc<Language>,
-) -> CodeLabel {
+) -> Option<CodeLabel> {
     let mut text = String::new();
     let mut runs = vec![];
 
@@ -315,7 +315,7 @@ fn build_code_label(
         match span {
             wit::CodeLabelSpan::CodeRange(range) => {
                 let range = Range::from(*range);
-
+                let code_span = &label.code.get(range.clone())?;
                 let mut input_ix = range.start;
                 let mut output_ix = text.len();
                 for (run_range, id) in parsed_runs {
@@ -327,19 +327,18 @@ fn build_code_label(
                     }
 
                     if run_range.start > input_ix {
-                        output_ix += run_range.start - input_ix;
-                        input_ix = run_range.start;
-                    }
-
-                    {
-                        let len = range.end.min(run_range.end) - input_ix;
-                        runs.push((output_ix..output_ix + len, *id));
+                        let len = run_range.start - input_ix;
                         output_ix += len;
                         input_ix += len;
                     }
+
+                    let len = range.end.min(run_range.end) - input_ix;
+                    runs.push((output_ix..output_ix + len, *id));
+                    output_ix += len;
+                    input_ix += len;
                 }
 
-                text.push_str(&label.code[range]);
+                text.push_str(code_span);
             }
             wit::CodeLabelSpan::Literal(span) => {
                 let highlight_id = language
@@ -356,11 +355,13 @@ fn build_code_label(
         }
     }
 
-    CodeLabel {
+    let filter_range = Range::from(label.filter_range);
+    text.get(filter_range.clone())?;
+    Some(CodeLabel {
         text,
         runs,
-        filter_range: label.filter_range.into(),
-    }
+        filter_range,
+    })
 }
 
 impl From<wit::Range> for Range<usize> {
@@ -472,13 +473,13 @@ fn extract_int<T: Serialize>(value: T) -> i32 {
 fn test_build_code_label() {
     use util::test::marked_text_ranges;
 
-    let (code, ranges) = marked_text_ranges(
+    let (code, code_ranges) = marked_text_ranges(
         "«const» «a»: «fn»(«Bcd»(«Efgh»)) -> «Ijklm» = pqrs.tuv",
         false,
     );
-    let runs = ranges
-        .iter()
-        .map(|range| (range.clone(), HighlightId(0)))
+    let code_runs = code_ranges
+        .into_iter()
+        .map(|range| (range, HighlightId(0)))
         .collect::<Vec<_>>();
 
     let label = build_code_label(
@@ -499,22 +500,75 @@ fn test_build_code_label() {
             },
             code,
         },
-        &runs,
+        &code_runs,
         &language::PLAIN_TEXT,
-    );
+    )
+    .unwrap();
 
-    let (text, ranges) = marked_text_ranges("pqrs.tuv: «fn»(«Bcd»(«Efgh»)) -> «Ijklm»", false);
-    let runs = ranges
-        .iter()
-        .map(|range| (range.clone(), HighlightId(0)))
+    let (label_text, label_ranges) =
+        marked_text_ranges("pqrs.tuv: «fn»(«Bcd»(«Efgh»)) -> «Ijklm»", false);
+    let label_runs = label_ranges
+        .into_iter()
+        .map(|range| (range, HighlightId(0)))
         .collect::<Vec<_>>();
 
     assert_eq!(
         label,
         CodeLabel {
-            text,
-            runs,
+            text: label_text,
+            runs: label_runs,
             filter_range: label.filter_range.clone()
         }
     )
 }
+
+#[test]
+fn test_build_code_label_with_invalid_ranges() {
+    use util::test::marked_text_ranges;
+
+    let (code, code_ranges) = marked_text_ranges("const «a»: «B» = '🏀'", false);
+    let code_runs = code_ranges
+        .into_iter()
+        .map(|range| (range, HighlightId(0)))
+        .collect::<Vec<_>>();
+
+    // A span uses a code range that is invalid because it starts inside of
+    // a multi-byte character.
+    let label = build_code_label(
+        &wit::CodeLabel {
+            spans: vec![
+                wit::CodeLabelSpan::CodeRange(wit::Range {
+                    start: code.find('B').unwrap() as u32,
+                    end: code.find(" = ").unwrap() as u32,
+                }),
+                wit::CodeLabelSpan::CodeRange(wit::Range {
+                    start: code.find('🏀').unwrap() as u32 + 1,
+                    end: code.len() as u32,
+                }),
+            ],
+            filter_range: wit::Range {
+                start: 0,
+                end: "B".len() as u32,
+            },
+            code,
+        },
+        &code_runs,
+        &language::PLAIN_TEXT,
+    );
+    assert!(label.is_none());
+
+    // Filter range extends beyond actual text
+    let label = build_code_label(
+        &wit::CodeLabel {
+            spans: vec![wit::CodeLabelSpan::Literal(wit::CodeLabelSpanLiteral {
+                text: "abc".into(),
+                highlight_name: Some("type".into()),
+            })],
+            filter_range: wit::Range { start: 0, end: 5 },
+            code: String::new(),
+        },
+        &code_runs,
+        &language::PLAIN_TEXT,
+    );
+    assert!(label.is_none());
+}

crates/extension/src/wasm_host/wit.rs 🔗

@@ -1,22 +1,20 @@
 mod since_v0_0_1;
 mod since_v0_0_4;
 mod since_v0_0_6;
+use since_v0_0_6 as latest;
 
-use std::ops::RangeInclusive;
-use std::sync::Arc;
-
+use super::{wasm_engine, WasmState};
 use anyhow::{Context, Result};
 use language::{LanguageServerName, LspAdapterDelegate};
 use semantic_version::SemanticVersion;
+use std::{ops::RangeInclusive, sync::Arc};
 use wasmtime::{
     component::{Component, Instance, Linker, Resource},
     Store,
 };
 
-use super::{wasm_engine, WasmState};
-
-use since_v0_0_6 as latest;
-
+#[cfg(test)]
+pub use latest::CodeLabelSpanLiteral;
 pub use latest::{
     zed::extension::lsp::{Completion, CompletionKind, InsertTextFormat, Symbol, SymbolKind},
     CodeLabel, CodeLabelSpan, Command, Range,