Lay the groundwork for a Rust-based eval (#28488)

Antonio Scandurra and Nathan Sobo created

Also, we moved the logic for driving the agentic loop into `Thread` so
that we don't have to re-implement it.

Release Notes:

- N/A

---------

Co-authored-by: Nathan Sobo <nathan@zed.dev>

Change summary

Cargo.lock                                                |  31 +
Cargo.toml                                                |   1 
crates/agent/src/active_thread.rs                         |  25 
crates/agent/src/thread.rs                                |  62 +
crates/agent_eval/src/headless_assistant.rs               |  19 
crates/eval/Cargo.toml                                    |  39 +
crates/eval/LICENSE-GPL                                   |   1 
crates/eval/README.md                                     |   7 
crates/eval/examples/find_and_replace_diff_card/base.toml |   2 
crates/eval/examples/find_and_replace_diff_card/prompt.md |   3 
crates/eval/examples/find_and_replace_diff_card/rubric.md |   0 
crates/eval/src/agent.rs                                  | 229 +++++++++
crates/eval/src/eval.rs                                   | 101 +++
13 files changed, 455 insertions(+), 65 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -4901,6 +4901,37 @@ dependencies = [
  "num-traits",
 ]
 
+[[package]]
+name = "eval"
+version = "0.1.0"
+dependencies = [
+ "agent",
+ "anyhow",
+ "assistant_tool",
+ "assistant_tools",
+ "client",
+ "collections",
+ "context_server",
+ "dap",
+ "env_logger 0.11.8",
+ "fs",
+ "gpui",
+ "gpui_tokio",
+ "language",
+ "language_model",
+ "language_models",
+ "node_runtime",
+ "project",
+ "prompt_store",
+ "release_channel",
+ "reqwest_client",
+ "serde",
+ "settings",
+ "smol",
+ "toml 0.8.20",
+ "workspace-hack",
+]
+
 [[package]]
 name = "evals"
 version = "0.1.0"

Cargo.toml 🔗

@@ -47,6 +47,7 @@ members = [
     "crates/diagnostics",
     "crates/docs_preprocessor",
     "crates/editor",
+    "crates/eval",
     "crates/evals",
     "crates/extension",
     "crates/extension_api",

crates/agent/src/active_thread.rs 🔗

@@ -21,7 +21,7 @@ use gpui::{
     linear_color_stop, linear_gradient, list, percentage, pulsating_between,
 };
 use language::{Buffer, LanguageRegistry};
-use language_model::{ConfiguredModel, LanguageModelRegistry, LanguageModelToolUseId, Role};
+use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
 use markdown::parser::CodeBlockKind;
 use markdown::{Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown, without_fences};
 use project::ProjectItem as _;
@@ -897,11 +897,7 @@ impl ActiveThread {
                 self.save_thread(cx);
                 cx.notify();
             }
-            ThreadEvent::UsePendingTools => {
-                let tool_uses = self
-                    .thread
-                    .update(cx, |thread, cx| thread.use_pending_tools(cx));
-
+            ThreadEvent::UsePendingTools { tool_uses } => {
                 for tool_use in tool_uses {
                     self.render_tool_use_markdown(
                         tool_use.id.clone(),
@@ -913,11 +909,8 @@ impl ActiveThread {
                 }
             }
             ThreadEvent::ToolFinished {
-                pending_tool_use,
-                canceled,
-                ..
+                pending_tool_use, ..
             } => {
-                let canceled = *canceled;
                 if let Some(tool_use) = pending_tool_use {
                     self.render_tool_use_markdown(
                         tool_use.id.clone(),
@@ -931,18 +924,6 @@ impl ActiveThread {
                         cx,
                     );
                 }
-
-                if self.thread.read(cx).all_tools_finished() {
-                    let model_registry = LanguageModelRegistry::read_global(cx);
-                    if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
-                        self.thread.update(cx, |thread, cx| {
-                            thread.attach_tool_results(cx);
-                            if !canceled {
-                                thread.send_to_model(model, RequestKind::Chat, cx);
-                            }
-                        });
-                    }
-                }
             }
             ThreadEvent::CheckpointChanged => cx.notify(),
         }

crates/agent/src/thread.rs 🔗

@@ -1181,7 +1181,8 @@ impl Thread {
                     match result.as_ref() {
                         Ok(stop_reason) => match stop_reason {
                             StopReason::ToolUse => {
-                                cx.emit(ThreadEvent::UsePendingTools);
+                                let tool_uses = thread.use_pending_tools(cx);
+                                cx.emit(ThreadEvent::UsePendingTools { tool_uses });
                             }
                             StopReason::EndTurn => {}
                             StopReason::MaxTokens => {}
@@ -1369,10 +1370,7 @@ impl Thread {
         )
     }
 
-    pub fn use_pending_tools(
-        &mut self,
-        cx: &mut Context<Self>,
-    ) -> impl IntoIterator<Item = PendingToolUse> + use<> {
+    pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
         let request = self.to_completion_request(RequestKind::Chat, cx);
         let messages = Arc::new(request.messages);
         let pending_tool_uses = self
@@ -1460,18 +1458,36 @@ impl Thread {
                             output,
                             cx,
                         );
-
-                        cx.emit(ThreadEvent::ToolFinished {
-                            tool_use_id,
-                            pending_tool_use,
-                            canceled: false,
-                        });
+                        thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
                     })
                     .ok();
             }
         })
     }
 
+    fn tool_finished(
+        &mut self,
+        tool_use_id: LanguageModelToolUseId,
+        pending_tool_use: Option<PendingToolUse>,
+        canceled: bool,
+        cx: &mut Context<Self>,
+    ) {
+        if self.all_tools_finished() {
+            let model_registry = LanguageModelRegistry::read_global(cx);
+            if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
+                self.attach_tool_results(cx);
+                if !canceled {
+                    self.send_to_model(model, RequestKind::Chat, cx);
+                }
+            }
+        }
+
+        cx.emit(ThreadEvent::ToolFinished {
+            tool_use_id,
+            pending_tool_use,
+        });
+    }
+
     pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
         // Insert a user message to contain the tool results.
         self.insert_user_message(
@@ -1495,11 +1511,12 @@ impl Thread {
             let mut canceled = false;
             for pending_tool_use in self.tool_use.cancel_pending() {
                 canceled = true;
-                cx.emit(ThreadEvent::ToolFinished {
-                    tool_use_id: pending_tool_use.id.clone(),
-                    pending_tool_use: Some(pending_tool_use),
-                    canceled: true,
-                });
+                self.tool_finished(
+                    pending_tool_use.id.clone(),
+                    Some(pending_tool_use),
+                    true,
+                    cx,
+                );
             }
             canceled
         };
@@ -1866,12 +1883,7 @@ impl Thread {
 
         self.tool_use
             .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
-
-        cx.emit(ThreadEvent::ToolFinished {
-            tool_use_id,
-            pending_tool_use: None,
-            canceled: true,
-        });
+        self.tool_finished(tool_use_id.clone(), None, true, cx);
     }
 }
 
