diff --git a/crates/edit_prediction_cli/Cargo.toml b/crates/edit_prediction_cli/Cargo.toml index 0d81b6fd155f6ddc61a9fee8fcba3325a2acf39f..20451bdc7a7e2e96a9a1b48ed32180250f64b6b6 100644 --- a/crates/edit_prediction_cli/Cargo.toml +++ b/crates/edit_prediction_cli/Cargo.toml @@ -62,13 +62,16 @@ rand.workspace = true similar = "2.7.0" flate2 = "1.1.8" toml.workspace = true -rust-embed.workspace = true +rust-embed = { workspace = true, features = ["debug-embed"] } # Wasmtime is included as a dependency in order to enable the same # features that are enabled in Zed. # # If we don't enable these features we get crashes when creating # a Tree-sitter WasmStore. +[features] +dynamic_prompts = [] + [package.metadata.cargo-machete] ignored = ["wasmtime"] diff --git a/crates/edit_prediction_cli/src/filter_languages.rs b/crates/edit_prediction_cli/src/filter_languages.rs index fa4addbd240ff611c5ceba53b3136c4f3b35f0b9..355b5708d43c35c74bf62608726309389a1bfe32 100644 --- a/crates/edit_prediction_cli/src/filter_languages.rs +++ b/crates/edit_prediction_cli/src/filter_languages.rs @@ -18,17 +18,24 @@ use anyhow::{Context as _, Result, bail}; use clap::Args; use collections::HashMap; -use rust_embed::RustEmbed; use serde::Deserialize; use std::ffi::OsStr; use std::fs::File; use std::io::{self, BufRead, BufReader, BufWriter, Write}; use std::path::{Path, PathBuf}; -#[derive(RustEmbed)] -#[folder = "../languages/src/"] -#[include = "*/config.toml"] -struct LanguageConfigs; +#[cfg(not(feature = "dynamic_prompts"))] +mod language_configs_embedded { + use rust_embed::RustEmbed; + + #[derive(RustEmbed)] + #[folder = "../languages/src/"] + #[include = "*/config.toml"] + pub struct LanguageConfigs; +} + +#[cfg(not(feature = "dynamic_prompts"))] +use language_configs_embedded::LanguageConfigs; #[derive(Debug, Deserialize)] struct LanguageConfig { @@ -89,6 +96,7 @@ pub struct FilterLanguagesArgs { pub show_top_excluded: Option, } +#[cfg(not(feature = "dynamic_prompts"))] fn build_extension_to_language_map() -> HashMap { let mut map = HashMap::default(); @@ -113,6 +121,42 @@ fn build_extension_to_language_map() -> HashMap { map } +#[cfg(feature = "dynamic_prompts")] +fn build_extension_to_language_map() -> HashMap { + const LANGUAGES_SRC_DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/../languages/src"); + + let mut map = HashMap::default(); + + let languages_dir = Path::new(LANGUAGES_SRC_DIR); + let entries = match std::fs::read_dir(languages_dir) { + Ok(e) => e, + Err(_) => return map, + }; + + for entry in entries.flatten() { + let config_path = entry.path().join("config.toml"); + if !config_path.exists() { + continue; + } + + let content_str = match std::fs::read_to_string(&config_path) { + Ok(s) => s, + Err(_) => continue, + }; + + let config: LanguageConfig = match toml::from_str(&content_str) { + Ok(c) => c, + Err(_) => continue, + }; + + for suffix in &config.path_suffixes { + map.insert(suffix.to_lowercase(), config.name.clone()); + } + } + + map +} + fn get_all_languages(extension_map: &HashMap) -> Vec<(String, Vec)> { let mut language_to_extensions: HashMap> = HashMap::default(); diff --git a/crates/edit_prediction_cli/src/format_prompt.rs b/crates/edit_prediction_cli/src/format_prompt.rs index d7d5fa75b55182bee7035e01d3e030f3a34af565..5588200f745b90d4f92b0a87f45571753f3b0d6f 100644 --- a/crates/edit_prediction_cli/src/format_prompt.rs +++ b/crates/edit_prediction_cli/src/format_prompt.rs @@ -164,7 +164,6 @@ pub fn zeta2_output_for_patch( pub struct TeacherPrompt; impl TeacherPrompt { - const PROMPT: &str = include_str!("prompts/teacher.md"); pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n"; pub(crate) const EDITABLE_REGION_END: &str = "\n<|editable_region_end|>"; pub(crate) const USER_CURSOR_MARKER: &str = "<|user_cursor|>"; @@ -182,7 +181,8 @@ impl TeacherPrompt { let context = Self::format_context(example); let cursor_excerpt = Self::format_cursor_excerpt(example, editable_range, context_range); - let prompt = Self::PROMPT + let prompt_template = crate::prompt_assets::get_prompt("teacher.md"); + let prompt = prompt_template .replace("{{context}}", &context) .replace("{{edit_history}}", &edit_history) .replace("{{cursor_excerpt}}", &cursor_excerpt); diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index 2f229b3c392b5dc9114819637e0f3770b33db0ae..96662afa13a86ffb1b99d37aa50232eeafed9928 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -12,6 +12,7 @@ mod parse_output; mod paths; mod predict; mod progress; +mod prompt_assets; mod pull_examples; mod qa; mod reorder_patch; diff --git a/crates/edit_prediction_cli/src/prompt_assets.rs b/crates/edit_prediction_cli/src/prompt_assets.rs new file mode 100644 index 0000000000000000000000000000000000000000..cc3497a612fc2d0f60c380592c4c917b0af73636 --- /dev/null +++ b/crates/edit_prediction_cli/src/prompt_assets.rs @@ -0,0 +1,46 @@ +use std::borrow::Cow; + +#[cfg(feature = "dynamic_prompts")] +pub fn get_prompt(name: &'static str) -> Cow<'static, str> { + use anyhow::Context; + use std::collections::HashMap; + use std::path::Path; + use std::sync::{LazyLock, RwLock}; + + const PROMPTS_DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/src/prompts"); + + static PROMPT_CACHE: LazyLock>> = + LazyLock::new(|| RwLock::new(HashMap::default())); + + let filesystem_path = Path::new(PROMPTS_DIR).join(name); + if let Some(cached_contents) = PROMPT_CACHE.read().unwrap().get(name) { + return Cow::Borrowed(cached_contents); + } + let contents = std::fs::read_to_string(&filesystem_path) + .context(name) + .expect("Failed to read prompt"); + let leaked = contents.leak(); + PROMPT_CACHE.write().unwrap().insert(name, leaked); + return Cow::Borrowed(leaked); +} + +#[cfg(not(feature = "dynamic_prompts"))] +pub fn get_prompt(name: &'static str) -> Cow<'static, str> { + use rust_embed::RustEmbed; + + #[derive(RustEmbed)] + #[folder = "src/prompts"] + struct EmbeddedPrompts; + + match EmbeddedPrompts::get(name) { + Some(file) => match file.data { + Cow::Borrowed(bytes) => { + Cow::Borrowed(std::str::from_utf8(bytes).expect("prompt file is not valid UTF-8")) + } + Cow::Owned(bytes) => { + Cow::Owned(String::from_utf8(bytes).expect("prompt file is not valid UTF-8")) + } + }, + None => panic!("prompt file not found: {name}"), + } +} diff --git a/crates/edit_prediction_cli/src/qa.rs b/crates/edit_prediction_cli/src/qa.rs index e9c2e44549b67940a1dcfbb5529c123b4287f7e2..9a54353040afbe57ba431fc103b4a16f7cbca232 100644 --- a/crates/edit_prediction_cli/src/qa.rs +++ b/crates/edit_prediction_cli/src/qa.rs @@ -15,8 +15,6 @@ use serde::{Deserialize, Serialize}; use std::io::{BufWriter, Write}; use std::path::PathBuf; -const PROMPT_TEMPLATE: &str = include_str!("prompts/qa.md"); - /// Arguments for the QA command. #[derive(Debug, Clone, clap::Args)] pub struct QaArgs { @@ -94,8 +92,9 @@ pub fn build_prompt(example: &Example) -> Option { } } + let prompt_template = crate::prompt_assets::get_prompt("qa.md"); Some( - PROMPT_TEMPLATE + prompt_template .replace("{edit_history}", &edit_history) .replace("{cursor_excerpt}", &cursor_excerpt) .replace("{actual_patch_word_diff}", &actual_patch_word_diff), diff --git a/crates/edit_prediction_cli/src/repair.rs b/crates/edit_prediction_cli/src/repair.rs index e3cf424244dfcc2e407464a4048a6f7500813a1c..78d7232209ef6268fce943bff34e3b08274a02e8 100644 --- a/crates/edit_prediction_cli/src/repair.rs +++ b/crates/edit_prediction_cli/src/repair.rs @@ -16,8 +16,6 @@ use anyhow::Result; use std::io::{BufWriter, Write}; use std::path::PathBuf; -const PROMPT_TEMPLATE: &str = include_str!("prompts/repair.md"); - /// Arguments for the repair command. #[derive(Debug, Clone, clap::Args)] pub struct RepairArgs { @@ -91,8 +89,9 @@ pub fn build_repair_prompt(example: &Example) -> Option { .confidence .map_or("unknown".to_string(), |v| v.to_string()); + let prompt_template = crate::prompt_assets::get_prompt("repair.md"); Some( - PROMPT_TEMPLATE + prompt_template .replace("{edit_history}", &edit_history) .replace("{context}", &context) .replace("{cursor_excerpt}", &cursor_excerpt)