1use crate::headless_assistant::{HeadlessAppState, HeadlessAssistant};
2use anyhow::anyhow;
3use assistant2::RequestKind;
4use collections::HashMap;
5use gpui::{App, Task};
6use language_model::{LanguageModel, TokenUsage};
7use serde::{Deserialize, Serialize};
8use std::{
9 fs,
10 io::Write,
11 path::{Path, PathBuf},
12 sync::Arc,
13 time::Duration,
14};
15use util::command::new_smol_command;
16
17pub struct Eval {
18 pub name: String,
19 pub path: PathBuf,
20 pub repo_path: PathBuf,
21 pub eval_setup: EvalSetup,
22 pub user_prompt: String,
23}
24
25#[derive(Debug, Serialize)]
26pub struct EvalOutput {
27 pub diff: String,
28 pub last_message: String,
29 pub elapsed_time: Duration,
30 pub assistant_response_count: usize,
31 pub tool_use_counts: HashMap<Arc<str>, u32>,
32 pub token_usage: TokenUsage,
33}
34
35#[derive(Deserialize)]
36pub struct EvalSetup {
37 pub url: String,
38 pub base_sha: String,
39}
40
41impl Eval {
42 /// Loads the eval from a path (typically in `evaluation_data`). Clones and checks out the repo
43 /// if necessary.
44 pub async fn load(name: String, path: PathBuf, repos_dir: &Path) -> anyhow::Result<Self> {
45 let prompt_path = path.join("prompt.txt");
46 let user_prompt = smol::unblock(|| std::fs::read_to_string(prompt_path)).await?;
47 let setup_path = path.join("setup.json");
48 let setup_contents = smol::unblock(|| std::fs::read_to_string(setup_path)).await?;
49 let eval_setup = serde_json_lenient::from_str_lenient::<EvalSetup>(&setup_contents)?;
50 let repo_path = repos_dir.join(repo_dir_name(&eval_setup.url));
51 Ok(Eval {
52 name,
53 path,
54 repo_path,
55 eval_setup,
56 user_prompt,
57 })
58 }
59
60 pub fn run(
61 self,
62 app_state: Arc<HeadlessAppState>,
63 model: Arc<dyn LanguageModel>,
64 cx: &mut App,
65 ) -> Task<anyhow::Result<EvalOutput>> {
66 cx.spawn(move |mut cx| async move {
67 checkout_repo(&self.eval_setup, &self.repo_path).await?;
68
69 let (assistant, done_rx) =
70 cx.update(|cx| HeadlessAssistant::new(app_state.clone(), cx))??;
71
72 let _worktree = assistant
73 .update(&mut cx, |assistant, cx| {
74 assistant.project.update(cx, |project, cx| {
75 project.create_worktree(&self.repo_path, true, cx)
76 })
77 })?
78 .await?;
79
80 let start_time = std::time::SystemTime::now();
81
82 assistant.update(&mut cx, |assistant, cx| {
83 assistant.thread.update(cx, |thread, cx| {
84 let context = vec![];
85 thread.insert_user_message(self.user_prompt.clone(), context, cx);
86 thread.send_to_model(model, RequestKind::Chat, cx);
87 });
88 })?;
89
90 done_rx.recv().await??;
91
92 let elapsed_time = start_time.elapsed()?;
93
94 let diff = query_git(&self.repo_path, vec!["diff"]).await?;
95
96 assistant.update(&mut cx, |assistant, cx| {
97 let thread = assistant.thread.read(cx);
98 let last_message = thread.messages().last().unwrap();
99 if last_message.role != language_model::Role::Assistant {
100 return Err(anyhow!("Last message is not from assistant"));
101 }
102 let assistant_response_count = thread
103 .messages()
104 .filter(|message| message.role == language_model::Role::Assistant)
105 .count();
106 Ok(EvalOutput {
107 diff,
108 last_message: last_message.text.clone(),
109 elapsed_time,
110 assistant_response_count,
111 tool_use_counts: assistant.tool_use_counts.clone(),
112 token_usage: thread.cumulative_token_usage(),
113 })
114 })?
115 })
116 }
117}
118
119impl EvalOutput {
120 // Method to save the output to a directory
121 pub fn save_to_directory(
122 &self,
123 output_dir: &Path,
124 eval_output_value: String,
125 ) -> anyhow::Result<()> {
126 // Create the output directory if it doesn't exist
127 fs::create_dir_all(&output_dir)?;
128
129 // Save the diff to a file
130 let diff_path = output_dir.join("diff.patch");
131 let mut diff_file = fs::File::create(&diff_path)?;
132 diff_file.write_all(self.diff.as_bytes())?;
133
134 // Save the last message to a file
135 let message_path = output_dir.join("assistant_response.txt");
136 let mut message_file = fs::File::create(&message_path)?;
137 message_file.write_all(self.last_message.as_bytes())?;
138
139 // Current metrics for this run
140 let current_metrics = serde_json::json!({
141 "elapsed_time_ms": self.elapsed_time.as_millis(),
142 "assistant_response_count": self.assistant_response_count,
143 "tool_use_counts": self.tool_use_counts,
144 "token_usage": self.token_usage,
145 "eval_output_value": eval_output_value,
146 });
147
148 // Get current timestamp in milliseconds
149 let timestamp = std::time::SystemTime::now()
150 .duration_since(std::time::UNIX_EPOCH)?
151 .as_millis()
152 .to_string();
153
154 // Path to metrics file
155 let metrics_path = output_dir.join("metrics.json");
156
157 // Load existing metrics if the file exists, or create a new object
158 let mut historical_metrics = if metrics_path.exists() {
159 let metrics_content = fs::read_to_string(&metrics_path)?;
160 serde_json::from_str::<serde_json::Value>(&metrics_content)
161 .unwrap_or_else(|_| serde_json::json!({}))
162 } else {
163 serde_json::json!({})
164 };
165
166 // Add new run with timestamp as key
167 if let serde_json::Value::Object(ref mut map) = historical_metrics {
168 map.insert(timestamp, current_metrics);
169 }
170
171 // Write updated metrics back to file
172 let metrics_json = serde_json::to_string_pretty(&historical_metrics)?;
173 let mut metrics_file = fs::File::create(&metrics_path)?;
174 metrics_file.write_all(metrics_json.as_bytes())?;
175
176 Ok(())
177 }
178}
179
180fn repo_dir_name(url: &str) -> String {
181 url.trim_start_matches("https://")
182 .replace(|c: char| !c.is_alphanumeric(), "_")
183}
184
185async fn checkout_repo(eval_setup: &EvalSetup, repo_path: &Path) -> anyhow::Result<()> {
186 if !repo_path.exists() {
187 smol::unblock({
188 let repo_path = repo_path.to_path_buf();
189 || std::fs::create_dir_all(repo_path)
190 })
191 .await?;
192 run_git(repo_path, vec!["init"]).await?;
193 run_git(repo_path, vec!["remote", "add", "origin", &eval_setup.url]).await?;
194 } else {
195 let actual_origin = query_git(repo_path, vec!["remote", "get-url", "origin"]).await?;
196 if actual_origin != eval_setup.url {
197 return Err(anyhow!(
198 "remote origin {} does not match expected origin {}",
199 actual_origin,
200 eval_setup.url
201 ));
202 }
203
204 // TODO: consider including "-x" to remove ignored files. The downside of this is that it will
205 // also remove build artifacts, and so prevent incremental reuse there.
206 run_git(repo_path, vec!["clean", "--force", "-d"]).await?;
207 run_git(repo_path, vec!["reset", "--hard", "HEAD"]).await?;
208 }
209
210 run_git(
211 repo_path,
212 vec!["fetch", "--depth", "1", "origin", &eval_setup.base_sha],
213 )
214 .await?;
215 run_git(repo_path, vec!["checkout", &eval_setup.base_sha]).await?;
216
217 Ok(())
218}
219
220async fn run_git(repo_path: &Path, args: Vec<&str>) -> anyhow::Result<()> {
221 let exit_status = new_smol_command("git")
222 .current_dir(repo_path)
223 .args(args.clone())
224 .status()
225 .await?;
226 if exit_status.success() {
227 Ok(())
228 } else {
229 Err(anyhow!(
230 "`git {}` failed with {}",
231 args.join(" "),
232 exit_status,
233 ))
234 }
235}
236
237async fn query_git(repo_path: &Path, args: Vec<&str>) -> anyhow::Result<String> {
238 let output = new_smol_command("git")
239 .current_dir(repo_path)
240 .args(args.clone())
241 .output()
242 .await?;
243 if output.status.success() {
244 Ok(String::from_utf8(output.stdout)?.trim().to_string())
245 } else {
246 Err(anyhow!(
247 "`git {}` failed with {}",
248 args.join(" "),
249 output.status
250 ))
251 }
252}