evals: Make LLMs configurable in edit_agent evals (#30813)

Oleksiy Syvokon created

Release Notes:

- N/A

Change summary

crates/assistant_tools/src/edit_agent/evals.rs | 29 +++++++++++++------
crates/language_model/src/registry.rs          | 27 +++++++++++++++++
2 files changed, 46 insertions(+), 10 deletions(-)

Detailed changes

crates/assistant_tools/src/edit_agent/evals.rs 🔗

@@ -15,7 +15,7 @@ use gpui::{AppContext, TestAppContext};
 use indoc::{formatdoc, indoc};
 use language_model::{
     LanguageModelRegistry, LanguageModelRequestTool, LanguageModelToolResult,
-    LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId,
+    LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, SelectedModel,
 };
 use project::Project;
 use rand::prelude::*;
@@ -25,6 +25,7 @@ use std::{
     cmp::Reverse,
     fmt::{self, Display},
     io::Write as _,
+    str::FromStr,
     sync::mpsc,
 };
 use util::path;
@@ -1216,7 +1217,7 @@ fn report_progress(evaluated_count: usize, failed_count: usize, iterations: usiz
         passed_count as f64 / evaluated_count as f64
     };
     print!(
-        "\r\x1b[KEvaluated {}/{} ({:.2}%)",
+        "\r\x1b[KEvaluated {}/{} ({:.2}% passed)",
         evaluated_count,
         iterations,
         passed_ratio * 100.0
@@ -1255,13 +1256,21 @@ impl EditAgentTest {
 
         fs.insert_tree("/root", json!({})).await;
         let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
+        let agent_model = SelectedModel::from_str(
+            &std::env::var("ZED_AGENT_MODEL")
+                .unwrap_or("anthropic/claude-3-7-sonnet-latest".into()),
+        )
+        .unwrap();
+        let judge_model = SelectedModel::from_str(
+            &std::env::var("ZED_JUDGE_MODEL")
+                .unwrap_or("anthropic/claude-3-7-sonnet-latest".into()),
+        )
+        .unwrap();
         let (agent_model, judge_model) = cx
             .update(|cx| {
                 cx.spawn(async move |cx| {
-                    let agent_model =
-                        Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await;
-                    let judge_model =
-                        Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await;
+                    let agent_model = Self::load_model(&agent_model, cx).await;
+                    let judge_model = Self::load_model(&judge_model, cx).await;
                     (agent_model.unwrap(), judge_model.unwrap())
                 })
             })
@@ -1276,15 +1285,17 @@ impl EditAgentTest {
     }
 
     async fn load_model(
-        provider: &str,
-        id: &str,
+        selected_model: &SelectedModel,
         cx: &mut AsyncApp,
     ) -> Result<Arc<dyn LanguageModel>> {
         let (provider, model) = cx.update(|cx| {
             let models = LanguageModelRegistry::read_global(cx);
             let model = models
                 .available_models(cx)
-                .find(|model| model.provider_id().0 == provider && model.id().0 == id)
+                .find(|model| {
+                    model.provider_id() == selected_model.provider
+                        && model.id() == selected_model.model
+                })
                 .unwrap();
             let provider = models.provider(&model.provider_id()).unwrap();
             (provider, model)

crates/language_model/src/registry.rs 🔗

@@ -4,7 +4,7 @@ use crate::{
 };
 use collections::BTreeMap;
 use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
-use std::sync::Arc;
+use std::{str::FromStr, sync::Arc};
 use util::maybe;
 
 pub fn init(cx: &mut App) {
@@ -27,11 +27,36 @@ pub struct LanguageModelRegistry {
     inline_alternatives: Vec<Arc<dyn LanguageModel>>,
 }
 
+#[derive(Debug)]
 pub struct SelectedModel {
     pub provider: LanguageModelProviderId,
     pub model: LanguageModelId,
 }
 
+impl FromStr for SelectedModel {
+    type Err = String;
+
+    /// Parse string identifiers like `provider_id/model_id` into a `SelectedModel`
+    fn from_str(id: &str) -> Result<SelectedModel, Self::Err> {
+        let parts: Vec<&str> = id.split('/').collect();
+        let [provider_id, model_id] = parts.as_slice() else {
+            return Err(format!(
+                "Invalid model identifier format: `{}`. Expected `provider_id/model_id`",
+                id
+            ));
+        };
+
+        if provider_id.is_empty() || model_id.is_empty() {
+            return Err(format!("Provider and model ids can't be empty: `{}`", id));
+        }
+
+        Ok(SelectedModel {
+            provider: LanguageModelProviderId(provider_id.to_string().into()),
+            model: LanguageModelId(model_id.to_string().into()),
+        })
+    }
+}
+
 #[derive(Clone)]
 pub struct ConfiguredModel {
     pub provider: Arc<dyn LanguageModelProvider>,