1use agent_settings::AgentProfileId;
2use anyhow::Result;
3use async_trait::async_trait;
4use serde::Deserialize;
5use std::collections::BTreeMap;
6use std::fs;
7use std::{
8 path::{Path, PathBuf},
9 rc::Rc,
10};
11use util::serde::default_true;
12
13use crate::example::{Example, ExampleContext, ExampleMetadata, JudgeAssertion};
14
15mod add_arg_to_trait_method;
16mod code_block_citations;
17mod comment_translation;
18mod file_search;
19mod grep_params_escapement;
20mod overwrite_file;
21mod planets;
22
23pub fn all(examples_dir: &Path) -> Vec<Rc<dyn Example>> {
24 let mut threads: Vec<Rc<dyn Example>> = vec![
25 Rc::new(file_search::FileSearchExample),
26 Rc::new(add_arg_to_trait_method::AddArgToTraitMethod),
27 Rc::new(code_block_citations::CodeBlockCitations),
28 Rc::new(planets::Planets),
29 Rc::new(comment_translation::CommentTranslation),
30 Rc::new(overwrite_file::FileOverwriteExample),
31 Rc::new(grep_params_escapement::GrepParamsEscapementExample),
32 ];
33
34 for example_path in list_declarative_examples(examples_dir).unwrap() {
35 threads.push(Rc::new(DeclarativeExample::load(&example_path).unwrap()));
36 }
37
38 threads
39}
40
41struct DeclarativeExample {
42 metadata: ExampleMetadata,
43 prompt: String,
44 diff_assertions: Vec<JudgeAssertion>,
45 thread_assertions: Vec<JudgeAssertion>,
46}
47
48impl DeclarativeExample {
49 pub fn load(example_path: &Path) -> Result<Self> {
50 let name = Self::name_from_path(example_path);
51 let base: ExampleToml = toml::from_str(&fs::read_to_string(&example_path)?)?;
52 let example_dir = example_path.parent().unwrap();
53
54 let language_server = if base.require_lsp {
55 Some(crate::example::LanguageServer {
56 file_extension: base
57 .language_extension
58 .expect("Language extension is required when require_lsp = true"),
59 allow_preexisting_diagnostics: base.allow_preexisting_diagnostics,
60 })
61 } else {
62 None
63 };
64
65 let profile_id = if let Some(profile_name) = base.profile_name {
66 AgentProfileId(profile_name.into())
67 } else {
68 AgentProfileId::default()
69 };
70
71 let existing_thread_json = if let Some(path) = base.existing_thread_path {
72 let content = fs::read_to_string(example_dir.join(&path))
73 .unwrap_or_else(|_| panic!("Failed to read existing thread file: {}", path));
74 Some(content)
75 } else {
76 None
77 };
78
79 let metadata = ExampleMetadata {
80 name,
81 url: base.url,
82 revision: base.revision,
83 language_server,
84 max_assertions: None,
85 profile_id,
86 existing_thread_json,
87 max_turns: base.max_turns,
88 };
89
90 Ok(DeclarativeExample {
91 metadata,
92 prompt: base.prompt,
93 thread_assertions: base
94 .thread_assertions
95 .into_iter()
96 .map(|(id, description)| JudgeAssertion { id, description })
97 .collect(),
98 diff_assertions: base
99 .diff_assertions
100 .into_iter()
101 .map(|(id, description)| JudgeAssertion { id, description })
102 .collect(),
103 })
104 }
105
106 pub fn name_from_path(path: &Path) -> String {
107 path.file_stem().unwrap().to_string_lossy().to_string()
108 }
109}
110
111#[derive(Clone, Debug, Deserialize)]
112pub struct ExampleToml {
113 pub url: String,
114 pub revision: String,
115 pub language_extension: Option<String>,
116 pub insert_id: Option<String>,
117 #[serde(default = "default_true")]
118 pub require_lsp: bool,
119 #[serde(default)]
120 pub allow_preexisting_diagnostics: bool,
121 pub prompt: String,
122 #[serde(default)]
123 pub profile_name: Option<String>,
124 #[serde(default)]
125 pub diff_assertions: BTreeMap<String, String>,
126 #[serde(default)]
127 pub thread_assertions: BTreeMap<String, String>,
128 #[serde(default)]
129 pub existing_thread_path: Option<String>,
130 #[serde(default)]
131 pub max_turns: Option<u32>,
132}
133
134#[async_trait(?Send)]
135impl Example for DeclarativeExample {
136 fn meta(&self) -> ExampleMetadata {
137 self.metadata.clone()
138 }
139
140 async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> {
141 cx.push_user_message(&self.prompt);
142 let max_turns = self.metadata.max_turns.unwrap_or(1000);
143 let _ = cx.run_turns(max_turns).await;
144 Ok(())
145 }
146
147 fn diff_assertions(&self) -> Vec<JudgeAssertion> {
148 self.diff_assertions.clone()
149 }
150
151 fn thread_assertions(&self) -> Vec<JudgeAssertion> {
152 self.thread_assertions.clone()
153 }
154}
155
156fn list_declarative_examples(examples_dir: &Path) -> Result<Vec<PathBuf>> {
157 let path = std::fs::canonicalize(examples_dir).unwrap();
158 let entries = std::fs::read_dir(path).unwrap();
159 let mut result_paths = Vec::new();
160 for entry in entries {
161 let entry = entry?;
162 let path = entry.path();
163 if path.extension() == Some("toml".as_ref()) {
164 result_paths.push(path);
165 }
166 }
167 Ok(result_paths)
168}