ep_cli: Allow dynamically loading prompts (#48046)

Ben Kunkle created

Closes #ISSUE

Release Notes:

- N/A *or* Added/Fixed/Improved ...

Change summary

crates/edit_prediction_cli/Cargo.toml              |  5 +
crates/edit_prediction_cli/src/filter_languages.rs | 54 ++++++++++++++-
crates/edit_prediction_cli/src/format_prompt.rs    |  4 
crates/edit_prediction_cli/src/main.rs             |  1 
crates/edit_prediction_cli/src/prompt_assets.rs    | 46 +++++++++++++
crates/edit_prediction_cli/src/qa.rs               |  5 
crates/edit_prediction_cli/src/repair.rs           |  5 
7 files changed, 106 insertions(+), 14 deletions(-)

Detailed changes

crates/edit_prediction_cli/Cargo.toml 🔗

@@ -62,13 +62,16 @@ rand.workspace = true
 similar = "2.7.0"
 flate2 = "1.1.8"
 toml.workspace = true
-rust-embed.workspace = true
+rust-embed = { workspace = true, features = ["debug-embed"] }
 
 # Wasmtime is included as a dependency in order to enable the same
 # features that are enabled in Zed.
 #
 # If we don't enable these features we get crashes when creating
 # a Tree-sitter WasmStore.
+[features]
+dynamic_prompts = []
+
 [package.metadata.cargo-machete]
 ignored = ["wasmtime"]
 

crates/edit_prediction_cli/src/filter_languages.rs 🔗

@@ -18,17 +18,24 @@
 use anyhow::{Context as _, Result, bail};
 use clap::Args;
 use collections::HashMap;
-use rust_embed::RustEmbed;
 use serde::Deserialize;
 use std::ffi::OsStr;
 use std::fs::File;
 use std::io::{self, BufRead, BufReader, BufWriter, Write};
 use std::path::{Path, PathBuf};
 
-#[derive(RustEmbed)]
-#[folder = "../languages/src/"]
-#[include = "*/config.toml"]
-struct LanguageConfigs;
+#[cfg(not(feature = "dynamic_prompts"))]
+mod language_configs_embedded {
+    use rust_embed::RustEmbed;
+
+    #[derive(RustEmbed)]
+    #[folder = "../languages/src/"]
+    #[include = "*/config.toml"]
+    pub struct LanguageConfigs;
+}
+
+#[cfg(not(feature = "dynamic_prompts"))]
+use language_configs_embedded::LanguageConfigs;
 
 #[derive(Debug, Deserialize)]
 struct LanguageConfig {
@@ -89,6 +96,7 @@ pub struct FilterLanguagesArgs {
     pub show_top_excluded: Option<usize>,
 }
 
+#[cfg(not(feature = "dynamic_prompts"))]
 fn build_extension_to_language_map() -> HashMap<String, String> {
     let mut map = HashMap::default();
 
@@ -113,6 +121,42 @@ fn build_extension_to_language_map() -> HashMap<String, String> {
     map
 }
 
+#[cfg(feature = "dynamic_prompts")]
+fn build_extension_to_language_map() -> HashMap<String, String> {
+    const LANGUAGES_SRC_DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/../languages/src");
+
+    let mut map = HashMap::default();
+
+    let languages_dir = Path::new(LANGUAGES_SRC_DIR);
+    let entries = match std::fs::read_dir(languages_dir) {
+        Ok(e) => e,
+        Err(_) => return map,
+    };
+
+    for entry in entries.flatten() {
+        let config_path = entry.path().join("config.toml");
+        if !config_path.exists() {
+            continue;
+        }
+
+        let content_str = match std::fs::read_to_string(&config_path) {
+            Ok(s) => s,
+            Err(_) => continue,
+        };
+
+        let config: LanguageConfig = match toml::from_str(&content_str) {
+            Ok(c) => c,
+            Err(_) => continue,
+        };
+
+        for suffix in &config.path_suffixes {
+            map.insert(suffix.to_lowercase(), config.name.clone());
+        }
+    }
+
+    map
+}
+
 fn get_all_languages(extension_map: &HashMap<String, String>) -> Vec<(String, Vec<String>)> {
     let mut language_to_extensions: HashMap<String, Vec<String>> = HashMap::default();
 

crates/edit_prediction_cli/src/format_prompt.rs 🔗

@@ -164,7 +164,6 @@ pub fn zeta2_output_for_patch(
 pub struct TeacherPrompt;
 
 impl TeacherPrompt {
-    const PROMPT: &str = include_str!("prompts/teacher.md");
     pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
     pub(crate) const EDITABLE_REGION_END: &str = "\n<|editable_region_end|>";
     pub(crate) const USER_CURSOR_MARKER: &str = "<|user_cursor|>";
@@ -182,7 +181,8 @@ impl TeacherPrompt {
         let context = Self::format_context(example);
         let cursor_excerpt = Self::format_cursor_excerpt(example, editable_range, context_range);
 
-        let prompt = Self::PROMPT
+        let prompt_template = crate::prompt_assets::get_prompt("teacher.md");
+        let prompt = prompt_template
             .replace("{{context}}", &context)
             .replace("{{edit_history}}", &edit_history)
             .replace("{{cursor_excerpt}}", &cursor_excerpt);

crates/edit_prediction_cli/src/prompt_assets.rs 🔗

@@ -0,0 +1,46 @@
+use std::borrow::Cow;
+
+#[cfg(feature = "dynamic_prompts")]
+pub fn get_prompt(name: &'static str) -> Cow<'static, str> {
+    use anyhow::Context;
+    use std::collections::HashMap;
+    use std::path::Path;
+    use std::sync::{LazyLock, RwLock};
+
+    const PROMPTS_DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/src/prompts");
+
+    static PROMPT_CACHE: LazyLock<RwLock<HashMap<&'static str, &'static str>>> =
+        LazyLock::new(|| RwLock::new(HashMap::default()));
+
+    let filesystem_path = Path::new(PROMPTS_DIR).join(name);
+    if let Some(cached_contents) = PROMPT_CACHE.read().unwrap().get(name) {
+        return Cow::Borrowed(cached_contents);
+    }
+    let contents = std::fs::read_to_string(&filesystem_path)
+        .context(name)
+        .expect("Failed to read prompt");
+    let leaked = contents.leak();
+    PROMPT_CACHE.write().unwrap().insert(name, leaked);
+    return Cow::Borrowed(leaked);
+}
+
+#[cfg(not(feature = "dynamic_prompts"))]
+pub fn get_prompt(name: &'static str) -> Cow<'static, str> {
+    use rust_embed::RustEmbed;
+
+    #[derive(RustEmbed)]
+    #[folder = "src/prompts"]
+    struct EmbeddedPrompts;
+
+    match EmbeddedPrompts::get(name) {
+        Some(file) => match file.data {
+            Cow::Borrowed(bytes) => {
+                Cow::Borrowed(std::str::from_utf8(bytes).expect("prompt file is not valid UTF-8"))
+            }
+            Cow::Owned(bytes) => {
+                Cow::Owned(String::from_utf8(bytes).expect("prompt file is not valid UTF-8"))
+            }
+        },
+        None => panic!("prompt file not found: {name}"),
+    }
+}

crates/edit_prediction_cli/src/qa.rs 🔗

@@ -15,8 +15,6 @@ use serde::{Deserialize, Serialize};
 use std::io::{BufWriter, Write};
 use std::path::PathBuf;
 
-const PROMPT_TEMPLATE: &str = include_str!("prompts/qa.md");
-
 /// Arguments for the QA command.
 #[derive(Debug, Clone, clap::Args)]
 pub struct QaArgs {
@@ -94,8 +92,9 @@ pub fn build_prompt(example: &Example) -> Option<String> {
         }
     }
 
+    let prompt_template = crate::prompt_assets::get_prompt("qa.md");
     Some(
-        PROMPT_TEMPLATE
+        prompt_template
             .replace("{edit_history}", &edit_history)
             .replace("{cursor_excerpt}", &cursor_excerpt)
             .replace("{actual_patch_word_diff}", &actual_patch_word_diff),

crates/edit_prediction_cli/src/repair.rs 🔗

@@ -16,8 +16,6 @@ use anyhow::Result;
 use std::io::{BufWriter, Write};
 use std::path::PathBuf;
 
-const PROMPT_TEMPLATE: &str = include_str!("prompts/repair.md");
-
 /// Arguments for the repair command.
 #[derive(Debug, Clone, clap::Args)]
 pub struct RepairArgs {
@@ -91,8 +89,9 @@ pub fn build_repair_prompt(example: &Example) -> Option<String> {
         .confidence
         .map_or("unknown".to_string(), |v| v.to_string());
 
+    let prompt_template = crate::prompt_assets::get_prompt("repair.md");
     Some(
-        PROMPT_TEMPLATE
+        prompt_template
             .replace("{edit_history}", &edit_history)
             .replace("{context}", &context)
             .replace("{cursor_excerpt}", &cursor_excerpt)