Add judge to new eval + provide LSP diagnostics (#28713)

Michael Sloan , Antonio Scandurra , and agus created

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <antonio@zed.dev>
Co-authored-by: agus <agus@zed.dev>

Change summary

Cargo.lock                                                  |  11 
crates/agent/src/thread.rs                                  |  16 
crates/eval/.gitignore                                      |   3 
crates/eval/Cargo.toml                                      |  13 
crates/eval/examples/find_and_replace_diff_card/base.toml   |   3 
crates/eval/examples/find_and_replace_diff_card/criteria.md |   2 
crates/eval/examples/find_and_replace_diff_card/rubric.md   |   0 
crates/eval/src/eval.rs                                     | 247 ++
crates/eval/src/example.rs                                  | 600 ++++++
crates/eval/src/judge_prompt.hbs                            |  25 
crates/language_model/src/language_model.rs                 |   2 
11 files changed, 838 insertions(+), 84 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -4878,25 +4878,36 @@ dependencies = [
  "assistant_settings",
  "assistant_tool",
  "assistant_tools",
+ "async-watch",
+ "chrono",
+ "clap",
  "client",
  "context_server",
  "dap",
  "env_logger 0.11.8",
+ "extension",
  "fs",
  "futures 0.3.31",
  "gpui",
  "gpui_tokio",
+ "handlebars 4.5.0",
  "language",
+ "language_extension",
  "language_model",
  "language_models",
+ "languages",
  "node_runtime",
+ "paths",
  "project",
  "prompt_store",
  "release_channel",
  "reqwest_client",
  "serde",
  "settings",
+ "shellexpand 2.1.2",
  "toml 0.8.20",
+ "unindent",
+ "util",
  "workspace-hack",
 ]
 

crates/agent/src/thread.rs 🔗

@@ -827,7 +827,7 @@ impl Thread {
                     })
                     .collect(),
                 initial_project_snapshot,
-                cumulative_token_usage: this.cumulative_token_usage.clone(),
+                cumulative_token_usage: this.cumulative_token_usage,
                 detailed_summary_state: this.detailed_summary_state.clone(),
                 exceeded_window_error: this.exceeded_window_error.clone(),
             })
