example.rs

   1use agent::{ThreadEvent, ThreadStore};
   2use anyhow::{Context as _, Result, anyhow};
   3use assistant_tool::ToolWorkingSet;
   4use client::proto::LspWorkProgress;
   5use collections::HashMap;
   6use dap::DapRegistry;
   7use futures::channel::mpsc;
   8use futures::{FutureExt, StreamExt as _, select_biased};
   9use gpui::{App, AppContext as _, AsyncApp, Entity, Task};
  10use handlebars::Handlebars;
  11use language::{Buffer, DiagnosticSeverity, OffsetRangeExt};
  12use language_model::{
  13    LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
  14    MessageContent, Role, StopReason, TokenUsage,
  15};
  16use project::{Project, ProjectPath};
  17use serde::{Deserialize, Serialize};
  18use std::cell::RefCell;
  19use std::fmt::Write as _;
  20use std::fs::File;
  21use std::io::Write as _;
  22use std::rc::Rc;
  23use std::sync::{Arc, Mutex};
  24use std::time::Duration;
  25use std::{
  26    fs,
  27    path::{Path, PathBuf},
  28};
  29use unindent::Unindent as _;
  30use util::ResultExt as _;
  31use util::command::new_smol_command;
  32use util::markdown::MarkdownString;
  33use util::serde::default_true;
  34
  35use crate::AgentAppState;
  36
  37pub const EXAMPLES_DIR: &str = "./crates/eval/examples";
  38pub const REPOS_DIR: &str = "./crates/eval/repos";
  39pub const WORKTREES_DIR: &str = "./crates/eval/worktrees";
  40
  41const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2);
  42
  43#[derive(Clone, Debug, Deserialize)]
  44pub struct ExampleBase {
  45    pub url: String,
  46    pub revision: String,
  47    pub language_extension: Option<String>,
  48    pub insert_id: Option<String>,
  49    #[serde(default = "default_true")]
  50    pub require_lsp: bool,
  51    #[serde(default)]
  52    pub allow_preexisting_diagnostics: bool,
  53}
  54
  55impl ExampleBase {
  56    pub fn repo_name(&self) -> String {
  57        self.url
  58            .split('/')
  59            .next_back()
  60            .unwrap_or(&"")
  61            .trim_end_matches(".git")
  62            .into()
  63    }
  64}
  65
  66#[derive(Clone, Debug)]
  67pub struct Example {
  68    pub name: String,
  69    /// Content of `base.toml`
  70    pub base: ExampleBase,
  71    /// Content of `prompt.md`
  72    pub prompt: String,
  73    /// Content of `diff_criteria.md`
  74    pub diff_criteria: String,
  75    /// Content of `thread_criteria.md`, if that file exists (it's optional)
  76    pub thread_criteria: Option<String>,
  77    /// Path to the directory containing the requests and responses for the agentic loop
  78    pub run_directory_path: PathBuf,
  79    /// Prefix used for logging that identifies this example
  80    pub log_prefix: String,
  81}
  82
  83#[derive(Debug, Serialize, Deserialize, Clone)]
  84pub struct RunOutput {
  85    pub repository_diff: String,
  86    pub ran_diagnostics_check: bool,
  87    pub diagnostics_before: Option<String>,
  88    pub diagnostics_after: Option<String>,
  89    pub response_count: usize,
  90    pub token_usage: TokenUsage,
  91    pub tool_use_counts: HashMap<Arc<str>, u32>,
  92    pub last_request: LanguageModelRequest,
  93}
  94
  95#[derive(Debug, Clone, Serialize, Deserialize)]
  96pub struct JudgeDiffInput {
  97    pub repository_diff: String,
  98    pub ran_diagnostics_check: bool,
  99    #[serde(skip_serializing_if = "Option::is_none")]
 100    pub diagnostics_before: Option<String>,
 101    #[serde(skip_serializing_if = "Option::is_none")]
 102    pub diagnostics_after: Option<String>,
 103    pub criteria: String,
 104}
 105
 106#[derive(Debug, Clone, Serialize, Deserialize)]
 107pub struct JudgeThreadInput {
 108    pub messages: String,
 109    pub criteria: String,
 110}
 111
 112#[derive(Debug, Clone, Serialize, Deserialize)]
 113pub struct JudgeResponse {
 114    pub analysis: String,
 115    pub score: u32,
 116}
 117
 118#[derive(Debug, Clone, Serialize, Deserialize)]
 119pub struct JudgeOutput {
 120    pub thread: Option<JudgeResponse>,
 121    pub diff: JudgeResponse,
 122}
 123
 124impl Example {
 125    /// Load an example from a directory containing base.toml, prompt.md, and criteria.md
 126    pub fn load_from_directory(dir_path: &Path, run_dir: &Path) -> Result<Self> {
 127        let name = Self::name_from_path(dir_path);
 128        let base_path = dir_path.join("base.toml");
 129        let prompt_path = dir_path.join("prompt.md");
 130        let diff_criteria_path = dir_path.join("diff_criteria.md");
 131        let thread_criteria_path = dir_path.join("thread_criteria.md");
 132        let thread_criteria = if thread_criteria_path.exists() {
 133            Some(fs::read_to_string(thread_criteria_path.clone())?)
 134        } else {
 135            None
 136        };
 137
 138        Ok(Example {
 139            name: name.clone(),
 140            base: toml::from_str(&fs::read_to_string(&base_path)?)?,
 141            prompt: fs::read_to_string(prompt_path.clone())?,
 142            thread_criteria,
 143            diff_criteria: fs::read_to_string(diff_criteria_path.clone())?,
 144            run_directory_path: run_dir.to_path_buf(),
 145            log_prefix: name,
 146        })
 147    }
 148
 149    pub fn set_repetition_number(&mut self, repetition_number: u32) {
 150        if repetition_number > 0 {
 151            self.name = format!("{}-{}", self.name, repetition_number);
 152        }
 153    }
 154
 155    pub fn example_output_directory(&self) -> PathBuf {
 156        self.run_directory_path.join(&self.name)
 157    }
 158
 159    pub fn set_log_prefix_style(&mut self, color: &str, name_width: usize) {
 160        self.log_prefix = format!(
 161            "{}{:<width$}\x1b[0m | ",
 162            color,
 163            self.name,
 164            width = name_width
 165        );
 166    }
 167
 168    pub fn name_from_path(path: &Path) -> String {
 169        path.file_name().unwrap().to_string_lossy().to_string()
 170    }
 171
 172    pub fn worktree_path(&self) -> PathBuf {
 173        Path::new(WORKTREES_DIR)
 174            .canonicalize()
 175            .context(format!("No such directory {WORKTREES_DIR}"))
 176            .unwrap()
 177            .join(&self.name)
 178            .join(self.base.repo_name())
 179    }
 180
 181    /// Set up the example by checking out the specified Git revision
 182    pub async fn setup(&mut self) -> Result<()> {
 183        let repo_path = repo_path_for_url(&self.base.url);
 184
 185        let revision_exists = run_git(
 186            &repo_path,
 187            &["rev-parse", &format!("{}^{{commit}}", self.base.revision)],
 188        )
 189        .await
 190        .is_ok();
 191
 192        if !revision_exists {
 193            println!(
 194                "{}Fetching revision {}",
 195                self.log_prefix, &self.base.revision
 196            );
 197            run_git(
 198                &repo_path,
 199                &["fetch", "--depth", "1", "origin", &self.base.revision],
 200            )
 201            .await?;
 202        }
 203
 204        let worktree_path = self.worktree_path();
 205
 206        if worktree_path.is_dir() {
 207            println!("{}Resetting existing worktree", self.log_prefix);
 208
 209            // TODO: consider including "-x" to remove ignored files. The downside of this is that
 210            // it will also remove build artifacts, and so prevent incremental reuse there.
 211            run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
 212            run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
 213            run_git(&worktree_path, &["checkout", &self.base.revision]).await?;
 214        } else {
 215            println!("{}Creating worktree", self.log_prefix);
 216
 217            let worktree_path_string = worktree_path.to_string_lossy().to_string();
 218
 219            run_git(
 220                &repo_path,
 221                &[
 222                    "worktree",
 223                    "add",
 224                    "-f",
 225                    &worktree_path_string,
 226                    &self.base.revision,
 227                ],
 228            )
 229            .await?;
 230        }
 231
 232        std::fs::create_dir_all(self.example_output_directory())?;
 233
 234        Ok(())
 235    }
 236
 237    pub fn run(
 238        &self,
 239        model: Arc<dyn LanguageModel>,
 240        app_state: Arc<AgentAppState>,
 241        cx: &mut App,
 242    ) -> Task<Result<RunOutput>> {
 243        let project = Project::local(
 244            app_state.client.clone(),
 245            app_state.node_runtime.clone(),
 246            app_state.user_store.clone(),
 247            app_state.languages.clone(),
 248            Arc::new(DapRegistry::default()),
 249            app_state.fs.clone(),
 250            None,
 251            cx,
 252        );
 253
 254        let worktree_path = self.worktree_path();
 255        let worktree = project.update(cx, |project, cx| {
 256            project.create_worktree(&worktree_path, true, cx)
 257        });
 258
 259        let tools = cx.new(|_| ToolWorkingSet::default());
 260        let thread_store =
 261            ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx);
 262        let this = self.clone();
 263
 264        cx.spawn(async move |cx| {
 265            let worktree = worktree.await?;
 266
 267            // Wait for worktree scan to finish before choosing a file to open.
 268            worktree
 269                .update(cx, |worktree, _cx| {
 270                    worktree.as_local().unwrap().scan_complete()
 271                })?
 272                .await;
 273
 274            let lsp = if this.base.require_lsp {
 275                let language_extension = this.base.language_extension.as_deref().context(
 276                    "language_extension field is required in base.toml when `require_lsp == true`",
 277                )?;
 278
 279                // Open a file that matches the language to cause LSP to start.
 280                let language_file = worktree.read_with(cx, |worktree, _cx| {
 281                    worktree
 282                        .files(false, 0)
 283                        .find_map(|e| {
 284                            if e.path.clone().extension().and_then(|ext| ext.to_str())
 285                                == Some(language_extension)
 286                            {
 287                                Some(ProjectPath {
 288                                    worktree_id: worktree.id(),
 289                                    path: e.path.clone(),
 290                                })
 291                            } else {
 292                                None
 293                            }
 294                        })
 295                        .context("Failed to find a file for example language")
 296                })??;
 297
 298                let open_language_file_buffer_task = project.update(cx, |project, cx| {
 299                    project.open_buffer(language_file.clone(), cx)
 300                })?;
 301
 302                let language_file_buffer = open_language_file_buffer_task.await?;
 303
 304                let lsp_open_handle = project.update(cx, |project, cx| {
 305                    project.register_buffer_with_language_servers(&language_file_buffer, cx)
 306                })?;
 307
 308                wait_for_lang_server(&project, &language_file_buffer, this.log_prefix.clone(), cx).await?;
 309
 310                Some((lsp_open_handle, language_file_buffer))
 311            } else {
 312                None
 313            };
 314
 315            let diagnostics_before = query_lsp_diagnostics(project.clone(), cx).await?;
 316            if diagnostics_before.is_some() && !this.base.allow_preexisting_diagnostics {
 317                return Err(anyhow!("Example has pre-existing diagnostics. If you want to run this example regardless, set `allow_preexisting_diagnostics` to `true` in `base.toml`"));
 318            }
 319
 320            if std::env::var("ZED_EVAL_SETUP_ONLY").is_ok() {
 321                return Err(anyhow!("Setup only mode"));
 322            }
 323
 324            let thread_store = thread_store.await?;
 325            let thread =
 326                thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?;
 327            let last_request = Rc::new(RefCell::new(None));
 328
 329            thread.update(cx, |thread, _cx| {
 330                let mut request_count = 0;
 331                let example_dir_path = this.example_output_directory();
 332
 333                let last_request = Rc::clone(&last_request);
 334                thread.set_request_callback(move |request, response_events| {
 335                    *last_request.borrow_mut() = Some(request.clone());
 336
 337                    request_count += 1;
 338                    let messages_file_path = example_dir_path.join(format!("{request_count}.messages.md"));
 339                    let last_messages_file_path = example_dir_path.join("last.messages.md");
 340                    let request_markdown = RequestMarkdown::new(request);
 341                    let response_events_markdown = response_events_to_markdown(response_events);
 342
 343                    let messages = format!("{}\n\n{}", request_markdown.messages, response_events_markdown);
 344                    fs::write(messages_file_path, messages.clone()).expect("failed to write messages file");
 345                    fs::write(last_messages_file_path, messages).expect("failed to write last messages file");
 346
 347                    if request_count == 1 {
 348                        let tools_file_path = example_dir_path.join("tools.md");
 349                        fs::write(tools_file_path, request_markdown.tools).expect("failed to write tools file");
 350                    }
 351                });
 352            })?;
 353
 354            let tool_use_counts: Arc<Mutex<HashMap<Arc<str>, u32>>> =
 355                Mutex::new(HashMap::default()).into();
 356
 357            let (thread_event_tx, mut thread_event_rx) = mpsc::unbounded();
 358
 359            let subscription = cx.subscribe(&thread, move |_thread, event: &ThreadEvent, _cx| {
 360                thread_event_tx.unbounded_send(event.clone()).log_err();
 361            });
 362
 363            let event_handler_task = cx.spawn({
 364                let log_prefix = this.log_prefix.clone();
 365                let tool_use_counts = tool_use_counts.clone();
 366                let thread = thread.downgrade();
 367                async move |cx| {
 368                    loop {
 369                        let event = select_biased! {
 370                            event = thread_event_rx.next() => event,
 371                            _ = cx.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => {
 372                                return Err(anyhow!("Agentic loop stalled - waited {:?} without any events", THREAD_EVENT_TIMEOUT));
 373                            }
 374                        };
 375                        let Some(event) = event else {
 376                            return Err(anyhow!("ThreadEvent channel ended early"));
 377                        };
 378
 379                        match event {
 380                            ThreadEvent::Stopped(reason) => match reason {
 381                                Ok(StopReason::EndTurn) => {
 382                                    return Ok(());
 383                                }
 384                                Ok(StopReason::MaxTokens) => {
 385                                    return Err(anyhow!("Exceeded maximum tokens"));
 386                                }
 387                                Ok(StopReason::ToolUse) => {
 388                                    if std::env::var("ZED_EVAL_DEBUG").is_ok() {
 389                                        println!("{}StopReason: Tool use", log_prefix);
 390                                    }
 391                                }
 392                                Err(error) => {
 393                                    return Err(anyhow!(error.clone()));
 394                                }
 395                            },
 396                            ThreadEvent::ShowError(thread_error) => {
 397                                break Err(anyhow!(thread_error.clone()));
 398                            }
 399                            ThreadEvent::StreamedAssistantText(_, _)| ThreadEvent::StreamedAssistantThinking(_, _) | ThreadEvent::UsePendingTools { .. } => {
 400                            }
 401                            ThreadEvent::ToolFinished {
 402                                tool_use_id,
 403                                pending_tool_use,
 404                                ..
 405                            } => {
 406                                thread.update(cx, |thread, _cx| {
 407                                    if let Some(tool_use) = pending_tool_use {
 408                                        if let Some(tool_result) = thread.tool_result(&tool_use_id) {
 409                                            let message = if tool_result.is_error {
 410                                                format!("TOOL FAILED: {}", tool_use.name)
 411                                            } else {
 412                                                format!("TOOL FINISHED: {}", tool_use.name)
 413                                            };
 414                                            println!("{log_prefix}{message}");
 415                                            let mut tool_use_counts = tool_use_counts.lock().unwrap();
 416                                            *tool_use_counts
 417                                                .entry(tool_result.tool_name.clone())
 418                                                .or_insert(0) += 1;
 419                                        } else {
 420                                            let message = format!("TOOL FINISHED WITHOUT RESULT: {}", tool_use.name);
 421                                            println!("{log_prefix}{message}");
 422                                        }
 423                                    }
 424                                })?;
 425                            }
 426                            ThreadEvent::ToolConfirmationNeeded => {
 427                                panic!("{}Bug: Tool confirmation should not be required in eval", log_prefix);
 428                            },
 429                            ThreadEvent::StreamedCompletion |
 430                            ThreadEvent::MessageAdded(_) |
 431                            ThreadEvent::MessageEdited(_) |
 432                            ThreadEvent::MessageDeleted(_) |
 433                            ThreadEvent::SummaryChanged |
 434                            ThreadEvent::SummaryGenerated |
 435                            ThreadEvent::CheckpointChanged |
 436                            ThreadEvent::UsageUpdated(_) => {
 437                                if std::env::var("ZED_EVAL_DEBUG").is_ok() {
 438                                    println!("{}Event: {:#?}", log_prefix, event);
 439                                }
 440                            }
 441                        }
 442                    }
 443                }
 444            });
 445
 446            thread.update(cx, |thread, cx| {
 447                let context = vec![];
 448                thread.insert_user_message(this.prompt.clone(), context, None, cx);
 449                thread.send_to_model(model, cx);
 450            })?;
 451
 452            event_handler_task.await?;
 453
 454            println!("{}Stopped", this.log_prefix);
 455
 456            if let Some((_, language_file_buffer)) = lsp.as_ref() {
 457                wait_for_lang_server(&project, &language_file_buffer, this.log_prefix.clone(), cx).await?;
 458            }
 459
 460            println!("{}Getting repository diff", this.log_prefix);
 461            let repository_diff = this.repository_diff().await?;
 462
 463            let example_output_dir = this.example_output_directory();
 464            let repository_diff_path = example_output_dir.join("patch.diff");
 465            let mut repository_diff_output_file = File::create(&repository_diff_path)?;
 466            writeln!(&mut repository_diff_output_file, "{}", &repository_diff).log_err();
 467
 468            println!("{}Getting diagnostics", this.log_prefix);
 469            let diagnostics_after = cx
 470                .update(move |cx| {
 471                    cx.spawn(async move |cx| query_lsp_diagnostics(project, cx).await)
 472                })?
 473                .await?;
 474            println!("{}Got diagnostics", this.log_prefix);
 475
 476            let Some(last_request) = last_request.borrow_mut().take() else {
 477                return Err(anyhow!("No requests ran."));
 478            };
 479
 480            drop(subscription);
 481            drop(lsp);
 482
 483            if let Some(diagnostics_before) = &diagnostics_before {
 484                fs::write(example_output_dir.join("diagnostics_before.txt"), diagnostics_before)?;
 485            }
 486
 487            if let Some(diagnostics_after) = &diagnostics_after {
 488                fs::write(example_output_dir.join("diagnostics_after.txt"), diagnostics_after)?;
 489            }
 490
 491
 492            thread.update(cx, |thread, _cx| {
 493                let response_count = thread
 494                    .messages()
 495                    .filter(|message| message.role == language_model::Role::Assistant)
 496                    .count();
 497                RunOutput {
 498                    repository_diff,
 499                    ran_diagnostics_check: this.base.require_lsp,
 500                    diagnostics_before,
 501                    diagnostics_after,
 502                    response_count,
 503                    token_usage: thread.cumulative_token_usage(),
 504                    tool_use_counts: tool_use_counts.lock().unwrap().clone(),
 505                    last_request,
 506                }
 507            })
 508        })
 509    }
 510
 511    async fn judge_diff(
 512        &self,
 513        model: Arc<dyn LanguageModel>,
 514        run_output: &RunOutput,
 515        judge_number: u32,
 516        cx: &AsyncApp,
 517    ) -> Result<(String, JudgeResponse)> {
 518        let judge_diff_prompt = include_str!("judge_diff_prompt.hbs");
 519        let judge_diff_prompt_name = "judge_diff_prompt";
 520        let mut hbs = Handlebars::new();
 521        hbs.register_template_string(judge_diff_prompt_name, judge_diff_prompt)?;
 522
 523        let diff_prompt = hbs.render(
 524            judge_diff_prompt_name,
 525            &JudgeDiffInput {
 526                repository_diff: run_output.repository_diff.clone(),
 527                ran_diagnostics_check: run_output.ran_diagnostics_check,
 528                diagnostics_before: run_output.diagnostics_before.clone(),
 529                diagnostics_after: run_output.diagnostics_after.clone(),
 530                criteria: self.diff_criteria.clone(),
 531            },
 532        )?;
 533
 534        let request = LanguageModelRequest {
 535            thread_id: None,
 536            prompt_id: None,
 537            messages: vec![LanguageModelRequestMessage {
 538                role: Role::User,
 539                content: vec![MessageContent::Text(diff_prompt)],
 540                cache: false,
 541            }],
 542            temperature: None,
 543            tools: Vec::new(),
 544            stop: Vec::new(),
 545        };
 546
 547        let diff_response = send_language_model_request(model, request, cx).await?;
 548        let diff_output = JudgeResponse::parse(&diff_response)?;
 549
 550        println!(
 551            "{}Judge #{judge_number} - Diff score: {}",
 552            self.log_prefix, diff_output.score
 553        );
 554
 555        Ok((diff_response, diff_output))
 556    }
 557
 558    async fn judge_thread(
 559        &self,
 560        model: Arc<dyn LanguageModel>,
 561        run_output: &RunOutput,
 562        judge_number: u32,
 563        cx: &AsyncApp,
 564    ) -> Result<(String, Option<JudgeResponse>)> {
 565        if let Some(criteria) = self.thread_criteria.clone() {
 566            let judge_thread_prompt = include_str!("judge_thread_prompt.hbs");
 567            let judge_thread_prompt_name = "judge_thread_prompt";
 568            let mut hbs = Handlebars::new();
 569            hbs.register_template_string(judge_thread_prompt_name, judge_thread_prompt)?;
 570
 571            let request_markdown = RequestMarkdown::new(&run_output.last_request);
 572            let thread_prompt = hbs.render(
 573                judge_thread_prompt_name,
 574                &JudgeThreadInput {
 575                    messages: request_markdown.messages,
 576                    criteria,
 577                },
 578            )?;
 579
 580            let request = LanguageModelRequest {
 581                thread_id: None,
 582                prompt_id: None,
 583                messages: vec![LanguageModelRequestMessage {
 584                    role: Role::User,
 585                    content: vec![MessageContent::Text(thread_prompt)],
 586                    cache: false,
 587                }],
 588                temperature: None,
 589                tools: Vec::new(),
 590                stop: Vec::new(),
 591            };
 592
 593            let thread_response = send_language_model_request(model, request, cx).await?;
 594            let thread_output = JudgeResponse::parse(&thread_response)?;
 595
 596            println!(
 597                "{}Judge #{judge_number} - Thread score: {}",
 598                self.log_prefix, thread_output.score
 599            );
 600
 601            Ok((thread_response, Some(thread_output)))
 602        } else {
 603            let msg = "There were no criteria specified for this thread, so this example was not judged on its thread.".to_string();
 604            Ok((msg, None))
 605        }
 606    }
 607
 608    pub async fn judge(
 609        &self,
 610        model: Arc<dyn LanguageModel>,
 611        run_output: &RunOutput,
 612        judge_number: u32,
 613        cx: &AsyncApp,
 614    ) -> Result<JudgeOutput> {
 615        let mut output_file = File::create(
 616            self.example_output_directory()
 617                .join(format!("judge_{}.md", judge_number)),
 618        )
 619        .expect("failed to create judge.md");
 620
 621        println!("{}Running judge #{judge_number}", self.log_prefix);
 622
 623        let diff_task = self.judge_diff(model.clone(), &run_output, judge_number, cx);
 624        let thread_task = self.judge_thread(model.clone(), &run_output, judge_number, cx);
 625
 626        let (diff_result, thread_result) = futures::join!(diff_task, thread_task);
 627
 628        let (diff_response, diff_output) = diff_result?;
 629        let (thread_response, thread_output) = thread_result?;
 630
 631        writeln!(
 632            &mut output_file,
 633            "# Judgment\n\n## Thread\n\n{thread_response}\n\n## Diff\n\n{diff_response}",
 634        )
 635        .log_err();
 636
 637        Ok(JudgeOutput {
 638            thread: thread_output,
 639            diff: diff_output,
 640        })
 641    }
 642
 643    async fn repository_diff(&self) -> Result<String> {
 644        let worktree_path = self.worktree_path();
 645        run_git(&worktree_path, &["add", "."]).await?;
 646        run_git(&worktree_path, &["diff", "--staged"]).await
 647    }
 648}
 649
 650fn wait_for_lang_server(
 651    project: &Entity<Project>,
 652    buffer: &Entity<Buffer>,
 653    log_prefix: String,
 654    cx: &mut AsyncApp,
 655) -> Task<Result<()>> {
 656    println!("{}⏵ Waiting for language server", log_prefix);
 657
 658    let (mut tx, mut rx) = mpsc::channel(1);
 659
 660    let lsp_store = project
 661        .update(cx, |project, _| project.lsp_store())
 662        .unwrap();
 663
 664    let has_lang_server = buffer
 665        .update(cx, |buffer, cx| {
 666            lsp_store.update(cx, |lsp_store, cx| {
 667                lsp_store
 668                    .language_servers_for_local_buffer(&buffer, cx)
 669                    .next()
 670                    .is_some()
 671            })
 672        })
 673        .unwrap_or(false);
 674
 675    if has_lang_server {
 676        project
 677            .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
 678            .unwrap()
 679            .detach();
 680    }
 681
 682    let subscriptions =
 683        [
 684            cx.subscribe(&lsp_store, {
 685                let log_prefix = log_prefix.clone();
 686                move |_, event, _| match event {
 687                    project::LspStoreEvent::LanguageServerUpdate {
 688                        message:
 689                            client::proto::update_language_server::Variant::WorkProgress(
 690                                LspWorkProgress {
 691                                    message: Some(message),
 692                                    ..
 693                                },
 694                            ),
 695                        ..
 696                    } => println!("{}{message}", log_prefix),
 697                    _ => {}
 698                }
 699            }),
 700            cx.subscribe(&project, {
 701                let buffer = buffer.clone();
 702                move |project, event, cx| match event {
 703                    project::Event::LanguageServerAdded(_, _, _) => {
 704                        let buffer = buffer.clone();
 705                        project
 706                            .update(cx, |project, cx| project.save_buffer(buffer, cx))
 707                            .detach();
 708                    }
 709                    project::Event::DiskBasedDiagnosticsFinished { .. } => {
 710                        tx.try_send(()).ok();
 711                    }
 712                    _ => {}
 713                }
 714            }),
 715        ];
 716
 717    cx.spawn(async move |cx| {
 718        let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0));
 719        let result = futures::select! {
 720            _ = rx.next() => {
 721                println!("{}⚑ Language server idle", log_prefix);
 722                anyhow::Ok(())
 723            },
 724            _ = timeout.fuse() => {
 725                Err(anyhow!("LSP wait timed out after 5 minutes"))
 726            }
 727        };
 728        drop(subscriptions);
 729        result
 730    })
 731}
 732
 733async fn query_lsp_diagnostics(
 734    project: Entity<Project>,
 735    cx: &mut AsyncApp,
 736) -> Result<Option<String>> {
 737    let paths_with_diagnostics = project.update(cx, |project, cx| {
 738        project
 739            .diagnostic_summaries(true, cx)
 740            .filter(|(_, _, summary)| summary.error_count > 0 || summary.warning_count > 0)
 741            .map(|(project_path, _, _)| project_path)
 742            .collect::<Vec<_>>()
 743    })?;
 744
 745    if paths_with_diagnostics.is_empty() {
 746        return Ok(None);
 747    }
 748
 749    let mut output = String::new();
 750    for project_path in paths_with_diagnostics {
 751        let buffer = project
 752            .update(cx, |project, cx| project.open_buffer(project_path, cx))?
 753            .await?;
 754        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
 755
 756        for (_, group) in snapshot.diagnostic_groups(None) {
 757            let entry = &group.entries[group.primary_ix];
 758            let range = entry.range.to_point(&snapshot);
 759            let severity = match entry.diagnostic.severity {
 760                DiagnosticSeverity::ERROR => "error",
 761                DiagnosticSeverity::WARNING => "warning",
 762                _ => continue,
 763            };
 764
 765            writeln!(
 766                output,
 767                "{} at line {}: {}",
 768                severity,
 769                range.start.row + 1,
 770                entry.diagnostic.message
 771            )?;
 772        }
 773    }
 774    anyhow::Ok(Some(output))
 775}
 776
 777impl JudgeResponse {
 778    fn parse(response: &str) -> Result<Self> {
 779        let analysis = get_tag("analysis", response)?.to_string();
 780        let score = get_tag("score", response)?
 781            .parse()
 782            .context("error parsing score")?;
 783
 784        Ok(Self { analysis, score })
 785    }
 786}
 787
 788fn get_tag(name: &'static str, response: &str) -> Result<String> {
 789    let start_tag = format!("<{}>", name);
 790    let end_tag = format!("</{}>", name);
 791
 792    let start_ix = response
 793        .find(&start_tag)
 794        .context(format!("{} start tag not found", name))?;
 795    let content_start_ix = start_ix + start_tag.len();
 796
 797    let end_ix = content_start_ix
 798        + response[content_start_ix..]
 799            .find(&end_tag)
 800            .context(format!("{} end tag not found", name))?;
 801
 802    let content = response[content_start_ix..end_ix].trim().unindent();
 803
 804    anyhow::Ok(content)
 805}
 806
 807pub fn repo_path_for_url(repo_url: &str) -> PathBuf {
 808    let repo_name = repo_url
 809        .trim_start_matches("https://")
 810        .replace(|c: char| !c.is_alphanumeric(), "-");
 811    Path::new(REPOS_DIR)
 812        .canonicalize()
 813        .context(format!("No such directory {REPOS_DIR}"))
 814        .unwrap()
 815        .join(repo_name)
 816}
 817
 818pub async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
 819    let output = new_smol_command("git")
 820        .current_dir(repo_path)
 821        .args(args)
 822        .output()
 823        .await?;
 824
 825    if output.status.success() {
 826        Ok(String::from_utf8(output.stdout)?.trim().to_string())
 827    } else {
 828        Err(anyhow!(
 829            "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
 830            args.join(" "),
 831            repo_path.display(),
 832            output.status,
 833            String::from_utf8_lossy(&output.stderr),
 834            String::from_utf8_lossy(&output.stdout),
 835        ))
 836    }
 837}
 838
 839pub async fn send_language_model_request(
 840    model: Arc<dyn LanguageModel>,
 841    request: LanguageModelRequest,
 842    cx: &AsyncApp,
 843) -> anyhow::Result<String> {
 844    match model.stream_completion_text(request, &cx).await {
 845        Ok(mut stream) => {
 846            let mut full_response = String::new();
 847            while let Some(chunk_result) = stream.stream.next().await {
 848                match chunk_result {
 849                    Ok(chunk_str) => {
 850                        full_response.push_str(&chunk_str);
 851                    }
 852                    Err(err) => {
 853                        return Err(anyhow!(
 854                            "Error receiving response from language model: {err}"
 855                        ));
 856                    }
 857                }
 858            }
 859            Ok(full_response)
 860        }
 861        Err(err) => Err(anyhow!(
 862            "Failed to get response from language model. Error was: {err}"
 863        )),
 864    }
 865}
 866
 867struct RequestMarkdown {
 868    tools: String,
 869    messages: String,
 870}
 871
 872impl RequestMarkdown {
 873    fn new(request: &LanguageModelRequest) -> Self {
 874        let mut tools = String::new();
 875        let mut messages = String::new();
 876        let mut assistant_message_number: u32 = 1;
 877
 878        // Print the tools
 879        if !request.tools.is_empty() {
 880            for tool in &request.tools {
 881                write!(&mut tools, "# {}\n\n", tool.name).unwrap();
 882                write!(&mut tools, "{}\n\n", tool.description).unwrap();
 883                write!(
 884                    &mut tools,
 885                    "{}\n",
 886                    MarkdownString::code_block("json", &format!("{:#}", tool.input_schema))
 887                )
 888                .unwrap();
 889            }
 890        }
 891
 892        // Print the messages
 893        for message in &request.messages {
 894            match message.role {
 895                Role::System => messages.push_str("# ⚙️ SYSTEM\n\n"),
 896                Role::User => messages.push_str("# 👤 USER\n\n"),
 897                Role::Assistant => {
 898                    messages.push_str(&format!("# 🤖 ASSISTANT {assistant_message_number}\n\n"));
 899                    assistant_message_number += 1;
 900                }
 901            };
 902
 903            for content in &message.content {
 904                match content {
 905                    MessageContent::Text(text) => {
 906                        messages.push_str(text);
 907                        messages.push_str("\n\n");
 908                    }
 909                    MessageContent::Image(_) => {
 910                        messages.push_str("[IMAGE DATA]\n\n");
 911                    }
 912                    MessageContent::Thinking { text, signature } => {
 913                        messages.push_str("**Thinking**:\n\n");
 914                        if let Some(sig) = signature {
 915                            messages.push_str(&format!("Signature: {}\n\n", sig));
 916                        }
 917                        messages.push_str(text);
 918                        messages.push_str("\n");
 919                    }
 920                    MessageContent::RedactedThinking(items) => {
 921                        messages.push_str(&format!(
 922                            "**Redacted Thinking**: {} item(s)\n\n",
 923                            items.len()
 924                        ));
 925                    }
 926                    MessageContent::ToolUse(tool_use) => {
 927                        messages.push_str(&format!(
 928                            "**Tool Use**: {} (ID: {})\n",
 929                            tool_use.name, tool_use.id
 930                        ));
 931                        messages.push_str(&format!(
 932                            "{}\n",
 933                            MarkdownString::code_block("json", &format!("{:#}", tool_use.input))
 934                        ));
 935                    }
 936                    MessageContent::ToolResult(tool_result) => {
 937                        messages.push_str(&format!(
 938                            "**Tool Result**: {} (ID: {})\n\n",
 939                            tool_result.tool_name, tool_result.tool_use_id
 940                        ));
 941                        if tool_result.is_error {
 942                            messages.push_str("**ERROR:**\n");
 943                        }
 944                        messages.push_str(&format!("{}\n\n", tool_result.content));
 945                    }
 946                }
 947            }
 948        }
 949
 950        Self { tools, messages }
 951    }
 952}
 953
 954fn response_events_to_markdown(
 955    response_events: &[std::result::Result<LanguageModelCompletionEvent, String>],
 956) -> String {
 957    let mut response = String::new();
 958    // Print the response events if any
 959    response.push_str("# Response\n\n");
 960    let mut text_buffer = String::new();
 961    let mut thinking_buffer = String::new();
 962
 963    let flush_buffers =
 964        |output: &mut String, text_buffer: &mut String, thinking_buffer: &mut String| {
 965            if !text_buffer.is_empty() {
 966                output.push_str(&format!("**Text**:\n{}\n\n", text_buffer));
 967                text_buffer.clear();
 968            }
 969            if !thinking_buffer.is_empty() {
 970                output.push_str(&format!("**Thinking**:\n{}\n\n", thinking_buffer));
 971                thinking_buffer.clear();
 972            }
 973        };
 974
 975    for event in response_events {
 976        match event {
 977            Ok(LanguageModelCompletionEvent::Text(text)) => {
 978                text_buffer.push_str(text);
 979            }
 980            Ok(LanguageModelCompletionEvent::Thinking { text, .. }) => {
 981                thinking_buffer.push_str(text);
 982            }
 983            Ok(LanguageModelCompletionEvent::Stop(reason)) => {
 984                flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
 985                response.push_str(&format!("**Stop**: {:?}\n\n", reason));
 986            }
 987            Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => {
 988                flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
 989                response.push_str(&format!(
 990                    "**Tool Use**: {} (ID: {})\n",
 991                    tool_use.name, tool_use.id
 992                ));
 993                response.push_str(&format!(
 994                    "{}\n",
 995                    MarkdownString::code_block("json", &format!("{:#}", tool_use.input))
 996                ));
 997            }
 998            Ok(
 999                LanguageModelCompletionEvent::UsageUpdate(_)
1000                | LanguageModelCompletionEvent::StartMessage { .. },
1001            ) => {}
1002            Err(error) => {
1003                flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
1004                response.push_str(&format!("**Error**: {}\n\n", error));
1005            }
1006        }
1007    }
1008
1009    flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
1010
1011    response
1012}
1013
1014#[cfg(test)]
1015mod test {
1016    use super::*;
1017    use handlebars::Handlebars;
1018
1019    #[test]
1020    fn test_parse_judge_output() {
1021        let response = r#"
1022            <analysis>The model did a good job but there were still compilations errors.</analysis>
1023            <score>3</score>
1024        "#
1025        .unindent();
1026
1027        let output = JudgeResponse::parse(&response).unwrap();
1028        assert_eq!(
1029            output.analysis,
1030            "The model did a good job but there were still compilations errors."
1031        );
1032        assert_eq!(output.score, 3);
1033
1034        let response = r#"
1035            Text around ignored
1036
1037            <analysis>
1038                Failed to compile:
1039                - Error 1
1040                - Error 2
1041            </analysis>
1042
1043            <score>1</score>
1044        "#
1045        .unindent();
1046
1047        let output = JudgeResponse::parse(&response).unwrap();
1048        assert_eq!(output.analysis, "Failed to compile:\n- Error 1\n- Error 2");
1049        assert_eq!(output.score, 1);
1050    }
1051
1052    #[test]
1053    fn test_judge_prompt_with_diagnostics() {
1054        // Case 1: Both diagnostics before and after are present
1055        let input = JudgeDiffInput {
1056            repository_diff: "diff content goes here".to_string(),
1057            ran_diagnostics_check: true,
1058            diagnostics_before: Some("Error at line 10: variable not found".to_string()),
1059            diagnostics_after: Some("Error at line 15: missing semicolon".to_string()),
1060            criteria: "Fix all bugs".to_string(),
1061        };
1062
1063        let rendered = templates().render(JUDGE_PROMPT_NAME, &input).unwrap();
1064
1065        let expected_diagnostics_section = r#"
1066            Take into account the diagnostics before and after applying the change:
1067
1068            <diagnostics_before>
1069            Error at line 10: variable not found
1070            </diagnostics_before>
1071
1072            <diagnostics_after>
1073            Error at line 15: missing semicolon
1074            </diagnostics_after>
1075            "#
1076        .unindent();
1077
1078        assert!(rendered.contains(&expected_diagnostics_section));
1079    }
1080
1081    #[test]
1082    fn test_judge_prompt_with_empty_diagnostics() {
1083        // Case 2: Diagnostics check run but no diagnostics found
1084        let input = JudgeDiffInput {
1085            repository_diff: "diff content goes here".to_string(),
1086            ran_diagnostics_check: true,
1087            diagnostics_before: None,
1088            diagnostics_after: None,
1089            criteria: "Fix all bugs".to_string(),
1090        };
1091
1092        let rendered = templates().render(JUDGE_PROMPT_NAME, &input).unwrap();
1093
1094        let expected_diagnostics_section = r#"
1095            Take into account the diagnostics before and after applying the change:
1096
1097            <diagnostics_before>
1098            No diagnostics before applying the edits.
1099            </diagnostics_before>
1100
1101            <diagnostics_after>
1102            No diagnostics after applying the edits.
1103            </diagnostics_after>
1104            "#
1105        .unindent();
1106
1107        assert!(rendered.contains(&expected_diagnostics_section));
1108    }
1109
1110    #[test]
1111    fn test_judge_prompt_with_mixed_diagnostics() {
1112        let templates = templates();
1113
1114        // Case 3: Before diagnostics present, after diagnostics absent
1115        let input = JudgeDiffInput {
1116            repository_diff: "diff content goes here".to_string(),
1117            ran_diagnostics_check: true,
1118            diagnostics_before: Some("Error at line 10: variable not found".to_string()),
1119            diagnostics_after: None,
1120            criteria: "Fix all bugs".to_string(),
1121        };
1122
1123        let rendered = templates.render(JUDGE_PROMPT_NAME, &input).unwrap();
1124
1125        let expected_diagnostics_section = r#"
1126            Take into account the diagnostics before and after applying the change:
1127
1128            <diagnostics_before>
1129            Error at line 10: variable not found
1130            </diagnostics_before>
1131
1132            <diagnostics_after>
1133            No diagnostics after applying the edits.
1134            </diagnostics_after>
1135            "#
1136        .unindent();
1137
1138        assert!(rendered.contains(&expected_diagnostics_section));
1139
1140        // Case 4: Before diagnostics absent, after diagnostics present
1141        let input = JudgeDiffInput {
1142            repository_diff: "diff content goes here".to_string(),
1143            ran_diagnostics_check: true,
1144            diagnostics_before: None,
1145            diagnostics_after: Some("Error at line 15: missing semicolon".to_string()),
1146            criteria: "Fix all bugs".to_string(),
1147        };
1148
1149        let rendered = templates.render(JUDGE_PROMPT_NAME, &input).unwrap();
1150
1151        let expected_diagnostics_section = r#"
1152            Take into account the diagnostics before and after applying the change:
1153
1154            <diagnostics_before>
1155            No diagnostics before applying the edits.
1156            </diagnostics_before>
1157
1158            <diagnostics_after>
1159            Error at line 15: missing semicolon
1160            </diagnostics_after>
1161            "#
1162        .unindent();
1163
1164        assert!(rendered.contains(&expected_diagnostics_section));
1165    }
1166
1167    #[test]
1168    fn test_judge_prompt_without_diagnostics() {
1169        let templates = templates();
1170
1171        // Case 5: No diagnostics check run
1172        let input = JudgeDiffInput {
1173            repository_diff: "diff content goes here".to_string(),
1174            ran_diagnostics_check: false,
1175            diagnostics_before: None,
1176            diagnostics_after: None,
1177            criteria: "Fix all bugs".to_string(),
1178        };
1179
1180        let rendered = templates.render(JUDGE_PROMPT_NAME, &input).unwrap();
1181
1182        // Check for the message when no diagnostics were performed
1183        let diagnostics_message = "No diagnostic checks were performed.";
1184
1185        assert!(rendered.contains(diagnostics_message));
1186        assert!(!rendered.contains("<diagnostics_before>"));
1187        assert!(!rendered.contains("<diagnostics_after>"));
1188    }
1189
1190    const JUDGE_PROMPT_NAME: &str = "judge_prompt";
1191
1192    fn templates() -> Handlebars<'static> {
1193        let mut judge_prompt = include_str!("judge_diff_prompt.hbs").to_string();
1194        language::LineEnding::normalize(&mut judge_prompt);
1195        let mut handlebars = Handlebars::new();
1196        handlebars
1197            .register_template_string(JUDGE_PROMPT_NAME, judge_prompt)
1198            .unwrap();
1199        handlebars
1200    }
1201}