mod.rs

  1use anyhow::Result;
  2use async_trait::async_trait;
  3use serde::Deserialize;
  4use std::collections::BTreeMap;
  5use std::fs;
  6use std::{
  7    path::{Path, PathBuf},
  8    rc::Rc,
  9};
 10use util::serde::default_true;
 11
 12use crate::example::{Example, ExampleContext, ExampleMetadata, JudgeAssertion};
 13
 14mod add_arg_to_trait_method;
 15mod code_block_citations;
 16mod file_search;
 17mod planets;
 18
 19pub fn all(examples_dir: &Path) -> Vec<Rc<dyn Example>> {
 20    let mut threads: Vec<Rc<dyn Example>> = vec![
 21        Rc::new(file_search::FileSearchExample),
 22        Rc::new(add_arg_to_trait_method::AddArgToTraitMethod),
 23        Rc::new(code_block_citations::CodeBlockCitations),
 24        Rc::new(planets::Planets),
 25    ];
 26
 27    for example_path in list_declarative_examples(examples_dir).unwrap() {
 28        threads.push(Rc::new(DeclarativeExample::load(&example_path).unwrap()));
 29    }
 30
 31    threads
 32}
 33
 34struct DeclarativeExample {
 35    metadata: ExampleMetadata,
 36    prompt: String,
 37    diff_assertions: Vec<JudgeAssertion>,
 38    thread_assertions: Vec<JudgeAssertion>,
 39}
 40
 41impl DeclarativeExample {
 42    pub fn load(example_path: &Path) -> Result<Self> {
 43        let name = Self::name_from_path(example_path);
 44        let base: ExampleToml = toml::from_str(&fs::read_to_string(&example_path)?)?;
 45
 46        let language_server = if base.require_lsp {
 47            Some(crate::example::LanguageServer {
 48                file_extension: base
 49                    .language_extension
 50                    .expect("Language extension is required when require_lsp = true"),
 51                allow_preexisting_diagnostics: base.allow_preexisting_diagnostics,
 52            })
 53        } else {
 54            None
 55        };
 56
 57        let metadata = ExampleMetadata {
 58            name,
 59            url: base.url,
 60            revision: base.revision,
 61            language_server,
 62            max_assertions: None,
 63        };
 64
 65        Ok(DeclarativeExample {
 66            metadata,
 67            prompt: base.prompt,
 68            thread_assertions: base
 69                .thread_assertions
 70                .into_iter()
 71                .map(|(id, description)| JudgeAssertion { id, description })
 72                .collect(),
 73            diff_assertions: base
 74                .diff_assertions
 75                .into_iter()
 76                .map(|(id, description)| JudgeAssertion { id, description })
 77                .collect(),
 78        })
 79    }
 80
 81    pub fn name_from_path(path: &Path) -> String {
 82        path.file_stem().unwrap().to_string_lossy().to_string()
 83    }
 84}
 85
 86#[derive(Clone, Debug, Deserialize)]
 87pub struct ExampleToml {
 88    pub url: String,
 89    pub revision: String,
 90    pub language_extension: Option<String>,
 91    pub insert_id: Option<String>,
 92    #[serde(default = "default_true")]
 93    pub require_lsp: bool,
 94    #[serde(default)]
 95    pub allow_preexisting_diagnostics: bool,
 96    pub prompt: String,
 97    #[serde(default)]
 98    pub diff_assertions: BTreeMap<String, String>,
 99    #[serde(default)]
100    pub thread_assertions: BTreeMap<String, String>,
101}
102
103#[async_trait(?Send)]
104impl Example for DeclarativeExample {
105    fn meta(&self) -> ExampleMetadata {
106        self.metadata.clone()
107    }
108
109    async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> {
110        cx.push_user_message(&self.prompt);
111        let _ = cx.run_to_end().await;
112        Ok(())
113    }
114
115    fn diff_assertions(&self) -> Vec<JudgeAssertion> {
116        self.diff_assertions.clone()
117    }
118
119    fn thread_assertions(&self) -> Vec<JudgeAssertion> {
120        self.thread_assertions.clone()
121    }
122}
123
124fn list_declarative_examples(examples_dir: &Path) -> Result<Vec<PathBuf>> {
125    let path = std::fs::canonicalize(examples_dir).unwrap();
126    let entries = std::fs::read_dir(path).unwrap();
127    let mut result_paths = Vec::new();
128    for entry in entries {
129        let entry = entry?;
130        let path = entry.path();
131        if path.extension() == Some("toml".as_ref()) {
132            result_paths.push(path);
133        }
134    }
135    Ok(result_paths)
136}