@@ -1016,7 +1016,7 @@ impl Thread {
         let task = cx.spawn(async move |thread, cx| {
             let stream = model.stream_completion(request, &cx);
             let initial_token_usage =
-                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage.clone());
+                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
             let stream_completion = async {
                 let mut events = stream.await?;
                 let mut stop_reason = StopReason::EndTurn;
@@ -1038,9 +1038,9 @@ impl Thread {
                                 stop_reason = reason;
                             }
                             LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
-                                thread.cumulative_token_usage =
-                                    thread.cumulative_token_usage.clone() + token_usage.clone()
-                                        - current_token_usage.clone();
+                                thread.cumulative_token_usage = thread.cumulative_token_usage
+                                    + token_usage
+                                    - current_token_usage;
                                 current_token_usage = token_usage;
                             }
                             LanguageModelCompletionEvent::Text(chunk) => {
@@ -1183,7 +1183,7 @@ impl Thread {
                     thread.auto_capture_telemetry(cx);
 
                     if let Ok(initial_usage) = initial_token_usage {
-                        let usage = thread.cumulative_token_usage.clone() - initial_usage;
+                        let usage = thread.cumulative_token_usage - initial_usage;
 
                         telemetry::event!(
                             "Assistant Thread Completion",
@@ -1862,6 +1862,10 @@ impl Thread {
             .detach();
     }
 
+    pub fn cumulative_token_usage(&self) -> TokenUsage {
+        self.cumulative_token_usage
+    }
+
     pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
         let model_registry = LanguageModelRegistry::read_global(cx);
         let Some(model) = model_registry.default_model() else {

crates/eval/Cargo.toml 🔗

@@ -7,28 +7,39 @@ edition.workspace = true
 [dependencies]
 agent.workspace = true
 anyhow.workspace = true
+async-watch.workspace = true
+assistant_settings.workspace = true
 assistant_tool.workspace = true
 assistant_tools.workspace = true
-assistant_settings.workspace = true
+chrono.workspace = true
+clap.workspace = true
 client.workspace = true
 context_server.workspace = true
 dap.workspace = true
 env_logger.workspace = true
+extension.workspace = true
 fs.workspace = true
 futures.workspace = true
 gpui.workspace = true
 gpui_tokio.workspace = true
+handlebars.workspace = true
 language.workspace = true
+language_extension.workspace = true
 language_model.workspace = true
 language_models.workspace = true
+languages.workspace = true
 node_runtime.workspace = true
+paths.workspace = true
 project.workspace = true
 prompt_store.workspace = true
 release_channel.workspace = true
 reqwest_client.workspace = true
 serde.workspace = true
 settings.workspace = true
+shellexpand.workspace = true
 toml.workspace = true
+unindent.workspace = true
+util.workspace = true
 workspace-hack.workspace = true
 
 [[bin]]

crates/eval/examples/find_and_replace_diff_card/criteria.md 🔗

@@ -0,0 +1,2 @@
+1. The changes must replace the previous output returned by `FindReplaceFileTool` with the new `ToolResult` struct. The struct should contain an `output` field that is the same as the string we were returning before, and a new `card` field that contains a view for the card
+2. The card should be a view that displays a diff. Each line in the diff should be colored according to whether it was added, removed or unchanged.

crates/eval/src/eval.rs 🔗

@@ -1,32 +1,75 @@
 mod example;
 
 use assistant_settings::AssistantSettings;
-use client::{Client, UserStore};
+use client::{Client, ProxySettings, UserStore};
 pub(crate) use example::*;
 
 use ::fs::RealFs;
-use anyhow::anyhow;
-use gpui::{App, AppContext, Application, Entity, SemanticVersion, Task};
+use anyhow::{Result, anyhow};
+use clap::Parser;
+use extension::ExtensionHostProxy;
+use futures::future;
+use gpui::http_client::{Uri, read_proxy_from_env};
+use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, Task};
+use gpui_tokio::Tokio;
 use language::LanguageRegistry;
 use language_model::{
     AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
 };
-use node_runtime::NodeRuntime;
+use node_runtime::{NodeBinaryOptions, NodeRuntime};
 use project::Project;
+use project::project_settings::ProjectSettings;
 use prompt_store::PromptBuilder;
+use release_channel::AppVersion;
 use reqwest_client::ReqwestClient;
 use settings::{Settings, SettingsStore};
+use std::collections::HashSet;
+use std::path::{Path, PathBuf};
 use std::sync::Arc;
+use util::ResultExt as _;
+
+pub const RUNS_DIR: &str = "./crates/eval/runs";
+
+#[derive(Parser, Debug)]
+#[command(name = "eval", disable_version_flag = true)]
+struct Args {
+    /// Runs all examples that contain these substrings. If unspecified, all examples are run.
+    #[arg(value_name = "EXAMPLE_SUBSTRING")]
+    examples: Vec<String>,
+    /// Model to use (default: "claude-3-7-sonnet-latest")
+    #[arg(long, default_value = "claude-3-7-sonnet-latest")]
+    model: String,
+}
 
 fn main() {
     env_logger::init();
+
+    let args = Args::parse();
+    let all_available_examples = list_all_examples().unwrap();
+    let example_paths = all_available_examples
+        .iter()
+        .filter_map(|example_path| {
+            let name = example_path.file_name()?.to_string_lossy();
+            if args.examples.is_empty()
+                || args
+                    .examples
+                    .iter()
+                    .any(|name_substring| name.contains(name_substring))
+            {
+                Some(example_path.clone())
+            } else {
+                None
+            }
+        })
+        .collect::<Vec<_>>();
+
     let http_client = Arc::new(ReqwestClient::new());
     let app = Application::headless().with_http_client(http_client.clone());
 
     app.run(move |cx| {
         let app_state = init(cx);
 
-        let model = find_model("claude-3-7-sonnet-thinking-latest", cx).unwrap();
+        let model = find_model("claude-3-7-sonnet-latest", cx).unwrap();
 
         LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
             registry.set_default_model(Some(model.clone()), cx);
@@ -39,17 +82,142 @@ fn main() {
         cx.spawn(async move |cx| {
             authenticate.await.unwrap();
 
-            let example =
-                Example::load_from_directory("./crates/eval/examples/find_and_replace_diff_card")?;
-            example.setup()?;
-            cx.update(|cx| example.run(model, app_state, cx))?.await?;
+            std::fs::create_dir_all(REPOS_DIR)?;
+            std::fs::create_dir_all(WORKTREES_DIR)?;
+
+            let run_dir = Path::new(RUNS_DIR).join(format!(
+                "{}",
+                chrono::Local::now().format("%Y-%m-%d_%H-%M-%S")
+            ));
+            std::fs::create_dir_all(&run_dir)?;
+
+            let mut examples = Vec::new();
+            for example_path in example_paths {
+                let example = Example::load_from_directory(&example_path, &run_dir)?;
+                examples.push((example_path, example));
+            }
+            let mut repo_urls = HashSet::new();
+
+            let mut clone_tasks = Vec::new();
+
+            for (_, example) in examples.iter() {
+                let repo_url = example.base.url.clone();
+                if repo_urls.insert(repo_url.clone()) {
+                    let repo_path = repo_path_for_url(&repo_url);
+
+                    if !repo_path.join(".git").is_dir() {
+                        println!("Cloning: {}", repo_url);
+
+                        let git_task = cx.spawn(async move |_cx| {
+                            std::fs::create_dir_all(&repo_path)?;
+                            run_git(&repo_path, &["init"]).await?;
+                            run_git(&repo_path, &["remote", "add", "origin", &repo_url]).await
+                        });
+
+                        clone_tasks.push(git_task);
+                    } else {
+                        println!("Already cloned: {}", repo_url);
+
+                        let actual_origin =
+                            run_git(&repo_path, &["remote", "get-url", "origin"]).await?;
+                        if actual_origin != repo_url {
+                            return Err(anyhow!(
+                                "remote origin {} does not match expected origin {}",
+                                actual_origin,
+                                repo_url,
+                            ));
+                        }
+                    }
+                }
+            }
+
+            future::join_all(clone_tasks).await;
+
+            let tasks = examples
+                .into_iter()
+                .map(|(example_path, example)| {
+                    let app_state = app_state.clone();
+                    let model = model.clone();
+                    cx.spawn(async move |cx| {
+                        (
+                            example_path,
+                            run_example(example, model, app_state, cx).await,
+                        )
+                    })
+                })
+                .collect::<Vec<_>>();
+
+            let results: Vec<(PathBuf, Result<JudgeOutput>)> = future::join_all(tasks).await;
+
+            println!("\n\n");
+            println!("========================================");
+            println!("              EVAL RESULTS              ");
+            println!("========================================");
+            println!("");
 
-            anyhow::Ok(())
+            let mut judge_scores = Vec::new();
+
+            for (example_path, result) in results {
+                let example_name = example_path.file_name().unwrap().to_string_lossy();
+                match result {
+                    Err(err) => {
+                        println!("💥 {:<30}: {:?}", example_name, err);
+                    }
+                    Ok(judge_output) => {
+                        const SCORES: [&str; 6] = ["💀", "😭", "😔", "😐", "🙂", "🤩"];
+
+                        println!(
+                            "{} {:<30}: {}",
+                            SCORES[judge_output.score.min(5) as usize],
+                            example_name,
+                            judge_output.score,
+                        );
+                        judge_scores.push(judge_output.score);
+                    }
+                }
+            }
+
+            let score_count = judge_scores.len();
+            let average_score = judge_scores
+                .into_iter()
+                .map(|score| score as f32)
+                .sum::<f32>()
+                / (score_count as f32);
+            println!("\nAverage score: {average_score}");
+
+            cx.update(|cx| cx.quit())
         })
         .detach_and_log_err(cx);
     });
 }
 
+async fn run_example(
+    mut example: Example,
+    model: Arc<dyn LanguageModel>,
+    app_state: Arc<AgentAppState>,
+    cx: &mut AsyncApp,
+) -> Result<JudgeOutput> {
+    example.setup().await?;
+    cx.update(|cx| example.run(model.clone(), app_state, cx))?
+        .await?;
+    let diff = example.repository_diff().await?;
+    example.judge(model, diff, cx).await
+}
+
+fn list_all_examples() -> Result<Vec<PathBuf>> {
+    let path = std::fs::canonicalize(EXAMPLES_DIR).unwrap();
+    let entries = std::fs::read_dir(path).unwrap();
+    let mut result_paths = Vec::new();
+    for entry in entries {
+        let entry = entry?;
+        let path = entry.path();
+        if path.is_dir() {
+            result_paths.push(path);
+        }
+    }
+    Ok(result_paths)
+}
+
 /// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
 pub struct AgentAppState {
     pub languages: Arc<LanguageRegistry>,
@@ -72,6 +240,27 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
         .unwrap();
     cx.set_global(settings_store);
     client::init_settings(cx);
+
+    // Set User-Agent so we can download language servers from GitHub
+    let user_agent = format!(
+        "Zed/{} ({}; {})",
+        AppVersion::global(cx),
+        std::env::consts::OS,
+        std::env::consts::ARCH
+    );
+    let proxy_str = ProxySettings::get_global(cx).proxy.to_owned();
+    let proxy_url = proxy_str
+        .as_ref()
+        .and_then(|input| input.parse::<Uri>().ok())
+        .or_else(read_proxy_from_env);
+    let http = {
+        let _guard = Tokio::handle(cx).enter();
+
+        ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent)
+            .expect("could not start HTTP client")
+    };
+    cx.set_http_client(Arc::new(http));
+
     Project::init_settings(cx);
 
     let client = Client::production(cx);
@@ -83,13 +272,47 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
         cx.background_executor().clone(),
     ));
 
-    let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone()));
+    let mut languages = LanguageRegistry::new(cx.background_executor().clone());
+    languages.set_language_server_download_dir(paths::languages_dir().clone());
+    let languages = Arc::new(languages);
 
     let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
 
+    extension::init(cx);
+
+    let (tx, rx) = async_watch::channel(None);
+    cx.observe_global::<SettingsStore>(move |cx| {
+        let settings = &ProjectSettings::get_global(cx).node;
+        let options = NodeBinaryOptions {
+            allow_path_lookup: !settings.ignore_system_version.unwrap_or_default(),
+            allow_binary_download: true,
+            use_paths: settings.path.as_ref().map(|node_path| {
+                let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref());
+                let npm_path = settings
+                    .npm_path
+                    .as_ref()
+                    .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref()));
+                (
+                    node_path.clone(),
+                    npm_path.unwrap_or_else(|| {
+                        let base_path = PathBuf::new();
+                        node_path.parent().unwrap_or(&base_path).join("npm")
+                    }),
+                )
+            }),
+        };
+        tx.send(Some(options)).log_err();
+    })
+    .detach();
+    let node_runtime = NodeRuntime::new(client.http_client().clone(), rx);
+
+    let extension_host_proxy = ExtensionHostProxy::global(cx);
+
     language::init(cx);