@@ -1897,14 +1909,14 @@ pub enum ThreadEvent {
     MessageDeleted(MessageId),
     SummaryGenerated,
     SummaryChanged,
-    UsePendingTools,
+    UsePendingTools {
+        tool_uses: Vec<PendingToolUse>,
+    },
     ToolFinished {
         #[allow(unused)]
         tool_use_id: LanguageModelToolUseId,
         /// The pending tool use that corresponds to this tool.
         pending_tool_use: Option<PendingToolUse>,
-        /// Whether the tool was canceled by the user.
-        canceled: bool,
     },
     CheckpointChanged,
     ToolConfirmationNeeded,

crates/agent_eval/src/headless_assistant.rs 🔗

@@ -95,11 +95,7 @@ impl HeadlessAssistant {
                     self.done_tx.send_blocking(Ok(())).unwrap()
                 }
             }
-            ThreadEvent::UsePendingTools => {
-                thread.update(cx, |thread, cx| {
-                    thread.use_pending_tools(cx);
-                });
-            }
+            ThreadEvent::UsePendingTools { .. } => {}
             ThreadEvent::ToolConfirmationNeeded => {
                 // Automatically approve all tools that need confirmation in headless mode
                 println!("Tool confirmation needed - automatically approving in headless mode");
@@ -152,19 +148,6 @@ impl HeadlessAssistant {
                 if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
                     println!("Tool result: {:?}", tool_result);
                 }
-                if thread.read(cx).all_tools_finished() {
-                    let model_registry = LanguageModelRegistry::read_global(cx);
-                    if let Some(model) = model_registry.default_model() {
-                        thread.update(cx, |thread, cx| {
-                            thread.attach_tool_results(cx);
-                            thread.send_to_model(model.model, RequestKind::Chat, cx);
-                        });
-                    } else {
-                        println!(
-                            "Warning: No active language model available to continue conversation"
-                        );
-                    }
-                }
             }
             _ => {}
         }

