@@ -20,7 +20,7 @@ use gpui::http_client::read_proxy_from_env;
use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, UpdateGlobal};
use gpui_tokio::Tokio;
use language::LanguageRegistry;
-use language_model::{ConfiguredModel, LanguageModel, LanguageModelRegistry};
+use language_model::{ConfiguredModel, LanguageModel, LanguageModelRegistry, SelectedModel};
use node_runtime::{NodeBinaryOptions, NodeRuntime};
use project::Project;
use project::project_settings::ProjectSettings;
@@ -33,6 +33,7 @@ use std::collections::VecDeque;
use std::env;
use std::path::{Path, PathBuf};
use std::rc::Rc;
+use std::str::FromStr;
use std::sync::{Arc, LazyLock};
use util::ResultExt as _;
@@ -45,12 +46,12 @@ struct Args {
/// Runs all examples and threads that contain these substrings. If unspecified, all examples and threads are run.
#[arg(value_name = "EXAMPLE_SUBSTRING")]
filter: Vec<String>,
- /// ID of model to use.
- #[arg(long, default_value = "claude-3-7-sonnet-latest")]
+ /// provider/model to use for agent
+ #[arg(long, default_value = "anthropic/claude-3-7-sonnet-latest")]
model: String,
- /// Model provider to use.
- #[arg(long, default_value = "anthropic")]
- provider: String,
+ /// provider/model to use for judges
+ #[arg(long, default_value = "anthropic/claude-3-7-sonnet-latest")]
+ judge_model: String,
#[arg(long, value_delimiter = ',', default_value = "rs,ts,py")]
languages: Vec<String>,
/// How many times to run each example.
@@ -124,25 +125,19 @@ fn main() {
let mut cumulative_tool_metrics = ToolMetrics::default();
- let model_registry = LanguageModelRegistry::read_global(cx);
- let model = find_model(&args.provider, &args.model, model_registry, cx).unwrap();
- let model_provider_id = model.provider_id();
- let model_provider = model_registry.provider(&model_provider_id).unwrap();
+ let agent_model = load_model(&args.model, cx).unwrap();
+ let judge_model = load_model(&args.judge_model, cx).unwrap();
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
- registry.set_default_model(
- Some(ConfiguredModel {
- provider: model_provider.clone(),
- model: model.clone(),
- }),
- cx,
- );
+ registry.set_default_model(Some(agent_model.clone()), cx);
});
- let authenticate_task = model_provider.authenticate(cx);
+ let auth1 = agent_model.provider.authenticate(cx);
+ let auth2 = judge_model.provider.authenticate(cx);
cx.spawn(async move |cx| {
- authenticate_task.await.unwrap();
+ auth1.await?;
+ auth2.await?;
let mut examples = Vec::new();
@@ -273,7 +268,8 @@ fn main() {
future::join_all((0..args.concurrency).map(|_| {
let app_state = app_state.clone();
- let model = model.clone();
+ let model = agent_model.model.clone();
+ let judge_model = judge_model.model.clone();
let zed_commit_sha = zed_commit_sha.clone();
let zed_branch_name = zed_branch_name.clone();
let run_id = run_id.clone();
@@ -291,7 +287,7 @@ fn main() {
.await?;
let judge_output = judge_example(
example.clone(),
- model.clone(),
+ judge_model.clone(),
&zed_commit_sha,
&zed_branch_name,
&run_id,
@@ -453,37 +449,45 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
}
pub fn find_model(
- provider_id: &str,
- model_id: &str,
+ model_name: &str,
model_registry: &LanguageModelRegistry,
cx: &App,
) -> anyhow::Result<Arc<dyn LanguageModel>> {
- let matching_models = model_registry
+ let selected = SelectedModel::from_str(model_name).map_err(|e| anyhow::anyhow!(e))?;
+ model_registry
.available_models(cx)
- .filter(|model| model.id().0 == model_id && model.provider_id().0 == provider_id)
- .collect::<Vec<_>>();
+ .find(|model| model.id() == selected.model && model.provider_id() == selected.provider)
+ .ok_or_else(|| {
+ anyhow::anyhow!(
+ "No language model with ID {}/{} was available. Available models: {}",
+ selected.model.0,
+ selected.provider.0,
+ model_registry
+ .available_models(cx)
+ .map(|model| format!("{}/{}", model.provider_id().0, model.id().0))
+ .collect::<Vec<_>>()
+ .join(", ")
+ )
+ })
+}
- match matching_models.as_slice() {
- [model] => Ok(model.clone()),
- [] => anyhow::bail!(
- "No language model with ID {}/{} was available. Available models: {}",
- provider_id,
- model_id,
- model_registry
- .available_models(cx)
- .map(|model| format!("{}/{}", model.provider_id().0, model.id().0))
- .collect::<Vec<_>>()
- .join(", ")
- ),
- _ => anyhow::bail!(
- "Multiple language models with ID {} available - use `--provider` to choose one of: {:?}",
- model_id,
- matching_models
- .iter()
- .map(|model| model.provider_id().0)
- .collect::<Vec<_>>()
- ),
- }
+pub fn load_model(model_name: &str, cx: &mut App) -> anyhow::Result<ConfiguredModel> {
+ let model = {
+ let model_registry = LanguageModelRegistry::read_global(cx);
+ find_model(model_name, model_registry, cx)?
+ };
+
+ let provider = {
+ let model_registry = LanguageModelRegistry::read_global(cx);
+ model_registry
+ .provider(&model.provider_id())
+ .ok_or_else(|| anyhow::anyhow!("Provider not found: {}", model.provider_id()))?
+ };
+
+ Ok(ConfiguredModel {
+ provider: provider.clone(),
+ model: model.clone(),
+ })
}
pub fn commit_sha_for_path(repo_path: &Path) -> String {