+    language_extension::init(extension_host_proxy.clone(), languages.clone());
     language_model::init(client.clone(), cx);
     language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
+    languages::init(languages.clone(), node_runtime.clone(), cx);
     assistant_tools::init(client.http_client().clone(), cx);
     context_server::init(cx);
     let stdout_is_a_pty = false;
@@ -109,7 +332,7 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
         client,
         user_store,
         fs,
-        node_runtime: NodeRuntime::unavailable(),
+        node_runtime,
         prompt_builder,
     })
 }

crates/eval/src/example.rs 🔗

@@ -1,83 +1,161 @@
 use agent::{RequestKind, ThreadEvent, ThreadStore};
-use anyhow::{Result, anyhow};
+use anyhow::{Context as _, Result, anyhow};
 use assistant_tool::ToolWorkingSet;
+use client::proto::LspWorkProgress;
 use dap::DapRegistry;
-use futures::channel::oneshot;
-use gpui::{App, Task};
-use language_model::{LanguageModel, StopReason};
-use project::Project;
-use serde::Deserialize;
-use std::process::Command;
-use std::sync::Arc;
+use futures::channel::{mpsc, oneshot};
+use futures::{FutureExt, StreamExt as _};
+use gpui::{App, AsyncApp, Entity, Task};
+use handlebars::Handlebars;
+use language::{DiagnosticSeverity, OffsetRangeExt};
+use language_model::{
+    LanguageModel, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
+    StopReason, TokenUsage,
+};
+use project::{LspStore, Project, ProjectPath};
+use serde::{Deserialize, Serialize};
+use std::fmt::Write as _;
+use std::fs::File;
+use std::io::Write as _;
+use std::sync::{Arc, Mutex};
+use std::time::Duration;
 use std::{
     fs,
     path::{Path, PathBuf},
 };
