example.rs

  1use agent::{RequestKind, ThreadEvent, ThreadStore};
  2use anyhow::{Context as _, Result, anyhow};
  3use assistant_tool::ToolWorkingSet;
  4use client::proto::LspWorkProgress;
  5use collections::HashMap;
  6use dap::DapRegistry;
  7use futures::channel::mpsc;
  8use futures::{FutureExt, StreamExt as _, select_biased};
  9use gpui::{App, AsyncApp, Entity, Task};
 10use handlebars::Handlebars;
 11use language::{DiagnosticSeverity, OffsetRangeExt};
 12use language_model::{
 13    LanguageModel, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
 14    StopReason, TokenUsage,
 15};
 16use project::{LspStore, Project, ProjectPath};
 17use serde::{Deserialize, Serialize};
 18use std::fmt::Write as _;
 19use std::fs::File;
 20use std::io::Write as _;
 21use std::sync::{Arc, Mutex};
 22use std::time::Duration;
 23use std::{
 24    fs,
 25    path::{Path, PathBuf},
 26};
 27use unindent::Unindent as _;
 28use util::ResultExt as _;
 29use util::command::new_smol_command;
 30use util::serde::default_true;
 31
 32use crate::AgentAppState;
 33
 34pub const EXAMPLES_DIR: &str = "./crates/eval/examples";
 35pub const REPOS_DIR: &str = "./crates/eval/repos";
 36pub const WORKTREES_DIR: &str = "./crates/eval/worktrees";
 37
 38const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2);
 39
 40#[derive(Clone, Debug, Deserialize)]
 41pub struct ExampleBase {
 42    pub url: String,
 43    pub revision: String,
 44    pub language_extension: Option<String>,
 45    pub insert_id: Option<String>,
 46    #[serde(default = "default_true")]
 47    pub require_lsp: bool,
 48}
 49
 50#[derive(Clone, Debug)]
 51pub struct Example {
 52    pub name: String,
 53    /// Content of `base.toml`
 54    pub base: ExampleBase,
 55    /// Content of `prompt.md`
 56    pub prompt: String,
 57    /// Content of `criteria.md`
 58    pub criteria: String,
 59    /// Markdown log file to append to
 60    pub log_file: Arc<Mutex<File>>,
 61}
 62
 63#[derive(Debug, Serialize, Deserialize, Clone)]
 64pub struct RunOutput {
 65    pub repository_diff: String,
 66    pub diagnostics: String,
 67    pub response_count: usize,
 68    pub token_usage: TokenUsage,
 69    pub tool_use_counts: HashMap<Arc<str>, u32>,
 70}
 71
 72#[derive(Debug, Clone, Serialize, Deserialize)]
 73pub struct JudgeInput {
 74    pub repository_diff: String,
 75    pub criteria: String,
 76}
 77
 78#[derive(Debug, Clone, Serialize, Deserialize)]
 79pub struct JudgeOutput {
 80    pub analysis: String,
 81    pub score: u32,
 82}
 83
 84impl Example {
 85    /// Load an example from a directory containing base.toml, prompt.md, and criteria.md
 86    pub fn load_from_directory(dir_path: &Path, run_dir: &Path) -> Result<Self> {
 87        let name = dir_path.file_name().unwrap().to_string_lossy().to_string();
 88        let base_path = dir_path.join("base.toml");
 89        let prompt_path = dir_path.join("prompt.md");
 90        let criteria_path = dir_path.join("criteria.md");
 91
 92        let log_file_path = run_dir.join(format!(
 93            "{}.md",
 94            dir_path.file_name().unwrap().to_str().unwrap()
 95        ));
 96        let log_file = Arc::new(Mutex::new(File::create(&log_file_path).unwrap()));
 97        println!("{}> Logging to {:?}", name, log_file_path);
 98
 99        Ok(Example {
100            name,
101            base: toml::from_str(&fs::read_to_string(&base_path)?)?,
102            prompt: fs::read_to_string(prompt_path.clone())?,
103            criteria: fs::read_to_string(criteria_path.clone())?,
104            log_file,
105        })
106    }
107
108    pub fn worktree_path(&self) -> PathBuf {
109        Path::new(WORKTREES_DIR)
110            .canonicalize()
111            .context(format!("No such directory {WORKTREES_DIR}"))
112            .unwrap()
113            .join(&self.name)
114    }
115
116    /// Set up the example by checking out the specified Git revision
117    pub async fn setup(&self) -> Result<()> {
118        let repo_path = repo_path_for_url(&self.base.url);
119
120        println!("{}> Fetching", self.name);
121
122        run_git(
123            &repo_path,
124            &["fetch", "--depth", "1", "origin", &self.base.revision],
125        )
126        .await?;
127
128        let worktree_path = self.worktree_path();
129
130        if worktree_path.is_dir() {
131            println!("{}> Resetting existing worktree", self.name);
132
133            // TODO: consider including "-x" to remove ignored files. The downside of this is that
134            // it will also remove build artifacts, and so prevent incremental reuse there.
135            run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
136            run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
137            run_git(&worktree_path, &["checkout", &self.base.revision]).await?;
138        } else {
139            println!("{}> Creating worktree", self.name);
140
141            let worktree_path_string = worktree_path.to_string_lossy().to_string();
142
143            run_git(
144                &repo_path,
145                &[
146                    "worktree",
147                    "add",
148                    "-f",
149                    &worktree_path_string,
150                    &self.base.revision,
151                ],
152            )
153            .await?;
154        }
155
156        Ok(())
157    }
158
159    pub fn run(
160        &self,
161        model: Arc<dyn LanguageModel>,
162        app_state: Arc<AgentAppState>,
163        cx: &mut App,
164    ) -> Task<Result<RunOutput>> {
165        let project = Project::local(
166            app_state.client.clone(),
167            app_state.node_runtime.clone(),
168            app_state.user_store.clone(),
169            app_state.languages.clone(),
170            Arc::new(DapRegistry::default()),
171            app_state.fs.clone(),
172            None,
173            cx,
174        );
175
176        let worktree_path = self.worktree_path();
177        let worktree = project.update(cx, |project, cx| {
178            project.create_worktree(&worktree_path, true, cx)
179        });
180
181        let tools = Arc::new(ToolWorkingSet::default());
182        let thread_store =
183            ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx);
184        let this = self.clone();
185
186        cx.spawn(async move |cx| {
187            let worktree = worktree.await?;
188
189            // Wait for worktree scan to finish before choosing a file to open.
190            worktree
191                .update(cx, |worktree, _cx| {
192                    worktree.as_local().unwrap().scan_complete()
193                })?
194                .await;
195
196            let lsp_open_handle_and_store = if this.base.require_lsp {
197                let language_extension = this.base.language_extension.as_deref().context(
198                    "language_extension field is required in base.toml when `require_lsp == true`",
199                )?;
200
201                // Open a file that matches the language to cause LSP to start.
202                let language_file = worktree.read_with(cx, |worktree, _cx| {
203                    worktree
204                        .files(false, 0)
205                        .find_map(|e| {
206                            if e.path.clone().extension().and_then(|ext| ext.to_str())
207                                == Some(language_extension)
208                            {
209                                Some(ProjectPath {
210                                    worktree_id: worktree.id(),
211                                    path: e.path.clone(),
212                                })
213                            } else {
214                                None
215                            }
216                        })
217                        .context("Failed to find a file for example language")
218                })??;
219
220                let open_language_file_buffer_task = project.update(cx, |project, cx| {
221                    project.open_buffer(language_file.clone(), cx)
222                })?;
223
224                let language_file_buffer = open_language_file_buffer_task.await?;
225
226                let (lsp_open_handle, lsp_store) = project.update(cx, |project, cx| {
227                    (
228                        project.register_buffer_with_language_servers(&language_file_buffer, cx),
229                        project.lsp_store().clone(),
230                    )
231                })?;
232
233                // TODO: remove this once the diagnostics tool waits for new diagnostics
234                cx.background_executor().timer(Duration::new(5, 0)).await;
235                wait_for_lang_server(&lsp_store, this.name.clone(), cx).await?;
236
237                lsp_store.update(cx, |lsp_store, cx| {
238                    lsp_open_handle.update(cx, |buffer, cx| {
239                        buffer.update(cx, |buffer, cx| {
240                            let has_language_server = lsp_store
241                                .language_servers_for_local_buffer(buffer, cx)
242                                .next()
243                                .is_some();
244                            if has_language_server {
245                                Ok(())
246                            } else {
247                                Err(anyhow!(
248                                    "`{:?}` was opened to cause the language server to start, \
249                                    but no language servers are registered for its buffer. \
250                                    Set `require_lsp = false` in `base.toml` to skip this.",
251                                    language_file
252                                ))
253                            }
254                        })
255                    })
256                })??;
257
258                Some((lsp_open_handle, lsp_store))
259            } else {
260                None
261            };
262
263            if std::env::var("ZED_EVAL_SETUP_ONLY").is_ok() {
264                return Err(anyhow!("Setup only mode"));
265            }
266
267            let thread_store = thread_store.await;
268            let thread =
269                thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?;
270
271            {
272                let mut log_file = this.log_file.lock().unwrap();
273                writeln!(&mut log_file, "👤 USER:").log_err();
274                writeln!(&mut log_file, "{}", this.prompt).log_err();
275                writeln!(&mut log_file, "🤖 ASSISTANT:").log_err();
276                log_file.flush().log_err();
277            }
278
279            let tool_use_counts: Arc<Mutex<HashMap<Arc<str>, u32>>> =
280                Mutex::new(HashMap::default()).into();
281
282            let (thread_event_tx, mut thread_event_rx) = mpsc::unbounded();
283
284            let subscription = cx.subscribe(&thread, move |_thread, event: &ThreadEvent, _cx| {
285                thread_event_tx.unbounded_send(event.clone()).log_err();
286            });
287
288            let event_handler_task = cx.spawn({
289                let log_file = this.log_file.clone();
290                let name = this.name.clone();
291                let tool_use_counts = tool_use_counts.clone();
292                let thread = thread.downgrade();
293                async move |cx| {
294                    loop {
295                        let event = select_biased! {
296                            event = thread_event_rx.next() => event,
297                            _ = cx.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => {
298                                return Err(anyhow!("Agentic loop stalled - waited {:?} without any events", THREAD_EVENT_TIMEOUT));
299                            }
300                        };
301                        let Some(event) = event else {
302                            return Err(anyhow!("ThreadEvent channel ended early"));
303                        };
304
305                        let mut log_file = log_file.lock().unwrap();
306
307                        match event {
308                            ThreadEvent::Stopped(reason) => match reason {
309                                Ok(StopReason::EndTurn) => {
310                                    return Ok(());
311                                }
312                                Ok(StopReason::MaxTokens) => {
313                                    return Err(anyhow!("Exceeded maximum tokens"));
314                                }
315                                Ok(StopReason::ToolUse) => {}
316                                Err(error) => {
317                                    return Err(anyhow!(error.clone()));
318                                }
319                            },
320                            ThreadEvent::ShowError(thread_error) => {
321                                break Err(anyhow!(thread_error.clone()));
322                            }
323                            ThreadEvent::StreamedAssistantText(_, chunk) => {
324                                write!(&mut log_file, "{}", chunk).log_err();
325                            }
326                            ThreadEvent::StreamedAssistantThinking(_, chunk) => {
327                                write!(&mut log_file, "{}", chunk).log_err();
328                            }
329                            ThreadEvent::UsePendingTools { tool_uses } => {
330                                writeln!(&mut log_file, "\n\nUSING TOOLS:").log_err();
331                                for tool_use in tool_uses {
332                                    writeln!(&mut log_file, "{}: {}", tool_use.name, tool_use.input)
333                                        .log_err();
334                                }
335                            }
336                            ThreadEvent::ToolFinished {
337                                tool_use_id,
338                                pending_tool_use,
339                                ..
340                            } => {
341                                if let Some(tool_use) = pending_tool_use {
342                                    let message = format!("TOOL FINISHED: {}", tool_use.name);
343                                    println!("{name}> {message}");
344                                    writeln!(&mut log_file, "\n{}", message).log_err();
345                                }
346                                thread.update(cx, |thread, _cx| {
347                                    if let Some(tool_result) = thread.tool_result(&tool_use_id) {
348                                        writeln!(&mut log_file, "\n{}\n", tool_result.content).log_err();
349                                        let mut tool_use_counts = tool_use_counts.lock().unwrap();
350                                        *tool_use_counts
351                                            .entry(tool_result.tool_name.clone())
352                                            .or_insert(0) += 1;
353                                    }
354                                })?;
355                            }
356                            _ => {}
357                        }
358
359                        log_file.flush().log_err();
360                    }
361                }
362            });
363
364            thread.update(cx, |thread, cx| {
365                let context = vec![];
366                thread.insert_user_message(this.prompt.clone(), context, None, cx);
367                thread.send_to_model(model, RequestKind::Chat, cx);
368            })?;
369
370            event_handler_task.await?;
371
372            if let Some((_, lsp_store)) = lsp_open_handle_and_store.as_ref() {
373                wait_for_lang_server(lsp_store, this.name.clone(), cx).await?;
374            }
375
376            let repository_diff = this.repository_diff().await?;
377            let diagnostics = cx
378                .update(move |cx| {
379                    cx.spawn(async move |cx| query_lsp_diagnostics(project, cx).await)
380                })?
381                .await?;
382
383            drop(subscription);
384            drop(lsp_open_handle_and_store);
385
386            thread.update(cx, |thread, _cx| {
387                let response_count = thread
388                    .messages()
389                    .filter(|message| message.role == language_model::Role::Assistant)
390                    .count();
391                RunOutput {
392                    repository_diff,
393                    diagnostics,
394                    response_count,
395                    token_usage: thread.cumulative_token_usage(),
396                    tool_use_counts: tool_use_counts.lock().unwrap().clone(),
397                }
398            })
399        })
400    }
401
402    pub async fn judge(
403        &mut self,
404        model: Arc<dyn LanguageModel>,
405        repository_diff: String,
406        cx: &AsyncApp,
407    ) -> Result<JudgeOutput> {
408        let judge_prompt = include_str!("judge_prompt.hbs");
409        let judge_prompt_name = "judge_prompt";
410        let mut handlebars = Handlebars::new();
411        handlebars.register_template_string(judge_prompt_name, judge_prompt)?;
412        let prompt = handlebars.render(
413            judge_prompt_name,
414            &JudgeInput {
415                repository_diff,
416                criteria: self.criteria.clone(),
417            },
418        )?;
419
420        let request = LanguageModelRequest {
421            messages: vec![LanguageModelRequestMessage {
422                role: Role::User,
423                content: vec![MessageContent::Text(prompt)],
424                cache: false,
425            }],
426            temperature: None,
427            tools: Vec::new(),
428            stop: Vec::new(),
429        };
430
431        let response = send_language_model_request(model, request, cx).await?;
432
433        let mut log_file = self.log_file.lock().unwrap();
434
435        writeln!(&mut log_file, "\n\n").log_err();
436        writeln!(&mut log_file, "========================================").log_err();
437        writeln!(&mut log_file, "              JUDGE OUTPUT              ").log_err();
438        writeln!(&mut log_file, "========================================").log_err();
439        writeln!(&mut log_file, "\n{}", &response).log_err();
440
441        parse_judge_output(&response)
442    }
443
444    pub async fn repository_diff(&self) -> Result<String> {
445        let worktree_path = self.worktree_path();
446        run_git(&worktree_path, &["add", "-N"]).await?;
447        run_git(&worktree_path, &["diff"]).await
448    }
449}
450
451fn wait_for_lang_server(
452    lsp_store: &Entity<LspStore>,
453    name: String,
454    cx: &mut AsyncApp,
455) -> Task<Result<()>> {
456    if cx
457        .update(|cx| !has_pending_lang_server_work(lsp_store, cx))
458        .unwrap()
459        || std::env::var("ZED_EVAL_SKIP_LS_WAIT").is_ok()
460    {
461        return Task::ready(anyhow::Ok(()));
462    }
463
464    println!("{}> ⏵ Waiting for language server", name);
465
466    let (mut tx, mut rx) = mpsc::channel(1);
467
468    let subscription =
469        cx.subscribe(&lsp_store, {
470            let name = name.clone();
471            move |lsp_store, event, cx| {
472                match event {
473                    project::LspStoreEvent::LanguageServerUpdate {
474                        message:
475                            client::proto::update_language_server::Variant::WorkProgress(
476                                LspWorkProgress {
477                                    message: Some(message),
478                                    ..
479                                },
480                            ),
481                        ..
482                    } => println!("{name}> ⟲ {message}"),
483                    _ => {}
484                }
485
486                if !has_pending_lang_server_work(&lsp_store, cx) {
487                    tx.try_send(()).ok();
488                }
489            }
490        });
491
492    cx.spawn(async move |cx| {
493        let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0));
494        let result = futures::select! {
495            _ = rx.next() => {
496                println!("{}> ⚑ Language server idle", name);
497                anyhow::Ok(())
498            },
499            _ = timeout.fuse() => {
500                Err(anyhow!("LSP wait timed out after 5 minutes"))
501            }
502        };
503        drop(subscription);
504        result
505    })
506}
507
508fn has_pending_lang_server_work(lsp_store: &Entity<LspStore>, cx: &App) -> bool {
509    lsp_store
510        .read(cx)
511        .language_server_statuses()
512        .any(|(_, status)| !status.pending_work.is_empty())
513}
514
515async fn query_lsp_diagnostics(project: Entity<Project>, cx: &mut AsyncApp) -> Result<String> {
516    let paths_with_diagnostics = project.update(cx, |project, cx| {
517        project
518            .diagnostic_summaries(true, cx)
519            .filter(|(_, _, summary)| summary.error_count > 0 || summary.warning_count > 0)
520            .map(|(project_path, _, _)| project_path)
521            .collect::<Vec<_>>()
522    })?;
523
524    let mut output = String::new();
525    for project_path in paths_with_diagnostics {
526        let buffer = project
527            .update(cx, |project, cx| project.open_buffer(project_path, cx))?
528            .await?;
529        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
530
531        for (_, group) in snapshot.diagnostic_groups(None) {
532            let entry = &group.entries[group.primary_ix];
533            let range = entry.range.to_point(&snapshot);
534            let severity = match entry.diagnostic.severity {
535                DiagnosticSeverity::ERROR => "error",
536                DiagnosticSeverity::WARNING => "warning",
537                _ => continue,
538            };
539
540            writeln!(
541                output,
542                "{} at line {}: {}",
543                severity,
544                range.start.row + 1,
545                entry.diagnostic.message
546            )?;
547        }
548    }
549    anyhow::Ok(output)
550}
551
552fn parse_judge_output(response: &str) -> Result<JudgeOutput> {
553    let analysis = get_tag("analysis", response)?.to_string();
554    let score = get_tag("score", response)?
555        .parse()
556        .context("error parsing score")?;
557
558    Ok(JudgeOutput { analysis, score })
559}
560
561fn get_tag(name: &'static str, response: &str) -> Result<String> {
562    let start_tag = format!("<{}>", name);
563    let end_tag = format!("</{}>", name);
564
565    let start_ix = response
566        .find(&start_tag)
567        .context(format!("{} start tag not found", name))?;
568    let content_start_ix = start_ix + start_tag.len();
569
570    let end_ix = content_start_ix
571        + response[content_start_ix..]
572            .find(&end_tag)
573            .context(format!("{} end tag not found", name))?;
574
575    let content = response[content_start_ix..end_ix].trim().unindent();
576
577    anyhow::Ok(content)
578}
579
580pub fn repo_path_for_url(repo_url: &str) -> PathBuf {
581    let repo_name = repo_url
582        .trim_start_matches("https://")
583        .replace(|c: char| !c.is_alphanumeric(), "-");
584    Path::new(REPOS_DIR)
585        .canonicalize()
586        .context(format!("No such directory {REPOS_DIR}"))
587        .unwrap()
588        .join(repo_name)
589}
590
591pub async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
592    let output = new_smol_command("git")
593        .current_dir(repo_path)
594        .args(args)
595        .output()
596        .await?;
597
598    if output.status.success() {
599        Ok(String::from_utf8(output.stdout)?.trim().to_string())
600    } else {
601        Err(anyhow!(
602            "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
603            args.join(" "),
604            repo_path.display(),
605            output.status,
606            String::from_utf8_lossy(&output.stderr),
607            String::from_utf8_lossy(&output.stdout),
608        ))
609    }
610}
611
612pub async fn send_language_model_request(
613    model: Arc<dyn LanguageModel>,
614    request: LanguageModelRequest,
615    cx: &AsyncApp,
616) -> anyhow::Result<String> {
617    match model.stream_completion_text(request, &cx).await {
618        Ok(mut stream) => {
619            let mut full_response = String::new();
620            while let Some(chunk_result) = stream.stream.next().await {
621                match chunk_result {
622                    Ok(chunk_str) => {
623                        print!("{}", &chunk_str);
624                        full_response.push_str(&chunk_str);
625                    }
626                    Err(err) => {
627                        return Err(anyhow!(
628                            "Error receiving response from language model: {err}"
629                        ));
630                    }
631                }
632            }
633            Ok(full_response)
634        }
635        Err(err) => Err(anyhow!(
636            "Failed to get response from language model. Error was: {err}"
637        )),
638    }
639}
640
641#[cfg(test)]
642mod test {
643    use super::*;
644
645    #[test]
646    fn test_parse_judge_output() {
647        let response = r#"
648            <analysis>The model did a good job but there were still compilations errors.</analysis>
649            <score>3</score>
650        "#
651        .unindent();
652
653        let output = parse_judge_output(&response).unwrap();
654        assert_eq!(
655            output.analysis,
656            "The model did a good job but there were still compilations errors."
657        );
658        assert_eq!(output.score, 3);
659
660        let response = r#"
661            Text around ignored
662
663            <analysis>
664                Failed to compile:
665                - Error 1
666                - Error 2
667            </analysis>
668
669            <score>1</score>
670        "#
671        .unindent();
672
673        let output = parse_judge_output(&response).unwrap();
674        assert_eq!(output.analysis, "Failed to compile:\n- Error 1\n- Error 2");
675        assert_eq!(output.score, 1);
676    }
677}