1use agent::{RequestKind, ThreadEvent, ThreadStore};
2use anyhow::{Context as _, Result, anyhow};
3use assistant_tool::ToolWorkingSet;
4use client::proto::LspWorkProgress;
5use collections::HashMap;
6use dap::DapRegistry;
7use futures::channel::mpsc;
8use futures::{FutureExt, StreamExt as _, select_biased};
9use gpui::{App, AsyncApp, Entity, Task};
10use handlebars::Handlebars;
11use language::{DiagnosticSeverity, OffsetRangeExt};
12use language_model::{
13 LanguageModel, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
14 StopReason, TokenUsage,
15};
16use project::{LspStore, Project, ProjectPath};
17use serde::{Deserialize, Serialize};
18use std::fmt::Write as _;
19use std::fs::File;
20use std::io::Write as _;
21use std::sync::{Arc, Mutex};
22use std::time::Duration;
23use std::{
24 fs,
25 path::{Path, PathBuf},
26};
27use unindent::Unindent as _;
28use util::ResultExt as _;
29use util::command::new_smol_command;
30use util::serde::default_true;
31
32use crate::AgentAppState;
33
34pub const EXAMPLES_DIR: &str = "./crates/eval/examples";
35pub const REPOS_DIR: &str = "./crates/eval/repos";
36pub const WORKTREES_DIR: &str = "./crates/eval/worktrees";
37
38const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2);
39
40#[derive(Clone, Debug, Deserialize)]
41pub struct ExampleBase {
42 pub url: String,
43 pub revision: String,
44 pub language_extension: Option<String>,
45 pub insert_id: Option<String>,
46 #[serde(default = "default_true")]
47 pub require_lsp: bool,
48}
49
50#[derive(Clone, Debug)]
51pub struct Example {
52 pub name: String,
53 /// Content of `base.toml`
54 pub base: ExampleBase,
55 /// Content of `prompt.md`
56 pub prompt: String,
57 /// Content of `criteria.md`
58 pub criteria: String,
59 /// Markdown log file to append to
60 pub log_file: Arc<Mutex<File>>,
61 /// Path to markdown log file
62 pub log_file_path: PathBuf,
63}
64
65#[derive(Debug, Serialize, Deserialize, Clone)]
66pub struct RunOutput {
67 pub repository_diff: String,
68 pub diagnostics: String,
69 pub response_count: usize,
70 pub token_usage: TokenUsage,
71 pub tool_use_counts: HashMap<Arc<str>, u32>,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct JudgeInput {
76 pub repository_diff: String,
77 pub criteria: String,
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct JudgeOutput {
82 pub analysis: String,
83 pub score: u32,
84}
85
86impl Example {
87 /// Load an example from a directory containing base.toml, prompt.md, and criteria.md
88 pub fn load_from_directory(dir_path: &Path, run_dir: &Path) -> Result<Self> {
89 let name = dir_path.file_name().unwrap().to_string_lossy().to_string();
90 let base_path = dir_path.join("base.toml");
91 let prompt_path = dir_path.join("prompt.md");
92 let criteria_path = dir_path.join("criteria.md");
93
94 let log_file_path = run_dir.join(format!(
95 "{}.md",
96 dir_path.file_name().unwrap().to_str().unwrap()
97 ));
98 let log_file = Arc::new(Mutex::new(File::create(&log_file_path).unwrap()));
99 println!("{}> Logging to {:?}", name, log_file_path);
100
101 Ok(Example {
102 name,
103 base: toml::from_str(&fs::read_to_string(&base_path)?)?,
104 prompt: fs::read_to_string(prompt_path.clone())?,
105 criteria: fs::read_to_string(criteria_path.clone())?,
106 log_file,
107 log_file_path,
108 })
109 }
110
111 pub fn worktree_path(&self) -> PathBuf {
112 Path::new(WORKTREES_DIR)
113 .canonicalize()
114 .context(format!("No such directory {WORKTREES_DIR}"))
115 .unwrap()
116 .join(&self.name)
117 }
118
119 /// Set up the example by checking out the specified Git revision
120 pub async fn setup(&self) -> Result<()> {
121 let repo_path = repo_path_for_url(&self.base.url);
122
123 println!("{}> Fetching", self.name);
124
125 run_git(
126 &repo_path,
127 &["fetch", "--depth", "1", "origin", &self.base.revision],
128 )
129 .await?;
130
131 let worktree_path = self.worktree_path();
132
133 if worktree_path.is_dir() {
134 println!("{}> Resetting existing worktree", self.name);
135
136 // TODO: consider including "-x" to remove ignored files. The downside of this is that
137 // it will also remove build artifacts, and so prevent incremental reuse there.
138 run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
139 run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
140 run_git(&worktree_path, &["checkout", &self.base.revision]).await?;
141 } else {
142 println!("{}> Creating worktree", self.name);
143
144 let worktree_path_string = worktree_path.to_string_lossy().to_string();
145
146 run_git(
147 &repo_path,
148 &[
149 "worktree",
150 "add",
151 "-f",
152 &worktree_path_string,
153 &self.base.revision,
154 ],
155 )
156 .await?;
157 }
158
159 Ok(())
160 }
161
162 pub fn run(
163 &self,
164 model: Arc<dyn LanguageModel>,
165 app_state: Arc<AgentAppState>,
166 cx: &mut App,
167 ) -> Task<Result<RunOutput>> {
168 let project = Project::local(
169 app_state.client.clone(),
170 app_state.node_runtime.clone(),
171 app_state.user_store.clone(),
172 app_state.languages.clone(),
173 Arc::new(DapRegistry::default()),
174 app_state.fs.clone(),
175 None,
176 cx,
177 );
178
179 let worktree_path = self.worktree_path();
180 let worktree = project.update(cx, |project, cx| {
181 project.create_worktree(&worktree_path, true, cx)
182 });
183
184 let tools = Arc::new(ToolWorkingSet::default());
185 let thread_store =
186 ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx);
187 let this = self.clone();
188
189 cx.spawn(async move |cx| {
190 let worktree = worktree.await?;
191
192 // Wait for worktree scan to finish before choosing a file to open.
193 worktree
194 .update(cx, |worktree, _cx| {
195 worktree.as_local().unwrap().scan_complete()
196 })?
197 .await;
198
199 let lsp_open_handle_and_store = if this.base.require_lsp {
200 let language_extension = this.base.language_extension.as_deref().context(
201 "language_extension field is required in base.toml when `require_lsp == true`",
202 )?;
203
204 // Open a file that matches the language to cause LSP to start.
205 let language_file = worktree.read_with(cx, |worktree, _cx| {
206 worktree
207 .files(false, 0)
208 .find_map(|e| {
209 if e.path.clone().extension().and_then(|ext| ext.to_str())
210 == Some(language_extension)
211 {
212 Some(ProjectPath {
213 worktree_id: worktree.id(),
214 path: e.path.clone(),
215 })
216 } else {
217 None
218 }
219 })
220 .context("Failed to find a file for example language")
221 })??;
222
223 let open_language_file_buffer_task = project.update(cx, |project, cx| {
224 project.open_buffer(language_file.clone(), cx)
225 })?;
226
227 let language_file_buffer = open_language_file_buffer_task.await?;
228
229 let (lsp_open_handle, lsp_store) = project.update(cx, |project, cx| {
230 (
231 project.register_buffer_with_language_servers(&language_file_buffer, cx),
232 project.lsp_store().clone(),
233 )
234 })?;
235
236 // TODO: remove this once the diagnostics tool waits for new diagnostics
237 cx.background_executor().timer(Duration::new(5, 0)).await;
238 wait_for_lang_server(&lsp_store, this.name.clone(), cx).await?;
239
240 lsp_store.update(cx, |lsp_store, cx| {
241 lsp_open_handle.update(cx, |buffer, cx| {
242 buffer.update(cx, |buffer, cx| {
243 let has_language_server = lsp_store
244 .language_servers_for_local_buffer(buffer, cx)
245 .next()
246 .is_some();
247 if has_language_server {
248 Ok(())
249 } else {
250 Err(anyhow!(
251 "`{:?}` was opened to cause the language server to start, \
252 but no language servers are registered for its buffer. \
253 Set `require_lsp = false` in `base.toml` to skip this.",
254 language_file
255 ))
256 }
257 })
258 })
259 })??;
260
261 Some((lsp_open_handle, lsp_store))
262 } else {
263 None
264 };
265
266 if std::env::var("ZED_EVAL_SETUP_ONLY").is_ok() {
267 return Err(anyhow!("Setup only mode"));
268 }
269
270 let thread_store = thread_store.await;
271 let thread =
272 thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?;
273
274 {
275 let mut log_file = this.log_file.lock().unwrap();
276 writeln!(&mut log_file, "👤 USER:").log_err();
277 writeln!(&mut log_file, "{}", this.prompt).log_err();
278 writeln!(&mut log_file, "🤖 ASSISTANT:").log_err();
279 log_file.flush().log_err();
280 }
281
282 let tool_use_counts: Arc<Mutex<HashMap<Arc<str>, u32>>> =
283 Mutex::new(HashMap::default()).into();
284
285 let (thread_event_tx, mut thread_event_rx) = mpsc::unbounded();
286
287 let subscription = cx.subscribe(&thread, move |_thread, event: &ThreadEvent, _cx| {
288 thread_event_tx.unbounded_send(event.clone()).log_err();
289 });
290
291 let event_handler_task = cx.spawn({
292 let log_file = this.log_file.clone();
293 let name = this.name.clone();
294 let tool_use_counts = tool_use_counts.clone();
295 let thread = thread.downgrade();
296 async move |cx| {
297 loop {
298 let event = select_biased! {
299 event = thread_event_rx.next() => event,
300 _ = cx.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => {
301 return Err(anyhow!("Agentic loop stalled - waited {:?} without any events", THREAD_EVENT_TIMEOUT));
302 }
303 };
304 let Some(event) = event else {
305 return Err(anyhow!("ThreadEvent channel ended early"));
306 };
307
308 let mut log_file = log_file.lock().unwrap();
309
310 match event {
311 ThreadEvent::Stopped(reason) => match reason {
312 Ok(StopReason::EndTurn) => {
313 return Ok(());
314 }
315 Ok(StopReason::MaxTokens) => {
316 return Err(anyhow!("Exceeded maximum tokens"));
317 }
318 Ok(StopReason::ToolUse) => {}
319 Err(error) => {
320 return Err(anyhow!(error.clone()));
321 }
322 },
323 ThreadEvent::ShowError(thread_error) => {
324 break Err(anyhow!(thread_error.clone()));
325 }
326 ThreadEvent::StreamedAssistantText(_, chunk) => {
327 write!(&mut log_file, "{}", chunk).log_err();
328 }
329 ThreadEvent::StreamedAssistantThinking(_, chunk) => {
330 write!(&mut log_file, "{}", chunk).log_err();
331 }
332 ThreadEvent::UsePendingTools { tool_uses } => {
333 writeln!(&mut log_file, "\n\nUSING TOOLS:").log_err();
334 for tool_use in tool_uses {
335 writeln!(&mut log_file, "{}: {}", tool_use.name, tool_use.input)
336 .log_err();
337 }
338 }
339 ThreadEvent::ToolFinished {
340 tool_use_id,
341 pending_tool_use,
342 ..
343 } => {
344 if let Some(tool_use) = pending_tool_use {
345 let message = format!("TOOL FINISHED: {}", tool_use.name);
346 println!("{name}> {message}");
347 writeln!(&mut log_file, "\n{}", message).log_err();
348 }
349 thread.update(cx, |thread, _cx| {
350 if let Some(tool_result) = thread.tool_result(&tool_use_id) {
351 writeln!(&mut log_file, "\n{}\n", tool_result.content).log_err();
352 let mut tool_use_counts = tool_use_counts.lock().unwrap();
353 *tool_use_counts
354 .entry(tool_result.tool_name.clone())
355 .or_insert(0) += 1;
356 }
357 })?;
358 }
359 _ => {}
360 }
361
362 log_file.flush().log_err();
363 }
364 }
365 });
366
367 thread.update(cx, |thread, cx| {
368 let context = vec![];
369 thread.insert_user_message(this.prompt.clone(), context, None, cx);
370 thread.send_to_model(model, RequestKind::Chat, cx);
371 })?;
372
373 event_handler_task.await?;
374
375 if let Some((_, lsp_store)) = lsp_open_handle_and_store.as_ref() {
376 wait_for_lang_server(lsp_store, this.name.clone(), cx).await?;
377 }
378
379 let repository_diff = this.repository_diff().await?;
380 let diagnostics = cx
381 .update(move |cx| {
382 cx.spawn(async move |cx| query_lsp_diagnostics(project, cx).await)
383 })?
384 .await?;
385
386 drop(subscription);
387 drop(lsp_open_handle_and_store);
388
389 thread.update(cx, |thread, _cx| {
390 let response_count = thread
391 .messages()
392 .filter(|message| message.role == language_model::Role::Assistant)
393 .count();
394 RunOutput {
395 repository_diff,
396 diagnostics,
397 response_count,
398 token_usage: thread.cumulative_token_usage(),
399 tool_use_counts: tool_use_counts.lock().unwrap().clone(),
400 }
401 })
402 })
403 }
404
405 pub async fn judge(
406 &self,
407 model: Arc<dyn LanguageModel>,
408 repository_diff: String,
409 cx: &AsyncApp,
410 ) -> Result<JudgeOutput> {
411 let judge_prompt = include_str!("judge_prompt.hbs");
412 let judge_prompt_name = "judge_prompt";
413 let mut handlebars = Handlebars::new();
414 handlebars.register_template_string(judge_prompt_name, judge_prompt)?;
415 let prompt = handlebars.render(
416 judge_prompt_name,
417 &JudgeInput {
418 repository_diff,
419 criteria: self.criteria.clone(),
420 },
421 )?;
422
423 let request = LanguageModelRequest {
424 messages: vec![LanguageModelRequestMessage {
425 role: Role::User,
426 content: vec![MessageContent::Text(prompt)],
427 cache: false,
428 }],
429 temperature: None,
430 tools: Vec::new(),
431 stop: Vec::new(),
432 };
433
434 let response = send_language_model_request(model, request, cx).await?;
435
436 let mut log_file = self.log_file.lock().unwrap();
437
438 writeln!(&mut log_file, "\n\n").log_err();
439 writeln!(&mut log_file, "========================================").log_err();
440 writeln!(&mut log_file, " JUDGE OUTPUT ").log_err();
441 writeln!(&mut log_file, "========================================").log_err();
442 writeln!(&mut log_file, "\n{}", &response).log_err();
443
444 parse_judge_output(&response)
445 }
446
447 pub async fn repository_diff(&self) -> Result<String> {
448 let worktree_path = self.worktree_path();
449 run_git(&worktree_path, &["add", "-N"]).await?;
450 run_git(&worktree_path, &["diff"]).await
451 }
452}
453
454fn wait_for_lang_server(
455 lsp_store: &Entity<LspStore>,
456 name: String,
457 cx: &mut AsyncApp,
458) -> Task<Result<()>> {
459 if cx
460 .update(|cx| !has_pending_lang_server_work(lsp_store, cx))
461 .unwrap()
462 || std::env::var("ZED_EVAL_SKIP_LS_WAIT").is_ok()
463 {
464 return Task::ready(anyhow::Ok(()));
465 }
466
467 println!("{}> ⏵ Waiting for language server", name);
468
469 let (mut tx, mut rx) = mpsc::channel(1);
470
471 let subscription =
472 cx.subscribe(&lsp_store, {
473 let name = name.clone();
474 move |lsp_store, event, cx| {
475 match event {
476 project::LspStoreEvent::LanguageServerUpdate {
477 message:
478 client::proto::update_language_server::Variant::WorkProgress(
479 LspWorkProgress {
480 message: Some(message),
481 ..
482 },
483 ),
484 ..
485 } => println!("{name}> ⟲ {message}"),
486 _ => {}
487 }
488
489 if !has_pending_lang_server_work(&lsp_store, cx) {
490 tx.try_send(()).ok();
491 }
492 }
493 });
494
495 cx.spawn(async move |cx| {
496 let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0));
497 let result = futures::select! {
498 _ = rx.next() => {
499 println!("{}> ⚑ Language server idle", name);
500 anyhow::Ok(())
501 },
502 _ = timeout.fuse() => {
503 Err(anyhow!("LSP wait timed out after 5 minutes"))
504 }
505 };
506 drop(subscription);
507 result
508 })
509}
510
511fn has_pending_lang_server_work(lsp_store: &Entity<LspStore>, cx: &App) -> bool {
512 lsp_store
513 .read(cx)
514 .language_server_statuses()
515 .any(|(_, status)| !status.pending_work.is_empty())
516}
517
518async fn query_lsp_diagnostics(project: Entity<Project>, cx: &mut AsyncApp) -> Result<String> {
519 let paths_with_diagnostics = project.update(cx, |project, cx| {
520 project
521 .diagnostic_summaries(true, cx)
522 .filter(|(_, _, summary)| summary.error_count > 0 || summary.warning_count > 0)
523 .map(|(project_path, _, _)| project_path)
524 .collect::<Vec<_>>()
525 })?;
526
527 let mut output = String::new();
528 for project_path in paths_with_diagnostics {
529 let buffer = project
530 .update(cx, |project, cx| project.open_buffer(project_path, cx))?
531 .await?;
532 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
533
534 for (_, group) in snapshot.diagnostic_groups(None) {
535 let entry = &group.entries[group.primary_ix];
536 let range = entry.range.to_point(&snapshot);
537 let severity = match entry.diagnostic.severity {
538 DiagnosticSeverity::ERROR => "error",
539 DiagnosticSeverity::WARNING => "warning",
540 _ => continue,
541 };
542
543 writeln!(
544 output,
545 "{} at line {}: {}",
546 severity,
547 range.start.row + 1,
548 entry.diagnostic.message
549 )?;
550 }
551 }
552 anyhow::Ok(output)
553}
554
555fn parse_judge_output(response: &str) -> Result<JudgeOutput> {
556 let analysis = get_tag("analysis", response)?.to_string();
557 let score = get_tag("score", response)?
558 .parse()
559 .context("error parsing score")?;
560
561 Ok(JudgeOutput { analysis, score })
562}
563
564fn get_tag(name: &'static str, response: &str) -> Result<String> {
565 let start_tag = format!("<{}>", name);
566 let end_tag = format!("</{}>", name);
567
568 let start_ix = response
569 .find(&start_tag)
570 .context(format!("{} start tag not found", name))?;
571 let content_start_ix = start_ix + start_tag.len();
572
573 let end_ix = content_start_ix
574 + response[content_start_ix..]
575 .find(&end_tag)
576 .context(format!("{} end tag not found", name))?;
577
578 let content = response[content_start_ix..end_ix].trim().unindent();
579
580 anyhow::Ok(content)
581}
582
583pub fn repo_path_for_url(repo_url: &str) -> PathBuf {
584 let repo_name = repo_url
585 .trim_start_matches("https://")
586 .replace(|c: char| !c.is_alphanumeric(), "-");
587 Path::new(REPOS_DIR)
588 .canonicalize()
589 .context(format!("No such directory {REPOS_DIR}"))
590 .unwrap()
591 .join(repo_name)
592}
593
594pub async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
595 let output = new_smol_command("git")
596 .current_dir(repo_path)
597 .args(args)
598 .output()
599 .await?;
600
601 if output.status.success() {
602 Ok(String::from_utf8(output.stdout)?.trim().to_string())
603 } else {
604 Err(anyhow!(
605 "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
606 args.join(" "),
607 repo_path.display(),
608 output.status,
609 String::from_utf8_lossy(&output.stderr),
610 String::from_utf8_lossy(&output.stdout),
611 ))
612 }
613}
614
615pub async fn send_language_model_request(
616 model: Arc<dyn LanguageModel>,
617 request: LanguageModelRequest,
618 cx: &AsyncApp,
619) -> anyhow::Result<String> {
620 match model.stream_completion_text(request, &cx).await {
621 Ok(mut stream) => {
622 let mut full_response = String::new();
623 while let Some(chunk_result) = stream.stream.next().await {
624 match chunk_result {
625 Ok(chunk_str) => {
626 print!("{}", &chunk_str);
627 full_response.push_str(&chunk_str);
628 }
629 Err(err) => {
630 return Err(anyhow!(
631 "Error receiving response from language model: {err}"
632 ));
633 }
634 }
635 }
636 Ok(full_response)
637 }
638 Err(err) => Err(anyhow!(
639 "Failed to get response from language model. Error was: {err}"
640 )),
641 }
642}
643
644#[cfg(test)]
645mod test {
646 use super::*;
647
648 #[test]
649 fn test_parse_judge_output() {
650 let response = r#"
651 <analysis>The model did a good job but there were still compilations errors.</analysis>
652 <score>3</score>
653 "#
654 .unindent();
655
656 let output = parse_judge_output(&response).unwrap();
657 assert_eq!(
658 output.analysis,
659 "The model did a good job but there were still compilations errors."
660 );
661 assert_eq!(output.score, 3);
662
663 let response = r#"
664 Text around ignored
665
666 <analysis>
667 Failed to compile:
668 - Error 1
669 - Error 2
670 </analysis>
671
672 <score>1</score>
673 "#
674 .unindent();
675
676 let output = parse_judge_output(&response).unwrap();
677 assert_eq!(output.analysis, "Failed to compile:\n- Error 1\n- Error 2");
678 assert_eq!(output.score, 1);
679 }
680}