mod.rs

  1use agent_settings::AgentProfileId;
  2use anyhow::Result;
  3use async_trait::async_trait;
  4use serde::Deserialize;
  5use std::collections::BTreeMap;
  6use std::fs;
  7use std::{
  8    path::{Path, PathBuf},
  9    rc::Rc,
 10};
 11use util::serde::default_true;
 12
 13use crate::example::{Example, ExampleContext, ExampleMetadata, JudgeAssertion};
 14
 15mod add_arg_to_trait_method;
 16mod code_block_citations;
 17mod comment_translation;
 18mod file_search;
 19mod grep_params_escapement;
 20mod overwrite_file;
 21mod planets;
 22
 23pub fn all(examples_dir: &Path) -> Vec<Rc<dyn Example>> {
 24    let mut threads: Vec<Rc<dyn Example>> = vec![
 25        Rc::new(file_search::FileSearchExample),
 26        Rc::new(add_arg_to_trait_method::AddArgToTraitMethod),
 27        Rc::new(code_block_citations::CodeBlockCitations),
 28        Rc::new(planets::Planets),
 29        Rc::new(comment_translation::CommentTranslation),
 30        Rc::new(overwrite_file::FileOverwriteExample),
 31        Rc::new(grep_params_escapement::GrepParamsEscapementExample),
 32    ];
 33
 34    for example_path in list_declarative_examples(examples_dir).unwrap() {
 35        threads.push(Rc::new(DeclarativeExample::load(&example_path).unwrap()));
 36    }
 37
 38    threads
 39}
 40
 41struct DeclarativeExample {
 42    metadata: ExampleMetadata,
 43    prompt: String,
 44    diff_assertions: Vec<JudgeAssertion>,
 45    thread_assertions: Vec<JudgeAssertion>,
 46}
 47
 48impl DeclarativeExample {
 49    pub fn load(example_path: &Path) -> Result<Self> {
 50        let name = Self::name_from_path(example_path);
 51        let base: ExampleToml = toml::from_str(&fs::read_to_string(&example_path)?)?;
 52        let example_dir = example_path.parent().unwrap();
 53
 54        let language_server = if base.require_lsp {
 55            Some(crate::example::LanguageServer {
 56                file_extension: base
 57                    .language_extension
 58                    .expect("Language extension is required when require_lsp = true"),
 59                allow_preexisting_diagnostics: base.allow_preexisting_diagnostics,
 60            })
 61        } else {
 62            None
 63        };
 64
 65        let profile_id = if let Some(profile_name) = base.profile_name {
 66            AgentProfileId(profile_name.into())
 67        } else {
 68            AgentProfileId::default()
 69        };
 70
 71        let existing_thread_json = if let Some(path) = base.existing_thread_path {
 72            let content = fs::read_to_string(example_dir.join(&path))
 73                .unwrap_or_else(|_| panic!("Failed to read existing thread file: {}", path));
 74            Some(content)
 75        } else {
 76            None
 77        };
 78
 79        let metadata = ExampleMetadata {
 80            name,
 81            url: base.url,
 82            revision: base.revision,
 83            language_server,
 84            max_assertions: None,
 85            profile_id,
 86            existing_thread_json,
 87            max_turns: base.max_turns,
 88        };
 89
 90        Ok(DeclarativeExample {
 91            metadata,
 92            prompt: base.prompt,
 93            thread_assertions: base
 94                .thread_assertions
 95                .into_iter()
 96                .map(|(id, description)| JudgeAssertion { id, description })
 97                .collect(),
 98            diff_assertions: base
 99                .diff_assertions
100                .into_iter()
101                .map(|(id, description)| JudgeAssertion { id, description })
102                .collect(),
103        })
104    }
105
106    pub fn name_from_path(path: &Path) -> String {
107        path.file_stem().unwrap().to_string_lossy().to_string()
108    }
109}
110
111#[derive(Clone, Debug, Deserialize)]
112pub struct ExampleToml {
113    pub url: String,
114    pub revision: String,
115    pub language_extension: Option<String>,
116    pub insert_id: Option<String>,
117    #[serde(default = "default_true")]
118    pub require_lsp: bool,
119    #[serde(default)]
120    pub allow_preexisting_diagnostics: bool,
121    pub prompt: String,
122    #[serde(default)]
123    pub profile_name: Option<String>,
124    #[serde(default)]
125    pub diff_assertions: BTreeMap<String, String>,
126    #[serde(default)]
127    pub thread_assertions: BTreeMap<String, String>,
128    #[serde(default)]
129    pub existing_thread_path: Option<String>,
130    #[serde(default)]
131    pub max_turns: Option<u32>,
132}
133
134#[async_trait(?Send)]
135impl Example for DeclarativeExample {
136    fn meta(&self) -> ExampleMetadata {
137        self.metadata.clone()
138    }
139
140    async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> {
141        cx.push_user_message(&self.prompt);
142        let max_turns = self.metadata.max_turns.unwrap_or(1000);
143        let _ = cx.run_turns(max_turns).await;
144        Ok(())
145    }
146
147    fn diff_assertions(&self) -> Vec<JudgeAssertion> {
148        self.diff_assertions.clone()
149    }
150
151    fn thread_assertions(&self) -> Vec<JudgeAssertion> {
152        self.thread_assertions.clone()
153    }
154}
155
156fn list_declarative_examples(examples_dir: &Path) -> Result<Vec<PathBuf>> {
157    let path = std::fs::canonicalize(examples_dir).unwrap();
158    let entries = std::fs::read_dir(path).unwrap();
159    let mut result_paths = Vec::new();
160    for entry in entries {
161        let entry = entry?;
162        let path = entry.path();
163        if path.extension() == Some("toml".as_ref()) {
164            result_paths.push(path);
165        }
166    }
167    Ok(result_paths)
168}