crates/eval/Cargo.toml 🔗

@@ -0,0 +1,39 @@
+[package]
+name = "eval"
+version = "0.1.0"
+publish.workspace = true
+edition.workspace = true
+
+[dependencies]
+agent.workspace = true
+anyhow.workspace = true
+assistant_tool.workspace = true
+assistant_tools.workspace = true
+client.workspace = true
+collections.workspace = true
+context_server.workspace = true
+dap.workspace = true
+env_logger.workspace = true
+fs.workspace = true
+gpui.workspace = true
+gpui_tokio.workspace = true
+language.workspace = true
+language_model.workspace = true
+language_models.workspace = true
+node_runtime.workspace = true
+project.workspace = true
+prompt_store.workspace = true
+release_channel.workspace = true
+reqwest_client.workspace = true
+serde.workspace = true
+settings.workspace = true
+smol.workspace = true
+toml.workspace = true
+workspace-hack.workspace = true
+
+[[bin]]
+name = "eval"
+path = "src/eval.rs"
+
+[lints]
+workspace = true

crates/eval/README.md 🔗

@@ -0,0 +1,7 @@
+# Eval
+
+This eval assumes the working directory is the root of the repository. Run it with:
+
+```sh
+cargo run -p eval
+```

crates/eval/examples/find_and_replace_diff_card/prompt.md 🔗

@@ -0,0 +1,3 @@
+Look at the `find_replace_file_tool.rs`. I want to implement a card for it. The card should be a brand new `Entity` with a `Render` implementation.
+
+The card should show a diff. It should be a beautifully presented diff. The card "box" should look like what we show for markdown codeblocks (look at `MarkdownElement`). I want to see a red background for lines that were deleted and a green background for lines that were added. We should have a div per diff line.

crates/eval/src/agent.rs 🔗

