1use agent::{RequestKind, ThreadEvent, ThreadStore};
2use anyhow::{Context as _, Result, anyhow};
3use assistant_tool::ToolWorkingSet;
4use client::proto::LspWorkProgress;
5use dap::DapRegistry;
6use futures::channel::{mpsc, oneshot};
7use futures::{FutureExt, StreamExt as _};
8use gpui::{App, AsyncApp, Entity, Task};
9use handlebars::Handlebars;
10use language::{DiagnosticSeverity, OffsetRangeExt};
11use language_model::{
12 LanguageModel, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
13 StopReason, TokenUsage,
14};
15use project::{LspStore, Project, ProjectPath};
16use serde::{Deserialize, Serialize};
17use std::fmt::Write as _;
18use std::fs::File;
19use std::io::Write as _;
20use std::sync::{Arc, Mutex};
21use std::time::Duration;
22use std::{
23 fs,
24 path::{Path, PathBuf},
25};
26use unindent::Unindent as _;
27use util::ResultExt as _;
28use util::command::new_smol_command;
29use util::serde::default_true;
30
31use crate::AgentAppState;
32
33pub const EXAMPLES_DIR: &str = "./crates/eval/examples";
34pub const REPOS_DIR: &str = "./crates/eval/repos";
35pub const WORKTREES_DIR: &str = "./crates/eval/worktrees";
36
37#[derive(Clone, Debug, Deserialize)]
38pub struct ExampleBase {
39 pub url: String,
40 pub revision: String,
41 pub language_extension: Option<String>,
42 pub insert_id: Option<String>,
43 #[serde(default = "default_true")]
44 pub require_lsp: bool,
45}
46
47#[derive(Clone, Debug)]
48pub struct Example {
49 pub name: String,
50 /// Content of `base.toml`
51 pub base: ExampleBase,
52 /// Content of `prompt.md`
53 pub prompt: String,
54 /// Content of `criteria.md`
55 pub criteria: String,
56 /// Markdown log file to append to
57 pub log_file: Arc<Mutex<File>>,
58}
59
60#[derive(Debug, Serialize, Deserialize, Clone)]
61pub struct RunOutput {
62 pub repository_diff: String,
63 pub diagnostics: String,
64 pub response_count: usize,
65 pub token_usage: TokenUsage,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct JudgeInput {
70 pub repository_diff: String,
71 pub criteria: String,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct JudgeOutput {
76 pub analysis: String,
77 pub score: u32,
78}
79
80impl Example {
81 /// Load an example from a directory containing base.toml, prompt.md, and criteria.md
82 pub fn load_from_directory(dir_path: &Path, run_dir: &Path) -> Result<Self> {
83 let name = dir_path.file_name().unwrap().to_string_lossy().to_string();
84 let base_path = dir_path.join("base.toml");
85 let prompt_path = dir_path.join("prompt.md");
86 let criteria_path = dir_path.join("criteria.md");
87
88 let log_file_path = run_dir.join(format!(
89 "{}.md",
90 dir_path.file_name().unwrap().to_str().unwrap()
91 ));
92 let log_file = Arc::new(Mutex::new(File::create(&log_file_path).unwrap()));
93 println!("{}> Logging to {:?}", name, log_file_path);
94
95 Ok(Example {
96 name,
97 base: toml::from_str(&fs::read_to_string(&base_path)?)?,
98 prompt: fs::read_to_string(prompt_path.clone())?,
99 criteria: fs::read_to_string(criteria_path.clone())?,
100 log_file,
101 })
102 }
103
104 pub fn worktree_path(&self) -> PathBuf {
105 Path::new(WORKTREES_DIR)
106 .canonicalize()
107 .context(format!("No such directory {WORKTREES_DIR}"))
108 .unwrap()
109 .join(&self.name)
110 }
111
112 /// Set up the example by checking out the specified Git revision
113 pub async fn setup(&self) -> Result<()> {
114 let repo_path = repo_path_for_url(&self.base.url);
115
116 run_git(
117 &repo_path,
118 &["fetch", "--depth", "1", "origin", &self.base.revision],
119 )
120 .await?;
121
122 let worktree_path = self.worktree_path();
123
124 if worktree_path.is_dir() {
125 println!("{}> Resetting existing worktree", self.name);
126
127 // TODO: consider including "-x" to remove ignored files. The downside of this is that
128 // it will also remove build artifacts, and so prevent incremental reuse there.
129 run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
130 run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
131 run_git(&worktree_path, &["checkout", &self.base.revision]).await?;
132 } else {
133 println!("{}> Creating worktree", self.name);
134
135 let worktree_path_string = worktree_path.to_string_lossy().to_string();
136
137 run_git(
138 &repo_path,
139 &[
140 "worktree",
141 "add",
142 "-f",
143 &worktree_path_string,
144 &self.base.revision,
145 ],
146 )
147 .await?;
148 }
149
150 Ok(())
151 }
152
153 pub fn run(
154 &self,
155 model: Arc<dyn LanguageModel>,
156 app_state: Arc<AgentAppState>,
157 cx: &mut App,
158 ) -> Task<Result<RunOutput>> {
159 let project = Project::local(
160 app_state.client.clone(),
161 app_state.node_runtime.clone(),
162 app_state.user_store.clone(),
163 app_state.languages.clone(),
164 Arc::new(DapRegistry::default()),
165 app_state.fs.clone(),
166 None,
167 cx,
168 );
169
170 let worktree_path = self.worktree_path();
171 let worktree = project.update(cx, |project, cx| {
172 project.create_worktree(&worktree_path, true, cx)
173 });
174
175 let tools = Arc::new(ToolWorkingSet::default());
176 let thread_store =
177 ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx);
178 let this = self.clone();
179
180 cx.spawn(async move |cx| {
181 let worktree = worktree.await?;
182
183 // Wait for worktree scan to finish before choosing a file to open.
184 worktree
185 .update(cx, |worktree, _cx| {
186 worktree.as_local().unwrap().scan_complete()
187 })?
188 .await;
189
190 let lsp_open_handle_and_store = if this.base.require_lsp {
191 let language_extension = this.base.language_extension.as_deref().context(
192 "language_extension field is required in base.toml when `require_lsp == true`",
193 )?;
194
195 // Open a file that matches the language to cause LSP to start.
196 let language_file = worktree.read_with(cx, |worktree, _cx| {
197 worktree
198 .files(false, 0)
199 .find_map(|e| {
200 if e.path.clone().extension().and_then(|ext| ext.to_str())
201 == Some(language_extension)
202 {
203 Some(ProjectPath {
204 worktree_id: worktree.id(),
205 path: e.path.clone(),
206 })
207 } else {
208 None
209 }
210 })
211 .context("Failed to find a file for example language")
212 })??;
213
214 let open_language_file_buffer_task = project.update(cx, |project, cx| {
215 project.open_buffer(language_file.clone(), cx)
216 })?;
217
218 let language_file_buffer = open_language_file_buffer_task.await?;
219
220 let (lsp_open_handle, lsp_store) = project.update(cx, |project, cx| {
221 (
222 project.register_buffer_with_language_servers(&language_file_buffer, cx),
223 project.lsp_store().clone(),
224 )
225 })?;
226
227 // TODO: remove this once the diagnostics tool waits for new diagnostics
228 cx.background_executor().timer(Duration::new(5, 0)).await;
229 wait_for_lang_server(&lsp_store, this.name.clone(), cx).await?;
230
231 lsp_store.update(cx, |lsp_store, cx| {
232 lsp_open_handle.update(cx, |buffer, cx| {
233 buffer.update(cx, |buffer, cx| {
234 let has_language_server = lsp_store
235 .language_servers_for_local_buffer(buffer, cx)
236 .next()
237 .is_some();
238 if has_language_server {
239 Ok(())
240 } else {
241 Err(anyhow!(
242 "`{:?}` was opened to cause the language server to start, \
243 but no language servers are registered for its buffer. \
244 Set `require_lsp = false` in `base.toml` to skip this.",
245 language_file
246 ))
247 }
248 })
249 })
250 })??;
251
252 Some((lsp_open_handle, lsp_store))
253 } else {
254 None
255 };
256
257 if std::env::var("ZED_EVAL_SETUP_ONLY").is_ok() {
258 return Err(anyhow!("Setup only mode"));
259 }
260
261 let thread_store = thread_store.await;
262 let thread =
263 thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?;
264
265 {
266 let mut log_file = this.log_file.lock().unwrap();
267 writeln!(&mut log_file, "👤 USER:").log_err();
268 writeln!(&mut log_file, "{}", this.prompt).log_err();
269 writeln!(&mut log_file, "🤖 ASSISTANT:").log_err();
270 log_file.flush().log_err();
271 }
272
273 let (tx, rx) = oneshot::channel();
274 let mut tx = Some(tx);
275
276 let _subscription = cx.subscribe(&thread, {
277 let log_file = this.log_file.clone();
278 let name = this.name.clone();
279 move |thread, event: &ThreadEvent, cx| {
280 let mut log_file = log_file.lock().unwrap();
281
282 match event {
283 ThreadEvent::Stopped(reason) => match reason {
284 Ok(StopReason::EndTurn) => {
285 if let Some(tx) = tx.take() {
286 tx.send(Ok(())).ok();
287 }
288 }
289 Ok(StopReason::MaxTokens) => {
290 if let Some(tx) = tx.take() {
291 tx.send(Err(anyhow!("Exceeded maximum tokens"))).ok();
292 }
293 }
294 Ok(StopReason::ToolUse) => {}
295 Err(error) => {
296 if let Some(tx) = tx.take() {
297 tx.send(Err(anyhow!(error.clone()))).ok();
298 }
299 }
300 },
301 ThreadEvent::ShowError(thread_error) => {
302 if let Some(tx) = tx.take() {
303 tx.send(Err(anyhow!(thread_error.clone()))).ok();
304 }
305 }
306 ThreadEvent::StreamedAssistantText(_, chunk) => {
307 write!(&mut log_file, "{}", chunk).log_err();
308 }
309 ThreadEvent::StreamedAssistantThinking(_, chunk) => {
310 write!(&mut log_file, "{}", chunk).log_err();
311 }
312 ThreadEvent::UsePendingTools { tool_uses } => {
313 writeln!(&mut log_file, "\n\nUSING TOOLS:").log_err();
314 for tool_use in tool_uses {
315 writeln!(&mut log_file, "{}: {}", tool_use.name, tool_use.input)
316 .log_err();
317 }
318 }
319 ThreadEvent::ToolFinished {
320 tool_use_id,
321 pending_tool_use,
322 ..
323 } => {
324 if let Some(tool_use) = pending_tool_use {
325 let message = format!("TOOL FINISHED: {}", tool_use.name);
326 println!("{name}> {message}");
327 writeln!(&mut log_file, "\n{}", message).log_err();
328 }
329 if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
330 let message = format!("\n{}\n", tool_result.content);
331 writeln!(&mut log_file, "{}", message).log_err();
332 }
333 }
334 _ => {}
335 }
336
337 log_file.flush().log_err();
338 }
339 })?;
340
341 thread.update(cx, |thread, cx| {
342 let context = vec![];
343 thread.insert_user_message(this.prompt.clone(), context, None, cx);
344 thread.send_to_model(model, RequestKind::Chat, cx);
345 })?;
346
347 rx.await??;
348
349 if let Some((_, lsp_store)) = lsp_open_handle_and_store.as_ref() {
350 wait_for_lang_server(lsp_store, this.name.clone(), cx).await?;
351 }
352
353 let repository_diff = this.repository_diff().await?;
354 let diagnostics = cx
355 .update(move |cx| {
356 cx.spawn(async move |cx| query_lsp_diagnostics(project, cx).await)
357 })?
358 .await?;
359
360 drop(lsp_open_handle_and_store);
361
362 thread.update(cx, |thread, _cx| {
363 let response_count = thread
364 .messages()
365 .filter(|message| message.role == language_model::Role::Assistant)
366 .count();
367 RunOutput {
368 repository_diff,
369 diagnostics,
370 response_count,
371 token_usage: thread.cumulative_token_usage(),
372 }
373 })
374 })
375 }
376
377 pub async fn judge(
378 &mut self,
379 model: Arc<dyn LanguageModel>,
380 repository_diff: String,
381 cx: &AsyncApp,
382 ) -> Result<JudgeOutput> {
383 let judge_prompt = include_str!("judge_prompt.hbs");
384 let judge_prompt_name = "judge_prompt";
385 let mut handlebars = Handlebars::new();
386 handlebars.register_template_string(judge_prompt_name, judge_prompt)?;
387 let prompt = handlebars.render(
388 judge_prompt_name,
389 &JudgeInput {
390 repository_diff,
391 criteria: self.criteria.clone(),
392 },
393 )?;
394
395 let request = LanguageModelRequest {
396 messages: vec![LanguageModelRequestMessage {
397 role: Role::User,
398 content: vec![MessageContent::Text(prompt)],
399 cache: false,
400 }],
401 temperature: None,
402 tools: Vec::new(),
403 stop: Vec::new(),
404 };
405
406 let response = send_language_model_request(model, request, cx).await?;
407
408 let mut log_file = self.log_file.lock().unwrap();
409
410 writeln!(&mut log_file, "\n\n").log_err();
411 writeln!(&mut log_file, "========================================").log_err();
412 writeln!(&mut log_file, " JUDGE OUTPUT ").log_err();
413 writeln!(&mut log_file, "========================================").log_err();
414 writeln!(&mut log_file, "\n{}", &response).log_err();
415
416 parse_judge_output(&response)
417 }
418
419 pub async fn repository_diff(&self) -> Result<String> {
420 let worktree_path = self.worktree_path();
421 run_git(&worktree_path, &["add", "-N"]).await?;
422 run_git(&worktree_path, &["diff"]).await
423 }
424}
425
426fn wait_for_lang_server(
427 lsp_store: &Entity<LspStore>,
428 name: String,
429 cx: &mut AsyncApp,
430) -> Task<Result<()>> {
431 if cx
432 .update(|cx| !has_pending_lang_server_work(lsp_store, cx))
433 .unwrap()
434 || std::env::var("ZED_EVAL_SKIP_LS_WAIT").is_ok()
435 {
436 return Task::ready(anyhow::Ok(()));
437 }
438
439 println!("{}> ⏵ Waiting for language server", name);
440
441 let (mut tx, mut rx) = mpsc::channel(1);
442
443 let subscription =
444 cx.subscribe(&lsp_store, {
445 let name = name.clone();
446 move |lsp_store, event, cx| {
447 match event {
448 project::LspStoreEvent::LanguageServerUpdate {
449 message:
450 client::proto::update_language_server::Variant::WorkProgress(
451 LspWorkProgress {
452 message: Some(message),
453 ..
454 },
455 ),
456 ..
457 } => println!("{name}> ⟲ {message}"),
458 _ => {}
459 }
460
461 if !has_pending_lang_server_work(&lsp_store, cx) {
462 tx.try_send(()).ok();
463 }
464 }
465 });
466
467 cx.spawn(async move |cx| {
468 let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0));
469 let result = futures::select! {
470 _ = rx.next() => {
471 println!("{}> ⚑ Language server idle", name);
472 anyhow::Ok(())
473 },
474 _ = timeout.fuse() => {
475 Err(anyhow!("LSP wait timed out after 5 minutes"))
476 }
477 };
478 drop(subscription);
479 result
480 })
481}
482
483fn has_pending_lang_server_work(lsp_store: &Entity<LspStore>, cx: &App) -> bool {
484 lsp_store
485 .read(cx)
486 .language_server_statuses()
487 .any(|(_, status)| !status.pending_work.is_empty())
488}
489
490async fn query_lsp_diagnostics(project: Entity<Project>, cx: &mut AsyncApp) -> Result<String> {
491 let paths_with_diagnostics = project.update(cx, |project, cx| {
492 project
493 .diagnostic_summaries(true, cx)
494 .filter(|(_, _, summary)| summary.error_count > 0 || summary.warning_count > 0)
495 .map(|(project_path, _, _)| project_path)
496 .collect::<Vec<_>>()
497 })?;
498
499 let mut output = String::new();
500 for project_path in paths_with_diagnostics {
501 let buffer = project
502 .update(cx, |project, cx| project.open_buffer(project_path, cx))?
503 .await?;
504 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
505
506 for (_, group) in snapshot.diagnostic_groups(None) {
507 let entry = &group.entries[group.primary_ix];
508 let range = entry.range.to_point(&snapshot);
509 let severity = match entry.diagnostic.severity {
510 DiagnosticSeverity::ERROR => "error",
511 DiagnosticSeverity::WARNING => "warning",
512 _ => continue,
513 };
514
515 writeln!(
516 output,
517 "{} at line {}: {}",
518 severity,
519 range.start.row + 1,
520 entry.diagnostic.message
521 )?;
522 }
523 }
524 anyhow::Ok(output)
525}
526
527fn parse_judge_output(response: &str) -> Result<JudgeOutput> {
528 let analysis = get_tag("analysis", response)?.to_string();
529 let score = get_tag("score", response)?
530 .parse()
531 .context("error parsing score")?;
532
533 Ok(JudgeOutput { analysis, score })
534}
535
536fn get_tag(name: &'static str, response: &str) -> Result<String> {
537 let start_tag = format!("<{}>", name);
538 let end_tag = format!("</{}>", name);
539
540 let start_ix = response
541 .find(&start_tag)
542 .context(format!("{} start tag not found", name))?;
543 let content_start_ix = start_ix + start_tag.len();
544
545 let end_ix = content_start_ix
546 + response[content_start_ix..]
547 .find(&end_tag)
548 .context(format!("{} end tag not found", name))?;
549
550 let content = response[content_start_ix..end_ix].trim().unindent();
551
552 anyhow::Ok(content)
553}
554
555pub fn repo_path_for_url(repo_url: &str) -> PathBuf {
556 let repo_name = repo_url
557 .trim_start_matches("https://")
558 .replace(|c: char| !c.is_alphanumeric(), "-");
559 Path::new(REPOS_DIR)
560 .canonicalize()
561 .context(format!("No such directory {REPOS_DIR}"))
562 .unwrap()
563 .join(repo_name)
564}
565
566pub async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
567 let output = new_smol_command("git")
568 .current_dir(repo_path)
569 .args(args)
570 .output()
571 .await?;
572
573 if output.status.success() {
574 Ok(String::from_utf8(output.stdout)?.trim().to_string())
575 } else {
576 Err(anyhow!(
577 "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
578 args.join(" "),
579 repo_path.display(),
580 output.status,
581 String::from_utf8_lossy(&output.stderr),
582 String::from_utf8_lossy(&output.stdout),
583 ))
584 }
585}
586
587pub async fn send_language_model_request(
588 model: Arc<dyn LanguageModel>,
589 request: LanguageModelRequest,
590 cx: &AsyncApp,
591) -> anyhow::Result<String> {
592 match model.stream_completion_text(request, &cx).await {
593 Ok(mut stream) => {
594 let mut full_response = String::new();
595 while let Some(chunk_result) = stream.stream.next().await {
596 match chunk_result {
597 Ok(chunk_str) => {
598 print!("{}", &chunk_str);
599 full_response.push_str(&chunk_str);
600 }
601 Err(err) => {
602 return Err(anyhow!(
603 "Error receiving response from language model: {err}"
604 ));
605 }
606 }
607 }
608 Ok(full_response)
609 }
610 Err(err) => Err(anyhow!(
611 "Failed to get response from language model. Error was: {err}"
612 )),
613 }
614}
615
616#[cfg(test)]
617mod test {
618 use super::*;
619
620 #[test]
621 fn test_parse_judge_output() {
622 let response = r#"
623 <analysis>The model did a good job but there were still compilations errors.</analysis>
624 <score>3</score>
625 "#
626 .unindent();
627
628 let output = parse_judge_output(&response).unwrap();
629 assert_eq!(
630 output.analysis,
631 "The model did a good job but there were still compilations errors."
632 );
633 assert_eq!(output.score, 3);
634
635 let response = r#"
636 Text around ignored
637
638 <analysis>
639 Failed to compile:
640 - Error 1
641 - Error 2
642 </analysis>
643
644 <score>1</score>
645 "#
646 .unindent();
647
648 let output = parse_judge_output(&response).unwrap();
649 assert_eq!(output.analysis, "Failed to compile:\n- Error 1\n- Error 2");
650 assert_eq!(output.score, 1);
651 }
652}