+use unindent::Unindent as _;
+use util::ResultExt as _;
+use util::command::new_smol_command;
+use util::serde::default_true;
 
 use crate::AgentAppState;
 
-#[derive(Debug, Deserialize)]
+pub const EXAMPLES_DIR: &str = "./crates/eval/examples";
+pub const REPOS_DIR: &str = "./crates/eval/repos";
+pub const WORKTREES_DIR: &str = "./crates/eval/worktrees";
+
+#[derive(Clone, Debug, Deserialize)]
 pub struct ExampleBase {
-    pub path: PathBuf,
+    pub url: String,
     pub revision: String,
+    pub language_extension: Option<String>,
+    pub insert_id: Option<String>,
+    #[serde(default = "default_true")]
+    pub require_lsp: bool,
 }
 
-#[derive(Debug)]
+#[derive(Clone, Debug)]
 pub struct Example {
+    pub name: String,
+    /// Content of `base.toml`
     pub base: ExampleBase,
-
-    /// Content of the prompt.md file
+    /// Content of `prompt.md`
     pub prompt: String,
+    /// Content of `criteria.md`
+    pub criteria: String,
+    /// Markdown log file to append to
+    pub log_file: Arc<Mutex<File>>,
+}
+
+#[derive(Debug, Serialize, Deserialize, Clone)]
+pub struct RunOutput {
+    pub repository_diff: String,
+    pub diagnostics: String,
+    pub response_count: usize,
+    pub token_usage: TokenUsage,
+}
 
