mod.rs

  1use anyhow::Result;
  2use async_trait::async_trait;
  3use serde::Deserialize;
  4use std::collections::BTreeMap;
  5use std::fs;
  6use std::{
  7    path::{Path, PathBuf},
  8    rc::Rc,
  9};
 10use util::serde::default_true;
 11
 12use crate::example::{Example, ExampleContext, ExampleMetadata, JudgeAssertion};
 13
 14mod file_search;
 15
 16pub fn all(examples_dir: &Path) -> Vec<Rc<dyn Example>> {
 17    let mut threads: Vec<Rc<dyn Example>> = vec![Rc::new(file_search::FileSearchExample)];
 18
 19    for example_path in list_declarative_examples(examples_dir).unwrap() {
 20        threads.push(Rc::new(DeclarativeExample::load(&example_path).unwrap()));
 21    }
 22
 23    threads
 24}
 25
 26struct DeclarativeExample {
 27    metadata: ExampleMetadata,
 28    prompt: String,
 29    diff_assertions: Vec<JudgeAssertion>,
 30    thread_assertions: Vec<JudgeAssertion>,
 31}
 32
 33impl DeclarativeExample {
 34    pub fn load(example_path: &Path) -> Result<Self> {
 35        let name = Self::name_from_path(example_path);
 36        let base: ExampleToml = toml::from_str(&fs::read_to_string(&example_path)?)?;
 37
 38        let language_server = if base.require_lsp {
 39            Some(crate::example::LanguageServer {
 40                file_extension: base
 41                    .language_extension
 42                    .expect("Language extension is required when require_lsp = true"),
 43                allow_preexisting_diagnostics: base.allow_preexisting_diagnostics,
 44            })
 45        } else {
 46            None
 47        };
 48
 49        let metadata = ExampleMetadata {
 50            name,
 51            url: base.url,
 52            revision: base.revision,
 53            language_server,
 54            max_assertions: None,
 55        };
 56
 57        Ok(DeclarativeExample {
 58            metadata,
 59            prompt: base.prompt,
 60            thread_assertions: base
 61                .thread_assertions
 62                .into_iter()
 63                .map(|(id, description)| JudgeAssertion { id, description })
 64                .collect(),
 65            diff_assertions: base
 66                .diff_assertions
 67                .into_iter()
 68                .map(|(id, description)| JudgeAssertion { id, description })
 69                .collect(),
 70        })
 71    }
 72
 73    pub fn name_from_path(path: &Path) -> String {
 74        path.file_stem().unwrap().to_string_lossy().to_string()
 75    }
 76}
 77
 78#[derive(Clone, Debug, Deserialize)]
 79pub struct ExampleToml {
 80    pub url: String,
 81    pub revision: String,
 82    pub language_extension: Option<String>,
 83    pub insert_id: Option<String>,
 84    #[serde(default = "default_true")]
 85    pub require_lsp: bool,
 86    #[serde(default)]
 87    pub allow_preexisting_diagnostics: bool,
 88    pub prompt: String,
 89    #[serde(default)]
 90    pub diff_assertions: BTreeMap<String, String>,
 91    #[serde(default)]
 92    pub thread_assertions: BTreeMap<String, String>,
 93}
 94
 95#[async_trait(?Send)]
 96impl Example for DeclarativeExample {
 97    fn meta(&self) -> ExampleMetadata {
 98        self.metadata.clone()
 99    }
100
101    async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> {
102        cx.push_user_message(&self.prompt);
103        let _ = cx.run_to_end().await;
104        Ok(())
105    }
106
107    fn diff_assertions(&self) -> Vec<JudgeAssertion> {
108        self.diff_assertions.clone()
109    }
110
111    fn thread_assertions(&self) -> Vec<JudgeAssertion> {
112        self.thread_assertions.clone()
113    }
114}
115
116fn list_declarative_examples(examples_dir: &Path) -> Result<Vec<PathBuf>> {
117    let path = std::fs::canonicalize(examples_dir).unwrap();
118    let entries = std::fs::read_dir(path).unwrap();
119    let mut result_paths = Vec::new();
120    for entry in entries {
121        let entry = entry?;
122        let path = entry.path();
123        if path.extension() == Some("toml".as_ref()) {
124            result_paths.push(path);
125        }
126    }
127    Ok(result_paths)
128}