mod.rs

  1use agent_settings::AgentProfileId;
  2use anyhow::Result;
  3use async_trait::async_trait;
  4use serde::Deserialize;
  5use std::collections::BTreeMap;
  6use std::fs;
  7use std::{
  8    path::{Path, PathBuf},
  9    rc::Rc,
 10};
 11use util::serde::default_true;
 12
 13use crate::example::{Example, ExampleContext, ExampleMetadata, JudgeAssertion};
 14
 15mod add_arg_to_trait_method;
 16mod code_block_citations;
 17mod comment_translation;
 18mod file_search;
 19mod overwrite_file;
 20mod planets;
 21
 22pub fn all(examples_dir: &Path) -> Vec<Rc<dyn Example>> {
 23    let mut threads: Vec<Rc<dyn Example>> = vec![
 24        Rc::new(file_search::FileSearchExample),
 25        Rc::new(add_arg_to_trait_method::AddArgToTraitMethod),
 26        Rc::new(code_block_citations::CodeBlockCitations),
 27        Rc::new(planets::Planets),
 28        Rc::new(comment_translation::CommentTranslation),
 29        Rc::new(overwrite_file::FileOverwriteExample),
 30    ];
 31
 32    for example_path in list_declarative_examples(examples_dir).unwrap() {
 33        threads.push(Rc::new(DeclarativeExample::load(&example_path).unwrap()));
 34    }
 35
 36    threads
 37}
 38
 39struct DeclarativeExample {
 40    metadata: ExampleMetadata,
 41    prompt: String,
 42    diff_assertions: Vec<JudgeAssertion>,
 43    thread_assertions: Vec<JudgeAssertion>,
 44}
 45
 46impl DeclarativeExample {
 47    pub fn load(example_path: &Path) -> Result<Self> {
 48        let name = Self::name_from_path(example_path);
 49        let base: ExampleToml = toml::from_str(&fs::read_to_string(&example_path)?)?;
 50        let example_dir = example_path.parent().unwrap();
 51
 52        let language_server = if base.require_lsp {
 53            Some(crate::example::LanguageServer {
 54                file_extension: base
 55                    .language_extension
 56                    .expect("Language extension is required when require_lsp = true"),
 57                allow_preexisting_diagnostics: base.allow_preexisting_diagnostics,
 58            })
 59        } else {
 60            None
 61        };
 62
 63        let profile_id = if let Some(profile_name) = base.profile_name {
 64            AgentProfileId(profile_name.into())
 65        } else {
 66            AgentProfileId::default()
 67        };
 68
 69        let existing_thread_json = if let Some(path) = base.existing_thread_path {
 70            let content = fs::read_to_string(example_dir.join(&path))
 71                .unwrap_or_else(|_| panic!("Failed to read existing thread file: {}", path));
 72            Some(content)
 73        } else {
 74            None
 75        };
 76
 77        let metadata = ExampleMetadata {
 78            name,
 79            url: base.url,
 80            revision: base.revision,
 81            language_server,
 82            max_assertions: None,
 83            profile_id,
 84            existing_thread_json,
 85            max_turns: base.max_turns,
 86        };
 87
 88        Ok(DeclarativeExample {
 89            metadata,
 90            prompt: base.prompt,
 91            thread_assertions: base
 92                .thread_assertions
 93                .into_iter()
 94                .map(|(id, description)| JudgeAssertion { id, description })
 95                .collect(),
 96            diff_assertions: base
 97                .diff_assertions
 98                .into_iter()
 99                .map(|(id, description)| JudgeAssertion { id, description })
100                .collect(),
101        })
102    }
103
104    pub fn name_from_path(path: &Path) -> String {
105        path.file_stem().unwrap().to_string_lossy().to_string()
106    }
107}
108
109#[derive(Clone, Debug, Deserialize)]
110pub struct ExampleToml {
111    pub url: String,
112    pub revision: String,
113    pub language_extension: Option<String>,
114    pub insert_id: Option<String>,
115    #[serde(default = "default_true")]
116    pub require_lsp: bool,
117    #[serde(default)]
118    pub allow_preexisting_diagnostics: bool,
119    pub prompt: String,
120    #[serde(default)]
121    pub profile_name: Option<String>,
122    #[serde(default)]
123    pub diff_assertions: BTreeMap<String, String>,
124    #[serde(default)]
125    pub thread_assertions: BTreeMap<String, String>,
126    #[serde(default)]
127    pub existing_thread_path: Option<String>,
128    #[serde(default)]
129    pub max_turns: Option<u32>,
130}
131
132#[async_trait(?Send)]
133impl Example for DeclarativeExample {
134    fn meta(&self) -> ExampleMetadata {
135        self.metadata.clone()
136    }
137
138    async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> {
139        cx.push_user_message(&self.prompt);
140        let max_turns = self.metadata.max_turns.unwrap_or(1000);
141        let _ = cx.run_turns(max_turns).await;
142        Ok(())
143    }
144
145    fn diff_assertions(&self) -> Vec<JudgeAssertion> {
146        self.diff_assertions.clone()
147    }
148
149    fn thread_assertions(&self) -> Vec<JudgeAssertion> {
150        self.thread_assertions.clone()
151    }
152}
153
154fn list_declarative_examples(examples_dir: &Path) -> Result<Vec<PathBuf>> {
155    let path = std::fs::canonicalize(examples_dir).unwrap();
156    let entries = std::fs::read_dir(path).unwrap();
157    let mut result_paths = Vec::new();
158    for entry in entries {
159        let entry = entry?;
160        let path = entry.path();
161        if path.extension() == Some("toml".as_ref()) {
162            result_paths.push(path);
163        }
164    }
165    Ok(result_paths)
166}