1use agent::{RequestKind, ThreadEvent, ThreadStore};
2use anyhow::{Context as _, Result, anyhow};
3use assistant_tool::ToolWorkingSet;
4use client::proto::LspWorkProgress;
5use collections::HashMap;
6use dap::DapRegistry;
7use futures::channel::mpsc;
8use futures::{FutureExt, StreamExt as _, select_biased};
9use gpui::{App, AppContext as _, AsyncApp, Entity, Task};
10use handlebars::Handlebars;
11use language::{DiagnosticSeverity, OffsetRangeExt};
12use language_model::{
13 LanguageModel, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
14 StopReason, TokenUsage,
15};
16use project::{LspStore, Project, ProjectPath};
17use serde::{Deserialize, Serialize};
18use std::fmt::Write as _;
19use std::fs::File;
20use std::io::Write as _;
21use std::sync::{Arc, Mutex};
22use std::time::Duration;
23use std::{
24 fs,
25 path::{Path, PathBuf},
26};
27use unindent::Unindent as _;
28use util::ResultExt as _;
29use util::command::new_smol_command;
30use util::serde::default_true;
31
32use crate::AgentAppState;
33
34pub const EXAMPLES_DIR: &str = "./crates/eval/examples";
35pub const REPOS_DIR: &str = "./crates/eval/repos";
36pub const WORKTREES_DIR: &str = "./crates/eval/worktrees";
37
38const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2);
39
40#[derive(Clone, Debug, Deserialize)]
41pub struct ExampleBase {
42 pub url: String,
43 pub revision: String,
44 pub language_extension: Option<String>,
45 pub insert_id: Option<String>,
46 #[serde(default = "default_true")]
47 pub require_lsp: bool,
48}
49
50#[derive(Clone, Debug)]
51pub struct Example {
52 pub name: String,
53 /// Content of `base.toml`
54 pub base: ExampleBase,
55 /// Content of `prompt.md`
56 pub prompt: String,
57 /// Content of `criteria.md`
58 pub criteria: String,
59 /// Markdown output file to append to
60 pub output_file: Arc<Mutex<File>>,
61 /// Path to markdown output file
62 pub output_file_path: PathBuf,
63 /// Prefix used for logging that identifies this example
64 pub log_prefix: String,
65}
66
67#[derive(Debug, Serialize, Deserialize, Clone)]
68pub struct RunOutput {
69 pub repository_diff: String,
70 pub diagnostics: String,
71 pub response_count: usize,
72 pub token_usage: TokenUsage,
73 pub tool_use_counts: HashMap<Arc<str>, u32>,
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct JudgeInput {
78 pub repository_diff: String,
79 pub criteria: String,
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct JudgeOutput {
84 pub analysis: String,
85 pub score: u32,
86}
87
88impl Example {
89 /// Load an example from a directory containing base.toml, prompt.md, and criteria.md
90 pub fn load_from_directory(dir_path: &Path, run_dir: &Path) -> Result<Self> {
91 let name = Self::name_from_path(dir_path);
92 let base_path = dir_path.join("base.toml");
93 let prompt_path = dir_path.join("prompt.md");
94 let criteria_path = dir_path.join("criteria.md");
95
96 let output_file_path = run_dir.join(format!(
97 "{}.md",
98 dir_path.file_name().unwrap().to_str().unwrap()
99 ));
100 let output_file = Arc::new(Mutex::new(File::create(&output_file_path).unwrap()));
101
102 Ok(Example {
103 name: name.clone(),
104 base: toml::from_str(&fs::read_to_string(&base_path)?)?,
105 prompt: fs::read_to_string(prompt_path.clone())?,
106 criteria: fs::read_to_string(criteria_path.clone())?,
107 output_file,
108 output_file_path,
109 log_prefix: name,
110 })
111 }
112
113 pub fn set_log_prefix_style(&mut self, color: &str, name_width: usize) {
114 self.log_prefix = format!(
115 "{}{:<width$}\x1b[0m | ",
116 color,
117 self.name,
118 width = name_width
119 );
120 }
121
122 pub fn name_from_path(path: &Path) -> String {
123 path.file_name().unwrap().to_string_lossy().to_string()
124 }
125
126 pub fn worktree_path(&self) -> PathBuf {
127 Path::new(WORKTREES_DIR)
128 .canonicalize()
129 .context(format!("No such directory {WORKTREES_DIR}"))
130 .unwrap()
131 .join(&self.name)
132 }
133
134 /// Set up the example by checking out the specified Git revision
135 pub async fn setup(&self) -> Result<()> {
136 let repo_path = repo_path_for_url(&self.base.url);
137
138 println!("{}Fetching", self.log_prefix);
139
140 run_git(
141 &repo_path,
142 &["fetch", "--depth", "1", "origin", &self.base.revision],
143 )
144 .await?;
145
146 let worktree_path = self.worktree_path();
147
148 if worktree_path.is_dir() {
149 println!("{}Resetting existing worktree", self.log_prefix);
150
151 // TODO: consider including "-x" to remove ignored files. The downside of this is that
152 // it will also remove build artifacts, and so prevent incremental reuse there.
153 run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
154 run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
155 run_git(&worktree_path, &["checkout", &self.base.revision]).await?;
156 } else {
157 println!("{}Creating worktree", self.log_prefix);
158
159 let worktree_path_string = worktree_path.to_string_lossy().to_string();
160
161 run_git(
162 &repo_path,
163 &[
164 "worktree",
165 "add",
166 "-f",
167 &worktree_path_string,
168 &self.base.revision,
169 ],
170 )
171 .await?;
172 }
173
174 Ok(())
175 }
176
177 pub fn run(
178 &self,
179 model: Arc<dyn LanguageModel>,
180 app_state: Arc<AgentAppState>,
181 cx: &mut App,
182 ) -> Task<Result<RunOutput>> {
183 let project = Project::local(
184 app_state.client.clone(),
185 app_state.node_runtime.clone(),
186 app_state.user_store.clone(),
187 app_state.languages.clone(),
188 Arc::new(DapRegistry::default()),
189 app_state.fs.clone(),
190 None,
191 cx,
192 );
193
194 let worktree_path = self.worktree_path();
195 let worktree = project.update(cx, |project, cx| {
196 project.create_worktree(&worktree_path, true, cx)
197 });
198
199 let tools = cx.new(|_| ToolWorkingSet::default());
200 let thread_store =
201 ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx);
202 let this = self.clone();
203
204 cx.spawn(async move |cx| {
205 let worktree = worktree.await?;
206
207 // Wait for worktree scan to finish before choosing a file to open.
208 worktree
209 .update(cx, |worktree, _cx| {
210 worktree.as_local().unwrap().scan_complete()
211 })?
212 .await;
213
214 let lsp_open_handle_and_store = if this.base.require_lsp {
215 let language_extension = this.base.language_extension.as_deref().context(
216 "language_extension field is required in base.toml when `require_lsp == true`",
217 )?;
218
219 // Open a file that matches the language to cause LSP to start.
220 let language_file = worktree.read_with(cx, |worktree, _cx| {
221 worktree
222 .files(false, 0)
223 .find_map(|e| {
224 if e.path.clone().extension().and_then(|ext| ext.to_str())
225 == Some(language_extension)
226 {
227 Some(ProjectPath {
228 worktree_id: worktree.id(),
229 path: e.path.clone(),
230 })
231 } else {
232 None
233 }
234 })
235 .context("Failed to find a file for example language")
236 })??;
237
238 let open_language_file_buffer_task = project.update(cx, |project, cx| {
239 project.open_buffer(language_file.clone(), cx)
240 })?;
241
242 let language_file_buffer = open_language_file_buffer_task.await?;
243
244 let (lsp_open_handle, lsp_store) = project.update(cx, |project, cx| {
245 (
246 project.register_buffer_with_language_servers(&language_file_buffer, cx),
247 project.lsp_store().clone(),
248 )
249 })?;
250
251 // TODO: remove this once the diagnostics tool waits for new diagnostics
252 cx.background_executor().timer(Duration::new(5, 0)).await;
253 wait_for_lang_server(&lsp_store, this.log_prefix.clone(), cx).await?;
254
255 lsp_store.update(cx, |lsp_store, cx| {
256 lsp_open_handle.update(cx, |buffer, cx| {
257 buffer.update(cx, |buffer, cx| {
258 let has_language_server = lsp_store
259 .language_servers_for_local_buffer(buffer, cx)
260 .next()
261 .is_some();
262 if has_language_server {
263 Ok(())
264 } else {
265 Err(anyhow!(
266 "`{:?}` was opened to cause the language server to start, \
267 but no language servers are registered for its buffer. \
268 Set `require_lsp = false` in `base.toml` to skip this.",
269 language_file
270 ))
271 }
272 })
273 })
274 })??;
275
276 Some((lsp_open_handle, lsp_store))
277 } else {
278 None
279 };
280
281 if std::env::var("ZED_EVAL_SETUP_ONLY").is_ok() {
282 return Err(anyhow!("Setup only mode"));
283 }
284
285 let thread_store = thread_store.await;
286 let thread =
287 thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?;
288
289 {
290 let mut output_file = this.output_file.lock().unwrap();
291 writeln!(&mut output_file, "👤 USER:").log_err();
292 writeln!(&mut output_file, "{}", this.prompt).log_err();
293 writeln!(&mut output_file, "🤖 ASSISTANT:").log_err();
294 output_file.flush().log_err();
295 }
296
297 let tool_use_counts: Arc<Mutex<HashMap<Arc<str>, u32>>> =
298 Mutex::new(HashMap::default()).into();
299
300 let (thread_event_tx, mut thread_event_rx) = mpsc::unbounded();
301
302 let subscription = cx.subscribe(&thread, move |_thread, event: &ThreadEvent, _cx| {
303 thread_event_tx.unbounded_send(event.clone()).log_err();
304 });
305
306 let event_handler_task = cx.spawn({
307 let output_file = this.output_file.clone();
308 let log_prefix = this.log_prefix.clone();
309 let tool_use_counts = tool_use_counts.clone();
310 let thread = thread.downgrade();
311 async move |cx| {
312 loop {
313 let event = select_biased! {
314 event = thread_event_rx.next() => event,
315 _ = cx.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => {
316 return Err(anyhow!("Agentic loop stalled - waited {:?} without any events", THREAD_EVENT_TIMEOUT));
317 }
318 };
319 let Some(event) = event else {
320 return Err(anyhow!("ThreadEvent channel ended early"));
321 };
322
323 let mut output_file = output_file.lock().unwrap();
324
325 match event {
326 ThreadEvent::Stopped(reason) => match reason {
327 Ok(StopReason::EndTurn) => {
328 return Ok(());
329 }
330 Ok(StopReason::MaxTokens) => {
331 return Err(anyhow!("Exceeded maximum tokens"));
332 }
333 Ok(StopReason::ToolUse) => {
334 if std::env::var("ZED_EVAL_DEBUG").is_ok() {
335 println!("{}StopReason: Tool use", log_prefix);
336 }
337 }
338 Err(error) => {
339 return Err(anyhow!(error.clone()));
340 }
341 },
342 ThreadEvent::ShowError(thread_error) => {
343 break Err(anyhow!(thread_error.clone()));
344 }
345 ThreadEvent::StreamedAssistantText(_, chunk) => {
346 write!(&mut output_file, "{}", chunk).log_err();
347 }
348 ThreadEvent::StreamedAssistantThinking(_, chunk) => {
349 write!(&mut output_file, "{}", chunk).log_err();
350 }
351 ThreadEvent::UsePendingTools { tool_uses } => {
352 writeln!(&mut output_file, "\n\nUSING TOOLS:").log_err();
353 for tool_use in tool_uses {
354 writeln!(&mut output_file, "{}: {}", tool_use.name, tool_use.input)
355 .log_err();
356 }
357 }
358 ThreadEvent::ToolFinished {
359 tool_use_id,
360 pending_tool_use,
361 ..
362 } => {
363 if let Some(tool_use) = pending_tool_use {
364 let message = format!("TOOL FINISHED: {}", tool_use.name);
365 println!("{}{message}", log_prefix);
366 writeln!(&mut output_file, "\n{}", message).log_err();
367 }
368 thread.update(cx, |thread, _cx| {
369 if let Some(tool_result) = thread.tool_result(&tool_use_id) {
370 writeln!(&mut output_file, "\n{}\n", tool_result.content).log_err();
371 let mut tool_use_counts = tool_use_counts.lock().unwrap();
372 *tool_use_counts
373 .entry(tool_result.tool_name.clone())
374 .or_insert(0) += 1;
375 }
376 })?;
377 }
378 ThreadEvent::ToolConfirmationNeeded => {
379 panic!("{}Bug: Tool confirmation should not be required in eval", log_prefix);
380 },
381 ThreadEvent::StreamedCompletion |
382 ThreadEvent::MessageAdded(_) |
383 ThreadEvent::MessageEdited(_) |
384 ThreadEvent::MessageDeleted(_) |
385 ThreadEvent::SummaryChanged |
386 ThreadEvent::SummaryGenerated |
387 ThreadEvent::CheckpointChanged => {
388 if std::env::var("ZED_EVAL_DEBUG").is_ok() {
389 println!("{}Event: {:#?}", log_prefix, event);
390 }
391 }
392 }
393
394 output_file.flush().log_err();
395 }
396 }
397 });
398
399 thread.update(cx, |thread, cx| {
400 let context = vec![];
401 thread.insert_user_message(this.prompt.clone(), context, None, cx);
402 thread.send_to_model(model, RequestKind::Chat, cx);
403 })?;
404
405 event_handler_task.await?;
406
407 println!("{}Stopped", this.log_prefix);
408
409 if let Some((_, lsp_store)) = lsp_open_handle_and_store.as_ref() {
410 wait_for_lang_server(lsp_store, this.log_prefix.clone(), cx).await?;
411 }
412
413 println!("{}Getting repository diff", this.log_prefix);
414 let repository_diff = this.repository_diff().await?;
415
416 println!("{}Getting diagnostics", this.log_prefix);
417 let diagnostics = cx
418 .update(move |cx| {
419 cx.spawn(async move |cx| query_lsp_diagnostics(project, cx).await)
420 })?
421 .await?;
422 println!("{}Got diagnostics", this.log_prefix);
423
424 drop(subscription);
425 drop(lsp_open_handle_and_store);
426
427 thread.update(cx, |thread, _cx| {
428 let response_count = thread
429 .messages()
430 .filter(|message| message.role == language_model::Role::Assistant)
431 .count();
432 RunOutput {
433 repository_diff,
434 diagnostics,
435 response_count,
436 token_usage: thread.cumulative_token_usage(),
437 tool_use_counts: tool_use_counts.lock().unwrap().clone(),
438 }
439 })
440 })
441 }
442
443 pub async fn judge(
444 &self,
445 model: Arc<dyn LanguageModel>,
446 repository_diff: String,
447 cx: &AsyncApp,
448 ) -> Result<JudgeOutput> {
449 let judge_prompt = include_str!("judge_prompt.hbs");
450 let judge_prompt_name = "judge_prompt";
451 let mut handlebars = Handlebars::new();
452 handlebars.register_template_string(judge_prompt_name, judge_prompt)?;
453 let prompt = handlebars.render(
454 judge_prompt_name,
455 &JudgeInput {
456 repository_diff,
457 criteria: self.criteria.clone(),
458 },
459 )?;
460
461 let request = LanguageModelRequest {
462 messages: vec![LanguageModelRequestMessage {
463 role: Role::User,
464 content: vec![MessageContent::Text(prompt)],
465 cache: false,
466 }],
467 temperature: None,
468 tools: Vec::new(),
469 stop: Vec::new(),
470 };
471
472 let response = send_language_model_request(model, request, cx).await?;
473
474 let mut output_file = self.output_file.lock().unwrap();
475
476 writeln!(&mut output_file, "\n\n").log_err();
477 writeln!(&mut output_file, "========================================").log_err();
478 writeln!(&mut output_file, " JUDGE OUTPUT ").log_err();
479 writeln!(&mut output_file, "========================================").log_err();
480 writeln!(&mut output_file, "\n{}", &response).log_err();
481
482 parse_judge_output(&response)
483 }
484
485 pub async fn repository_diff(&self) -> Result<String> {
486 let worktree_path = self.worktree_path();
487 run_git(&worktree_path, &["add", "-N"]).await?;
488 run_git(&worktree_path, &["diff"]).await
489 }
490}
491
492fn wait_for_lang_server(
493 lsp_store: &Entity<LspStore>,
494 log_prefix: String,
495 cx: &mut AsyncApp,
496) -> Task<Result<()>> {
497 if cx
498 .update(|cx| !has_pending_lang_server_work(lsp_store, cx))
499 .unwrap()
500 || std::env::var("ZED_EVAL_SKIP_LS_WAIT").is_ok()
501 {
502 return Task::ready(anyhow::Ok(()));
503 }
504
505 println!("{}⏵ Waiting for language server", log_prefix);
506
507 let (mut tx, mut rx) = mpsc::channel(1);
508
509 let subscription =
510 cx.subscribe(&lsp_store, {
511 let log_prefix = log_prefix.clone();
512 move |lsp_store, event, cx| {
513 match event {
514 project::LspStoreEvent::LanguageServerUpdate {
515 message:
516 client::proto::update_language_server::Variant::WorkProgress(
517 LspWorkProgress {
518 message: Some(message),
519 ..
520 },
521 ),
522 ..
523 } => println!("{}⟲ {message}", log_prefix),
524 _ => {}
525 }
526
527 if !has_pending_lang_server_work(&lsp_store, cx) {
528 tx.try_send(()).ok();
529 }
530 }
531 });
532
533 cx.spawn(async move |cx| {
534 let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0));
535 let result = futures::select! {
536 _ = rx.next() => {
537 println!("{}⚑ Language server idle", log_prefix);
538 anyhow::Ok(())
539 },
540 _ = timeout.fuse() => {
541 Err(anyhow!("LSP wait timed out after 5 minutes"))
542 }
543 };
544 drop(subscription);
545 result
546 })
547}
548
549fn has_pending_lang_server_work(lsp_store: &Entity<LspStore>, cx: &App) -> bool {
550 lsp_store
551 .read(cx)
552 .language_server_statuses()
553 .any(|(_, status)| !status.pending_work.is_empty())
554}
555
556async fn query_lsp_diagnostics(project: Entity<Project>, cx: &mut AsyncApp) -> Result<String> {
557 let paths_with_diagnostics = project.update(cx, |project, cx| {
558 project
559 .diagnostic_summaries(true, cx)
560 .filter(|(_, _, summary)| summary.error_count > 0 || summary.warning_count > 0)
561 .map(|(project_path, _, _)| project_path)
562 .collect::<Vec<_>>()
563 })?;
564
565 let mut output = String::new();
566 for project_path in paths_with_diagnostics {
567 let buffer = project
568 .update(cx, |project, cx| project.open_buffer(project_path, cx))?
569 .await?;
570 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
571
572 for (_, group) in snapshot.diagnostic_groups(None) {
573 let entry = &group.entries[group.primary_ix];
574 let range = entry.range.to_point(&snapshot);
575 let severity = match entry.diagnostic.severity {
576 DiagnosticSeverity::ERROR => "error",
577 DiagnosticSeverity::WARNING => "warning",
578 _ => continue,
579 };
580
581 writeln!(
582 output,
583 "{} at line {}: {}",
584 severity,
585 range.start.row + 1,
586 entry.diagnostic.message
587 )?;
588 }
589 }
590 anyhow::Ok(output)
591}
592
593fn parse_judge_output(response: &str) -> Result<JudgeOutput> {
594 let analysis = get_tag("analysis", response)?.to_string();
595 let score = get_tag("score", response)?
596 .parse()
597 .context("error parsing score")?;
598
599 Ok(JudgeOutput { analysis, score })
600}
601
602fn get_tag(name: &'static str, response: &str) -> Result<String> {
603 let start_tag = format!("<{}>", name);
604 let end_tag = format!("</{}>", name);
605
606 let start_ix = response
607 .find(&start_tag)
608 .context(format!("{} start tag not found", name))?;
609 let content_start_ix = start_ix + start_tag.len();
610
611 let end_ix = content_start_ix
612 + response[content_start_ix..]
613 .find(&end_tag)
614 .context(format!("{} end tag not found", name))?;
615
616 let content = response[content_start_ix..end_ix].trim().unindent();
617
618 anyhow::Ok(content)
619}
620
621pub fn repo_path_for_url(repo_url: &str) -> PathBuf {
622 let repo_name = repo_url
623 .trim_start_matches("https://")
624 .replace(|c: char| !c.is_alphanumeric(), "-");
625 Path::new(REPOS_DIR)
626 .canonicalize()
627 .context(format!("No such directory {REPOS_DIR}"))
628 .unwrap()
629 .join(repo_name)
630}
631
632pub async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
633 let output = new_smol_command("git")
634 .current_dir(repo_path)
635 .args(args)
636 .output()
637 .await?;
638
639 if output.status.success() {
640 Ok(String::from_utf8(output.stdout)?.trim().to_string())
641 } else {
642 Err(anyhow!(
643 "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
644 args.join(" "),
645 repo_path.display(),
646 output.status,
647 String::from_utf8_lossy(&output.stderr),
648 String::from_utf8_lossy(&output.stdout),
649 ))
650 }
651}
652
653pub async fn send_language_model_request(
654 model: Arc<dyn LanguageModel>,
655 request: LanguageModelRequest,
656 cx: &AsyncApp,
657) -> anyhow::Result<String> {
658 match model.stream_completion_text(request, &cx).await {
659 Ok(mut stream) => {
660 let mut full_response = String::new();
661 while let Some(chunk_result) = stream.stream.next().await {
662 match chunk_result {
663 Ok(chunk_str) => {
664 full_response.push_str(&chunk_str);
665 }
666 Err(err) => {
667 return Err(anyhow!(
668 "Error receiving response from language model: {err}"
669 ));
670 }
671 }
672 }
673 Ok(full_response)
674 }
675 Err(err) => Err(anyhow!(
676 "Failed to get response from language model. Error was: {err}"
677 )),
678 }
679}
680
681#[cfg(test)]
682mod test {
683 use super::*;
684
685 #[test]
686 fn test_parse_judge_output() {
687 let response = r#"
688 <analysis>The model did a good job but there were still compilations errors.</analysis>
689 <score>3</score>
690 "#
691 .unindent();
692
693 let output = parse_judge_output(&response).unwrap();
694 assert_eq!(
695 output.analysis,
696 "The model did a good job but there were still compilations errors."
697 );
698 assert_eq!(output.score, 3);
699
700 let response = r#"
701 Text around ignored
702
703 <analysis>
704 Failed to compile:
705 - Error 1
706 - Error 2
707 </analysis>
708
709 <score>1</score>
710 "#
711 .unindent();
712
713 let output = parse_judge_output(&response).unwrap();
714 assert_eq!(output.analysis, "Failed to compile:\n- Error 1\n- Error 2");
715 assert_eq!(output.score, 1);
716 }
717}