@@ -0,0 +1,229 @@
+use ::agent::{RequestKind, Thread, ThreadEvent, ThreadStore};
+use anyhow::anyhow;
+use assistant_tool::ToolWorkingSet;
+use client::{Client, UserStore};
+use collections::HashMap;
+use dap::DapRegistry;
+use gpui::{App, Entity, SemanticVersion, Subscription, Task, prelude::*};
+use language::LanguageRegistry;
+use language_model::{
+    AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
+};
+use node_runtime::NodeRuntime;
+use project::{Project, RealFs};
+use prompt_store::PromptBuilder;
+use settings::SettingsStore;
+use smol::channel;
+use std::sync::Arc;
+
+/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
+pub struct AgentAppState {
+    pub languages: Arc<LanguageRegistry>,
+    pub client: Arc<Client>,
+    pub user_store: Entity<UserStore>,
+    pub fs: Arc<dyn fs::Fs>,
+    pub node_runtime: NodeRuntime,
+
+    // Additional fields not present in `workspace::AppState`.
+    pub prompt_builder: Arc<PromptBuilder>,
+}
+
+pub struct Agent {
+    // pub thread: Entity<Thread>,
+    // pub project: Entity<Project>,
+    #[allow(dead_code)]
+    pub thread_store: Entity<ThreadStore>,
+    pub tool_use_counts: HashMap<Arc<str>, u32>,
+    pub done_tx: channel::Sender<anyhow::Result<()>>,
+    _subscription: Subscription,
+}
+
+impl Agent {
+    pub fn new(
+        app_state: Arc<AgentAppState>,
+        cx: &mut App,
+    ) -> anyhow::Result<(Entity<Self>, channel::Receiver<anyhow::Result<()>>)> {
+        let env = None;
+        let project = Project::local(
+            app_state.client.clone(),
+            app_state.node_runtime.clone(),
+            app_state.user_store.clone(),
+            app_state.languages.clone(),
+            Arc::new(DapRegistry::default()),
+            app_state.fs.clone(),
+            env,
+            cx,
+        );
+
+        let tools = Arc::new(ToolWorkingSet::default());
+        let thread_store =
+            ThreadStore::new(project.clone(), tools, app_state.prompt_builder.clone(), cx)?;
+
+        let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
+
+        let (done_tx, done_rx) = channel::unbounded::<anyhow::Result<()>>();
+
+        let headless_thread = cx.new(move |cx| Self {
+            _subscription: cx.subscribe(&thread, Self::handle_thread_event),
+            // thread,
+            // project,
+            thread_store,
+            tool_use_counts: HashMap::default(),
+            done_tx,
+        });
+
+        Ok((headless_thread, done_rx))
+    }
+
+    fn handle_thread_event(
+        &mut self,
+        thread: Entity<Thread>,
+        event: &ThreadEvent,
+        cx: &mut Context<Self>,
+    ) {
+        match event {
+            ThreadEvent::ShowError(err) => self
+                .done_tx
+                .send_blocking(Err(anyhow!("{:?}", err)))
+                .unwrap(),
+            ThreadEvent::DoneStreaming => {
+                let thread = thread.read(cx);
+                if let Some(message) = thread.messages().last() {
+                    println!("Message: {}", message.to_string());
+                }
+                if thread.all_tools_finished() {
+                    self.done_tx.send_blocking(Ok(())).unwrap()
+                }
+            }
+            ThreadEvent::UsePendingTools { .. } => {}
+            ThreadEvent::ToolConfirmationNeeded => {
+                // Automatically approve all tools that need confirmation in headless mode
+                println!("Tool confirmation needed - automatically approving in headless mode");
+
+                // Get the tools needing confirmation
+                let tools_needing_confirmation: Vec<_> = thread
+                    .read(cx)
+                    .tools_needing_confirmation()
+                    .cloned()
+                    .collect();
+
+                // Run each tool that needs confirmation
+                for tool_use in tools_needing_confirmation {
+                    if let Some(tool) = thread.read(cx).tools().tool(&tool_use.name, cx) {
+                        thread.update(cx, |thread, cx| {
+                            println!("Auto-approving tool: {}", tool_use.name);
+
+                            // Create a request to send to the tool
+                            let request = thread.to_completion_request(RequestKind::Chat, cx);
+                            let messages = Arc::new(request.messages);
+
+                            // Run the tool
+                            thread.run_tool(
+                                tool_use.id.clone(),
+                                tool_use.ui_text.clone(),
+                                tool_use.input.clone(),
+                                &messages,
+                                tool,
+                                cx,
+                            );
+                        });
+                    }
+                }
+            }
+            ThreadEvent::ToolFinished {
+                tool_use_id,
+                pending_tool_use,
+                ..
+            } => {
+                if let Some(pending_tool_use) = pending_tool_use {
+                    println!(
+                        "Used tool {} with input: {}",
+                        pending_tool_use.name, pending_tool_use.input
+                    );
+                    *self
+                        .tool_use_counts
+                        .entry(pending_tool_use.name.clone())
+                        .or_insert(0) += 1;
+                }
+                if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
+                    println!("Tool result: {:?}", tool_result);
+                }
+            }
+            _ => {}
+        }
+    }
+}
+
+pub fn init(cx: &mut App) -> Arc<AgentAppState> {
+    release_channel::init(SemanticVersion::default(), cx);
+    gpui_tokio::init(cx);
+
+    let mut settings_store = SettingsStore::new(cx);
+    settings_store
+        .set_default_settings(settings::default_settings().as_ref(), cx)
+        .unwrap();
+    cx.set_global(settings_store);
+    client::init_settings(cx);
+    Project::init_settings(cx);
+
+    let client = Client::production(cx);
+    cx.set_http_client(client.http_client().clone());
+
+    let git_binary_path = None;
+    let fs = Arc::new(RealFs::new(
+        git_binary_path,
+        cx.background_executor().clone(),
+    ));
+
+    let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone()));
+
+    let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
+
+    language::init(cx);
+    language_model::init(client.clone(), cx);
+    language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
+    assistant_tools::init(client.http_client().clone(), cx);
+    context_server::init(cx);
+    let stdout_is_a_pty = false;
+    let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
+    agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
+
+    Arc::new(AgentAppState {
+        languages,
+        client,
+        user_store,
+        fs,
+        node_runtime: NodeRuntime::unavailable(),
+        prompt_builder,
+    })
+}
+
+pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
+    let model_registry = LanguageModelRegistry::read_global(cx);
+    let model = model_registry
+        .available_models(cx)
+        .find(|model| model.id().0 == model_name);
+
+    let Some(model) = model else {
+        return Err(anyhow!(
+            "No language model named {} was available. Available models: {}",
+            model_name,
+            model_registry
+                .available_models(cx)
+                .map(|model| model.id().0.clone())
+                .collect::<Vec<_>>()
+                .join(", ")
+        ));
+    };
+
+    Ok(model)
+}
+
+pub fn authenticate_model_provider(
+    provider_id: LanguageModelProviderId,
+    cx: &mut App,
+) -> Task<std::result::Result<(), AuthenticateError>> {
+    let model_registry = LanguageModelRegistry::read_global(cx);
+    let model_provider = model_registry.provider(&provider_id).unwrap();
+    model_provider.authenticate(cx)
+}

