eval.rs

  1use crate::headless_assistant::{HeadlessAppState, HeadlessAssistant};
  2use anyhow::anyhow;
  3use assistant2::RequestKind;
  4use collections::HashMap;
  5use gpui::{App, Task};
  6use language_model::{LanguageModel, TokenUsage};
  7use serde::{Deserialize, Serialize};
  8use std::{
  9    fs,
 10    io::Write,
 11    path::{Path, PathBuf},
 12    sync::Arc,
 13    time::Duration,
 14};
 15use util::command::new_smol_command;
 16
 17pub struct Eval {
 18    pub name: String,
 19    pub path: PathBuf,
 20    pub repo_path: PathBuf,
 21    pub eval_setup: EvalSetup,
 22    pub user_prompt: String,
 23}
 24
 25#[derive(Debug, Serialize)]
 26pub struct EvalOutput {
 27    pub diff: String,
 28    pub last_message: String,
 29    pub elapsed_time: Duration,
 30    pub assistant_response_count: usize,
 31    pub tool_use_counts: HashMap<Arc<str>, u32>,
 32    pub token_usage: TokenUsage,
 33}
 34
 35#[derive(Deserialize)]
 36pub struct EvalSetup {
 37    pub url: String,
 38    pub base_sha: String,
 39}
 40
 41impl Eval {
 42    /// Loads the eval from a path (typically in `evaluation_data`). Clones and checks out the repo
 43    /// if necessary.
 44    pub async fn load(name: String, path: PathBuf, repos_dir: &Path) -> anyhow::Result<Self> {
 45        let prompt_path = path.join("prompt.txt");
 46        let user_prompt = smol::unblock(|| std::fs::read_to_string(prompt_path)).await?;
 47        let setup_path = path.join("setup.json");
 48        let setup_contents = smol::unblock(|| std::fs::read_to_string(setup_path)).await?;
 49        let eval_setup = serde_json_lenient::from_str_lenient::<EvalSetup>(&setup_contents)?;
 50        let repo_path = repos_dir.join(repo_dir_name(&eval_setup.url));
 51        Ok(Eval {
 52            name,
 53            path,
 54            repo_path,
 55            eval_setup,
 56            user_prompt,
 57        })
 58    }
 59
 60    pub fn run(
 61        self,
 62        app_state: Arc<HeadlessAppState>,
 63        model: Arc<dyn LanguageModel>,
 64        cx: &mut App,
 65    ) -> Task<anyhow::Result<EvalOutput>> {
 66        cx.spawn(async move |cx| {
 67            checkout_repo(&self.eval_setup, &self.repo_path).await?;
 68
 69            let (assistant, done_rx) =
 70                cx.update(|cx| HeadlessAssistant::new(app_state.clone(), cx))??;
 71
 72            let _worktree = assistant
 73                .update(cx, |assistant, cx| {
 74                    assistant.project.update(cx, |project, cx| {
 75                        project.create_worktree(&self.repo_path, true, cx)
 76                    })
 77                })?
 78                .await?;
 79
 80            let start_time = std::time::SystemTime::now();
 81
 82            assistant.update(cx, |assistant, cx| {
 83                assistant.thread.update(cx, |thread, cx| {
 84                    let context = vec![];
 85                    thread.insert_user_message(self.user_prompt.clone(), context, cx);
 86                    thread.send_to_model(model, RequestKind::Chat, cx);
 87                });
 88            })?;
 89
 90            done_rx.recv().await??;
 91
 92            let elapsed_time = start_time.elapsed()?;
 93
 94            let diff = query_git(&self.repo_path, vec!["diff"]).await?;
 95
 96            assistant.update(cx, |assistant, cx| {
 97                let thread = assistant.thread.read(cx);
 98                let last_message = thread.messages().last().unwrap();
 99                if last_message.role != language_model::Role::Assistant {
100                    return Err(anyhow!("Last message is not from assistant"));
101                }
102                let assistant_response_count = thread
103                    .messages()
104                    .filter(|message| message.role == language_model::Role::Assistant)
105                    .count();
106                Ok(EvalOutput {
107                    diff,
108                    last_message: last_message.text.clone(),
109                    elapsed_time,
110                    assistant_response_count,
111                    tool_use_counts: assistant.tool_use_counts.clone(),
112                    token_usage: thread.cumulative_token_usage(),
113                })
114            })?
115        })
116    }
117}
118
119impl EvalOutput {
120    // Method to save the output to a directory
121    pub fn save_to_directory(
122        &self,
123        output_dir: &Path,
124        eval_output_value: String,
125    ) -> anyhow::Result<()> {
126        // Create the output directory if it doesn't exist
127        fs::create_dir_all(&output_dir)?;
128
129        // Save the diff to a file
130        let diff_path = output_dir.join("diff.patch");
131        let mut diff_file = fs::File::create(&diff_path)?;
132        diff_file.write_all(self.diff.as_bytes())?;
133
134        // Save the last message to a file
135        let message_path = output_dir.join("assistant_response.txt");
136        let mut message_file = fs::File::create(&message_path)?;
137        message_file.write_all(self.last_message.as_bytes())?;
138
139        // Current metrics for this run
140        let current_metrics = serde_json::json!({
141            "elapsed_time_ms": self.elapsed_time.as_millis(),
142            "assistant_response_count": self.assistant_response_count,
143            "tool_use_counts": self.tool_use_counts,
144            "token_usage": self.token_usage,
145            "eval_output_value": eval_output_value,
146        });
147
148        // Get current timestamp in milliseconds
149        let timestamp = std::time::SystemTime::now()
150            .duration_since(std::time::UNIX_EPOCH)?
151            .as_millis()
152            .to_string();
153
154        // Path to metrics file
155        let metrics_path = output_dir.join("metrics.json");
156
157        // Load existing metrics if the file exists, or create a new object
158        let mut historical_metrics = if metrics_path.exists() {
159            let metrics_content = fs::read_to_string(&metrics_path)?;
160            serde_json::from_str::<serde_json::Value>(&metrics_content)
161                .unwrap_or_else(|_| serde_json::json!({}))
162        } else {
163            serde_json::json!({})
164        };
165
166        // Add new run with timestamp as key
167        if let serde_json::Value::Object(ref mut map) = historical_metrics {
168            map.insert(timestamp, current_metrics);
169        }
170
171        // Write updated metrics back to file
172        let metrics_json = serde_json::to_string_pretty(&historical_metrics)?;
173        let mut metrics_file = fs::File::create(&metrics_path)?;
174        metrics_file.write_all(metrics_json.as_bytes())?;
175
176        Ok(())
177    }
178}
179
180fn repo_dir_name(url: &str) -> String {
181    url.trim_start_matches("https://")
182        .replace(|c: char| !c.is_alphanumeric(), "_")
183}
184
185async fn checkout_repo(eval_setup: &EvalSetup, repo_path: &Path) -> anyhow::Result<()> {
186    if !repo_path.exists() {
187        smol::unblock({
188            let repo_path = repo_path.to_path_buf();
189            || std::fs::create_dir_all(repo_path)
190        })
191        .await?;
192        run_git(repo_path, vec!["init"]).await?;
193        run_git(repo_path, vec!["remote", "add", "origin", &eval_setup.url]).await?;
194    } else {
195        let actual_origin = query_git(repo_path, vec!["remote", "get-url", "origin"]).await?;
196        if actual_origin != eval_setup.url {
197            return Err(anyhow!(
198                "remote origin {} does not match expected origin {}",
199                actual_origin,
200                eval_setup.url
201            ));
202        }
203
204        // TODO: consider including "-x" to remove ignored files. The downside of this is that it will
205        // also remove build artifacts, and so prevent incremental reuse there.
206        run_git(repo_path, vec!["clean", "--force", "-d"]).await?;
207        run_git(repo_path, vec!["reset", "--hard", "HEAD"]).await?;
208    }
209
210    run_git(
211        repo_path,
212        vec!["fetch", "--depth", "1", "origin", &eval_setup.base_sha],
213    )
214    .await?;
215    run_git(repo_path, vec!["checkout", &eval_setup.base_sha]).await?;
216
217    Ok(())
218}
219
220async fn run_git(repo_path: &Path, args: Vec<&str>) -> anyhow::Result<()> {
221    let exit_status = new_smol_command("git")
222        .current_dir(repo_path)
223        .args(args.clone())
224        .status()
225        .await?;
226    if exit_status.success() {
227        Ok(())
228    } else {
229        Err(anyhow!(
230            "`git {}` failed with {}",
231            args.join(" "),
232            exit_status,
233        ))
234    }
235}
236
237async fn query_git(repo_path: &Path, args: Vec<&str>) -> anyhow::Result<String> {
238    let output = new_smol_command("git")
239        .current_dir(repo_path)
240        .args(args.clone())
241        .output()
242        .await?;
243    if output.status.success() {
244        Ok(String::from_utf8(output.stdout)?.trim().to_string())
245    } else {
246        Err(anyhow!(
247            "`git {}` failed with {}",
248            args.join(" "),
249            output.status
250        ))
251    }
252}