1use crate::headless_assistant::{HeadlessAppState, HeadlessAssistant};
2use agent::RequestKind;
3use anyhow::anyhow;
4use collections::HashMap;
5use gpui::{App, Task};
6use language_model::{LanguageModel, TokenUsage};
7use serde::{Deserialize, Serialize};
8use std::{
9 fs,
10 io::Write,
11 path::{Path, PathBuf},
12 sync::Arc,
13 time::Duration,
14};
15use util::command::new_smol_command;
16
17pub struct Eval {
18 pub name: String,
19 pub path: PathBuf,
20 pub repo_path: PathBuf,
21 pub eval_setup: EvalSetup,
22 pub user_prompt: String,
23}
24
25#[derive(Debug, Serialize)]
26pub struct EvalOutput {
27 pub diff: String,
28 pub last_message: String,
29 pub elapsed_time: Duration,
30 pub assistant_response_count: usize,
31 pub tool_use_counts: HashMap<Arc<str>, u32>,
32 pub token_usage: TokenUsage,
33}
34
35#[derive(Deserialize)]
36pub struct EvalSetup {
37 pub url: String,
38 pub base_sha: String,
39}
40
41impl Eval {
42 /// Loads the eval from a path (typically in `evaluation_data`). Clones and checks out the repo
43 /// if necessary.
44 pub async fn load(name: String, path: PathBuf, repos_dir: &Path) -> anyhow::Result<Self> {
45 let prompt_path = path.join("prompt.txt");
46 let user_prompt = smol::unblock(|| std::fs::read_to_string(prompt_path)).await?;
47 let setup_path = path.join("setup.json");
48 let setup_contents = smol::unblock(|| std::fs::read_to_string(setup_path)).await?;
49 let eval_setup = serde_json_lenient::from_str_lenient::<EvalSetup>(&setup_contents)?;
50 let repo_path = repos_dir.join(repo_dir_name(&eval_setup.url));
51 Ok(Eval {
52 name,
53 path,
54 repo_path,
55 eval_setup,
56 user_prompt,
57 })
58 }
59
60 pub fn run(
61 self,
62 app_state: Arc<HeadlessAppState>,
63 model: Arc<dyn LanguageModel>,
64 cx: &mut App,
65 ) -> Task<anyhow::Result<EvalOutput>> {
66 cx.spawn(async move |cx| {
67 checkout_repo(&self.eval_setup, &self.repo_path).await?;
68
69 let (assistant, done_rx) =
70 cx.update(|cx| HeadlessAssistant::new(app_state.clone(), cx))??;
71
72 let _worktree = assistant
73 .update(cx, |assistant, cx| {
74 assistant.project.update(cx, |project, cx| {
75 project.create_worktree(&self.repo_path, true, cx)
76 })
77 })?
78 .await?;
79
80 let start_time = std::time::SystemTime::now();
81
82 let (system_prompt_context, load_error) = cx
83 .update(|cx| {
84 assistant
85 .read(cx)
86 .thread
87 .read(cx)
88 .load_system_prompt_context(cx)
89 })?
90 .await;
91
92 if let Some(load_error) = load_error {
93 return Err(anyhow!("{:?}", load_error));
94 };
95
96 assistant.update(cx, |assistant, cx| {
97 assistant.thread.update(cx, |thread, cx| {
98 let context = vec![];
99 thread.insert_user_message(self.user_prompt.clone(), context, None, cx);
100 thread.set_system_prompt_context(system_prompt_context);
101 thread.send_to_model(model, RequestKind::Chat, cx);
102 });
103 })?;
104
105 done_rx.recv().await??;
106
107 let elapsed_time = start_time.elapsed()?;
108
109 let diff = query_git(&self.repo_path, vec!["diff"]).await?;
110
111 assistant.update(cx, |assistant, cx| {
112 let thread = assistant.thread.read(cx);
113 let last_message = thread.messages().last().unwrap();
114 if last_message.role != language_model::Role::Assistant {
115 return Err(anyhow!("Last message is not from assistant"));
116 }
117 let assistant_response_count = thread
118 .messages()
119 .filter(|message| message.role == language_model::Role::Assistant)
120 .count();
121 Ok(EvalOutput {
122 diff,
123 last_message: last_message.to_string(),
124 elapsed_time,
125 assistant_response_count,
126 tool_use_counts: assistant.tool_use_counts.clone(),
127 token_usage: thread.cumulative_token_usage(),
128 })
129 })?
130 })
131 }
132}
133
134impl EvalOutput {
135 // Method to save the output to a directory
136 pub fn save_to_directory(
137 &self,
138 output_dir: &Path,
139 eval_output_value: String,
140 ) -> anyhow::Result<()> {
141 // Create the output directory if it doesn't exist
142 fs::create_dir_all(&output_dir)?;
143
144 // Save the diff to a file
145 let diff_path = output_dir.join("diff.patch");
146 let mut diff_file = fs::File::create(&diff_path)?;
147 diff_file.write_all(self.diff.as_bytes())?;
148
149 // Save the last message to a file
150 let message_path = output_dir.join("assistant_response.txt");
151 let mut message_file = fs::File::create(&message_path)?;
152 message_file.write_all(self.last_message.as_bytes())?;
153
154 // Current metrics for this run
155 let current_metrics = serde_json::json!({
156 "elapsed_time_ms": self.elapsed_time.as_millis(),
157 "assistant_response_count": self.assistant_response_count,
158 "tool_use_counts": self.tool_use_counts,
159 "token_usage": self.token_usage,
160 "eval_output_value": eval_output_value,
161 });
162
163 // Get current timestamp in milliseconds
164 let timestamp = std::time::SystemTime::now()
165 .duration_since(std::time::UNIX_EPOCH)?
166 .as_millis()
167 .to_string();
168
169 // Path to metrics file
170 let metrics_path = output_dir.join("metrics.json");
171
172 // Load existing metrics if the file exists, or create a new object
173 let mut historical_metrics = if metrics_path.exists() {
174 let metrics_content = fs::read_to_string(&metrics_path)?;
175 serde_json::from_str::<serde_json::Value>(&metrics_content)
176 .unwrap_or_else(|_| serde_json::json!({}))
177 } else {
178 serde_json::json!({})
179 };
180
181 // Add new run with timestamp as key
182 if let serde_json::Value::Object(ref mut map) = historical_metrics {
183 map.insert(timestamp, current_metrics);
184 }
185
186 // Write updated metrics back to file
187 let metrics_json = serde_json::to_string_pretty(&historical_metrics)?;
188 let mut metrics_file = fs::File::create(&metrics_path)?;
189 metrics_file.write_all(metrics_json.as_bytes())?;
190
191 Ok(())
192 }
193}
194
195fn repo_dir_name(url: &str) -> String {
196 url.trim_start_matches("https://")
197 .replace(|c: char| !c.is_alphanumeric(), "_")
198}
199
200async fn checkout_repo(eval_setup: &EvalSetup, repo_path: &Path) -> anyhow::Result<()> {
201 if !repo_path.exists() {
202 smol::unblock({
203 let repo_path = repo_path.to_path_buf();
204 || std::fs::create_dir_all(repo_path)
205 })
206 .await?;
207 run_git(repo_path, vec!["init"]).await?;
208 run_git(repo_path, vec!["remote", "add", "origin", &eval_setup.url]).await?;
209 } else {
210 let actual_origin = query_git(repo_path, vec!["remote", "get-url", "origin"]).await?;
211 if actual_origin != eval_setup.url {
212 return Err(anyhow!(
213 "remote origin {} does not match expected origin {}",
214 actual_origin,
215 eval_setup.url
216 ));
217 }
218
219 // TODO: consider including "-x" to remove ignored files. The downside of this is that it will
220 // also remove build artifacts, and so prevent incremental reuse there.
221 run_git(repo_path, vec!["clean", "--force", "-d"]).await?;
222 run_git(repo_path, vec!["reset", "--hard", "HEAD"]).await?;
223 }
224
225 run_git(
226 repo_path,
227 vec!["fetch", "--depth", "1", "origin", &eval_setup.base_sha],
228 )
229 .await?;
230 run_git(repo_path, vec!["checkout", &eval_setup.base_sha]).await?;
231
232 Ok(())
233}
234
235async fn run_git(repo_path: &Path, args: Vec<&str>) -> anyhow::Result<()> {
236 let exit_status = new_smol_command("git")
237 .current_dir(repo_path)
238 .args(args.clone())
239 .status()
240 .await?;
241 if exit_status.success() {
242 Ok(())
243 } else {
244 Err(anyhow!(
245 "`git {}` failed with {}",
246 args.join(" "),
247 exit_status,
248 ))
249 }
250}
251
252async fn query_git(repo_path: &Path, args: Vec<&str>) -> anyhow::Result<String> {
253 let output = new_smol_command("git")
254 .current_dir(repo_path)
255 .args(args.clone())
256 .output()
257 .await?;
258 if output.status.success() {
259 Ok(String::from_utf8(output.stdout)?.trim().to_string())
260 } else {
261 Err(anyhow!(
262 "`git {}` failed with {}",
263 args.join(" "),
264 output.status
265 ))
266 }
267}