mod.rs

  1use anyhow::Result;
  2use assistant_settings::AgentProfileId;
  3use async_trait::async_trait;
  4use serde::Deserialize;
  5use std::collections::BTreeMap;
  6use std::fs;
  7use std::{
  8    path::{Path, PathBuf},
  9    rc::Rc,
 10};
 11use util::serde::default_true;
 12
 13use crate::example::{Example, ExampleContext, ExampleMetadata, JudgeAssertion};
 14
 15mod add_arg_to_trait_method;
 16mod code_block_citations;
 17mod comment_translation;
 18mod file_search;
 19mod planets;
 20
 21pub fn all(examples_dir: &Path) -> Vec<Rc<dyn Example>> {
 22    let mut threads: Vec<Rc<dyn Example>> = vec![
 23        Rc::new(file_search::FileSearchExample),
 24        Rc::new(add_arg_to_trait_method::AddArgToTraitMethod),
 25        Rc::new(code_block_citations::CodeBlockCitations),
 26        Rc::new(planets::Planets),
 27        Rc::new(comment_translation::CommentTranslation),
 28    ];
 29
 30    for example_path in list_declarative_examples(examples_dir).unwrap() {
 31        threads.push(Rc::new(DeclarativeExample::load(&example_path).unwrap()));
 32    }
 33
 34    threads
 35}
 36
 37struct DeclarativeExample {
 38    metadata: ExampleMetadata,
 39    prompt: String,
 40    diff_assertions: Vec<JudgeAssertion>,
 41    thread_assertions: Vec<JudgeAssertion>,
 42}
 43
 44impl DeclarativeExample {
 45    pub fn load(example_path: &Path) -> Result<Self> {
 46        let name = Self::name_from_path(example_path);
 47        let base: ExampleToml = toml::from_str(&fs::read_to_string(&example_path)?)?;
 48
 49        let language_server = if base.require_lsp {
 50            Some(crate::example::LanguageServer {
 51                file_extension: base
 52                    .language_extension
 53                    .expect("Language extension is required when require_lsp = true"),
 54                allow_preexisting_diagnostics: base.allow_preexisting_diagnostics,
 55            })
 56        } else {
 57            None
 58        };
 59
 60        let profile_id = if let Some(profile_name) = base.profile_name {
 61            AgentProfileId(profile_name.into())
 62        } else {
 63            AgentProfileId::default()
 64        };
 65
 66        let metadata = ExampleMetadata {
 67            name,
 68            url: base.url,
 69            revision: base.revision,
 70            language_server,
 71            max_assertions: None,
 72            profile_id,
 73        };
 74
 75        Ok(DeclarativeExample {
 76            metadata,
 77            prompt: base.prompt,
 78            thread_assertions: base
 79                .thread_assertions
 80                .into_iter()
 81                .map(|(id, description)| JudgeAssertion { id, description })
 82                .collect(),
 83            diff_assertions: base
 84                .diff_assertions
 85                .into_iter()
 86                .map(|(id, description)| JudgeAssertion { id, description })
 87                .collect(),
 88        })
 89    }
 90
 91    pub fn name_from_path(path: &Path) -> String {
 92        path.file_stem().unwrap().to_string_lossy().to_string()
 93    }
 94}
 95
 96#[derive(Clone, Debug, Deserialize)]
 97pub struct ExampleToml {
 98    pub url: String,
 99    pub revision: String,
100    pub language_extension: Option<String>,
101    pub insert_id: Option<String>,
102    #[serde(default = "default_true")]
103    pub require_lsp: bool,
104    #[serde(default)]
105    pub allow_preexisting_diagnostics: bool,
106    pub prompt: String,
107    #[serde(default)]
108    pub profile_name: Option<String>,
109    #[serde(default)]
110    pub diff_assertions: BTreeMap<String, String>,
111    #[serde(default)]
112    pub thread_assertions: BTreeMap<String, String>,
113}
114
115#[async_trait(?Send)]
116impl Example for DeclarativeExample {
117    fn meta(&self) -> ExampleMetadata {
118        self.metadata.clone()
119    }
120
121    async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> {
122        cx.push_user_message(&self.prompt);
123        let _ = cx.run_to_end().await;
124        Ok(())
125    }
126
127    fn diff_assertions(&self) -> Vec<JudgeAssertion> {
128        self.diff_assertions.clone()
129    }
130
131    fn thread_assertions(&self) -> Vec<JudgeAssertion> {
132        self.thread_assertions.clone()
133    }
134}
135
136fn list_declarative_examples(examples_dir: &Path) -> Result<Vec<PathBuf>> {
137    let path = std::fs::canonicalize(examples_dir).unwrap();
138    let entries = std::fs::read_dir(path).unwrap();
139    let mut result_paths = Vec::new();
140    for entry in entries {
141        let entry = entry?;
142        let path = entry.path();
143        if path.extension() == Some("toml".as_ref()) {
144            result_paths.push(path);
145        }
146    }
147    Ok(result_paths)
148}