example.rs

   1use crate::{AgentAppState, ToolMetrics};
   2use agent::{ThreadEvent, ThreadStore};
   3use anyhow::{Context as _, Result, anyhow};
   4use assistant_tool::ToolWorkingSet;
   5use client::proto::LspWorkProgress;
   6use futures::channel::mpsc;
   7use futures::{FutureExt, StreamExt as _, select_biased};
   8use gpui::{App, AppContext as _, AsyncApp, Entity, Task};
   9use handlebars::Handlebars;
  10use language::{Buffer, DiagnosticSeverity, OffsetRangeExt};
  11use language_model::{
  12    LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
  13    MessageContent, Role, StopReason, TokenUsage,
  14};
  15use project::{Project, ProjectPath};
  16use serde::{Deserialize, Serialize};
  17use std::cell::RefCell;
  18use std::fmt::Write as _;
  19use std::fs::File;
  20use std::io::Write as _;
  21use std::rc::Rc;
  22use std::sync::{Arc, Mutex};
  23use std::time::Duration;
  24use std::{
  25    fs,
  26    path::{Path, PathBuf},
  27};
  28use unindent::Unindent as _;
  29use util::ResultExt as _;
  30use util::command::new_smol_command;
  31use util::markdown::MarkdownString;
  32use util::serde::default_true;
  33
  34pub const EXAMPLES_DIR: &str = "./crates/eval/examples";
  35pub const REPOS_DIR: &str = "./crates/eval/repos";
  36pub const WORKTREES_DIR: &str = "./crates/eval/worktrees";
  37
  38const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2);
  39
  40#[derive(Clone, Debug, Deserialize)]
  41pub struct ExampleBase {
  42    pub url: String,
  43    pub revision: String,
  44    pub language_extension: Option<String>,
  45    pub insert_id: Option<String>,
  46    #[serde(default = "default_true")]
  47    pub require_lsp: bool,
  48    #[serde(default)]
  49    pub allow_preexisting_diagnostics: bool,
  50}
  51
  52impl ExampleBase {
  53    pub fn repo_name(&self) -> String {
  54        self.url
  55            .split('/')
  56            .next_back()
  57            .unwrap_or(&"")
  58            .trim_end_matches(".git")
  59            .into()
  60    }
  61}
  62
  63#[derive(Clone, Debug)]
  64pub struct Example {
  65    pub name: String,
  66    /// Content of `base.toml`
  67    pub base: ExampleBase,
  68    /// Content of `prompt.md`
  69    pub prompt: String,
  70    /// Content of `diff_criteria.md`
  71    pub diff_criteria: String,
  72    /// Content of `thread_criteria.md`, if that file exists (it's optional)
  73    pub thread_criteria: Option<String>,
  74    /// Path to the directory containing the requests and responses for the agentic loop
  75    pub run_directory_path: PathBuf,
  76    /// Prefix used for logging that identifies this example
  77    pub log_prefix: String,
  78}
  79
  80#[derive(Debug, Serialize, Deserialize, Clone)]
  81pub struct RunOutput {
  82    pub repository_diff: String,
  83    pub ran_diagnostics_check: bool,
  84    pub diagnostics_before: Option<String>,
  85    pub diagnostics_after: Option<String>,
  86    pub response_count: usize,
  87    pub token_usage: TokenUsage,
  88    pub tool_metrics: ToolMetrics,
  89    pub last_request: LanguageModelRequest,
  90}
  91
  92#[derive(Debug, Clone, Serialize, Deserialize)]
  93pub struct JudgeDiffInput {
  94    pub repository_diff: String,
  95    pub ran_diagnostics_check: bool,
  96    #[serde(skip_serializing_if = "Option::is_none")]
  97    pub diagnostics_before: Option<String>,
  98    #[serde(skip_serializing_if = "Option::is_none")]
  99    pub diagnostics_after: Option<String>,
 100    pub criteria: String,
 101}
 102
 103#[derive(Debug, Clone, Serialize, Deserialize)]
 104pub struct JudgeThreadInput {
 105    pub messages: String,
 106    pub criteria: String,
 107}
 108
 109#[derive(Debug, Clone, Serialize, Deserialize)]
 110pub struct JudgeResponse {
 111    pub analysis: String,
 112    pub score: u32,
 113}
 114
 115#[derive(Debug, Clone, Serialize, Deserialize)]
 116pub struct JudgeOutput {
 117    pub thread: Option<JudgeResponse>,
 118    pub diff: JudgeResponse,
 119}
 120
 121impl Example {
 122    /// Load an example from a directory containing base.toml, prompt.md, and criteria.md
 123    pub fn load_from_directory(dir_path: &Path, run_dir: &Path) -> Result<Self> {
 124        let name = Self::name_from_path(dir_path);
 125        let base_path = dir_path.join("base.toml");
 126        let prompt_path = dir_path.join("prompt.md");
 127        let diff_criteria_path = dir_path.join("diff_criteria.md");
 128        let thread_criteria_path = dir_path.join("thread_criteria.md");
 129        let thread_criteria = if thread_criteria_path.exists() {
 130            Some(fs::read_to_string(thread_criteria_path.clone())?)
 131        } else {
 132            None
 133        };
 134
 135        Ok(Example {
 136            name: name.clone(),
 137            base: toml::from_str(&fs::read_to_string(&base_path)?)?,
 138            prompt: fs::read_to_string(prompt_path.clone())?,
 139            thread_criteria,
 140            diff_criteria: fs::read_to_string(diff_criteria_path.clone())?,
 141            run_directory_path: run_dir.to_path_buf(),
 142            log_prefix: name,
 143        })
 144    }
 145
 146    pub fn set_repetition_number(&mut self, repetition_number: u32) {
 147        if repetition_number > 0 {
 148            self.name = format!("{}-{}", self.name, repetition_number);
 149        }
 150    }
 151
 152    pub fn example_output_directory(&self) -> PathBuf {
 153        self.run_directory_path.join(&self.name)
 154    }
 155
 156    pub fn set_log_prefix_style(&mut self, color: &str, name_width: usize) {
 157        self.log_prefix = format!(
 158            "{}{:<width$}\x1b[0m | ",
 159            color,
 160            self.name,
 161            width = name_width
 162        );
 163    }
 164
 165    pub fn name_from_path(path: &Path) -> String {
 166        path.file_name().unwrap().to_string_lossy().to_string()
 167    }
 168
 169    pub fn worktree_path(&self) -> PathBuf {
 170        Path::new(WORKTREES_DIR)
 171            .canonicalize()
 172            .context(format!("No such directory {WORKTREES_DIR}"))
 173            .unwrap()
 174            .join(&self.name)
 175            .join(self.base.repo_name())
 176    }
 177
 178    /// Set up the example by checking out the specified Git revision
 179    pub async fn setup(&mut self) -> Result<()> {
 180        let repo_path = repo_path_for_url(&self.base.url);
 181
 182        let revision_exists = run_git(
 183            &repo_path,
 184            &["rev-parse", &format!("{}^{{commit}}", self.base.revision)],
 185        )
 186        .await
 187        .is_ok();
 188
 189        if !revision_exists {
 190            println!(
 191                "{}Fetching revision {}",
 192                self.log_prefix, &self.base.revision
 193            );
 194            run_git(
 195                &repo_path,
 196                &["fetch", "--depth", "1", "origin", &self.base.revision],
 197            )
 198            .await?;
 199        }
 200
 201        let worktree_path = self.worktree_path();
 202
 203        if worktree_path.is_dir() {
 204            println!("{}Resetting existing worktree", self.log_prefix);
 205
 206            // TODO: consider including "-x" to remove ignored files. The downside of this is that
 207            // it will also remove build artifacts, and so prevent incremental reuse there.
 208            run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
 209            run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
 210            run_git(&worktree_path, &["checkout", &self.base.revision]).await?;
 211        } else {
 212            println!("{}Creating worktree", self.log_prefix);
 213
 214            let worktree_path_string = worktree_path.to_string_lossy().to_string();
 215
 216            run_git(
 217                &repo_path,
 218                &[
 219                    "worktree",
 220                    "add",
 221                    "-f",
 222                    &worktree_path_string,
 223                    &self.base.revision,
 224                ],
 225            )
 226            .await?;
 227        }
 228
 229        std::fs::create_dir_all(self.example_output_directory())?;
 230
 231        Ok(())
 232    }
 233
 234    pub fn run(
 235        &self,
 236        model: Arc<dyn LanguageModel>,
 237        app_state: Arc<AgentAppState>,
 238        cx: &mut App,
 239    ) -> Task<Result<RunOutput>> {
 240        let project = Project::local(
 241            app_state.client.clone(),
 242            app_state.node_runtime.clone(),
 243            app_state.user_store.clone(),
 244            app_state.languages.clone(),
 245            app_state.fs.clone(),
 246            None,
 247            cx,
 248        );
 249
 250        let worktree_path = self.worktree_path();
 251        let worktree = project.update(cx, |project, cx| {
 252            project.create_worktree(&worktree_path, true, cx)
 253        });
 254
 255        let tools = cx.new(|_| ToolWorkingSet::default());
 256        let thread_store =
 257            ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx);
 258        let this = self.clone();
 259
 260        cx.spawn(async move |cx| {
 261            let worktree = worktree.await?;
 262
 263            // Wait for worktree scan to finish before choosing a file to open.
 264            worktree
 265                .update(cx, |worktree, _cx| {
 266                    worktree.as_local().unwrap().scan_complete()
 267                })?
 268                .await;
 269
 270            let lsp = if this.base.require_lsp {
 271                let language_extension = this.base.language_extension.as_deref().context(
 272                    "language_extension field is required in base.toml when `require_lsp == true`",
 273                )?;
 274
 275                // Open a file that matches the language to cause LSP to start.
 276                let language_file = worktree.read_with(cx, |worktree, _cx| {
 277                    worktree
 278                        .files(false, 0)
 279                        .find_map(|e| {
 280                            if e.path.clone().extension().and_then(|ext| ext.to_str())
 281                                == Some(language_extension)
 282                            {
 283                                Some(ProjectPath {
 284                                    worktree_id: worktree.id(),
 285                                    path: e.path.clone(),
 286                                })
 287                            } else {
 288                                None
 289                            }
 290                        })
 291                        .context("Failed to find a file for example language")
 292                })??;
 293
 294                let open_language_file_buffer_task = project.update(cx, |project, cx| {
 295                    project.open_buffer(language_file.clone(), cx)
 296                })?;
 297
 298                let language_file_buffer = open_language_file_buffer_task.await?;
 299
 300                let lsp_open_handle = project.update(cx, |project, cx| {
 301                    project.register_buffer_with_language_servers(&language_file_buffer, cx)
 302                })?;
 303
 304                wait_for_lang_server(&project, &language_file_buffer, this.log_prefix.clone(), cx).await?;
 305
 306                Some((lsp_open_handle, language_file_buffer))
 307            } else {
 308                None
 309            };
 310
 311            let diagnostics_before = query_lsp_diagnostics(project.clone(), cx).await?;
 312            if diagnostics_before.is_some() && !this.base.allow_preexisting_diagnostics {
 313                return Err(anyhow!("Example has pre-existing diagnostics. If you want to run this example regardless, set `allow_preexisting_diagnostics` to `true` in `base.toml`"));
 314            }
 315
 316            if std::env::var("ZED_EVAL_SETUP_ONLY").is_ok() {
 317                return Err(anyhow!("Setup only mode"));
 318            }
 319
 320            let thread_store = thread_store.await?;
 321            let thread =
 322                thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?;
 323            let last_request = Rc::new(RefCell::new(None));
 324
 325            thread.update(cx, |thread, _cx| {
 326                let mut request_count = 0;
 327                let example_dir_path = this.example_output_directory();
 328
 329                let last_request = Rc::clone(&last_request);
 330                thread.set_request_callback(move |request, response_events| {
 331                    *last_request.borrow_mut() = Some(request.clone());
 332
 333                    request_count += 1;
 334                    let messages_file_path = example_dir_path.join(format!("{request_count}.messages.md"));
 335                    let last_messages_file_path = example_dir_path.join("last.messages.md");
 336                    let request_markdown = RequestMarkdown::new(request);
 337                    let response_events_markdown = response_events_to_markdown(response_events);
 338
 339                    let messages = format!("{}\n\n{}", request_markdown.messages, response_events_markdown);
 340                    fs::write(messages_file_path, messages.clone()).expect("failed to write messages file");
 341                    fs::write(last_messages_file_path, messages).expect("failed to write last messages file");
 342
 343                    if request_count == 1 {
 344                        let tools_file_path = example_dir_path.join("tools.md");
 345                        fs::write(tools_file_path, request_markdown.tools).expect("failed to write tools file");
 346                    }
 347                });
 348            })?;
 349
 350            let tool_metrics = Arc::new(Mutex::new(ToolMetrics::default()));
 351
 352            let (thread_event_tx, mut thread_event_rx) = mpsc::unbounded();
 353
 354            let subscription = cx.subscribe(&thread, move |_thread, event: &ThreadEvent, _cx| {
 355                thread_event_tx.unbounded_send(event.clone()).log_err();
 356            });
 357
 358            let event_handler_task = cx.spawn({
 359                let log_prefix = this.log_prefix.clone();
 360                let tool_metrics = tool_metrics.clone();
 361                let thread = thread.downgrade();
 362                async move |cx| {
 363                    loop {
 364                        let event = select_biased! {
 365                            event = thread_event_rx.next() => event,
 366                            _ = cx.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => {
 367                                return Err(anyhow!("Agentic loop stalled - waited {:?} without any events", THREAD_EVENT_TIMEOUT));
 368                            }
 369                        };
 370                        let Some(event) = event else {
 371                            return Err(anyhow!("ThreadEvent channel ended early"));
 372                        };
 373
 374                        match event {
 375                            ThreadEvent::Stopped(reason) => match reason {
 376                                Ok(StopReason::EndTurn) => {
 377                                    return Ok(());
 378                                }
 379                                Ok(StopReason::MaxTokens) => {
 380                                    return Err(anyhow!("Exceeded maximum tokens"));
 381                                }
 382                                Ok(StopReason::ToolUse) => {
 383                                    if std::env::var("ZED_EVAL_DEBUG").is_ok() {
 384                                        println!("{}StopReason: Tool use", log_prefix);
 385                                    }
 386                                }
 387                                Err(error) => {
 388                                    return Err(anyhow!(error.clone()));
 389                                }
 390                            },
 391                            ThreadEvent::ShowError(thread_error) => {
 392                                break Err(anyhow!(thread_error.clone()));
 393                            }
 394                            ThreadEvent::StreamedAssistantText(_, _)| ThreadEvent::StreamedAssistantThinking(_, _) | ThreadEvent::UsePendingTools { .. } => {
 395                            }
 396                            ThreadEvent::ToolFinished {
 397                                tool_use_id,
 398                                pending_tool_use,
 399                                ..
 400                            } => {
 401                                thread.update(cx, |thread, _cx| {
 402                                    if let Some(tool_use) = pending_tool_use {
 403                                        let mut tool_metrics = tool_metrics.lock().unwrap();
 404                                        if let Some(tool_result) = thread.tool_result(&tool_use_id) {
 405                                            let message = if tool_result.is_error {
 406                                                format!("TOOL FAILED: {}", tool_use.name)
 407                                            } else {
 408                                                format!("TOOL FINISHED: {}", tool_use.name)
 409                                            };
 410                                            println!("{log_prefix}{message}");
 411                                            tool_metrics.insert(tool_result.tool_name.clone(), !tool_result.is_error);
 412                                        } else {
 413                                            let message = format!("TOOL FINISHED WITHOUT RESULT: {}", tool_use.name);
 414                                            println!("{log_prefix}{message}");
 415                                            tool_metrics.insert(tool_use.name.clone(), true);
 416                                        }
 417                                    }
 418                                })?;
 419                            }
 420                            ThreadEvent::ToolConfirmationNeeded => {
 421                                panic!("{}Bug: Tool confirmation should not be required in eval", log_prefix);
 422                            },
 423                            ThreadEvent::StreamedCompletion |
 424                            ThreadEvent::MessageAdded(_) |
 425                            ThreadEvent::MessageEdited(_) |
 426                            ThreadEvent::MessageDeleted(_) |
 427                            ThreadEvent::SummaryChanged |
 428                            ThreadEvent::SummaryGenerated |
 429                            ThreadEvent::CheckpointChanged |
 430                            ThreadEvent::UsageUpdated(_) => {
 431                                if std::env::var("ZED_EVAL_DEBUG").is_ok() {
 432                                    println!("{}Event: {:#?}", log_prefix, event);
 433                                }
 434                            }
 435                        }
 436                    }
 437                }
 438            });
 439
 440            thread.update(cx, |thread, cx| {
 441                let context = vec![];
 442                thread.insert_user_message(this.prompt.clone(), context, None, cx);
 443                thread.send_to_model(model, cx);
 444            })?;
 445
 446            event_handler_task.await?;
 447
 448            println!("{}Stopped", this.log_prefix);
 449
 450            if let Some((_, language_file_buffer)) = lsp.as_ref() {
 451                wait_for_lang_server(&project, &language_file_buffer, this.log_prefix.clone(), cx).await?;
 452            }
 453
 454            println!("{}Getting repository diff", this.log_prefix);
 455            let repository_diff = this.repository_diff().await?;
 456
 457            let example_output_dir = this.example_output_directory();
 458            let repository_diff_path = example_output_dir.join("patch.diff");
 459            let mut repository_diff_output_file = File::create(&repository_diff_path)?;
 460            writeln!(&mut repository_diff_output_file, "{}", &repository_diff).log_err();
 461
 462            println!("{}Getting diagnostics", this.log_prefix);
 463            let diagnostics_after = cx
 464                .update(move |cx| {
 465                    cx.spawn(async move |cx| query_lsp_diagnostics(project, cx).await)
 466                })?
 467                .await?;
 468            println!("{}Got diagnostics", this.log_prefix);
 469
 470            let Some(last_request) = last_request.borrow_mut().take() else {
 471                return Err(anyhow!("No requests ran."));
 472            };
 473
 474            drop(subscription);
 475            drop(lsp);
 476
 477            if let Some(diagnostics_before) = &diagnostics_before {
 478                fs::write(example_output_dir.join("diagnostics_before.txt"), diagnostics_before)?;
 479            }
 480
 481            if let Some(diagnostics_after) = &diagnostics_after {
 482                fs::write(example_output_dir.join("diagnostics_after.txt"), diagnostics_after)?;
 483            }
 484
 485
 486            thread.update(cx, |thread, _cx| {
 487                let response_count = thread
 488                    .messages()
 489                    .filter(|message| message.role == language_model::Role::Assistant)
 490                    .count();
 491                RunOutput {
 492                    repository_diff,
 493                    ran_diagnostics_check: this.base.require_lsp,
 494                    diagnostics_before,
 495                    diagnostics_after,
 496                    response_count,
 497                    token_usage: thread.cumulative_token_usage(),
 498                    tool_metrics: tool_metrics.lock().unwrap().clone(),
 499                    last_request,
 500                }
 501            })
 502        })
 503    }
 504
 505    async fn judge_diff(
 506        &self,
 507        model: Arc<dyn LanguageModel>,
 508        run_output: &RunOutput,
 509        judge_number: u32,
 510        cx: &AsyncApp,
 511    ) -> Result<(String, JudgeResponse)> {
 512        let judge_diff_prompt = include_str!("judge_diff_prompt.hbs");
 513        let judge_diff_prompt_name = "judge_diff_prompt";
 514        let mut hbs = Handlebars::new();
 515        hbs.register_template_string(judge_diff_prompt_name, judge_diff_prompt)?;
 516
 517        let diff_prompt = hbs.render(
 518            judge_diff_prompt_name,
 519            &JudgeDiffInput {
 520                repository_diff: run_output.repository_diff.clone(),
 521                ran_diagnostics_check: run_output.ran_diagnostics_check,
 522                diagnostics_before: run_output.diagnostics_before.clone(),
 523                diagnostics_after: run_output.diagnostics_after.clone(),
 524                criteria: self.diff_criteria.clone(),
 525            },
 526        )?;
 527
 528        let request = LanguageModelRequest {
 529            thread_id: None,
 530            prompt_id: None,
 531            messages: vec![LanguageModelRequestMessage {
 532                role: Role::User,
 533                content: vec![MessageContent::Text(diff_prompt)],
 534                cache: false,
 535            }],
 536            temperature: None,
 537            tools: Vec::new(),
 538            stop: Vec::new(),
 539        };
 540
 541        let diff_response = send_language_model_request(model, request, cx).await?;
 542        let diff_output = JudgeResponse::parse(&diff_response)?;
 543
 544        println!(
 545            "{}Judge #{judge_number} - Diff score: {}",
 546            self.log_prefix, diff_output.score
 547        );
 548
 549        Ok((diff_response, diff_output))
 550    }
 551
 552    async fn judge_thread(
 553        &self,
 554        model: Arc<dyn LanguageModel>,
 555        run_output: &RunOutput,
 556        judge_number: u32,
 557        cx: &AsyncApp,
 558    ) -> Result<(String, Option<JudgeResponse>)> {
 559        if let Some(criteria) = self.thread_criteria.clone() {
 560            let judge_thread_prompt = include_str!("judge_thread_prompt.hbs");
 561            let judge_thread_prompt_name = "judge_thread_prompt";
 562            let mut hbs = Handlebars::new();
 563            hbs.register_template_string(judge_thread_prompt_name, judge_thread_prompt)?;
 564
 565            let request_markdown = RequestMarkdown::new(&run_output.last_request);
 566            let thread_prompt = hbs.render(
 567                judge_thread_prompt_name,
 568                &JudgeThreadInput {
 569                    messages: request_markdown.messages,
 570                    criteria,
 571                },
 572            )?;
 573
 574            let request = LanguageModelRequest {
 575                thread_id: None,
 576                prompt_id: None,
 577                messages: vec![LanguageModelRequestMessage {
 578                    role: Role::User,
 579                    content: vec![MessageContent::Text(thread_prompt)],
 580                    cache: false,
 581                }],
 582                temperature: None,
 583                tools: Vec::new(),
 584                stop: Vec::new(),
 585            };
 586
 587            let thread_response = send_language_model_request(model, request, cx).await?;
 588            let thread_output = JudgeResponse::parse(&thread_response)?;
 589
 590            println!(
 591                "{}Judge #{judge_number} - Thread score: {}",
 592                self.log_prefix, thread_output.score
 593            );
 594
 595            Ok((thread_response, Some(thread_output)))
 596        } else {
 597            let msg = "There were no criteria specified for this thread, so this example was not judged on its thread.".to_string();
 598            Ok((msg, None))
 599        }
 600    }
 601
 602    pub async fn judge(
 603        &self,
 604        model: Arc<dyn LanguageModel>,
 605        run_output: &RunOutput,
 606        judge_number: u32,
 607        cx: &AsyncApp,
 608    ) -> Result<JudgeOutput> {
 609        let mut output_file = File::create(
 610            self.example_output_directory()
 611                .join(format!("judge_{}.md", judge_number)),
 612        )
 613        .expect("failed to create judge.md");
 614
 615        println!("{}Running judge #{judge_number}", self.log_prefix);
 616
 617        let diff_task = self.judge_diff(model.clone(), &run_output, judge_number, cx);
 618        let thread_task = self.judge_thread(model.clone(), &run_output, judge_number, cx);
 619
 620        let (diff_result, thread_result) = futures::join!(diff_task, thread_task);
 621
 622        let (diff_response, diff_output) = diff_result?;
 623        let (thread_response, thread_output) = thread_result?;
 624
 625        writeln!(
 626            &mut output_file,
 627            "# Judgment\n\n## Thread\n\n{thread_response}\n\n## Diff\n\n{diff_response}",
 628        )
 629        .log_err();
 630
 631        Ok(JudgeOutput {
 632            thread: thread_output,
 633            diff: diff_output,
 634        })
 635    }
 636
 637    async fn repository_diff(&self) -> Result<String> {
 638        let worktree_path = self.worktree_path();
 639        run_git(&worktree_path, &["add", "."]).await?;
 640        run_git(&worktree_path, &["diff", "--staged"]).await
 641    }
 642}
 643
 644fn wait_for_lang_server(
 645    project: &Entity<Project>,
 646    buffer: &Entity<Buffer>,
 647    log_prefix: String,
 648    cx: &mut AsyncApp,
 649) -> Task<Result<()>> {
 650    println!("{}⏵ Waiting for language server", log_prefix);
 651
 652    let (mut tx, mut rx) = mpsc::channel(1);
 653
 654    let lsp_store = project
 655        .update(cx, |project, _| project.lsp_store())
 656        .unwrap();
 657
 658    let has_lang_server = buffer
 659        .update(cx, |buffer, cx| {
 660            lsp_store.update(cx, |lsp_store, cx| {
 661                lsp_store
 662                    .language_servers_for_local_buffer(&buffer, cx)
 663                    .next()
 664                    .is_some()
 665            })
 666        })
 667        .unwrap_or(false);
 668
 669    if has_lang_server {
 670        project
 671            .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
 672            .unwrap()
 673            .detach();
 674    }
 675
 676    let subscriptions =
 677        [
 678            cx.subscribe(&lsp_store, {
 679                let log_prefix = log_prefix.clone();
 680                move |_, event, _| match event {
 681                    project::LspStoreEvent::LanguageServerUpdate {
 682                        message:
 683                            client::proto::update_language_server::Variant::WorkProgress(
 684                                LspWorkProgress {
 685                                    message: Some(message),
 686                                    ..
 687                                },
 688                            ),
 689                        ..
 690                    } => println!("{}{message}", log_prefix),
 691                    _ => {}
 692                }
 693            }),
 694            cx.subscribe(&project, {
 695                let buffer = buffer.clone();
 696                move |project, event, cx| match event {
 697                    project::Event::LanguageServerAdded(_, _, _) => {
 698                        let buffer = buffer.clone();
 699                        project
 700                            .update(cx, |project, cx| project.save_buffer(buffer, cx))
 701                            .detach();
 702                    }
 703                    project::Event::DiskBasedDiagnosticsFinished { .. } => {
 704                        tx.try_send(()).ok();
 705                    }
 706                    _ => {}
 707                }
 708            }),
 709        ];
 710
 711    cx.spawn(async move |cx| {
 712        let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0));
 713        let result = futures::select! {
 714            _ = rx.next() => {
 715                println!("{}⚑ Language server idle", log_prefix);
 716                anyhow::Ok(())
 717            },
 718            _ = timeout.fuse() => {
 719                Err(anyhow!("LSP wait timed out after 5 minutes"))
 720            }
 721        };
 722        drop(subscriptions);
 723        result
 724    })
 725}
 726
 727async fn query_lsp_diagnostics(
 728    project: Entity<Project>,
 729    cx: &mut AsyncApp,
 730) -> Result<Option<String>> {
 731    let paths_with_diagnostics = project.update(cx, |project, cx| {
 732        project
 733            .diagnostic_summaries(true, cx)
 734            .filter(|(_, _, summary)| summary.error_count > 0 || summary.warning_count > 0)
 735            .map(|(project_path, _, _)| project_path)
 736            .collect::<Vec<_>>()
 737    })?;
 738
 739    if paths_with_diagnostics.is_empty() {
 740        return Ok(None);
 741    }
 742
 743    let mut output = String::new();
 744    for project_path in paths_with_diagnostics {
 745        let buffer = project
 746            .update(cx, |project, cx| project.open_buffer(project_path, cx))?
 747            .await?;
 748        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
 749
 750        for (_, group) in snapshot.diagnostic_groups(None) {
 751            let entry = &group.entries[group.primary_ix];
 752            let range = entry.range.to_point(&snapshot);
 753            let severity = match entry.diagnostic.severity {
 754                DiagnosticSeverity::ERROR => "error",
 755                DiagnosticSeverity::WARNING => "warning",
 756                _ => continue,
 757            };
 758
 759            writeln!(
 760                output,
 761                "{} at line {}: {}",
 762                severity,
 763                range.start.row + 1,
 764                entry.diagnostic.message
 765            )?;
 766        }
 767    }
 768    anyhow::Ok(Some(output))
 769}
 770
 771impl JudgeResponse {
 772    fn parse(response: &str) -> Result<Self> {
 773        let analysis = get_tag("analysis", response)?.to_string();
 774        let score = get_tag("score", response)?
 775            .parse()
 776            .context("error parsing score")?;
 777
 778        Ok(Self { analysis, score })
 779    }
 780}
 781
 782fn get_tag(name: &'static str, response: &str) -> Result<String> {
 783    let start_tag = format!("<{}>", name);
 784    let end_tag = format!("</{}>", name);
 785
 786    let start_ix = response
 787        .find(&start_tag)
 788        .context(format!("{} start tag not found", name))?;
 789    let content_start_ix = start_ix + start_tag.len();
 790
 791    let end_ix = content_start_ix
 792        + response[content_start_ix..]
 793            .find(&end_tag)
 794            .context(format!("{} end tag not found", name))?;
 795
 796    let content = response[content_start_ix..end_ix].trim().unindent();
 797
 798    anyhow::Ok(content)
 799}
 800
 801pub fn repo_path_for_url(repo_url: &str) -> PathBuf {
 802    let repo_name = repo_url
 803        .trim_start_matches("https://")
 804        .replace(|c: char| !c.is_alphanumeric(), "-");
 805    Path::new(REPOS_DIR)
 806        .canonicalize()
 807        .context(format!("No such directory {REPOS_DIR}"))
 808        .unwrap()
 809        .join(repo_name)
 810}
 811
 812pub async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
 813    let output = new_smol_command("git")
 814        .current_dir(repo_path)
 815        .args(args)
 816        .output()
 817        .await?;
 818
 819    if output.status.success() {
 820        Ok(String::from_utf8(output.stdout)?.trim().to_string())
 821    } else {
 822        Err(anyhow!(
 823            "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
 824            args.join(" "),
 825            repo_path.display(),
 826            output.status,
 827            String::from_utf8_lossy(&output.stderr),
 828            String::from_utf8_lossy(&output.stdout),
 829        ))
 830    }
 831}
 832
 833pub async fn send_language_model_request(
 834    model: Arc<dyn LanguageModel>,
 835    request: LanguageModelRequest,
 836    cx: &AsyncApp,
 837) -> anyhow::Result<String> {
 838    match model.stream_completion_text(request, &cx).await {
 839        Ok(mut stream) => {
 840            let mut full_response = String::new();
 841            while let Some(chunk_result) = stream.stream.next().await {
 842                match chunk_result {
 843                    Ok(chunk_str) => {
 844                        full_response.push_str(&chunk_str);
 845                    }
 846                    Err(err) => {
 847                        return Err(anyhow!(
 848                            "Error receiving response from language model: {err}"
 849                        ));
 850                    }
 851                }
 852            }
 853            Ok(full_response)
 854        }
 855        Err(err) => Err(anyhow!(
 856            "Failed to get response from language model. Error was: {err}"
 857        )),
 858    }
 859}
 860
 861struct RequestMarkdown {
 862    tools: String,
 863    messages: String,
 864}
 865
 866impl RequestMarkdown {
 867    fn new(request: &LanguageModelRequest) -> Self {
 868        let mut tools = String::new();
 869        let mut messages = String::new();
 870        let mut assistant_message_number: u32 = 1;
 871
 872        // Print the tools
 873        if !request.tools.is_empty() {
 874            for tool in &request.tools {
 875                write!(&mut tools, "# {}\n\n", tool.name).unwrap();
 876                write!(&mut tools, "{}\n\n", tool.description).unwrap();
 877                write!(
 878                    &mut tools,
 879                    "{}\n",
 880                    MarkdownString::code_block("json", &format!("{:#}", tool.input_schema))
 881                )
 882                .unwrap();
 883            }
 884        }
 885
 886        // Print the messages
 887        for message in &request.messages {
 888            match message.role {
 889                Role::System => messages.push_str("# ⚙️ SYSTEM\n\n"),
 890                Role::User => messages.push_str("# 👤 USER\n\n"),
 891                Role::Assistant => {
 892                    messages.push_str(&format!("# 🤖 ASSISTANT {assistant_message_number}\n\n"));
 893                    assistant_message_number += 1;
 894                }
 895            };
 896
 897            for content in &message.content {
 898                match content {
 899                    MessageContent::Text(text) => {
 900                        messages.push_str(text);
 901                        messages.push_str("\n\n");
 902                    }
 903                    MessageContent::Image(_) => {
 904                        messages.push_str("[IMAGE DATA]\n\n");
 905                    }
 906                    MessageContent::Thinking { text, signature } => {
 907                        messages.push_str("**Thinking**:\n\n");
 908                        if let Some(sig) = signature {
 909                            messages.push_str(&format!("Signature: {}\n\n", sig));
 910                        }
 911                        messages.push_str(text);
 912                        messages.push_str("\n");
 913                    }
 914                    MessageContent::RedactedThinking(items) => {
 915                        messages.push_str(&format!(
 916                            "**Redacted Thinking**: {} item(s)\n\n",
 917                            items.len()
 918                        ));
 919                    }
 920                    MessageContent::ToolUse(tool_use) => {
 921                        messages.push_str(&format!(
 922                            "**Tool Use**: {} (ID: {})\n",
 923                            tool_use.name, tool_use.id
 924                        ));
 925                        messages.push_str(&format!(
 926                            "{}\n",
 927                            MarkdownString::code_block("json", &format!("{:#}", tool_use.input))
 928                        ));
 929                    }
 930                    MessageContent::ToolResult(tool_result) => {
 931                        messages.push_str(&format!(
 932                            "**Tool Result**: {} (ID: {})\n\n",
 933                            tool_result.tool_name, tool_result.tool_use_id
 934                        ));
 935                        if tool_result.is_error {
 936                            messages.push_str("**ERROR:**\n");
 937                        }
 938                        messages.push_str(&format!("{}\n\n", tool_result.content));
 939                    }
 940                }
 941            }
 942        }
 943
 944        Self { tools, messages }
 945    }
 946}
 947
 948fn response_events_to_markdown(
 949    response_events: &[std::result::Result<LanguageModelCompletionEvent, String>],
 950) -> String {
 951    let mut response = String::new();
 952    // Print the response events if any
 953    response.push_str("# Response\n\n");
 954    let mut text_buffer = String::new();
 955    let mut thinking_buffer = String::new();
 956
 957    let flush_buffers =
 958        |output: &mut String, text_buffer: &mut String, thinking_buffer: &mut String| {
 959            if !text_buffer.is_empty() {
 960                output.push_str(&format!("**Text**:\n{}\n\n", text_buffer));
 961                text_buffer.clear();
 962            }
 963            if !thinking_buffer.is_empty() {
 964                output.push_str(&format!("**Thinking**:\n{}\n\n", thinking_buffer));
 965                thinking_buffer.clear();
 966            }
 967        };
 968
 969    for event in response_events {
 970        match event {
 971            Ok(LanguageModelCompletionEvent::Text(text)) => {
 972                text_buffer.push_str(text);
 973            }
 974            Ok(LanguageModelCompletionEvent::Thinking { text, .. }) => {
 975                thinking_buffer.push_str(text);
 976            }
 977            Ok(LanguageModelCompletionEvent::Stop(reason)) => {
 978                flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
 979                response.push_str(&format!("**Stop**: {:?}\n\n", reason));
 980            }
 981            Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => {
 982                flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
 983                response.push_str(&format!(
 984                    "**Tool Use**: {} (ID: {})\n",
 985                    tool_use.name, tool_use.id
 986                ));
 987                response.push_str(&format!(
 988                    "{}\n",
 989                    MarkdownString::code_block("json", &format!("{:#}", tool_use.input))
 990                ));
 991            }
 992            Ok(
 993                LanguageModelCompletionEvent::UsageUpdate(_)
 994                | LanguageModelCompletionEvent::StartMessage { .. },
 995            ) => {}
 996            Err(error) => {
 997                flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
 998                response.push_str(&format!("**Error**: {}\n\n", error));
 999            }
1000        }
1001    }
1002
1003    flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
1004
1005    response
1006}
1007
1008#[cfg(test)]
1009mod test {
1010    use super::*;
1011    use handlebars::Handlebars;
1012
1013    #[test]
1014    fn test_parse_judge_output() {
1015        let response = r#"
1016            <analysis>The model did a good job but there were still compilations errors.</analysis>
1017            <score>3</score>
1018        "#
1019        .unindent();
1020
1021        let output = JudgeResponse::parse(&response).unwrap();
1022        assert_eq!(
1023            output.analysis,
1024            "The model did a good job but there were still compilations errors."
1025        );
1026        assert_eq!(output.score, 3);
1027
1028        let response = r#"
1029            Text around ignored
1030
1031            <analysis>
1032                Failed to compile:
1033                - Error 1
1034                - Error 2
1035            </analysis>
1036
1037            <score>1</score>
1038        "#
1039        .unindent();
1040
1041        let output = JudgeResponse::parse(&response).unwrap();
1042        assert_eq!(output.analysis, "Failed to compile:\n- Error 1\n- Error 2");
1043        assert_eq!(output.score, 1);
1044    }
1045
1046    #[test]
1047    fn test_judge_prompt_with_diagnostics() {
1048        // Case 1: Both diagnostics before and after are present
1049        let input = JudgeDiffInput {
1050            repository_diff: "diff content goes here".to_string(),
1051            ran_diagnostics_check: true,
1052            diagnostics_before: Some("Error at line 10: variable not found".to_string()),
1053            diagnostics_after: Some("Error at line 15: missing semicolon".to_string()),
1054            criteria: "Fix all bugs".to_string(),
1055        };
1056
1057        let rendered = templates().render(JUDGE_PROMPT_NAME, &input).unwrap();
1058
1059        let expected_diagnostics_section = r#"
1060            Take into account the diagnostics before and after applying the change:
1061
1062            <diagnostics_before>
1063            Error at line 10: variable not found
1064            </diagnostics_before>
1065
1066            <diagnostics_after>
1067            Error at line 15: missing semicolon
1068            </diagnostics_after>
1069            "#
1070        .unindent();
1071
1072        assert!(rendered.contains(&expected_diagnostics_section));
1073    }
1074
1075    #[test]
1076    fn test_judge_prompt_with_empty_diagnostics() {
1077        // Case 2: Diagnostics check run but no diagnostics found
1078        let input = JudgeDiffInput {
1079            repository_diff: "diff content goes here".to_string(),
1080            ran_diagnostics_check: true,
1081            diagnostics_before: None,
1082            diagnostics_after: None,
1083            criteria: "Fix all bugs".to_string(),
1084        };
1085
1086        let rendered = templates().render(JUDGE_PROMPT_NAME, &input).unwrap();
1087
1088        let expected_diagnostics_section = r#"
1089            Take into account the diagnostics before and after applying the change:
1090
1091            <diagnostics_before>
1092            No diagnostics before applying the edits.
1093            </diagnostics_before>
1094
1095            <diagnostics_after>
1096            No diagnostics after applying the edits.
1097            </diagnostics_after>
1098            "#
1099        .unindent();
1100
1101        assert!(rendered.contains(&expected_diagnostics_section));
1102    }
1103
1104    #[test]
1105    fn test_judge_prompt_with_mixed_diagnostics() {
1106        let templates = templates();
1107
1108        // Case 3: Before diagnostics present, after diagnostics absent
1109        let input = JudgeDiffInput {
1110            repository_diff: "diff content goes here".to_string(),
1111            ran_diagnostics_check: true,
1112            diagnostics_before: Some("Error at line 10: variable not found".to_string()),
1113            diagnostics_after: None,
1114            criteria: "Fix all bugs".to_string(),
1115        };
1116
1117        let rendered = templates.render(JUDGE_PROMPT_NAME, &input).unwrap();
1118
1119        let expected_diagnostics_section = r#"
1120            Take into account the diagnostics before and after applying the change:
1121
1122            <diagnostics_before>
1123            Error at line 10: variable not found
1124            </diagnostics_before>
1125
1126            <diagnostics_after>
1127            No diagnostics after applying the edits.
1128            </diagnostics_after>
1129            "#
1130        .unindent();
1131
1132        assert!(rendered.contains(&expected_diagnostics_section));
1133
1134        // Case 4: Before diagnostics absent, after diagnostics present
1135        let input = JudgeDiffInput {
1136            repository_diff: "diff content goes here".to_string(),
1137            ran_diagnostics_check: true,
1138            diagnostics_before: None,
1139            diagnostics_after: Some("Error at line 15: missing semicolon".to_string()),
1140            criteria: "Fix all bugs".to_string(),
1141        };
1142
1143        let rendered = templates.render(JUDGE_PROMPT_NAME, &input).unwrap();
1144
1145        let expected_diagnostics_section = r#"
1146            Take into account the diagnostics before and after applying the change:
1147
1148            <diagnostics_before>
1149            No diagnostics before applying the edits.
1150            </diagnostics_before>
1151
1152            <diagnostics_after>
1153            Error at line 15: missing semicolon
1154            </diagnostics_after>
1155            "#
1156        .unindent();
1157
1158        assert!(rendered.contains(&expected_diagnostics_section));
1159    }
1160
1161    #[test]
1162    fn test_judge_prompt_without_diagnostics() {
1163        let templates = templates();
1164
1165        // Case 5: No diagnostics check run
1166        let input = JudgeDiffInput {
1167            repository_diff: "diff content goes here".to_string(),
1168            ran_diagnostics_check: false,
1169            diagnostics_before: None,
1170            diagnostics_after: None,
1171            criteria: "Fix all bugs".to_string(),
1172        };
1173
1174        let rendered = templates.render(JUDGE_PROMPT_NAME, &input).unwrap();
1175
1176        // Check for the message when no diagnostics were performed
1177        let diagnostics_message = "No diagnostic checks were performed.";
1178
1179        assert!(rendered.contains(diagnostics_message));
1180        assert!(!rendered.contains("<diagnostics_before>"));
1181        assert!(!rendered.contains("<diagnostics_after>"));
1182    }
1183
1184    const JUDGE_PROMPT_NAME: &str = "judge_prompt";
1185
1186    fn templates() -> Handlebars<'static> {
1187        let mut judge_prompt = include_str!("judge_diff_prompt.hbs").to_string();
1188        language::LineEnding::normalize(&mut judge_prompt);
1189        let mut handlebars = Handlebars::new();
1190        handlebars
1191            .register_template_string(JUDGE_PROMPT_NAME, judge_prompt)
1192            .unwrap();
1193        handlebars
1194    }
1195}