Detailed changes
@@ -5692,6 +5692,61 @@ dependencies = [
"num-traits",
]
+[[package]]
+name = "eval"
+version = "0.1.0"
+dependencies = [
+ "acp_thread",
+ "agent",
+ "agent-client-protocol",
+ "agent_settings",
+ "agent_ui",
+ "anyhow",
+ "async-trait",
+ "buffer_diff",
+ "chrono",
+ "clap",
+ "client",
+ "collections",
+ "debug_adapter_extension",
+ "dirs 4.0.0",
+ "dotenvy",
+ "env_logger 0.11.8",
+ "extension",
+ "fs",
+ "futures 0.3.31",
+ "gpui",
+ "gpui_tokio",
+ "handlebars 4.5.0",
+ "language",
+ "language_extension",
+ "language_model",
+ "language_models",
+ "languages",
+ "markdown",
+ "node_runtime",
+ "pathdiff",
+ "paths",
+ "pretty_assertions",
+ "project",
+ "prompt_store",
+ "rand 0.9.2",
+ "regex",
+ "release_channel",
+ "reqwest_client",
+ "serde",
+ "serde_json",
+ "settings",
+ "shellexpand 2.1.2",
+ "telemetry",
+ "terminal_view",
+ "toml 0.8.23",
+ "unindent",
+ "util",
+ "uuid",
+ "watch",
+]
+
[[package]]
name = "event-listener"
version = "2.5.3"
@@ -58,7 +58,7 @@ members = [
"crates/edit_prediction_context",
"crates/zeta2_tools",
"crates/editor",
- # "crates/eval",
+ "crates/eval",
"crates/explorer_command_injector",
"crates/extension",
"crates/extension_api",
@@ -1,10 +1,15 @@
use agent_client_protocol as acp;
-
+use anyhow::Result;
use futures::{FutureExt as _, future::Shared};
-use gpui::{App, AppContext, Context, Entity, Task};
+use gpui::{App, AppContext, AsyncApp, Context, Entity, Task};
use language::LanguageRegistry;
use markdown::Markdown;
+use project::Project;
+use settings::{Settings as _, SettingsLocation};
use std::{path::PathBuf, process::ExitStatus, sync::Arc, time::Instant};
+use task::Shell;
+use terminal::terminal_settings::TerminalSettings;
+use util::get_default_system_shell_preferring_bash;
pub struct Terminal {
id: acp::TerminalId,
@@ -170,3 +175,68 @@ impl Terminal {
)
}
}
+
+pub async fn create_terminal_entity(
+ command: String,
+ args: &[String],
+ env_vars: Vec<(String, String)>,
+ cwd: Option<PathBuf>,
+ project: &Entity<Project>,
+ cx: &mut AsyncApp,
+) -> Result<Entity<terminal::Terminal>> {
+ let mut env = if let Some(dir) = &cwd {
+ project
+ .update(cx, |project, cx| {
+ let worktree = project.find_worktree(dir.as_path(), cx);
+ let shell = TerminalSettings::get(
+ worktree.as_ref().map(|(worktree, path)| SettingsLocation {
+ worktree_id: worktree.read(cx).id(),
+ path: &path,
+ }),
+ cx,
+ )
+ .shell
+ .clone();
+ project.directory_environment(&shell, dir.clone().into(), cx)
+ })?
+ .await
+ .unwrap_or_default()
+ } else {
+ Default::default()
+ };
+
+ // Disables paging for `git` and hopefully other commands
+ env.insert("PAGER".into(), "".into());
+ env.extend(env_vars);
+
+ // Use remote shell or default system shell, as appropriate
+ let shell = project
+ .update(cx, |project, cx| {
+ project
+ .remote_client()
+ .and_then(|r| r.read(cx).default_system_shell())
+ .map(Shell::Program)
+ })?
+ .unwrap_or_else(|| Shell::Program(get_default_system_shell_preferring_bash()));
+ let is_windows = project
+ .read_with(cx, |project, cx| project.path_style(cx).is_windows())
+ .unwrap_or(cfg!(windows));
+ let (task_command, task_args) = task::ShellBuilder::new(&shell, is_windows)
+ .redirect_stdin_to_dev_null()
+ .build(Some(command.clone()), &args);
+
+ project
+ .update(cx, |project, cx| {
+ project.create_terminal_task(
+ task::SpawnInTerminal {
+ command: Some(task_command),
+ args: task_args,
+ cwd,
+ env,
+ ..Default::default()
+ },
+ cx,
+ )
+ })?
+ .await
+}
@@ -10,6 +10,8 @@ path = "src/agent.rs"
[features]
test-support = ["db/test-support"]
+eval = []
+edit-agent-eval = []
e2e = []
[lints]
@@ -31,7 +31,7 @@ use std::{
use util::path;
#[test]
-#[cfg_attr(not(feature = "eval"), ignore)]
+#[cfg_attr(not(feature = "edit-agent-eval"), ignore)]
fn eval_extract_handle_command_output() {
// Test how well agent generates multiple edit hunks.
//
@@ -108,7 +108,7 @@ fn eval_extract_handle_command_output() {
}
#[test]
-#[cfg_attr(not(feature = "eval"), ignore)]
+#[cfg_attr(not(feature = "edit-agent-eval"), ignore)]
fn eval_delete_run_git_blame() {
// Model | Pass rate
// ----------------------------|----------
@@ -171,7 +171,7 @@ fn eval_delete_run_git_blame() {
}
#[test]
-#[cfg_attr(not(feature = "eval"), ignore)]
+#[cfg_attr(not(feature = "edit-agent-eval"), ignore)]
fn eval_translate_doc_comments() {
// Model | Pass rate
// ============================================
@@ -234,7 +234,7 @@ fn eval_translate_doc_comments() {
}
#[test]
-#[cfg_attr(not(feature = "eval"), ignore)]
+#[cfg_attr(not(feature = "edit-agent-eval"), ignore)]
fn eval_use_wasi_sdk_in_compile_parser_to_wasm() {
// Model | Pass rate
// ============================================
@@ -360,7 +360,7 @@ fn eval_use_wasi_sdk_in_compile_parser_to_wasm() {
}
#[test]
-#[cfg_attr(not(feature = "eval"), ignore)]
+#[cfg_attr(not(feature = "edit-agent-eval"), ignore)]
fn eval_disable_cursor_blinking() {
// Model | Pass rate
// ============================================
@@ -446,7 +446,7 @@ fn eval_disable_cursor_blinking() {
}
#[test]
-#[cfg_attr(not(feature = "eval"), ignore)]
+#[cfg_attr(not(feature = "edit-agent-eval"), ignore)]
fn eval_from_pixels_constructor() {
// Results for 2025-06-13
//
@@ -656,7 +656,7 @@ fn eval_from_pixels_constructor() {
}
#[test]
-#[cfg_attr(not(feature = "eval"), ignore)]
+#[cfg_attr(not(feature = "edit-agent-eval"), ignore)]
fn eval_zode() {
// Model | Pass rate
// ============================================
@@ -763,7 +763,7 @@ fn eval_zode() {
}
#[test]
-#[cfg_attr(not(feature = "eval"), ignore)]
+#[cfg_attr(not(feature = "edit-agent-eval"), ignore)]
fn eval_add_overwrite_test() {
// Model | Pass rate
// ============================================
@@ -995,7 +995,7 @@ fn eval_add_overwrite_test() {
}
#[test]
-#[cfg_attr(not(feature = "eval"), ignore)]
+#[cfg_attr(not(feature = "edit-agent-eval"), ignore)]
fn eval_create_empty_file() {
// Check that Edit Agent can create a file without writing its
// thoughts into it. This issue is not specific to empty files, but
@@ -1490,9 +1490,20 @@ impl EditAgentTest {
&std::env::var("ZED_JUDGE_MODEL").unwrap_or("anthropic/claude-4-sonnet-latest".into()),
)
.unwrap();
+
+ let authenticate_provider_tasks = cx.update(|cx| {
+ LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
+ registry
+ .providers()
+ .iter()
+ .map(|p| p.authenticate(cx))
+ .collect::<Vec<_>>()
+ })
+ });
let (agent_model, judge_model) = cx
.update(|cx| {
cx.spawn(async move |cx| {
+ futures::future::join_all(authenticate_provider_tasks).await;
let agent_model = Self::load_model(&agent_model, cx).await;
let judge_model = Self::load_model(&judge_model, cx).await;
(agent_model.unwrap(), judge_model.unwrap())
@@ -1995,7 +1995,7 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
locations: vec![],
raw_input: Some(json!({})),
raw_output: None,
- meta: None,
+ meta: Some(json!({ "tool_name": "thinking" })),
}
);
let update = expect_tool_call_update_fields(&mut events).await;
@@ -745,7 +745,13 @@ impl Thread {
let title = tool.initial_title(tool_use.input.clone(), cx);
let kind = tool.kind();
- stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
+ stream.send_tool_call(
+ &tool_use.id,
+ &tool_use.name,
+ title,
+ kind,
+ tool_use.input.clone(),
+ );
let output = tool_result
.as_ref()
@@ -1044,14 +1050,18 @@ impl Thread {
Ok(())
}
- pub fn latest_token_usage(&self) -> Option<acp_thread::TokenUsage> {
+ pub fn latest_request_token_usage(&self) -> Option<language_model::TokenUsage> {
let last_user_message = self.last_user_message()?;
let tokens = self.request_token_usage.get(&last_user_message.id)?;
- let model = self.model.clone()?;
+ Some(*tokens)
+ }
+ pub fn latest_token_usage(&self) -> Option<acp_thread::TokenUsage> {
+ let usage = self.latest_request_token_usage()?;
+ let model = self.model.clone()?;
Some(acp_thread::TokenUsage {
max_tokens: model.max_token_count_for_mode(self.completion_mode.into()),
- used_tokens: tokens.total_tokens(),
+ used_tokens: usage.total_tokens(),
})
}
@@ -1094,6 +1104,14 @@ impl Thread {
self.run_turn(cx)
}
+ #[cfg(feature = "eval")]
+ pub fn proceed(
+ &mut self,
+ cx: &mut Context<Self>,
+ ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
+ self.run_turn(cx)
+ }
+
fn run_turn(
&mut self,
cx: &mut Context<Self>,
@@ -1461,7 +1479,13 @@ impl Thread {
});
if push_new_tool_use {
- event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
+ event_stream.send_tool_call(
+ &tool_use.id,
+ &tool_use.name,
+ title,
+ kind,
+ tool_use.input.clone(),
+ );
last_message
.content
.push(AgentMessageContent::ToolUse(tool_use.clone()));
@@ -2256,6 +2280,7 @@ impl ThreadEventStream {
fn send_tool_call(
&self,
id: &LanguageModelToolUseId,
+ tool_name: &str,
title: SharedString,
kind: acp::ToolKind,
input: serde_json::Value,
@@ -2263,6 +2288,7 @@ impl ThreadEventStream {
self.0
.unbounded_send(Ok(ThreadEvent::ToolCall(Self::initial_tool_call(
id,
+ tool_name,
title.to_string(),
kind,
input,
@@ -2272,12 +2298,15 @@ impl ThreadEventStream {
fn initial_tool_call(
id: &LanguageModelToolUseId,
+ tool_name: &str,
title: String,
kind: acp::ToolKind,
input: serde_json::Value,
) -> acp::ToolCall {
acp::ToolCall {
- meta: None,
+ meta: Some(serde_json::json!({
+ "tool_name": tool_name
+ })),
id: acp::ToolCallId(id.to_string().into()),
title,
kind,
@@ -9,9 +9,7 @@ use futures::io::BufReader;
use project::Project;
use project::agent_server_store::AgentServerCommand;
use serde::Deserialize;
-use settings::{Settings as _, SettingsLocation};
-use task::Shell;
-use util::{ResultExt as _, get_default_system_shell_preferring_bash};
+use util::ResultExt as _;
use std::path::PathBuf;
use std::{any::Any, cell::RefCell};
@@ -23,7 +21,7 @@ use gpui::{App, AppContext as _, AsyncApp, Entity, SharedString, Task, WeakEntit
use acp_thread::{AcpThread, AuthRequired, LoadError, TerminalProviderEvent};
use terminal::TerminalBuilder;
-use terminal::terminal_settings::{AlternateScroll, CursorShape, TerminalSettings};
+use terminal::terminal_settings::{AlternateScroll, CursorShape};
#[derive(Debug, Error)]
#[error("Unsupported version")]
@@ -816,62 +814,18 @@ impl acp::Client for ClientDelegate {
let thread = self.session_thread(&args.session_id)?;
let project = thread.read_with(&self.cx, |thread, _cx| thread.project().clone())?;
- let mut env = if let Some(dir) = &args.cwd {
- project
- .update(&mut self.cx.clone(), |project, cx| {
- let worktree = project.find_worktree(dir.as_path(), cx);
- let shell = TerminalSettings::get(
- worktree.as_ref().map(|(worktree, path)| SettingsLocation {
- worktree_id: worktree.read(cx).id(),
- path: &path,
- }),
- cx,
- )
- .shell
- .clone();
- project.directory_environment(&shell, dir.clone().into(), cx)
- })?
- .await
- .unwrap_or_default()
- } else {
- Default::default()
- };
- // Disables paging for `git` and hopefully other commands
- env.insert("PAGER".into(), "".into());
- for var in args.env {
- env.insert(var.name, var.value);
- }
-
- // Use remote shell or default system shell, as appropriate
- let shell = project
- .update(&mut self.cx.clone(), |project, cx| {
- project
- .remote_client()
- .and_then(|r| r.read(cx).default_system_shell())
- .map(Shell::Program)
- })?
- .unwrap_or_else(|| Shell::Program(get_default_system_shell_preferring_bash()));
- let is_windows = project
- .read_with(&self.cx, |project, cx| project.path_style(cx).is_windows())
- .unwrap_or(cfg!(windows));
- let (task_command, task_args) = task::ShellBuilder::new(&shell, is_windows)
- .redirect_stdin_to_dev_null()
- .build(Some(args.command.clone()), &args.args);
-
- let terminal_entity = project
- .update(&mut self.cx.clone(), |project, cx| {
- project.create_terminal_task(
- task::SpawnInTerminal {
- command: Some(task_command),
- args: task_args,
- cwd: args.cwd.clone(),
- env,
- ..Default::default()
- },
- cx,
- )
- })?
- .await?;
+ let terminal_entity = acp_thread::create_terminal_entity(
+ args.command.clone(),
+ &args.args,
+ args.env
+ .into_iter()
+ .map(|env| (env.name, env.value))
+ .collect(),
+ args.cwd.clone(),
+ &project,
+ &mut self.cx.clone(),
+ )
+ .await?;
// Register with renderer
let terminal_entity = thread.update(&mut self.cx.clone(), |thread, cx| {
@@ -19,7 +19,7 @@ path = "src/explorer.rs"
[dependencies]
acp_thread.workspace = true
-agent.workspace = true
+agent = { workspace = true, features = ["eval"] }
agent-client-protocol.workspace = true
agent_settings.workspace = true
agent_ui.workspace = true
@@ -29,7 +29,6 @@ buffer_diff.workspace = true
chrono.workspace = true
clap.workspace = true
client.workspace = true
-cloud_llm_client.workspace = true
collections.workspace = true
debug_adapter_extension.workspace = true
dirs.workspace = true
@@ -54,13 +53,13 @@ pretty_assertions.workspace = true
project.workspace = true
prompt_store.workspace = true
regex.workspace = true
+rand.workspace = true
release_channel.workspace = true
reqwest_client.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
shellexpand.workspace = true
-smol.workspace = true
telemetry.workspace = true
terminal_view.workspace = true
toml.workspace = true
@@ -1,6 +1,5 @@
{
- "assistant": {
- "always_allow_tool_actions": true,
- "version": "2"
+ "agent": {
+ "always_allow_tool_actions": true
}
}
@@ -61,9 +61,22 @@ struct Args {
/// Maximum number of examples to run concurrently.
#[arg(long, default_value = "4")]
concurrency: usize,
+ /// Output current environment variables as JSON to stdout
+ #[arg(long, hide = true)]
+ printenv: bool,
}
fn main() {
+ let args = Args::parse();
+
+ // This prevents errors showing up in the logs, because
+ // project::environment::load_shell_environment() calls
+ // std::env::current_exe().unwrap() --printenv
+ if args.printenv {
+ util::shell_env::print_env();
+ return;
+ }
+
dotenvy::from_filename(CARGO_MANIFEST_DIR.join(".env")).ok();
env_logger::init();
@@ -99,7 +112,6 @@ fn main() {
let zed_commit_sha = commit_sha_for_path(&root_dir);
let zed_branch_name = git_branch_for_path(&root_dir);
- let args = Args::parse();
let languages: HashSet<String> = args.languages.into_iter().collect();
let http_client = Arc::new(ReqwestClient::new());
@@ -126,19 +138,20 @@ fn main() {
let mut cumulative_tool_metrics = ToolMetrics::default();
- let agent_model = load_model(&args.model, cx).unwrap();
- let judge_model = load_model(&args.judge_model, cx).unwrap();
-
- LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
- registry.set_default_model(Some(agent_model.clone()), cx);
+ let tasks = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
+ registry.providers().iter().map(|p| p.authenticate(cx)).collect::<Vec<_>>()
});
- let auth1 = agent_model.provider.authenticate(cx);
- let auth2 = judge_model.provider.authenticate(cx);
-
cx.spawn(async move |cx| {
- auth1.await?;
- auth2.await?;
+ future::join_all(tasks).await;
+ let judge_model = cx.update(|cx| {
+ let agent_model = load_model(&args.model, cx).unwrap();
+ let judge_model = load_model(&args.judge_model, cx).unwrap();
+ LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
+ registry.set_default_model(Some(agent_model.clone()), cx);
+ });
+ judge_model
+ })?;
let mut examples = Vec::new();
@@ -268,7 +281,6 @@ fn main() {
future::join_all((0..args.concurrency).map(|_| {
let app_state = app_state.clone();
- let model = agent_model.model.clone();
let judge_model = judge_model.model.clone();
let zed_commit_sha = zed_commit_sha.clone();
let zed_branch_name = zed_branch_name.clone();
@@ -283,7 +295,7 @@ fn main() {
let result = async {
example.setup().await?;
let run_output = cx
- .update(|cx| example.run(model.clone(), app_state.clone(), cx))?
+ .update(|cx| example.run(app_state.clone(), cx))?
.await?;
let judge_output = judge_example(
example.clone(),
@@ -524,7 +536,6 @@ async fn judge_example(
diff_evaluation = judge_output.diff.clone(),
thread_evaluation = judge_output.thread,
tool_metrics = run_output.tool_metrics,
- response_count = run_output.response_count,
token_usage = run_output.token_usage,
model = model.telemetry_id(),
model_provider = model.provider_id().to_string(),
@@ -3,6 +3,7 @@ use std::{
fmt::{self, Debug},
sync::{Arc, Mutex},
time::Duration,
+ u32,
};
use crate::{
@@ -16,11 +17,10 @@ use agent_settings::AgentProfileId;
use anyhow::{Result, anyhow};
use async_trait::async_trait;
use buffer_diff::DiffHunkStatus;
-use cloud_llm_client::CompletionIntent;
use collections::HashMap;
-use futures::{FutureExt as _, StreamExt, channel::mpsc, select_biased};
+use futures::{FutureExt as _, StreamExt, select_biased};
use gpui::{App, AppContext, AsyncApp, Entity};
-use language_model::{LanguageModel, Role, StopReason};
+use language_model::Role;
use util::rel_path::RelPath;
pub const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2);
@@ -93,7 +93,6 @@ pub struct ExampleContext {
log_prefix: String,
agent_thread: Entity<agent::Thread>,
app: AsyncApp,
- model: Arc<dyn LanguageModel>,
pub assertions: AssertionsReport,
pub tool_metrics: Arc<Mutex<ToolMetrics>>,
}
@@ -103,7 +102,6 @@ impl ExampleContext {
meta: ExampleMetadata,
log_prefix: String,
agent_thread: Entity<Thread>,
- model: Arc<dyn LanguageModel>,
app: AsyncApp,
) -> Self {
let assertions = AssertionsReport::new(meta.max_assertions);
@@ -113,26 +111,11 @@ impl ExampleContext {
log_prefix,
agent_thread,
assertions,
- model,
app,
tool_metrics: Arc::new(Mutex::new(ToolMetrics::default())),
}
}
- pub fn push_user_message(&mut self, text: impl ToString) {
- self.app
- .update_entity(&self.agent_thread, |thread, cx| {
- thread.insert_user_message(
- text.to_string(),
- ContextLoadResult::default(),
- None,
- Vec::new(),
- cx,
- );
- })
- .unwrap();
- }
-
pub fn assert(&mut self, expected: bool, message: impl ToString) -> Result<()> {
let message = message.to_string();
self.log_assertion(
@@ -204,156 +187,174 @@ impl ExampleContext {
result
}
- pub async fn run_to_end(&mut self) -> Result<Response> {
- self.run_turns(u32::MAX).await
+ pub async fn prompt(&mut self, prompt: impl Into<String>) -> Result<Response> {
+ self.prompt_with_max_turns(prompt, u32::MAX).await
}
- pub async fn run_turn(&mut self) -> Result<Response> {
- self.run_turns(1).await
+ pub async fn prompt_with_max_turns(
+ &mut self,
+ prompt: impl Into<String>,
+ max_turns: u32,
+ ) -> Result<Response> {
+ let content = vec![UserMessageContent::Text(prompt.into())];
+ self.run_turns(Some(content), max_turns).await
}
- pub async fn run_turns(&mut self, iterations: u32) -> Result<Response> {
- let (mut tx, mut rx) = mpsc::channel(1);
+ pub async fn proceed_with_max_turns(&mut self, max_turns: u32) -> Result<Response> {
+ self.run_turns(None, max_turns).await
+ }
+ async fn run_turns(
+ &mut self,
+ prompt: Option<Vec<UserMessageContent>>,
+ max_turns: u32,
+ ) -> Result<Response> {
let tool_metrics = self.tool_metrics.clone();
let log_prefix = self.log_prefix.clone();
- let _subscription = self.app.subscribe(
- &self.agent_thread,
- move |thread, event: &ThreadEvent, cx| match event {
- ThreadEvent::ShowError(thread_error) => {
- tx.try_send(Err(anyhow!(thread_error.clone()))).ok();
- }
- ThreadEvent::Stopped(reason) => match reason {
- Ok(StopReason::EndTurn) => {
- tx.close_channel();
+
+ let mut remaining_turns = max_turns;
+
+ let mut event_stream = self.agent_thread.update(&mut self.app, |thread, cx| {
+ if let Some(prompt) = prompt {
+ let id = UserMessageId::new();
+ thread.send(id, prompt, cx)
+ } else {
+ thread.proceed(cx)
+ }
+ })??;
+
+ let task = self.app.background_spawn(async move {
+ let mut messages = Vec::new();
+ let mut tool_uses_by_id = HashMap::default();
+ while let Some(event) = event_stream.next().await {
+ match event? {
+ ThreadEvent::UserMessage(user_message) => {
+ messages.push(Message {
+ role: Role::User,
+ text: user_message.to_markdown(),
+ tool_use: Vec::new(),
+ });
}
- Ok(StopReason::ToolUse) => {
- if thread.read(cx).remaining_turns() == 0 {
- tx.close_channel();
+ ThreadEvent::AgentThinking(text) | ThreadEvent::AgentText(text) => {
+ if matches!(
+ messages.last(),
+ Some(Message {
+ role: Role::Assistant,
+ ..
+ })
+ ) {
+ messages.last_mut().unwrap().text.push_str(&text);
+ } else {
+ messages.push(Message {
+ role: Role::Assistant,
+ text,
+ tool_use: Vec::new(),
+ });
}
}
- Ok(StopReason::MaxTokens) => {
- tx.try_send(Err(anyhow!("Exceeded maximum tokens"))).ok();
- }
- Ok(StopReason::Refusal) => {
- tx.try_send(Err(anyhow!("Model refused to generate content")))
- .ok();
- }
- Err(err) => {
- tx.try_send(Err(anyhow!(err.clone()))).ok();
+ ThreadEvent::ToolCall(tool_call) => {
+ let meta = tool_call.meta.expect("Missing meta field in tool_call");
+ let tool_name = meta
+ .get("tool_name")
+ .expect("Missing tool_name field in meta")
+ .as_str()
+ .expect("Unknown tool_name content in meta");
+
+ tool_uses_by_id.insert(
+ tool_call.id,
+ ToolUse {
+ name: tool_name.to_string(),
+ value: tool_call.raw_input.unwrap_or_default(),
+ },
+ );
+ if matches!(
+ tool_call.status,
+ acp::ToolCallStatus::Completed | acp::ToolCallStatus::Failed
+ ) {
+ panic!("Tool call completed without update");
+ }
}
- },
- ThreadEvent::NewRequest
- | ThreadEvent::StreamedAssistantText(_, _)
- | ThreadEvent::StreamedAssistantThinking(_, _)
- | ThreadEvent::UsePendingTools { .. }
- | ThreadEvent::CompletionCanceled => {}
- ThreadEvent::ToolUseLimitReached => {}
- ThreadEvent::ToolFinished {
- tool_use_id,
- pending_tool_use,
- ..
- } => {
- thread.update(cx, |thread, _cx| {
- if let Some(tool_use) = pending_tool_use {
- let mut tool_metrics = tool_metrics.lock().unwrap();
- if let Some(tool_result) = thread.tool_result(tool_use_id) {
- let message = if tool_result.is_error {
- format!("✖︎ {}", tool_use.name)
- } else {
+ ThreadEvent::ToolCallUpdate(tool_call_update) => {
+ if let acp_thread::ToolCallUpdate::UpdateFields(update) = tool_call_update {
+ if let Some(raw_input) = update.fields.raw_input {
+ if let Some(tool_use) = tool_uses_by_id.get_mut(&update.id) {
+ tool_use.value = raw_input;
+ }
+ }
+
+ if matches!(
+ update.fields.status,
+ Some(acp::ToolCallStatus::Completed | acp::ToolCallStatus::Failed)
+ ) {
+ let succeeded =
+ update.fields.status == Some(acp::ToolCallStatus::Completed);
+
+ let tool_use = tool_uses_by_id
+ .remove(&update.id)
+ .expect("Unrecognized tool call completed");
+
+ let log_message = if succeeded {
format!("✔︎ {}", tool_use.name)
+ } else {
+ format!("✖︎ {}", tool_use.name)
};
- println!("{log_prefix}{message}");
+ println!("{log_prefix}{log_message}");
+
tool_metrics
- .insert(tool_result.tool_name.clone(), !tool_result.is_error);
- } else {
- let message =
- format!("TOOL FINISHED WITHOUT RESULT: {}", tool_use.name);
- println!("{log_prefix}{message}");
- tool_metrics.insert(tool_use.name.clone(), true);
+ .lock()
+ .unwrap()
+ .insert(tool_use.name.clone().into(), succeeded);
+
+ if let Some(message) = messages.last_mut() {
+ message.tool_use.push(tool_use);
+ } else {
+ messages.push(Message {
+ role: Role::Assistant,
+ text: "".to_string(),
+ tool_use: vec![tool_use],
+ });
+ }
+
+ remaining_turns -= 1;
+ if remaining_turns == 0 {
+ return Ok(messages);
+ }
}
}
- });
- }
- ThreadEvent::InvalidToolInput { .. } => {
- println!("{log_prefix} invalid tool input");
- }
- ThreadEvent::MissingToolUse {
- tool_use_id: _,
- ui_text,
- } => {
- println!("{log_prefix} {ui_text}");
- }
- ThreadEvent::ToolConfirmationNeeded => {
- panic!(
+ }
+ ThreadEvent::ToolCallAuthorization(_) => panic!(
"{}Bug: Tool confirmation should not be required in eval",
log_prefix
- );
- }
- ThreadEvent::StreamedCompletion
- | ThreadEvent::MessageAdded(_)
- | ThreadEvent::MessageEdited(_)
- | ThreadEvent::MessageDeleted(_)
- | ThreadEvent::SummaryChanged
- | ThreadEvent::SummaryGenerated
- | ThreadEvent::ProfileChanged
- | ThreadEvent::ReceivedTextChunk
- | ThreadEvent::StreamedToolUse { .. }
- | ThreadEvent::CheckpointChanged
- | ThreadEvent::CancelEditing => {
- tx.try_send(Ok(())).ok();
- if std::env::var("ZED_EVAL_DEBUG").is_ok() {
- println!("{}Event: {:#?}", log_prefix, event);
- }
- }
- },
- );
-
- let model = self.model.clone();
-
- let message_count_before = self.app.update_entity(&self.agent_thread, |thread, cx| {
- thread.set_remaining_turns(iterations);
- thread.send_to_model(model, CompletionIntent::UserPrompt, None, cx);
- thread.messages().len()
- })?;
-
- loop {
- select_biased! {
- result = rx.next() => {
- if let Some(result) = result {
- result?;
- } else {
- break;
+ ),
+ ThreadEvent::Retry(status) => {
+ println!("{log_prefix} Got retry: {status:?}");
}
- }
- _ = self.app.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => {
- anyhow::bail!("Agentic loop stalled - waited {THREAD_EVENT_TIMEOUT:?} without any events");
+ ThreadEvent::Stop(stop_reason) => match stop_reason {
+ acp::StopReason::EndTurn => {}
+ acp::StopReason::MaxTokens => {
+ return Err(anyhow!("Exceeded maximum tokens"));
+ }
+ acp::StopReason::MaxTurnRequests => {
+ return Err(anyhow!("Exceeded maximum turn requests"));
+ }
+ acp::StopReason::Refusal => {
+ return Err(anyhow!("Refusal"));
+ }
+ acp::StopReason::Cancelled => return Err(anyhow!("Cancelled")),
+ },
}
}
- }
+ Ok(messages)
+ });
- let messages = self.app.read_entity(&self.agent_thread, |thread, cx| {
- let mut messages = Vec::new();
- for message in thread.messages().skip(message_count_before) {
- messages.push(Message {
- _role: message.role,
- text: message.to_message_content(),
- tool_use: thread
- .tool_uses_for_message(message.id, cx)
- .into_iter()
- .map(|tool_use| ToolUse {
- name: tool_use.name.to_string(),
- value: tool_use.input,
- })
- .collect(),
- });
+ select_biased! {
+ result = task.fuse() => {
+ Ok(Response::new(result?))
}
- messages
- })?;
-
- let response = Response::new(messages);
-
- Ok(response)
+ _ = self.app.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => {
+ anyhow::bail!("Agentic loop stalled - waited {THREAD_EVENT_TIMEOUT:?} without any events");
+ }
+ }
}
pub fn edits(&self) -> HashMap<Arc<RelPath>, FileEdits> {
@@ -488,7 +489,7 @@ impl Response {
Self { messages }
}
- pub fn expect_tool(
+ pub fn expect_tool_call(
&self,
tool_name: &'static str,
cx: &mut ExampleContext,
@@ -505,8 +506,7 @@ impl Response {
})
}
- #[allow(dead_code)]
- pub fn tool_uses(&self) -> impl Iterator<Item = &ToolUse> {
+ pub fn tool_calls(&self) -> impl Iterator<Item = &ToolUse> {
self.messages.iter().flat_map(|msg| &msg.tool_use)
}
@@ -517,7 +517,7 @@ impl Response {
#[derive(Debug)]
pub struct Message {
- _role: Role,
+ role: Role,
text: String,
tool_use: Vec<ToolUse>,
}
@@ -27,14 +27,12 @@ impl Example for AddArgToTraitMethod {
async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> {
const FILENAME: &str = "assistant_tool.rs";
- cx.push_user_message(format!(
+ let _ = cx.prompt(format!(
r#"
Add a `window: Option<gpui::AnyWindowHandle>` argument to the `Tool::run` trait method in {FILENAME},
and update all the implementations of the trait and call sites accordingly.
"#
- ));
-
- let _ = cx.run_to_end().await?;
+ )).await?;
// Adds ignored argument to all but `batch_tool`
@@ -29,16 +29,19 @@ impl Example for CodeBlockCitations {
async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> {
const FILENAME: &str = "assistant_tool.rs";
- cx.push_user_message(format!(
- r#"
- Show me the method bodies of all the methods of the `Tool` trait in {FILENAME}.
-
- Please show each method in a separate code snippet.
- "#
- ));
// Verify that the messages all have the correct formatting.
- let texts: Vec<String> = cx.run_to_end().await?.texts().collect();
+ let texts: Vec<String> = cx
+ .prompt(format!(
+ r#"
+ Show me the method bodies of all the methods of the `Tool` trait in {FILENAME}.
+
+ Please show each method in a separate code snippet.
+ "#
+ ))
+ .await?
+ .texts()
+ .collect();
let closing_fence = format!("\n{FENCE}");
for text in texts.iter() {
@@ -22,30 +22,26 @@ impl Example for CommentTranslation {
}
async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> {
- cx.push_user_message(r#"
- Edit the following files and translate all their comments to italian, in this exact order:
+ let response = cx.prompt(
+ r#"
+ Edit the following files and translate all their comments to italian, in this exact order:
- - font-kit/src/family.rs
- - font-kit/src/canvas.rs
- - font-kit/src/error.rs
- "#);
- cx.run_to_end().await?;
+ - font-kit/src/family.rs
+ - font-kit/src/canvas.rs
+ - font-kit/src/error.rs
+ "#
+ ).await?;
let mut create_or_overwrite_count = 0;
- cx.agent_thread().read_with(cx, |thread, cx| {
- for message in thread.messages() {
- for tool_use in thread.tool_uses_for_message(message.id, cx) {
- if tool_use.name == "edit_file" {
- let input: EditFileToolInput = serde_json::from_value(tool_use.input)?;
- if !matches!(input.mode, EditFileMode::Edit) {
- create_or_overwrite_count += 1;
- }
- }
+ for tool_call in response.tool_calls() {
+ if tool_call.name == "edit_file" {
+ let input = tool_call.parse_input::<EditFileToolInput>()?;
+ if !matches!(input.mode, EditFileMode::Edit) {
+ create_or_overwrite_count += 1;
}
}
+ }
- anyhow::Ok(())
- })??;
cx.assert_eq(create_or_overwrite_count, 0, "no_creation_or_overwrite")?;
Ok(())
@@ -48,8 +48,8 @@ impl Example for FileChangeNotificationExample {
})?;
// Start conversation (specific message is not important)
- cx.push_user_message("Find all files in this repo");
- cx.run_turn().await?;
+ cx.prompt_with_max_turns("Find all files in this repo", 1)
+ .await?;
// Edit the README buffer - the model should get a notification on next turn
buffer.update(cx, |buffer, cx| {
@@ -58,7 +58,7 @@ impl Example for FileChangeNotificationExample {
// Run for some more turns.
// The model shouldn't thank us for letting it know about the file change.
- cx.run_turns(3).await?;
+ cx.proceed_with_max_turns(3).await?;
Ok(())
}
@@ -25,18 +25,19 @@ impl Example for FileSearchExample {
async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> {
const FILENAME: &str = "find_replace_file_tool.rs";
- cx.push_user_message(format!(
- r#"
+
+ let prompt = format!(
+ r#"
Look at the `{FILENAME}`. I want to implement a card for it. The card should implement the `Render` trait.
The card should show a diff. It should be a beautifully presented diff. The card "box" should look like what we show for
markdown codeblocks (look at `MarkdownElement`). I want to see a red background for lines that were deleted and a green
background for lines that were added. We should have a div per diff line.
"#
- ));
+ );
- let response = cx.run_turn().await?;
- let tool_use = response.expect_tool("find_path", cx)?;
+ let response = cx.prompt_with_max_turns(prompt, 1).await?;
+ let tool_use = response.expect_tool_call("find_path", cx)?;
let input = tool_use.parse_input::<FindPathToolInput>()?;
let glob = input.glob;
@@ -1,3 +1,4 @@
+use agent::GrepToolInput;
use agent_settings::AgentProfileId;
use anyhow::Result;
use async_trait::async_trait;
@@ -35,9 +36,9 @@ impl Example for GrepParamsEscapementExample {
}
async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> {
- // cx.push_user_message("How does the precedence/specificity work with Keymap contexts? I am seeing that `MessageEditor > Editor` is lower precendence than `Editor` which is surprising to me, but might be how it works");
- cx.push_user_message("Search for files containing the characters `>` or `<`");
- let response = cx.run_turns(2).await?;
+ let response = cx
+ .prompt_with_max_turns("Search for files containing the characters `>` or `<`", 2)
+ .await?;
let grep_input = response
.find_tool_call("grep")
.and_then(|tool_use| tool_use.parse_input::<GrepToolInput>().ok());
@@ -144,9 +144,8 @@ impl Example for DeclarativeExample {
}
async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> {
- cx.push_user_message(&self.prompt);
let max_turns = self.metadata.max_turns.unwrap_or(1000);
- let _ = cx.run_turns(max_turns).await;
+ let _ = cx.prompt_with_max_turns(&self.prompt, max_turns).await;
Ok(())
}
@@ -1,3 +1,4 @@
+use agent::{EditFileMode, EditFileToolInput};
use agent_settings::AgentProfileId;
use anyhow::Result;
use async_trait::async_trait;
@@ -35,17 +36,14 @@ impl Example for FileOverwriteExample {
}
async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> {
- let response = cx.run_turns(1).await?;
- let file_overwritten = if let Some(tool_use) = response.find_tool_call("edit_file") {
- let input = tool_use.parse_input::<EditFileToolInput>()?;
- match input.mode {
- EditFileMode::Edit => false,
- EditFileMode::Create | EditFileMode::Overwrite => {
- input.path.ends_with("src/language_model_selector.rs")
- }
+ let response = cx.proceed_with_max_turns(1).await?;
+ let tool_use = response.expect_tool_call("edit_file", cx)?;
+ let input = tool_use.parse_input::<EditFileToolInput>()?;
+ let file_overwritten = match input.mode {
+ EditFileMode::Edit => false,
+ EditFileMode::Create | EditFileMode::Overwrite => {
+ input.path.ends_with("src/language_model_selector.rs")
}
- } else {
- false
};
cx.assert(!file_overwritten, "File should be edited, not overwritten")
@@ -23,20 +23,19 @@ impl Example for Planets {
}
async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> {
- cx.push_user_message(
- r#"
+ let response = cx
+ .prompt(
+ r#"
Make a plain JavaScript web page which renders an animated 3D solar system.
Let me drag to rotate the camera around.
Do not use npm.
- "#
- .to_string(),
- );
-
- let response = cx.run_to_end().await?;
+ "#,
+ )
+ .await?;
let mut open_tool_uses = 0;
let mut terminal_tool_uses = 0;
- for tool_use in response.tool_uses() {
+ for tool_use in response.tool_calls() {
if tool_use.name == OpenTool::name() {
open_tool_uses += 1;
} else if tool_use.name == TerminalTool::name() {
@@ -1,36 +1,38 @@
-use agent::Message;
+use agent::ContextServerRegistry;
+use agent_client_protocol as acp;
use anyhow::{Context as _, Result, anyhow, bail};
use client::proto::LspWorkProgress;
use futures::channel::mpsc;
+use futures::future::Shared;
use futures::{FutureExt as _, StreamExt as _, future};
use gpui::{App, AppContext as _, AsyncApp, Entity, Task};
use handlebars::Handlebars;
use language::{Buffer, DiagnosticSeverity, OffsetRangeExt as _};
use language_model::{
- LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
- LanguageModelToolResultContent, MessageContent, Role, TokenUsage,
+ LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
+ LanguageModelRequestMessage, LanguageModelToolResultContent, MessageContent, Role, TokenUsage,
};
-use project::lsp_store::OpenLspBufferHandle;
-use project::{DiagnosticSummary, Project, ProjectPath};
+use project::{DiagnosticSummary, Project, ProjectPath, lsp_store::OpenLspBufferHandle};
+use prompt_store::{ProjectContext, WorktreeContext};
+use rand::{distr, prelude::*};
use serde::{Deserialize, Serialize};
-use std::cell::RefCell;
-use std::fmt::Write as _;
-use std::fs;
-use std::fs::File;
-use std::io::Write as _;
-use std::path::Path;
-use std::path::PathBuf;
-use std::rc::Rc;
-use std::sync::Arc;
-use std::time::Duration;
+use std::{
+ fmt::Write as _,
+ fs::{self, File},
+ io::Write as _,
+ path::{Path, PathBuf},
+ rc::Rc,
+ sync::{Arc, Mutex},
+ time::Duration,
+};
use unindent::Unindent as _;
-use util::ResultExt as _;
-use util::command::new_smol_command;
-use util::markdown::MarkdownCodeBlock;
+use util::{ResultExt as _, command::new_smol_command, markdown::MarkdownCodeBlock};
-use crate::assertions::{AssertionsReport, RanAssertion, RanAssertionResult};
-use crate::example::{Example, ExampleContext, FailedAssertion, JudgeAssertion};
-use crate::{AgentAppState, ToolMetrics};
+use crate::{
+ AgentAppState, ToolMetrics,
+ assertions::{AssertionsReport, RanAssertion, RanAssertionResult},
+ example::{Example, ExampleContext, FailedAssertion, JudgeAssertion},
+};
pub const ZED_REPO_URL: &str = "https://github.com/zed-industries/zed.git";
@@ -56,10 +58,9 @@ pub struct RunOutput {
pub diagnostic_summary_after: DiagnosticSummary,
pub diagnostics_before: Option<String>,
pub diagnostics_after: Option<String>,
- pub response_count: usize,
pub token_usage: TokenUsage,
pub tool_metrics: ToolMetrics,
- pub all_messages: String,
+ pub thread_markdown: String,
pub programmatic_assertions: AssertionsReport,
}
@@ -193,12 +194,7 @@ impl ExampleInstance {
.join(self.thread.meta().repo_name())
}
- pub fn run(
- &self,
- model: Arc<dyn LanguageModel>,
- app_state: Arc<AgentAppState>,
- cx: &mut App,
- ) -> Task<Result<RunOutput>> {
+ pub fn run(&self, app_state: Arc<AgentAppState>, cx: &mut App) -> Task<Result<RunOutput>> {
let project = Project::local(
app_state.client.clone(),
app_state.node_runtime.clone(),
@@ -213,15 +209,6 @@ impl ExampleInstance {
project.create_worktree(self.worktree_path(), true, cx)
});
- let tools = cx.new(|_| ToolWorkingSet::default());
- let prompt_store = None;
- let thread_store = ThreadStore::load(
- project.clone(),
- tools,
- prompt_store,
- app_state.prompt_builder.clone(),
- cx,
- );
let meta = self.thread.meta();
let this = self.clone();
@@ -300,74 +287,62 @@ impl ExampleInstance {
// history using undo/redo.
std::fs::write(&last_diff_file_path, "")?;
- let thread_store = thread_store.await?;
-
+ let thread = cx.update(|cx| {
+ //todo: Do we want to load rules files here?
+ let worktrees = project.read(cx).visible_worktrees(cx).map(|worktree| {
+ let root_name = worktree.read(cx).root_name_str().into();
+ let abs_path = worktree.read(cx).abs_path();
- let thread =
- thread_store.update(cx, |thread_store, cx| {
- let thread = if let Some(json) = &meta.existing_thread_json {
- let serialized = SerializedThread::from_json(json.as_bytes()).expect("Can't read serialized thread");
- thread_store.create_thread_from_serialized(serialized, cx)
- } else {
- thread_store.create_thread(cx)
- };
- thread.update(cx, |thread, cx| {
- thread.set_profile(meta.profile_id.clone(), cx);
- });
- thread
- })?;
-
-
- thread.update(cx, |thread, _cx| {
- let mut request_count = 0;
- let previous_diff = Rc::new(RefCell::new("".to_string()));
- let example_output_dir = this.run_directory.clone();
- let last_diff_file_path = last_diff_file_path.clone();
- let messages_json_file_path = example_output_dir.join("last.messages.json");
- let this = this.clone();
- thread.set_request_callback(move |request, response_events| {
- request_count += 1;
- let messages_file_path = example_output_dir.join(format!("{request_count}.messages.md"));
- let diff_file_path = example_output_dir.join(format!("{request_count}.diff"));
- let last_messages_file_path = example_output_dir.join("last.messages.md");
- let request_markdown = RequestMarkdown::new(request);
- let response_events_markdown = response_events_to_markdown(response_events);
- let dialog = ThreadDialog::new(request, response_events);
- let dialog_json = serde_json::to_string_pretty(&dialog.to_combined_request()).unwrap_or_default();
-
- let messages = format!("{}\n\n{}", request_markdown.messages, response_events_markdown);
- fs::write(&messages_file_path, messages.clone()).expect("failed to write messages file");
- fs::write(&last_messages_file_path, messages).expect("failed to write last messages file");
- fs::write(&messages_json_file_path, dialog_json).expect("failed to write last.messages.json");
-
- let diff_result = smol::block_on(this.repository_diff());
- match diff_result {
- Ok(diff) => {
- if diff != previous_diff.borrow().clone() {
- fs::write(&diff_file_path, &diff).expect("failed to write diff file");
- fs::write(&last_diff_file_path, &diff).expect("failed to write last diff file");
- *previous_diff.borrow_mut() = diff;
- }
- }
- Err(err) => {
- let error_message = format!("{err:?}");
- fs::write(&diff_file_path, &error_message).expect("failed to write diff error to file");
- fs::write(&last_diff_file_path, &error_message).expect("failed to write last diff file");
- }
+ WorktreeContext {
+ root_name,
+ abs_path,
+ rules_file: None,
}
+ }).collect::<Vec<_>>();
+ let project_context = cx.new(|_cx| ProjectContext::new(worktrees, vec![]));
+ let context_server_registry = cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
+
+ let thread = if let Some(json) = &meta.existing_thread_json {
+ let session_id = acp::SessionId(
+ rand::rng()
+ .sample_iter(&distr::Alphanumeric)
+ .take(7)
+ .map(char::from)
+ .collect::<String>()
+ .into(),
+ );
+
+ let db_thread = agent::DbThread::from_json(json.as_bytes()).expect("Can't read serialized thread");
+ cx.new(|cx| agent::Thread::from_db(session_id, db_thread, project.clone(), project_context, context_server_registry, agent::Templates::new(), cx))
+ } else {
+ cx.new(|cx| agent::Thread::new(project.clone(), project_context, context_server_registry, agent::Templates::new(), None, cx))
+ };
- if request_count == 1 {
- let tools_file_path = example_output_dir.join("tools.md");
- fs::write(tools_file_path, request_markdown.tools).expect("failed to write tools file");
- }
+ thread.update(cx, |thread, cx| {
+ thread.add_default_tools(Rc::new(EvalThreadEnvironment {
+ project: project.clone(),
+ }), cx);
+ thread.set_profile(meta.profile_id.clone());
+ thread.set_model(
+ LanguageModelInterceptor::new(
+ LanguageModelRegistry::read_global(cx).default_model().expect("Missing model").model.clone(),
+ this.run_directory.clone(),
+ last_diff_file_path.clone(),
+ this.run_directory.join("last.messages.json"),
+ this.worktree_path(),
+ this.repo_url(),
+ ),
+ cx,
+ );
});
- })?;
+
+ thread
+ }).unwrap();
let mut example_cx = ExampleContext::new(
meta.clone(),
this.log_prefix.clone(),
thread.clone(),
- model.clone(),
cx.clone(),
);
let result = this.thread.conversation(&mut example_cx).await;
@@ -380,7 +355,7 @@ impl ExampleInstance {
println!("{}Stopped", this.log_prefix);
println!("{}Getting repository diff", this.log_prefix);
- let repository_diff = this.repository_diff().await?;
+ let repository_diff = Self::repository_diff(this.worktree_path(), &this.repo_url()).await?;
std::fs::write(last_diff_file_path, &repository_diff)?;
@@ -415,34 +390,28 @@ impl ExampleInstance {
}
thread.update(cx, |thread, _cx| {
- let response_count = thread
- .messages()
- .filter(|message| message.role == language_model::Role::Assistant)
- .count();
RunOutput {
repository_diff,
diagnostic_summary_before,
diagnostic_summary_after,
diagnostics_before,
diagnostics_after,
- response_count,
- token_usage: thread.cumulative_token_usage(),
+ token_usage: thread.latest_request_token_usage().unwrap(),
tool_metrics: example_cx.tool_metrics.lock().unwrap().clone(),
- all_messages: messages_to_markdown(thread.messages()),
+ thread_markdown: thread.to_markdown(),
programmatic_assertions: example_cx.assertions,
}
})
})
}
- async fn repository_diff(&self) -> Result<String> {
- let worktree_path = self.worktree_path();
- run_git(&worktree_path, &["add", "."]).await?;
+ async fn repository_diff(repository_path: PathBuf, repository_url: &str) -> Result<String> {
+ run_git(&repository_path, &["add", "."]).await?;
let mut diff_args = vec!["diff", "--staged"];
- if self.thread.meta().url == ZED_REPO_URL {
+ if repository_url == ZED_REPO_URL {
diff_args.push(":(exclude).rules");
}
- run_git(&worktree_path, &diff_args).await
+ run_git(&repository_path, &diff_args).await
}
pub async fn judge(
@@ -542,7 +511,7 @@ impl ExampleInstance {
hbs.register_template_string(judge_thread_prompt_name, judge_thread_prompt)
.unwrap();
- let complete_messages = &run_output.all_messages;
+ let complete_messages = &run_output.thread_markdown;
let to_prompt = |assertion: String| {
hbs.render(
judge_thread_prompt_name,
@@ -634,6 +603,273 @@ impl ExampleInstance {
}
}
+struct EvalThreadEnvironment {
+ project: Entity<Project>,
+}
+
+struct EvalTerminalHandle {
+ terminal: Entity<acp_thread::Terminal>,
+}
+
+impl agent::TerminalHandle for EvalTerminalHandle {
+ fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
+ self.terminal.read_with(cx, |term, _cx| term.id().clone())
+ }
+
+ fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
+ self.terminal
+ .read_with(cx, |term, _cx| term.wait_for_exit())
+ }
+
+ fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
+ self.terminal
+ .read_with(cx, |term, cx| term.current_output(cx))
+ }
+}
+
+impl agent::ThreadEnvironment for EvalThreadEnvironment {
+ fn create_terminal(
+ &self,
+ command: String,
+ cwd: Option<PathBuf>,
+ output_byte_limit: Option<u64>,
+ cx: &mut AsyncApp,
+ ) -> Task<Result<Rc<dyn agent::TerminalHandle>>> {
+ let project = self.project.clone();
+ cx.spawn(async move |cx| {
+ let language_registry =
+ project.read_with(cx, |project, _cx| project.languages().clone())?;
+ let id = acp::TerminalId(uuid::Uuid::new_v4().to_string().into());
+ let terminal =
+ acp_thread::create_terminal_entity(command, &[], vec![], cwd.clone(), &project, cx)
+ .await?;
+ let terminal = cx.new(|cx| {
+ acp_thread::Terminal::new(
+ id,
+ "",
+ cwd,
+ output_byte_limit.map(|limit| limit as usize),
+ terminal,
+ language_registry,
+ cx,
+ )
+ })?;
+ Ok(Rc::new(EvalTerminalHandle { terminal }) as Rc<dyn agent::TerminalHandle>)
+ })
+ }
+}
+
+struct LanguageModelInterceptor {
+ model: Arc<dyn LanguageModel>,
+ request_count: Arc<Mutex<usize>>,
+ previous_diff: Arc<Mutex<String>>,
+ example_output_dir: PathBuf,
+ last_diff_file_path: PathBuf,
+ messages_json_file_path: PathBuf,
+ repository_path: PathBuf,
+ repository_url: String,
+}
+
+impl LanguageModelInterceptor {
+ fn new(
+ model: Arc<dyn LanguageModel>,
+ example_output_dir: PathBuf,
+ last_diff_file_path: PathBuf,
+ messages_json_file_path: PathBuf,
+ repository_path: PathBuf,
+ repository_url: String,
+ ) -> Arc<Self> {
+ Arc::new(Self {
+ model,
+ request_count: Arc::new(Mutex::new(0)),
+ previous_diff: Arc::new(Mutex::new("".to_string())),
+ example_output_dir,
+ last_diff_file_path,
+ messages_json_file_path,
+ repository_path,
+ repository_url,
+ })
+ }
+}
+
+impl language_model::LanguageModel for LanguageModelInterceptor {
+ fn id(&self) -> language_model::LanguageModelId {
+ self.model.id()
+ }
+
+ fn name(&self) -> language_model::LanguageModelName {
+ self.model.name()
+ }
+
+ fn provider_id(&self) -> language_model::LanguageModelProviderId {
+ self.model.provider_id()
+ }
+
+ fn provider_name(&self) -> language_model::LanguageModelProviderName {
+ self.model.provider_name()
+ }
+
+ fn telemetry_id(&self) -> String {
+ self.model.telemetry_id()
+ }
+
+ fn supports_images(&self) -> bool {
+ self.model.supports_images()
+ }
+
+ fn supports_tools(&self) -> bool {
+ self.model.supports_tools()
+ }
+
+ fn supports_tool_choice(&self, choice: language_model::LanguageModelToolChoice) -> bool {
+ self.model.supports_tool_choice(choice)
+ }
+
+ fn max_token_count(&self) -> u64 {
+ self.model.max_token_count()
+ }
+
+ fn count_tokens(
+ &self,
+ request: LanguageModelRequest,
+ cx: &App,
+ ) -> future::BoxFuture<'static, Result<u64>> {
+ self.model.count_tokens(request, cx)
+ }
+
+ fn stream_completion(
+ &self,
+ request: LanguageModelRequest,
+ cx: &AsyncApp,
+ ) -> future::BoxFuture<
+ 'static,
+ Result<
+ futures::stream::BoxStream<
+ 'static,
+ Result<LanguageModelCompletionEvent, language_model::LanguageModelCompletionError>,
+ >,
+ language_model::LanguageModelCompletionError,
+ >,
+ > {
+ let stream = self.model.stream_completion(request.clone(), cx);
+ let request_count = self.request_count.clone();
+ let previous_diff = self.previous_diff.clone();
+ let example_output_dir = self.example_output_dir.clone();
+ let last_diff_file_path = self.last_diff_file_path.clone();
+ let messages_json_file_path = self.messages_json_file_path.clone();
+ let repository_path = self.repository_path.clone();
+ let repository_url = self.repository_url.clone();
+
+ Box::pin(async move {
+ let stream = stream.await?;
+
+ let response_events = Arc::new(Mutex::new(Vec::new()));
+ let request_clone = request.clone();
+
+ let wrapped_stream = stream.then(move |event| {
+ let response_events = response_events.clone();
+ let request = request_clone.clone();
+ let request_count = request_count.clone();
+ let previous_diff = previous_diff.clone();
+ let example_output_dir = example_output_dir.clone();
+ let last_diff_file_path = last_diff_file_path.clone();
+ let messages_json_file_path = messages_json_file_path.clone();
+ let repository_path = repository_path.clone();
+ let repository_url = repository_url.clone();
+
+ async move {
+ let event_result = match &event {
+ Ok(ev) => Ok(ev.clone()),
+ Err(err) => Err(err.to_string()),
+ };
+ response_events.lock().unwrap().push(event_result);
+
+ let should_execute = matches!(
+ &event,
+ Ok(LanguageModelCompletionEvent::Stop { .. }) | Err(_)
+ );
+
+ if should_execute {
+ let current_request_count = {
+ let mut count = request_count.lock().unwrap();
+ *count += 1;
+ *count
+ };
+
+ let messages_file_path =
+ example_output_dir.join(format!("{current_request_count}.messages.md"));
+ let diff_file_path =
+ example_output_dir.join(format!("{current_request_count}.diff"));
+ let last_messages_file_path = example_output_dir.join("last.messages.md");
+
+ let collected_events = response_events.lock().unwrap().clone();
+ let request_markdown = RequestMarkdown::new(&request);
+ let response_events_markdown =
+ response_events_to_markdown(&collected_events);
+ let dialog = ThreadDialog::new(&request, &collected_events);
+ let dialog_json =
+ serde_json::to_string_pretty(&dialog.to_combined_request())
+ .unwrap_or_default();
+
+ let messages = format!(
+ "{}\n\n{}",
+ request_markdown.messages, response_events_markdown
+ );
+ fs::write(&messages_file_path, messages.clone())
+ .expect("failed to write messages file");
+ fs::write(&last_messages_file_path, messages)
+ .expect("failed to write last messages file");
+ fs::write(&messages_json_file_path, dialog_json)
+ .expect("failed to write last.messages.json");
+
+ // Get repository diff
+ let diff_result =
+ ExampleInstance::repository_diff(repository_path, &repository_url)
+ .await;
+
+ match diff_result {
+ Ok(diff) => {
+ let prev_diff = previous_diff.lock().unwrap().clone();
+ if diff != prev_diff {
+ fs::write(&diff_file_path, &diff)
+ .expect("failed to write diff file");
+ fs::write(&last_diff_file_path, &diff)
+ .expect("failed to write last diff file");
+ *previous_diff.lock().unwrap() = diff;
+ }
+ }
+ Err(err) => {
+ let error_message = format!("{err:?}");
+ fs::write(&diff_file_path, &error_message)
+ .expect("failed to write diff error to file");
+ fs::write(&last_diff_file_path, &error_message)
+ .expect("failed to write last diff file");
+ }
+ }
+
+ if current_request_count == 1 {
+ let tools_file_path = example_output_dir.join("tools.md");
+ fs::write(tools_file_path, request_markdown.tools)
+ .expect("failed to write tools file");
+ }
+ }
+
+ event
+ }
+ });
+
+ Ok(Box::pin(wrapped_stream)
+ as futures::stream::BoxStream<
+ 'static,
+ Result<
+ LanguageModelCompletionEvent,
+ language_model::LanguageModelCompletionError,
+ >,
+ >)
+ })
+ }
+}
+
pub fn wait_for_lang_server(
project: &Entity<Project>,
buffer: &Entity<Buffer>,
@@ -825,40 +1061,6 @@ pub async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
Ok(String::from_utf8(output.stdout)?.trim().to_string())
}
-fn messages_to_markdown<'a>(message_iter: impl IntoIterator<Item = &'a Message>) -> String {
- let mut messages = String::new();
- let mut assistant_message_number: u32 = 1;
-
- for message in message_iter {
- push_role(&message.role, &mut messages, &mut assistant_message_number);
-
- for segment in &message.segments {
- match segment {
- MessageSegment::Text(text) => {
- messages.push_str(text);
- messages.push_str("\n\n");
- }
- MessageSegment::Thinking { text, signature } => {
- messages.push_str("**Thinking**:\n\n");
- if let Some(sig) = signature {
- messages.push_str(&format!("Signature: {}\n\n", sig));
- }
- messages.push_str(text);
- messages.push_str("\n");
- }
- MessageSegment::RedactedThinking(items) => {
- messages.push_str(&format!(
- "**Redacted Thinking**: {} item(s)\n\n",
- items.len()
- ));
- }
- }
- }
- }
-
- messages
-}
-
fn push_role(role: &Role, buf: &mut String, assistant_message_number: &mut u32) {
match role {
Role::System => buf.push_str("# ⚙️ SYSTEM\n\n"),