Detailed changes
@@ -4901,6 +4901,37 @@ dependencies = [
"num-traits",
]
+[[package]]
+name = "eval"
+version = "0.1.0"
+dependencies = [
+ "agent",
+ "anyhow",
+ "assistant_tool",
+ "assistant_tools",
+ "client",
+ "collections",
+ "context_server",
+ "dap",
+ "env_logger 0.11.8",
+ "fs",
+ "gpui",
+ "gpui_tokio",
+ "language",
+ "language_model",
+ "language_models",
+ "node_runtime",
+ "project",
+ "prompt_store",
+ "release_channel",
+ "reqwest_client",
+ "serde",
+ "settings",
+ "smol",
+ "toml 0.8.20",
+ "workspace-hack",
+]
+
[[package]]
name = "evals"
version = "0.1.0"
@@ -47,6 +47,7 @@ members = [
"crates/diagnostics",
"crates/docs_preprocessor",
"crates/editor",
+ "crates/eval",
"crates/evals",
"crates/extension",
"crates/extension_api",
@@ -21,7 +21,7 @@ use gpui::{
linear_color_stop, linear_gradient, list, percentage, pulsating_between,
};
use language::{Buffer, LanguageRegistry};
-use language_model::{ConfiguredModel, LanguageModelRegistry, LanguageModelToolUseId, Role};
+use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
use markdown::parser::CodeBlockKind;
use markdown::{Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown, without_fences};
use project::ProjectItem as _;
@@ -897,11 +897,7 @@ impl ActiveThread {
self.save_thread(cx);
cx.notify();
}
- ThreadEvent::UsePendingTools => {
- let tool_uses = self
- .thread
- .update(cx, |thread, cx| thread.use_pending_tools(cx));
-
+ ThreadEvent::UsePendingTools { tool_uses } => {
for tool_use in tool_uses {
self.render_tool_use_markdown(
tool_use.id.clone(),
@@ -913,11 +909,8 @@ impl ActiveThread {
}
}
ThreadEvent::ToolFinished {
- pending_tool_use,
- canceled,
- ..
+ pending_tool_use, ..
} => {
- let canceled = *canceled;
if let Some(tool_use) = pending_tool_use {
self.render_tool_use_markdown(
tool_use.id.clone(),
@@ -931,18 +924,6 @@ impl ActiveThread {
cx,
);
}
-
- if self.thread.read(cx).all_tools_finished() {
- let model_registry = LanguageModelRegistry::read_global(cx);
- if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
- self.thread.update(cx, |thread, cx| {
- thread.attach_tool_results(cx);
- if !canceled {
- thread.send_to_model(model, RequestKind::Chat, cx);
- }
- });
- }
- }
}
ThreadEvent::CheckpointChanged => cx.notify(),
}
@@ -1181,7 +1181,8 @@ impl Thread {
match result.as_ref() {
Ok(stop_reason) => match stop_reason {
StopReason::ToolUse => {
- cx.emit(ThreadEvent::UsePendingTools);
+ let tool_uses = thread.use_pending_tools(cx);
+ cx.emit(ThreadEvent::UsePendingTools { tool_uses });
}
StopReason::EndTurn => {}
StopReason::MaxTokens => {}
@@ -1369,10 +1370,7 @@ impl Thread {
)
}
- pub fn use_pending_tools(
- &mut self,
- cx: &mut Context<Self>,
- ) -> impl IntoIterator<Item = PendingToolUse> + use<> {
+ pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
let request = self.to_completion_request(RequestKind::Chat, cx);
let messages = Arc::new(request.messages);
let pending_tool_uses = self
@@ -1460,18 +1458,36 @@ impl Thread {
output,
cx,
);
-
- cx.emit(ThreadEvent::ToolFinished {
- tool_use_id,
- pending_tool_use,
- canceled: false,
- });
+ thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
})
.ok();
}
})
}
+ fn tool_finished(
+ &mut self,
+ tool_use_id: LanguageModelToolUseId,
+ pending_tool_use: Option<PendingToolUse>,
+ canceled: bool,
+ cx: &mut Context<Self>,
+ ) {
+ if self.all_tools_finished() {
+ let model_registry = LanguageModelRegistry::read_global(cx);
+ if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
+ self.attach_tool_results(cx);
+ if !canceled {
+ self.send_to_model(model, RequestKind::Chat, cx);
+ }
+ }
+ }
+
+ cx.emit(ThreadEvent::ToolFinished {
+ tool_use_id,
+ pending_tool_use,
+ });
+ }
+
pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
// Insert a user message to contain the tool results.
self.insert_user_message(
@@ -1495,11 +1511,12 @@ impl Thread {
let mut canceled = false;
for pending_tool_use in self.tool_use.cancel_pending() {
canceled = true;
- cx.emit(ThreadEvent::ToolFinished {
- tool_use_id: pending_tool_use.id.clone(),
- pending_tool_use: Some(pending_tool_use),
- canceled: true,
- });
+ self.tool_finished(
+ pending_tool_use.id.clone(),
+ Some(pending_tool_use),
+ true,
+ cx,
+ );
}
canceled
};
@@ -1866,12 +1883,7 @@ impl Thread {
self.tool_use
.insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
-
- cx.emit(ThreadEvent::ToolFinished {
- tool_use_id,
- pending_tool_use: None,
- canceled: true,
- });
+ self.tool_finished(tool_use_id.clone(), None, true, cx);
}
}
@@ -1897,14 +1909,14 @@ pub enum ThreadEvent {
MessageDeleted(MessageId),
SummaryGenerated,
SummaryChanged,
- UsePendingTools,
+ UsePendingTools {
+ tool_uses: Vec<PendingToolUse>,
+ },
ToolFinished {
#[allow(unused)]
tool_use_id: LanguageModelToolUseId,
/// The pending tool use that corresponds to this tool.
pending_tool_use: Option<PendingToolUse>,
- /// Whether the tool was canceled by the user.
- canceled: bool,
},
CheckpointChanged,
ToolConfirmationNeeded,
@@ -95,11 +95,7 @@ impl HeadlessAssistant {
self.done_tx.send_blocking(Ok(())).unwrap()
}
}
- ThreadEvent::UsePendingTools => {
- thread.update(cx, |thread, cx| {
- thread.use_pending_tools(cx);
- });
- }
+ ThreadEvent::UsePendingTools { .. } => {}
ThreadEvent::ToolConfirmationNeeded => {
// Automatically approve all tools that need confirmation in headless mode
println!("Tool confirmation needed - automatically approving in headless mode");
@@ -152,19 +148,6 @@ impl HeadlessAssistant {
if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
println!("Tool result: {:?}", tool_result);
}
- if thread.read(cx).all_tools_finished() {
- let model_registry = LanguageModelRegistry::read_global(cx);
- if let Some(model) = model_registry.default_model() {
- thread.update(cx, |thread, cx| {
- thread.attach_tool_results(cx);
- thread.send_to_model(model.model, RequestKind::Chat, cx);
- });
- } else {
- println!(
- "Warning: No active language model available to continue conversation"
- );
- }
- }
}
_ => {}
}
@@ -0,0 +1,39 @@
+[package]
+name = "eval"
+version = "0.1.0"
+publish.workspace = true
+edition.workspace = true
+
+[dependencies]
+agent.workspace = true
+anyhow.workspace = true
+assistant_tool.workspace = true
+assistant_tools.workspace = true
+client.workspace = true
+collections.workspace = true
+context_server.workspace = true
+dap.workspace = true
+env_logger.workspace = true
+fs.workspace = true
+gpui.workspace = true
+gpui_tokio.workspace = true
+language.workspace = true
+language_model.workspace = true
+language_models.workspace = true
+node_runtime.workspace = true
+project.workspace = true
+prompt_store.workspace = true
+release_channel.workspace = true
+reqwest_client.workspace = true
+serde.workspace = true
+settings.workspace = true
+smol.workspace = true
+toml.workspace = true
+workspace-hack.workspace = true
+
+[[bin]]
+name = "eval"
+path = "src/eval.rs"
+
+[lints]
+workspace = true
@@ -0,0 +1 @@
+../../LICENSE-GPL
@@ -0,0 +1,7 @@
+# Eval
+
+This eval assumes the working directory is the root of the repository. Run it with:
+
+```sh
+cargo run -p eval
+```
@@ -0,0 +1,2 @@
+path = "../zed_worktree"
+revision = "38fcadf9481d018543c65f36ac3bafeba190179b"
@@ -0,0 +1,3 @@
+Look at the `find_replace_file_tool.rs`. I want to implement a card for it. The card should be a brand new `Entity` with a `Render` implementation.
+
+The card should show a diff. It should be a beautifully presented diff. The card "box" should look like what we show for markdown codeblocks (look at `MarkdownElement`). I want to see a red background for lines that were deleted and a green background for lines that were added. We should have a div per diff line.
@@ -0,0 +1,229 @@
+use ::agent::{RequestKind, Thread, ThreadEvent, ThreadStore};
+use anyhow::anyhow;
+use assistant_tool::ToolWorkingSet;
+use client::{Client, UserStore};
+use collections::HashMap;
+use dap::DapRegistry;
+use gpui::{App, Entity, SemanticVersion, Subscription, Task, prelude::*};
+use language::LanguageRegistry;
+use language_model::{
+ AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
+};
+use node_runtime::NodeRuntime;
+use project::{Project, RealFs};
+use prompt_store::PromptBuilder;
+use settings::SettingsStore;
+use smol::channel;
+use std::sync::Arc;
+
+/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
+pub struct AgentAppState {
+ pub languages: Arc<LanguageRegistry>,
+ pub client: Arc<Client>,
+ pub user_store: Entity<UserStore>,
+ pub fs: Arc<dyn fs::Fs>,
+ pub node_runtime: NodeRuntime,
+
+ // Additional fields not present in `workspace::AppState`.
+ pub prompt_builder: Arc<PromptBuilder>,
+}
+
+pub struct Agent {
+ // pub thread: Entity<Thread>,
+ // pub project: Entity<Project>,
+ #[allow(dead_code)]
+ pub thread_store: Entity<ThreadStore>,
+ pub tool_use_counts: HashMap<Arc<str>, u32>,
+ pub done_tx: channel::Sender<anyhow::Result<()>>,
+ _subscription: Subscription,
+}
+
+impl Agent {
+ pub fn new(
+ app_state: Arc<AgentAppState>,
+ cx: &mut App,
+ ) -> anyhow::Result<(Entity<Self>, channel::Receiver<anyhow::Result<()>>)> {
+ let env = None;
+ let project = Project::local(
+ app_state.client.clone(),
+ app_state.node_runtime.clone(),
+ app_state.user_store.clone(),
+ app_state.languages.clone(),
+ Arc::new(DapRegistry::default()),
+ app_state.fs.clone(),
+ env,
+ cx,
+ );
+
+ let tools = Arc::new(ToolWorkingSet::default());
+ let thread_store =
+ ThreadStore::new(project.clone(), tools, app_state.prompt_builder.clone(), cx)?;
+
+ let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
+
+ let (done_tx, done_rx) = channel::unbounded::<anyhow::Result<()>>();
+
+ let headless_thread = cx.new(move |cx| Self {
+ _subscription: cx.subscribe(&thread, Self::handle_thread_event),
+ // thread,
+ // project,
+ thread_store,
+ tool_use_counts: HashMap::default(),
+ done_tx,
+ });
+
+ Ok((headless_thread, done_rx))
+ }
+
+ fn handle_thread_event(
+ &mut self,
+ thread: Entity<Thread>,
+ event: &ThreadEvent,
+ cx: &mut Context<Self>,
+ ) {
+ match event {
+ ThreadEvent::ShowError(err) => self
+ .done_tx
+ .send_blocking(Err(anyhow!("{:?}", err)))
+ .unwrap(),
+ ThreadEvent::DoneStreaming => {
+ let thread = thread.read(cx);
+ if let Some(message) = thread.messages().last() {
+ println!("Message: {}", message.to_string());
+ }
+ if thread.all_tools_finished() {
+ self.done_tx.send_blocking(Ok(())).unwrap()
+ }
+ }
+ ThreadEvent::UsePendingTools { .. } => {}
+ ThreadEvent::ToolConfirmationNeeded => {
+ // Automatically approve all tools that need confirmation in headless mode
+ println!("Tool confirmation needed - automatically approving in headless mode");
+
+ // Get the tools needing confirmation
+ let tools_needing_confirmation: Vec<_> = thread
+ .read(cx)
+ .tools_needing_confirmation()
+ .cloned()
+ .collect();
+
+ // Run each tool that needs confirmation
+ for tool_use in tools_needing_confirmation {
+ if let Some(tool) = thread.read(cx).tools().tool(&tool_use.name, cx) {
+ thread.update(cx, |thread, cx| {
+ println!("Auto-approving tool: {}", tool_use.name);
+
+ // Create a request to send to the tool
+ let request = thread.to_completion_request(RequestKind::Chat, cx);
+ let messages = Arc::new(request.messages);
+
+ // Run the tool
+ thread.run_tool(
+ tool_use.id.clone(),
+ tool_use.ui_text.clone(),
+ tool_use.input.clone(),
+ &messages,
+ tool,
+ cx,
+ );
+ });
+ }
+ }
+ }
+ ThreadEvent::ToolFinished {
+ tool_use_id,
+ pending_tool_use,
+ ..
+ } => {
+ if let Some(pending_tool_use) = pending_tool_use {
+ println!(
+ "Used tool {} with input: {}",
+ pending_tool_use.name, pending_tool_use.input
+ );
+ *self
+ .tool_use_counts
+ .entry(pending_tool_use.name.clone())
+ .or_insert(0) += 1;
+ }
+ if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
+ println!("Tool result: {:?}", tool_result);
+ }
+ }
+ _ => {}
+ }
+ }
+}
+
+pub fn init(cx: &mut App) -> Arc<AgentAppState> {
+ release_channel::init(SemanticVersion::default(), cx);
+ gpui_tokio::init(cx);
+
+ let mut settings_store = SettingsStore::new(cx);
+ settings_store
+ .set_default_settings(settings::default_settings().as_ref(), cx)
+ .unwrap();
+ cx.set_global(settings_store);
+ client::init_settings(cx);
+ Project::init_settings(cx);
+
+ let client = Client::production(cx);
+ cx.set_http_client(client.http_client().clone());
+
+ let git_binary_path = None;
+ let fs = Arc::new(RealFs::new(
+ git_binary_path,
+ cx.background_executor().clone(),
+ ));
+
+ let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone()));
+
+ let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
+
+ language::init(cx);
+ language_model::init(client.clone(), cx);
+ language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
+ assistant_tools::init(client.http_client().clone(), cx);
+ context_server::init(cx);
+ let stdout_is_a_pty = false;
+ let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
+ agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
+
+ Arc::new(AgentAppState {
+ languages,
+ client,
+ user_store,
+ fs,
+ node_runtime: NodeRuntime::unavailable(),
+ prompt_builder,
+ })
+}
+
+pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
+ let model_registry = LanguageModelRegistry::read_global(cx);
+ let model = model_registry
+ .available_models(cx)
+ .find(|model| model.id().0 == model_name);
+
+ let Some(model) = model else {
+ return Err(anyhow!(
+ "No language model named {} was available. Available models: {}",
+ model_name,
+ model_registry
+ .available_models(cx)
+ .map(|model| model.id().0.clone())
+ .collect::<Vec<_>>()
+ .join(", ")
+ ));
+ };
+
+ Ok(model)
+}
+
+pub fn authenticate_model_provider(
+ provider_id: LanguageModelProviderId,
+ cx: &mut App,
+) -> Task<std::result::Result<(), AuthenticateError>> {
+ let model_registry = LanguageModelRegistry::read_global(cx);
+ let model_provider = model_registry.provider(&provider_id).unwrap();
+ model_provider.authenticate(cx)
+}
@@ -0,0 +1,101 @@
+use agent::Agent;
+use anyhow::Result;
+use gpui::Application;
+use language_model::LanguageModelRegistry;
+use reqwest_client::ReqwestClient;
+use serde::Deserialize;
+use std::{
+ fs,
+ path::{Path, PathBuf},
+ sync::Arc,
+};
+mod agent;
+
+#[derive(Debug, Deserialize)]
+pub struct ExampleBase {
+ pub path: PathBuf,
+ pub revision: String,
+}
+
+#[derive(Debug)]
+pub struct Example {
+ pub base: ExampleBase,
+
+ /// Content of the prompt.md file
+ pub prompt: String,
+
+ /// Content of the rubric.md file
+ pub rubric: String,
+}
+
+impl Example {
+ /// Load an example from a directory containing base.toml, prompt.md, and rubric.md
+ pub fn load_from_directory<P: AsRef<Path>>(dir_path: P) -> Result<Self> {
+ let base_path = dir_path.as_ref().join("base.toml");
+ let prompt_path = dir_path.as_ref().join("prompt.md");
+ let rubric_path = dir_path.as_ref().join("rubric.md");
+
+ let mut base: ExampleBase = toml::from_str(&fs::read_to_string(&base_path)?)?;
+ base.path = base.path.canonicalize()?;
+
+ Ok(Example {
+ base,
+ prompt: fs::read_to_string(prompt_path)?,
+ rubric: fs::read_to_string(rubric_path)?,
+ })
+ }
+
+ /// Set up the example by checking out the specified Git revision
+ pub fn setup(&self) -> Result<()> {
+ use std::process::Command;
+
+ // Check if the directory exists
+ let path = Path::new(&self.base.path);
+ anyhow::ensure!(path.exists(), "Path does not exist: {:?}", self.base.path);
+
+ // Change to the project directory and checkout the specified revision
+ let output = Command::new("git")
+ .current_dir(&self.base.path)
+ .arg("checkout")
+ .arg(&self.base.revision)
+ .output()?;
+ anyhow::ensure!(
+ output.status.success(),
+ "Failed to checkout revision {}: {}",
+ self.base.revision,
+ String::from_utf8_lossy(&output.stderr),
+ );
+
+ Ok(())
+ }
+}
+
+fn main() {
+ env_logger::init();
+ let http_client = Arc::new(ReqwestClient::new());
+ let app = Application::headless().with_http_client(http_client.clone());
+
+ app.run(move |cx| {
+ let app_state = crate::agent::init(cx);
+ let _agent = Agent::new(app_state, cx);
+
+ let model = agent::find_model("claude-3-7-sonnet-thinking-latest", cx).unwrap();
+
+ LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
+ registry.set_default_model(Some(model.clone()), cx);
+ });
+
+ let model_provider_id = model.provider_id();
+
+ let authenticate = agent::authenticate_model_provider(model_provider_id.clone(), cx);
+
+ cx.spawn(async move |_cx| {
+ authenticate.await.unwrap();
+ })
+ .detach();
+ });
+
+ // let example =
+ // Example::load_from_directory("./crates/eval/examples/find_and_replace_diff_card")?;
+ // example.setup()?;
+}