1use anyhow::Result;
2use async_trait::async_trait;
3use serde::Deserialize;
4use std::collections::BTreeMap;
5use std::fs;
6use std::{
7 path::{Path, PathBuf},
8 rc::Rc,
9};
10use util::serde::default_true;
11
12use crate::example::{Example, ExampleContext, ExampleMetadata, JudgeAssertion};
13
14mod add_arg_to_trait_method;
15mod file_search;
16
17pub fn all(examples_dir: &Path) -> Vec<Rc<dyn Example>> {
18 let mut threads: Vec<Rc<dyn Example>> = vec![
19 Rc::new(file_search::FileSearchExample),
20 Rc::new(add_arg_to_trait_method::AddArgToTraitMethod),
21 ];
22
23 for example_path in list_declarative_examples(examples_dir).unwrap() {
24 threads.push(Rc::new(DeclarativeExample::load(&example_path).unwrap()));
25 }
26
27 threads
28}
29
30struct DeclarativeExample {
31 metadata: ExampleMetadata,
32 prompt: String,
33 diff_assertions: Vec<JudgeAssertion>,
34 thread_assertions: Vec<JudgeAssertion>,
35}
36
37impl DeclarativeExample {
38 pub fn load(example_path: &Path) -> Result<Self> {
39 let name = Self::name_from_path(example_path);
40 let base: ExampleToml = toml::from_str(&fs::read_to_string(&example_path)?)?;
41
42 let language_server = if base.require_lsp {
43 Some(crate::example::LanguageServer {
44 file_extension: base
45 .language_extension
46 .expect("Language extension is required when require_lsp = true"),
47 allow_preexisting_diagnostics: base.allow_preexisting_diagnostics,
48 })
49 } else {
50 None
51 };
52
53 let metadata = ExampleMetadata {
54 name,
55 url: base.url,
56 revision: base.revision,
57 language_server,
58 max_assertions: None,
59 };
60
61 Ok(DeclarativeExample {
62 metadata,
63 prompt: base.prompt,
64 thread_assertions: base
65 .thread_assertions
66 .into_iter()
67 .map(|(id, description)| JudgeAssertion { id, description })
68 .collect(),
69 diff_assertions: base
70 .diff_assertions
71 .into_iter()
72 .map(|(id, description)| JudgeAssertion { id, description })
73 .collect(),
74 })
75 }
76
77 pub fn name_from_path(path: &Path) -> String {
78 path.file_stem().unwrap().to_string_lossy().to_string()
79 }
80}
81
82#[derive(Clone, Debug, Deserialize)]
83pub struct ExampleToml {
84 pub url: String,
85 pub revision: String,
86 pub language_extension: Option<String>,
87 pub insert_id: Option<String>,
88 #[serde(default = "default_true")]
89 pub require_lsp: bool,
90 #[serde(default)]
91 pub allow_preexisting_diagnostics: bool,
92 pub prompt: String,
93 #[serde(default)]
94 pub diff_assertions: BTreeMap<String, String>,
95 #[serde(default)]
96 pub thread_assertions: BTreeMap<String, String>,
97}
98
99#[async_trait(?Send)]
100impl Example for DeclarativeExample {
101 fn meta(&self) -> ExampleMetadata {
102 self.metadata.clone()
103 }
104
105 async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> {
106 cx.push_user_message(&self.prompt);
107 let _ = cx.run_to_end().await;
108 Ok(())
109 }
110
111 fn diff_assertions(&self) -> Vec<JudgeAssertion> {
112 self.diff_assertions.clone()
113 }
114
115 fn thread_assertions(&self) -> Vec<JudgeAssertion> {
116 self.thread_assertions.clone()
117 }
118}
119
120fn list_declarative_examples(examples_dir: &Path) -> Result<Vec<PathBuf>> {
121 let path = std::fs::canonicalize(examples_dir).unwrap();
122 let entries = std::fs::read_dir(path).unwrap();
123 let mut result_paths = Vec::new();
124 for entry in entries {
125 let entry = entry?;
126 let path = entry.path();
127 if path.extension() == Some("toml".as_ref()) {
128 result_paths.push(path);
129 }
130 }
131 Ok(result_paths)
132}