Merge branch 'main' into n/t2

Marshall Bowers created

Change summary

crates/assistant/src/codegen.rs     | 108 ++++++++++++++++----
crates/assistant/src/prompts.rs     |  18 +-
crates/theme2/src/default_colors.rs | 165 ++++++++++++++++++------------
3 files changed, 194 insertions(+), 97 deletions(-)

Detailed changes

crates/assistant/src/codegen.rs 🔗

@@ -118,7 +118,7 @@ impl Codegen {
 
                     let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
                     let diff = cx.background().spawn(async move {
-                        let chunks = strip_markdown_codeblock(response.await?);
+                        let chunks = strip_invalid_spans_from_codeblock(response.await?);
                         futures::pin_mut!(chunks);
                         let mut diff = StreamingDiff::new(selected_text.to_string());
 
@@ -279,12 +279,13 @@ impl Codegen {
     }
 }
 
-fn strip_markdown_codeblock(
+fn strip_invalid_spans_from_codeblock(
     stream: impl Stream<Item = Result<String>>,
 ) -> impl Stream<Item = Result<String>> {
     let mut first_line = true;
     let mut buffer = String::new();
-    let mut starts_with_fenced_code_block = false;
+    let mut starts_with_markdown_codeblock = false;
+    let mut includes_start_or_end_span = false;
     stream.filter_map(move |chunk| {
         let chunk = match chunk {
             Ok(chunk) => chunk,
@@ -292,11 +293,31 @@ fn strip_markdown_codeblock(
         };
         buffer.push_str(&chunk);
 
+        if buffer.len() > "<|S|".len() && buffer.starts_with("<|S|") {
+            includes_start_or_end_span = true;
+
+            buffer = buffer
+                .strip_prefix("<|S|>")
+                .or_else(|| buffer.strip_prefix("<|S|"))
+                .unwrap_or(&buffer)
+                .to_string();
+        } else if buffer.ends_with("|E|>") {
+            includes_start_or_end_span = true;
+        } else if buffer.starts_with("<|")
+            || buffer.starts_with("<|S")
+            || buffer.starts_with("<|S|")
+            || buffer.ends_with("|")
+            || buffer.ends_with("|E")
+            || buffer.ends_with("|E|")
+        {
+            return future::ready(None);
+        }
+
         if first_line {
             if buffer == "" || buffer == "`" || buffer == "``" {
                 return future::ready(None);
             } else if buffer.starts_with("```") {
-                starts_with_fenced_code_block = true;
+                starts_with_markdown_codeblock = true;
                 if let Some(newline_ix) = buffer.find('\n') {
                     buffer.replace_range(..newline_ix + 1, "");
                     first_line = false;
@@ -306,16 +327,26 @@ fn strip_markdown_codeblock(
             }
         }
 
-        let text = if starts_with_fenced_code_block {
-            buffer
+        let mut text = buffer.to_string();
+        if starts_with_markdown_codeblock {
+            text = text
                 .strip_suffix("\n```\n")
-                .or_else(|| buffer.strip_suffix("\n```"))
-                .or_else(|| buffer.strip_suffix("\n``"))
-                .or_else(|| buffer.strip_suffix("\n`"))
-                .or_else(|| buffer.strip_suffix('\n'))
-                .unwrap_or(&buffer)
-        } else {
-            &buffer
+                .or_else(|| text.strip_suffix("\n```"))
+                .or_else(|| text.strip_suffix("\n``"))
+                .or_else(|| text.strip_suffix("\n`"))
+                .or_else(|| text.strip_suffix('\n'))
+                .unwrap_or(&text)
+                .to_string();
+        }
+
+        if includes_start_or_end_span {
+            text = text
+                .strip_suffix("|E|>")
+                .or_else(|| text.strip_suffix("E|>"))
+                .or_else(|| text.strip_prefix("|>"))
+                .or_else(|| text.strip_prefix(">"))
+                .unwrap_or(&text)
+                .to_string();
         };
 
         if text.contains('\n') {
@@ -328,6 +359,7 @@ fn strip_markdown_codeblock(
         } else {
             Some(Ok(buffer.clone()))
         };
+
         buffer = remainder;
         future::ready(result)
     })
@@ -558,50 +590,82 @@ mod tests {
     }
 
     #[gpui::test]
-    async fn test_strip_markdown_codeblock() {
+    async fn test_strip_invalid_spans_from_codeblock() {
         assert_eq!(
-            strip_markdown_codeblock(chunks("Lorem ipsum dolor", 2))
+            strip_invalid_spans_from_codeblock(chunks("Lorem ipsum dolor", 2))
                 .map(|chunk| chunk.unwrap())
                 .collect::<String>()
                 .await,
             "Lorem ipsum dolor"
         );
         assert_eq!(
-            strip_markdown_codeblock(chunks("```\nLorem ipsum dolor", 2))
+            strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor", 2))
                 .map(|chunk| chunk.unwrap())
                 .collect::<String>()
                 .await,
             "Lorem ipsum dolor"
         );
         assert_eq!(
-            strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
+            strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
                 .map(|chunk| chunk.unwrap())
                 .collect::<String>()
                 .await,
             "Lorem ipsum dolor"
         );
         assert_eq!(
-            strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
+            strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
                 .map(|chunk| chunk.unwrap())
                 .collect::<String>()
                 .await,
             "Lorem ipsum dolor"
         );
         assert_eq!(
-            strip_markdown_codeblock(chunks("```html\n```js\nLorem ipsum dolor\n```\n```", 2))
+            strip_invalid_spans_from_codeblock(chunks(
+                "```html\n```js\nLorem ipsum dolor\n```\n```",
+                2
+            ))
+            .map(|chunk| chunk.unwrap())
+            .collect::<String>()
+            .await,
+            "```js\nLorem ipsum dolor\n```"
+        );
+        assert_eq!(
+            strip_invalid_spans_from_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
                 .map(|chunk| chunk.unwrap())
                 .collect::<String>()
                 .await,
-            "```js\nLorem ipsum dolor\n```"
+            "``\nLorem ipsum dolor\n```"
         );
         assert_eq!(
-            strip_markdown_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
+            strip_invalid_spans_from_codeblock(chunks("<|S|Lorem ipsum|E|>", 2))
                 .map(|chunk| chunk.unwrap())
                 .collect::<String>()
                 .await,
-            "``\nLorem ipsum dolor\n```"
+            "Lorem ipsum"
         );
 
+        assert_eq!(
+            strip_invalid_spans_from_codeblock(chunks("<|S|>Lorem ipsum", 2))
+                .map(|chunk| chunk.unwrap())
+                .collect::<String>()
+                .await,
+            "Lorem ipsum"
+        );
+
+        assert_eq!(
+            strip_invalid_spans_from_codeblock(chunks("```\n<|S|>Lorem ipsum\n```", 2))
+                .map(|chunk| chunk.unwrap())
+                .collect::<String>()
+                .await,
+            "Lorem ipsum"
+        );
+        assert_eq!(
+            strip_invalid_spans_from_codeblock(chunks("```\n<|S|Lorem ipsum|E|>\n```", 2))
+                .map(|chunk| chunk.unwrap())
+                .collect::<String>()
+                .await,
+            "Lorem ipsum"
+        );
         fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
             stream::iter(
                 text.chars()

crates/assistant/src/prompts.rs 🔗

@@ -80,12 +80,12 @@ fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> S
             if !flushed_selection {
                 // The collapsed node ends after the selection starts, so we'll flush the selection first.
                 summary.extend(buffer.text_for_range(offset..selected_range.start));
-                summary.push_str("<|START|");
+                summary.push_str("<|S|");
                 if selected_range.end == selected_range.start {
                     summary.push_str(">");
                 } else {
                     summary.extend(buffer.text_for_range(selected_range.clone()));
-                    summary.push_str("|END|>");
+                    summary.push_str("|E|>");
                 }
                 offset = selected_range.end;
                 flushed_selection = true;
@@ -107,12 +107,12 @@ fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> S
     // Flush selection if we haven't already done so.
     if !flushed_selection && offset <= selected_range.start {
         summary.extend(buffer.text_for_range(offset..selected_range.start));
-        summary.push_str("<|START|");
+        summary.push_str("<|S|");
         if selected_range.end == selected_range.start {
             summary.push_str(">");
         } else {
             summary.extend(buffer.text_for_range(selected_range.clone()));
-            summary.push_str("|END|>");
+            summary.push_str("|E|>");
         }
         offset = selected_range.end;
     }
@@ -260,7 +260,7 @@ pub(crate) mod tests {
             summarize(&snapshot, Point::new(1, 4)..Point::new(1, 4)),
             indoc! {"
                 struct X {
-                    <|START|>a: usize,
+                    <|S|>a: usize,
                     b: usize,
                 }
 
@@ -286,7 +286,7 @@ pub(crate) mod tests {
                 impl X {
 
                     fn new() -> Self {
-                        let <|START|a |END|>= 1;
+                        let <|S|a |E|>= 1;
                         let b = 2;
                         Self { a, b }
                     }
@@ -307,7 +307,7 @@ pub(crate) mod tests {
                 }
 
                 impl X {
-                <|START|>
+                <|S|>
                     fn new() -> Self {}
 
                     pub fn a(&self, param: bool) -> usize {}
@@ -333,7 +333,7 @@ pub(crate) mod tests {
 
                     pub fn b(&self) -> usize {}
                 }
-                <|START|>"}
+                <|S|>"}
         );
 
         // Ensure nested functions get collapsed properly.
@@ -369,7 +369,7 @@ pub(crate) mod tests {
         assert_eq!(
             summarize(&snapshot, Point::new(0, 0)..Point::new(0, 0)),
             indoc! {"
-                <|START|>struct X {
+                <|S|>struct X {
                     a: usize,
                     b: usize,
                 }

crates/theme2/src/default_colors.rs 🔗

@@ -332,43 +332,43 @@ impl From<DefaultColorScaleSet> for ColorScaleSet {
 
 pub fn default_color_scales() -> ColorScales {
     ColorScales {
-        gray: gray().into(),
-        mauve: mauve().into(),
-        slate: slate().into(),
-        sage: sage().into(),
-        olive: olive().into(),
-        sand: sand().into(),
-        gold: gold().into(),
-        bronze: bronze().into(),
-        brown: brown().into(),
-        yellow: yellow().into(),
-        amber: amber().into(),
-        orange: orange().into(),
-        tomato: tomato().into(),
-        red: red().into(),
-        ruby: ruby().into(),
-        crimson: crimson().into(),
-        pink: pink().into(),
-        plum: plum().into(),
-        purple: purple().into(),
-        violet: violet().into(),
-        iris: iris().into(),
-        indigo: indigo().into(),
-        blue: blue().into(),
-        cyan: cyan().into(),
-        teal: teal().into(),
-        jade: jade().into(),
-        green: green().into(),
-        grass: grass().into(),
-        lime: lime().into(),
-        mint: mint().into(),
-        sky: sky().into(),
-        black: black().into(),
-        white: white().into(),
+        gray: gray(),
+        mauve: mauve(),
+        slate: slate(),
+        sage: sage(),
+        olive: olive(),
+        sand: sand(),
+        gold: gold(),
+        bronze: bronze(),
+        brown: brown(),
+        yellow: yellow(),
+        amber: amber(),
+        orange: orange(),
+        tomato: tomato(),
+        red: red(),
+        ruby: ruby(),
+        crimson: crimson(),
+        pink: pink(),
+        plum: plum(),
+        purple: purple(),
+        violet: violet(),
+        iris: iris(),
+        indigo: indigo(),
+        blue: blue(),
+        cyan: cyan(),
+        teal: teal(),
+        jade: jade(),
+        green: green(),
+        grass: grass(),
+        lime: lime(),
+        mint: mint(),
+        sky: sky(),
+        black: black(),
+        white: white(),
     }
 }
 
-fn gray() -> DefaultColorScaleSet {
+fn gray() -> ColorScaleSet {
     DefaultColorScaleSet {
         scale: "Gray",
         light: [
@@ -428,9 +428,10 @@ fn gray() -> DefaultColorScaleSet {
             "#ffffffed",
         ],
     }
+    .into()
 }
 
-fn mauve() -> DefaultColorScaleSet {
+fn mauve() -> ColorScaleSet {
     DefaultColorScaleSet {
         scale: "Mauve",
         light: [
@@ -490,9 +491,10 @@ fn mauve() -> DefaultColorScaleSet {
             "#fdfdffef",
         ],
     }
+    .into()
 }
 
-fn slate() -> DefaultColorScaleSet {
+fn slate() -> ColorScaleSet {
     DefaultColorScaleSet {
         scale: "Slate",
         light: [
@@ -552,9 +554,10 @@ fn slate() -> DefaultColorScaleSet {
             "#fcfdffef",
         ],
     }
+    .into()
 }
 
-fn sage() -> DefaultColorScaleSet {
+fn sage() -> ColorScaleSet {
     DefaultColorScaleSet {
         scale: "Sage",
         light: [
@@ -614,9 +617,10 @@ fn sage() -> DefaultColorScaleSet {
             "#fdfffeed",
         ],
     }
+    .into()
 }
 
-fn olive() -> DefaultColorScaleSet {
+fn olive() -> ColorScaleSet {
     DefaultColorScaleSet {
         scale: "Olive",
         light: [
@@ -676,9 +680,10 @@ fn olive() -> DefaultColorScaleSet {
             "#fdfffded",
         ],
     }
+    .into()
 }
 
-fn sand() -> DefaultColorScaleSet {
+fn sand() -> ColorScaleSet {
     DefaultColorScaleSet {
         scale: "Sand",
         light: [
@@ -738,9 +743,10 @@ fn sand() -> DefaultColorScaleSet {
             "#fffffded",
         ],
     }
+    .into()
 }
 
-fn gold() -> DefaultColorScaleSet {
+fn gold() -> ColorScaleSet {
     DefaultColorScaleSet {
         scale: "Gold",
         light: [
@@ -800,9 +806,10 @@ fn gold() -> DefaultColorScaleSet {
             "#fef7ede7",
         ],
     }
+    .into()
 }
 
-fn bronze() -> DefaultColorScaleSet {
+fn bronze() -> ColorScaleSet {
     DefaultColorScaleSet {
         scale: "Bronze",
         light: [
@@ -862,9 +869,10 @@ fn bronze() -> DefaultColorScaleSet {
             "#fff1e9ec",
         ],
     }
+    .into()
 }
 
-fn brown() -> DefaultColorScaleSet {
+fn brown() -> ColorScaleSet {
     DefaultColorScaleSet {
         scale: "Brown",
         light: [
@@ -924,9 +932,10 @@ fn brown() -> DefaultColorScaleSet {
             "#feecd4f2",
         ],
     }
+    .into()
 }
 
-fn yellow() -> DefaultColorScaleSet {
+fn yellow() -> ColorScaleSet {
     DefaultColorScaleSet {
         scale: "Yellow",
         light: [
@@ -986,9 +995,10 @@ fn yellow() -> DefaultColorScaleSet {
             "#fef6baf6",
         ],
     }
+    .into()
 }
 
-fn amber() -> DefaultColorScaleSet {
+fn amber() -> ColorScaleSet {
     DefaultColorScaleSet {
         scale: "Amber",
         light: [
@@ -1048,9 +1058,10 @@ fn amber() -> DefaultColorScaleSet {
             "#ffe7b3ff",
         ],
     }
+    .into()
 }
 
-fn orange() -> DefaultColorScaleSet {
+fn orange() -> ColorScaleSet {
     DefaultColorScaleSet {
         scale: "Orange",
         light: [
@@ -1110,9 +1121,10 @@ fn orange() -> DefaultColorScaleSet {
             "#ffe0c2ff",
         ],
     }
+    .into()
 }
 
-fn tomato() -> DefaultColorScaleSet {
+fn tomato() -> ColorScaleSet {
     DefaultColorScaleSet {
         scale: "Tomato",
         light: [
@@ -1172,9 +1184,10 @@ fn tomato() -> DefaultColorScaleSet {
             "#ffd6cefb",
         ],
     }
+    .into()
 }
 
-fn red() -> DefaultColorScaleSet {
+fn red() -> ColorScaleSet {
     DefaultColorScaleSet {
         scale: "Red",
         light: [
@@ -1234,9 +1247,10 @@ fn red() -> DefaultColorScaleSet {
             "#ffd1d9ff",
         ],
     }
+    .into()
 }
 
-fn ruby() -> DefaultColorScaleSet {
+fn ruby() -> ColorScaleSet {
     DefaultColorScaleSet {
         scale: "Ruby",
         light: [
@@ -1296,9 +1310,10 @@ fn ruby() -> DefaultColorScaleSet {
             "#ffd3e2fe",
         ],
     }
+    .into()
 }
 
-fn crimson() -> DefaultColorScaleSet {
+fn crimson() -> ColorScaleSet {
     DefaultColorScaleSet {
         scale: "Crimson",
         light: [
@@ -1358,9 +1373,10 @@ fn crimson() -> DefaultColorScaleSet {
             "#ffd5eafd",
         ],
     }
+    .into()
 }
 
-fn pink() -> DefaultColorScaleSet {
+fn pink() -> ColorScaleSet {
     DefaultColorScaleSet {
         scale: "Pink",
         light: [
@@ -1420,9 +1436,10 @@ fn pink() -> DefaultColorScaleSet {
             "#ffd3ecfd",
         ],
     }
+    .into()
 }
 
-fn plum() -> DefaultColorScaleSet {
+fn plum() -> ColorScaleSet {
     DefaultColorScaleSet {
         scale: "Plum",
         light: [
@@ -1482,9 +1499,10 @@ fn plum() -> DefaultColorScaleSet {
             "#feddfef4",
         ],
     }
+    .into()
 }
 
-fn purple() -> DefaultColorScaleSet {
+fn purple() -> ColorScaleSet {
     DefaultColorScaleSet {
         scale: "Purple",
         light: [
@@ -1544,9 +1562,10 @@ fn purple() -> DefaultColorScaleSet {
             "#f1ddfffa",
         ],
     }
+    .into()
 }
 
-fn violet() -> DefaultColorScaleSet {
+fn violet() -> ColorScaleSet {
     DefaultColorScaleSet {
         scale: "Violet",
         light: [
@@ -1606,9 +1625,10 @@ fn violet() -> DefaultColorScaleSet {
             "#e3defffe",
         ],
     }
+    .into()
 }
 
-fn iris() -> DefaultColorScaleSet {
+fn iris() -> ColorScaleSet {
     DefaultColorScaleSet {
         scale: "Iris",
         light: [
@@ -1668,9 +1688,10 @@ fn iris() -> DefaultColorScaleSet {
             "#e1e0fffe",
         ],
     }
+    .into()
 }
 
-fn indigo() -> DefaultColorScaleSet {
+fn indigo() -> ColorScaleSet {
     DefaultColorScaleSet {
         scale: "Indigo",
         light: [
@@ -1730,9 +1751,10 @@ fn indigo() -> DefaultColorScaleSet {
             "#d6e1ffff",
         ],
     }
+    .into()
 }
 
-fn blue() -> DefaultColorScaleSet {
+fn blue() -> ColorScaleSet {
     DefaultColorScaleSet {
         scale: "Blue",
         light: [
@@ -1792,9 +1814,10 @@ fn blue() -> DefaultColorScaleSet {
             "#c2e6ffff",
         ],
     }
+    .into()
 }
 
-fn cyan() -> DefaultColorScaleSet {
+fn cyan() -> ColorScaleSet {
     DefaultColorScaleSet {
         scale: "Cyan",
         light: [
@@ -1854,9 +1877,10 @@ fn cyan() -> DefaultColorScaleSet {
             "#bbf3fef7",
         ],
     }
+    .into()
 }
 
-fn teal() -> DefaultColorScaleSet {
+fn teal() -> ColorScaleSet {
     DefaultColorScaleSet {
         scale: "Teal",
         light: [
@@ -1916,9 +1940,10 @@ fn teal() -> DefaultColorScaleSet {
             "#b8ffebef",
         ],
     }
+    .into()
 }
 
-fn jade() -> DefaultColorScaleSet {
+fn jade() -> ColorScaleSet {
     DefaultColorScaleSet {
         scale: "Jade",
         light: [
@@ -1978,9 +2003,10 @@ fn jade() -> DefaultColorScaleSet {
             "#b8ffe1ef",
         ],
     }
+    .into()
 }
 
-fn green() -> DefaultColorScaleSet {
+fn green() -> ColorScaleSet {
     DefaultColorScaleSet {
         scale: "Green",
         light: [
@@ -2040,9 +2066,10 @@ fn green() -> DefaultColorScaleSet {
             "#bbffd7f0",
         ],
     }
+    .into()
 }
 
-fn grass() -> DefaultColorScaleSet {
+fn grass() -> ColorScaleSet {
     DefaultColorScaleSet {
         scale: "Grass",
         light: [
@@ -2102,9 +2129,10 @@ fn grass() -> DefaultColorScaleSet {
             "#ceffceef",
         ],
     }
+    .into()
 }
 
-fn lime() -> DefaultColorScaleSet {
+fn lime() -> ColorScaleSet {
     DefaultColorScaleSet {
         scale: "Lime",
         light: [
@@ -2164,9 +2192,10 @@ fn lime() -> DefaultColorScaleSet {
             "#e9febff7",
         ],
     }
+    .into()
 }
 
-fn mint() -> DefaultColorScaleSet {
+fn mint() -> ColorScaleSet {
     DefaultColorScaleSet {
         scale: "Mint",
         light: [
@@ -2226,9 +2255,10 @@ fn mint() -> DefaultColorScaleSet {
             "#cbfee9f5",
         ],
     }
+    .into()
 }
 
-fn sky() -> DefaultColorScaleSet {
+fn sky() -> ColorScaleSet {
     DefaultColorScaleSet {
         scale: "Sky",
         light: [
@@ -2288,9 +2318,10 @@ fn sky() -> DefaultColorScaleSet {
             "#c2f3ffff",
         ],
     }
+    .into()
 }
 
-fn black() -> DefaultColorScaleSet {
+fn black() -> ColorScaleSet {
     DefaultColorScaleSet {
         scale: "Black",
         light: [
@@ -2350,9 +2381,10 @@ fn black() -> DefaultColorScaleSet {
             "#000000f2",
         ],
     }
+    .into()
 }
 
-fn white() -> DefaultColorScaleSet {
+fn white() -> ColorScaleSet {
     DefaultColorScaleSet {
         scale: "White",
         light: [
@@ -2412,4 +2444,5 @@ fn white() -> DefaultColorScaleSet {
             "#fffffff2",
         ],
     }
+    .into()
 }