ep_cli: Add filter languages subcommand (#47242)

Ben Kunkle created

Closes #ISSUE

Release Notes:

- N/A *or* Added/Fixed/Improved ...

Change summary

Cargo.lock                                         |   2 
crates/edit_prediction_cli/Cargo.toml              |   2 
crates/edit_prediction_cli/src/filter_languages.rs | 507 ++++++++++++++++
crates/edit_prediction_cli/src/main.rs             |  17 
4 files changed, 527 insertions(+), 1 deletion(-)

Detailed changes

Cargo.lock 🔗

@@ -5336,6 +5336,7 @@ dependencies = [
  "rand 0.9.2",
  "release_channel",
  "reqwest_client",
+ "rust-embed",
  "serde",
  "serde_json",
  "settings",
@@ -5346,6 +5347,7 @@ dependencies = [
  "sqlez_macros",
  "tempfile",
  "terminal_view",
+ "toml 0.8.23",
  "util",
  "wasmtime",
  "watch",

crates/edit_prediction_cli/Cargo.toml 🔗

@@ -59,6 +59,8 @@ zeta_prompt.workspace = true
 rand.workspace = true
 similar = "2.7.0"
 flate2 = "1.1.8"
+toml.workspace = true
+rust-embed.workspace = true
 
 # Wasmtime is included as a dependency in order to enable the same
 # features that are enabled in Zed.

crates/edit_prediction_cli/src/filter_languages.rs 🔗

@@ -0,0 +1,507 @@
+//! `ep filter-languages` implementation.
+//!
+//! This command filters a JSONL dataset to include only examples where the
+//! cursor is in a file of a specified programming language.
+//!
+//! # Usage
+//!
+//! ```text
+//! ep filter-languages [input.jsonl] --languages rust,python,go
+//! ```
+//!
+//! # Language Detection
+//!
+//! Language is detected based on file extension of the `cursor_path` field.
+//! The extension-to-language mapping is built from the embedded language
+//! config files in the `languages` crate.
+
+use anyhow::{Context as _, Result, bail};
+use clap::Args;
+use collections::HashMap;
+use rust_embed::RustEmbed;
+use serde::Deserialize;
+use std::ffi::OsStr;
+use std::fs::File;
+use std::io::{self, BufRead, BufReader, BufWriter, Write};
+use std::path::{Path, PathBuf};
+
+#[derive(RustEmbed)]
+#[folder = "../languages/src/"]
+#[include = "*/config.toml"]
+struct LanguageConfigs;
+
+#[derive(Debug, Deserialize)]
+struct LanguageConfig {
+    name: String,
+    #[serde(default)]
+    path_suffixes: Vec<String>,
+}
+
+/// `ep filter-languages` CLI args.
+#[derive(Debug, Args, Clone)]
+#[command(
+    about = "Filter a JSONL dataset by programming language (based on cursor_path extension)",
+    after_help = r#"EXAMPLES:
+  # Filter to only Rust, Python, and Go examples
+  ep filter-languages input.jsonl --languages rust,python,go -o filtered.jsonl
+
+  # Filter by language and also include specific extensions
+  ep filter-languages input.jsonl --languages rust,python --extensions txt,md -o filtered.jsonl
+
+  # Filter by extensions only (no language filter)
+  ep filter-languages input.jsonl --extensions cs,java,swift -o filtered.jsonl
+
+  # List available languages
+  ep filter-languages --list
+
+  # Show statistics about languages in the input
+  ep filter-languages input.jsonl --stats
+
+NOTES:
+  Language names are case-insensitive.
+  Extensions should be specified without the leading dot (e.g., 'txt' not '.txt').
+  Use --list to see all available language names.
+"#
+)]
+pub struct FilterLanguagesArgs {
+    /// Comma-separated list of languages to include
+    #[arg(long, short = 'l', value_delimiter = ',')]
+    pub languages: Option<Vec<String>>,
+
+    /// Comma-separated list of file extensions to include (without leading dot)
+    #[arg(long, short = 'e', value_delimiter = ',')]
+    pub extensions: Option<Vec<String>>,
+
+    /// List all available languages and their extensions
+    #[arg(long)]
+    pub list: bool,
+
+    /// Show statistics about language distribution in the input
+    #[arg(long)]
+    pub stats: bool,
+
+    /// Include examples where language could not be detected
+    #[arg(long)]
+    pub include_unknown: bool,
+
+    /// Show top N excluded file extensions after filtering
+    #[arg(long, value_name = "N")]
+    pub show_top_excluded: Option<usize>,
+}
+
+fn build_extension_to_language_map() -> HashMap<String, String> {
+    let mut map = HashMap::default();
+
+    for file_path in LanguageConfigs::iter() {
+        if let Some(content) = LanguageConfigs::get(&file_path) {
+            let content_str = match std::str::from_utf8(&content.data) {
+                Ok(s) => s,
+                Err(_) => continue,
+            };
+
+            let config: LanguageConfig = match toml::from_str(content_str) {
+                Ok(c) => c,
+                Err(_) => continue,
+            };
+
+            for suffix in &config.path_suffixes {
+                map.insert(suffix.to_lowercase(), config.name.clone());
+            }
+        }
+    }
+
+    map
+}
+
+fn get_all_languages(extension_map: &HashMap<String, String>) -> Vec<(String, Vec<String>)> {
+    let mut language_to_extensions: HashMap<String, Vec<String>> = HashMap::default();
+
+    for (ext, lang) in extension_map {
+        language_to_extensions
+            .entry(lang.clone())
+            .or_default()
+            .push(ext.clone());
+    }
+
+    let mut result: Vec<_> = language_to_extensions.into_iter().collect();
+    result.sort_by(|a, b| a.0.to_lowercase().cmp(&b.0.to_lowercase()));
+    for (_, extensions) in &mut result {
+        extensions.sort();
+    }
+    result
+}
+
+fn detect_language(cursor_path: &str, extension_map: &HashMap<String, String>) -> Option<String> {
+    let path = Path::new(cursor_path);
+
+    if let Some(ext) = path.extension().and_then(OsStr::to_str) {
+        if let Some(lang) = extension_map.get(&ext.to_lowercase()) {
+            return Some(lang.clone());
+        }
+    }
+
+    if let Some(file_name) = path.file_name().and_then(OsStr::to_str) {
+        if let Some(lang) = extension_map.get(&file_name.to_lowercase()) {
+            return Some(lang.clone());
+        }
+    }
+
+    None
+}
+
+fn get_extension(cursor_path: &str) -> Option<String> {
+    let path = Path::new(cursor_path);
+
+    if let Some(ext) = path.extension().and_then(OsStr::to_str) {
+        return Some(ext.to_lowercase());
+    }
+
+    if let Some(file_name) = path.file_name().and_then(OsStr::to_str) {
+        return Some(file_name.to_lowercase());
+    }
+
+    None
+}
+
+fn read_lines_streaming(
+    input: Option<&Path>,
+) -> Result<Box<dyn Iterator<Item = io::Result<String>>>> {
+    let reader: Box<dyn BufRead> = match input {
+        Some(path) => {
+            let file =
+                File::open(path).with_context(|| format!("failed to open '{}'", path.display()))?;
+            Box::new(BufReader::new(file))
+        }
+        None => Box::new(BufReader::new(io::stdin())),
+    };
+    Ok(Box::new(reader.lines()))
+}
+
+fn get_cursor_path(line: &str) -> Option<String> {
+    let value: serde_json::Value = serde_json::from_str(line).ok()?;
+    value
+        .get("cursor_path")
+        .and_then(|v| v.as_str())
+        .map(|s| s.to_string())
+}
+
+pub fn run_filter_languages(
+    args: &FilterLanguagesArgs,
+    inputs: &[PathBuf],
+    output: Option<&PathBuf>,
+) -> Result<()> {
+    let extension_map = build_extension_to_language_map();
+
+    if args.list {
+        let languages = get_all_languages(&extension_map);
+        println!("Available languages ({}):", languages.len());
+        println!();
+        for (lang, extensions) in languages {
+            println!("  {}: {}", lang, extensions.join(", "));
+        }
+        return Ok(());
+    }
+
+    let input_path: Option<&Path> = match inputs.first().map(|p| p.as_path()) {
+        Some(p) if p.as_os_str() == "-" => None,
+        Some(p) => Some(p),
+        None => None,
+    };
+
+    if args.stats {
+        let stats_input =
+            input_path.with_context(|| "input file is required for --stats (cannot use stdin)")?;
+        return run_stats(stats_input, &extension_map);
+    }
+
+    if args.languages.is_none() && args.extensions.is_none() {
+        bail!(
+            "--languages and/or --extensions is required (use --list to see available languages, or --stats to see input distribution)"
+        );
+    }
+
+    let allowed_languages: std::collections::HashSet<String> = args
+        .languages
+        .as_ref()
+        .map(|langs| langs.iter().map(|l| l.to_lowercase()).collect())
+        .unwrap_or_default();
+
+    let allowed_extensions: std::collections::HashSet<String> = args
+        .extensions
+        .as_ref()
+        .map(|exts| {
+            exts.iter()
+                .map(|e| e.trim_start_matches('.').to_lowercase())
+                .collect()
+        })
+        .unwrap_or_default();
+
+    let language_name_lower_map: HashMap<String, String> = get_all_languages(&extension_map)
+        .into_iter()
+        .map(|(lang, _)| (lang.to_lowercase(), lang))
+        .collect();
+
+    if !allowed_languages.is_empty() {
+        for lang in &allowed_languages {
+            if !language_name_lower_map.contains_key(lang) {
+                eprintln!(
+                    "Warning: '{}' is not a recognized language name. Use --list to see available languages.",
+                    lang
+                );
+            }
+        }
+    }
+
+    let lines = read_lines_streaming(input_path)?;
+
+    let mut writer: Box<dyn Write> = match output {
+        Some(path) => {
+            if let Some(parent) = path.parent() {
+                if !parent.as_os_str().is_empty() {
+                    std::fs::create_dir_all(parent).with_context(|| {
+                        format!("failed to create directory '{}'", parent.display())
+                    })?;
+                }
+            }
+            let file = File::create(path)
+                .with_context(|| format!("failed to create '{}'", path.display()))?;
+            Box::new(BufWriter::new(file))
+        }
+        None => Box::new(BufWriter::new(io::stdout())),
+    };
+
+    let mut total_count = 0usize;
+    let mut included_count = 0usize;
+    let mut unknown_count = 0usize;
+    let mut excluded_extensions: HashMap<String, usize> = HashMap::default();
+
+    for line_result in lines {
+        let line = line_result.context("failed to read line")?;
+        if line.trim().is_empty() {
+            continue;
+        }
+
+        total_count += 1;
+
+        let cursor_path = match get_cursor_path(&line) {
+            Some(p) => p,
+            None => {
+                if args.include_unknown {
+                    unknown_count += 1;
+                    included_count += 1;
+                    writeln!(writer, "{}", line)?;
+                }
+                continue;
+            }
+        };
+
+        let language = detect_language(&cursor_path, &extension_map);
+        let extension = get_extension(&cursor_path);
+
+        let matches_language = match &language {
+            Some(lang) => allowed_languages.contains(&lang.to_lowercase()),
+            None => false,
+        };
+
+        let matches_extension = match &extension {
+            Some(ext) => allowed_extensions.contains(ext),
+            None => false,
+        };
+
+        let should_include = if matches_language || matches_extension {
+            true
+        } else if language.is_none() && args.include_unknown {
+            unknown_count += 1;
+            true
+        } else {
+            if let Some(ext) = &extension {
+                *excluded_extensions.entry(ext.clone()).or_default() += 1;
+            }
+            false
+        };
+
+        if should_include {
+            included_count += 1;
+            writeln!(writer, "{}", line)?;
+        }
+    }
+
+    writer.flush()?;
+
+    eprintln!(
+        "Filtered {} examples to {} ({} unknown language)",
+        total_count, included_count, unknown_count
+    );
+
+    if let Some(top_n) = args.show_top_excluded {
+        if !excluded_extensions.is_empty() {
+            let mut sorted: Vec<_> = excluded_extensions.into_iter().collect();
+            sorted.sort_by(|a, b| b.1.cmp(&a.1));
+            eprintln!("\nTop {} excluded extensions:", top_n.min(sorted.len()));
+            for (ext, count) in sorted.into_iter().take(top_n) {
+                eprintln!("  {:>6}  .{}", count, ext);
+            }
+        }
+    }
+
+    Ok(())
+}
+
+fn run_stats(input: &Path, extension_map: &HashMap<String, String>) -> Result<()> {
+    let lines = read_lines_streaming(Some(input))?;
+
+    let mut language_counts: HashMap<String, usize> = HashMap::default();
+    let mut unknown_extensions: HashMap<String, usize> = HashMap::default();
+    let mut total_count = 0usize;
+
+    for line_result in lines {
+        let line = line_result.context("failed to read line")?;
+        if line.trim().is_empty() {
+            continue;
+        }
+
+        total_count += 1;
+
+        let cursor_path = match get_cursor_path(&line) {
+            Some(p) => p,
+            None => {
+                *language_counts
+                    .entry("<no cursor_path>".to_string())
+                    .or_default() += 1;
+                continue;
+            }
+        };
+
+        match detect_language(&cursor_path, extension_map) {
+            Some(lang) => {
+                *language_counts.entry(lang).or_default() += 1;
+            }
+            None => {
+                let ext = Path::new(&cursor_path)
+                    .extension()
+                    .and_then(OsStr::to_str)
+                    .map(|s| s.to_string())
+                    .unwrap_or_else(|| {
+                        Path::new(&cursor_path)
+                            .file_name()
+                            .and_then(OsStr::to_str)
+                            .map(|s| s.to_string())
+                            .unwrap_or_else(|| "<no extension>".to_string())
+                    });
+                *unknown_extensions.entry(ext).or_default() += 1;
+                *language_counts.entry("<unknown>".to_string()).or_default() += 1;
+            }
+        }
+    }
+
+    let mut sorted_counts: Vec<_> = language_counts.into_iter().collect();
+    sorted_counts.sort_by(|a, b| b.1.cmp(&a.1));
+
+    println!("Language distribution ({} total examples):", total_count);
+    println!();
+    for (lang, count) in &sorted_counts {
+        let pct = (*count as f64 / total_count as f64) * 100.0;
+        println!("  {:>6} ({:>5.1}%)  {}", count, pct, lang);
+    }
+
+    if !unknown_extensions.is_empty() {
+        println!();
+        println!("Unknown extensions:");
+        let mut sorted_unknown: Vec<_> = unknown_extensions.into_iter().collect();
+        sorted_unknown.sort_by(|a, b| b.1.cmp(&a.1));
+        for (ext, count) in sorted_unknown.iter().take(30) {
+            println!("  {:>6}  .{}", count, ext);
+        }
+        if sorted_unknown.len() > 30 {
+            println!("  ... and {} more", sorted_unknown.len() - 30);
+        }
+    }
+
+    Ok(())
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn test_build_extension_map() {
+        let map = build_extension_to_language_map();
+        assert!(!map.is_empty());
+        assert_eq!(map.get("rs"), Some(&"Rust".to_string()));
+        assert_eq!(map.get("py"), Some(&"Python".to_string()));
+        assert_eq!(map.get("go"), Some(&"Go".to_string()));
+    }
+
+    #[test]
+    fn test_detect_language_by_extension() {
+        let map = build_extension_to_language_map();
+
+        assert_eq!(
+            detect_language("src/main.rs", &map),
+            Some("Rust".to_string())
+        );
+        assert_eq!(
+            detect_language("lib/foo.py", &map),
+            Some("Python".to_string())
+        );
+        assert_eq!(
+            detect_language("cmd/server.go", &map),
+            Some("Go".to_string())
+        );
+        assert_eq!(detect_language("index.tsx", &map), Some("TSX".to_string()));
+    }
+
+    #[test]
+    fn test_detect_language_by_filename() {
+        let map = build_extension_to_language_map();
+
+        // PKGBUILD is a filename-based match for Shell Script
+        assert_eq!(
+            detect_language("PKGBUILD", &map),
+            Some("Shell Script".to_string())
+        );
+        assert_eq!(
+            detect_language("project/PKGBUILD", &map),
+            Some("Shell Script".to_string())
+        );
+        // .env files are also Shell Script
+        assert_eq!(
+            detect_language(".env", &map),
+            Some("Shell Script".to_string())
+        );
+    }
+
+    #[test]
+    fn test_detect_language_unknown() {
+        let map = build_extension_to_language_map();
+
+        assert_eq!(detect_language("file.xyz123", &map), None);
+        assert_eq!(detect_language("random_file", &map), None);
+    }
+
+    #[test]
+    fn test_get_cursor_path() {
+        let line = r#"{"cursor_path": "src/main.rs", "other": "data"}"#;
+        assert_eq!(get_cursor_path(line), Some("src/main.rs".to_string()));
+
+        let line_no_cursor = r#"{"other": "data"}"#;
+        assert_eq!(get_cursor_path(line_no_cursor), None);
+
+        let invalid_json = "not json";
+        assert_eq!(get_cursor_path(invalid_json), None);
+    }
+
+    #[test]
+    fn test_get_all_languages() {
+        let map = build_extension_to_language_map();
+        let languages = get_all_languages(&map);
+
+        assert!(!languages.is_empty());
+
+        let rust_entry = languages.iter().find(|(name, _)| name == "Rust");
+        assert!(rust_entry.is_some());
+        let (_, rust_extensions) = rust_entry.unwrap();
+        assert!(rust_extensions.contains(&"rs".to_string()));
+    }
+}

crates/edit_prediction_cli/src/main.rs 🔗

@@ -1,6 +1,7 @@
 mod anthropic_client;
 mod distill;
 mod example;
+mod filter_languages;
 mod format_prompt;
 mod git;
 mod headless;
@@ -36,6 +37,7 @@ use std::{path::PathBuf, sync::Arc};
 
 use crate::distill::run_distill;
 use crate::example::{Example, group_examples_by_repo, read_example_files};
+use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
 use crate::format_prompt::run_format_prompt;
 use crate::load_project::run_load_project;
 use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
@@ -153,6 +155,8 @@ enum Command {
     SplitCommit(SplitCommitArgs),
     /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
     Split(SplitArgs),
+    /// Filter a JSONL dataset by programming language (based on cursor_path extension)
+    FilterLanguages(FilterLanguagesArgs),
 }
 
 impl Display for Command {
@@ -184,6 +188,7 @@ impl Display for Command {
             Command::Clean => write!(f, "clean"),
             Command::SplitCommit(_) => write!(f, "split-commit"),
             Command::Split(_) => write!(f, "split"),
+            Command::FilterLanguages(_) => write!(f, "filter-languages"),
         }
     }
 }
@@ -506,6 +511,15 @@ fn main() {
             }
             return;
         }
+        Command::FilterLanguages(filter_args) => {
+            if let Err(error) =
+                run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
+            {
+                eprintln!("{error:#}");
+                std::process::exit(1);
+            }
+            return;
+        }
         _ => {}
     }
 
@@ -634,7 +648,8 @@ fn main() {
                                         Command::Clean
                                         | Command::Synthesize(_)
                                         | Command::SplitCommit(_)
-                                        | Command::Split(_) => {
+                                        | Command::Split(_)
+                                        | Command::FilterLanguages(_) => {
                                             unreachable!()
                                         }
                                     }