@@ -15,7 +15,7 @@ use gpui::{AppContext, TestAppContext};
use indoc::{formatdoc, indoc};
use language_model::{
LanguageModelRegistry, LanguageModelRequestTool, LanguageModelToolResult,
- LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId,
+ LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, SelectedModel,
};
use project::Project;
use rand::prelude::*;
@@ -25,6 +25,7 @@ use std::{
cmp::Reverse,
fmt::{self, Display},
io::Write as _,
+ str::FromStr,
sync::mpsc,
};
use util::path;
@@ -1216,7 +1217,7 @@ fn report_progress(evaluated_count: usize, failed_count: usize, iterations: usiz
passed_count as f64 / evaluated_count as f64
};
print!(
- "\r\x1b[KEvaluated {}/{} ({:.2}%)",
+ "\r\x1b[KEvaluated {}/{} ({:.2}% passed)",
evaluated_count,
iterations,
passed_ratio * 100.0
@@ -1255,13 +1256,21 @@ impl EditAgentTest {
fs.insert_tree("/root", json!({})).await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
+ let agent_model = SelectedModel::from_str(
+ &std::env::var("ZED_AGENT_MODEL")
+ .unwrap_or("anthropic/claude-3-7-sonnet-latest".into()),
+ )
+ .unwrap();
+ let judge_model = SelectedModel::from_str(
+ &std::env::var("ZED_JUDGE_MODEL")
+ .unwrap_or("anthropic/claude-3-7-sonnet-latest".into()),
+ )
+ .unwrap();
let (agent_model, judge_model) = cx
.update(|cx| {
cx.spawn(async move |cx| {
- let agent_model =
- Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await;
- let judge_model =
- Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await;
+ let agent_model = Self::load_model(&agent_model, cx).await;
+ let judge_model = Self::load_model(&judge_model, cx).await;
(agent_model.unwrap(), judge_model.unwrap())
})
})
@@ -1276,15 +1285,17 @@ impl EditAgentTest {
}
async fn load_model(
- provider: &str,
- id: &str,
+ selected_model: &SelectedModel,
cx: &mut AsyncApp,
) -> Result<Arc<dyn LanguageModel>> {
let (provider, model) = cx.update(|cx| {
let models = LanguageModelRegistry::read_global(cx);
let model = models
.available_models(cx)
- .find(|model| model.provider_id().0 == provider && model.id().0 == id)
+ .find(|model| {
+ model.provider_id() == selected_model.provider
+ && model.id() == selected_model.model
+ })
.unwrap();
let provider = models.provider(&model.provider_id()).unwrap();
(provider, model)
@@ -4,7 +4,7 @@ use crate::{
};
use collections::BTreeMap;
use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
-use std::sync::Arc;
+use std::{str::FromStr, sync::Arc};
use util::maybe;
pub fn init(cx: &mut App) {
@@ -27,11 +27,36 @@ pub struct LanguageModelRegistry {
inline_alternatives: Vec<Arc<dyn LanguageModel>>,
}
+#[derive(Debug)]
pub struct SelectedModel {
pub provider: LanguageModelProviderId,
pub model: LanguageModelId,
}
+impl FromStr for SelectedModel {
+ type Err = String;
+
+ /// Parse string identifiers like `provider_id/model_id` into a `SelectedModel`
+ fn from_str(id: &str) -> Result<SelectedModel, Self::Err> {
+ let parts: Vec<&str> = id.split('/').collect();
+ let [provider_id, model_id] = parts.as_slice() else {
+ return Err(format!(
+ "Invalid model identifier format: `{}`. Expected `provider_id/model_id`",
+ id
+ ));
+ };
+
+ if provider_id.is_empty() || model_id.is_empty() {
+ return Err(format!("Provider and model ids can't be empty: `{}`", id));
+ }
+
+ Ok(SelectedModel {
+ provider: LanguageModelProviderId(provider_id.to_string().into()),
+ model: LanguageModelId(model_id.to_string().into()),
+ })
+ }
+}
+
#[derive(Clone)]
pub struct ConfiguredModel {
pub provider: Arc<dyn LanguageModelProvider>,