1use agent::{RequestKind, ThreadEvent, ThreadStore};
2use anyhow::{Context as _, Result, anyhow};
3use assistant_tool::ToolWorkingSet;
4use client::proto::LspWorkProgress;
5use collections::HashMap;
6use dap::DapRegistry;
7use futures::channel::mpsc;
8use futures::{FutureExt, StreamExt as _, select_biased};
9use gpui::{App, AppContext as _, AsyncApp, Entity, Task};
10use handlebars::Handlebars;
11use language::{DiagnosticSeverity, OffsetRangeExt};
12use language_model::{
13 LanguageModel, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
14 StopReason, TokenUsage,
15};
16use project::{LspStore, Project, ProjectPath};
17use serde::{Deserialize, Serialize};
18use std::fmt::Write as _;
19use std::fs::File;
20use std::io::Write as _;
21use std::sync::{Arc, Mutex};
22use std::time::Duration;
23use std::{
24 fs,
25 path::{Path, PathBuf},
26};
27use unindent::Unindent as _;
28use util::ResultExt as _;
29use util::command::new_smol_command;
30use util::serde::default_true;
31
32use crate::AgentAppState;
33
34pub const EXAMPLES_DIR: &str = "./crates/eval/examples";
35pub const REPOS_DIR: &str = "./crates/eval/repos";
36pub const WORKTREES_DIR: &str = "./crates/eval/worktrees";
37
38const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2);
39
40#[derive(Clone, Debug, Deserialize)]
41pub struct ExampleBase {
42 pub url: String,
43 pub revision: String,
44 pub language_extension: Option<String>,
45 pub insert_id: Option<String>,
46 #[serde(default = "default_true")]
47 pub require_lsp: bool,
48}
49
50#[derive(Clone, Debug)]
51pub struct Example {
52 pub name: String,
53 /// Content of `base.toml`
54 pub base: ExampleBase,
55 /// Content of `prompt.md`
56 pub prompt: String,
57 /// Content of `criteria.md`
58 pub criteria: String,
59 /// Markdown output file to append to
60 pub output_file: Option<Arc<Mutex<File>>>,
61 /// Path to markdown output file
62 pub output_file_path: PathBuf,
63 /// Prefix used for logging that identifies this example
64 pub log_prefix: String,
65}
66
67#[derive(Debug, Serialize, Deserialize, Clone)]
68pub struct RunOutput {
69 pub repository_diff: String,
70 pub diagnostics: String,
71 pub response_count: usize,
72 pub token_usage: TokenUsage,
73 pub tool_use_counts: HashMap<Arc<str>, u32>,
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct JudgeInput {
78 pub repository_diff: String,
79 pub criteria: String,
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct JudgeOutput {
84 pub analysis: String,
85 pub score: u32,
86}
87
88impl Example {
89 /// Load an example from a directory containing base.toml, prompt.md, and criteria.md
90 pub fn load_from_directory(dir_path: &Path, run_dir: &Path) -> Result<Self> {
91 let name = Self::name_from_path(dir_path);
92 let base_path = dir_path.join("base.toml");
93 let prompt_path = dir_path.join("prompt.md");
94 let criteria_path = dir_path.join("criteria.md");
95
96 let output_file_path = run_dir.join(format!(
97 "{}.md",
98 dir_path.file_name().unwrap().to_str().unwrap()
99 ));
100
101 Ok(Example {
102 name: name.clone(),
103 base: toml::from_str(&fs::read_to_string(&base_path)?)?,
104 prompt: fs::read_to_string(prompt_path.clone())?,
105 criteria: fs::read_to_string(criteria_path.clone())?,
106 output_file: None,
107 output_file_path,
108 log_prefix: name,
109 })
110 }
111
112 pub fn set_log_prefix_style(&mut self, color: &str, name_width: usize) {
113 self.log_prefix = format!(
114 "{}{:<width$}\x1b[0m | ",
115 color,
116 self.name,
117 width = name_width
118 );
119 }
120
121 pub fn name_from_path(path: &Path) -> String {
122 path.file_name().unwrap().to_string_lossy().to_string()
123 }
124
125 pub fn worktree_path(&self) -> PathBuf {
126 Path::new(WORKTREES_DIR)
127 .canonicalize()
128 .context(format!("No such directory {WORKTREES_DIR}"))
129 .unwrap()
130 .join(&self.name)
131 }
132
133 /// Set up the example by checking out the specified Git revision
134 pub async fn setup(&mut self) -> Result<()> {
135 let repo_path = repo_path_for_url(&self.base.url);
136
137 println!("{}Fetching", self.log_prefix);
138
139 run_git(
140 &repo_path,
141 &["fetch", "--depth", "1", "origin", &self.base.revision],
142 )
143 .await?;
144
145 let worktree_path = self.worktree_path();
146
147 if worktree_path.is_dir() {
148 println!("{}Resetting existing worktree", self.log_prefix);
149
150 // TODO: consider including "-x" to remove ignored files. The downside of this is that
151 // it will also remove build artifacts, and so prevent incremental reuse there.
152 run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
153 run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
154 run_git(&worktree_path, &["checkout", &self.base.revision]).await?;
155 } else {
156 println!("{}Creating worktree", self.log_prefix);
157
158 let worktree_path_string = worktree_path.to_string_lossy().to_string();
159
160 run_git(
161 &repo_path,
162 &[
163 "worktree",
164 "add",
165 "-f",
166 &worktree_path_string,
167 &self.base.revision,
168 ],
169 )
170 .await?;
171 }
172
173 // Create the output file
174 let output_file = Arc::new(Mutex::new(File::create(&self.output_file_path)?));
175 self.output_file = Some(output_file);
176
177 Ok(())
178 }
179
180 /// Returns the output file, panicking if it's not set
181 fn output_file(&self) -> Arc<Mutex<File>> {
182 self.output_file
183 .clone()
184 .expect("Output file not created. Call setup() first.")
185 }
186
187 pub fn run(
188 &self,
189 model: Arc<dyn LanguageModel>,
190 app_state: Arc<AgentAppState>,
191 cx: &mut App,
192 ) -> Task<Result<RunOutput>> {
193 let project = Project::local(
194 app_state.client.clone(),
195 app_state.node_runtime.clone(),
196 app_state.user_store.clone(),
197 app_state.languages.clone(),
198 Arc::new(DapRegistry::default()),
199 app_state.fs.clone(),
200 None,
201 cx,
202 );
203
204 let worktree_path = self.worktree_path();
205 let worktree = project.update(cx, |project, cx| {
206 project.create_worktree(&worktree_path, true, cx)
207 });
208
209 let tools = cx.new(|_| ToolWorkingSet::default());
210 let thread_store =
211 ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx);
212 let this = self.clone();
213
214 cx.spawn(async move |cx| {
215 let worktree = worktree.await?;
216
217 // Wait for worktree scan to finish before choosing a file to open.
218 worktree
219 .update(cx, |worktree, _cx| {
220 worktree.as_local().unwrap().scan_complete()
221 })?
222 .await;
223
224 let lsp_open_handle_and_store = if this.base.require_lsp {
225 let language_extension = this.base.language_extension.as_deref().context(
226 "language_extension field is required in base.toml when `require_lsp == true`",
227 )?;
228
229 // Open a file that matches the language to cause LSP to start.
230 let language_file = worktree.read_with(cx, |worktree, _cx| {
231 worktree
232 .files(false, 0)
233 .find_map(|e| {
234 if e.path.clone().extension().and_then(|ext| ext.to_str())
235 == Some(language_extension)
236 {
237 Some(ProjectPath {
238 worktree_id: worktree.id(),
239 path: e.path.clone(),
240 })
241 } else {
242 None
243 }
244 })
245 .context("Failed to find a file for example language")
246 })??;
247
248 let open_language_file_buffer_task = project.update(cx, |project, cx| {
249 project.open_buffer(language_file.clone(), cx)
250 })?;
251
252 let language_file_buffer = open_language_file_buffer_task.await?;
253
254 let (lsp_open_handle, lsp_store) = project.update(cx, |project, cx| {
255 (
256 project.register_buffer_with_language_servers(&language_file_buffer, cx),
257 project.lsp_store().clone(),
258 )
259 })?;
260
261 // TODO: remove this once the diagnostics tool waits for new diagnostics
262 cx.background_executor().timer(Duration::new(5, 0)).await;
263 wait_for_lang_server(&lsp_store, this.log_prefix.clone(), cx).await?;
264
265 lsp_store.update(cx, |lsp_store, cx| {
266 lsp_open_handle.update(cx, |buffer, cx| {
267 buffer.update(cx, |buffer, cx| {
268 let has_language_server = lsp_store
269 .language_servers_for_local_buffer(buffer, cx)
270 .next()
271 .is_some();
272 if has_language_server {
273 Ok(())
274 } else {
275 Err(anyhow!(
276 "`{:?}` was opened to cause the language server to start, \
277 but no language servers are registered for its buffer. \
278 Set `require_lsp = false` in `base.toml` to skip this.",
279 language_file
280 ))
281 }
282 })
283 })
284 })??;
285
286 Some((lsp_open_handle, lsp_store))
287 } else {
288 None
289 };
290
291 if std::env::var("ZED_EVAL_SETUP_ONLY").is_ok() {
292 return Err(anyhow!("Setup only mode"));
293 }
294
295 let thread_store = thread_store.await;
296 let thread =
297 thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?;
298
299 {
300 let output_file_ref = this.output_file();
301 let mut output_file = output_file_ref.lock().unwrap();
302 writeln!(&mut output_file, "👤 USER:").log_err();
303 writeln!(&mut output_file, "{}", this.prompt).log_err();
304 writeln!(&mut output_file, "🤖 ASSISTANT:").log_err();
305 output_file.flush().log_err();
306 }
307
308 let tool_use_counts: Arc<Mutex<HashMap<Arc<str>, u32>>> =
309 Mutex::new(HashMap::default()).into();
310
311 let (thread_event_tx, mut thread_event_rx) = mpsc::unbounded();
312
313 let subscription = cx.subscribe(&thread, move |_thread, event: &ThreadEvent, _cx| {
314 thread_event_tx.unbounded_send(event.clone()).log_err();
315 });
316
317 let event_handler_task = cx.spawn({
318 // Need to clone the Arc here because the reference from output_file() won't live long enough
319 let output_file = this.output_file.clone().unwrap();
320 let log_prefix = this.log_prefix.clone();
321 let tool_use_counts = tool_use_counts.clone();
322 let thread = thread.downgrade();
323 async move |cx| {
324 loop {
325 let event = select_biased! {
326 event = thread_event_rx.next() => event,
327 _ = cx.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => {
328 return Err(anyhow!("Agentic loop stalled - waited {:?} without any events", THREAD_EVENT_TIMEOUT));
329 }
330 };
331 let Some(event) = event else {
332 return Err(anyhow!("ThreadEvent channel ended early"));
333 };
334
335 let mut output_file = output_file.lock().unwrap();
336
337 match event {
338 ThreadEvent::Stopped(reason) => match reason {
339 Ok(StopReason::EndTurn) => {
340 return Ok(());
341 }
342 Ok(StopReason::MaxTokens) => {
343 return Err(anyhow!("Exceeded maximum tokens"));
344 }
345 Ok(StopReason::ToolUse) => {
346 if std::env::var("ZED_EVAL_DEBUG").is_ok() {
347 println!("{}StopReason: Tool use", log_prefix);
348 }
349 }
350 Err(error) => {
351 return Err(anyhow!(error.clone()));
352 }
353 },
354 ThreadEvent::ShowError(thread_error) => {
355 break Err(anyhow!(thread_error.clone()));
356 }
357 ThreadEvent::StreamedAssistantText(_, chunk) => {
358 write!(&mut output_file, "{}", chunk).log_err();
359 }
360 ThreadEvent::StreamedAssistantThinking(_, chunk) => {
361 write!(&mut output_file, "{}", chunk).log_err();
362 }
363 ThreadEvent::UsePendingTools { tool_uses } => {
364 writeln!(&mut output_file, "\n\nUSING TOOLS:").log_err();
365 for tool_use in tool_uses {
366 writeln!(&mut output_file, "{}: {}", tool_use.name, tool_use.input)
367 .log_err();
368 }
369 }
370 ThreadEvent::ToolFinished {
371 tool_use_id,
372 pending_tool_use,
373 ..
374 } => {
375 if let Some(tool_use) = pending_tool_use {
376 let message = format!("TOOL FINISHED: {}", tool_use.name);
377 println!("{}{message}", log_prefix);
378 writeln!(&mut output_file, "\n{}", message).log_err();
379 }
380 thread.update(cx, |thread, _cx| {
381 if let Some(tool_result) = thread.tool_result(&tool_use_id) {
382 writeln!(&mut output_file, "\n{}\n", tool_result.content).log_err();
383 let mut tool_use_counts = tool_use_counts.lock().unwrap();
384 *tool_use_counts
385 .entry(tool_result.tool_name.clone())
386 .or_insert(0) += 1;
387 }
388 })?;
389 }
390 ThreadEvent::ToolConfirmationNeeded => {
391 panic!("{}Bug: Tool confirmation should not be required in eval", log_prefix);
392 },
393 ThreadEvent::StreamedCompletion |
394 ThreadEvent::MessageAdded(_) |
395 ThreadEvent::MessageEdited(_) |
396 ThreadEvent::MessageDeleted(_) |
397 ThreadEvent::SummaryChanged |
398 ThreadEvent::SummaryGenerated |
399 ThreadEvent::CheckpointChanged => {
400 if std::env::var("ZED_EVAL_DEBUG").is_ok() {
401 println!("{}Event: {:#?}", log_prefix, event);
402 }
403 }
404 }
405
406 output_file.flush().log_err();
407 }
408 }
409 });
410
411 thread.update(cx, |thread, cx| {
412 let context = vec![];
413 thread.insert_user_message(this.prompt.clone(), context, None, cx);
414 thread.send_to_model(model, RequestKind::Chat, cx);
415 })?;
416
417 event_handler_task.await?;
418
419 println!("{}Stopped", this.log_prefix);
420
421 if let Some((_, lsp_store)) = lsp_open_handle_and_store.as_ref() {
422 wait_for_lang_server(lsp_store, this.log_prefix.clone(), cx).await?;
423 }
424
425 println!("{}Getting repository diff", this.log_prefix);
426 let repository_diff = this.repository_diff().await?;
427
428 println!("{}Getting diagnostics", this.log_prefix);
429 let diagnostics = cx
430 .update(move |cx| {
431 cx.spawn(async move |cx| query_lsp_diagnostics(project, cx).await)
432 })?
433 .await?;
434 println!("{}Got diagnostics", this.log_prefix);
435
436 drop(subscription);
437 drop(lsp_open_handle_and_store);
438
439 thread.update(cx, |thread, _cx| {
440 let response_count = thread
441 .messages()
442 .filter(|message| message.role == language_model::Role::Assistant)
443 .count();
444 RunOutput {
445 repository_diff,
446 diagnostics,
447 response_count,
448 token_usage: thread.cumulative_token_usage(),
449 tool_use_counts: tool_use_counts.lock().unwrap().clone(),
450 }
451 })
452 })
453 }
454
455 pub async fn judge(
456 &self,
457 model: Arc<dyn LanguageModel>,
458 repository_diff: String,
459 cx: &AsyncApp,
460 ) -> Result<JudgeOutput> {
461 let judge_prompt = include_str!("judge_prompt.hbs");
462 let judge_prompt_name = "judge_prompt";
463 let mut handlebars = Handlebars::new();
464 handlebars.register_template_string(judge_prompt_name, judge_prompt)?;
465 let prompt = handlebars.render(
466 judge_prompt_name,
467 &JudgeInput {
468 repository_diff,
469 criteria: self.criteria.clone(),
470 },
471 )?;
472
473 let request = LanguageModelRequest {
474 messages: vec![LanguageModelRequestMessage {
475 role: Role::User,
476 content: vec![MessageContent::Text(prompt)],
477 cache: false,
478 }],
479 temperature: None,
480 tools: Vec::new(),
481 stop: Vec::new(),
482 };
483
484 let response = send_language_model_request(model, request, cx).await?;
485
486 let output_file_ref = self.output_file();
487 let mut output_file = output_file_ref.lock().unwrap();
488
489 writeln!(&mut output_file, "\n\n").log_err();
490 writeln!(&mut output_file, "========================================").log_err();
491 writeln!(&mut output_file, " JUDGE OUTPUT ").log_err();
492 writeln!(&mut output_file, "========================================").log_err();
493 writeln!(&mut output_file, "\n{}", &response).log_err();
494
495 parse_judge_output(&response)
496 }
497
498 pub async fn repository_diff(&self) -> Result<String> {
499 let worktree_path = self.worktree_path();
500 run_git(&worktree_path, &["add", "-N"]).await?;
501 run_git(&worktree_path, &["diff"]).await
502 }
503}
504
505fn wait_for_lang_server(
506 lsp_store: &Entity<LspStore>,
507 log_prefix: String,
508 cx: &mut AsyncApp,
509) -> Task<Result<()>> {
510 if cx
511 .update(|cx| !has_pending_lang_server_work(lsp_store, cx))
512 .unwrap()
513 || std::env::var("ZED_EVAL_SKIP_LS_WAIT").is_ok()
514 {
515 return Task::ready(anyhow::Ok(()));
516 }
517
518 println!("{}⏵ Waiting for language server", log_prefix);
519
520 let (mut tx, mut rx) = mpsc::channel(1);
521
522 let subscription =
523 cx.subscribe(&lsp_store, {
524 let log_prefix = log_prefix.clone();
525 move |lsp_store, event, cx| {
526 match event {
527 project::LspStoreEvent::LanguageServerUpdate {
528 message:
529 client::proto::update_language_server::Variant::WorkProgress(
530 LspWorkProgress {
531 message: Some(message),
532 ..
533 },
534 ),
535 ..
536 } => println!("{}⟲ {message}", log_prefix),
537 _ => {}
538 }
539
540 if !has_pending_lang_server_work(&lsp_store, cx) {
541 tx.try_send(()).ok();
542 }
543 }
544 });
545
546 cx.spawn(async move |cx| {
547 let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0));
548 let result = futures::select! {
549 _ = rx.next() => {
550 println!("{}⚑ Language server idle", log_prefix);
551 anyhow::Ok(())
552 },
553 _ = timeout.fuse() => {
554 Err(anyhow!("LSP wait timed out after 5 minutes"))
555 }
556 };
557 drop(subscription);
558 result
559 })
560}
561
562fn has_pending_lang_server_work(lsp_store: &Entity<LspStore>, cx: &App) -> bool {
563 lsp_store
564 .read(cx)
565 .language_server_statuses()
566 .any(|(_, status)| !status.pending_work.is_empty())
567}
568
569async fn query_lsp_diagnostics(project: Entity<Project>, cx: &mut AsyncApp) -> Result<String> {
570 let paths_with_diagnostics = project.update(cx, |project, cx| {
571 project
572 .diagnostic_summaries(true, cx)
573 .filter(|(_, _, summary)| summary.error_count > 0 || summary.warning_count > 0)
574 .map(|(project_path, _, _)| project_path)
575 .collect::<Vec<_>>()
576 })?;
577
578 let mut output = String::new();
579 for project_path in paths_with_diagnostics {
580 let buffer = project
581 .update(cx, |project, cx| project.open_buffer(project_path, cx))?
582 .await?;
583 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
584
585 for (_, group) in snapshot.diagnostic_groups(None) {
586 let entry = &group.entries[group.primary_ix];
587 let range = entry.range.to_point(&snapshot);
588 let severity = match entry.diagnostic.severity {
589 DiagnosticSeverity::ERROR => "error",
590 DiagnosticSeverity::WARNING => "warning",
591 _ => continue,
592 };
593
594 writeln!(
595 output,
596 "{} at line {}: {}",
597 severity,
598 range.start.row + 1,
599 entry.diagnostic.message
600 )?;
601 }
602 }
603 anyhow::Ok(output)
604}
605
606fn parse_judge_output(response: &str) -> Result<JudgeOutput> {
607 let analysis = get_tag("analysis", response)?.to_string();
608 let score = get_tag("score", response)?
609 .parse()
610 .context("error parsing score")?;
611
612 Ok(JudgeOutput { analysis, score })
613}
614
615fn get_tag(name: &'static str, response: &str) -> Result<String> {
616 let start_tag = format!("<{}>", name);
617 let end_tag = format!("</{}>", name);
618
619 let start_ix = response
620 .find(&start_tag)
621 .context(format!("{} start tag not found", name))?;
622 let content_start_ix = start_ix + start_tag.len();
623
624 let end_ix = content_start_ix
625 + response[content_start_ix..]
626 .find(&end_tag)
627 .context(format!("{} end tag not found", name))?;
628
629 let content = response[content_start_ix..end_ix].trim().unindent();
630
631 anyhow::Ok(content)
632}
633
634pub fn repo_path_for_url(repo_url: &str) -> PathBuf {
635 let repo_name = repo_url
636 .trim_start_matches("https://")
637 .replace(|c: char| !c.is_alphanumeric(), "-");
638 Path::new(REPOS_DIR)
639 .canonicalize()
640 .context(format!("No such directory {REPOS_DIR}"))
641 .unwrap()
642 .join(repo_name)
643}
644
645pub async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
646 let output = new_smol_command("git")
647 .current_dir(repo_path)
648 .args(args)
649 .output()
650 .await?;
651
652 if output.status.success() {
653 Ok(String::from_utf8(output.stdout)?.trim().to_string())
654 } else {
655 Err(anyhow!(
656 "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
657 args.join(" "),
658 repo_path.display(),
659 output.status,
660 String::from_utf8_lossy(&output.stderr),
661 String::from_utf8_lossy(&output.stdout),
662 ))
663 }
664}
665
666pub async fn send_language_model_request(
667 model: Arc<dyn LanguageModel>,
668 request: LanguageModelRequest,
669 cx: &AsyncApp,
670) -> anyhow::Result<String> {
671 match model.stream_completion_text(request, &cx).await {
672 Ok(mut stream) => {
673 let mut full_response = String::new();
674 while let Some(chunk_result) = stream.stream.next().await {
675 match chunk_result {
676 Ok(chunk_str) => {
677 full_response.push_str(&chunk_str);
678 }
679 Err(err) => {
680 return Err(anyhow!(
681 "Error receiving response from language model: {err}"
682 ));
683 }
684 }
685 }
686 Ok(full_response)
687 }
688 Err(err) => Err(anyhow!(
689 "Failed to get response from language model. Error was: {err}"
690 )),
691 }
692}
693
694#[cfg(test)]
695mod test {
696 use super::*;
697
698 #[test]
699 fn test_parse_judge_output() {
700 let response = r#"
701 <analysis>The model did a good job but there were still compilations errors.</analysis>
702 <score>3</score>
703 "#
704 .unindent();
705
706 let output = parse_judge_output(&response).unwrap();
707 assert_eq!(
708 output.analysis,
709 "The model did a good job but there were still compilations errors."
710 );
711 assert_eq!(output.score, 3);
712
713 let response = r#"
714 Text around ignored
715
716 <analysis>
717 Failed to compile:
718 - Error 1
719 - Error 2
720 </analysis>
721
722 <score>1</score>
723 "#
724 .unindent();
725
726 let output = parse_judge_output(&response).unwrap();
727 assert_eq!(output.analysis, "Failed to compile:\n- Error 1\n- Error 2");
728 assert_eq!(output.score, 1);
729 }
730}