example.rs

  1use agent::{RequestKind, ThreadEvent, ThreadStore};
  2use anyhow::{Context as _, Result, anyhow};
  3use assistant_tool::ToolWorkingSet;
  4use client::proto::LspWorkProgress;
  5use collections::HashMap;
  6use dap::DapRegistry;
  7use futures::channel::mpsc;
  8use futures::{FutureExt, StreamExt as _, select_biased};
  9use gpui::{App, AppContext as _, AsyncApp, Entity, Task};
 10use handlebars::Handlebars;
 11use language::{DiagnosticSeverity, OffsetRangeExt};
 12use language_model::{
 13    LanguageModel, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
 14    StopReason, TokenUsage,
 15};
 16use project::{LspStore, Project, ProjectPath};
 17use serde::{Deserialize, Serialize};
 18use std::fmt::Write as _;
 19use std::fs::File;
 20use std::io::Write as _;
 21use std::sync::{Arc, Mutex};
 22use std::time::Duration;
 23use std::{
 24    fs,
 25    path::{Path, PathBuf},
 26};
 27use unindent::Unindent as _;
 28use util::ResultExt as _;
 29use util::command::new_smol_command;
 30use util::serde::default_true;
 31
 32use crate::AgentAppState;
 33
 34pub const EXAMPLES_DIR: &str = "./crates/eval/examples";
 35pub const REPOS_DIR: &str = "./crates/eval/repos";
 36pub const WORKTREES_DIR: &str = "./crates/eval/worktrees";
 37
 38const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2);
 39
 40#[derive(Clone, Debug, Deserialize)]
 41pub struct ExampleBase {
 42    pub url: String,
 43    pub revision: String,
 44    pub language_extension: Option<String>,
 45    pub insert_id: Option<String>,
 46    #[serde(default = "default_true")]
 47    pub require_lsp: bool,
 48}
 49
 50#[derive(Clone, Debug)]
 51pub struct Example {
 52    pub name: String,
 53    /// Content of `base.toml`
 54    pub base: ExampleBase,
 55    /// Content of `prompt.md`
 56    pub prompt: String,
 57    /// Content of `criteria.md`
 58    pub criteria: String,
 59    /// Markdown output file to append to
 60    pub output_file: Option<Arc<Mutex<File>>>,
 61    /// Path to the output run directory.
 62    pub run_dir: PathBuf,
 63    /// Path to markdown output file
 64    pub output_file_path: PathBuf,
 65    /// Prefix used for logging that identifies this example
 66    pub log_prefix: String,
 67}
 68
 69#[derive(Debug, Serialize, Deserialize, Clone)]
 70pub struct RunOutput {
 71    pub repository_diff: String,
 72    pub diagnostics: String,
 73    pub response_count: usize,
 74    pub token_usage: TokenUsage,
 75    pub tool_use_counts: HashMap<Arc<str>, u32>,
 76}
 77
 78#[derive(Debug, Clone, Serialize, Deserialize)]
 79pub struct JudgeInput {
 80    pub repository_diff: String,
 81    pub criteria: String,
 82}
 83
 84#[derive(Debug, Clone, Serialize, Deserialize)]
 85pub struct JudgeOutput {
 86    pub analysis: String,
 87    pub score: u32,
 88}
 89
 90impl Example {
 91    /// Load an example from a directory containing base.toml, prompt.md, and criteria.md
 92    pub fn load_from_directory(dir_path: &Path, run_dir: &Path) -> Result<Self> {
 93        let name = Self::name_from_path(dir_path);
 94        let base_path = dir_path.join("base.toml");
 95        let prompt_path = dir_path.join("prompt.md");
 96        let criteria_path = dir_path.join("criteria.md");
 97        let output_file_path = run_dir.join(format!("{}.md", name));
 98
 99        Ok(Example {
100            name: name.clone(),
101            base: toml::from_str(&fs::read_to_string(&base_path)?)?,
102            prompt: fs::read_to_string(prompt_path.clone())?,
103            criteria: fs::read_to_string(criteria_path.clone())?,
104            run_dir: run_dir.to_path_buf(),
105            output_file: None,
106            output_file_path,
107            log_prefix: name,
108        })
109    }
110
111    pub fn set_repetition_number(&mut self, repetition_number: u32) {
112        if repetition_number > 0 {
113            self.name = format!("{}-{}", self.name, repetition_number);
114            self.output_file_path = self.run_dir.join(format!("{}.md", self.name));
115        }
116    }
117
118    pub fn set_log_prefix_style(&mut self, color: &str, name_width: usize) {
119        self.log_prefix = format!(
120            "{}{:<width$}\x1b[0m | ",
121            color,
122            self.name,
123            width = name_width
124        );
125    }
126
127    pub fn name_from_path(path: &Path) -> String {
128        path.file_name().unwrap().to_string_lossy().to_string()
129    }
130
131    pub fn worktree_path(&self) -> PathBuf {
132        Path::new(WORKTREES_DIR)
133            .canonicalize()
134            .context(format!("No such directory {WORKTREES_DIR}"))
135            .unwrap()
136            .join(&self.name)
137    }
138
139    /// Set up the example by checking out the specified Git revision
140    pub async fn setup(&mut self) -> Result<()> {
141        let repo_path = repo_path_for_url(&self.base.url);
142
143        let revision_exists = run_git(
144            &repo_path,
145            &["rev-parse", &format!("{}^{{commit}}", self.base.revision)],
146        )
147        .await
148        .is_ok();
149
150        if !revision_exists {
151            println!(
152                "{}Fetching revision {}",
153                self.log_prefix, &self.base.revision
154            );
155            run_git(
156                &repo_path,
157                &["fetch", "--depth", "1", "origin", &self.base.revision],
158            )
159            .await?;
160        }
161
162        let worktree_path = self.worktree_path();
163
164        if worktree_path.is_dir() {
165            println!("{}Resetting existing worktree", self.log_prefix);
166
167            // TODO: consider including "-x" to remove ignored files. The downside of this is that
168            // it will also remove build artifacts, and so prevent incremental reuse there.
169            run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
170            run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
171            run_git(&worktree_path, &["checkout", &self.base.revision]).await?;
172        } else {
173            println!("{}Creating worktree", self.log_prefix);
174
175            let worktree_path_string = worktree_path.to_string_lossy().to_string();
176
177            run_git(
178                &repo_path,
179                &[
180                    "worktree",
181                    "add",
182                    "-f",
183                    &worktree_path_string,
184                    &self.base.revision,
185                ],
186            )
187            .await?;
188        }
189
190        // Create the output file
191        let output_file = Arc::new(Mutex::new(File::create(&self.output_file_path)?));
192        self.output_file = Some(output_file);
193
194        Ok(())
195    }
196
197    /// Returns the output file, panicking if it's not set
198    fn output_file(&self) -> Arc<Mutex<File>> {
199        self.output_file
200            .clone()
201            .expect("Output file not created. Call setup() first.")
202    }
203
204    pub fn run(
205        &self,
206        model: Arc<dyn LanguageModel>,
207        app_state: Arc<AgentAppState>,
208        cx: &mut App,
209    ) -> Task<Result<RunOutput>> {
210        let project = Project::local(
211            app_state.client.clone(),
212            app_state.node_runtime.clone(),
213            app_state.user_store.clone(),
214            app_state.languages.clone(),
215            Arc::new(DapRegistry::default()),
216            app_state.fs.clone(),
217            None,
218            cx,
219        );
220
221        let worktree_path = self.worktree_path();
222        let worktree = project.update(cx, |project, cx| {
223            project.create_worktree(&worktree_path, true, cx)
224        });
225
226        let tools = cx.new(|_| ToolWorkingSet::default());
227        let thread_store =
228            ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx);
229        let this = self.clone();
230
231        cx.spawn(async move |cx| {
232            let worktree = worktree.await?;
233
234            // Wait for worktree scan to finish before choosing a file to open.
235            worktree
236                .update(cx, |worktree, _cx| {
237                    worktree.as_local().unwrap().scan_complete()
238                })?
239                .await;
240
241            let lsp_open_handle_and_store = if this.base.require_lsp {
242                let language_extension = this.base.language_extension.as_deref().context(
243                    "language_extension field is required in base.toml when `require_lsp == true`",
244                )?;
245
246                // Open a file that matches the language to cause LSP to start.
247                let language_file = worktree.read_with(cx, |worktree, _cx| {
248                    worktree
249                        .files(false, 0)
250                        .find_map(|e| {
251                            if e.path.clone().extension().and_then(|ext| ext.to_str())
252                                == Some(language_extension)
253                            {
254                                Some(ProjectPath {
255                                    worktree_id: worktree.id(),
256                                    path: e.path.clone(),
257                                })
258                            } else {
259                                None
260                            }
261                        })
262                        .context("Failed to find a file for example language")
263                })??;
264
265                let open_language_file_buffer_task = project.update(cx, |project, cx| {
266                    project.open_buffer(language_file.clone(), cx)
267                })?;
268
269                let language_file_buffer = open_language_file_buffer_task.await?;
270
271                let (lsp_open_handle, lsp_store) = project.update(cx, |project, cx| {
272                    (
273                        project.register_buffer_with_language_servers(&language_file_buffer, cx),
274                        project.lsp_store().clone(),
275                    )
276                })?;
277
278                // TODO: remove this once the diagnostics tool waits for new diagnostics
279                cx.background_executor().timer(Duration::new(5, 0)).await;
280                wait_for_lang_server(&lsp_store, this.log_prefix.clone(), cx).await?;
281
282                lsp_store.update(cx, |lsp_store, cx| {
283                    lsp_open_handle.update(cx, |buffer, cx| {
284                        buffer.update(cx, |buffer, cx| {
285                            let has_language_server = lsp_store
286                                .language_servers_for_local_buffer(buffer, cx)
287                                .next()
288                                .is_some();
289                            if has_language_server {
290                                Ok(())
291                            } else {
292                                Err(anyhow!(
293                                    "`{:?}` was opened to cause the language server to start, \
294                                    but no language servers are registered for its buffer. \
295                                    Set `require_lsp = false` in `base.toml` to skip this.",
296                                    language_file
297                                ))
298                            }
299                        })
300                    })
301                })??;
302
303                Some((lsp_open_handle, lsp_store))
304            } else {
305                None
306            };
307
308            if std::env::var("ZED_EVAL_SETUP_ONLY").is_ok() {
309                return Err(anyhow!("Setup only mode"));
310            }
311
312            let thread_store = thread_store.await?;
313            let thread =
314                thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?;
315
316            {
317                let output_file_ref = this.output_file();
318                let mut output_file = output_file_ref.lock().unwrap();
319                writeln!(&mut output_file, "👤 USER:").log_err();
320                writeln!(&mut output_file, "{}", this.prompt).log_err();
321                writeln!(&mut output_file, "🤖 ASSISTANT:").log_err();
322                output_file.flush().log_err();
323            }
324
325            let tool_use_counts: Arc<Mutex<HashMap<Arc<str>, u32>>> =
326                Mutex::new(HashMap::default()).into();
327
328            let (thread_event_tx, mut thread_event_rx) = mpsc::unbounded();
329
330            let subscription = cx.subscribe(&thread, move |_thread, event: &ThreadEvent, _cx| {
331                thread_event_tx.unbounded_send(event.clone()).log_err();
332            });
333
334            let event_handler_task = cx.spawn({
335                // Need to clone the Arc here because the reference from output_file() won't live long enough
336                let output_file = this.output_file.clone().unwrap();
337                let log_prefix = this.log_prefix.clone();
338                let tool_use_counts = tool_use_counts.clone();
339                let thread = thread.downgrade();
340                async move |cx| {
341                    loop {
342                        let event = select_biased! {
343                            event = thread_event_rx.next() => event,
344                            _ = cx.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => {
345                                return Err(anyhow!("Agentic loop stalled - waited {:?} without any events", THREAD_EVENT_TIMEOUT));
346                            }
347                        };
348                        let Some(event) = event else {
349                            return Err(anyhow!("ThreadEvent channel ended early"));
350                        };
351
352                        let mut output_file = output_file.lock().unwrap();
353
354                        match event {
355                            ThreadEvent::Stopped(reason) => match reason {
356                                Ok(StopReason::EndTurn) => {
357                                    return Ok(());
358                                }
359                                Ok(StopReason::MaxTokens) => {
360                                    return Err(anyhow!("Exceeded maximum tokens"));
361                                }
362                                Ok(StopReason::ToolUse) => {
363                                    if std::env::var("ZED_EVAL_DEBUG").is_ok() {
364                                        println!("{}StopReason: Tool use", log_prefix);
365                                    }
366                                }
367                                Err(error) => {
368                                    return Err(anyhow!(error.clone()));
369                                }
370                            },
371                            ThreadEvent::ShowError(thread_error) => {
372                                break Err(anyhow!(thread_error.clone()));
373                            }
374                            ThreadEvent::StreamedAssistantText(_, chunk) => {
375                                write!(&mut output_file, "{}", chunk).log_err();
376                            }
377                            ThreadEvent::StreamedAssistantThinking(_, chunk) => {
378                                write!(&mut output_file, "{}", chunk).log_err();
379                            }
380                            ThreadEvent::UsePendingTools { tool_uses } => {
381                                writeln!(&mut output_file, "\n\nUSING TOOLS:").log_err();
382                                for tool_use in tool_uses {
383                                    writeln!(&mut output_file, "{}: {}", tool_use.name, tool_use.input)
384                                        .log_err();
385                                }
386                            }
387                            ThreadEvent::ToolFinished {
388                                tool_use_id,
389                                pending_tool_use,
390                                ..
391                            } => {
392                                thread.update(cx, |thread, _cx| {
393                                    if let Some(tool_use) = pending_tool_use {
394                                        if let Some(tool_result) = thread.tool_result(&tool_use_id) {
395                                            let message = if tool_result.is_error {
396                                                format!("TOOL FAILED: {}", tool_use.name)
397                                            } else {
398                                                format!("TOOL FINISHED: {}", tool_use.name)
399                                            };
400                                            println!("{log_prefix}{message}");
401                                            writeln!(&mut output_file, "\n{}", message).log_err();
402                                            writeln!(&mut output_file, "\n{}\n", tool_result.content).log_err();
403                                            let mut tool_use_counts = tool_use_counts.lock().unwrap();
404                                            *tool_use_counts
405                                                .entry(tool_result.tool_name.clone())
406                                                .or_insert(0) += 1;
407                                        } else {
408                                            let message = format!("TOOL FINISHED WITHOUT RESULT: {}", tool_use.name);
409                                            println!("{log_prefix}{message}");
410                                            writeln!(&mut output_file, "\n{}", message).log_err();
411                                        }
412                                    }
413                                })?;
414                            }
415                            ThreadEvent::ToolConfirmationNeeded => {
416                                panic!("{}Bug: Tool confirmation should not be required in eval", log_prefix);
417                            },
418                            ThreadEvent::StreamedCompletion |
419                            ThreadEvent::MessageAdded(_) |
420                            ThreadEvent::MessageEdited(_) |
421                            ThreadEvent::MessageDeleted(_) |
422                            ThreadEvent::SummaryChanged |
423                            ThreadEvent::SummaryGenerated |
424                            ThreadEvent::CheckpointChanged |
425                            ThreadEvent::UsageUpdated(_) => {
426                                if std::env::var("ZED_EVAL_DEBUG").is_ok() {
427                                    println!("{}Event: {:#?}", log_prefix, event);
428                                }
429                            }
430                        }
431
432                        output_file.flush().log_err();
433                    }
434                }
435            });
436
437            thread.update(cx, |thread, cx| {
438                let context = vec![];
439                thread.insert_user_message(this.prompt.clone(), context, None, cx);
440                thread.send_to_model(model, RequestKind::Chat, cx);
441            })?;
442
443            event_handler_task.await?;
444
445            println!("{}Stopped", this.log_prefix);
446
447            if let Some((_, lsp_store)) = lsp_open_handle_and_store.as_ref() {
448                wait_for_lang_server(lsp_store, this.log_prefix.clone(), cx).await?;
449            }
450
451            println!("{}Getting repository diff", this.log_prefix);
452            let repository_diff = this.repository_diff().await?;
453
454            let repository_diff_path = this.run_dir.join(format!("{}.diff", this.name));
455            let mut repository_diff_output_file = File::create(&repository_diff_path)?;
456            writeln!(&mut repository_diff_output_file, "{}", &repository_diff).log_err();
457
458            println!("{}Getting diagnostics", this.log_prefix);
459            let diagnostics = cx
460                .update(move |cx| {
461                    cx.spawn(async move |cx| query_lsp_diagnostics(project, cx).await)
462                })?
463                .await?;
464            println!("{}Got diagnostics", this.log_prefix);
465
466            drop(subscription);
467            drop(lsp_open_handle_and_store);
468
469            thread.update(cx, |thread, _cx| {
470                let response_count = thread
471                    .messages()
472                    .filter(|message| message.role == language_model::Role::Assistant)
473                    .count();
474                RunOutput {
475                    repository_diff,
476                    diagnostics,
477                    response_count,
478                    token_usage: thread.cumulative_token_usage(),
479                    tool_use_counts: tool_use_counts.lock().unwrap().clone(),
480                }
481            })
482        })
483    }
484
485    pub async fn judge(
486        &self,
487        model: Arc<dyn LanguageModel>,
488        repository_diff: String,
489        judge_repetitions: u32,
490        cx: &AsyncApp,
491    ) -> Result<JudgeOutput> {
492        let judge_prompt = include_str!("judge_prompt.hbs");
493        let judge_prompt_name = "judge_prompt";
494        let mut handlebars = Handlebars::new();
495        handlebars.register_template_string(judge_prompt_name, judge_prompt)?;
496        let prompt = handlebars.render(
497            judge_prompt_name,
498            &JudgeInput {
499                repository_diff,
500                criteria: self.criteria.clone(),
501            },
502        )?;
503
504        let request = LanguageModelRequest {
505            thread_id: None,
506            prompt_id: None,
507            messages: vec![LanguageModelRequestMessage {
508                role: Role::User,
509                content: vec![MessageContent::Text(prompt)],
510                cache: false,
511            }],
512            temperature: None,
513            tools: Vec::new(),
514            stop: Vec::new(),
515        };
516
517        let response = send_language_model_request(model, request, cx).await?;
518
519        let judge_file_path = self.run_dir.join(format!(
520            "{}_judge_{}.md",
521            self.name, // This is the eval_name
522            judge_repetitions
523        ));
524
525        let mut judge_output_file = File::create(&judge_file_path)?;
526        writeln!(&mut judge_output_file, "{}", &response).log_err();
527
528        parse_judge_output(&response)
529    }
530
531    pub async fn repository_diff(&self) -> Result<String> {
532        let worktree_path = self.worktree_path();
533        run_git(&worktree_path, &["add", "-N"]).await?;
534        run_git(&worktree_path, &["diff"]).await
535    }
536}
537
538fn wait_for_lang_server(
539    lsp_store: &Entity<LspStore>,
540    log_prefix: String,
541    cx: &mut AsyncApp,
542) -> Task<Result<()>> {
543    if cx
544        .update(|cx| !has_pending_lang_server_work(lsp_store, cx))
545        .unwrap()
546        || std::env::var("ZED_EVAL_SKIP_LS_WAIT").is_ok()
547    {
548        return Task::ready(anyhow::Ok(()));
549    }
550
551    println!("{}⏵ Waiting for language server", log_prefix);
552
553    let (mut tx, mut rx) = mpsc::channel(1);
554
555    let subscription =
556        cx.subscribe(&lsp_store, {
557            let log_prefix = log_prefix.clone();
558            move |lsp_store, event, cx| {
559                match event {
560                    project::LspStoreEvent::LanguageServerUpdate {
561                        message:
562                            client::proto::update_language_server::Variant::WorkProgress(
563                                LspWorkProgress {
564                                    message: Some(message),
565                                    ..
566                                },
567                            ),
568                        ..
569                    } => println!("{}{message}", log_prefix),
570                    _ => {}
571                }
572
573                if !has_pending_lang_server_work(&lsp_store, cx) {
574                    tx.try_send(()).ok();
575                }
576            }
577        });
578
579    cx.spawn(async move |cx| {
580        let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0));
581        let result = futures::select! {
582            _ = rx.next() => {
583                println!("{}⚑ Language server idle", log_prefix);
584                anyhow::Ok(())
585            },
586            _ = timeout.fuse() => {
587                Err(anyhow!("LSP wait timed out after 5 minutes"))
588            }
589        };
590        drop(subscription);
591        result
592    })
593}
594
595fn has_pending_lang_server_work(lsp_store: &Entity<LspStore>, cx: &App) -> bool {
596    lsp_store
597        .read(cx)
598        .language_server_statuses()
599        .any(|(_, status)| !status.pending_work.is_empty())
600}
601
602async fn query_lsp_diagnostics(project: Entity<Project>, cx: &mut AsyncApp) -> Result<String> {
603    let paths_with_diagnostics = project.update(cx, |project, cx| {
604        project
605            .diagnostic_summaries(true, cx)
606            .filter(|(_, _, summary)| summary.error_count > 0 || summary.warning_count > 0)
607            .map(|(project_path, _, _)| project_path)
608            .collect::<Vec<_>>()
609    })?;
610
611    let mut output = String::new();
612    for project_path in paths_with_diagnostics {
613        let buffer = project
614            .update(cx, |project, cx| project.open_buffer(project_path, cx))?
615            .await?;
616        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
617
618        for (_, group) in snapshot.diagnostic_groups(None) {
619            let entry = &group.entries[group.primary_ix];
620            let range = entry.range.to_point(&snapshot);
621            let severity = match entry.diagnostic.severity {
622                DiagnosticSeverity::ERROR => "error",
623                DiagnosticSeverity::WARNING => "warning",
624                _ => continue,
625            };
626
627            writeln!(
628                output,
629                "{} at line {}: {}",
630                severity,
631                range.start.row + 1,
632                entry.diagnostic.message
633            )?;
634        }
635    }
636    anyhow::Ok(output)
637}
638
639fn parse_judge_output(response: &str) -> Result<JudgeOutput> {
640    let analysis = get_tag("analysis", response)?.to_string();
641    let score = get_tag("score", response)?
642        .parse()
643        .context("error parsing score")?;
644
645    Ok(JudgeOutput { analysis, score })
646}
647
648fn get_tag(name: &'static str, response: &str) -> Result<String> {
649    let start_tag = format!("<{}>", name);
650    let end_tag = format!("</{}>", name);
651
652    let start_ix = response
653        .find(&start_tag)
654        .context(format!("{} start tag not found", name))?;
655    let content_start_ix = start_ix + start_tag.len();
656
657    let end_ix = content_start_ix
658        + response[content_start_ix..]
659            .find(&end_tag)
660            .context(format!("{} end tag not found", name))?;
661
662    let content = response[content_start_ix..end_ix].trim().unindent();
663
664    anyhow::Ok(content)
665}
666
667pub fn repo_path_for_url(repo_url: &str) -> PathBuf {
668    let repo_name = repo_url
669        .trim_start_matches("https://")
670        .replace(|c: char| !c.is_alphanumeric(), "-");
671    Path::new(REPOS_DIR)
672        .canonicalize()
673        .context(format!("No such directory {REPOS_DIR}"))
674        .unwrap()
675        .join(repo_name)
676}
677
678pub async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
679    let output = new_smol_command("git")
680        .current_dir(repo_path)
681        .args(args)
682        .output()
683        .await?;
684
685    if output.status.success() {
686        Ok(String::from_utf8(output.stdout)?.trim().to_string())
687    } else {
688        Err(anyhow!(
689            "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
690            args.join(" "),
691            repo_path.display(),
692            output.status,
693            String::from_utf8_lossy(&output.stderr),
694            String::from_utf8_lossy(&output.stdout),
695        ))
696    }
697}
698
699pub async fn send_language_model_request(
700    model: Arc<dyn LanguageModel>,
701    request: LanguageModelRequest,
702    cx: &AsyncApp,
703) -> anyhow::Result<String> {
704    match model.stream_completion_text(request, &cx).await {
705        Ok(mut stream) => {
706            let mut full_response = String::new();
707            while let Some(chunk_result) = stream.stream.next().await {
708                match chunk_result {
709                    Ok(chunk_str) => {
710                        full_response.push_str(&chunk_str);
711                    }
712                    Err(err) => {
713                        return Err(anyhow!(
714                            "Error receiving response from language model: {err}"
715                        ));
716                    }
717                }
718            }
719            Ok(full_response)
720        }
721        Err(err) => Err(anyhow!(
722            "Failed to get response from language model. Error was: {err}"
723        )),
724    }
725}
726
727#[cfg(test)]
728mod test {
729    use super::*;
730
731    #[test]
732    fn test_parse_judge_output() {
733        let response = r#"
734            <analysis>The model did a good job but there were still compilations errors.</analysis>
735            <score>3</score>
736        "#
737        .unindent();
738
739        let output = parse_judge_output(&response).unwrap();
740        assert_eq!(
741            output.analysis,
742            "The model did a good job but there were still compilations errors."
743        );
744        assert_eq!(output.score, 3);
745
746        let response = r#"
747            Text around ignored
748
749            <analysis>
750                Failed to compile:
751                - Error 1
752                - Error 2
753            </analysis>
754
755            <score>1</score>
756        "#
757        .unindent();
758
759        let output = parse_judge_output(&response).unwrap();
760        assert_eq!(output.analysis, "Failed to compile:\n- Error 1\n- Error 2");
761        assert_eq!(output.score, 1);
762    }
763}