example.rs

  1use agent::{RequestKind, ThreadEvent, ThreadStore};
  2use anyhow::{Context as _, Result, anyhow};
  3use assistant_tool::ToolWorkingSet;
  4use client::proto::LspWorkProgress;
  5use collections::HashMap;
  6use dap::DapRegistry;
  7use futures::channel::mpsc;
  8use futures::{FutureExt, StreamExt as _, select_biased};
  9use gpui::{App, AppContext as _, AsyncApp, Entity, Task};
 10use handlebars::Handlebars;
 11use language::{DiagnosticSeverity, OffsetRangeExt};
 12use language_model::{
 13    LanguageModel, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
 14    StopReason, TokenUsage,
 15};
 16use project::{LspStore, Project, ProjectPath};
 17use serde::{Deserialize, Serialize};
 18use std::fmt::Write as _;
 19use std::fs::File;
 20use std::io::Write as _;
 21use std::sync::{Arc, Mutex};
 22use std::time::Duration;
 23use std::{
 24    fs,
 25    path::{Path, PathBuf},
 26};
 27use unindent::Unindent as _;
 28use util::ResultExt as _;
 29use util::command::new_smol_command;
 30use util::serde::default_true;
 31
 32use crate::AgentAppState;
 33
 34pub const EXAMPLES_DIR: &str = "./crates/eval/examples";
 35pub const REPOS_DIR: &str = "./crates/eval/repos";
 36pub const WORKTREES_DIR: &str = "./crates/eval/worktrees";
 37
 38const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2);
 39
 40#[derive(Clone, Debug, Deserialize)]
 41pub struct ExampleBase {
 42    pub url: String,
 43    pub revision: String,
 44    pub language_extension: Option<String>,
 45    pub insert_id: Option<String>,
 46    #[serde(default = "default_true")]
 47    pub require_lsp: bool,
 48}
 49
 50#[derive(Clone, Debug)]
 51pub struct Example {
 52    pub name: String,
 53    /// Content of `base.toml`
 54    pub base: ExampleBase,
 55    /// Content of `prompt.md`
 56    pub prompt: String,
 57    /// Content of `criteria.md`
 58    pub criteria: String,
 59    /// Markdown output file to append to
 60    pub output_file: Arc<Mutex<File>>,
 61    /// Path to markdown output file
 62    pub output_file_path: PathBuf,
 63    /// Prefix used for logging that identifies this example
 64    pub log_prefix: String,
 65}
 66
 67#[derive(Debug, Serialize, Deserialize, Clone)]
 68pub struct RunOutput {
 69    pub repository_diff: String,
 70    pub diagnostics: String,
 71    pub response_count: usize,
 72    pub token_usage: TokenUsage,
 73    pub tool_use_counts: HashMap<Arc<str>, u32>,
 74}
 75
 76#[derive(Debug, Clone, Serialize, Deserialize)]
 77pub struct JudgeInput {
 78    pub repository_diff: String,
 79    pub criteria: String,
 80}
 81
 82#[derive(Debug, Clone, Serialize, Deserialize)]
 83pub struct JudgeOutput {
 84    pub analysis: String,
 85    pub score: u32,
 86}
 87
 88impl Example {
 89    /// Load an example from a directory containing base.toml, prompt.md, and criteria.md
 90    pub fn load_from_directory(dir_path: &Path, run_dir: &Path) -> Result<Self> {
 91        let name = Self::name_from_path(dir_path);
 92        let base_path = dir_path.join("base.toml");
 93        let prompt_path = dir_path.join("prompt.md");
 94        let criteria_path = dir_path.join("criteria.md");
 95
 96        let output_file_path = run_dir.join(format!(
 97            "{}.md",
 98            dir_path.file_name().unwrap().to_str().unwrap()
 99        ));
100        let output_file = Arc::new(Mutex::new(File::create(&output_file_path).unwrap()));
101
102        Ok(Example {
103            name: name.clone(),
104            base: toml::from_str(&fs::read_to_string(&base_path)?)?,
105            prompt: fs::read_to_string(prompt_path.clone())?,
106            criteria: fs::read_to_string(criteria_path.clone())?,
107            output_file,
108            output_file_path,
109            log_prefix: name,
110        })
111    }
112
113    pub fn set_log_prefix_style(&mut self, color: &str, name_width: usize) {
114        self.log_prefix = format!(
115            "{}{:<width$}\x1b[0m | ",
116            color,
117            self.name,
118            width = name_width
119        );
120    }
121
122    pub fn name_from_path(path: &Path) -> String {
123        path.file_name().unwrap().to_string_lossy().to_string()
124    }
125
126    pub fn worktree_path(&self) -> PathBuf {
127        Path::new(WORKTREES_DIR)
128            .canonicalize()
129            .context(format!("No such directory {WORKTREES_DIR}"))
130            .unwrap()
131            .join(&self.name)
132    }
133
134    /// Set up the example by checking out the specified Git revision
135    pub async fn setup(&self) -> Result<()> {
136        let repo_path = repo_path_for_url(&self.base.url);
137
138        println!("{}Fetching", self.log_prefix);
139
140        run_git(
141            &repo_path,
142            &["fetch", "--depth", "1", "origin", &self.base.revision],
143        )
144        .await?;
145
146        let worktree_path = self.worktree_path();
147
148        if worktree_path.is_dir() {
149            println!("{}Resetting existing worktree", self.log_prefix);
150
151            // TODO: consider including "-x" to remove ignored files. The downside of this is that
152            // it will also remove build artifacts, and so prevent incremental reuse there.
153            run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
154            run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
155            run_git(&worktree_path, &["checkout", &self.base.revision]).await?;
156        } else {
157            println!("{}Creating worktree", self.log_prefix);
158
159            let worktree_path_string = worktree_path.to_string_lossy().to_string();
160
161            run_git(
162                &repo_path,
163                &[
164                    "worktree",
165                    "add",
166                    "-f",
167                    &worktree_path_string,
168                    &self.base.revision,
169                ],
170            )
171            .await?;
172        }
173
174        Ok(())
175    }
176
177    pub fn run(
178        &self,
179        model: Arc<dyn LanguageModel>,
180        app_state: Arc<AgentAppState>,
181        cx: &mut App,
182    ) -> Task<Result<RunOutput>> {
183        let project = Project::local(
184            app_state.client.clone(),
185            app_state.node_runtime.clone(),
186            app_state.user_store.clone(),
187            app_state.languages.clone(),
188            Arc::new(DapRegistry::default()),
189            app_state.fs.clone(),
190            None,
191            cx,
192        );
193
194        let worktree_path = self.worktree_path();
195        let worktree = project.update(cx, |project, cx| {
196            project.create_worktree(&worktree_path, true, cx)
197        });
198
199        let tools = cx.new(|_| ToolWorkingSet::default());
200        let thread_store =
201            ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx);
202        let this = self.clone();
203
204        cx.spawn(async move |cx| {
205            let worktree = worktree.await?;
206
207            // Wait for worktree scan to finish before choosing a file to open.
208            worktree
209                .update(cx, |worktree, _cx| {
210                    worktree.as_local().unwrap().scan_complete()
211                })?
212                .await;
213
214            let lsp_open_handle_and_store = if this.base.require_lsp {
215                let language_extension = this.base.language_extension.as_deref().context(
216                    "language_extension field is required in base.toml when `require_lsp == true`",
217                )?;
218
219                // Open a file that matches the language to cause LSP to start.
220                let language_file = worktree.read_with(cx, |worktree, _cx| {
221                    worktree
222                        .files(false, 0)
223                        .find_map(|e| {
224                            if e.path.clone().extension().and_then(|ext| ext.to_str())
225                                == Some(language_extension)
226                            {
227                                Some(ProjectPath {
228                                    worktree_id: worktree.id(),
229                                    path: e.path.clone(),
230                                })
231                            } else {
232                                None
233                            }
234                        })
235                        .context("Failed to find a file for example language")
236                })??;
237
238                let open_language_file_buffer_task = project.update(cx, |project, cx| {
239                    project.open_buffer(language_file.clone(), cx)
240                })?;
241
242                let language_file_buffer = open_language_file_buffer_task.await?;
243
244                let (lsp_open_handle, lsp_store) = project.update(cx, |project, cx| {
245                    (
246                        project.register_buffer_with_language_servers(&language_file_buffer, cx),
247                        project.lsp_store().clone(),
248                    )
249                })?;
250
251                // TODO: remove this once the diagnostics tool waits for new diagnostics
252                cx.background_executor().timer(Duration::new(5, 0)).await;
253                wait_for_lang_server(&lsp_store, this.log_prefix.clone(), cx).await?;
254
255                lsp_store.update(cx, |lsp_store, cx| {
256                    lsp_open_handle.update(cx, |buffer, cx| {
257                        buffer.update(cx, |buffer, cx| {
258                            let has_language_server = lsp_store
259                                .language_servers_for_local_buffer(buffer, cx)
260                                .next()
261                                .is_some();
262                            if has_language_server {
263                                Ok(())
264                            } else {
265                                Err(anyhow!(
266                                    "`{:?}` was opened to cause the language server to start, \
267                                    but no language servers are registered for its buffer. \
268                                    Set `require_lsp = false` in `base.toml` to skip this.",
269                                    language_file
270                                ))
271                            }
272                        })
273                    })
274                })??;
275
276                Some((lsp_open_handle, lsp_store))
277            } else {
278                None
279            };
280
281            if std::env::var("ZED_EVAL_SETUP_ONLY").is_ok() {
282                return Err(anyhow!("Setup only mode"));
283            }
284
285            let thread_store = thread_store.await;
286            let thread =
287                thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?;
288
289            {
290                let mut output_file = this.output_file.lock().unwrap();
291                writeln!(&mut output_file, "👤 USER:").log_err();
292                writeln!(&mut output_file, "{}", this.prompt).log_err();
293                writeln!(&mut output_file, "🤖 ASSISTANT:").log_err();
294                output_file.flush().log_err();
295            }
296
297            let tool_use_counts: Arc<Mutex<HashMap<Arc<str>, u32>>> =
298                Mutex::new(HashMap::default()).into();
299
300            let (thread_event_tx, mut thread_event_rx) = mpsc::unbounded();
301
302            let subscription = cx.subscribe(&thread, move |_thread, event: &ThreadEvent, _cx| {
303                thread_event_tx.unbounded_send(event.clone()).log_err();
304            });
305
306            let event_handler_task = cx.spawn({
307                let output_file = this.output_file.clone();
308                let log_prefix = this.log_prefix.clone();
309                let tool_use_counts = tool_use_counts.clone();
310                let thread = thread.downgrade();
311                async move |cx| {
312                    loop {
313                        let event = select_biased! {
314                            event = thread_event_rx.next() => event,
315                            _ = cx.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => {
316                                return Err(anyhow!("Agentic loop stalled - waited {:?} without any events", THREAD_EVENT_TIMEOUT));
317                            }
318                        };
319                        let Some(event) = event else {
320                            return Err(anyhow!("ThreadEvent channel ended early"));
321                        };
322
323                        let mut output_file = output_file.lock().unwrap();
324
325                        match event {
326                            ThreadEvent::Stopped(reason) => match reason {
327                                Ok(StopReason::EndTurn) => {
328                                    return Ok(());
329                                }
330                                Ok(StopReason::MaxTokens) => {
331                                    return Err(anyhow!("Exceeded maximum tokens"));
332                                }
333                                Ok(StopReason::ToolUse) => {
334                                    if std::env::var("ZED_EVAL_DEBUG").is_ok() {
335                                        println!("{}StopReason: Tool use", log_prefix);
336                                    }
337                                }
338                                Err(error) => {
339                                    return Err(anyhow!(error.clone()));
340                                }
341                            },
342                            ThreadEvent::ShowError(thread_error) => {
343                                break Err(anyhow!(thread_error.clone()));
344                            }
345                            ThreadEvent::StreamedAssistantText(_, chunk) => {
346                                write!(&mut output_file, "{}", chunk).log_err();
347                            }
348                            ThreadEvent::StreamedAssistantThinking(_, chunk) => {
349                                write!(&mut output_file, "{}", chunk).log_err();
350                            }
351                            ThreadEvent::UsePendingTools { tool_uses } => {
352                                writeln!(&mut output_file, "\n\nUSING TOOLS:").log_err();
353                                for tool_use in tool_uses {
354                                    writeln!(&mut output_file, "{}: {}", tool_use.name, tool_use.input)
355                                        .log_err();
356                                }
357                            }
358                            ThreadEvent::ToolFinished {
359                                tool_use_id,
360                                pending_tool_use,
361                                ..
362                            } => {
363                                if let Some(tool_use) = pending_tool_use {
364                                    let message = format!("TOOL FINISHED: {}", tool_use.name);
365                                    println!("{}{message}", log_prefix);
366                                    writeln!(&mut output_file, "\n{}", message).log_err();
367                                }
368                                thread.update(cx, |thread, _cx| {
369                                    if let Some(tool_result) = thread.tool_result(&tool_use_id) {
370                                        writeln!(&mut output_file, "\n{}\n", tool_result.content).log_err();
371                                        let mut tool_use_counts = tool_use_counts.lock().unwrap();
372                                        *tool_use_counts
373                                            .entry(tool_result.tool_name.clone())
374                                            .or_insert(0) += 1;
375                                    }
376                                })?;
377                            }
378                            ThreadEvent::ToolConfirmationNeeded => {
379                                panic!("{}Bug: Tool confirmation should not be required in eval", log_prefix);
380                            },
381                            ThreadEvent::StreamedCompletion |
382                            ThreadEvent::MessageAdded(_) |
383                            ThreadEvent::MessageEdited(_) |
384                            ThreadEvent::MessageDeleted(_) |
385                            ThreadEvent::SummaryChanged |
386                            ThreadEvent::SummaryGenerated |
387                            ThreadEvent::CheckpointChanged => {
388                                if std::env::var("ZED_EVAL_DEBUG").is_ok() {
389                                    println!("{}Event: {:#?}", log_prefix, event);
390                                }
391                            }
392                        }
393
394                        output_file.flush().log_err();
395                    }
396                }
397            });
398
399            thread.update(cx, |thread, cx| {
400                let context = vec![];
401                thread.insert_user_message(this.prompt.clone(), context, None, cx);
402                thread.send_to_model(model, RequestKind::Chat, cx);
403            })?;
404
405            event_handler_task.await?;
406
407            println!("{}Stopped", this.log_prefix);
408
409            if let Some((_, lsp_store)) = lsp_open_handle_and_store.as_ref() {
410                wait_for_lang_server(lsp_store, this.log_prefix.clone(), cx).await?;
411            }
412
413            println!("{}Getting repository diff", this.log_prefix);
414            let repository_diff = this.repository_diff().await?;
415
416            println!("{}Getting diagnostics", this.log_prefix);
417            let diagnostics = cx
418                .update(move |cx| {
419                    cx.spawn(async move |cx| query_lsp_diagnostics(project, cx).await)
420                })?
421                .await?;
422            println!("{}Got diagnostics", this.log_prefix);
423
424            drop(subscription);
425            drop(lsp_open_handle_and_store);
426
427            thread.update(cx, |thread, _cx| {
428                let response_count = thread
429                    .messages()
430                    .filter(|message| message.role == language_model::Role::Assistant)
431                    .count();
432                RunOutput {
433                    repository_diff,
434                    diagnostics,
435                    response_count,
436                    token_usage: thread.cumulative_token_usage(),
437                    tool_use_counts: tool_use_counts.lock().unwrap().clone(),
438                }
439            })
440        })
441    }
442
443    pub async fn judge(
444        &self,
445        model: Arc<dyn LanguageModel>,
446        repository_diff: String,
447        cx: &AsyncApp,
448    ) -> Result<JudgeOutput> {
449        let judge_prompt = include_str!("judge_prompt.hbs");
450        let judge_prompt_name = "judge_prompt";
451        let mut handlebars = Handlebars::new();
452        handlebars.register_template_string(judge_prompt_name, judge_prompt)?;
453        let prompt = handlebars.render(
454            judge_prompt_name,
455            &JudgeInput {
456                repository_diff,
457                criteria: self.criteria.clone(),
458            },
459        )?;
460
461        let request = LanguageModelRequest {
462            messages: vec![LanguageModelRequestMessage {
463                role: Role::User,
464                content: vec![MessageContent::Text(prompt)],
465                cache: false,
466            }],
467            temperature: None,
468            tools: Vec::new(),
469            stop: Vec::new(),
470        };
471
472        let response = send_language_model_request(model, request, cx).await?;
473
474        let mut output_file = self.output_file.lock().unwrap();
475
476        writeln!(&mut output_file, "\n\n").log_err();
477        writeln!(&mut output_file, "========================================").log_err();
478        writeln!(&mut output_file, "              JUDGE OUTPUT              ").log_err();
479        writeln!(&mut output_file, "========================================").log_err();
480        writeln!(&mut output_file, "\n{}", &response).log_err();
481
482        parse_judge_output(&response)
483    }
484
485    pub async fn repository_diff(&self) -> Result<String> {
486        let worktree_path = self.worktree_path();
487        run_git(&worktree_path, &["add", "-N"]).await?;
488        run_git(&worktree_path, &["diff"]).await
489    }
490}
491
492fn wait_for_lang_server(
493    lsp_store: &Entity<LspStore>,
494    log_prefix: String,
495    cx: &mut AsyncApp,
496) -> Task<Result<()>> {
497    if cx
498        .update(|cx| !has_pending_lang_server_work(lsp_store, cx))
499        .unwrap()
500        || std::env::var("ZED_EVAL_SKIP_LS_WAIT").is_ok()
501    {
502        return Task::ready(anyhow::Ok(()));
503    }
504
505    println!("{}⏵ Waiting for language server", log_prefix);
506
507    let (mut tx, mut rx) = mpsc::channel(1);
508
509    let subscription =
510        cx.subscribe(&lsp_store, {
511            let log_prefix = log_prefix.clone();
512            move |lsp_store, event, cx| {
513                match event {
514                    project::LspStoreEvent::LanguageServerUpdate {
515                        message:
516                            client::proto::update_language_server::Variant::WorkProgress(
517                                LspWorkProgress {
518                                    message: Some(message),
519                                    ..
520                                },
521                            ),
522                        ..
523                    } => println!("{}{message}", log_prefix),
524                    _ => {}
525                }
526
527                if !has_pending_lang_server_work(&lsp_store, cx) {
528                    tx.try_send(()).ok();
529                }
530            }
531        });
532
533    cx.spawn(async move |cx| {
534        let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0));
535        let result = futures::select! {
536            _ = rx.next() => {
537                println!("{}⚑ Language server idle", log_prefix);
538                anyhow::Ok(())
539            },
540            _ = timeout.fuse() => {
541                Err(anyhow!("LSP wait timed out after 5 minutes"))
542            }
543        };
544        drop(subscription);
545        result
546    })
547}
548
549fn has_pending_lang_server_work(lsp_store: &Entity<LspStore>, cx: &App) -> bool {
550    lsp_store
551        .read(cx)
552        .language_server_statuses()
553        .any(|(_, status)| !status.pending_work.is_empty())
554}
555
556async fn query_lsp_diagnostics(project: Entity<Project>, cx: &mut AsyncApp) -> Result<String> {
557    let paths_with_diagnostics = project.update(cx, |project, cx| {
558        project
559            .diagnostic_summaries(true, cx)
560            .filter(|(_, _, summary)| summary.error_count > 0 || summary.warning_count > 0)
561            .map(|(project_path, _, _)| project_path)
562            .collect::<Vec<_>>()
563    })?;
564
565    let mut output = String::new();
566    for project_path in paths_with_diagnostics {
567        let buffer = project
568            .update(cx, |project, cx| project.open_buffer(project_path, cx))?
569            .await?;
570        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
571
572        for (_, group) in snapshot.diagnostic_groups(None) {
573            let entry = &group.entries[group.primary_ix];
574            let range = entry.range.to_point(&snapshot);
575            let severity = match entry.diagnostic.severity {
576                DiagnosticSeverity::ERROR => "error",
577                DiagnosticSeverity::WARNING => "warning",
578                _ => continue,
579            };
580
581            writeln!(
582                output,
583                "{} at line {}: {}",
584                severity,
585                range.start.row + 1,
586                entry.diagnostic.message
587            )?;
588        }
589    }
590    anyhow::Ok(output)
591}
592
593fn parse_judge_output(response: &str) -> Result<JudgeOutput> {
594    let analysis = get_tag("analysis", response)?.to_string();
595    let score = get_tag("score", response)?
596        .parse()
597        .context("error parsing score")?;
598
599    Ok(JudgeOutput { analysis, score })
600}
601
602fn get_tag(name: &'static str, response: &str) -> Result<String> {
603    let start_tag = format!("<{}>", name);
604    let end_tag = format!("</{}>", name);
605
606    let start_ix = response
607        .find(&start_tag)
608        .context(format!("{} start tag not found", name))?;
609    let content_start_ix = start_ix + start_tag.len();
610
611    let end_ix = content_start_ix
612        + response[content_start_ix..]
613            .find(&end_tag)
614            .context(format!("{} end tag not found", name))?;
615
616    let content = response[content_start_ix..end_ix].trim().unindent();
617
618    anyhow::Ok(content)
619}
620
621pub fn repo_path_for_url(repo_url: &str) -> PathBuf {
622    let repo_name = repo_url
623        .trim_start_matches("https://")
624        .replace(|c: char| !c.is_alphanumeric(), "-");
625    Path::new(REPOS_DIR)
626        .canonicalize()
627        .context(format!("No such directory {REPOS_DIR}"))
628        .unwrap()
629        .join(repo_name)
630}
631
632pub async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
633    let output = new_smol_command("git")
634        .current_dir(repo_path)
635        .args(args)
636        .output()
637        .await?;
638
639    if output.status.success() {
640        Ok(String::from_utf8(output.stdout)?.trim().to_string())
641    } else {
642        Err(anyhow!(
643            "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
644            args.join(" "),
645            repo_path.display(),
646            output.status,
647            String::from_utf8_lossy(&output.stderr),
648            String::from_utf8_lossy(&output.stdout),
649        ))
650    }
651}
652
653pub async fn send_language_model_request(
654    model: Arc<dyn LanguageModel>,
655    request: LanguageModelRequest,
656    cx: &AsyncApp,
657) -> anyhow::Result<String> {
658    match model.stream_completion_text(request, &cx).await {
659        Ok(mut stream) => {
660            let mut full_response = String::new();
661            while let Some(chunk_result) = stream.stream.next().await {
662                match chunk_result {
663                    Ok(chunk_str) => {
664                        full_response.push_str(&chunk_str);
665                    }
666                    Err(err) => {
667                        return Err(anyhow!(
668                            "Error receiving response from language model: {err}"
669                        ));
670                    }
671                }
672            }
673            Ok(full_response)
674        }
675        Err(err) => Err(anyhow!(
676            "Failed to get response from language model. Error was: {err}"
677        )),
678    }
679}
680
681#[cfg(test)]
682mod test {
683    use super::*;
684
685    #[test]
686    fn test_parse_judge_output() {
687        let response = r#"
688            <analysis>The model did a good job but there were still compilations errors.</analysis>
689            <score>3</score>
690        "#
691        .unindent();
692
693        let output = parse_judge_output(&response).unwrap();
694        assert_eq!(
695            output.analysis,
696            "The model did a good job but there were still compilations errors."
697        );
698        assert_eq!(output.score, 3);
699
700        let response = r#"
701            Text around ignored
702
703            <analysis>
704                Failed to compile:
705                - Error 1
706                - Error 2
707            </analysis>
708
709            <score>1</score>
710        "#
711        .unindent();
712
713        let output = parse_judge_output(&response).unwrap();
714        assert_eq!(output.analysis, "Failed to compile:\n- Error 1\n- Error 2");
715        assert_eq!(output.score, 1);
716    }
717}