example.rs

   1use crate::{AgentAppState, ToolMetrics};
   2use agent::{ThreadEvent, ThreadStore};
   3use anyhow::{Context as _, Result, anyhow};
   4use assistant_tool::ToolWorkingSet;
   5use client::proto::LspWorkProgress;
   6use futures::channel::mpsc;
   7use futures::{FutureExt, StreamExt as _, select_biased};
   8use gpui::{App, AppContext as _, AsyncApp, Entity, Task};
   9use handlebars::Handlebars;
  10use language::{Buffer, DiagnosticSeverity, OffsetRangeExt};
  11use language_model::{
  12    LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
  13    MessageContent, Role, StopReason, TokenUsage,
  14};
  15use project::{Project, ProjectPath};
  16use serde::{Deserialize, Serialize};
  17use std::cell::RefCell;
  18use std::fmt::Write as _;
  19use std::fs::File;
  20use std::io::Write as _;
  21use std::rc::Rc;
  22use std::sync::{Arc, Mutex};
  23use std::time::Duration;
  24use std::{
  25    fs,
  26    path::{Path, PathBuf},
  27};
  28use unindent::Unindent as _;
  29use util::ResultExt as _;
  30use util::command::new_smol_command;
  31use util::markdown::MarkdownString;
  32use util::serde::default_true;
  33
  34const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2);
  35
  36const ZED_REPO_URL: &str = "https://github.com/zed-industries/zed.git";
  37
  38#[derive(Clone, Debug, Deserialize)]
  39pub struct ExampleBase {
  40    pub url: String,
  41    pub revision: String,
  42    pub language_extension: Option<String>,
  43    pub insert_id: Option<String>,
  44    #[serde(default = "default_true")]
  45    pub require_lsp: bool,
  46    #[serde(default)]
  47    pub allow_preexisting_diagnostics: bool,
  48}
  49
  50impl ExampleBase {
  51    pub fn repo_name(&self) -> String {
  52        self.url
  53            .split('/')
  54            .next_back()
  55            .unwrap_or(&"")
  56            .trim_end_matches(".git")
  57            .into()
  58    }
  59}
  60
  61#[derive(Clone, Debug)]
  62pub struct Example {
  63    pub name: String,
  64    /// Content of `base.toml`
  65    pub base: ExampleBase,
  66    /// Content of `prompt.md`
  67    pub prompt: String,
  68    /// Content of `diff_criteria.md`
  69    pub diff_criteria: String,
  70    /// Content of `thread_criteria.md`, if that file exists (it's optional)
  71    pub thread_criteria: Option<String>,
  72    /// Path to the directory containing the requests and responses for the agentic loop
  73    pub run_directory_path: PathBuf,
  74    /// Prefix used for logging that identifies this example
  75    pub log_prefix: String,
  76    pub worktree_path: PathBuf,
  77    pub repo_path: PathBuf,
  78}
  79
  80#[derive(Debug, Serialize, Deserialize, Clone)]
  81pub struct RunOutput {
  82    pub repository_diff: String,
  83    pub ran_diagnostics_check: bool,
  84    pub diagnostics_before: Option<String>,
  85    pub diagnostics_after: Option<String>,
  86    pub response_count: usize,
  87    pub token_usage: TokenUsage,
  88    pub tool_metrics: ToolMetrics,
  89    pub last_request: LanguageModelRequest,
  90}
  91
  92#[derive(Debug, Clone, Serialize, Deserialize)]
  93pub struct JudgeDiffInput {
  94    pub repository_diff: String,
  95    pub ran_diagnostics_check: bool,
  96    #[serde(skip_serializing_if = "Option::is_none")]
  97    pub diagnostics_before: Option<String>,
  98    #[serde(skip_serializing_if = "Option::is_none")]
  99    pub diagnostics_after: Option<String>,
 100    pub criteria: String,
 101}
 102
 103#[derive(Debug, Clone, Serialize, Deserialize)]
 104pub struct JudgeThreadInput {
 105    pub messages: String,
 106    pub criteria: String,
 107}
 108
 109#[derive(Debug, Clone, Serialize, Deserialize)]
 110pub struct JudgeResponse {
 111    pub analysis: String,
 112    pub score: u32,
 113}
 114
 115#[derive(Debug, Clone, Serialize, Deserialize)]
 116pub struct JudgeOutput {
 117    pub thread: Option<JudgeResponse>,
 118    pub diff: JudgeResponse,
 119}
 120
 121impl Example {
 122    /// Load an example from a directory containing base.toml, prompt.md, and criteria.md
 123    pub fn load_from_directory(
 124        dir_path: &Path,
 125        run_dir: &Path,
 126        worktrees_dir: &Path,
 127        repos_dir: &Path,
 128    ) -> Result<Self> {
 129        let name = Self::name_from_path(dir_path);
 130        let base_path = dir_path.join("base.toml");
 131        let prompt_path = dir_path.join("prompt.md");
 132        let diff_criteria_path = dir_path.join("diff_criteria.md");
 133        let thread_criteria_path = dir_path.join("thread_criteria.md");
 134        let thread_criteria = if thread_criteria_path.exists() {
 135            Some(fs::read_to_string(thread_criteria_path.clone())?)
 136        } else {
 137            None
 138        };
 139
 140        let base: ExampleBase = toml::from_str(&fs::read_to_string(&base_path)?)?;
 141
 142        let repo_path = repo_path_for_url(repos_dir, &base.url);
 143
 144        let worktree_path = worktrees_dir
 145            .canonicalize()
 146            .unwrap()
 147            .join(&name)
 148            .join(&base.repo_name());
 149
 150        Ok(Example {
 151            name: name.clone(),
 152            base,
 153            prompt: fs::read_to_string(prompt_path.clone())?,
 154            thread_criteria,
 155            diff_criteria: fs::read_to_string(diff_criteria_path.clone())?,
 156            run_directory_path: run_dir.to_path_buf(),
 157            worktree_path,
 158            repo_path,
 159            log_prefix: name,
 160        })
 161    }
 162
 163    pub fn set_repetition_number(&mut self, repetition_number: u32) {
 164        if repetition_number > 0 {
 165            self.name = format!("{}-{}", self.name, repetition_number);
 166        }
 167    }
 168
 169    pub fn example_output_directory(&self) -> PathBuf {
 170        self.run_directory_path.join(&self.name)
 171    }
 172
 173    pub fn set_log_prefix_style(&mut self, color: &str, name_width: usize) {
 174        self.log_prefix = format!(
 175            "{}{:<width$}\x1b[0m | ",
 176            color,
 177            self.name,
 178            width = name_width
 179        );
 180    }
 181
 182    pub fn name_from_path(path: &Path) -> String {
 183        path.file_name().unwrap().to_string_lossy().to_string()
 184    }
 185
 186    /// Set up the example by checking out the specified Git revision
 187    pub async fn setup(&mut self) -> Result<()> {
 188        let revision_exists = run_git(
 189            &self.repo_path,
 190            &["rev-parse", &format!("{}^{{commit}}", self.base.revision)],
 191        )
 192        .await
 193        .is_ok();
 194
 195        if !revision_exists {
 196            println!(
 197                "{}Fetching revision {}",
 198                self.log_prefix, &self.base.revision
 199            );
 200            run_git(
 201                &self.repo_path,
 202                &["fetch", "--depth", "1", "origin", &self.base.revision],
 203            )
 204            .await?;
 205        }
 206
 207        if self.worktree_path.is_dir() {
 208            println!("{}Resetting existing worktree", self.log_prefix);
 209
 210            // TODO: consider including "-x" to remove ignored files. The downside of this is that
 211            // it will also remove build artifacts, and so prevent incremental reuse there.
 212            run_git(&self.worktree_path, &["clean", "--force", "-d"]).await?;
 213            run_git(&self.worktree_path, &["reset", "--hard", "HEAD"]).await?;
 214            run_git(&self.worktree_path, &["checkout", &self.base.revision]).await?;
 215        } else {
 216            println!("{}Creating worktree", self.log_prefix);
 217
 218            let worktree_path_string = self.worktree_path.to_string_lossy().to_string();
 219
 220            run_git(
 221                &self.repo_path,
 222                &[
 223                    "worktree",
 224                    "add",
 225                    "-f",
 226                    &worktree_path_string,
 227                    &self.base.revision,
 228                ],
 229            )
 230            .await?;
 231        }
 232
 233        if self.base.url == ZED_REPO_URL {
 234            std::fs::write(self.worktree_path.join(".rules"), std::fs::read(".rules")?)?;
 235        }
 236
 237        std::fs::create_dir_all(self.example_output_directory())?;
 238
 239        Ok(())
 240    }
 241
 242    pub fn run(
 243        &self,
 244        model: Arc<dyn LanguageModel>,
 245        app_state: Arc<AgentAppState>,
 246        cx: &mut App,
 247    ) -> Task<Result<RunOutput>> {
 248        let project = Project::local(
 249            app_state.client.clone(),
 250            app_state.node_runtime.clone(),
 251            app_state.user_store.clone(),
 252            app_state.languages.clone(),
 253            app_state.fs.clone(),
 254            None,
 255            cx,
 256        );
 257
 258        let worktree = project.update(cx, |project, cx| {
 259            project.create_worktree(&self.worktree_path, true, cx)
 260        });
 261
 262        let tools = cx.new(|_| ToolWorkingSet::default());
 263        let thread_store =
 264            ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx);
 265        let this = self.clone();
 266
 267        cx.spawn(async move |cx| {
 268            let worktree = worktree.await?;
 269
 270            // Wait for worktree scan to finish before choosing a file to open.
 271            worktree
 272                .update(cx, |worktree, _cx| {
 273                    worktree.as_local().unwrap().scan_complete()
 274                })?
 275                .await;
 276
 277            let lsp = if this.base.require_lsp {
 278                let language_extension = this.base.language_extension.as_deref().context(
 279                    "language_extension field is required in base.toml when `require_lsp == true`",
 280                )?;
 281
 282                // Open a file that matches the language to cause LSP to start.
 283                let language_file = worktree.read_with(cx, |worktree, _cx| {
 284                    worktree
 285                        .files(false, 0)
 286                        .find_map(|e| {
 287                            if e.path.clone().extension().and_then(|ext| ext.to_str())
 288                                == Some(language_extension)
 289                            {
 290                                Some(ProjectPath {
 291                                    worktree_id: worktree.id(),
 292                                    path: e.path.clone(),
 293                                })
 294                            } else {
 295                                None
 296                            }
 297                        })
 298                        .context("Failed to find a file for example language")
 299                })??;
 300
 301                let open_language_file_buffer_task = project.update(cx, |project, cx| {
 302                    project.open_buffer(language_file.clone(), cx)
 303                })?;
 304
 305                let language_file_buffer = open_language_file_buffer_task.await?;
 306
 307                let lsp_open_handle = project.update(cx, |project, cx| {
 308                    project.register_buffer_with_language_servers(&language_file_buffer, cx)
 309                })?;
 310
 311                wait_for_lang_server(&project, &language_file_buffer, this.log_prefix.clone(), cx).await?;
 312
 313                Some((lsp_open_handle, language_file_buffer))
 314            } else {
 315                None
 316            };
 317
 318            let diagnostics_before = query_lsp_diagnostics(project.clone(), cx).await?;
 319            if diagnostics_before.is_some() && !this.base.allow_preexisting_diagnostics {
 320                return Err(anyhow!("Example has pre-existing diagnostics. If you want to run this example regardless, set `allow_preexisting_diagnostics` to `true` in `base.toml`"));
 321            }
 322
 323            if std::env::var("ZED_EVAL_SETUP_ONLY").is_ok() {
 324                return Err(anyhow!("Setup only mode"));
 325            }
 326
 327            let example_output_dir = this.example_output_directory();
 328            let last_diff_file_path = example_output_dir.join("last.diff");
 329
 330            // Write an empty "last.diff" so that it can be opened in Zed for convenient view of the
 331            // history using undo/redo.
 332            std::fs::write(&last_diff_file_path, "")?;
 333
 334            let thread_store = thread_store.await?;
 335            let thread =
 336                thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?;
 337            let last_request = Rc::new(RefCell::new(None));
 338
 339            thread.update(cx, |thread, _cx| {
 340                let mut request_count = 0;
 341                let last_request = Rc::clone(&last_request);
 342                let previous_diff = Rc::new(RefCell::new("".to_string()));
 343                let example_output_dir = example_output_dir.clone();
 344                let last_diff_file_path = last_diff_file_path.clone();
 345                let this = this.clone();
 346                thread.set_request_callback(move |request, response_events| {
 347                    *last_request.borrow_mut() = Some(request.clone());
 348
 349                    request_count += 1;
 350                    let messages_file_path = example_output_dir.join(format!("{request_count}.messages.md"));
 351                    let diff_file_path = example_output_dir.join(format!("{request_count}.diff"));
 352                    let last_messages_file_path = example_output_dir.join("last.messages.md");
 353                    let request_markdown = RequestMarkdown::new(request);
 354                    let response_events_markdown = response_events_to_markdown(response_events);
 355
 356                    let messages = format!("{}\n\n{}", request_markdown.messages, response_events_markdown);
 357                    fs::write(&messages_file_path, messages.clone()).expect("failed to write messages file");
 358                    fs::write(&last_messages_file_path, messages).expect("failed to write last messages file");
 359
 360                    let diff_result = smol::block_on(this.repository_diff());
 361                    match diff_result {
 362                        Ok(diff) => {
 363                            if diff != previous_diff.borrow().clone() {
 364                                fs::write(&diff_file_path, &diff).expect("failed to write diff file");
 365                                fs::write(&last_diff_file_path, &diff).expect("failed to write last diff file");
 366                                *previous_diff.borrow_mut() = diff;
 367                            }
 368                        }
 369                        Err(err) => {
 370                            let error_message = format!("{err:?}");
 371                            fs::write(&diff_file_path, &error_message).expect("failed to write diff error to file");
 372                            fs::write(&last_diff_file_path, &error_message).expect("failed to write last diff file");
 373                        }
 374                    }
 375
 376                    if request_count == 1 {
 377                        let tools_file_path = example_output_dir.join("tools.md");
 378                        fs::write(tools_file_path, request_markdown.tools).expect("failed to write tools file");
 379                    }
 380                });
 381            })?;
 382
 383            let tool_metrics = Arc::new(Mutex::new(ToolMetrics::default()));
 384
 385            let (thread_event_tx, mut thread_event_rx) = mpsc::unbounded();
 386
 387            let subscription = cx.subscribe(&thread, move |_thread, event: &ThreadEvent, _cx| {
 388                thread_event_tx.unbounded_send(event.clone()).log_err();
 389            });
 390
 391            let event_handler_task = cx.spawn({
 392                let log_prefix = this.log_prefix.clone();
 393                let tool_metrics = tool_metrics.clone();
 394                let thread = thread.downgrade();
 395                async move |cx| {
 396                    loop {
 397                        let event = select_biased! {
 398                            event = thread_event_rx.next() => event,
 399                            _ = cx.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => {
 400                                return Err(anyhow!("Agentic loop stalled - waited {:?} without any events", THREAD_EVENT_TIMEOUT));
 401                            }
 402                        };
 403                        let Some(event) = event else {
 404                            return Err(anyhow!("ThreadEvent channel ended early"));
 405                        };
 406
 407                        match event {
 408                            ThreadEvent::Stopped(reason) => match reason {
 409                                Ok(StopReason::EndTurn) => {
 410                                    return Ok(());
 411                                }
 412                                Ok(StopReason::MaxTokens) => {
 413                                    return Err(anyhow!("Exceeded maximum tokens"));
 414                                }
 415                                Ok(StopReason::ToolUse) => {
 416                                    if std::env::var("ZED_EVAL_DEBUG").is_ok() {
 417                                        println!("{}StopReason: Tool use", log_prefix);
 418                                    }
 419                                }
 420                                Err(error) => {
 421                                    return Err(anyhow!(error.clone()));
 422                                }
 423                            },
 424                            ThreadEvent::ShowError(thread_error) => {
 425                                break Err(anyhow!(thread_error.clone()));
 426                            }
 427                            ThreadEvent::StreamedAssistantText(_, _)| ThreadEvent::StreamedAssistantThinking(_, _) | ThreadEvent::UsePendingTools { .. } => {
 428                            }
 429                            ThreadEvent::ToolFinished {
 430                                tool_use_id,
 431                                pending_tool_use,
 432                                ..
 433                            } => {
 434                                thread.update(cx, |thread, _cx| {
 435                                    if let Some(tool_use) = pending_tool_use {
 436                                        let mut tool_metrics = tool_metrics.lock().unwrap();
 437                                        if let Some(tool_result) = thread.tool_result(&tool_use_id) {
 438                                            let message = if tool_result.is_error {
 439                                                format!("TOOL FAILED: {}", tool_use.name)
 440                                            } else {
 441                                                format!("TOOL FINISHED: {}", tool_use.name)
 442                                            };
 443                                            println!("{log_prefix}{message}");
 444                                            tool_metrics.insert(tool_result.tool_name.clone(), !tool_result.is_error);
 445                                        } else {
 446                                            let message = format!("TOOL FINISHED WITHOUT RESULT: {}", tool_use.name);
 447                                            println!("{log_prefix}{message}");
 448                                            tool_metrics.insert(tool_use.name.clone(), true);
 449                                        }
 450                                    }
 451                                })?;
 452                            }
 453                            ThreadEvent::ToolConfirmationNeeded => {
 454                                panic!("{}Bug: Tool confirmation should not be required in eval", log_prefix);
 455                            },
 456                            ThreadEvent::StreamedToolUse { .. } |
 457                            ThreadEvent::StreamedCompletion |
 458                            ThreadEvent::MessageAdded(_) |
 459                            ThreadEvent::MessageEdited(_) |
 460                            ThreadEvent::MessageDeleted(_) |
 461                            ThreadEvent::SummaryChanged |
 462                            ThreadEvent::SummaryGenerated |
 463                            ThreadEvent::CheckpointChanged |
 464                            ThreadEvent::ReceivedTextChunk |
 465                            ThreadEvent::UsageUpdated(_) => {
 466                                if std::env::var("ZED_EVAL_DEBUG").is_ok() {
 467                                    println!("{}Event: {:#?}", log_prefix, event);
 468                                }
 469                            }
 470                        }
 471                    }
 472                }
 473            });
 474
 475            thread.update(cx, |thread, cx| {
 476                let context = vec![];
 477                thread.insert_user_message(this.prompt.clone(), context, None, cx);
 478                thread.send_to_model(model, cx);
 479            })?;
 480
 481            event_handler_task.await?;
 482
 483            println!("{}Stopped", this.log_prefix);
 484
 485            if let Some((_, language_file_buffer)) = lsp.as_ref() {
 486                wait_for_lang_server(&project, &language_file_buffer, this.log_prefix.clone(), cx).await?;
 487            }
 488
 489            println!("{}Getting repository diff", this.log_prefix);
 490            let repository_diff = this.repository_diff().await?;
 491            std::fs::write(last_diff_file_path, &repository_diff)?;
 492
 493            println!("{}Getting diagnostics", this.log_prefix);
 494            let diagnostics_after = cx
 495                .update(move |cx| {
 496                    cx.spawn(async move |cx| query_lsp_diagnostics(project, cx).await)
 497                })?
 498                .await?;
 499            println!("{}Got diagnostics", this.log_prefix);
 500
 501            let Some(last_request) = last_request.borrow_mut().take() else {
 502                return Err(anyhow!("No requests ran."));
 503            };
 504
 505            drop(subscription);
 506            drop(lsp);
 507
 508            if let Some(diagnostics_before) = &diagnostics_before {
 509                fs::write(example_output_dir.join("diagnostics_before.txt"), diagnostics_before)?;
 510            }
 511
 512            if let Some(diagnostics_after) = &diagnostics_after {
 513                fs::write(example_output_dir.join("diagnostics_after.txt"), diagnostics_after)?;
 514            }
 515
 516
 517            thread.update(cx, |thread, _cx| {
 518                let response_count = thread
 519                    .messages()
 520                    .filter(|message| message.role == language_model::Role::Assistant)
 521                    .count();
 522                RunOutput {
 523                    repository_diff,
 524                    ran_diagnostics_check: this.base.require_lsp,
 525                    diagnostics_before,
 526                    diagnostics_after,
 527                    response_count,
 528                    token_usage: thread.cumulative_token_usage(),
 529                    tool_metrics: tool_metrics.lock().unwrap().clone(),
 530                    last_request,
 531                }
 532            })
 533        })
 534    }
 535
 536    async fn judge_diff(
 537        &self,
 538        model: Arc<dyn LanguageModel>,
 539        run_output: &RunOutput,
 540        judge_number: u32,
 541        cx: &AsyncApp,
 542    ) -> Result<(String, JudgeResponse)> {
 543        let judge_diff_prompt = include_str!("judge_diff_prompt.hbs");
 544        let judge_diff_prompt_name = "judge_diff_prompt";
 545        let mut hbs = Handlebars::new();
 546        hbs.register_template_string(judge_diff_prompt_name, judge_diff_prompt)?;
 547
 548        let diff_prompt = hbs.render(
 549            judge_diff_prompt_name,
 550            &JudgeDiffInput {
 551                repository_diff: run_output.repository_diff.clone(),
 552                ran_diagnostics_check: run_output.ran_diagnostics_check,
 553                diagnostics_before: run_output.diagnostics_before.clone(),
 554                diagnostics_after: run_output.diagnostics_after.clone(),
 555                criteria: self.diff_criteria.clone(),
 556            },
 557        )?;
 558
 559        let request = LanguageModelRequest {
 560            thread_id: None,
 561            prompt_id: None,
 562            messages: vec![LanguageModelRequestMessage {
 563                role: Role::User,
 564                content: vec![MessageContent::Text(diff_prompt)],
 565                cache: false,
 566            }],
 567            temperature: None,
 568            tools: Vec::new(),
 569            stop: Vec::new(),
 570        };
 571
 572        let diff_response = send_language_model_request(model, request, cx).await?;
 573        let diff_output = JudgeResponse::parse(&diff_response)?;
 574
 575        println!(
 576            "{}Judge #{judge_number} - Diff score: {}",
 577            self.log_prefix, diff_output.score
 578        );
 579
 580        Ok((diff_response, diff_output))
 581    }
 582
 583    async fn judge_thread(
 584        &self,
 585        model: Arc<dyn LanguageModel>,
 586        run_output: &RunOutput,
 587        judge_number: u32,
 588        cx: &AsyncApp,
 589    ) -> Result<(String, Option<JudgeResponse>)> {
 590        if let Some(criteria) = self.thread_criteria.clone() {
 591            let judge_thread_prompt = include_str!("judge_thread_prompt.hbs");
 592            let judge_thread_prompt_name = "judge_thread_prompt";
 593            let mut hbs = Handlebars::new();
 594            hbs.register_template_string(judge_thread_prompt_name, judge_thread_prompt)?;
 595
 596            let request_markdown = RequestMarkdown::new(&run_output.last_request);
 597            let thread_prompt = hbs.render(
 598                judge_thread_prompt_name,
 599                &JudgeThreadInput {
 600                    messages: request_markdown.messages,
 601                    criteria,
 602                },
 603            )?;
 604
 605            let request = LanguageModelRequest {
 606                thread_id: None,
 607                prompt_id: None,
 608                messages: vec![LanguageModelRequestMessage {
 609                    role: Role::User,
 610                    content: vec![MessageContent::Text(thread_prompt)],
 611                    cache: false,
 612                }],
 613                temperature: None,
 614                tools: Vec::new(),
 615                stop: Vec::new(),
 616            };
 617
 618            let thread_response = send_language_model_request(model, request, cx).await?;
 619            let thread_output = JudgeResponse::parse(&thread_response)?;
 620
 621            println!(
 622                "{}Judge #{judge_number} - Thread score: {}",
 623                self.log_prefix, thread_output.score
 624            );
 625
 626            Ok((thread_response, Some(thread_output)))
 627        } else {
 628            let msg = "There were no criteria specified for this thread, so this example was not judged on its thread.".to_string();
 629            Ok((msg, None))
 630        }
 631    }
 632
 633    pub async fn judge(
 634        &self,
 635        model: Arc<dyn LanguageModel>,
 636        run_output: &RunOutput,
 637        judge_number: u32,
 638        cx: &AsyncApp,
 639    ) -> Result<JudgeOutput> {
 640        let mut output_file = File::create(
 641            self.example_output_directory()
 642                .join(format!("judge_{}.md", judge_number)),
 643        )
 644        .expect("failed to create judge.md");
 645
 646        println!("{}Running judge #{judge_number}", self.log_prefix);
 647
 648        let diff_task = self.judge_diff(model.clone(), &run_output, judge_number, cx);
 649        let thread_task = self.judge_thread(model.clone(), &run_output, judge_number, cx);
 650
 651        let (diff_result, thread_result) = futures::join!(diff_task, thread_task);
 652
 653        let (diff_response, diff_output) = diff_result?;
 654        let (thread_response, thread_output) = thread_result?;
 655
 656        writeln!(
 657            &mut output_file,
 658            "# Judgment\n\n## Thread\n\n{thread_response}\n\n## Diff\n\n{diff_response}",
 659        )
 660        .log_err();
 661
 662        Ok(JudgeOutput {
 663            thread: thread_output,
 664            diff: diff_output,
 665        })
 666    }
 667
 668    async fn repository_diff(&self) -> Result<String> {
 669        run_git(&self.worktree_path, &["add", "."]).await?;
 670        let mut diff_args = vec!["diff", "--staged"];
 671        if self.base.url == ZED_REPO_URL {
 672            diff_args.push(":(exclude).rules");
 673        }
 674        run_git(&self.worktree_path, &diff_args).await
 675    }
 676}
 677
 678fn wait_for_lang_server(
 679    project: &Entity<Project>,
 680    buffer: &Entity<Buffer>,
 681    log_prefix: String,
 682    cx: &mut AsyncApp,
 683) -> Task<Result<()>> {
 684    println!("{}⏵ Waiting for language server", log_prefix);
 685
 686    let (mut tx, mut rx) = mpsc::channel(1);
 687
 688    let lsp_store = project
 689        .update(cx, |project, _| project.lsp_store())
 690        .unwrap();
 691
 692    let has_lang_server = buffer
 693        .update(cx, |buffer, cx| {
 694            lsp_store.update(cx, |lsp_store, cx| {
 695                lsp_store
 696                    .language_servers_for_local_buffer(&buffer, cx)
 697                    .next()
 698                    .is_some()
 699            })
 700        })
 701        .unwrap_or(false);
 702
 703    if has_lang_server {
 704        project
 705            .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
 706            .unwrap()
 707            .detach();
 708    }
 709
 710    let subscriptions =
 711        [
 712            cx.subscribe(&lsp_store, {
 713                let log_prefix = log_prefix.clone();
 714                move |_, event, _| match event {
 715                    project::LspStoreEvent::LanguageServerUpdate {
 716                        message:
 717                            client::proto::update_language_server::Variant::WorkProgress(
 718                                LspWorkProgress {
 719                                    message: Some(message),
 720                                    ..
 721                                },
 722                            ),
 723                        ..
 724                    } => println!("{}{message}", log_prefix),
 725                    _ => {}
 726                }
 727            }),
 728            cx.subscribe(&project, {
 729                let buffer = buffer.clone();
 730                move |project, event, cx| match event {
 731                    project::Event::LanguageServerAdded(_, _, _) => {
 732                        let buffer = buffer.clone();
 733                        project
 734                            .update(cx, |project, cx| project.save_buffer(buffer, cx))
 735                            .detach();
 736                    }
 737                    project::Event::DiskBasedDiagnosticsFinished { .. } => {
 738                        tx.try_send(()).ok();
 739                    }
 740                    _ => {}
 741                }
 742            }),
 743        ];
 744
 745    cx.spawn(async move |cx| {
 746        let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0));
 747        let result = futures::select! {
 748            _ = rx.next() => {
 749                println!("{}⚑ Language server idle", log_prefix);
 750                anyhow::Ok(())
 751            },
 752            _ = timeout.fuse() => {
 753                Err(anyhow!("LSP wait timed out after 5 minutes"))
 754            }
 755        };
 756        drop(subscriptions);
 757        result
 758    })
 759}
 760
 761async fn query_lsp_diagnostics(
 762    project: Entity<Project>,
 763    cx: &mut AsyncApp,
 764) -> Result<Option<String>> {
 765    let paths_with_diagnostics = project.update(cx, |project, cx| {
 766        project
 767            .diagnostic_summaries(true, cx)
 768            .filter(|(_, _, summary)| summary.error_count > 0 || summary.warning_count > 0)
 769            .map(|(project_path, _, _)| project_path)
 770            .collect::<Vec<_>>()
 771    })?;
 772
 773    if paths_with_diagnostics.is_empty() {
 774        return Ok(None);
 775    }
 776
 777    let mut output = String::new();
 778    for project_path in paths_with_diagnostics {
 779        let buffer = project
 780            .update(cx, |project, cx| project.open_buffer(project_path, cx))?
 781            .await?;
 782        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
 783
 784        for (_, group) in snapshot.diagnostic_groups(None) {
 785            let entry = &group.entries[group.primary_ix];
 786            let range = entry.range.to_point(&snapshot);
 787            let severity = match entry.diagnostic.severity {
 788                DiagnosticSeverity::ERROR => "error",
 789                DiagnosticSeverity::WARNING => "warning",
 790                _ => continue,
 791            };
 792
 793            writeln!(
 794                output,
 795                "{} at line {}: {}",
 796                severity,
 797                range.start.row + 1,
 798                entry.diagnostic.message
 799            )?;
 800        }
 801    }
 802    anyhow::Ok(Some(output))
 803}
 804
 805impl JudgeResponse {
 806    fn parse(response: &str) -> Result<Self> {
 807        let analysis = get_tag("analysis", response)?.to_string();
 808        let score = get_tag("score", response)?
 809            .parse()
 810            .context("error parsing score")?;
 811
 812        Ok(Self { analysis, score })
 813    }
 814}
 815
 816fn get_tag(name: &'static str, response: &str) -> Result<String> {
 817    let start_tag = format!("<{}>", name);
 818    let end_tag = format!("</{}>", name);
 819
 820    let start_ix = response
 821        .find(&start_tag)
 822        .context(format!("{} start tag not found", name))?;
 823    let content_start_ix = start_ix + start_tag.len();
 824
 825    let end_ix = content_start_ix
 826        + response[content_start_ix..]
 827            .find(&end_tag)
 828            .context(format!("{} end tag not found", name))?;
 829
 830    let content = response[content_start_ix..end_ix].trim().unindent();
 831
 832    anyhow::Ok(content)
 833}
 834
 835pub fn repo_path_for_url(repos_dir: &Path, repo_url: &str) -> PathBuf {
 836    let repo_name = repo_url
 837        .trim_start_matches("https://")
 838        .replace(|c: char| !c.is_alphanumeric(), "-");
 839    Path::new(repos_dir)
 840        .canonicalize()
 841        .context(format!("No such directory {}", repos_dir.display()))
 842        .unwrap()
 843        .join(repo_name)
 844}
 845
 846pub async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
 847    let output = new_smol_command("git")
 848        .current_dir(repo_path)
 849        .args(args)
 850        .output()
 851        .await?;
 852
 853    if output.status.success() {
 854        Ok(String::from_utf8(output.stdout)?.trim().to_string())
 855    } else {
 856        Err(anyhow!(
 857            "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
 858            args.join(" "),
 859            repo_path.display(),
 860            output.status,
 861            String::from_utf8_lossy(&output.stderr),
 862            String::from_utf8_lossy(&output.stdout),
 863        ))
 864    }
 865}
 866
 867pub async fn send_language_model_request(
 868    model: Arc<dyn LanguageModel>,
 869    request: LanguageModelRequest,
 870    cx: &AsyncApp,
 871) -> anyhow::Result<String> {
 872    match model.stream_completion_text(request, &cx).await {
 873        Ok(mut stream) => {
 874            let mut full_response = String::new();
 875            while let Some(chunk_result) = stream.stream.next().await {
 876                match chunk_result {
 877                    Ok(chunk_str) => {
 878                        full_response.push_str(&chunk_str);
 879                    }
 880                    Err(err) => {
 881                        return Err(anyhow!(
 882                            "Error receiving response from language model: {err}"
 883                        ));
 884                    }
 885                }
 886            }
 887            Ok(full_response)
 888        }
 889        Err(err) => Err(anyhow!(
 890            "Failed to get response from language model. Error was: {err}"
 891        )),
 892    }
 893}
 894
 895struct RequestMarkdown {
 896    tools: String,
 897    messages: String,
 898}
 899
 900impl RequestMarkdown {
 901    fn new(request: &LanguageModelRequest) -> Self {
 902        let mut tools = String::new();
 903        let mut messages = String::new();
 904        let mut assistant_message_number: u32 = 1;
 905
 906        // Print the tools
 907        if !request.tools.is_empty() {
 908            for tool in &request.tools {
 909                write!(&mut tools, "# {}\n\n", tool.name).unwrap();
 910                write!(&mut tools, "{}\n\n", tool.description).unwrap();
 911                write!(
 912                    &mut tools,
 913                    "{}\n",
 914                    MarkdownString::code_block("json", &format!("{:#}", tool.input_schema))
 915                )
 916                .unwrap();
 917            }
 918        }
 919
 920        // Print the messages
 921        for message in &request.messages {
 922            match message.role {
 923                Role::System => messages.push_str("# ⚙️ SYSTEM\n\n"),
 924                Role::User => messages.push_str("# 👤 USER\n\n"),
 925                Role::Assistant => {
 926                    messages.push_str(&format!("# 🤖 ASSISTANT {assistant_message_number}\n\n"));
 927                    assistant_message_number += 1;
 928                }
 929            };
 930
 931            for content in &message.content {
 932                match content {
 933                    MessageContent::Text(text) => {
 934                        messages.push_str(text);
 935                        messages.push_str("\n\n");
 936                    }
 937                    MessageContent::Image(_) => {
 938                        messages.push_str("[IMAGE DATA]\n\n");
 939                    }
 940                    MessageContent::Thinking { text, signature } => {
 941                        messages.push_str("**Thinking**:\n\n");
 942                        if let Some(sig) = signature {
 943                            messages.push_str(&format!("Signature: {}\n\n", sig));
 944                        }
 945                        messages.push_str(text);
 946                        messages.push_str("\n");
 947                    }
 948                    MessageContent::RedactedThinking(items) => {
 949                        messages.push_str(&format!(
 950                            "**Redacted Thinking**: {} item(s)\n\n",
 951                            items.len()
 952                        ));
 953                    }
 954                    MessageContent::ToolUse(tool_use) => {
 955                        messages.push_str(&format!(
 956                            "**Tool Use**: {} (ID: {})\n",
 957                            tool_use.name, tool_use.id
 958                        ));
 959                        messages.push_str(&format!(
 960                            "{}\n",
 961                            MarkdownString::code_block("json", &format!("{:#}", tool_use.input))
 962                        ));
 963                    }
 964                    MessageContent::ToolResult(tool_result) => {
 965                        messages.push_str(&format!(
 966                            "**Tool Result**: {} (ID: {})\n\n",
 967                            tool_result.tool_name, tool_result.tool_use_id
 968                        ));
 969                        if tool_result.is_error {
 970                            messages.push_str("**ERROR:**\n");
 971                        }
 972                        messages.push_str(&format!("{}\n\n", tool_result.content));
 973                    }
 974                }
 975            }
 976        }
 977
 978        Self { tools, messages }
 979    }
 980}
 981
 982fn response_events_to_markdown(
 983    response_events: &[std::result::Result<LanguageModelCompletionEvent, String>],
 984) -> String {
 985    let mut response = String::new();
 986    // Print the response events if any
 987    response.push_str("# Response\n\n");
 988    let mut text_buffer = String::new();
 989    let mut thinking_buffer = String::new();
 990
 991    let flush_buffers =
 992        |output: &mut String, text_buffer: &mut String, thinking_buffer: &mut String| {
 993            if !text_buffer.is_empty() {
 994                output.push_str(&format!("**Text**:\n{}\n\n", text_buffer));
 995                text_buffer.clear();
 996            }
 997            if !thinking_buffer.is_empty() {
 998                output.push_str(&format!("**Thinking**:\n{}\n\n", thinking_buffer));
 999                thinking_buffer.clear();
1000            }
1001        };
1002
1003    for event in response_events {
1004        match event {
1005            Ok(LanguageModelCompletionEvent::Text(text)) => {
1006                text_buffer.push_str(text);
1007            }
1008            Ok(LanguageModelCompletionEvent::Thinking { text, .. }) => {
1009                thinking_buffer.push_str(text);
1010            }
1011            Ok(LanguageModelCompletionEvent::Stop(reason)) => {
1012                flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
1013                response.push_str(&format!("**Stop**: {:?}\n\n", reason));
1014            }
1015            Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => {
1016                flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
1017                response.push_str(&format!(
1018                    "**Tool Use**: {} (ID: {})\n",
1019                    tool_use.name, tool_use.id
1020                ));
1021                response.push_str(&format!(
1022                    "{}\n",
1023                    MarkdownString::code_block("json", &format!("{:#}", tool_use.input))
1024                ));
1025            }
1026            Ok(
1027                LanguageModelCompletionEvent::UsageUpdate(_)
1028                | LanguageModelCompletionEvent::StartMessage { .. },
1029            ) => {}
1030            Err(error) => {
1031                flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
1032                response.push_str(&format!("**Error**: {}\n\n", error));
1033            }
1034        }
1035    }
1036
1037    flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
1038
1039    response
1040}
1041
1042#[cfg(test)]
1043mod test {
1044    use super::*;
1045    use handlebars::Handlebars;
1046
1047    #[test]
1048    fn test_parse_judge_output() {
1049        let response = r#"
1050            <analysis>The model did a good job but there were still compilations errors.</analysis>
1051            <score>3</score>
1052        "#
1053        .unindent();
1054
1055        let output = JudgeResponse::parse(&response).unwrap();
1056        assert_eq!(
1057            output.analysis,
1058            "The model did a good job but there were still compilations errors."
1059        );
1060        assert_eq!(output.score, 3);
1061
1062        let response = r#"
1063            Text around ignored
1064
1065            <analysis>
1066                Failed to compile:
1067                - Error 1
1068                - Error 2
1069            </analysis>
1070
1071            <score>1</score>
1072        "#
1073        .unindent();
1074
1075        let output = JudgeResponse::parse(&response).unwrap();
1076        assert_eq!(output.analysis, "Failed to compile:\n- Error 1\n- Error 2");
1077        assert_eq!(output.score, 1);
1078    }
1079
1080    #[test]
1081    fn test_judge_prompt_with_diagnostics() {
1082        // Case 1: Both diagnostics before and after are present
1083        let input = JudgeDiffInput {
1084            repository_diff: "diff content goes here".to_string(),
1085            ran_diagnostics_check: true,
1086            diagnostics_before: Some("Error at line 10: variable not found".to_string()),
1087            diagnostics_after: Some("Error at line 15: missing semicolon".to_string()),
1088            criteria: "Fix all bugs".to_string(),
1089        };
1090
1091        let rendered = templates().render(JUDGE_PROMPT_NAME, &input).unwrap();
1092
1093        let expected_diagnostics_section = r#"
1094            Take into account the diagnostics before and after applying the change:
1095
1096            <diagnostics_before>
1097            Error at line 10: variable not found
1098            </diagnostics_before>
1099
1100            <diagnostics_after>
1101            Error at line 15: missing semicolon
1102            </diagnostics_after>
1103            "#
1104        .unindent();
1105
1106        assert!(rendered.contains(&expected_diagnostics_section));
1107    }
1108
1109    #[test]
1110    fn test_judge_prompt_with_empty_diagnostics() {
1111        // Case 2: Diagnostics check run but no diagnostics found
1112        let input = JudgeDiffInput {
1113            repository_diff: "diff content goes here".to_string(),
1114            ran_diagnostics_check: true,
1115            diagnostics_before: None,
1116            diagnostics_after: None,
1117            criteria: "Fix all bugs".to_string(),
1118        };
1119
1120        let rendered = templates().render(JUDGE_PROMPT_NAME, &input).unwrap();
1121
1122        let expected_diagnostics_section = r#"
1123            Take into account the diagnostics before and after applying the change:
1124
1125            <diagnostics_before>
1126            No diagnostics before applying the edits.
1127            </diagnostics_before>
1128
1129            <diagnostics_after>
1130            No diagnostics after applying the edits.
1131            </diagnostics_after>
1132            "#
1133        .unindent();
1134
1135        assert!(rendered.contains(&expected_diagnostics_section));
1136    }
1137
1138    #[test]
1139    fn test_judge_prompt_with_mixed_diagnostics() {
1140        let templates = templates();
1141
1142        // Case 3: Before diagnostics present, after diagnostics absent
1143        let input = JudgeDiffInput {
1144            repository_diff: "diff content goes here".to_string(),
1145            ran_diagnostics_check: true,
1146            diagnostics_before: Some("Error at line 10: variable not found".to_string()),
1147            diagnostics_after: None,
1148            criteria: "Fix all bugs".to_string(),
1149        };
1150
1151        let rendered = templates.render(JUDGE_PROMPT_NAME, &input).unwrap();
1152
1153        let expected_diagnostics_section = r#"
1154            Take into account the diagnostics before and after applying the change:
1155
1156            <diagnostics_before>
1157            Error at line 10: variable not found
1158            </diagnostics_before>
1159
1160            <diagnostics_after>
1161            No diagnostics after applying the edits.
1162            </diagnostics_after>
1163            "#
1164        .unindent();
1165
1166        assert!(rendered.contains(&expected_diagnostics_section));
1167
1168        // Case 4: Before diagnostics absent, after diagnostics present
1169        let input = JudgeDiffInput {
1170            repository_diff: "diff content goes here".to_string(),
1171            ran_diagnostics_check: true,
1172            diagnostics_before: None,
1173            diagnostics_after: Some("Error at line 15: missing semicolon".to_string()),
1174            criteria: "Fix all bugs".to_string(),
1175        };
1176
1177        let rendered = templates.render(JUDGE_PROMPT_NAME, &input).unwrap();
1178
1179        let expected_diagnostics_section = r#"
1180            Take into account the diagnostics before and after applying the change:
1181
1182            <diagnostics_before>
1183            No diagnostics before applying the edits.
1184            </diagnostics_before>
1185
1186            <diagnostics_after>
1187            Error at line 15: missing semicolon
1188            </diagnostics_after>
1189            "#
1190        .unindent();
1191
1192        assert!(rendered.contains(&expected_diagnostics_section));
1193    }
1194
1195    #[test]
1196    fn test_judge_prompt_without_diagnostics() {
1197        let templates = templates();
1198
1199        // Case 5: No diagnostics check run
1200        let input = JudgeDiffInput {
1201            repository_diff: "diff content goes here".to_string(),
1202            ran_diagnostics_check: false,
1203            diagnostics_before: None,
1204            diagnostics_after: None,
1205            criteria: "Fix all bugs".to_string(),
1206        };
1207
1208        let rendered = templates.render(JUDGE_PROMPT_NAME, &input).unwrap();
1209
1210        // Check for the message when no diagnostics were performed
1211        let diagnostics_message = "No diagnostic checks were performed.";
1212
1213        assert!(rendered.contains(diagnostics_message));
1214        assert!(!rendered.contains("<diagnostics_before>"));
1215        assert!(!rendered.contains("<diagnostics_after>"));
1216    }
1217
1218    const JUDGE_PROMPT_NAME: &str = "judge_prompt";
1219
1220    fn templates() -> Handlebars<'static> {
1221        let mut judge_prompt = include_str!("judge_diff_prompt.hbs").to_string();
1222        language::LineEnding::normalize(&mut judge_prompt);
1223        let mut handlebars = Handlebars::new();
1224        handlebars
1225            .register_template_string(JUDGE_PROMPT_NAME, judge_prompt)
1226            .unwrap();
1227        handlebars
1228    }
1229}