1use agent::{RequestKind, ThreadEvent, ThreadStore};
2use anyhow::{Context as _, Result, anyhow};
3use assistant_tool::ToolWorkingSet;
4use client::proto::LspWorkProgress;
5use collections::HashMap;
6use dap::DapRegistry;
7use futures::channel::mpsc;
8use futures::{FutureExt, StreamExt as _, select_biased};
9use gpui::{App, AsyncApp, Entity, Task};
10use handlebars::Handlebars;
11use language::{DiagnosticSeverity, OffsetRangeExt};
12use language_model::{
13 LanguageModel, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
14 StopReason, TokenUsage,
15};
16use project::{LspStore, Project, ProjectPath};
17use serde::{Deserialize, Serialize};
18use std::fmt::Write as _;
19use std::fs::File;
20use std::io::Write as _;
21use std::sync::{Arc, Mutex};
22use std::time::Duration;
23use std::{
24 fs,
25 path::{Path, PathBuf},
26};
27use unindent::Unindent as _;
28use util::ResultExt as _;
29use util::command::new_smol_command;
30use util::serde::default_true;
31
32use crate::AgentAppState;
33
34pub const EXAMPLES_DIR: &str = "./crates/eval/examples";
35pub const REPOS_DIR: &str = "./crates/eval/repos";
36pub const WORKTREES_DIR: &str = "./crates/eval/worktrees";
37
38const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2);
39
40#[derive(Clone, Debug, Deserialize)]
41pub struct ExampleBase {
42 pub url: String,
43 pub revision: String,
44 pub language_extension: Option<String>,
45 pub insert_id: Option<String>,
46 #[serde(default = "default_true")]
47 pub require_lsp: bool,
48}
49
50#[derive(Clone, Debug)]
51pub struct Example {
52 pub name: String,
53 /// Content of `base.toml`
54 pub base: ExampleBase,
55 /// Content of `prompt.md`
56 pub prompt: String,
57 /// Content of `criteria.md`
58 pub criteria: String,
59 /// Markdown log file to append to
60 pub log_file: Arc<Mutex<File>>,
61}
62
63#[derive(Debug, Serialize, Deserialize, Clone)]
64pub struct RunOutput {
65 pub repository_diff: String,
66 pub diagnostics: String,
67 pub response_count: usize,
68 pub token_usage: TokenUsage,
69 pub tool_use_counts: HashMap<Arc<str>, u32>,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct JudgeInput {
74 pub repository_diff: String,
75 pub criteria: String,
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct JudgeOutput {
80 pub analysis: String,
81 pub score: u32,
82}
83
84impl Example {
85 /// Load an example from a directory containing base.toml, prompt.md, and criteria.md
86 pub fn load_from_directory(dir_path: &Path, run_dir: &Path) -> Result<Self> {
87 let name = dir_path.file_name().unwrap().to_string_lossy().to_string();
88 let base_path = dir_path.join("base.toml");
89 let prompt_path = dir_path.join("prompt.md");
90 let criteria_path = dir_path.join("criteria.md");
91
92 let log_file_path = run_dir.join(format!(
93 "{}.md",
94 dir_path.file_name().unwrap().to_str().unwrap()
95 ));
96 let log_file = Arc::new(Mutex::new(File::create(&log_file_path).unwrap()));
97 println!("{}> Logging to {:?}", name, log_file_path);
98
99 Ok(Example {
100 name,
101 base: toml::from_str(&fs::read_to_string(&base_path)?)?,
102 prompt: fs::read_to_string(prompt_path.clone())?,
103 criteria: fs::read_to_string(criteria_path.clone())?,
104 log_file,
105 })
106 }
107
108 pub fn worktree_path(&self) -> PathBuf {
109 Path::new(WORKTREES_DIR)
110 .canonicalize()
111 .context(format!("No such directory {WORKTREES_DIR}"))
112 .unwrap()
113 .join(&self.name)
114 }
115
116 /// Set up the example by checking out the specified Git revision
117 pub async fn setup(&self) -> Result<()> {
118 let repo_path = repo_path_for_url(&self.base.url);
119
120 println!("{}> Fetching", self.name);
121
122 run_git(
123 &repo_path,
124 &["fetch", "--depth", "1", "origin", &self.base.revision],
125 )
126 .await?;
127
128 let worktree_path = self.worktree_path();
129
130 if worktree_path.is_dir() {
131 println!("{}> Resetting existing worktree", self.name);
132
133 // TODO: consider including "-x" to remove ignored files. The downside of this is that
134 // it will also remove build artifacts, and so prevent incremental reuse there.
135 run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
136 run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
137 run_git(&worktree_path, &["checkout", &self.base.revision]).await?;
138 } else {
139 println!("{}> Creating worktree", self.name);
140
141 let worktree_path_string = worktree_path.to_string_lossy().to_string();
142
143 run_git(
144 &repo_path,
145 &[
146 "worktree",
147 "add",
148 "-f",
149 &worktree_path_string,
150 &self.base.revision,
151 ],
152 )
153 .await?;
154 }
155
156 Ok(())
157 }
158
159 pub fn run(
160 &self,
161 model: Arc<dyn LanguageModel>,
162 app_state: Arc<AgentAppState>,
163 cx: &mut App,
164 ) -> Task<Result<RunOutput>> {
165 let project = Project::local(
166 app_state.client.clone(),
167 app_state.node_runtime.clone(),
168 app_state.user_store.clone(),
169 app_state.languages.clone(),
170 Arc::new(DapRegistry::default()),
171 app_state.fs.clone(),
172 None,
173 cx,
174 );
175
176 let worktree_path = self.worktree_path();
177 let worktree = project.update(cx, |project, cx| {
178 project.create_worktree(&worktree_path, true, cx)
179 });
180
181 let tools = Arc::new(ToolWorkingSet::default());
182 let thread_store =
183 ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx);
184 let this = self.clone();
185
186 cx.spawn(async move |cx| {
187 let worktree = worktree.await?;
188
189 // Wait for worktree scan to finish before choosing a file to open.
190 worktree
191 .update(cx, |worktree, _cx| {
192 worktree.as_local().unwrap().scan_complete()
193 })?
194 .await;
195
196 let lsp_open_handle_and_store = if this.base.require_lsp {
197 let language_extension = this.base.language_extension.as_deref().context(
198 "language_extension field is required in base.toml when `require_lsp == true`",
199 )?;
200
201 // Open a file that matches the language to cause LSP to start.
202 let language_file = worktree.read_with(cx, |worktree, _cx| {
203 worktree
204 .files(false, 0)
205 .find_map(|e| {
206 if e.path.clone().extension().and_then(|ext| ext.to_str())
207 == Some(language_extension)
208 {
209 Some(ProjectPath {
210 worktree_id: worktree.id(),
211 path: e.path.clone(),
212 })
213 } else {
214 None
215 }
216 })
217 .context("Failed to find a file for example language")
218 })??;
219
220 let open_language_file_buffer_task = project.update(cx, |project, cx| {
221 project.open_buffer(language_file.clone(), cx)
222 })?;
223
224 let language_file_buffer = open_language_file_buffer_task.await?;
225
226 let (lsp_open_handle, lsp_store) = project.update(cx, |project, cx| {
227 (
228 project.register_buffer_with_language_servers(&language_file_buffer, cx),
229 project.lsp_store().clone(),
230 )
231 })?;
232
233 // TODO: remove this once the diagnostics tool waits for new diagnostics
234 cx.background_executor().timer(Duration::new(5, 0)).await;
235 wait_for_lang_server(&lsp_store, this.name.clone(), cx).await?;
236
237 lsp_store.update(cx, |lsp_store, cx| {
238 lsp_open_handle.update(cx, |buffer, cx| {
239 buffer.update(cx, |buffer, cx| {
240 let has_language_server = lsp_store
241 .language_servers_for_local_buffer(buffer, cx)
242 .next()
243 .is_some();
244 if has_language_server {
245 Ok(())
246 } else {
247 Err(anyhow!(
248 "`{:?}` was opened to cause the language server to start, \
249 but no language servers are registered for its buffer. \
250 Set `require_lsp = false` in `base.toml` to skip this.",
251 language_file
252 ))
253 }
254 })
255 })
256 })??;
257
258 Some((lsp_open_handle, lsp_store))
259 } else {
260 None
261 };
262
263 if std::env::var("ZED_EVAL_SETUP_ONLY").is_ok() {
264 return Err(anyhow!("Setup only mode"));
265 }
266
267 let thread_store = thread_store.await;
268 let thread =
269 thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?;
270
271 {
272 let mut log_file = this.log_file.lock().unwrap();
273 writeln!(&mut log_file, "👤 USER:").log_err();
274 writeln!(&mut log_file, "{}", this.prompt).log_err();
275 writeln!(&mut log_file, "🤖 ASSISTANT:").log_err();
276 log_file.flush().log_err();
277 }
278
279 let tool_use_counts: Arc<Mutex<HashMap<Arc<str>, u32>>> =
280 Mutex::new(HashMap::default()).into();
281
282 let (thread_event_tx, mut thread_event_rx) = mpsc::unbounded();
283
284 let subscription = cx.subscribe(&thread, move |_thread, event: &ThreadEvent, _cx| {
285 thread_event_tx.unbounded_send(event.clone()).log_err();
286 });
287
288 let event_handler_task = cx.spawn({
289 let log_file = this.log_file.clone();
290 let name = this.name.clone();
291 let tool_use_counts = tool_use_counts.clone();
292 let thread = thread.downgrade();
293 async move |cx| {
294 loop {
295 let event = select_biased! {
296 event = thread_event_rx.next() => event,
297 _ = cx.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => {
298 return Err(anyhow!("Agentic loop stalled - waited {:?} without any events", THREAD_EVENT_TIMEOUT));
299 }
300 };
301 let Some(event) = event else {
302 return Err(anyhow!("ThreadEvent channel ended early"));
303 };
304
305 let mut log_file = log_file.lock().unwrap();
306
307 match event {
308 ThreadEvent::Stopped(reason) => match reason {
309 Ok(StopReason::EndTurn) => {
310 return Ok(());
311 }
312 Ok(StopReason::MaxTokens) => {
313 return Err(anyhow!("Exceeded maximum tokens"));
314 }
315 Ok(StopReason::ToolUse) => {}
316 Err(error) => {
317 return Err(anyhow!(error.clone()));
318 }
319 },
320 ThreadEvent::ShowError(thread_error) => {
321 break Err(anyhow!(thread_error.clone()));
322 }
323 ThreadEvent::StreamedAssistantText(_, chunk) => {
324 write!(&mut log_file, "{}", chunk).log_err();
325 }
326 ThreadEvent::StreamedAssistantThinking(_, chunk) => {
327 write!(&mut log_file, "{}", chunk).log_err();
328 }
329 ThreadEvent::UsePendingTools { tool_uses } => {
330 writeln!(&mut log_file, "\n\nUSING TOOLS:").log_err();
331 for tool_use in tool_uses {
332 writeln!(&mut log_file, "{}: {}", tool_use.name, tool_use.input)
333 .log_err();
334 }
335 }
336 ThreadEvent::ToolFinished {
337 tool_use_id,
338 pending_tool_use,
339 ..
340 } => {
341 if let Some(tool_use) = pending_tool_use {
342 let message = format!("TOOL FINISHED: {}", tool_use.name);
343 println!("{name}> {message}");
344 writeln!(&mut log_file, "\n{}", message).log_err();
345 }
346 thread.update(cx, |thread, _cx| {
347 if let Some(tool_result) = thread.tool_result(&tool_use_id) {
348 writeln!(&mut log_file, "\n{}\n", tool_result.content).log_err();
349 let mut tool_use_counts = tool_use_counts.lock().unwrap();
350 *tool_use_counts
351 .entry(tool_result.tool_name.clone())
352 .or_insert(0) += 1;
353 }
354 })?;
355 }
356 _ => {}
357 }
358
359 log_file.flush().log_err();
360 }
361 }
362 });
363
364 thread.update(cx, |thread, cx| {
365 let context = vec![];
366 thread.insert_user_message(this.prompt.clone(), context, None, cx);
367 thread.send_to_model(model, RequestKind::Chat, cx);
368 })?;
369
370 event_handler_task.await?;
371
372 if let Some((_, lsp_store)) = lsp_open_handle_and_store.as_ref() {
373 wait_for_lang_server(lsp_store, this.name.clone(), cx).await?;
374 }
375
376 let repository_diff = this.repository_diff().await?;
377 let diagnostics = cx
378 .update(move |cx| {
379 cx.spawn(async move |cx| query_lsp_diagnostics(project, cx).await)
380 })?
381 .await?;
382
383 drop(subscription);
384 drop(lsp_open_handle_and_store);
385
386 thread.update(cx, |thread, _cx| {
387 let response_count = thread
388 .messages()
389 .filter(|message| message.role == language_model::Role::Assistant)
390 .count();
391 RunOutput {
392 repository_diff,
393 diagnostics,
394 response_count,
395 token_usage: thread.cumulative_token_usage(),
396 tool_use_counts: tool_use_counts.lock().unwrap().clone(),
397 }
398 })
399 })
400 }
401
402 pub async fn judge(
403 &mut self,
404 model: Arc<dyn LanguageModel>,
405 repository_diff: String,
406 cx: &AsyncApp,
407 ) -> Result<JudgeOutput> {
408 let judge_prompt = include_str!("judge_prompt.hbs");
409 let judge_prompt_name = "judge_prompt";
410 let mut handlebars = Handlebars::new();
411 handlebars.register_template_string(judge_prompt_name, judge_prompt)?;
412 let prompt = handlebars.render(
413 judge_prompt_name,
414 &JudgeInput {
415 repository_diff,
416 criteria: self.criteria.clone(),
417 },
418 )?;
419
420 let request = LanguageModelRequest {
421 messages: vec![LanguageModelRequestMessage {
422 role: Role::User,
423 content: vec![MessageContent::Text(prompt)],
424 cache: false,
425 }],
426 temperature: None,
427 tools: Vec::new(),
428 stop: Vec::new(),
429 };
430
431 let response = send_language_model_request(model, request, cx).await?;
432
433 let mut log_file = self.log_file.lock().unwrap();
434
435 writeln!(&mut log_file, "\n\n").log_err();
436 writeln!(&mut log_file, "========================================").log_err();
437 writeln!(&mut log_file, " JUDGE OUTPUT ").log_err();
438 writeln!(&mut log_file, "========================================").log_err();
439 writeln!(&mut log_file, "\n{}", &response).log_err();
440
441 parse_judge_output(&response)
442 }
443
444 pub async fn repository_diff(&self) -> Result<String> {
445 let worktree_path = self.worktree_path();
446 run_git(&worktree_path, &["add", "-N"]).await?;
447 run_git(&worktree_path, &["diff"]).await
448 }
449}
450
451fn wait_for_lang_server(
452 lsp_store: &Entity<LspStore>,
453 name: String,
454 cx: &mut AsyncApp,
455) -> Task<Result<()>> {
456 if cx
457 .update(|cx| !has_pending_lang_server_work(lsp_store, cx))
458 .unwrap()
459 || std::env::var("ZED_EVAL_SKIP_LS_WAIT").is_ok()
460 {
461 return Task::ready(anyhow::Ok(()));
462 }
463
464 println!("{}> ⏵ Waiting for language server", name);
465
466 let (mut tx, mut rx) = mpsc::channel(1);
467
468 let subscription =
469 cx.subscribe(&lsp_store, {
470 let name = name.clone();
471 move |lsp_store, event, cx| {
472 match event {
473 project::LspStoreEvent::LanguageServerUpdate {
474 message:
475 client::proto::update_language_server::Variant::WorkProgress(
476 LspWorkProgress {
477 message: Some(message),
478 ..
479 },
480 ),
481 ..
482 } => println!("{name}> ⟲ {message}"),
483 _ => {}
484 }
485
486 if !has_pending_lang_server_work(&lsp_store, cx) {
487 tx.try_send(()).ok();
488 }
489 }
490 });
491
492 cx.spawn(async move |cx| {
493 let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0));
494 let result = futures::select! {
495 _ = rx.next() => {
496 println!("{}> ⚑ Language server idle", name);
497 anyhow::Ok(())
498 },
499 _ = timeout.fuse() => {
500 Err(anyhow!("LSP wait timed out after 5 minutes"))
501 }
502 };
503 drop(subscription);
504 result
505 })
506}
507
508fn has_pending_lang_server_work(lsp_store: &Entity<LspStore>, cx: &App) -> bool {
509 lsp_store
510 .read(cx)
511 .language_server_statuses()
512 .any(|(_, status)| !status.pending_work.is_empty())
513}
514
515async fn query_lsp_diagnostics(project: Entity<Project>, cx: &mut AsyncApp) -> Result<String> {
516 let paths_with_diagnostics = project.update(cx, |project, cx| {
517 project
518 .diagnostic_summaries(true, cx)
519 .filter(|(_, _, summary)| summary.error_count > 0 || summary.warning_count > 0)
520 .map(|(project_path, _, _)| project_path)
521 .collect::<Vec<_>>()
522 })?;
523
524 let mut output = String::new();
525 for project_path in paths_with_diagnostics {
526 let buffer = project
527 .update(cx, |project, cx| project.open_buffer(project_path, cx))?
528 .await?;
529 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
530
531 for (_, group) in snapshot.diagnostic_groups(None) {
532 let entry = &group.entries[group.primary_ix];
533 let range = entry.range.to_point(&snapshot);
534 let severity = match entry.diagnostic.severity {
535 DiagnosticSeverity::ERROR => "error",
536 DiagnosticSeverity::WARNING => "warning",
537 _ => continue,
538 };
539
540 writeln!(
541 output,
542 "{} at line {}: {}",
543 severity,
544 range.start.row + 1,
545 entry.diagnostic.message
546 )?;
547 }
548 }
549 anyhow::Ok(output)
550}
551
552fn parse_judge_output(response: &str) -> Result<JudgeOutput> {
553 let analysis = get_tag("analysis", response)?.to_string();
554 let score = get_tag("score", response)?
555 .parse()
556 .context("error parsing score")?;
557
558 Ok(JudgeOutput { analysis, score })
559}
560
561fn get_tag(name: &'static str, response: &str) -> Result<String> {
562 let start_tag = format!("<{}>", name);
563 let end_tag = format!("</{}>", name);
564
565 let start_ix = response
566 .find(&start_tag)
567 .context(format!("{} start tag not found", name))?;
568 let content_start_ix = start_ix + start_tag.len();
569
570 let end_ix = content_start_ix
571 + response[content_start_ix..]
572 .find(&end_tag)
573 .context(format!("{} end tag not found", name))?;
574
575 let content = response[content_start_ix..end_ix].trim().unindent();
576
577 anyhow::Ok(content)
578}
579
580pub fn repo_path_for_url(repo_url: &str) -> PathBuf {
581 let repo_name = repo_url
582 .trim_start_matches("https://")
583 .replace(|c: char| !c.is_alphanumeric(), "-");
584 Path::new(REPOS_DIR)
585 .canonicalize()
586 .context(format!("No such directory {REPOS_DIR}"))
587 .unwrap()
588 .join(repo_name)
589}
590
591pub async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
592 let output = new_smol_command("git")
593 .current_dir(repo_path)
594 .args(args)
595 .output()
596 .await?;
597
598 if output.status.success() {
599 Ok(String::from_utf8(output.stdout)?.trim().to_string())
600 } else {
601 Err(anyhow!(
602 "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
603 args.join(" "),
604 repo_path.display(),
605 output.status,
606 String::from_utf8_lossy(&output.stderr),
607 String::from_utf8_lossy(&output.stdout),
608 ))
609 }
610}
611
612pub async fn send_language_model_request(
613 model: Arc<dyn LanguageModel>,
614 request: LanguageModelRequest,
615 cx: &AsyncApp,
616) -> anyhow::Result<String> {
617 match model.stream_completion_text(request, &cx).await {
618 Ok(mut stream) => {
619 let mut full_response = String::new();
620 while let Some(chunk_result) = stream.stream.next().await {
621 match chunk_result {
622 Ok(chunk_str) => {
623 print!("{}", &chunk_str);
624 full_response.push_str(&chunk_str);
625 }
626 Err(err) => {
627 return Err(anyhow!(
628 "Error receiving response from language model: {err}"
629 ));
630 }
631 }
632 }
633 Ok(full_response)
634 }
635 Err(err) => Err(anyhow!(
636 "Failed to get response from language model. Error was: {err}"
637 )),
638 }
639}
640
641#[cfg(test)]
642mod test {
643 use super::*;
644
645 #[test]
646 fn test_parse_judge_output() {
647 let response = r#"
648 <analysis>The model did a good job but there were still compilations errors.</analysis>
649 <score>3</score>
650 "#
651 .unindent();
652
653 let output = parse_judge_output(&response).unwrap();
654 assert_eq!(
655 output.analysis,
656 "The model did a good job but there were still compilations errors."
657 );
658 assert_eq!(output.score, 3);
659
660 let response = r#"
661 Text around ignored
662
663 <analysis>
664 Failed to compile:
665 - Error 1
666 - Error 2
667 </analysis>
668
669 <score>1</score>
670 "#
671 .unindent();
672
673 let output = parse_judge_output(&response).unwrap();
674 assert_eq!(output.analysis, "Failed to compile:\n- Error 1\n- Error 2");
675 assert_eq!(output.score, 1);
676 }
677}