1use agent::{RequestKind, ThreadEvent, ThreadStore};
2use anyhow::{Context as _, Result, anyhow};
3use assistant_tool::ToolWorkingSet;
4use client::proto::LspWorkProgress;
5use collections::HashMap;
6use dap::DapRegistry;
7use futures::channel::{mpsc, oneshot};
8use futures::{FutureExt, StreamExt as _};
9use gpui::{App, AsyncApp, Entity, Task};
10use handlebars::Handlebars;
11use language::{DiagnosticSeverity, OffsetRangeExt};
12use language_model::{
13 LanguageModel, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
14 StopReason, TokenUsage,
15};
16use project::{LspStore, Project, ProjectPath};
17use serde::{Deserialize, Serialize};
18use std::fmt::Write as _;
19use std::fs::File;
20use std::io::Write as _;
21use std::sync::{Arc, Mutex};
22use std::time::Duration;
23use std::{
24 fs,
25 path::{Path, PathBuf},
26};
27use unindent::Unindent as _;
28use util::ResultExt as _;
29use util::command::new_smol_command;
30use util::serde::default_true;
31
32use crate::AgentAppState;
33
34pub const EXAMPLES_DIR: &str = "./crates/eval/examples";
35pub const REPOS_DIR: &str = "./crates/eval/repos";
36pub const WORKTREES_DIR: &str = "./crates/eval/worktrees";
37
38#[derive(Clone, Debug, Deserialize)]
39pub struct ExampleBase {
40 pub url: String,
41 pub revision: String,
42 pub language_extension: Option<String>,
43 pub insert_id: Option<String>,
44 #[serde(default = "default_true")]
45 pub require_lsp: bool,
46}
47
48#[derive(Clone, Debug)]
49pub struct Example {
50 pub name: String,
51 /// Content of `base.toml`
52 pub base: ExampleBase,
53 /// Content of `prompt.md`
54 pub prompt: String,
55 /// Content of `criteria.md`
56 pub criteria: String,
57 /// Markdown log file to append to
58 pub log_file: Arc<Mutex<File>>,
59}
60
61#[derive(Debug, Serialize, Deserialize, Clone)]
62pub struct RunOutput {
63 pub repository_diff: String,
64 pub diagnostics: String,
65 pub response_count: usize,
66 pub token_usage: TokenUsage,
67 pub tool_use_counts: HashMap<Arc<str>, u32>,
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct JudgeInput {
72 pub repository_diff: String,
73 pub criteria: String,
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct JudgeOutput {
78 pub analysis: String,
79 pub score: u32,
80}
81
82impl Example {
83 /// Load an example from a directory containing base.toml, prompt.md, and criteria.md
84 pub fn load_from_directory(dir_path: &Path, run_dir: &Path) -> Result<Self> {
85 let name = dir_path.file_name().unwrap().to_string_lossy().to_string();
86 let base_path = dir_path.join("base.toml");
87 let prompt_path = dir_path.join("prompt.md");
88 let criteria_path = dir_path.join("criteria.md");
89
90 let log_file_path = run_dir.join(format!(
91 "{}.md",
92 dir_path.file_name().unwrap().to_str().unwrap()
93 ));
94 let log_file = Arc::new(Mutex::new(File::create(&log_file_path).unwrap()));
95 println!("{}> Logging to {:?}", name, log_file_path);
96
97 Ok(Example {
98 name,
99 base: toml::from_str(&fs::read_to_string(&base_path)?)?,
100 prompt: fs::read_to_string(prompt_path.clone())?,
101 criteria: fs::read_to_string(criteria_path.clone())?,
102 log_file,
103 })
104 }
105
106 pub fn worktree_path(&self) -> PathBuf {
107 Path::new(WORKTREES_DIR)
108 .canonicalize()
109 .context(format!("No such directory {WORKTREES_DIR}"))
110 .unwrap()
111 .join(&self.name)
112 }
113
114 /// Set up the example by checking out the specified Git revision
115 pub async fn setup(&self) -> Result<()> {
116 let repo_path = repo_path_for_url(&self.base.url);
117
118 println!("{}> Fetching", self.name);
119
120 run_git(
121 &repo_path,
122 &["fetch", "--depth", "1", "origin", &self.base.revision],
123 )
124 .await?;
125
126 let worktree_path = self.worktree_path();
127
128 if worktree_path.is_dir() {
129 println!("{}> Resetting existing worktree", self.name);
130
131 // TODO: consider including "-x" to remove ignored files. The downside of this is that
132 // it will also remove build artifacts, and so prevent incremental reuse there.
133 run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
134 run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
135 run_git(&worktree_path, &["checkout", &self.base.revision]).await?;
136 } else {
137 println!("{}> Creating worktree", self.name);
138
139 let worktree_path_string = worktree_path.to_string_lossy().to_string();
140
141 run_git(
142 &repo_path,
143 &[
144 "worktree",
145 "add",
146 "-f",
147 &worktree_path_string,
148 &self.base.revision,
149 ],
150 )
151 .await?;
152 }
153
154 Ok(())
155 }
156
157 pub fn run(
158 &self,
159 model: Arc<dyn LanguageModel>,
160 app_state: Arc<AgentAppState>,
161 cx: &mut App,
162 ) -> Task<Result<RunOutput>> {
163 let project = Project::local(
164 app_state.client.clone(),
165 app_state.node_runtime.clone(),
166 app_state.user_store.clone(),
167 app_state.languages.clone(),
168 Arc::new(DapRegistry::default()),
169 app_state.fs.clone(),
170 None,
171 cx,
172 );
173
174 let worktree_path = self.worktree_path();
175 let worktree = project.update(cx, |project, cx| {
176 project.create_worktree(&worktree_path, true, cx)
177 });
178
179 let tools = Arc::new(ToolWorkingSet::default());
180 let thread_store =
181 ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx);
182 let this = self.clone();
183
184 cx.spawn(async move |cx| {
185 let worktree = worktree.await?;
186
187 // Wait for worktree scan to finish before choosing a file to open.
188 worktree
189 .update(cx, |worktree, _cx| {
190 worktree.as_local().unwrap().scan_complete()
191 })?
192 .await;
193
194 let lsp_open_handle_and_store = if this.base.require_lsp {
195 let language_extension = this.base.language_extension.as_deref().context(
196 "language_extension field is required in base.toml when `require_lsp == true`",
197 )?;
198
199 // Open a file that matches the language to cause LSP to start.
200 let language_file = worktree.read_with(cx, |worktree, _cx| {
201 worktree
202 .files(false, 0)
203 .find_map(|e| {
204 if e.path.clone().extension().and_then(|ext| ext.to_str())
205 == Some(language_extension)
206 {
207 Some(ProjectPath {
208 worktree_id: worktree.id(),
209 path: e.path.clone(),
210 })
211 } else {
212 None
213 }
214 })
215 .context("Failed to find a file for example language")
216 })??;
217
218 let open_language_file_buffer_task = project.update(cx, |project, cx| {
219 project.open_buffer(language_file.clone(), cx)
220 })?;
221
222 let language_file_buffer = open_language_file_buffer_task.await?;
223
224 let (lsp_open_handle, lsp_store) = project.update(cx, |project, cx| {
225 (
226 project.register_buffer_with_language_servers(&language_file_buffer, cx),
227 project.lsp_store().clone(),
228 )
229 })?;
230
231 // TODO: remove this once the diagnostics tool waits for new diagnostics
232 cx.background_executor().timer(Duration::new(5, 0)).await;
233 wait_for_lang_server(&lsp_store, this.name.clone(), cx).await?;
234
235 lsp_store.update(cx, |lsp_store, cx| {
236 lsp_open_handle.update(cx, |buffer, cx| {
237 buffer.update(cx, |buffer, cx| {
238 let has_language_server = lsp_store
239 .language_servers_for_local_buffer(buffer, cx)
240 .next()
241 .is_some();
242 if has_language_server {
243 Ok(())
244 } else {
245 Err(anyhow!(
246 "`{:?}` was opened to cause the language server to start, \
247 but no language servers are registered for its buffer. \
248 Set `require_lsp = false` in `base.toml` to skip this.",
249 language_file
250 ))
251 }
252 })
253 })
254 })??;
255
256 Some((lsp_open_handle, lsp_store))
257 } else {
258 None
259 };
260
261 if std::env::var("ZED_EVAL_SETUP_ONLY").is_ok() {
262 return Err(anyhow!("Setup only mode"));
263 }
264
265 let thread_store = thread_store.await;
266 let thread =
267 thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?;
268
269 {
270 let mut log_file = this.log_file.lock().unwrap();
271 writeln!(&mut log_file, "👤 USER:").log_err();
272 writeln!(&mut log_file, "{}", this.prompt).log_err();
273 writeln!(&mut log_file, "🤖 ASSISTANT:").log_err();
274 log_file.flush().log_err();
275 }
276
277 let tool_use_counts: Arc<Mutex<HashMap<Arc<str>, u32>>> =
278 Mutex::new(HashMap::default()).into();
279
280 let (tx, rx) = oneshot::channel();
281 let mut tx = Some(tx);
282
283 let subscription = cx.subscribe(&thread, {
284 let log_file = this.log_file.clone();
285 let name = this.name.clone();
286 let tool_use_counts = tool_use_counts.clone();
287 move |thread, event: &ThreadEvent, cx| {
288 let mut log_file = log_file.lock().unwrap();
289
290 match event {
291 ThreadEvent::Stopped(reason) => match reason {
292 Ok(StopReason::EndTurn) => {
293 if let Some(tx) = tx.take() {
294 tx.send(Ok(())).ok();
295 }
296 }
297 Ok(StopReason::MaxTokens) => {
298 if let Some(tx) = tx.take() {
299 tx.send(Err(anyhow!("Exceeded maximum tokens"))).ok();
300 }
301 }
302 Ok(StopReason::ToolUse) => {}
303 Err(error) => {
304 if let Some(tx) = tx.take() {
305 tx.send(Err(anyhow!(error.clone()))).ok();
306 }
307 }
308 },
309 ThreadEvent::ShowError(thread_error) => {
310 if let Some(tx) = tx.take() {
311 tx.send(Err(anyhow!(thread_error.clone()))).ok();
312 }
313 }
314 ThreadEvent::StreamedAssistantText(_, chunk) => {
315 write!(&mut log_file, "{}", chunk).log_err();
316 }
317 ThreadEvent::StreamedAssistantThinking(_, chunk) => {
318 write!(&mut log_file, "{}", chunk).log_err();
319 }
320 ThreadEvent::UsePendingTools { tool_uses } => {
321 writeln!(&mut log_file, "\n\nUSING TOOLS:").log_err();
322 for tool_use in tool_uses {
323 writeln!(&mut log_file, "{}: {}", tool_use.name, tool_use.input)
324 .log_err();
325 }
326 }
327 ThreadEvent::ToolFinished {
328 tool_use_id,
329 pending_tool_use,
330 ..
331 } => {
332 if let Some(tool_use) = pending_tool_use {
333 let message = format!("TOOL FINISHED: {}", tool_use.name);
334 println!("{name}> {message}");
335 writeln!(&mut log_file, "\n{}", message).log_err();
336 }
337 if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
338 writeln!(&mut log_file, "\n{}\n", tool_result.content).log_err();
339 let mut tool_use_counts = tool_use_counts.lock().unwrap();
340 *tool_use_counts
341 .entry(tool_result.tool_name.clone())
342 .or_insert(0) += 1;
343 }
344 }
345 _ => {}
346 }
347
348 log_file.flush().log_err();
349 }
350 })?;
351
352 thread.update(cx, |thread, cx| {
353 let context = vec![];
354 thread.insert_user_message(this.prompt.clone(), context, None, cx);
355 thread.send_to_model(model, RequestKind::Chat, cx);
356 })?;
357
358 rx.await??;
359
360 if let Some((_, lsp_store)) = lsp_open_handle_and_store.as_ref() {
361 wait_for_lang_server(lsp_store, this.name.clone(), cx).await?;
362 }
363
364 let repository_diff = this.repository_diff().await?;
365 let diagnostics = cx
366 .update(move |cx| {
367 cx.spawn(async move |cx| query_lsp_diagnostics(project, cx).await)
368 })?
369 .await?;
370
371 drop(subscription);
372 drop(lsp_open_handle_and_store);
373
374 thread.update(cx, |thread, _cx| {
375 let response_count = thread
376 .messages()
377 .filter(|message| message.role == language_model::Role::Assistant)
378 .count();
379 RunOutput {
380 repository_diff,
381 diagnostics,
382 response_count,
383 token_usage: thread.cumulative_token_usage(),
384 tool_use_counts: tool_use_counts.lock().unwrap().clone(),
385 }
386 })
387 })
388 }
389
390 pub async fn judge(
391 &mut self,
392 model: Arc<dyn LanguageModel>,
393 repository_diff: String,
394 cx: &AsyncApp,
395 ) -> Result<JudgeOutput> {
396 let judge_prompt = include_str!("judge_prompt.hbs");
397 let judge_prompt_name = "judge_prompt";
398 let mut handlebars = Handlebars::new();
399 handlebars.register_template_string(judge_prompt_name, judge_prompt)?;
400 let prompt = handlebars.render(
401 judge_prompt_name,
402 &JudgeInput {
403 repository_diff,
404 criteria: self.criteria.clone(),
405 },
406 )?;
407
408 let request = LanguageModelRequest {
409 messages: vec![LanguageModelRequestMessage {
410 role: Role::User,
411 content: vec![MessageContent::Text(prompt)],
412 cache: false,
413 }],
414 temperature: None,
415 tools: Vec::new(),
416 stop: Vec::new(),
417 };
418
419 let response = send_language_model_request(model, request, cx).await?;
420
421 let mut log_file = self.log_file.lock().unwrap();
422
423 writeln!(&mut log_file, "\n\n").log_err();
424 writeln!(&mut log_file, "========================================").log_err();
425 writeln!(&mut log_file, " JUDGE OUTPUT ").log_err();
426 writeln!(&mut log_file, "========================================").log_err();
427 writeln!(&mut log_file, "\n{}", &response).log_err();
428
429 parse_judge_output(&response)
430 }
431
432 pub async fn repository_diff(&self) -> Result<String> {
433 let worktree_path = self.worktree_path();
434 run_git(&worktree_path, &["add", "-N"]).await?;
435 run_git(&worktree_path, &["diff"]).await
436 }
437}
438
439fn wait_for_lang_server(
440 lsp_store: &Entity<LspStore>,
441 name: String,
442 cx: &mut AsyncApp,
443) -> Task<Result<()>> {
444 if cx
445 .update(|cx| !has_pending_lang_server_work(lsp_store, cx))
446 .unwrap()
447 || std::env::var("ZED_EVAL_SKIP_LS_WAIT").is_ok()
448 {
449 return Task::ready(anyhow::Ok(()));
450 }
451
452 println!("{}> ⏵ Waiting for language server", name);
453
454 let (mut tx, mut rx) = mpsc::channel(1);
455
456 let subscription =
457 cx.subscribe(&lsp_store, {
458 let name = name.clone();
459 move |lsp_store, event, cx| {
460 match event {
461 project::LspStoreEvent::LanguageServerUpdate {
462 message:
463 client::proto::update_language_server::Variant::WorkProgress(
464 LspWorkProgress {
465 message: Some(message),
466 ..
467 },
468 ),
469 ..
470 } => println!("{name}> ⟲ {message}"),
471 _ => {}
472 }
473
474 if !has_pending_lang_server_work(&lsp_store, cx) {
475 tx.try_send(()).ok();
476 }
477 }
478 });
479
480 cx.spawn(async move |cx| {
481 let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0));
482 let result = futures::select! {
483 _ = rx.next() => {
484 println!("{}> ⚑ Language server idle", name);
485 anyhow::Ok(())
486 },
487 _ = timeout.fuse() => {
488 Err(anyhow!("LSP wait timed out after 5 minutes"))
489 }
490 };
491 drop(subscription);
492 result
493 })
494}
495
496fn has_pending_lang_server_work(lsp_store: &Entity<LspStore>, cx: &App) -> bool {
497 lsp_store
498 .read(cx)
499 .language_server_statuses()
500 .any(|(_, status)| !status.pending_work.is_empty())
501}
502
503async fn query_lsp_diagnostics(project: Entity<Project>, cx: &mut AsyncApp) -> Result<String> {
504 let paths_with_diagnostics = project.update(cx, |project, cx| {
505 project
506 .diagnostic_summaries(true, cx)
507 .filter(|(_, _, summary)| summary.error_count > 0 || summary.warning_count > 0)
508 .map(|(project_path, _, _)| project_path)
509 .collect::<Vec<_>>()
510 })?;
511
512 let mut output = String::new();
513 for project_path in paths_with_diagnostics {
514 let buffer = project
515 .update(cx, |project, cx| project.open_buffer(project_path, cx))?
516 .await?;
517 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
518
519 for (_, group) in snapshot.diagnostic_groups(None) {
520 let entry = &group.entries[group.primary_ix];
521 let range = entry.range.to_point(&snapshot);
522 let severity = match entry.diagnostic.severity {
523 DiagnosticSeverity::ERROR => "error",
524 DiagnosticSeverity::WARNING => "warning",
525 _ => continue,
526 };
527
528 writeln!(
529 output,
530 "{} at line {}: {}",
531 severity,
532 range.start.row + 1,
533 entry.diagnostic.message
534 )?;
535 }
536 }
537 anyhow::Ok(output)
538}
539
540fn parse_judge_output(response: &str) -> Result<JudgeOutput> {
541 let analysis = get_tag("analysis", response)?.to_string();
542 let score = get_tag("score", response)?
543 .parse()
544 .context("error parsing score")?;
545
546 Ok(JudgeOutput { analysis, score })
547}
548
549fn get_tag(name: &'static str, response: &str) -> Result<String> {
550 let start_tag = format!("<{}>", name);
551 let end_tag = format!("</{}>", name);
552
553 let start_ix = response
554 .find(&start_tag)
555 .context(format!("{} start tag not found", name))?;
556 let content_start_ix = start_ix + start_tag.len();
557
558 let end_ix = content_start_ix
559 + response[content_start_ix..]
560 .find(&end_tag)
561 .context(format!("{} end tag not found", name))?;
562
563 let content = response[content_start_ix..end_ix].trim().unindent();
564
565 anyhow::Ok(content)
566}
567
568pub fn repo_path_for_url(repo_url: &str) -> PathBuf {
569 let repo_name = repo_url
570 .trim_start_matches("https://")
571 .replace(|c: char| !c.is_alphanumeric(), "-");
572 Path::new(REPOS_DIR)
573 .canonicalize()
574 .context(format!("No such directory {REPOS_DIR}"))
575 .unwrap()
576 .join(repo_name)
577}
578
579pub async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
580 let output = new_smol_command("git")
581 .current_dir(repo_path)
582 .args(args)
583 .output()
584 .await?;
585
586 if output.status.success() {
587 Ok(String::from_utf8(output.stdout)?.trim().to_string())
588 } else {
589 Err(anyhow!(
590 "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
591 args.join(" "),
592 repo_path.display(),
593 output.status,
594 String::from_utf8_lossy(&output.stderr),
595 String::from_utf8_lossy(&output.stdout),
596 ))
597 }
598}
599
600pub async fn send_language_model_request(
601 model: Arc<dyn LanguageModel>,
602 request: LanguageModelRequest,
603 cx: &AsyncApp,
604) -> anyhow::Result<String> {
605 match model.stream_completion_text(request, &cx).await {
606 Ok(mut stream) => {
607 let mut full_response = String::new();
608 while let Some(chunk_result) = stream.stream.next().await {
609 match chunk_result {
610 Ok(chunk_str) => {
611 print!("{}", &chunk_str);
612 full_response.push_str(&chunk_str);
613 }
614 Err(err) => {
615 return Err(anyhow!(
616 "Error receiving response from language model: {err}"
617 ));
618 }
619 }
620 }
621 Ok(full_response)
622 }
623 Err(err) => Err(anyhow!(
624 "Failed to get response from language model. Error was: {err}"
625 )),
626 }
627}
628
629#[cfg(test)]
630mod test {
631 use super::*;
632
633 #[test]
634 fn test_parse_judge_output() {
635 let response = r#"
636 <analysis>The model did a good job but there were still compilations errors.</analysis>
637 <score>3</score>
638 "#
639 .unindent();
640
641 let output = parse_judge_output(&response).unwrap();
642 assert_eq!(
643 output.analysis,
644 "The model did a good job but there were still compilations errors."
645 );
646 assert_eq!(output.score, 3);
647
648 let response = r#"
649 Text around ignored
650
651 <analysis>
652 Failed to compile:
653 - Error 1
654 - Error 2
655 </analysis>
656
657 <score>1</score>
658 "#
659 .unindent();
660
661 let output = parse_judge_output(&response).unwrap();
662 assert_eq!(output.analysis, "Failed to compile:\n- Error 1\n- Error 2");
663 assert_eq!(output.score, 1);
664 }
665}