example.rs

  1use agent::{RequestKind, ThreadEvent, ThreadStore};
  2use anyhow::{Context as _, Result, anyhow};
  3use assistant_tool::ToolWorkingSet;
  4use client::proto::LspWorkProgress;
  5use collections::HashMap;
  6use dap::DapRegistry;
  7use futures::channel::mpsc;
  8use futures::{FutureExt, StreamExt as _, select_biased};
  9use gpui::{App, AsyncApp, Entity, Task};
 10use handlebars::Handlebars;
 11use language::{DiagnosticSeverity, OffsetRangeExt};
 12use language_model::{
 13    LanguageModel, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
 14    StopReason, TokenUsage,
 15};
 16use project::{LspStore, Project, ProjectPath};
 17use serde::{Deserialize, Serialize};
 18use std::fmt::Write as _;
 19use std::fs::File;
 20use std::io::Write as _;
 21use std::sync::{Arc, Mutex};
 22use std::time::Duration;
 23use std::{
 24    fs,
 25    path::{Path, PathBuf},
 26};
 27use unindent::Unindent as _;
 28use util::ResultExt as _;
 29use util::command::new_smol_command;
 30use util::serde::default_true;
 31
 32use crate::AgentAppState;
 33
 34pub const EXAMPLES_DIR: &str = "./crates/eval/examples";
 35pub const REPOS_DIR: &str = "./crates/eval/repos";
 36pub const WORKTREES_DIR: &str = "./crates/eval/worktrees";
 37
 38const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2);
 39
 40#[derive(Clone, Debug, Deserialize)]
 41pub struct ExampleBase {
 42    pub url: String,
 43    pub revision: String,
 44    pub language_extension: Option<String>,
 45    pub insert_id: Option<String>,
 46    #[serde(default = "default_true")]
 47    pub require_lsp: bool,
 48}
 49
 50#[derive(Clone, Debug)]
 51pub struct Example {
 52    pub name: String,
 53    /// Content of `base.toml`
 54    pub base: ExampleBase,
 55    /// Content of `prompt.md`
 56    pub prompt: String,
 57    /// Content of `criteria.md`
 58    pub criteria: String,
 59    /// Markdown log file to append to
 60    pub log_file: Arc<Mutex<File>>,
 61    /// Path to markdown log file
 62    pub log_file_path: PathBuf,
 63}
 64
 65#[derive(Debug, Serialize, Deserialize, Clone)]
 66pub struct RunOutput {
 67    pub repository_diff: String,
 68    pub diagnostics: String,
 69    pub response_count: usize,
 70    pub token_usage: TokenUsage,
 71    pub tool_use_counts: HashMap<Arc<str>, u32>,
 72}
 73
 74#[derive(Debug, Clone, Serialize, Deserialize)]
 75pub struct JudgeInput {
 76    pub repository_diff: String,
 77    pub criteria: String,
 78}
 79
 80#[derive(Debug, Clone, Serialize, Deserialize)]
 81pub struct JudgeOutput {
 82    pub analysis: String,
 83    pub score: u32,
 84}
 85
 86impl Example {
 87    /// Load an example from a directory containing base.toml, prompt.md, and criteria.md
 88    pub fn load_from_directory(dir_path: &Path, run_dir: &Path) -> Result<Self> {
 89        let name = dir_path.file_name().unwrap().to_string_lossy().to_string();
 90        let base_path = dir_path.join("base.toml");
 91        let prompt_path = dir_path.join("prompt.md");
 92        let criteria_path = dir_path.join("criteria.md");
 93
 94        let log_file_path = run_dir.join(format!(
 95            "{}.md",
 96            dir_path.file_name().unwrap().to_str().unwrap()
 97        ));
 98        let log_file = Arc::new(Mutex::new(File::create(&log_file_path).unwrap()));
 99        println!("{}> Logging to {:?}", name, log_file_path);
100
101        Ok(Example {
102            name,
103            base: toml::from_str(&fs::read_to_string(&base_path)?)?,
104            prompt: fs::read_to_string(prompt_path.clone())?,
105            criteria: fs::read_to_string(criteria_path.clone())?,
106            log_file,
107            log_file_path,
108        })
109    }
110
111    pub fn worktree_path(&self) -> PathBuf {
112        Path::new(WORKTREES_DIR)
113            .canonicalize()
114            .context(format!("No such directory {WORKTREES_DIR}"))
115            .unwrap()
116            .join(&self.name)
117    }
118
119    /// Set up the example by checking out the specified Git revision
120    pub async fn setup(&self) -> Result<()> {
121        let repo_path = repo_path_for_url(&self.base.url);
122
123        println!("{}> Fetching", self.name);
124
125        run_git(
126            &repo_path,
127            &["fetch", "--depth", "1", "origin", &self.base.revision],
128        )
129        .await?;
130
131        let worktree_path = self.worktree_path();
132
133        if worktree_path.is_dir() {
134            println!("{}> Resetting existing worktree", self.name);
135
136            // TODO: consider including "-x" to remove ignored files. The downside of this is that
137            // it will also remove build artifacts, and so prevent incremental reuse there.
138            run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
139            run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
140            run_git(&worktree_path, &["checkout", &self.base.revision]).await?;
141        } else {
142            println!("{}> Creating worktree", self.name);
143
144            let worktree_path_string = worktree_path.to_string_lossy().to_string();
145
146            run_git(
147                &repo_path,
148                &[
149                    "worktree",
150                    "add",
151                    "-f",
152                    &worktree_path_string,
153                    &self.base.revision,
154                ],
155            )
156            .await?;
157        }
158
159        Ok(())
160    }
161
162    pub fn run(
163        &self,
164        model: Arc<dyn LanguageModel>,
165        app_state: Arc<AgentAppState>,
166        cx: &mut App,
167    ) -> Task<Result<RunOutput>> {
168        let project = Project::local(
169            app_state.client.clone(),
170            app_state.node_runtime.clone(),
171            app_state.user_store.clone(),
172            app_state.languages.clone(),
173            Arc::new(DapRegistry::default()),
174            app_state.fs.clone(),
175            None,
176            cx,
177        );
178
179        let worktree_path = self.worktree_path();
180        let worktree = project.update(cx, |project, cx| {
181            project.create_worktree(&worktree_path, true, cx)
182        });
183
184        let tools = Arc::new(ToolWorkingSet::default());
185        let thread_store =
186            ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx);
187        let this = self.clone();
188
189        cx.spawn(async move |cx| {
190            let worktree = worktree.await?;
191
192            // Wait for worktree scan to finish before choosing a file to open.
193            worktree
194                .update(cx, |worktree, _cx| {
195                    worktree.as_local().unwrap().scan_complete()
196                })?
197                .await;
198
199            let lsp_open_handle_and_store = if this.base.require_lsp {
200                let language_extension = this.base.language_extension.as_deref().context(
201                    "language_extension field is required in base.toml when `require_lsp == true`",
202                )?;
203
204                // Open a file that matches the language to cause LSP to start.
205                let language_file = worktree.read_with(cx, |worktree, _cx| {
206                    worktree
207                        .files(false, 0)
208                        .find_map(|e| {
209                            if e.path.clone().extension().and_then(|ext| ext.to_str())
210                                == Some(language_extension)
211                            {
212                                Some(ProjectPath {
213                                    worktree_id: worktree.id(),
214                                    path: e.path.clone(),
215                                })
216                            } else {
217                                None
218                            }
219                        })
220                        .context("Failed to find a file for example language")
221                })??;
222
223                let open_language_file_buffer_task = project.update(cx, |project, cx| {
224                    project.open_buffer(language_file.clone(), cx)
225                })?;
226
227                let language_file_buffer = open_language_file_buffer_task.await?;
228
229                let (lsp_open_handle, lsp_store) = project.update(cx, |project, cx| {
230                    (
231                        project.register_buffer_with_language_servers(&language_file_buffer, cx),
232                        project.lsp_store().clone(),
233                    )
234                })?;
235
236                // TODO: remove this once the diagnostics tool waits for new diagnostics
237                cx.background_executor().timer(Duration::new(5, 0)).await;
238                wait_for_lang_server(&lsp_store, this.name.clone(), cx).await?;
239
240                lsp_store.update(cx, |lsp_store, cx| {
241                    lsp_open_handle.update(cx, |buffer, cx| {
242                        buffer.update(cx, |buffer, cx| {
243                            let has_language_server = lsp_store
244                                .language_servers_for_local_buffer(buffer, cx)
245                                .next()
246                                .is_some();
247                            if has_language_server {
248                                Ok(())
249                            } else {
250                                Err(anyhow!(
251                                    "`{:?}` was opened to cause the language server to start, \
252                                    but no language servers are registered for its buffer. \
253                                    Set `require_lsp = false` in `base.toml` to skip this.",
254                                    language_file
255                                ))
256                            }
257                        })
258                    })
259                })??;
260
261                Some((lsp_open_handle, lsp_store))
262            } else {
263                None
264            };
265
266            if std::env::var("ZED_EVAL_SETUP_ONLY").is_ok() {
267                return Err(anyhow!("Setup only mode"));
268            }
269
270            let thread_store = thread_store.await;
271            let thread =
272                thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?;
273
274            {
275                let mut log_file = this.log_file.lock().unwrap();
276                writeln!(&mut log_file, "👤 USER:").log_err();
277                writeln!(&mut log_file, "{}", this.prompt).log_err();
278                writeln!(&mut log_file, "🤖 ASSISTANT:").log_err();
279                log_file.flush().log_err();
280            }
281
282            let tool_use_counts: Arc<Mutex<HashMap<Arc<str>, u32>>> =
283                Mutex::new(HashMap::default()).into();
284
285            let (thread_event_tx, mut thread_event_rx) = mpsc::unbounded();
286
287            let subscription = cx.subscribe(&thread, move |_thread, event: &ThreadEvent, _cx| {
288                thread_event_tx.unbounded_send(event.clone()).log_err();
289            });
290
291            let event_handler_task = cx.spawn({
292                let log_file = this.log_file.clone();
293                let name = this.name.clone();
294                let tool_use_counts = tool_use_counts.clone();
295                let thread = thread.downgrade();
296                async move |cx| {
297                    loop {
298                        let event = select_biased! {
299                            event = thread_event_rx.next() => event,
300                            _ = cx.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => {
301                                return Err(anyhow!("Agentic loop stalled - waited {:?} without any events", THREAD_EVENT_TIMEOUT));
302                            }
303                        };
304                        let Some(event) = event else {
305                            return Err(anyhow!("ThreadEvent channel ended early"));
306                        };
307
308                        let mut log_file = log_file.lock().unwrap();
309
310                        match event {
311                            ThreadEvent::Stopped(reason) => match reason {
312                                Ok(StopReason::EndTurn) => {
313                                    return Ok(());
314                                }
315                                Ok(StopReason::MaxTokens) => {
316                                    return Err(anyhow!("Exceeded maximum tokens"));
317                                }
318                                Ok(StopReason::ToolUse) => {}
319                                Err(error) => {
320                                    return Err(anyhow!(error.clone()));
321                                }
322                            },
323                            ThreadEvent::ShowError(thread_error) => {
324                                break Err(anyhow!(thread_error.clone()));
325                            }
326                            ThreadEvent::StreamedAssistantText(_, chunk) => {
327                                write!(&mut log_file, "{}", chunk).log_err();
328                            }
329                            ThreadEvent::StreamedAssistantThinking(_, chunk) => {
330                                write!(&mut log_file, "{}", chunk).log_err();
331                            }
332                            ThreadEvent::UsePendingTools { tool_uses } => {
333                                writeln!(&mut log_file, "\n\nUSING TOOLS:").log_err();
334                                for tool_use in tool_uses {
335                                    writeln!(&mut log_file, "{}: {}", tool_use.name, tool_use.input)
336                                        .log_err();
337                                }
338                            }
339                            ThreadEvent::ToolFinished {
340                                tool_use_id,
341                                pending_tool_use,
342                                ..
343                            } => {
344                                if let Some(tool_use) = pending_tool_use {
345                                    let message = format!("TOOL FINISHED: {}", tool_use.name);
346                                    println!("{name}> {message}");
347                                    writeln!(&mut log_file, "\n{}", message).log_err();
348                                }
349                                thread.update(cx, |thread, _cx| {
350                                    if let Some(tool_result) = thread.tool_result(&tool_use_id) {
351                                        writeln!(&mut log_file, "\n{}\n", tool_result.content).log_err();
352                                        let mut tool_use_counts = tool_use_counts.lock().unwrap();
353                                        *tool_use_counts
354                                            .entry(tool_result.tool_name.clone())
355                                            .or_insert(0) += 1;
356                                    }
357                                })?;
358                            }
359                            _ => {}
360                        }
361
362                        log_file.flush().log_err();
363                    }
364                }
365            });
366
367            thread.update(cx, |thread, cx| {
368                let context = vec![];
369                thread.insert_user_message(this.prompt.clone(), context, None, cx);
370                thread.send_to_model(model, RequestKind::Chat, cx);
371            })?;
372
373            event_handler_task.await?;
374
375            if let Some((_, lsp_store)) = lsp_open_handle_and_store.as_ref() {
376                wait_for_lang_server(lsp_store, this.name.clone(), cx).await?;
377            }
378
379            let repository_diff = this.repository_diff().await?;
380            let diagnostics = cx
381                .update(move |cx| {
382                    cx.spawn(async move |cx| query_lsp_diagnostics(project, cx).await)
383                })?
384                .await?;
385
386            drop(subscription);
387            drop(lsp_open_handle_and_store);
388
389            thread.update(cx, |thread, _cx| {
390                let response_count = thread
391                    .messages()
392                    .filter(|message| message.role == language_model::Role::Assistant)
393                    .count();
394                RunOutput {
395                    repository_diff,
396                    diagnostics,
397                    response_count,
398                    token_usage: thread.cumulative_token_usage(),
399                    tool_use_counts: tool_use_counts.lock().unwrap().clone(),
400                }
401            })
402        })
403    }
404
405    pub async fn judge(
406        &self,
407        model: Arc<dyn LanguageModel>,
408        repository_diff: String,
409        cx: &AsyncApp,
410    ) -> Result<JudgeOutput> {
411        let judge_prompt = include_str!("judge_prompt.hbs");
412        let judge_prompt_name = "judge_prompt";
413        let mut handlebars = Handlebars::new();
414        handlebars.register_template_string(judge_prompt_name, judge_prompt)?;
415        let prompt = handlebars.render(
416            judge_prompt_name,
417            &JudgeInput {
418                repository_diff,
419                criteria: self.criteria.clone(),
420            },
421        )?;
422
423        let request = LanguageModelRequest {
424            messages: vec![LanguageModelRequestMessage {
425                role: Role::User,
426                content: vec![MessageContent::Text(prompt)],
427                cache: false,
428            }],
429            temperature: None,
430            tools: Vec::new(),
431            stop: Vec::new(),
432        };
433
434        let response = send_language_model_request(model, request, cx).await?;
435
436        let mut log_file = self.log_file.lock().unwrap();
437
438        writeln!(&mut log_file, "\n\n").log_err();
439        writeln!(&mut log_file, "========================================").log_err();
440        writeln!(&mut log_file, "              JUDGE OUTPUT              ").log_err();
441        writeln!(&mut log_file, "========================================").log_err();
442        writeln!(&mut log_file, "\n{}", &response).log_err();
443
444        parse_judge_output(&response)
445    }
446
447    pub async fn repository_diff(&self) -> Result<String> {
448        let worktree_path = self.worktree_path();
449        run_git(&worktree_path, &["add", "-N"]).await?;
450        run_git(&worktree_path, &["diff"]).await
451    }
452}
453
454fn wait_for_lang_server(
455    lsp_store: &Entity<LspStore>,
456    name: String,
457    cx: &mut AsyncApp,
458) -> Task<Result<()>> {
459    if cx
460        .update(|cx| !has_pending_lang_server_work(lsp_store, cx))
461        .unwrap()
462        || std::env::var("ZED_EVAL_SKIP_LS_WAIT").is_ok()
463    {
464        return Task::ready(anyhow::Ok(()));
465    }
466
467    println!("{}> ⏵ Waiting for language server", name);
468
469    let (mut tx, mut rx) = mpsc::channel(1);
470
471    let subscription =
472        cx.subscribe(&lsp_store, {
473            let name = name.clone();
474            move |lsp_store, event, cx| {
475                match event {
476                    project::LspStoreEvent::LanguageServerUpdate {
477                        message:
478                            client::proto::update_language_server::Variant::WorkProgress(
479                                LspWorkProgress {
480                                    message: Some(message),
481                                    ..
482                                },
483                            ),
484                        ..
485                    } => println!("{name}> ⟲ {message}"),
486                    _ => {}
487                }
488
489                if !has_pending_lang_server_work(&lsp_store, cx) {
490                    tx.try_send(()).ok();
491                }
492            }
493        });
494
495    cx.spawn(async move |cx| {
496        let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0));
497        let result = futures::select! {
498            _ = rx.next() => {
499                println!("{}> ⚑ Language server idle", name);
500                anyhow::Ok(())
501            },
502            _ = timeout.fuse() => {
503                Err(anyhow!("LSP wait timed out after 5 minutes"))
504            }
505        };
506        drop(subscription);
507        result
508    })
509}
510
511fn has_pending_lang_server_work(lsp_store: &Entity<LspStore>, cx: &App) -> bool {
512    lsp_store
513        .read(cx)
514        .language_server_statuses()
515        .any(|(_, status)| !status.pending_work.is_empty())
516}
517
518async fn query_lsp_diagnostics(project: Entity<Project>, cx: &mut AsyncApp) -> Result<String> {
519    let paths_with_diagnostics = project.update(cx, |project, cx| {
520        project
521            .diagnostic_summaries(true, cx)
522            .filter(|(_, _, summary)| summary.error_count > 0 || summary.warning_count > 0)
523            .map(|(project_path, _, _)| project_path)
524            .collect::<Vec<_>>()
525    })?;
526
527    let mut output = String::new();
528    for project_path in paths_with_diagnostics {
529        let buffer = project
530            .update(cx, |project, cx| project.open_buffer(project_path, cx))?
531            .await?;
532        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
533
534        for (_, group) in snapshot.diagnostic_groups(None) {
535            let entry = &group.entries[group.primary_ix];
536            let range = entry.range.to_point(&snapshot);
537            let severity = match entry.diagnostic.severity {
538                DiagnosticSeverity::ERROR => "error",
539                DiagnosticSeverity::WARNING => "warning",
540                _ => continue,
541            };
542
543            writeln!(
544                output,
545                "{} at line {}: {}",
546                severity,
547                range.start.row + 1,
548                entry.diagnostic.message
549            )?;
550        }
551    }
552    anyhow::Ok(output)
553}
554
555fn parse_judge_output(response: &str) -> Result<JudgeOutput> {
556    let analysis = get_tag("analysis", response)?.to_string();
557    let score = get_tag("score", response)?
558        .parse()
559        .context("error parsing score")?;
560
561    Ok(JudgeOutput { analysis, score })
562}
563
564fn get_tag(name: &'static str, response: &str) -> Result<String> {
565    let start_tag = format!("<{}>", name);
566    let end_tag = format!("</{}>", name);
567
568    let start_ix = response
569        .find(&start_tag)
570        .context(format!("{} start tag not found", name))?;
571    let content_start_ix = start_ix + start_tag.len();
572
573    let end_ix = content_start_ix
574        + response[content_start_ix..]
575            .find(&end_tag)
576            .context(format!("{} end tag not found", name))?;
577
578    let content = response[content_start_ix..end_ix].trim().unindent();
579
580    anyhow::Ok(content)
581}
582
583pub fn repo_path_for_url(repo_url: &str) -> PathBuf {
584    let repo_name = repo_url
585        .trim_start_matches("https://")
586        .replace(|c: char| !c.is_alphanumeric(), "-");
587    Path::new(REPOS_DIR)
588        .canonicalize()
589        .context(format!("No such directory {REPOS_DIR}"))
590        .unwrap()
591        .join(repo_name)
592}
593
594pub async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
595    let output = new_smol_command("git")
596        .current_dir(repo_path)
597        .args(args)
598        .output()
599        .await?;
600
601    if output.status.success() {
602        Ok(String::from_utf8(output.stdout)?.trim().to_string())
603    } else {
604        Err(anyhow!(
605            "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
606            args.join(" "),
607            repo_path.display(),
608            output.status,
609            String::from_utf8_lossy(&output.stderr),
610            String::from_utf8_lossy(&output.stdout),
611        ))
612    }
613}
614
615pub async fn send_language_model_request(
616    model: Arc<dyn LanguageModel>,
617    request: LanguageModelRequest,
618    cx: &AsyncApp,
619) -> anyhow::Result<String> {
620    match model.stream_completion_text(request, &cx).await {
621        Ok(mut stream) => {
622            let mut full_response = String::new();
623            while let Some(chunk_result) = stream.stream.next().await {
624                match chunk_result {
625                    Ok(chunk_str) => {
626                        print!("{}", &chunk_str);
627                        full_response.push_str(&chunk_str);
628                    }
629                    Err(err) => {
630                        return Err(anyhow!(
631                            "Error receiving response from language model: {err}"
632                        ));
633                    }
634                }
635            }
636            Ok(full_response)
637        }
638        Err(err) => Err(anyhow!(
639            "Failed to get response from language model. Error was: {err}"
640        )),
641    }
642}
643
644#[cfg(test)]
645mod test {
646    use super::*;
647
648    #[test]
649    fn test_parse_judge_output() {
650        let response = r#"
651            <analysis>The model did a good job but there were still compilations errors.</analysis>
652            <score>3</score>
653        "#
654        .unindent();
655
656        let output = parse_judge_output(&response).unwrap();
657        assert_eq!(
658            output.analysis,
659            "The model did a good job but there were still compilations errors."
660        );
661        assert_eq!(output.score, 3);
662
663        let response = r#"
664            Text around ignored
665
666            <analysis>
667                Failed to compile:
668                - Error 1
669                - Error 2
670            </analysis>
671
672            <score>1</score>
673        "#
674        .unindent();
675
676        let output = parse_judge_output(&response).unwrap();
677        assert_eq!(output.analysis, "Failed to compile:\n- Error 1\n- Error 2");
678        assert_eq!(output.score, 1);
679    }
680}