eval.rs

  1use crate::headless_assistant::{HeadlessAppState, HeadlessAssistant};
  2use agent::RequestKind;
  3use anyhow::anyhow;
  4use collections::HashMap;
  5use gpui::{App, Task};
  6use language_model::{LanguageModel, TokenUsage};
  7use serde::{Deserialize, Serialize};
  8use std::{
  9    fs,
 10    io::Write,
 11    path::{Path, PathBuf},
 12    sync::Arc,
 13    time::Duration,
 14};
 15use util::command::new_smol_command;
 16
 17pub struct Eval {
 18    pub name: String,
 19    pub path: PathBuf,
 20    pub repo_path: PathBuf,
 21    pub eval_setup: EvalSetup,
 22    pub user_prompt: String,
 23}
 24
 25#[derive(Debug, Serialize)]
 26pub struct EvalOutput {
 27    pub diff: String,
 28    pub last_message: String,
 29    pub elapsed_time: Duration,
 30    pub assistant_response_count: usize,
 31    pub tool_use_counts: HashMap<Arc<str>, u32>,
 32    pub token_usage: TokenUsage,
 33}
 34
 35#[derive(Deserialize)]
 36pub struct EvalSetup {
 37    pub url: String,
 38    pub base_sha: String,
 39}
 40
 41impl Eval {
 42    /// Loads the eval from a path (typically in `evaluation_data`). Clones and checks out the repo
 43    /// if necessary.
 44    pub async fn load(name: String, path: PathBuf, repos_dir: &Path) -> anyhow::Result<Self> {
 45        let prompt_path = path.join("prompt.txt");
 46        let user_prompt = smol::unblock(|| std::fs::read_to_string(prompt_path)).await?;
 47        let setup_path = path.join("setup.json");
 48        let setup_contents = smol::unblock(|| std::fs::read_to_string(setup_path)).await?;
 49        let eval_setup = serde_json_lenient::from_str_lenient::<EvalSetup>(&setup_contents)?;
 50        let repo_path = repos_dir.join(repo_dir_name(&eval_setup.url));
 51        Ok(Eval {
 52            name,
 53            path,
 54            repo_path,
 55            eval_setup,
 56            user_prompt,
 57        })
 58    }
 59
 60    pub fn run(
 61        self,
 62        app_state: Arc<HeadlessAppState>,
 63        model: Arc<dyn LanguageModel>,
 64        cx: &mut App,
 65    ) -> Task<anyhow::Result<EvalOutput>> {
 66        cx.spawn(async move |cx| {
 67            checkout_repo(&self.eval_setup, &self.repo_path).await?;
 68
 69            let (assistant, done_rx) =
 70                cx.update(|cx| HeadlessAssistant::new(app_state.clone(), cx))??;
 71
 72            let _worktree = assistant
 73                .update(cx, |assistant, cx| {
 74                    assistant.project.update(cx, |project, cx| {
 75                        project.create_worktree(&self.repo_path, true, cx)
 76                    })
 77                })?
 78                .await?;
 79
 80            let start_time = std::time::SystemTime::now();
 81
 82            let (system_prompt_context, load_error) = cx
 83                .update(|cx| {
 84                    assistant
 85                        .read(cx)
 86                        .thread
 87                        .read(cx)
 88                        .load_system_prompt_context(cx)
 89                })?
 90                .await;
 91
 92            if let Some(load_error) = load_error {
 93                return Err(anyhow!("{:?}", load_error));
 94            };
 95
 96            assistant.update(cx, |assistant, cx| {
 97                assistant.thread.update(cx, |thread, cx| {
 98                    let context = vec![];
 99                    thread.insert_user_message(self.user_prompt.clone(), context, None, cx);
100                    thread.set_system_prompt_context(system_prompt_context);
101                    thread.send_to_model(model, RequestKind::Chat, cx);
102                });
103            })?;
104
105            done_rx.recv().await??;
106
107            let elapsed_time = start_time.elapsed()?;
108
109            let diff = query_git(&self.repo_path, vec!["diff"]).await?;
110
111            assistant.update(cx, |assistant, cx| {
112                let thread = assistant.thread.read(cx);
113                let last_message = thread.messages().last().unwrap();
114                if last_message.role != language_model::Role::Assistant {
115                    return Err(anyhow!("Last message is not from assistant"));
116                }
117                let assistant_response_count = thread
118                    .messages()
119                    .filter(|message| message.role == language_model::Role::Assistant)
120                    .count();
121                Ok(EvalOutput {
122                    diff,
123                    last_message: last_message.to_string(),
124                    elapsed_time,
125                    assistant_response_count,
126                    tool_use_counts: assistant.tool_use_counts.clone(),
127                    token_usage: thread.cumulative_token_usage(),
128                })
129            })?
130        })
131    }
132}
133
134impl EvalOutput {
135    // Method to save the output to a directory
136    pub fn save_to_directory(
137        &self,
138        output_dir: &Path,
139        eval_output_value: String,
140    ) -> anyhow::Result<()> {
141        // Create the output directory if it doesn't exist
142        fs::create_dir_all(&output_dir)?;
143
144        // Save the diff to a file
145        let diff_path = output_dir.join("diff.patch");
146        let mut diff_file = fs::File::create(&diff_path)?;
147        diff_file.write_all(self.diff.as_bytes())?;
148
149        // Save the last message to a file
150        let message_path = output_dir.join("assistant_response.txt");
151        let mut message_file = fs::File::create(&message_path)?;
152        message_file.write_all(self.last_message.as_bytes())?;
153
154        // Current metrics for this run
155        let current_metrics = serde_json::json!({
156            "elapsed_time_ms": self.elapsed_time.as_millis(),
157            "assistant_response_count": self.assistant_response_count,
158            "tool_use_counts": self.tool_use_counts,
159            "token_usage": self.token_usage,
160            "eval_output_value": eval_output_value,
161        });
162
163        // Get current timestamp in milliseconds
164        let timestamp = std::time::SystemTime::now()
165            .duration_since(std::time::UNIX_EPOCH)?
166            .as_millis()
167            .to_string();
168
169        // Path to metrics file
170        let metrics_path = output_dir.join("metrics.json");
171
172        // Load existing metrics if the file exists, or create a new object
173        let mut historical_metrics = if metrics_path.exists() {
174            let metrics_content = fs::read_to_string(&metrics_path)?;
175            serde_json::from_str::<serde_json::Value>(&metrics_content)
176                .unwrap_or_else(|_| serde_json::json!({}))
177        } else {
178            serde_json::json!({})
179        };
180
181        // Add new run with timestamp as key
182        if let serde_json::Value::Object(ref mut map) = historical_metrics {
183            map.insert(timestamp, current_metrics);
184        }
185
186        // Write updated metrics back to file
187        let metrics_json = serde_json::to_string_pretty(&historical_metrics)?;
188        let mut metrics_file = fs::File::create(&metrics_path)?;
189        metrics_file.write_all(metrics_json.as_bytes())?;
190
191        Ok(())
192    }
193}
194
195fn repo_dir_name(url: &str) -> String {
196    url.trim_start_matches("https://")
197        .replace(|c: char| !c.is_alphanumeric(), "_")
198}
199
200async fn checkout_repo(eval_setup: &EvalSetup, repo_path: &Path) -> anyhow::Result<()> {
201    if !repo_path.exists() {
202        smol::unblock({
203            let repo_path = repo_path.to_path_buf();
204            || std::fs::create_dir_all(repo_path)
205        })
206        .await?;
207        run_git(repo_path, vec!["init"]).await?;
208        run_git(repo_path, vec!["remote", "add", "origin", &eval_setup.url]).await?;
209    } else {
210        let actual_origin = query_git(repo_path, vec!["remote", "get-url", "origin"]).await?;
211        if actual_origin != eval_setup.url {
212            return Err(anyhow!(
213                "remote origin {} does not match expected origin {}",
214                actual_origin,
215                eval_setup.url
216            ));
217        }
218
219        // TODO: consider including "-x" to remove ignored files. The downside of this is that it will
220        // also remove build artifacts, and so prevent incremental reuse there.
221        run_git(repo_path, vec!["clean", "--force", "-d"]).await?;
222        run_git(repo_path, vec!["reset", "--hard", "HEAD"]).await?;
223    }
224
225    run_git(
226        repo_path,
227        vec!["fetch", "--depth", "1", "origin", &eval_setup.base_sha],
228    )
229    .await?;
230    run_git(repo_path, vec!["checkout", &eval_setup.base_sha]).await?;
231
232    Ok(())
233}
234
235async fn run_git(repo_path: &Path, args: Vec<&str>) -> anyhow::Result<()> {
236    let exit_status = new_smol_command("git")
237        .current_dir(repo_path)
238        .args(args.clone())
239        .status()
240        .await?;
241    if exit_status.success() {
242        Ok(())
243    } else {
244        Err(anyhow!(
245            "`git {}` failed with {}",
246            args.join(" "),
247            exit_status,
248        ))
249    }
250}
251
252async fn query_git(repo_path: &Path, args: Vec<&str>) -> anyhow::Result<String> {
253    let output = new_smol_command("git")
254        .current_dir(repo_path)
255        .args(args.clone())
256        .output()
257        .await?;
258    if output.status.success() {
259        Ok(String::from_utf8(output.stdout)?.trim().to_string())
260    } else {
261        Err(anyhow!(
262            "`git {}` failed with {}",
263            args.join(" "),
264            output.status
265        ))
266    }
267}