mod.rs

  1use anyhow::Result;
  2use async_trait::async_trait;
  3use serde::Deserialize;
  4use std::collections::BTreeMap;
  5use std::fs;
  6use std::{
  7    path::{Path, PathBuf},
  8    rc::Rc,
  9};
 10use util::serde::default_true;
 11
 12use crate::example::{Example, ExampleContext, ExampleMetadata, JudgeAssertion};
 13
 14mod add_arg_to_trait_method;
 15mod file_search;
 16
 17pub fn all(examples_dir: &Path) -> Vec<Rc<dyn Example>> {
 18    let mut threads: Vec<Rc<dyn Example>> = vec![
 19        Rc::new(file_search::FileSearchExample),
 20        Rc::new(add_arg_to_trait_method::AddArgToTraitMethod),
 21    ];
 22
 23    for example_path in list_declarative_examples(examples_dir).unwrap() {
 24        threads.push(Rc::new(DeclarativeExample::load(&example_path).unwrap()));
 25    }
 26
 27    threads
 28}
 29
 30struct DeclarativeExample {
 31    metadata: ExampleMetadata,
 32    prompt: String,
 33    diff_assertions: Vec<JudgeAssertion>,
 34    thread_assertions: Vec<JudgeAssertion>,
 35}
 36
 37impl DeclarativeExample {
 38    pub fn load(example_path: &Path) -> Result<Self> {
 39        let name = Self::name_from_path(example_path);
 40        let base: ExampleToml = toml::from_str(&fs::read_to_string(&example_path)?)?;
 41
 42        let language_server = if base.require_lsp {
 43            Some(crate::example::LanguageServer {
 44                file_extension: base
 45                    .language_extension
 46                    .expect("Language extension is required when require_lsp = true"),
 47                allow_preexisting_diagnostics: base.allow_preexisting_diagnostics,
 48            })
 49        } else {
 50            None
 51        };
 52
 53        let metadata = ExampleMetadata {
 54            name,
 55            url: base.url,
 56            revision: base.revision,
 57            language_server,
 58            max_assertions: None,
 59        };
 60
 61        Ok(DeclarativeExample {
 62            metadata,
 63            prompt: base.prompt,
 64            thread_assertions: base
 65                .thread_assertions
 66                .into_iter()
 67                .map(|(id, description)| JudgeAssertion { id, description })
 68                .collect(),
 69            diff_assertions: base
 70                .diff_assertions
 71                .into_iter()
 72                .map(|(id, description)| JudgeAssertion { id, description })
 73                .collect(),
 74        })
 75    }
 76
 77    pub fn name_from_path(path: &Path) -> String {
 78        path.file_stem().unwrap().to_string_lossy().to_string()
 79    }
 80}
 81
 82#[derive(Clone, Debug, Deserialize)]
 83pub struct ExampleToml {
 84    pub url: String,
 85    pub revision: String,
 86    pub language_extension: Option<String>,
 87    pub insert_id: Option<String>,
 88    #[serde(default = "default_true")]
 89    pub require_lsp: bool,
 90    #[serde(default)]
 91    pub allow_preexisting_diagnostics: bool,
 92    pub prompt: String,
 93    #[serde(default)]
 94    pub diff_assertions: BTreeMap<String, String>,
 95    #[serde(default)]
 96    pub thread_assertions: BTreeMap<String, String>,
 97}
 98
 99#[async_trait(?Send)]
100impl Example for DeclarativeExample {
101    fn meta(&self) -> ExampleMetadata {
102        self.metadata.clone()
103    }
104
105    async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> {
106        cx.push_user_message(&self.prompt);
107        let _ = cx.run_to_end().await;
108        Ok(())
109    }
110
111    fn diff_assertions(&self) -> Vec<JudgeAssertion> {
112        self.diff_assertions.clone()
113    }
114
115    fn thread_assertions(&self) -> Vec<JudgeAssertion> {
116        self.thread_assertions.clone()
117    }
118}
119
120fn list_declarative_examples(examples_dir: &Path) -> Result<Vec<PathBuf>> {
121    let path = std::fs::canonicalize(examples_dir).unwrap();
122    let entries = std::fs::read_dir(path).unwrap();
123    let mut result_paths = Vec::new();
124    for entry in entries {
125        let entry = entry?;
126        let path = entry.path();
127        if path.extension() == Some("toml".as_ref()) {
128            result_paths.push(path);
129        }
130    }
131    Ok(result_paths)
132}