1use anyhow::Result;
2use assistant_settings::AgentProfileId;
3use async_trait::async_trait;
4use serde::Deserialize;
5use std::collections::BTreeMap;
6use std::fs;
7use std::{
8 path::{Path, PathBuf},
9 rc::Rc,
10};
11use util::serde::default_true;
12
13use crate::example::{Example, ExampleContext, ExampleMetadata, JudgeAssertion};
14
15mod add_arg_to_trait_method;
16mod code_block_citations;
17mod comment_translation;
18mod file_search;
19mod planets;
20
21pub fn all(examples_dir: &Path) -> Vec<Rc<dyn Example>> {
22 let mut threads: Vec<Rc<dyn Example>> = vec![
23 Rc::new(file_search::FileSearchExample),
24 Rc::new(add_arg_to_trait_method::AddArgToTraitMethod),
25 Rc::new(code_block_citations::CodeBlockCitations),
26 Rc::new(planets::Planets),
27 Rc::new(comment_translation::CommentTranslation),
28 ];
29
30 for example_path in list_declarative_examples(examples_dir).unwrap() {
31 threads.push(Rc::new(DeclarativeExample::load(&example_path).unwrap()));
32 }
33
34 threads
35}
36
37struct DeclarativeExample {
38 metadata: ExampleMetadata,
39 prompt: String,
40 diff_assertions: Vec<JudgeAssertion>,
41 thread_assertions: Vec<JudgeAssertion>,
42}
43
44impl DeclarativeExample {
45 pub fn load(example_path: &Path) -> Result<Self> {
46 let name = Self::name_from_path(example_path);
47 let base: ExampleToml = toml::from_str(&fs::read_to_string(&example_path)?)?;
48
49 let language_server = if base.require_lsp {
50 Some(crate::example::LanguageServer {
51 file_extension: base
52 .language_extension
53 .expect("Language extension is required when require_lsp = true"),
54 allow_preexisting_diagnostics: base.allow_preexisting_diagnostics,
55 })
56 } else {
57 None
58 };
59
60 let profile_id = if let Some(profile_name) = base.profile_name {
61 AgentProfileId(profile_name.into())
62 } else {
63 AgentProfileId::default()
64 };
65
66 let metadata = ExampleMetadata {
67 name,
68 url: base.url,
69 revision: base.revision,
70 language_server,
71 max_assertions: None,
72 profile_id,
73 };
74
75 Ok(DeclarativeExample {
76 metadata,
77 prompt: base.prompt,
78 thread_assertions: base
79 .thread_assertions
80 .into_iter()
81 .map(|(id, description)| JudgeAssertion { id, description })
82 .collect(),
83 diff_assertions: base
84 .diff_assertions
85 .into_iter()
86 .map(|(id, description)| JudgeAssertion { id, description })
87 .collect(),
88 })
89 }
90
91 pub fn name_from_path(path: &Path) -> String {
92 path.file_stem().unwrap().to_string_lossy().to_string()
93 }
94}
95
96#[derive(Clone, Debug, Deserialize)]
97pub struct ExampleToml {
98 pub url: String,
99 pub revision: String,
100 pub language_extension: Option<String>,
101 pub insert_id: Option<String>,
102 #[serde(default = "default_true")]
103 pub require_lsp: bool,
104 #[serde(default)]
105 pub allow_preexisting_diagnostics: bool,
106 pub prompt: String,
107 #[serde(default)]
108 pub profile_name: Option<String>,
109 #[serde(default)]
110 pub diff_assertions: BTreeMap<String, String>,
111 #[serde(default)]
112 pub thread_assertions: BTreeMap<String, String>,
113}
114
115#[async_trait(?Send)]
116impl Example for DeclarativeExample {
117 fn meta(&self) -> ExampleMetadata {
118 self.metadata.clone()
119 }
120
121 async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> {
122 cx.push_user_message(&self.prompt);
123 let _ = cx.run_to_end().await;
124 Ok(())
125 }
126
127 fn diff_assertions(&self) -> Vec<JudgeAssertion> {
128 self.diff_assertions.clone()
129 }
130
131 fn thread_assertions(&self) -> Vec<JudgeAssertion> {
132 self.thread_assertions.clone()
133 }
134}
135
136fn list_declarative_examples(examples_dir: &Path) -> Result<Vec<PathBuf>> {
137 let path = std::fs::canonicalize(examples_dir).unwrap();
138 let entries = std::fs::read_dir(path).unwrap();
139 let mut result_paths = Vec::new();
140 for entry in entries {
141 let entry = entry?;
142 let path = entry.path();
143 if path.extension() == Some("toml".as_ref()) {
144 result_paths.push(path);
145 }
146 }
147 Ok(result_paths)
148}