evals: Configurable judge model (#31282)

Oleksiy Syvokon created

This is needed for apples-to-apples comparison of different agent
models.

Another change is that now `cargo -p eval` accepts model names as
`provider_id/model_id` instead of separate `--provider` and `--model`
params.


Release Notes:

- N/A

Change summary

crates/eval/src/eval.rs | 98 ++++++++++++++++++++++--------------------
1 file changed, 51 insertions(+), 47 deletions(-)

Detailed changes

crates/eval/src/eval.rs 🔗

@@ -20,7 +20,7 @@ use gpui::http_client::read_proxy_from_env;
 use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, UpdateGlobal};
 use gpui_tokio::Tokio;
 use language::LanguageRegistry;
-use language_model::{ConfiguredModel, LanguageModel, LanguageModelRegistry};
+use language_model::{ConfiguredModel, LanguageModel, LanguageModelRegistry, SelectedModel};
 use node_runtime::{NodeBinaryOptions, NodeRuntime};
 use project::Project;
 use project::project_settings::ProjectSettings;
@@ -33,6 +33,7 @@ use std::collections::VecDeque;
 use std::env;
 use std::path::{Path, PathBuf};
 use std::rc::Rc;
+use std::str::FromStr;
 use std::sync::{Arc, LazyLock};
 use util::ResultExt as _;
 
@@ -45,12 +46,12 @@ struct Args {
     /// Runs all examples and threads that contain these substrings. If unspecified, all examples and threads are run.
     #[arg(value_name = "EXAMPLE_SUBSTRING")]
     filter: Vec<String>,
-    /// ID of model to use.
-    #[arg(long, default_value = "claude-3-7-sonnet-latest")]
+    /// provider/model to use for agent
+    #[arg(long, default_value = "anthropic/claude-3-7-sonnet-latest")]
     model: String,
-    /// Model provider to use.
-    #[arg(long, default_value = "anthropic")]
-    provider: String,
+    /// provider/model to use for judges
+    #[arg(long, default_value = "anthropic/claude-3-7-sonnet-latest")]
+    judge_model: String,
     #[arg(long, value_delimiter = ',', default_value = "rs,ts,py")]
     languages: Vec<String>,
     /// How many times to run each example.
@@ -124,25 +125,19 @@ fn main() {
 
         let mut cumulative_tool_metrics = ToolMetrics::default();
 
-        let model_registry = LanguageModelRegistry::read_global(cx);
-        let model = find_model(&args.provider, &args.model, model_registry, cx).unwrap();
-        let model_provider_id = model.provider_id();
-        let model_provider = model_registry.provider(&model_provider_id).unwrap();
+        let agent_model = load_model(&args.model, cx).unwrap();
+        let judge_model = load_model(&args.judge_model, cx).unwrap();
 
         LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
-            registry.set_default_model(
-                Some(ConfiguredModel {
-                    provider: model_provider.clone(),
-                    model: model.clone(),
-                }),
-                cx,
-            );
+            registry.set_default_model(Some(agent_model.clone()), cx);
         });
 
-        let authenticate_task = model_provider.authenticate(cx);
+        let auth1 = agent_model.provider.authenticate(cx);
+        let auth2 = judge_model.provider.authenticate(cx);
 
         cx.spawn(async move |cx| {
-            authenticate_task.await.unwrap();
+            auth1.await?;
+            auth2.await?;
 
             let mut examples = Vec::new();
 
@@ -273,7 +268,8 @@ fn main() {
 
             future::join_all((0..args.concurrency).map(|_| {
                 let app_state = app_state.clone();
-                let model = model.clone();
+                let model = agent_model.model.clone();
+                let judge_model = judge_model.model.clone();
                 let zed_commit_sha = zed_commit_sha.clone();
                 let zed_branch_name = zed_branch_name.clone();
                 let run_id = run_id.clone();
@@ -291,7 +287,7 @@ fn main() {
                                 .await?;
                             let judge_output = judge_example(
                                 example.clone(),
-                                model.clone(),
+                                judge_model.clone(),
                                 &zed_commit_sha,
                                 &zed_branch_name,
                                 &run_id,
@@ -453,37 +449,45 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
 }
 
 pub fn find_model(
-    provider_id: &str,
-    model_id: &str,
+    model_name: &str,
     model_registry: &LanguageModelRegistry,
     cx: &App,
 ) -> anyhow::Result<Arc<dyn LanguageModel>> {
-    let matching_models = model_registry
+    let selected = SelectedModel::from_str(model_name).map_err(|e| anyhow::anyhow!(e))?;
+    model_registry
         .available_models(cx)
-        .filter(|model| model.id().0 == model_id && model.provider_id().0 == provider_id)
-        .collect::<Vec<_>>();
+        .find(|model| model.id() == selected.model && model.provider_id() == selected.provider)
+        .ok_or_else(|| {
+            anyhow::anyhow!(
+                "No language model with ID {}/{} was available. Available models: {}",
+                selected.model.0,
+                selected.provider.0,
+                model_registry
+                    .available_models(cx)
+                    .map(|model| format!("{}/{}", model.provider_id().0, model.id().0))
+                    .collect::<Vec<_>>()
+                    .join(", ")
+            )
+        })
+}
 
-    match matching_models.as_slice() {
-        [model] => Ok(model.clone()),
-        [] => anyhow::bail!(
-            "No language model with ID {}/{} was available. Available models: {}",
-            provider_id,
-            model_id,
-            model_registry
-                .available_models(cx)
-                .map(|model| format!("{}/{}", model.provider_id().0, model.id().0))
-                .collect::<Vec<_>>()
-                .join(", ")
-        ),
-        _ => anyhow::bail!(
-            "Multiple language models with ID {} available - use `--provider` to choose one of: {:?}",
-            model_id,
-            matching_models
-                .iter()
-                .map(|model| model.provider_id().0)
-                .collect::<Vec<_>>()
-        ),
-    }
+pub fn load_model(model_name: &str, cx: &mut App) -> anyhow::Result<ConfiguredModel> {
+    let model = {
+        let model_registry = LanguageModelRegistry::read_global(cx);
+        find_model(model_name, model_registry, cx)?
+    };
+
+    let provider = {
+        let model_registry = LanguageModelRegistry::read_global(cx);
+        model_registry
+            .provider(&model.provider_id())
+            .ok_or_else(|| anyhow::anyhow!("Provider not found: {}", model.provider_id()))?
+    };
+
+    Ok(ConfiguredModel {
+        provider: provider.clone(),
+        model: model.clone(),
+    })
 }
 
 pub fn commit_sha_for_path(repo_path: &Path) -> String {