-    /// Content of the rubric.md file
-    pub _rubric: String,
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct JudgeInput {
+    pub repository_diff: String,
+    pub criteria: String,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct JudgeOutput {
+    pub analysis: String,
+    pub score: u32,
 }
 
 impl Example {
-    /// Load an example from a directory containing base.toml, prompt.md, and rubric.md
-    pub fn load_from_directory<P: AsRef<Path>>(dir_path: P) -> Result<Self> {
-        let base_path = dir_path.as_ref().join("base.toml");
-        let prompt_path = dir_path.as_ref().join("prompt.md");
-        let rubric_path = dir_path.as_ref().join("rubric.md");
+    /// Load an example from a directory containing base.toml, prompt.md, and criteria.md
+    pub fn load_from_directory(dir_path: &Path, run_dir: &Path) -> Result<Self> {
+        let name = dir_path.file_name().unwrap().to_string_lossy().to_string();
+        let base_path = dir_path.join("base.toml");
+        let prompt_path = dir_path.join("prompt.md");
+        let criteria_path = dir_path.join("criteria.md");
 
-        let mut base: ExampleBase = toml::from_str(&fs::read_to_string(&base_path)?)?;
-        base.path = base.path.canonicalize()?;
+        let log_file_path = run_dir.join(format!(
+            "{}.md",
+            dir_path.file_name().unwrap().to_str().unwrap()
+        ));
+        let log_file = Arc::new(Mutex::new(File::create(&log_file_path).unwrap()));
+        println!("{}> Logging to {:?}", name, log_file_path);
 
         Ok(Example {
-            base,
-            prompt: fs::read_to_string(prompt_path)?,
-            _rubric: fs::read_to_string(rubric_path)?,
+            name,
+            base: toml::from_str(&fs::read_to_string(&base_path)?)?,
+            prompt: fs::read_to_string(prompt_path.clone())?,
+            criteria: fs::read_to_string(criteria_path.clone())?,
+            log_file,
         })
     }
 
+    pub fn worktree_path(&self) -> PathBuf {
+        Path::new(WORKTREES_DIR)
+            .canonicalize()
+            .context(format!("No such directory {WORKTREES_DIR}"))
+            .unwrap()
+            .join(&self.name)
+    }
+
     /// Set up the example by checking out the specified Git revision
-    pub fn setup(&self) -> Result<()> {
-        // Check if the directory exists
-        let path = Path::new(&self.base.path);
-        anyhow::ensure!(path.exists(), "Path does not exist: {:?}", self.base.path);
-
-        // Change to the project directory and checkout the specified revision
-        let output = Command::new("git")
-            .current_dir(&self.base.path)
-            .arg("checkout")
-            .arg(&self.base.revision)
-            .output()?;
-        anyhow::ensure!(
-            output.status.success(),
-            "Failed to checkout revision {}: {}",
-            self.base.revision,
-            String::from_utf8_lossy(&output.stderr),
-        );
+    pub async fn setup(&self) -> Result<()> {
+        let repo_path = repo_path_for_url(&self.base.url);
+
+        run_git(
+            &repo_path,
+            &["fetch", "--depth", "1", "origin", &self.base.revision],
+        )
+        .await?;
+
+        let worktree_path = self.worktree_path();
+
+        if worktree_path.is_dir() {
+            println!("{}> Resetting existing worktree", self.name);
+
+            // TODO: consider including "-x" to remove ignored files. The downside of this is that
+            // it will also remove build artifacts, and so prevent incremental reuse there.
+            run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
+            run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
+            run_git(&worktree_path, &["checkout", &self.base.revision]).await?;
+        } else {
+            println!("{}> Creating worktree", self.name);
+
+            let worktree_path_string = worktree_path.to_string_lossy().to_string();
+
+            run_git(
+                &repo_path,
+                &[
+                    "worktree",
+                    "add",
+                    "-f",
+                    &worktree_path_string,
+                    &self.base.revision,
+                ],
+            )
+            .await?;
+        }
 
         Ok(())
     }
 
     pub fn run(
-        self,
+        &self,
         model: Arc<dyn LanguageModel>,
         app_state: Arc<AgentAppState>,
         cx: &mut App,
-    ) -> Task<Result<()>> {
+    ) -> Task<Result<RunOutput>> {
         let project = Project::local(
             app_state.client.clone(),
             app_state.node_runtime.clone(),
@@ -89,30 +167,119 @@ impl Example {
             cx,
         );
 
+        let worktree_path = self.worktree_path();
         let worktree = project.update(cx, |project, cx| {
-            project.create_worktree(self.base.path, true, cx)
+            project.create_worktree(&worktree_path, true, cx)
         });
 
         let tools = Arc::new(ToolWorkingSet::default());
         let thread_store =
             ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx);
+        let this = self.clone();
 
-        println!("USER:");
-        println!("{}", self.prompt);
-        println!("ASSISTANT:");
         cx.spawn(async move |cx| {
-            worktree.await?;
+            let worktree = worktree.await?;
+
+            // Wait for worktree scan to finish before choosing a file to open.
+            worktree
+                .update(cx, |worktree, _cx| {
+                    worktree.as_local().unwrap().scan_complete()
+                })?
+                .await;
+
+            let lsp_open_handle_and_store = if this.base.require_lsp {
+                let language_extension = this.base.language_extension.as_deref().context(
+                    "language_extension field is required in base.toml when `require_lsp == true`",
+                )?;
+
+                // Open a file that matches the language to cause LSP to start.
+                let language_file = worktree.read_with(cx, |worktree, _cx| {
+                    worktree
+                        .files(false, 0)
+                        .find_map(|e| {
+                            if e.path.clone().extension().and_then(|ext| ext.to_str())
+                                == Some(language_extension)
+                            {
+                                Some(ProjectPath {
+                                    worktree_id: worktree.id(),
+                                    path: e.path.clone(),
+                                })
+                            } else {
+                                None
+                            }
+                        })
+                        .context("Failed to find a file for example language")
+                })??;
+
+                let open_language_file_buffer_task = project.update(cx, |project, cx| {
+                    project.open_buffer(language_file.clone(), cx)
+                })?;
+
+                let language_file_buffer = open_language_file_buffer_task.await?;
+
+                let (lsp_open_handle, lsp_store) = project.update(cx, |project, cx| {
+                    (
+                        project.register_buffer_with_language_servers(&language_file_buffer, cx),
+                        project.lsp_store().clone(),
+                    )
+                })?;
+
+                // TODO: remove this once the diagnostics tool waits for new diagnostics
+                cx.background_executor().timer(Duration::new(5, 0)).await;
+                wait_for_lang_server(&lsp_store, this.name.clone(), cx).await?;
+
+                lsp_store.update(cx, |lsp_store, cx| {
+                    lsp_open_handle.update(cx, |buffer, cx| {
+                        buffer.update(cx, |buffer, cx| {
+                            let has_language_server = lsp_store
+                                .language_servers_for_local_buffer(buffer, cx)
+                                .next()
+                                .is_some();
+                            if has_language_server {
+                                Ok(())
+                            } else {
+                                Err(anyhow!(
+                                    "`{:?}` was opened to cause the language server to start, \
+                                    but no language servers are registered for its buffer. \
+                                    Set `require_lsp = false` in `base.toml` to skip this.",
+                                    language_file
+                                ))
+                            }
+                        })
+                    })
+                })??;
+
+                Some((lsp_open_handle, lsp_store))
+            } else {
+                None
+            };
+
+            if std::env::var("ZED_EVAL_SETUP_ONLY").is_ok() {
+                return Err(anyhow!("Setup only mode"));
+            }
+
             let thread_store = thread_store.await;
             let thread =
                 thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?;
 
+            {
+                let mut log_file = this.log_file.lock().unwrap();
+                writeln!(&mut log_file, "👤 USER:").log_err();
+                writeln!(&mut log_file, "{}", this.prompt).log_err();
+                writeln!(&mut log_file, "🤖 ASSISTANT:").log_err();
+                log_file.flush().log_err();
+            }
+
             let (tx, rx) = oneshot::channel();
             let mut tx = Some(tx);
 
-            let _subscription =
-                cx.subscribe(
-                    &thread,
-                    move |thread, event: &ThreadEvent, cx| match event {
+            let _subscription = cx.subscribe(&thread, {
+                let log_file = this.log_file.clone();
+                let name = this.name.clone();
+                move |thread, event: &ThreadEvent, cx| {
+                    let mut log_file = log_file.lock().unwrap();
+
+                    match event {
                         ThreadEvent::Stopped(reason) => match reason {
                             Ok(StopReason::EndTurn) => {
                                 if let Some(tx) = tx.take() {
@@ -137,15 +304,16 @@ impl Example {
                             }
                         }
                         ThreadEvent::StreamedAssistantText(_, chunk) => {
-                            print!("{}", chunk);
+                            write!(&mut log_file, "{}", chunk).log_err();
                         }
                         ThreadEvent::StreamedAssistantThinking(_, chunk) => {
-                            print!("{}", chunk);
+                            write!(&mut log_file, "{}", chunk).log_err();
                         }
                         ThreadEvent::UsePendingTools { tool_uses } => {
-                            println!("\n\nUSING TOOLS:");
+                            writeln!(&mut log_file, "\n\nUSING TOOLS:").log_err();
                             for tool_use in tool_uses {
-                                println!("{}: {}", tool_use.name, tool_use.input);
+                                writeln!(&mut log_file, "{}: {}", tool_use.name, tool_use.input)
+                                    .log_err();
                             }
                         }
                         ThreadEvent::ToolFinished {
@@ -154,25 +322,331 @@ impl Example {
                             ..
                         } => {
                             if let Some(tool_use) = pending_tool_use {
-                                println!("\nTOOL FINISHED: {}", tool_use.name);
+                                let message = format!("TOOL FINISHED: {}", tool_use.name);
+                                println!("{name}> {message}");
+                                writeln!(&mut log_file, "\n{}", message).log_err();
                             }
                             if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
-                                println!("\n{}\n", tool_result.content);
+                                let message = format!("\n{}\n", tool_result.content);
+                                writeln!(&mut log_file, "{}", message).log_err();
                             }
                         }
                         _ => {}
-                    },
-                )?;
+                    }
+
+                    log_file.flush().log_err();
+                }
+            })?;
 
             thread.update(cx, |thread, cx| {
                 let context = vec![];
-                thread.insert_user_message(self.prompt.clone(), context, None, cx);
+                thread.insert_user_message(this.prompt.clone(), context, None, cx);
                 thread.send_to_model(model, RequestKind::Chat, cx);
             })?;
 
             rx.await??;
 
-            Ok(())
+            if let Some((_, lsp_store)) = lsp_open_handle_and_store.as_ref() {
+                wait_for_lang_server(lsp_store, this.name.clone(), cx).await?;
+            }
+
+            let repository_diff = this.repository_diff().await?;
+            let diagnostics = cx
+                .update(move |cx| {
+                    cx.spawn(async move |cx| query_lsp_diagnostics(project, cx).await)
+                })?
+                .await?;
+
+            drop(lsp_open_handle_and_store);
+
+            thread.update(cx, |thread, _cx| {
+                let response_count = thread
+                    .messages()
+                    .filter(|message| message.role == language_model::Role::Assistant)
+                    .count();
+                RunOutput {
+                    repository_diff,
+                    diagnostics,
+                    response_count,
+                    token_usage: thread.cumulative_token_usage(),
+                }
+            })
         })
     }
+
+    pub async fn judge(
+        &mut self,
+        model: Arc<dyn LanguageModel>,
+        repository_diff: String,
+        cx: &AsyncApp,
+    ) -> Result<JudgeOutput> {
+        let judge_prompt = include_str!("judge_prompt.hbs");
+        let judge_prompt_name = "judge_prompt";
+        let mut handlebars = Handlebars::new();
+        handlebars.register_template_string(judge_prompt_name, judge_prompt)?;
+        let prompt = handlebars.render(
+            judge_prompt_name,
+            &JudgeInput {
+                repository_diff,
+                criteria: self.criteria.clone(),
+            },
+        )?;
+
+        let request = LanguageModelRequest {
+            messages: vec![LanguageModelRequestMessage {
+                role: Role::User,
+                content: vec![MessageContent::Text(prompt)],
+                cache: false,
+            }],
+            temperature: None,
+            tools: Vec::new(),
+            stop: Vec::new(),
+        };
+
+        let response = send_language_model_request(model, request, cx).await?;
+
+        let mut log_file = self.log_file.lock().unwrap();
+
+        writeln!(&mut log_file, "\n\n").log_err();
+        writeln!(&mut log_file, "========================================").log_err();
+        writeln!(&mut log_file, "              JUDGE OUTPUT              ").log_err();
+        writeln!(&mut log_file, "========================================").log_err();
+        writeln!(&mut log_file, "\n{}", &response).log_err();
+
+        parse_judge_output(&response)
+    }
+
+    pub async fn repository_diff(&self) -> Result<String> {
+        let worktree_path = self.worktree_path();
+        run_git(&worktree_path, &["add", "-N"]).await?;
+        run_git(&worktree_path, &["diff"]).await
+    }
+}
+
+fn wait_for_lang_server(
+    lsp_store: &Entity<LspStore>,
+    name: String,
+    cx: &mut AsyncApp,
+) -> Task<Result<()>> {
+    if cx
+        .update(|cx| !has_pending_lang_server_work(lsp_store, cx))
+        .unwrap()
+        || std::env::var("ZED_EVAL_SKIP_LS_WAIT").is_ok()
+    {
+        return Task::ready(anyhow::Ok(()));
+    }
+
+    println!("{}> ⏵ Waiting for language server", name);
+
+    let (mut tx, mut rx) = mpsc::channel(1);
+
+    let subscription =
+        cx.subscribe(&lsp_store, {
+            let name = name.clone();
+            move |lsp_store, event, cx| {
+                match event {
+                    project::LspStoreEvent::LanguageServerUpdate {
+                        message:
+                            client::proto::update_language_server::Variant::WorkProgress(
+                                LspWorkProgress {
+                                    message: Some(message),
+                                    ..
+                                },
+                            ),
+                        ..
+                    } => println!("{name}> ⟲ {message}"),
+                    _ => {}
+                }
+
+                if !has_pending_lang_server_work(&lsp_store, cx) {
+                    tx.try_send(()).ok();
+                }
+            }
+        });
+
+    cx.spawn(async move |cx| {
+        let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0));
+        let result = futures::select! {
+            _ = rx.next() => {
+                println!("{}> ⚑ Language server idle", name);
+                anyhow::Ok(())
+            },
+            _ = timeout.fuse() => {
+                Err(anyhow!("LSP wait timed out after 5 minutes"))
+            }
+        };
+        drop(subscription);
+        result
+    })
+}
+
+fn has_pending_lang_server_work(lsp_store: &Entity<LspStore>, cx: &App) -> bool {
+    lsp_store
+        .read(cx)
+        .language_server_statuses()
+        .any(|(_, status)| !status.pending_work.is_empty())
+}
+
+async fn query_lsp_diagnostics(project: Entity<Project>, cx: &mut AsyncApp) -> Result<String> {
+    let paths_with_diagnostics = project.update(cx, |project, cx| {
+        project
+            .diagnostic_summaries(true, cx)
+            .filter(|(_, _, summary)| summary.error_count > 0 || summary.warning_count > 0)
+            .map(|(project_path, _, _)| project_path)
+            .collect::<Vec<_>>()
+    })?;
+
+    let mut output = String::new();
+    for project_path in paths_with_diagnostics {
+        let buffer = project
+            .update(cx, |project, cx| project.open_buffer(project_path, cx))?
+            .await?;
+        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
+
+        for (_, group) in snapshot.diagnostic_groups(None) {
+            let entry = &group.entries[group.primary_ix];
+            let range = entry.range.to_point(&snapshot);
+            let severity = match entry.diagnostic.severity {
+                DiagnosticSeverity::ERROR => "error",
+                DiagnosticSeverity::WARNING => "warning",
+                _ => continue,
+            };
+
+            writeln!(
+                output,
+                "{} at line {}: {}",
+                severity,
+                range.start.row + 1,
+                entry.diagnostic.message
+            )?;
+        }
+    }
+    anyhow::Ok(output)
+}
+
+fn parse_judge_output(response: &str) -> Result<JudgeOutput> {
+    let analysis = get_tag("analysis", response)?.to_string();
+    let score = get_tag("score", response)?
+        .parse()
+        .context("error parsing score")?;
+
+    Ok(JudgeOutput { analysis, score })
+}
+
+fn get_tag(name: &'static str, response: &str) -> Result<String> {
+    let start_tag = format!("<{}>", name);
+    let end_tag = format!("</{}>", name);
+
+    let start_ix = response
+        .find(&start_tag)
+        .context(format!("{} start tag not found", name))?;
+    let content_start_ix = start_ix + start_tag.len();
+
+    let end_ix = content_start_ix
+        + response[content_start_ix..]
+            .find(&end_tag)
+            .context(format!("{} end tag not found", name))?;
+
+    let content = response[content_start_ix..end_ix].trim().unindent();
+
+    anyhow::Ok(content)
+}
+
+pub fn repo_path_for_url(repo_url: &str) -> PathBuf {
+    let repo_name = repo_url
+        .trim_start_matches("https://")
+        .replace(|c: char| !c.is_alphanumeric(), "-");
+    Path::new(REPOS_DIR)
+        .canonicalize()
+        .context(format!("No such directory {REPOS_DIR}"))
+        .unwrap()
+        .join(repo_name)
+}
+
+pub async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
+    let output = new_smol_command("git")
+        .current_dir(repo_path)
+        .args(args)
+        .output()
+        .await?;
+
+    if output.status.success() {
+        Ok(String::from_utf8(output.stdout)?.trim().to_string())
+    } else {
+        Err(anyhow!(
+            "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
+            args.join(" "),
+            repo_path.display(),
+            output.status,
+            String::from_utf8_lossy(&output.stderr),
+            String::from_utf8_lossy(&output.stdout),
+        ))
+    }
+}
+
+pub async fn send_language_model_request(
+    model: Arc<dyn LanguageModel>,
+    request: LanguageModelRequest,
+    cx: &AsyncApp,
+) -> anyhow::Result<String> {
+    match model.stream_completion_text(request, &cx).await {
+        Ok(mut stream) => {
+            let mut full_response = String::new();
+            while let Some(chunk_result) = stream.stream.next().await {
+                match chunk_result {
+                    Ok(chunk_str) => {
+                        print!("{}", &chunk_str);
+                        full_response.push_str(&chunk_str);
+                    }
+                    Err(err) => {
+                        return Err(anyhow!(
+                            "Error receiving response from language model: {err}"
+                        ));
+                    }
+                }
+            }
+            Ok(full_response)
+        }
+        Err(err) => Err(anyhow!(
+            "Failed to get response from language model. Error was: {err}"
+        )),
+    }
+}
+
+#[cfg(test)]
+mod test {
+    use super::*;
+
+    #[test]
+    fn test_parse_judge_output() {
+        let response = r#"
+            <analysis>The model did a good job but there were still compilations errors.</analysis>
+            <score>3</score>
+        "#
+        .unindent();
+
+        let output = parse_judge_output(&response).unwrap();
+        assert_eq!(
+            output.analysis,
+            "The model did a good job but there were still compilations errors."
+        );
+        assert_eq!(output.score, 3);
+
+        let response = r#"
+            Text around ignored
+
+            <analysis>
+                Failed to compile:
+                - Error 1
+                - Error 2
+            </analysis>
+
+            <score>1</score>
+        "#
+        .unindent();
+
+        let output = parse_judge_output(&response).unwrap();
+        assert_eq!(output.analysis, "Failed to compile:\n- Error 1\n- Error 2");
+        assert_eq!(output.score, 1);
+    }
 }

crates/eval/src/judge_prompt.hbs 🔗

@@ -0,0 +1,25 @@
+You are an expert software developer tasked with evaluating the following changes to a codebase:
+
+<changes>
+{{repository_diff}}
+</changes>
+
+Use the following criteria to score the above changes.
+
+<criteria>
+{{criteria}}
+</criteria>
+
+Based on these criteria, give the test output a score between 0 and 5.
+
+- 5 means: changes meet all criteria
+- 0 means: changes don't meet any criteria
+
+Be suspicious of the changes because they were generated by an LLM.
+Sometimes the LLM decides to change random code, so if the changes are not mentioned in the criteria, penalize the score.
+Analyze the diff hunk by hunk and describe how each change meets or fails to meet the criteria.
+
+```
+<analysis>{YOUR ANALYSIS HERE}</analysis>
+<score>{YOUR SCORE HERE}</score>
+```

crates/language_model/src/language_model.rs 🔗

@@ -83,7 +83,7 @@ pub enum StopReason {
     ToolUse,
 }
 
-#[derive(Debug, PartialEq, Clone, Serialize, Deserialize, Default)]
+#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
 pub struct TokenUsage {
     #[serde(default, skip_serializing_if = "is_default")]
     pub input_tokens: u32,