@@ -1,32 +1,75 @@
mod example;
use assistant_settings::AssistantSettings;
-use client::{Client, UserStore};
+use client::{Client, ProxySettings, UserStore};
pub(crate) use example::*;
use ::fs::RealFs;
-use anyhow::anyhow;
-use gpui::{App, AppContext, Application, Entity, SemanticVersion, Task};
+use anyhow::{Result, anyhow};
+use clap::Parser;
+use extension::ExtensionHostProxy;
+use futures::future;
+use gpui::http_client::{Uri, read_proxy_from_env};
+use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, Task};
+use gpui_tokio::Tokio;
use language::LanguageRegistry;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
};
-use node_runtime::NodeRuntime;
+use node_runtime::{NodeBinaryOptions, NodeRuntime};
use project::Project;
+use project::project_settings::ProjectSettings;
use prompt_store::PromptBuilder;
+use release_channel::AppVersion;
use reqwest_client::ReqwestClient;
use settings::{Settings, SettingsStore};
+use std::collections::HashSet;
+use std::path::{Path, PathBuf};
use std::sync::Arc;
+use util::ResultExt as _;
+
+pub const RUNS_DIR: &str = "./crates/eval/runs";
+
+#[derive(Parser, Debug)]
+#[command(name = "eval", disable_version_flag = true)]
+struct Args {
+ /// Runs all examples that contain these substrings. If unspecified, all examples are run.
+ #[arg(value_name = "EXAMPLE_SUBSTRING")]
+ examples: Vec<String>,
+ /// Model to use (default: "claude-3-7-sonnet-latest")
+ #[arg(long, default_value = "claude-3-7-sonnet-latest")]
+ model: String,
+}
fn main() {
env_logger::init();
+
+ let args = Args::parse();
+ let all_available_examples = list_all_examples().unwrap();
+ let example_paths = all_available_examples
+ .iter()
+ .filter_map(|example_path| {
+ let name = example_path.file_name()?.to_string_lossy();
+ if args.examples.is_empty()
+ || args
+ .examples
+ .iter()
+ .any(|name_substring| name.contains(name_substring))
+ {
+ Some(example_path.clone())
+ } else {
+ None
+ }
+ })
+ .collect::<Vec<_>>();
+
let http_client = Arc::new(ReqwestClient::new());
let app = Application::headless().with_http_client(http_client.clone());
app.run(move |cx| {
let app_state = init(cx);
- let model = find_model("claude-3-7-sonnet-thinking-latest", cx).unwrap();
+ let model = find_model("claude-3-7-sonnet-latest", cx).unwrap();
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
registry.set_default_model(Some(model.clone()), cx);
@@ -39,17 +82,142 @@ fn main() {
cx.spawn(async move |cx| {
authenticate.await.unwrap();
- let example =
- Example::load_from_directory("./crates/eval/examples/find_and_replace_diff_card")?;
- example.setup()?;
- cx.update(|cx| example.run(model, app_state, cx))?.await?;
+ std::fs::create_dir_all(REPOS_DIR)?;
+ std::fs::create_dir_all(WORKTREES_DIR)?;
+
+ let run_dir = Path::new(RUNS_DIR).join(format!(
+ "{}",
+ chrono::Local::now().format("%Y-%m-%d_%H-%M-%S")
+ ));
+ std::fs::create_dir_all(&run_dir)?;
+
+ let mut examples = Vec::new();
+ for example_path in example_paths {
+ let example = Example::load_from_directory(&example_path, &run_dir)?;
+ examples.push((example_path, example));
+ }
+ let mut repo_urls = HashSet::new();
+
+ let mut clone_tasks = Vec::new();
+
+ for (_, example) in examples.iter() {
+ let repo_url = example.base.url.clone();
+ if repo_urls.insert(repo_url.clone()) {
+ let repo_path = repo_path_for_url(&repo_url);
+
+ if !repo_path.join(".git").is_dir() {
+ println!("Cloning: {}", repo_url);
+
+ let git_task = cx.spawn(async move |_cx| {
+ std::fs::create_dir_all(&repo_path)?;
+ run_git(&repo_path, &["init"]).await?;
+ run_git(&repo_path, &["remote", "add", "origin", &repo_url]).await
+ });
+
+ clone_tasks.push(git_task);
+ } else {
+ println!("Already cloned: {}", repo_url);
+
+ let actual_origin =
+ run_git(&repo_path, &["remote", "get-url", "origin"]).await?;
+ if actual_origin != repo_url {
+ return Err(anyhow!(
+ "remote origin {} does not match expected origin {}",
+ actual_origin,
+ repo_url,
+ ));
+ }
+ }
+ }
+ }
+
+ future::join_all(clone_tasks).await;
+
+ let tasks = examples
+ .into_iter()
+ .map(|(example_path, example)| {
+ let app_state = app_state.clone();
+ let model = model.clone();
+ cx.spawn(async move |cx| {
+ (
+ example_path,
+ run_example(example, model, app_state, cx).await,
+ )
+ })
+ })
+ .collect::<Vec<_>>();
+
+ let results: Vec<(PathBuf, Result<JudgeOutput>)> = future::join_all(tasks).await;
+
+ println!("\n\n");
+ println!("========================================");
+ println!(" EVAL RESULTS ");
+ println!("========================================");
+ println!("");
- anyhow::Ok(())
+ let mut judge_scores = Vec::new();
+
+ for (example_path, result) in results {
+ let example_name = example_path.file_name().unwrap().to_string_lossy();
+ match result {
+ Err(err) => {
+ println!("💥 {:<30}: {:?}", example_name, err);
+ }
+ Ok(judge_output) => {
+ const SCORES: [&str; 6] = ["💀", "😭", "😔", "😐", "🙂", "🤩"];
+
+ println!(
+ "{} {:<30}: {}",
+ SCORES[judge_output.score.min(5) as usize],
+ example_name,
+ judge_output.score,
+ );
+ judge_scores.push(judge_output.score);
+ }
+ }
+ }
+
+ let score_count = judge_scores.len();
+ let average_score = judge_scores
+ .into_iter()
+ .map(|score| score as f32)
+ .sum::<f32>()
+ / (score_count as f32);
+ println!("\nAverage score: {average_score}");
+
+ cx.update(|cx| cx.quit())
})
.detach_and_log_err(cx);
});
}
+async fn run_example(
+ mut example: Example,
+ model: Arc<dyn LanguageModel>,
+ app_state: Arc<AgentAppState>,
+ cx: &mut AsyncApp,
+) -> Result<JudgeOutput> {
+ example.setup().await?;
+ cx.update(|cx| example.run(model.clone(), app_state, cx))?
+ .await?;
+ let diff = example.repository_diff().await?;
+ example.judge(model, diff, cx).await
+}
+
+fn list_all_examples() -> Result<Vec<PathBuf>> {
+ let path = std::fs::canonicalize(EXAMPLES_DIR).unwrap();
+ let entries = std::fs::read_dir(path).unwrap();
+ let mut result_paths = Vec::new();
+ for entry in entries {
+ let entry = entry?;
+ let path = entry.path();
+ if path.is_dir() {
+ result_paths.push(path);
+ }
+ }
+ Ok(result_paths)
+}
+
/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
pub struct AgentAppState {
pub languages: Arc<LanguageRegistry>,
@@ -72,6 +240,27 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
.unwrap();
cx.set_global(settings_store);
client::init_settings(cx);
+
+ // Set User-Agent so we can download language servers from GitHub
+ let user_agent = format!(
+ "Zed/{} ({}; {})",
+ AppVersion::global(cx),
+ std::env::consts::OS,
+ std::env::consts::ARCH
+ );
+ let proxy_str = ProxySettings::get_global(cx).proxy.to_owned();
+ let proxy_url = proxy_str
+ .as_ref()
+ .and_then(|input| input.parse::<Uri>().ok())
+ .or_else(read_proxy_from_env);
+ let http = {
+ let _guard = Tokio::handle(cx).enter();
+
+ ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent)
+ .expect("could not start HTTP client")
+ };
+ cx.set_http_client(Arc::new(http));
+
Project::init_settings(cx);
let client = Client::production(cx);
@@ -83,13 +272,47 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
cx.background_executor().clone(),
));
- let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone()));
+ let mut languages = LanguageRegistry::new(cx.background_executor().clone());
+ languages.set_language_server_download_dir(paths::languages_dir().clone());
+ let languages = Arc::new(languages);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
+ extension::init(cx);
+
+ let (tx, rx) = async_watch::channel(None);
+ cx.observe_global::<SettingsStore>(move |cx| {
+ let settings = &ProjectSettings::get_global(cx).node;
+ let options = NodeBinaryOptions {
+ allow_path_lookup: !settings.ignore_system_version.unwrap_or_default(),
+ allow_binary_download: true,
+ use_paths: settings.path.as_ref().map(|node_path| {
+ let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref());
+ let npm_path = settings
+ .npm_path
+ .as_ref()
+ .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref()));
+ (
+ node_path.clone(),
+ npm_path.unwrap_or_else(|| {
+ let base_path = PathBuf::new();
+ node_path.parent().unwrap_or(&base_path).join("npm")
+ }),
+ )
+ }),
+ };
+ tx.send(Some(options)).log_err();
+ })
+ .detach();
+ let node_runtime = NodeRuntime::new(client.http_client().clone(), rx);
+
+ let extension_host_proxy = ExtensionHostProxy::global(cx);
+
language::init(cx);
+ language_extension::init(extension_host_proxy.clone(), languages.clone());
language_model::init(client.clone(), cx);
language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
+ languages::init(languages.clone(), node_runtime.clone(), cx);
assistant_tools::init(client.http_client().clone(), cx);
context_server::init(cx);
let stdout_is_a_pty = false;
@@ -109,7 +332,7 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
client,
user_store,
fs,
- node_runtime: NodeRuntime::unavailable(),
+ node_runtime,
prompt_builder,
})
}
@@ -1,83 +1,161 @@
use agent::{RequestKind, ThreadEvent, ThreadStore};
-use anyhow::{Result, anyhow};
+use anyhow::{Context as _, Result, anyhow};
use assistant_tool::ToolWorkingSet;
+use client::proto::LspWorkProgress;
use dap::DapRegistry;
-use futures::channel::oneshot;
-use gpui::{App, Task};
-use language_model::{LanguageModel, StopReason};
-use project::Project;
-use serde::Deserialize;
-use std::process::Command;
-use std::sync::Arc;
+use futures::channel::{mpsc, oneshot};
+use futures::{FutureExt, StreamExt as _};
+use gpui::{App, AsyncApp, Entity, Task};
+use handlebars::Handlebars;
+use language::{DiagnosticSeverity, OffsetRangeExt};
+use language_model::{
+ LanguageModel, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
+ StopReason, TokenUsage,
+};
+use project::{LspStore, Project, ProjectPath};
+use serde::{Deserialize, Serialize};
+use std::fmt::Write as _;
+use std::fs::File;
+use std::io::Write as _;
+use std::sync::{Arc, Mutex};
+use std::time::Duration;
use std::{
fs,
path::{Path, PathBuf},
};
+use unindent::Unindent as _;
+use util::ResultExt as _;
+use util::command::new_smol_command;
+use util::serde::default_true;
use crate::AgentAppState;
-#[derive(Debug, Deserialize)]
+pub const EXAMPLES_DIR: &str = "./crates/eval/examples";
+pub const REPOS_DIR: &str = "./crates/eval/repos";
+pub const WORKTREES_DIR: &str = "./crates/eval/worktrees";
+
+#[derive(Clone, Debug, Deserialize)]
pub struct ExampleBase {
- pub path: PathBuf,
+ pub url: String,
pub revision: String,
+ pub language_extension: Option<String>,
+ pub insert_id: Option<String>,
+ #[serde(default = "default_true")]
+ pub require_lsp: bool,
}
-#[derive(Debug)]
+#[derive(Clone, Debug)]
pub struct Example {
+ pub name: String,
+ /// Content of `base.toml`
pub base: ExampleBase,
-
- /// Content of the prompt.md file
+ /// Content of `prompt.md`
pub prompt: String,
+ /// Content of `criteria.md`
+ pub criteria: String,
+ /// Markdown log file to append to
+ pub log_file: Arc<Mutex<File>>,
+}
+
+#[derive(Debug, Serialize, Deserialize, Clone)]
+pub struct RunOutput {
+ pub repository_diff: String,
+ pub diagnostics: String,
+ pub response_count: usize,
+ pub token_usage: TokenUsage,
+}
- /// Content of the rubric.md file
- pub _rubric: String,
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct JudgeInput {
+ pub repository_diff: String,
+ pub criteria: String,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct JudgeOutput {
+ pub analysis: String,
+ pub score: u32,
}
impl Example {
- /// Load an example from a directory containing base.toml, prompt.md, and rubric.md
- pub fn load_from_directory<P: AsRef<Path>>(dir_path: P) -> Result<Self> {
- let base_path = dir_path.as_ref().join("base.toml");
- let prompt_path = dir_path.as_ref().join("prompt.md");
- let rubric_path = dir_path.as_ref().join("rubric.md");
+ /// Load an example from a directory containing base.toml, prompt.md, and criteria.md
+ pub fn load_from_directory(dir_path: &Path, run_dir: &Path) -> Result<Self> {
+ let name = dir_path.file_name().unwrap().to_string_lossy().to_string();
+ let base_path = dir_path.join("base.toml");
+ let prompt_path = dir_path.join("prompt.md");
+ let criteria_path = dir_path.join("criteria.md");
- let mut base: ExampleBase = toml::from_str(&fs::read_to_string(&base_path)?)?;
- base.path = base.path.canonicalize()?;
+ let log_file_path = run_dir.join(format!(
+ "{}.md",
+ dir_path.file_name().unwrap().to_str().unwrap()
+ ));
+ let log_file = Arc::new(Mutex::new(File::create(&log_file_path).unwrap()));
+ println!("{}> Logging to {:?}", name, log_file_path);
Ok(Example {
- base,
- prompt: fs::read_to_string(prompt_path)?,
- _rubric: fs::read_to_string(rubric_path)?,
+ name,
+ base: toml::from_str(&fs::read_to_string(&base_path)?)?,
+ prompt: fs::read_to_string(prompt_path.clone())?,
+ criteria: fs::read_to_string(criteria_path.clone())?,
+ log_file,
})
}
+ pub fn worktree_path(&self) -> PathBuf {
+ Path::new(WORKTREES_DIR)
+ .canonicalize()
+ .context(format!("No such directory {WORKTREES_DIR}"))
+ .unwrap()
+ .join(&self.name)
+ }
+
/// Set up the example by checking out the specified Git revision
- pub fn setup(&self) -> Result<()> {
- // Check if the directory exists
- let path = Path::new(&self.base.path);
- anyhow::ensure!(path.exists(), "Path does not exist: {:?}", self.base.path);
-
- // Change to the project directory and checkout the specified revision
- let output = Command::new("git")
- .current_dir(&self.base.path)
- .arg("checkout")
- .arg(&self.base.revision)
- .output()?;
- anyhow::ensure!(
- output.status.success(),
- "Failed to checkout revision {}: {}",
- self.base.revision,
- String::from_utf8_lossy(&output.stderr),
- );
+ pub async fn setup(&self) -> Result<()> {
+ let repo_path = repo_path_for_url(&self.base.url);
+
+ run_git(
+ &repo_path,
+ &["fetch", "--depth", "1", "origin", &self.base.revision],
+ )
+ .await?;
+
+ let worktree_path = self.worktree_path();
+
+ if worktree_path.is_dir() {
+ println!("{}> Resetting existing worktree", self.name);
+
+ // TODO: consider including "-x" to remove ignored files. The downside of this is that
+ // it will also remove build artifacts, and so prevent incremental reuse there.
+ run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
+ run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
+ run_git(&worktree_path, &["checkout", &self.base.revision]).await?;
+ } else {
+ println!("{}> Creating worktree", self.name);
+
+ let worktree_path_string = worktree_path.to_string_lossy().to_string();
+
+ run_git(
+ &repo_path,
+ &[
+ "worktree",
+ "add",
+ "-f",
+ &worktree_path_string,
+ &self.base.revision,
+ ],
+ )
+ .await?;
+ }
Ok(())
}
pub fn run(
- self,
+ &self,
model: Arc<dyn LanguageModel>,
app_state: Arc<AgentAppState>,
cx: &mut App,
- ) -> Task<Result<()>> {
+ ) -> Task<Result<RunOutput>> {
let project = Project::local(
app_state.client.clone(),
app_state.node_runtime.clone(),
@@ -89,30 +167,119 @@ impl Example {
cx,
);
+ let worktree_path = self.worktree_path();
let worktree = project.update(cx, |project, cx| {
- project.create_worktree(self.base.path, true, cx)
+ project.create_worktree(&worktree_path, true, cx)
});
let tools = Arc::new(ToolWorkingSet::default());
let thread_store =
ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx);
+ let this = self.clone();
- println!("USER:");
- println!("{}", self.prompt);
- println!("ASSISTANT:");
cx.spawn(async move |cx| {
- worktree.await?;
+ let worktree = worktree.await?;
+
+ // Wait for worktree scan to finish before choosing a file to open.
+ worktree
+ .update(cx, |worktree, _cx| {
+ worktree.as_local().unwrap().scan_complete()
+ })?
+ .await;
+
+ let lsp_open_handle_and_store = if this.base.require_lsp {
+ let language_extension = this.base.language_extension.as_deref().context(
+ "language_extension field is required in base.toml when `require_lsp == true`",
+ )?;
+
+ // Open a file that matches the language to cause LSP to start.
+ let language_file = worktree.read_with(cx, |worktree, _cx| {
+ worktree
+ .files(false, 0)
+ .find_map(|e| {
+ if e.path.clone().extension().and_then(|ext| ext.to_str())
+ == Some(language_extension)
+ {
+ Some(ProjectPath {
+ worktree_id: worktree.id(),
+ path: e.path.clone(),
+ })
+ } else {
+ None
+ }
+ })
+ .context("Failed to find a file for example language")
+ })??;
+
+ let open_language_file_buffer_task = project.update(cx, |project, cx| {
+ project.open_buffer(language_file.clone(), cx)
+ })?;
+
+ let language_file_buffer = open_language_file_buffer_task.await?;
+
+ let (lsp_open_handle, lsp_store) = project.update(cx, |project, cx| {
+ (
+ project.register_buffer_with_language_servers(&language_file_buffer, cx),
+ project.lsp_store().clone(),
+ )
+ })?;
+
+ // TODO: remove this once the diagnostics tool waits for new diagnostics
+ cx.background_executor().timer(Duration::new(5, 0)).await;
+ wait_for_lang_server(&lsp_store, this.name.clone(), cx).await?;
+
+ lsp_store.update(cx, |lsp_store, cx| {
+ lsp_open_handle.update(cx, |buffer, cx| {
+ buffer.update(cx, |buffer, cx| {
+ let has_language_server = lsp_store
+ .language_servers_for_local_buffer(buffer, cx)
+ .next()
+ .is_some();
+ if has_language_server {
+ Ok(())
+ } else {
+ Err(anyhow!(
+ "`{:?}` was opened to cause the language server to start, \
+ but no language servers are registered for its buffer. \
+ Set `require_lsp = false` in `base.toml` to skip this.",
+ language_file
+ ))
+ }
+ })
+ })
+ })??;
+
+ Some((lsp_open_handle, lsp_store))
+ } else {
+ None
+ };
+
+ if std::env::var("ZED_EVAL_SETUP_ONLY").is_ok() {
+ return Err(anyhow!("Setup only mode"));
+ }
+
let thread_store = thread_store.await;
let thread =
thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?;
+ {
+ let mut log_file = this.log_file.lock().unwrap();
+ writeln!(&mut log_file, "👤 USER:").log_err();
+ writeln!(&mut log_file, "{}", this.prompt).log_err();
+ writeln!(&mut log_file, "🤖 ASSISTANT:").log_err();
+ log_file.flush().log_err();
+ }
+
let (tx, rx) = oneshot::channel();
let mut tx = Some(tx);
- let _subscription =
- cx.subscribe(
- &thread,
- move |thread, event: &ThreadEvent, cx| match event {
+ let _subscription = cx.subscribe(&thread, {
+ let log_file = this.log_file.clone();
+ let name = this.name.clone();
+ move |thread, event: &ThreadEvent, cx| {
+ let mut log_file = log_file.lock().unwrap();
+
+ match event {
ThreadEvent::Stopped(reason) => match reason {
Ok(StopReason::EndTurn) => {
if let Some(tx) = tx.take() {
@@ -137,15 +304,16 @@ impl Example {
}
}
ThreadEvent::StreamedAssistantText(_, chunk) => {
- print!("{}", chunk);
+ write!(&mut log_file, "{}", chunk).log_err();
}
ThreadEvent::StreamedAssistantThinking(_, chunk) => {
- print!("{}", chunk);
+ write!(&mut log_file, "{}", chunk).log_err();
}
ThreadEvent::UsePendingTools { tool_uses } => {
- println!("\n\nUSING TOOLS:");
+ writeln!(&mut log_file, "\n\nUSING TOOLS:").log_err();
for tool_use in tool_uses {
- println!("{}: {}", tool_use.name, tool_use.input);
+ writeln!(&mut log_file, "{}: {}", tool_use.name, tool_use.input)
+ .log_err();
}
}
ThreadEvent::ToolFinished {
@@ -154,25 +322,331 @@ impl Example {
..
} => {
if let Some(tool_use) = pending_tool_use {
- println!("\nTOOL FINISHED: {}", tool_use.name);
+ let message = format!("TOOL FINISHED: {}", tool_use.name);
+ println!("{name}> {message}");
+ writeln!(&mut log_file, "\n{}", message).log_err();
}
if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
- println!("\n{}\n", tool_result.content);
+ let message = format!("\n{}\n", tool_result.content);
+ writeln!(&mut log_file, "{}", message).log_err();
}
}
_ => {}
- },
- )?;
+ }
+
+ log_file.flush().log_err();
+ }
+ })?;
thread.update(cx, |thread, cx| {
let context = vec![];
- thread.insert_user_message(self.prompt.clone(), context, None, cx);
+ thread.insert_user_message(this.prompt.clone(), context, None, cx);
thread.send_to_model(model, RequestKind::Chat, cx);
})?;
rx.await??;
- Ok(())
+ if let Some((_, lsp_store)) = lsp_open_handle_and_store.as_ref() {
+ wait_for_lang_server(lsp_store, this.name.clone(), cx).await?;
+ }
+
+ let repository_diff = this.repository_diff().await?;
+ let diagnostics = cx
+ .update(move |cx| {
+ cx.spawn(async move |cx| query_lsp_diagnostics(project, cx).await)
+ })?
+ .await?;
+
+ drop(lsp_open_handle_and_store);
+
+ thread.update(cx, |thread, _cx| {
+ let response_count = thread
+ .messages()
+ .filter(|message| message.role == language_model::Role::Assistant)
+ .count();
+ RunOutput {
+ repository_diff,
+ diagnostics,
+ response_count,
+ token_usage: thread.cumulative_token_usage(),
+ }
+ })
})
}
+
+ pub async fn judge(
+ &mut self,
+ model: Arc<dyn LanguageModel>,
+ repository_diff: String,
+ cx: &AsyncApp,
+ ) -> Result<JudgeOutput> {
+ let judge_prompt = include_str!("judge_prompt.hbs");
+ let judge_prompt_name = "judge_prompt";
+ let mut handlebars = Handlebars::new();
+ handlebars.register_template_string(judge_prompt_name, judge_prompt)?;
+ let prompt = handlebars.render(
+ judge_prompt_name,
+ &JudgeInput {
+ repository_diff,
+ criteria: self.criteria.clone(),
+ },
+ )?;
+
+ let request = LanguageModelRequest {
+ messages: vec![LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec![MessageContent::Text(prompt)],
+ cache: false,
+ }],
+ temperature: None,
+ tools: Vec::new(),
+ stop: Vec::new(),
+ };
+
+ let response = send_language_model_request(model, request, cx).await?;
+
+ let mut log_file = self.log_file.lock().unwrap();
+
+ writeln!(&mut log_file, "\n\n").log_err();
+ writeln!(&mut log_file, "========================================").log_err();
+ writeln!(&mut log_file, " JUDGE OUTPUT ").log_err();
+ writeln!(&mut log_file, "========================================").log_err();
+ writeln!(&mut log_file, "\n{}", &response).log_err();
+
+ parse_judge_output(&response)
+ }
+
+ pub async fn repository_diff(&self) -> Result<String> {
+ let worktree_path = self.worktree_path();
+ run_git(&worktree_path, &["add", "-N"]).await?;
+ run_git(&worktree_path, &["diff"]).await
+ }
+}
+
+fn wait_for_lang_server(
+ lsp_store: &Entity<LspStore>,
+ name: String,
+ cx: &mut AsyncApp,
+) -> Task<Result<()>> {
+ if cx
+ .update(|cx| !has_pending_lang_server_work(lsp_store, cx))
+ .unwrap()
+ || std::env::var("ZED_EVAL_SKIP_LS_WAIT").is_ok()
+ {
+ return Task::ready(anyhow::Ok(()));
+ }
+
+ println!("{}> ⏵ Waiting for language server", name);
+
+ let (mut tx, mut rx) = mpsc::channel(1);
+
+ let subscription =
+ cx.subscribe(&lsp_store, {
+ let name = name.clone();
+ move |lsp_store, event, cx| {
+ match event {
+ project::LspStoreEvent::LanguageServerUpdate {
+ message:
+ client::proto::update_language_server::Variant::WorkProgress(
+ LspWorkProgress {
+ message: Some(message),
+ ..
+ },
+ ),
+ ..
+ } => println!("{name}> ⟲ {message}"),
+ _ => {}
+ }
+
+ if !has_pending_lang_server_work(&lsp_store, cx) {
+ tx.try_send(()).ok();
+ }
+ }
+ });
+
+ cx.spawn(async move |cx| {
+ let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0));
+ let result = futures::select! {
+ _ = rx.next() => {
+ println!("{}> ⚑ Language server idle", name);
+ anyhow::Ok(())
+ },
+ _ = timeout.fuse() => {
+ Err(anyhow!("LSP wait timed out after 5 minutes"))
+ }
+ };
+ drop(subscription);
+ result
+ })
+}
+
+fn has_pending_lang_server_work(lsp_store: &Entity<LspStore>, cx: &App) -> bool {
+ lsp_store
+ .read(cx)
+ .language_server_statuses()
+ .any(|(_, status)| !status.pending_work.is_empty())
+}
+
+async fn query_lsp_diagnostics(project: Entity<Project>, cx: &mut AsyncApp) -> Result<String> {
+ let paths_with_diagnostics = project.update(cx, |project, cx| {
+ project
+ .diagnostic_summaries(true, cx)
+ .filter(|(_, _, summary)| summary.error_count > 0 || summary.warning_count > 0)
+ .map(|(project_path, _, _)| project_path)
+ .collect::<Vec<_>>()
+ })?;
+
+ let mut output = String::new();
+ for project_path in paths_with_diagnostics {
+ let buffer = project
+ .update(cx, |project, cx| project.open_buffer(project_path, cx))?
+ .await?;
+ let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
+
+ for (_, group) in snapshot.diagnostic_groups(None) {
+ let entry = &group.entries[group.primary_ix];
+ let range = entry.range.to_point(&snapshot);
+ let severity = match entry.diagnostic.severity {
+ DiagnosticSeverity::ERROR => "error",
+ DiagnosticSeverity::WARNING => "warning",
+ _ => continue,
+ };
+
+ writeln!(
+ output,
+ "{} at line {}: {}",
+ severity,
+ range.start.row + 1,
+ entry.diagnostic.message
+ )?;
+ }
+ }
+ anyhow::Ok(output)
+}
+
+fn parse_judge_output(response: &str) -> Result<JudgeOutput> {
+ let analysis = get_tag("analysis", response)?.to_string();
+ let score = get_tag("score", response)?
+ .parse()
+ .context("error parsing score")?;
+
+ Ok(JudgeOutput { analysis, score })
+}
+
+fn get_tag(name: &'static str, response: &str) -> Result<String> {
+ let start_tag = format!("<{}>", name);
+ let end_tag = format!("</{}>", name);
+
+ let start_ix = response
+ .find(&start_tag)
+ .context(format!("{} start tag not found", name))?;
+ let content_start_ix = start_ix + start_tag.len();
+
+ let end_ix = content_start_ix
+ + response[content_start_ix..]
+ .find(&end_tag)
+ .context(format!("{} end tag not found", name))?;
+
+ let content = response[content_start_ix..end_ix].trim().unindent();
+
+ anyhow::Ok(content)
+}
+
+pub fn repo_path_for_url(repo_url: &str) -> PathBuf {
+ let repo_name = repo_url
+ .trim_start_matches("https://")
+ .replace(|c: char| !c.is_alphanumeric(), "-");
+ Path::new(REPOS_DIR)
+ .canonicalize()
+ .context(format!("No such directory {REPOS_DIR}"))
+ .unwrap()
+ .join(repo_name)
+}
+
+pub async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
+ let output = new_smol_command("git")
+ .current_dir(repo_path)
+ .args(args)
+ .output()
+ .await?;
+
+ if output.status.success() {
+ Ok(String::from_utf8(output.stdout)?.trim().to_string())
+ } else {
+ Err(anyhow!(
+ "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
+ args.join(" "),
+ repo_path.display(),
+ output.status,
+ String::from_utf8_lossy(&output.stderr),
+ String::from_utf8_lossy(&output.stdout),
+ ))
+ }
+}
+
+pub async fn send_language_model_request(
+ model: Arc<dyn LanguageModel>,
+ request: LanguageModelRequest,
+ cx: &AsyncApp,
+) -> anyhow::Result<String> {
+ match model.stream_completion_text(request, &cx).await {
+ Ok(mut stream) => {
+ let mut full_response = String::new();
+ while let Some(chunk_result) = stream.stream.next().await {
+ match chunk_result {
+ Ok(chunk_str) => {
+ print!("{}", &chunk_str);
+ full_response.push_str(&chunk_str);
+ }
+ Err(err) => {
+ return Err(anyhow!(
+ "Error receiving response from language model: {err}"
+ ));
+ }
+ }
+ }
+ Ok(full_response)
+ }
+ Err(err) => Err(anyhow!(
+ "Failed to get response from language model. Error was: {err}"
+ )),
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+
+ #[test]
+ fn test_parse_judge_output() {
+ let response = r#"
+ <analysis>The model did a good job but there were still compilations errors.</analysis>
+ <score>3</score>
+ "#
+ .unindent();
+
+ let output = parse_judge_output(&response).unwrap();
+ assert_eq!(
+ output.analysis,
+ "The model did a good job but there were still compilations errors."
+ );
+ assert_eq!(output.score, 3);
+
+ let response = r#"
+ Text around ignored
+
+ <analysis>
+ Failed to compile:
+ - Error 1
+ - Error 2
+ </analysis>
+
+ <score>1</score>
+ "#
+ .unindent();
+
+ let output = parse_judge_output(&response).unwrap();
+ assert_eq!(output.analysis, "Failed to compile:\n- Error 1\n- Error 2");
+ assert_eq!(output.score, 1);
+ }
}