example.rs

  1use agent::{RequestKind, ThreadEvent, ThreadStore};
  2use anyhow::{Context as _, Result, anyhow};
  3use assistant_tool::ToolWorkingSet;
  4use client::proto::LspWorkProgress;
  5use dap::DapRegistry;
  6use futures::channel::{mpsc, oneshot};
  7use futures::{FutureExt, StreamExt as _};
  8use gpui::{App, AsyncApp, Entity, Task};
  9use handlebars::Handlebars;
 10use language::{DiagnosticSeverity, OffsetRangeExt};
 11use language_model::{
 12    LanguageModel, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
 13    StopReason, TokenUsage,
 14};
 15use project::{LspStore, Project, ProjectPath};
 16use serde::{Deserialize, Serialize};
 17use std::fmt::Write as _;
 18use std::fs::File;
 19use std::io::Write as _;
 20use std::sync::{Arc, Mutex};
 21use std::time::Duration;
 22use std::{
 23    fs,
 24    path::{Path, PathBuf},
 25};
 26use unindent::Unindent as _;
 27use util::ResultExt as _;
 28use util::command::new_smol_command;
 29use util::serde::default_true;
 30
 31use crate::AgentAppState;
 32
 33pub const EXAMPLES_DIR: &str = "./crates/eval/examples";
 34pub const REPOS_DIR: &str = "./crates/eval/repos";
 35pub const WORKTREES_DIR: &str = "./crates/eval/worktrees";
 36
 37#[derive(Clone, Debug, Deserialize)]
 38pub struct ExampleBase {
 39    pub url: String,
 40    pub revision: String,
 41    pub language_extension: Option<String>,
 42    pub insert_id: Option<String>,
 43    #[serde(default = "default_true")]
 44    pub require_lsp: bool,
 45}
 46
 47#[derive(Clone, Debug)]
 48pub struct Example {
 49    pub name: String,
 50    /// Content of `base.toml`
 51    pub base: ExampleBase,
 52    /// Content of `prompt.md`
 53    pub prompt: String,
 54    /// Content of `criteria.md`
 55    pub criteria: String,
 56    /// Markdown log file to append to
 57    pub log_file: Arc<Mutex<File>>,
 58}
 59
 60#[derive(Debug, Serialize, Deserialize, Clone)]
 61pub struct RunOutput {
 62    pub repository_diff: String,
 63    pub diagnostics: String,
 64    pub response_count: usize,
 65    pub token_usage: TokenUsage,
 66}
 67
 68#[derive(Debug, Clone, Serialize, Deserialize)]
 69pub struct JudgeInput {
 70    pub repository_diff: String,
 71    pub criteria: String,
 72}
 73
 74#[derive(Debug, Clone, Serialize, Deserialize)]
 75pub struct JudgeOutput {
 76    pub analysis: String,
 77    pub score: u32,
 78}
 79
 80impl Example {
 81    /// Load an example from a directory containing base.toml, prompt.md, and criteria.md
 82    pub fn load_from_directory(dir_path: &Path, run_dir: &Path) -> Result<Self> {
 83        let name = dir_path.file_name().unwrap().to_string_lossy().to_string();
 84        let base_path = dir_path.join("base.toml");
 85        let prompt_path = dir_path.join("prompt.md");
 86        let criteria_path = dir_path.join("criteria.md");
 87
 88        let log_file_path = run_dir.join(format!(
 89            "{}.md",
 90            dir_path.file_name().unwrap().to_str().unwrap()
 91        ));
 92        let log_file = Arc::new(Mutex::new(File::create(&log_file_path).unwrap()));
 93        println!("{}> Logging to {:?}", name, log_file_path);
 94
 95        Ok(Example {
 96            name,
 97            base: toml::from_str(&fs::read_to_string(&base_path)?)?,
 98            prompt: fs::read_to_string(prompt_path.clone())?,
 99            criteria: fs::read_to_string(criteria_path.clone())?,
100            log_file,
101        })
102    }
103
104    pub fn worktree_path(&self) -> PathBuf {
105        Path::new(WORKTREES_DIR)
106            .canonicalize()
107            .context(format!("No such directory {WORKTREES_DIR}"))
108            .unwrap()
109            .join(&self.name)
110    }
111
112    /// Set up the example by checking out the specified Git revision
113    pub async fn setup(&self) -> Result<()> {
114        let repo_path = repo_path_for_url(&self.base.url);
115
116        run_git(
117            &repo_path,
118            &["fetch", "--depth", "1", "origin", &self.base.revision],
119        )
120        .await?;
121
122        let worktree_path = self.worktree_path();
123
124        if worktree_path.is_dir() {
125            println!("{}> Resetting existing worktree", self.name);
126
127            // TODO: consider including "-x" to remove ignored files. The downside of this is that
128            // it will also remove build artifacts, and so prevent incremental reuse there.
129            run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
130            run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
131            run_git(&worktree_path, &["checkout", &self.base.revision]).await?;
132        } else {
133            println!("{}> Creating worktree", self.name);
134
135            let worktree_path_string = worktree_path.to_string_lossy().to_string();
136
137            run_git(
138                &repo_path,
139                &[
140                    "worktree",
141                    "add",
142                    "-f",
143                    &worktree_path_string,
144                    &self.base.revision,
145                ],
146            )
147            .await?;
148        }
149
150        Ok(())
151    }
152
153    pub fn run(
154        &self,
155        model: Arc<dyn LanguageModel>,
156        app_state: Arc<AgentAppState>,
157        cx: &mut App,
158    ) -> Task<Result<RunOutput>> {
159        let project = Project::local(
160            app_state.client.clone(),
161            app_state.node_runtime.clone(),
162            app_state.user_store.clone(),
163            app_state.languages.clone(),
164            Arc::new(DapRegistry::default()),
165            app_state.fs.clone(),
166            None,
167            cx,
168        );
169
170        let worktree_path = self.worktree_path();
171        let worktree = project.update(cx, |project, cx| {
172            project.create_worktree(&worktree_path, true, cx)
173        });
174
175        let tools = Arc::new(ToolWorkingSet::default());
176        let thread_store =
177            ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx);
178        let this = self.clone();
179
180        cx.spawn(async move |cx| {
181            let worktree = worktree.await?;
182
183            // Wait for worktree scan to finish before choosing a file to open.
184            worktree
185                .update(cx, |worktree, _cx| {
186                    worktree.as_local().unwrap().scan_complete()
187                })?
188                .await;
189
190            let lsp_open_handle_and_store = if this.base.require_lsp {
191                let language_extension = this.base.language_extension.as_deref().context(
192                    "language_extension field is required in base.toml when `require_lsp == true`",
193                )?;
194
195                // Open a file that matches the language to cause LSP to start.
196                let language_file = worktree.read_with(cx, |worktree, _cx| {
197                    worktree
198                        .files(false, 0)
199                        .find_map(|e| {
200                            if e.path.clone().extension().and_then(|ext| ext.to_str())
201                                == Some(language_extension)
202                            {
203                                Some(ProjectPath {
204                                    worktree_id: worktree.id(),
205                                    path: e.path.clone(),
206                                })
207                            } else {
208                                None
209                            }
210                        })
211                        .context("Failed to find a file for example language")
212                })??;
213
214                let open_language_file_buffer_task = project.update(cx, |project, cx| {
215                    project.open_buffer(language_file.clone(), cx)
216                })?;
217
218                let language_file_buffer = open_language_file_buffer_task.await?;
219
220                let (lsp_open_handle, lsp_store) = project.update(cx, |project, cx| {
221                    (
222                        project.register_buffer_with_language_servers(&language_file_buffer, cx),
223                        project.lsp_store().clone(),
224                    )
225                })?;
226
227                // TODO: remove this once the diagnostics tool waits for new diagnostics
228                cx.background_executor().timer(Duration::new(5, 0)).await;
229                wait_for_lang_server(&lsp_store, this.name.clone(), cx).await?;
230
231                lsp_store.update(cx, |lsp_store, cx| {
232                    lsp_open_handle.update(cx, |buffer, cx| {
233                        buffer.update(cx, |buffer, cx| {
234                            let has_language_server = lsp_store
235                                .language_servers_for_local_buffer(buffer, cx)
236                                .next()
237                                .is_some();
238                            if has_language_server {
239                                Ok(())
240                            } else {
241                                Err(anyhow!(
242                                    "`{:?}` was opened to cause the language server to start, \
243                                    but no language servers are registered for its buffer. \
244                                    Set `require_lsp = false` in `base.toml` to skip this.",
245                                    language_file
246                                ))
247                            }
248                        })
249                    })
250                })??;
251
252                Some((lsp_open_handle, lsp_store))
253            } else {
254                None
255            };
256
257            if std::env::var("ZED_EVAL_SETUP_ONLY").is_ok() {
258                return Err(anyhow!("Setup only mode"));
259            }
260
261            let thread_store = thread_store.await;
262            let thread =
263                thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?;
264
265            {
266                let mut log_file = this.log_file.lock().unwrap();
267                writeln!(&mut log_file, "👤 USER:").log_err();
268                writeln!(&mut log_file, "{}", this.prompt).log_err();
269                writeln!(&mut log_file, "🤖 ASSISTANT:").log_err();
270                log_file.flush().log_err();
271            }
272
273            let (tx, rx) = oneshot::channel();
274            let mut tx = Some(tx);
275
276            let _subscription = cx.subscribe(&thread, {
277                let log_file = this.log_file.clone();
278                let name = this.name.clone();
279                move |thread, event: &ThreadEvent, cx| {
280                    let mut log_file = log_file.lock().unwrap();
281
282                    match event {
283                        ThreadEvent::Stopped(reason) => match reason {
284                            Ok(StopReason::EndTurn) => {
285                                if let Some(tx) = tx.take() {
286                                    tx.send(Ok(())).ok();
287                                }
288                            }
289                            Ok(StopReason::MaxTokens) => {
290                                if let Some(tx) = tx.take() {
291                                    tx.send(Err(anyhow!("Exceeded maximum tokens"))).ok();
292                                }
293                            }
294                            Ok(StopReason::ToolUse) => {}
295                            Err(error) => {
296                                if let Some(tx) = tx.take() {
297                                    tx.send(Err(anyhow!(error.clone()))).ok();
298                                }
299                            }
300                        },
301                        ThreadEvent::ShowError(thread_error) => {
302                            if let Some(tx) = tx.take() {
303                                tx.send(Err(anyhow!(thread_error.clone()))).ok();
304                            }
305                        }
306                        ThreadEvent::StreamedAssistantText(_, chunk) => {
307                            write!(&mut log_file, "{}", chunk).log_err();
308                        }
309                        ThreadEvent::StreamedAssistantThinking(_, chunk) => {
310                            write!(&mut log_file, "{}", chunk).log_err();
311                        }
312                        ThreadEvent::UsePendingTools { tool_uses } => {
313                            writeln!(&mut log_file, "\n\nUSING TOOLS:").log_err();
314                            for tool_use in tool_uses {
315                                writeln!(&mut log_file, "{}: {}", tool_use.name, tool_use.input)
316                                    .log_err();
317                            }
318                        }
319                        ThreadEvent::ToolFinished {
320                            tool_use_id,
321                            pending_tool_use,
322                            ..
323                        } => {
324                            if let Some(tool_use) = pending_tool_use {
325                                let message = format!("TOOL FINISHED: {}", tool_use.name);
326                                println!("{name}> {message}");
327                                writeln!(&mut log_file, "\n{}", message).log_err();
328                            }
329                            if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
330                                let message = format!("\n{}\n", tool_result.content);
331                                writeln!(&mut log_file, "{}", message).log_err();
332                            }
333                        }
334                        _ => {}
335                    }
336
337                    log_file.flush().log_err();
338                }
339            })?;
340
341            thread.update(cx, |thread, cx| {
342                let context = vec![];
343                thread.insert_user_message(this.prompt.clone(), context, None, cx);
344                thread.send_to_model(model, RequestKind::Chat, cx);
345            })?;
346
347            rx.await??;
348
349            if let Some((_, lsp_store)) = lsp_open_handle_and_store.as_ref() {
350                wait_for_lang_server(lsp_store, this.name.clone(), cx).await?;
351            }
352
353            let repository_diff = this.repository_diff().await?;
354            let diagnostics = cx
355                .update(move |cx| {
356                    cx.spawn(async move |cx| query_lsp_diagnostics(project, cx).await)
357                })?
358                .await?;
359
360            drop(lsp_open_handle_and_store);
361
362            thread.update(cx, |thread, _cx| {
363                let response_count = thread
364                    .messages()
365                    .filter(|message| message.role == language_model::Role::Assistant)
366                    .count();
367                RunOutput {
368                    repository_diff,
369                    diagnostics,
370                    response_count,
371                    token_usage: thread.cumulative_token_usage(),
372                }
373            })
374        })
375    }
376
377    pub async fn judge(
378        &mut self,
379        model: Arc<dyn LanguageModel>,
380        repository_diff: String,
381        cx: &AsyncApp,
382    ) -> Result<JudgeOutput> {
383        let judge_prompt = include_str!("judge_prompt.hbs");
384        let judge_prompt_name = "judge_prompt";
385        let mut handlebars = Handlebars::new();
386        handlebars.register_template_string(judge_prompt_name, judge_prompt)?;
387        let prompt = handlebars.render(
388            judge_prompt_name,
389            &JudgeInput {
390                repository_diff,
391                criteria: self.criteria.clone(),
392            },
393        )?;
394
395        let request = LanguageModelRequest {
396            messages: vec![LanguageModelRequestMessage {
397                role: Role::User,
398                content: vec![MessageContent::Text(prompt)],
399                cache: false,
400            }],
401            temperature: None,
402            tools: Vec::new(),
403            stop: Vec::new(),
404        };
405
406        let response = send_language_model_request(model, request, cx).await?;
407
408        let mut log_file = self.log_file.lock().unwrap();
409
410        writeln!(&mut log_file, "\n\n").log_err();
411        writeln!(&mut log_file, "========================================").log_err();
412        writeln!(&mut log_file, "              JUDGE OUTPUT              ").log_err();
413        writeln!(&mut log_file, "========================================").log_err();
414        writeln!(&mut log_file, "\n{}", &response).log_err();
415
416        parse_judge_output(&response)
417    }
418
419    pub async fn repository_diff(&self) -> Result<String> {
420        let worktree_path = self.worktree_path();
421        run_git(&worktree_path, &["add", "-N"]).await?;
422        run_git(&worktree_path, &["diff"]).await
423    }
424}
425
426fn wait_for_lang_server(
427    lsp_store: &Entity<LspStore>,
428    name: String,
429    cx: &mut AsyncApp,
430) -> Task<Result<()>> {
431    if cx
432        .update(|cx| !has_pending_lang_server_work(lsp_store, cx))
433        .unwrap()
434        || std::env::var("ZED_EVAL_SKIP_LS_WAIT").is_ok()
435    {
436        return Task::ready(anyhow::Ok(()));
437    }
438
439    println!("{}> ⏵ Waiting for language server", name);
440
441    let (mut tx, mut rx) = mpsc::channel(1);
442
443    let subscription =
444        cx.subscribe(&lsp_store, {
445            let name = name.clone();
446            move |lsp_store, event, cx| {
447                match event {
448                    project::LspStoreEvent::LanguageServerUpdate {
449                        message:
450                            client::proto::update_language_server::Variant::WorkProgress(
451                                LspWorkProgress {
452                                    message: Some(message),
453                                    ..
454                                },
455                            ),
456                        ..
457                    } => println!("{name}> ⟲ {message}"),
458                    _ => {}
459                }
460
461                if !has_pending_lang_server_work(&lsp_store, cx) {
462                    tx.try_send(()).ok();
463                }
464            }
465        });
466
467    cx.spawn(async move |cx| {
468        let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0));
469        let result = futures::select! {
470            _ = rx.next() => {
471                println!("{}> ⚑ Language server idle", name);
472                anyhow::Ok(())
473            },
474            _ = timeout.fuse() => {
475                Err(anyhow!("LSP wait timed out after 5 minutes"))
476            }
477        };
478        drop(subscription);
479        result
480    })
481}
482
483fn has_pending_lang_server_work(lsp_store: &Entity<LspStore>, cx: &App) -> bool {
484    lsp_store
485        .read(cx)
486        .language_server_statuses()
487        .any(|(_, status)| !status.pending_work.is_empty())
488}
489
490async fn query_lsp_diagnostics(project: Entity<Project>, cx: &mut AsyncApp) -> Result<String> {
491    let paths_with_diagnostics = project.update(cx, |project, cx| {
492        project
493            .diagnostic_summaries(true, cx)
494            .filter(|(_, _, summary)| summary.error_count > 0 || summary.warning_count > 0)
495            .map(|(project_path, _, _)| project_path)
496            .collect::<Vec<_>>()
497    })?;
498
499    let mut output = String::new();
500    for project_path in paths_with_diagnostics {
501        let buffer = project
502            .update(cx, |project, cx| project.open_buffer(project_path, cx))?
503            .await?;
504        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
505
506        for (_, group) in snapshot.diagnostic_groups(None) {
507            let entry = &group.entries[group.primary_ix];
508            let range = entry.range.to_point(&snapshot);
509            let severity = match entry.diagnostic.severity {
510                DiagnosticSeverity::ERROR => "error",
511                DiagnosticSeverity::WARNING => "warning",
512                _ => continue,
513            };
514
515            writeln!(
516                output,
517                "{} at line {}: {}",
518                severity,
519                range.start.row + 1,
520                entry.diagnostic.message
521            )?;
522        }
523    }
524    anyhow::Ok(output)
525}
526
527fn parse_judge_output(response: &str) -> Result<JudgeOutput> {
528    let analysis = get_tag("analysis", response)?.to_string();
529    let score = get_tag("score", response)?
530        .parse()
531        .context("error parsing score")?;
532
533    Ok(JudgeOutput { analysis, score })
534}
535
536fn get_tag(name: &'static str, response: &str) -> Result<String> {
537    let start_tag = format!("<{}>", name);
538    let end_tag = format!("</{}>", name);
539
540    let start_ix = response
541        .find(&start_tag)
542        .context(format!("{} start tag not found", name))?;
543    let content_start_ix = start_ix + start_tag.len();
544
545    let end_ix = content_start_ix
546        + response[content_start_ix..]
547            .find(&end_tag)
548            .context(format!("{} end tag not found", name))?;
549
550    let content = response[content_start_ix..end_ix].trim().unindent();
551
552    anyhow::Ok(content)
553}
554
555pub fn repo_path_for_url(repo_url: &str) -> PathBuf {
556    let repo_name = repo_url
557        .trim_start_matches("https://")
558        .replace(|c: char| !c.is_alphanumeric(), "-");
559    Path::new(REPOS_DIR)
560        .canonicalize()
561        .context(format!("No such directory {REPOS_DIR}"))
562        .unwrap()
563        .join(repo_name)
564}
565
566pub async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
567    let output = new_smol_command("git")
568        .current_dir(repo_path)
569        .args(args)
570        .output()
571        .await?;
572
573    if output.status.success() {
574        Ok(String::from_utf8(output.stdout)?.trim().to_string())
575    } else {
576        Err(anyhow!(
577            "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
578            args.join(" "),
579            repo_path.display(),
580            output.status,
581            String::from_utf8_lossy(&output.stderr),
582            String::from_utf8_lossy(&output.stdout),
583        ))
584    }
585}
586
587pub async fn send_language_model_request(
588    model: Arc<dyn LanguageModel>,
589    request: LanguageModelRequest,
590    cx: &AsyncApp,
591) -> anyhow::Result<String> {
592    match model.stream_completion_text(request, &cx).await {
593        Ok(mut stream) => {
594            let mut full_response = String::new();
595            while let Some(chunk_result) = stream.stream.next().await {
596                match chunk_result {
597                    Ok(chunk_str) => {
598                        print!("{}", &chunk_str);
599                        full_response.push_str(&chunk_str);
600                    }
601                    Err(err) => {
602                        return Err(anyhow!(
603                            "Error receiving response from language model: {err}"
604                        ));
605                    }
606                }
607            }
608            Ok(full_response)
609        }
610        Err(err) => Err(anyhow!(
611            "Failed to get response from language model. Error was: {err}"
612        )),
613    }
614}
615
616#[cfg(test)]
617mod test {
618    use super::*;
619
620    #[test]
621    fn test_parse_judge_output() {
622        let response = r#"
623            <analysis>The model did a good job but there were still compilations errors.</analysis>
624            <score>3</score>
625        "#
626        .unindent();
627
628        let output = parse_judge_output(&response).unwrap();
629        assert_eq!(
630            output.analysis,
631            "The model did a good job but there were still compilations errors."
632        );
633        assert_eq!(output.score, 3);
634
635        let response = r#"
636            Text around ignored
637
638            <analysis>
639                Failed to compile:
640                - Error 1
641                - Error 2
642            </analysis>
643
644            <score>1</score>
645        "#
646        .unindent();
647
648        let output = parse_judge_output(&response).unwrap();
649        assert_eq!(output.analysis, "Failed to compile:\n- Error 1\n- Error 2");
650        assert_eq!(output.score, 1);
651    }
652}