example.rs

  1use agent::{RequestKind, ThreadEvent, ThreadStore};
  2use anyhow::{Context as _, Result, anyhow};
  3use assistant_tool::ToolWorkingSet;
  4use client::proto::LspWorkProgress;
  5use collections::HashMap;
  6use dap::DapRegistry;
  7use futures::channel::{mpsc, oneshot};
  8use futures::{FutureExt, StreamExt as _};
  9use gpui::{App, AsyncApp, Entity, Task};
 10use handlebars::Handlebars;
 11use language::{DiagnosticSeverity, OffsetRangeExt};
 12use language_model::{
 13    LanguageModel, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
 14    StopReason, TokenUsage,
 15};
 16use project::{LspStore, Project, ProjectPath};
 17use serde::{Deserialize, Serialize};
 18use std::fmt::Write as _;
 19use std::fs::File;
 20use std::io::Write as _;
 21use std::sync::{Arc, Mutex};
 22use std::time::Duration;
 23use std::{
 24    fs,
 25    path::{Path, PathBuf},
 26};
 27use unindent::Unindent as _;
 28use util::ResultExt as _;
 29use util::command::new_smol_command;
 30use util::serde::default_true;
 31
 32use crate::AgentAppState;
 33
 34pub const EXAMPLES_DIR: &str = "./crates/eval/examples";
 35pub const REPOS_DIR: &str = "./crates/eval/repos";
 36pub const WORKTREES_DIR: &str = "./crates/eval/worktrees";
 37
 38#[derive(Clone, Debug, Deserialize)]
 39pub struct ExampleBase {
 40    pub url: String,
 41    pub revision: String,
 42    pub language_extension: Option<String>,
 43    pub insert_id: Option<String>,
 44    #[serde(default = "default_true")]
 45    pub require_lsp: bool,
 46}
 47
 48#[derive(Clone, Debug)]
 49pub struct Example {
 50    pub name: String,
 51    /// Content of `base.toml`
 52    pub base: ExampleBase,
 53    /// Content of `prompt.md`
 54    pub prompt: String,
 55    /// Content of `criteria.md`
 56    pub criteria: String,
 57    /// Markdown log file to append to
 58    pub log_file: Arc<Mutex<File>>,
 59}
 60
 61#[derive(Debug, Serialize, Deserialize, Clone)]
 62pub struct RunOutput {
 63    pub repository_diff: String,
 64    pub diagnostics: String,
 65    pub response_count: usize,
 66    pub token_usage: TokenUsage,
 67    pub tool_use_counts: HashMap<Arc<str>, u32>,
 68}
 69
 70#[derive(Debug, Clone, Serialize, Deserialize)]
 71pub struct JudgeInput {
 72    pub repository_diff: String,
 73    pub criteria: String,
 74}
 75
 76#[derive(Debug, Clone, Serialize, Deserialize)]
 77pub struct JudgeOutput {
 78    pub analysis: String,
 79    pub score: u32,
 80}
 81
 82impl Example {
 83    /// Load an example from a directory containing base.toml, prompt.md, and criteria.md
 84    pub fn load_from_directory(dir_path: &Path, run_dir: &Path) -> Result<Self> {
 85        let name = dir_path.file_name().unwrap().to_string_lossy().to_string();
 86        let base_path = dir_path.join("base.toml");
 87        let prompt_path = dir_path.join("prompt.md");
 88        let criteria_path = dir_path.join("criteria.md");
 89
 90        let log_file_path = run_dir.join(format!(
 91            "{}.md",
 92            dir_path.file_name().unwrap().to_str().unwrap()
 93        ));
 94        let log_file = Arc::new(Mutex::new(File::create(&log_file_path).unwrap()));
 95        println!("{}> Logging to {:?}", name, log_file_path);
 96
 97        Ok(Example {
 98            name,
 99            base: toml::from_str(&fs::read_to_string(&base_path)?)?,
100            prompt: fs::read_to_string(prompt_path.clone())?,
101            criteria: fs::read_to_string(criteria_path.clone())?,
102            log_file,
103        })
104    }
105
106    pub fn worktree_path(&self) -> PathBuf {
107        Path::new(WORKTREES_DIR)
108            .canonicalize()
109            .context(format!("No such directory {WORKTREES_DIR}"))
110            .unwrap()
111            .join(&self.name)
112    }
113
114    /// Set up the example by checking out the specified Git revision
115    pub async fn setup(&self) -> Result<()> {
116        let repo_path = repo_path_for_url(&self.base.url);
117
118        println!("{}> Fetching", self.name);
119
120        run_git(
121            &repo_path,
122            &["fetch", "--depth", "1", "origin", &self.base.revision],
123        )
124        .await?;
125
126        let worktree_path = self.worktree_path();
127
128        if worktree_path.is_dir() {
129            println!("{}> Resetting existing worktree", self.name);
130
131            // TODO: consider including "-x" to remove ignored files. The downside of this is that
132            // it will also remove build artifacts, and so prevent incremental reuse there.
133            run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
134            run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
135            run_git(&worktree_path, &["checkout", &self.base.revision]).await?;
136        } else {
137            println!("{}> Creating worktree", self.name);
138
139            let worktree_path_string = worktree_path.to_string_lossy().to_string();
140
141            run_git(
142                &repo_path,
143                &[
144                    "worktree",
145                    "add",
146                    "-f",
147                    &worktree_path_string,
148                    &self.base.revision,
149                ],
150            )
151            .await?;
152        }
153
154        Ok(())
155    }
156
157    pub fn run(
158        &self,
159        model: Arc<dyn LanguageModel>,
160        app_state: Arc<AgentAppState>,
161        cx: &mut App,
162    ) -> Task<Result<RunOutput>> {
163        let project = Project::local(
164            app_state.client.clone(),
165            app_state.node_runtime.clone(),
166            app_state.user_store.clone(),
167            app_state.languages.clone(),
168            Arc::new(DapRegistry::default()),
169            app_state.fs.clone(),
170            None,
171            cx,
172        );
173
174        let worktree_path = self.worktree_path();
175        let worktree = project.update(cx, |project, cx| {
176            project.create_worktree(&worktree_path, true, cx)
177        });
178
179        let tools = Arc::new(ToolWorkingSet::default());
180        let thread_store =
181            ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx);
182        let this = self.clone();
183
184        cx.spawn(async move |cx| {
185            let worktree = worktree.await?;
186
187            // Wait for worktree scan to finish before choosing a file to open.
188            worktree
189                .update(cx, |worktree, _cx| {
190                    worktree.as_local().unwrap().scan_complete()
191                })?
192                .await;
193
194            let lsp_open_handle_and_store = if this.base.require_lsp {
195                let language_extension = this.base.language_extension.as_deref().context(
196                    "language_extension field is required in base.toml when `require_lsp == true`",
197                )?;
198
199                // Open a file that matches the language to cause LSP to start.
200                let language_file = worktree.read_with(cx, |worktree, _cx| {
201                    worktree
202                        .files(false, 0)
203                        .find_map(|e| {
204                            if e.path.clone().extension().and_then(|ext| ext.to_str())
205                                == Some(language_extension)
206                            {
207                                Some(ProjectPath {
208                                    worktree_id: worktree.id(),
209                                    path: e.path.clone(),
210                                })
211                            } else {
212                                None
213                            }
214                        })
215                        .context("Failed to find a file for example language")
216                })??;
217
218                let open_language_file_buffer_task = project.update(cx, |project, cx| {
219                    project.open_buffer(language_file.clone(), cx)
220                })?;
221
222                let language_file_buffer = open_language_file_buffer_task.await?;
223
224                let (lsp_open_handle, lsp_store) = project.update(cx, |project, cx| {
225                    (
226                        project.register_buffer_with_language_servers(&language_file_buffer, cx),
227                        project.lsp_store().clone(),
228                    )
229                })?;
230
231                // TODO: remove this once the diagnostics tool waits for new diagnostics
232                cx.background_executor().timer(Duration::new(5, 0)).await;
233                wait_for_lang_server(&lsp_store, this.name.clone(), cx).await?;
234
235                lsp_store.update(cx, |lsp_store, cx| {
236                    lsp_open_handle.update(cx, |buffer, cx| {
237                        buffer.update(cx, |buffer, cx| {
238                            let has_language_server = lsp_store
239                                .language_servers_for_local_buffer(buffer, cx)
240                                .next()
241                                .is_some();
242                            if has_language_server {
243                                Ok(())
244                            } else {
245                                Err(anyhow!(
246                                    "`{:?}` was opened to cause the language server to start, \
247                                    but no language servers are registered for its buffer. \
248                                    Set `require_lsp = false` in `base.toml` to skip this.",
249                                    language_file
250                                ))
251                            }
252                        })
253                    })
254                })??;
255
256                Some((lsp_open_handle, lsp_store))
257            } else {
258                None
259            };
260
261            if std::env::var("ZED_EVAL_SETUP_ONLY").is_ok() {
262                return Err(anyhow!("Setup only mode"));
263            }
264
265            let thread_store = thread_store.await;
266            let thread =
267                thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?;
268
269            {
270                let mut log_file = this.log_file.lock().unwrap();
271                writeln!(&mut log_file, "👤 USER:").log_err();
272                writeln!(&mut log_file, "{}", this.prompt).log_err();
273                writeln!(&mut log_file, "🤖 ASSISTANT:").log_err();
274                log_file.flush().log_err();
275            }
276
277            let tool_use_counts: Arc<Mutex<HashMap<Arc<str>, u32>>> =
278                Mutex::new(HashMap::default()).into();
279
280            let (tx, rx) = oneshot::channel();
281            let mut tx = Some(tx);
282
283            let subscription = cx.subscribe(&thread, {
284                let log_file = this.log_file.clone();
285                let name = this.name.clone();
286                let tool_use_counts = tool_use_counts.clone();
287                move |thread, event: &ThreadEvent, cx| {
288                    let mut log_file = log_file.lock().unwrap();
289
290                    match event {
291                        ThreadEvent::Stopped(reason) => match reason {
292                            Ok(StopReason::EndTurn) => {
293                                if let Some(tx) = tx.take() {
294                                    tx.send(Ok(())).ok();
295                                }
296                            }
297                            Ok(StopReason::MaxTokens) => {
298                                if let Some(tx) = tx.take() {
299                                    tx.send(Err(anyhow!("Exceeded maximum tokens"))).ok();
300                                }
301                            }
302                            Ok(StopReason::ToolUse) => {}
303                            Err(error) => {
304                                if let Some(tx) = tx.take() {
305                                    tx.send(Err(anyhow!(error.clone()))).ok();
306                                }
307                            }
308                        },
309                        ThreadEvent::ShowError(thread_error) => {
310                            if let Some(tx) = tx.take() {
311                                tx.send(Err(anyhow!(thread_error.clone()))).ok();
312                            }
313                        }
314                        ThreadEvent::StreamedAssistantText(_, chunk) => {
315                            write!(&mut log_file, "{}", chunk).log_err();
316                        }
317                        ThreadEvent::StreamedAssistantThinking(_, chunk) => {
318                            write!(&mut log_file, "{}", chunk).log_err();
319                        }
320                        ThreadEvent::UsePendingTools { tool_uses } => {
321                            writeln!(&mut log_file, "\n\nUSING TOOLS:").log_err();
322                            for tool_use in tool_uses {
323                                writeln!(&mut log_file, "{}: {}", tool_use.name, tool_use.input)
324                                    .log_err();
325                            }
326                        }
327                        ThreadEvent::ToolFinished {
328                            tool_use_id,
329                            pending_tool_use,
330                            ..
331                        } => {
332                            if let Some(tool_use) = pending_tool_use {
333                                let message = format!("TOOL FINISHED: {}", tool_use.name);
334                                println!("{name}> {message}");
335                                writeln!(&mut log_file, "\n{}", message).log_err();
336                            }
337                            if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
338                                writeln!(&mut log_file, "\n{}\n", tool_result.content).log_err();
339                                let mut tool_use_counts = tool_use_counts.lock().unwrap();
340                                *tool_use_counts
341                                    .entry(tool_result.tool_name.clone())
342                                    .or_insert(0) += 1;
343                            }
344                        }
345                        _ => {}
346                    }
347
348                    log_file.flush().log_err();
349                }
350            })?;
351
352            thread.update(cx, |thread, cx| {
353                let context = vec![];
354                thread.insert_user_message(this.prompt.clone(), context, None, cx);
355                thread.send_to_model(model, RequestKind::Chat, cx);
356            })?;
357
358            rx.await??;
359
360            if let Some((_, lsp_store)) = lsp_open_handle_and_store.as_ref() {
361                wait_for_lang_server(lsp_store, this.name.clone(), cx).await?;
362            }
363
364            let repository_diff = this.repository_diff().await?;
365            let diagnostics = cx
366                .update(move |cx| {
367                    cx.spawn(async move |cx| query_lsp_diagnostics(project, cx).await)
368                })?
369                .await?;
370
371            drop(subscription);
372            drop(lsp_open_handle_and_store);
373
374            thread.update(cx, |thread, _cx| {
375                let response_count = thread
376                    .messages()
377                    .filter(|message| message.role == language_model::Role::Assistant)
378                    .count();
379                RunOutput {
380                    repository_diff,
381                    diagnostics,
382                    response_count,
383                    token_usage: thread.cumulative_token_usage(),
384                    tool_use_counts: tool_use_counts.lock().unwrap().clone(),
385                }
386            })
387        })
388    }
389
390    pub async fn judge(
391        &mut self,
392        model: Arc<dyn LanguageModel>,
393        repository_diff: String,
394        cx: &AsyncApp,
395    ) -> Result<JudgeOutput> {
396        let judge_prompt = include_str!("judge_prompt.hbs");
397        let judge_prompt_name = "judge_prompt";
398        let mut handlebars = Handlebars::new();
399        handlebars.register_template_string(judge_prompt_name, judge_prompt)?;
400        let prompt = handlebars.render(
401            judge_prompt_name,
402            &JudgeInput {
403                repository_diff,
404                criteria: self.criteria.clone(),
405            },
406        )?;
407
408        let request = LanguageModelRequest {
409            messages: vec![LanguageModelRequestMessage {
410                role: Role::User,
411                content: vec![MessageContent::Text(prompt)],
412                cache: false,
413            }],
414            temperature: None,
415            tools: Vec::new(),
416            stop: Vec::new(),
417        };
418
419        let response = send_language_model_request(model, request, cx).await?;
420
421        let mut log_file = self.log_file.lock().unwrap();
422
423        writeln!(&mut log_file, "\n\n").log_err();
424        writeln!(&mut log_file, "========================================").log_err();
425        writeln!(&mut log_file, "              JUDGE OUTPUT              ").log_err();
426        writeln!(&mut log_file, "========================================").log_err();
427        writeln!(&mut log_file, "\n{}", &response).log_err();
428
429        parse_judge_output(&response)
430    }
431
432    pub async fn repository_diff(&self) -> Result<String> {
433        let worktree_path = self.worktree_path();
434        run_git(&worktree_path, &["add", "-N"]).await?;
435        run_git(&worktree_path, &["diff"]).await
436    }
437}
438
439fn wait_for_lang_server(
440    lsp_store: &Entity<LspStore>,
441    name: String,
442    cx: &mut AsyncApp,
443) -> Task<Result<()>> {
444    if cx
445        .update(|cx| !has_pending_lang_server_work(lsp_store, cx))
446        .unwrap()
447        || std::env::var("ZED_EVAL_SKIP_LS_WAIT").is_ok()
448    {
449        return Task::ready(anyhow::Ok(()));
450    }
451
452    println!("{}> ⏵ Waiting for language server", name);
453
454    let (mut tx, mut rx) = mpsc::channel(1);
455
456    let subscription =
457        cx.subscribe(&lsp_store, {
458            let name = name.clone();
459            move |lsp_store, event, cx| {
460                match event {
461                    project::LspStoreEvent::LanguageServerUpdate {
462                        message:
463                            client::proto::update_language_server::Variant::WorkProgress(
464                                LspWorkProgress {
465                                    message: Some(message),
466                                    ..
467                                },
468                            ),
469                        ..
470                    } => println!("{name}> ⟲ {message}"),
471                    _ => {}
472                }
473
474                if !has_pending_lang_server_work(&lsp_store, cx) {
475                    tx.try_send(()).ok();
476                }
477            }
478        });
479
480    cx.spawn(async move |cx| {
481        let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0));
482        let result = futures::select! {
483            _ = rx.next() => {
484                println!("{}> ⚑ Language server idle", name);
485                anyhow::Ok(())
486            },
487            _ = timeout.fuse() => {
488                Err(anyhow!("LSP wait timed out after 5 minutes"))
489            }
490        };
491        drop(subscription);
492        result
493    })
494}
495
496fn has_pending_lang_server_work(lsp_store: &Entity<LspStore>, cx: &App) -> bool {
497    lsp_store
498        .read(cx)
499        .language_server_statuses()
500        .any(|(_, status)| !status.pending_work.is_empty())
501}
502
503async fn query_lsp_diagnostics(project: Entity<Project>, cx: &mut AsyncApp) -> Result<String> {
504    let paths_with_diagnostics = project.update(cx, |project, cx| {
505        project
506            .diagnostic_summaries(true, cx)
507            .filter(|(_, _, summary)| summary.error_count > 0 || summary.warning_count > 0)
508            .map(|(project_path, _, _)| project_path)
509            .collect::<Vec<_>>()
510    })?;
511
512    let mut output = String::new();
513    for project_path in paths_with_diagnostics {
514        let buffer = project
515            .update(cx, |project, cx| project.open_buffer(project_path, cx))?
516            .await?;
517        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
518
519        for (_, group) in snapshot.diagnostic_groups(None) {
520            let entry = &group.entries[group.primary_ix];
521            let range = entry.range.to_point(&snapshot);
522            let severity = match entry.diagnostic.severity {
523                DiagnosticSeverity::ERROR => "error",
524                DiagnosticSeverity::WARNING => "warning",
525                _ => continue,
526            };
527
528            writeln!(
529                output,
530                "{} at line {}: {}",
531                severity,
532                range.start.row + 1,
533                entry.diagnostic.message
534            )?;
535        }
536    }
537    anyhow::Ok(output)
538}
539
540fn parse_judge_output(response: &str) -> Result<JudgeOutput> {
541    let analysis = get_tag("analysis", response)?.to_string();
542    let score = get_tag("score", response)?
543        .parse()
544        .context("error parsing score")?;
545
546    Ok(JudgeOutput { analysis, score })
547}
548
549fn get_tag(name: &'static str, response: &str) -> Result<String> {
550    let start_tag = format!("<{}>", name);
551    let end_tag = format!("</{}>", name);
552
553    let start_ix = response
554        .find(&start_tag)
555        .context(format!("{} start tag not found", name))?;
556    let content_start_ix = start_ix + start_tag.len();
557
558    let end_ix = content_start_ix
559        + response[content_start_ix..]
560            .find(&end_tag)
561            .context(format!("{} end tag not found", name))?;
562
563    let content = response[content_start_ix..end_ix].trim().unindent();
564
565    anyhow::Ok(content)
566}
567
568pub fn repo_path_for_url(repo_url: &str) -> PathBuf {
569    let repo_name = repo_url
570        .trim_start_matches("https://")
571        .replace(|c: char| !c.is_alphanumeric(), "-");
572    Path::new(REPOS_DIR)
573        .canonicalize()
574        .context(format!("No such directory {REPOS_DIR}"))
575        .unwrap()
576        .join(repo_name)
577}
578
579pub async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
580    let output = new_smol_command("git")
581        .current_dir(repo_path)
582        .args(args)
583        .output()
584        .await?;
585
586    if output.status.success() {
587        Ok(String::from_utf8(output.stdout)?.trim().to_string())
588    } else {
589        Err(anyhow!(
590            "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
591            args.join(" "),
592            repo_path.display(),
593            output.status,
594            String::from_utf8_lossy(&output.stderr),
595            String::from_utf8_lossy(&output.stdout),
596        ))
597    }
598}
599
600pub async fn send_language_model_request(
601    model: Arc<dyn LanguageModel>,
602    request: LanguageModelRequest,
603    cx: &AsyncApp,
604) -> anyhow::Result<String> {
605    match model.stream_completion_text(request, &cx).await {
606        Ok(mut stream) => {
607            let mut full_response = String::new();
608            while let Some(chunk_result) = stream.stream.next().await {
609                match chunk_result {
610                    Ok(chunk_str) => {
611                        print!("{}", &chunk_str);
612                        full_response.push_str(&chunk_str);
613                    }
614                    Err(err) => {
615                        return Err(anyhow!(
616                            "Error receiving response from language model: {err}"
617                        ));
618                    }
619                }
620            }
621            Ok(full_response)
622        }
623        Err(err) => Err(anyhow!(
624            "Failed to get response from language model. Error was: {err}"
625        )),
626    }
627}
628
629#[cfg(test)]
630mod test {
631    use super::*;
632
633    #[test]
634    fn test_parse_judge_output() {
635        let response = r#"
636            <analysis>The model did a good job but there were still compilations errors.</analysis>
637            <score>3</score>
638        "#
639        .unindent();
640
641        let output = parse_judge_output(&response).unwrap();
642        assert_eq!(
643            output.analysis,
644            "The model did a good job but there were still compilations errors."
645        );
646        assert_eq!(output.score, 3);
647
648        let response = r#"
649            Text around ignored
650
651            <analysis>
652                Failed to compile:
653                - Error 1
654                - Error 2
655            </analysis>
656
657            <score>1</score>
658        "#
659        .unindent();
660
661        let output = parse_judge_output(&response).unwrap();
662        assert_eq!(output.analysis, "Failed to compile:\n- Error 1\n- Error 2");
663        assert_eq!(output.score, 1);
664    }
665}