mod.rs

  1use anyhow::Result;
  2use async_trait::async_trait;
  3use serde::Deserialize;
  4use std::collections::BTreeMap;
  5use std::fs;
  6use std::{
  7    path::{Path, PathBuf},
  8    rc::Rc,
  9};
 10use util::serde::default_true;
 11
 12use crate::example::{Example, ExampleContext, ExampleMetadata, JudgeAssertion};
 13
 14mod add_arg_to_trait_method;
 15mod code_block_citations;
 16mod comment_translation;
 17mod file_search;
 18mod planets;
 19
 20pub fn all(examples_dir: &Path) -> Vec<Rc<dyn Example>> {
 21    let mut threads: Vec<Rc<dyn Example>> = vec![
 22        Rc::new(file_search::FileSearchExample),
 23        Rc::new(add_arg_to_trait_method::AddArgToTraitMethod),
 24        Rc::new(code_block_citations::CodeBlockCitations),
 25        Rc::new(planets::Planets),
 26        Rc::new(comment_translation::CommentTranslation),
 27    ];
 28
 29    for example_path in list_declarative_examples(examples_dir).unwrap() {
 30        threads.push(Rc::new(DeclarativeExample::load(&example_path).unwrap()));
 31    }
 32
 33    threads
 34}
 35
 36struct DeclarativeExample {
 37    metadata: ExampleMetadata,
 38    prompt: String,
 39    diff_assertions: Vec<JudgeAssertion>,
 40    thread_assertions: Vec<JudgeAssertion>,
 41}
 42
 43impl DeclarativeExample {
 44    pub fn load(example_path: &Path) -> Result<Self> {
 45        let name = Self::name_from_path(example_path);
 46        let base: ExampleToml = toml::from_str(&fs::read_to_string(&example_path)?)?;
 47
 48        let language_server = if base.require_lsp {
 49            Some(crate::example::LanguageServer {
 50                file_extension: base
 51                    .language_extension
 52                    .expect("Language extension is required when require_lsp = true"),
 53                allow_preexisting_diagnostics: base.allow_preexisting_diagnostics,
 54            })
 55        } else {
 56            None
 57        };
 58
 59        let metadata = ExampleMetadata {
 60            name,
 61            url: base.url,
 62            revision: base.revision,
 63            language_server,
 64            max_assertions: None,
 65        };
 66
 67        Ok(DeclarativeExample {
 68            metadata,
 69            prompt: base.prompt,
 70            thread_assertions: base
 71                .thread_assertions
 72                .into_iter()
 73                .map(|(id, description)| JudgeAssertion { id, description })
 74                .collect(),
 75            diff_assertions: base
 76                .diff_assertions
 77                .into_iter()
 78                .map(|(id, description)| JudgeAssertion { id, description })
 79                .collect(),
 80        })
 81    }
 82
 83    pub fn name_from_path(path: &Path) -> String {
 84        path.file_stem().unwrap().to_string_lossy().to_string()
 85    }
 86}
 87
 88#[derive(Clone, Debug, Deserialize)]
 89pub struct ExampleToml {
 90    pub url: String,
 91    pub revision: String,
 92    pub language_extension: Option<String>,
 93    pub insert_id: Option<String>,
 94    #[serde(default = "default_true")]
 95    pub require_lsp: bool,
 96    #[serde(default)]
 97    pub allow_preexisting_diagnostics: bool,
 98    pub prompt: String,
 99    #[serde(default)]
100    pub diff_assertions: BTreeMap<String, String>,
101    #[serde(default)]
102    pub thread_assertions: BTreeMap<String, String>,
103}
104
105#[async_trait(?Send)]
106impl Example for DeclarativeExample {
107    fn meta(&self) -> ExampleMetadata {
108        self.metadata.clone()
109    }
110
111    async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> {
112        cx.push_user_message(&self.prompt);
113        let _ = cx.run_to_end().await;
114        Ok(())
115    }
116
117    fn diff_assertions(&self) -> Vec<JudgeAssertion> {
118        self.diff_assertions.clone()
119    }
120
121    fn thread_assertions(&self) -> Vec<JudgeAssertion> {
122        self.thread_assertions.clone()
123    }
124}
125
126fn list_declarative_examples(examples_dir: &Path) -> Result<Vec<PathBuf>> {
127    let path = std::fs::canonicalize(examples_dir).unwrap();
128    let entries = std::fs::read_dir(path).unwrap();
129    let mut result_paths = Vec::new();
130    for entry in entries {
131        let entry = entry?;
132        let path = entry.path();
133        if path.extension() == Some("toml".as_ref()) {
134            result_paths.push(path);
135        }
136    }
137    Ok(result_paths)
138}