crates/eval/src/eval.rs 🔗

@@ -0,0 +1,101 @@
+use agent::Agent;
+use anyhow::Result;
+use gpui::Application;
+use language_model::LanguageModelRegistry;
+use reqwest_client::ReqwestClient;
+use serde::Deserialize;
+use std::{
+    fs,
+    path::{Path, PathBuf},
+    sync::Arc,
+};
+mod agent;
+
+#[derive(Debug, Deserialize)]
+pub struct ExampleBase {
+    pub path: PathBuf,
+    pub revision: String,
+}
+
+#[derive(Debug)]
+pub struct Example {
+    pub base: ExampleBase,
+
+    /// Content of the prompt.md file
+    pub prompt: String,
+
+    /// Content of the rubric.md file
+    pub rubric: String,
+}
+
+impl Example {
+    /// Load an example from a directory containing base.toml, prompt.md, and rubric.md
+    pub fn load_from_directory<P: AsRef<Path>>(dir_path: P) -> Result<Self> {
+        let base_path = dir_path.as_ref().join("base.toml");
+        let prompt_path = dir_path.as_ref().join("prompt.md");
+        let rubric_path = dir_path.as_ref().join("rubric.md");
+
+        let mut base: ExampleBase = toml::from_str(&fs::read_to_string(&base_path)?)?;
+        base.path = base.path.canonicalize()?;
+
+        Ok(Example {
+            base,
+            prompt: fs::read_to_string(prompt_path)?,
+            rubric: fs::read_to_string(rubric_path)?,
+        })
+    }
+
+    /// Set up the example by checking out the specified Git revision
+    pub fn setup(&self) -> Result<()> {
+        use std::process::Command;
+
+        // Check if the directory exists
+        let path = Path::new(&self.base.path);
+        anyhow::ensure!(path.exists(), "Path does not exist: {:?}", self.base.path);
+
+        // Change to the project directory and checkout the specified revision
+        let output = Command::new("git")
+            .current_dir(&self.base.path)
+            .arg("checkout")
+            .arg(&self.base.revision)
+            .output()?;
+        anyhow::ensure!(
+            output.status.success(),
+            "Failed to checkout revision {}: {}",
+            self.base.revision,
+            String::from_utf8_lossy(&output.stderr),
+        );
+
+        Ok(())
+    }
+}
+
+fn main() {
+    env_logger::init();
+    let http_client = Arc::new(ReqwestClient::new());
+    let app = Application::headless().with_http_client(http_client.clone());
+
+    app.run(move |cx| {
+        let app_state = crate::agent::init(cx);
+        let _agent = Agent::new(app_state, cx);
+
+        let model = agent::find_model("claude-3-7-sonnet-thinking-latest", cx).unwrap();
+
+        LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
+            registry.set_default_model(Some(model.clone()), cx);
+        });
+
+        let model_provider_id = model.provider_id();
+
+        let authenticate = agent::authenticate_model_provider(model_provider_id.clone(), cx);
+
+        cx.spawn(async move |_cx| {
+            authenticate.await.unwrap();
+        })
+        .detach();
+    });
+
+    // let example =
+    //     Example::load_from_directory("./crates/eval/examples/find_and_replace_diff_card")?;
+    // example.setup()?;
+}