Detailed changes
@@ -52,7 +52,6 @@ dependencies = [
name = "agent"
version = "0.1.0"
dependencies = [
- "agent_rules",
"anyhow",
"assistant_context_editor",
"assistant_settings",
@@ -116,6 +115,7 @@ dependencies = [
"terminal_view",
"text",
"theme",
+ "thiserror 2.0.12",
"time",
"time_format",
"ui",
@@ -127,57 +127,6 @@ dependencies = [
"zed_actions",
]
-[[package]]
-name = "agent_eval"
-version = "0.1.0"
-dependencies = [
- "agent",
- "anyhow",
- "assistant_tool",
- "assistant_tools",
- "clap",
- "client",
- "collections",
- "context_server",
- "dap",
- "env_logger 0.11.8",
- "fs",
- "futures 0.3.31",
- "gpui",
- "gpui_tokio",
- "language",
- "language_model",
- "language_models",
- "node_runtime",
- "project",
- "prompt_store",
- "release_channel",
- "reqwest_client",
- "serde",
- "serde_json",
- "serde_json_lenient",
- "settings",
- "smol",
- "tempfile",
- "util",
- "walkdir",
- "workspace-hack",
-]
-
-[[package]]
-name = "agent_rules"
-version = "0.1.0"
-dependencies = [
- "anyhow",
- "fs",
- "gpui",
- "indoc",
- "prompt_store",
- "util",
- "workspace-hack",
- "worktree",
-]
-
[[package]]
name = "ahash"
version = "0.7.8"
@@ -4910,14 +4859,15 @@ version = "0.1.0"
dependencies = [
"agent",
"anyhow",
+ "assistant_settings",
"assistant_tool",
"assistant_tools",
"client",
- "collections",
"context_server",
"dap",
"env_logger 0.11.8",
"fs",
+ "futures 0.3.31",
"gpui",
"gpui_tokio",
"language",
@@ -4930,7 +4880,6 @@ dependencies = [
"reqwest_client",
"serde",
"settings",
- "smol",
"toml 0.8.20",
"workspace-hack",
]
@@ -3,13 +3,11 @@ resolver = "2"
members = [
"crates/activity_indicator",
"crates/agent",
- "crates/agent_rules",
"crates/anthropic",
"crates/askpass",
"crates/assets",
"crates/assistant",
"crates/assistant_context_editor",
- "crates/agent_eval",
"crates/assistant_settings",
"crates/assistant_slash_command",
"crates/assistant_slash_commands",
@@ -211,14 +209,12 @@ edition = "2024"
activity_indicator = { path = "crates/activity_indicator" }
agent = { path = "crates/agent" }
-agent_rules = { path = "crates/agent_rules" }
ai = { path = "crates/ai" }
anthropic = { path = "crates/anthropic" }
askpass = { path = "crates/askpass" }
assets = { path = "crates/assets" }
assistant = { path = "crates/assistant" }
assistant_context_editor = { path = "crates/assistant_context_editor" }
-assistant_eval = { path = "crates/agent_eval" }
assistant_settings = { path = "crates/assistant_settings" }
assistant_slash_command = { path = "crates/assistant_slash_command" }
assistant_slash_commands = { path = "crates/assistant_slash_commands" }
@@ -19,7 +19,6 @@ test-support = [
]
[dependencies]
-agent_rules.workspace = true
anyhow.workspace = true
assistant_context_editor.workspace = true
assistant_settings.workspace = true
@@ -81,6 +80,7 @@ terminal.workspace = true
terminal_view.workspace = true
text.workspace = true
theme.workspace = true
+thiserror.workspace = true
time.workspace = true
time_format.workspace = true
ui.workspace = true
@@ -4,7 +4,7 @@ use crate::thread::{
LastRestoreCheckpoint, MessageId, MessageSegment, RequestKind, Thread, ThreadError,
ThreadEvent, ThreadFeedback,
};
-use crate::thread_store::ThreadStore;
+use crate::thread_store::{RulesLoadingError, ThreadStore};
use crate::tool_use::{PendingToolUseStatus, ToolUse, ToolUseStatus};
use crate::ui::{AddedContext, AgentNotification, AgentNotificationEvent, ContextPill};
use crate::{AssistantPanel, OpenActiveThreadAsMarkdown};
@@ -21,7 +21,7 @@ use gpui::{
linear_color_stop, linear_gradient, list, percentage, pulsating_between,
};
use language::{Buffer, LanguageRegistry};
-use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
+use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role, StopReason};
use markdown::parser::{CodeBlockKind, CodeBlockMetadata};
use markdown::{Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown};
use project::ProjectItem as _;
@@ -668,6 +668,7 @@ impl ActiveThread {
let subscriptions = vec![
cx.observe(&thread, |_, _, cx| cx.notify()),
cx.subscribe_in(&thread, window, Self::handle_thread_event),
+ cx.subscribe(&thread_store, Self::handle_rules_loading_error),
];
let list_state = ListState::new(0, ListAlignment::Bottom, px(2048.), {
@@ -833,10 +834,9 @@ impl ActiveThread {
| ThreadEvent::SummaryChanged => {
self.save_thread(cx);
}
- ThreadEvent::DoneStreaming => {
- let thread = self.thread.read(cx);
-
- if !thread.is_generating() {
+ ThreadEvent::Stopped(reason) => match reason {
+ Ok(StopReason::EndTurn | StopReason::MaxTokens) => {
+ let thread = self.thread.read(cx);
self.show_notification(
if thread.used_tools_since_last_user_message() {
"Finished running tools"
@@ -848,7 +848,8 @@ impl ActiveThread {
cx,
);
}
- }
+ _ => {}
+ },
ThreadEvent::ToolConfirmationNeeded => {
self.show_notification("Waiting for tool confirmation", IconName::Info, window, cx);
}
@@ -925,6 +926,19 @@ impl ActiveThread {
}
}
+ fn handle_rules_loading_error(
+ &mut self,
+ _thread_store: Entity<ThreadStore>,
+ error: &RulesLoadingError,
+ cx: &mut Context<Self>,
+ ) {
+ self.last_error = Some(ThreadError::Message {
+ header: "Error loading rules file".into(),
+ message: error.message.clone(),
+ });
+ cx.notify();
+ }
+
fn show_notification(
&mut self,
caption: impl Into<SharedString>,
@@ -2701,12 +2715,13 @@ impl ActiveThread {
}
fn render_rules_item(&self, cx: &Context<Self>) -> AnyElement {
- let Some(system_prompt_context) = self.thread.read(cx).system_prompt_context().as_ref()
- else {
+ let project_context = self.thread.read(cx).project_context();
+ let project_context = project_context.borrow();
+ let Some(project_context) = project_context.as_ref() else {
return div().into_any();
};
- let rules_files = system_prompt_context
+ let rules_files = project_context
.worktrees
.iter()
.filter_map(|worktree| worktree.rules_file.as_ref())
@@ -2796,12 +2811,13 @@ impl ActiveThread {
}
fn handle_open_rules(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context<Self>) {
- let Some(system_prompt_context) = self.thread.read(cx).system_prompt_context().as_ref()
- else {
+ let project_context = self.thread.read(cx).project_context();
+ let project_context = project_context.borrow();
+ let Some(project_context) = project_context.as_ref() else {
return;
};
- let abs_paths = system_prompt_context
+ let abs_paths = project_context
.worktrees
.iter()
.flat_map(|worktree| worktree.rules_file.as_ref())
@@ -921,15 +921,16 @@ mod tests {
})
.unwrap();
- let thread_store = cx.update(|cx| {
- ThreadStore::new(
- project.clone(),
- Arc::default(),
- Arc::new(PromptBuilder::new(None).unwrap()),
- cx,
- )
- .unwrap()
- });
+ let thread_store = cx
+ .update(|cx| {
+ ThreadStore::load(
+ project.clone(),
+ Arc::default(),
+ Arc::new(PromptBuilder::new(None).unwrap()),
+ cx,
+ )
+ })
+ .await;
let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
@@ -194,10 +194,12 @@ impl AssistantPanel {
) -> Task<Result<Entity<Self>>> {
cx.spawn(async move |cx| {
let tools = Arc::new(ToolWorkingSet::default());
- let thread_store = workspace.update(cx, |workspace, cx| {
- let project = workspace.project().clone();
- ThreadStore::new(project, tools.clone(), prompt_builder.clone(), cx)
- })??;
+ let thread_store = workspace
+ .update(cx, |workspace, cx| {
+ let project = workspace.project().clone();
+ ThreadStore::load(project, tools.clone(), prompt_builder.clone(), cx)
+ })?
+ .await;
let slash_commands = Arc::new(SlashCommandWorkingSet::default());
let context_store = workspace
@@ -32,8 +32,8 @@ use crate::profile_selector::ProfileSelector;
use crate::thread::{RequestKind, Thread, TokenUsageRatio};
use crate::thread_store::ThreadStore;
use crate::{
- AgentDiff, Chat, ChatMode, NewThread, OpenAgentDiff, RemoveAllContext, ThreadEvent,
- ToggleContextPicker, ToggleProfileSelector,
+ AgentDiff, Chat, ChatMode, NewThread, OpenAgentDiff, RemoveAllContext, ToggleContextPicker,
+ ToggleProfileSelector,
};
pub struct MessageEditor {
@@ -235,8 +235,6 @@ impl MessageEditor {
let refresh_task =
refresh_context_store_text(self.context_store.clone(), &HashSet::default(), cx);
- let system_prompt_context_task = self.thread.read(cx).load_system_prompt_context(cx);
-
let thread = self.thread.clone();
let context_store = self.context_store.clone();
let git_store = self.project.read(cx).git_store().clone();
@@ -245,16 +243,6 @@ impl MessageEditor {
cx.spawn(async move |this, cx| {
let checkpoint = checkpoint.await.ok();
refresh_task.await;
- let (system_prompt_context, load_error) = system_prompt_context_task.await;
-
- thread
- .update(cx, |thread, cx| {
- thread.set_system_prompt_context(system_prompt_context);
- if let Some(load_error) = load_error {
- cx.emit(ThreadEvent::ShowError(load_error));
- }
- })
- .log_err();
thread
.update(cx, |thread, cx| {
@@ -3,14 +3,12 @@ use std::io::Write;
use std::ops::Range;
use std::sync::Arc;
-use agent_rules::load_worktree_rules_file;
use anyhow::{Context as _, Result, anyhow};
use assistant_settings::AssistantSettings;
use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
use chrono::{DateTime, Utc};
use collections::{BTreeMap, HashMap};
use feature_flags::{self, FeatureFlagAppExt};
-use fs::Fs;
use futures::future::Shared;
use futures::{FutureExt, StreamExt as _};
use git::repository::DiffType;
@@ -21,19 +19,20 @@ use language_model::{
LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
PaymentRequiredError, Role, StopReason, TokenUsage,
};
+use project::Project;
use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
-use project::{Project, Worktree};
-use prompt_store::{AssistantSystemPromptContext, PromptBuilder, WorktreeInfoForSystemPrompt};
+use prompt_store::PromptBuilder;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::Settings;
+use thiserror::Error;
use util::{ResultExt as _, TryFutureExt as _, post_inc};
use uuid::Uuid;
use crate::context::{AssistantContext, ContextId, format_context_as_string};
use crate::thread_store::{
SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
- SerializedToolUse,
+ SerializedToolUse, SharedProjectContext,
};
use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
@@ -247,7 +246,7 @@ pub struct Thread {
next_message_id: MessageId,
context: BTreeMap<ContextId, AssistantContext>,
context_by_message: HashMap<MessageId, Vec<ContextId>>,
- system_prompt_context: Option<AssistantSystemPromptContext>,
+ project_context: SharedProjectContext,
checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
completion_count: usize,
pending_completions: Vec<PendingCompletion>,
@@ -269,6 +268,7 @@ impl Thread {
project: Entity<Project>,
tools: Arc<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
+ system_prompt: SharedProjectContext,
cx: &mut Context<Self>,
) -> Self {
Self {
@@ -281,7 +281,7 @@ impl Thread {
next_message_id: MessageId(0),
context: BTreeMap::default(),
context_by_message: HashMap::default(),
- system_prompt_context: None,
+ project_context: system_prompt,
checkpoints_by_message: HashMap::default(),
completion_count: 0,
pending_completions: Vec::new(),
@@ -310,6 +310,7 @@ impl Thread {
project: Entity<Project>,
tools: Arc<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
+ project_context: SharedProjectContext,
cx: &mut Context<Self>,
) -> Self {
let next_message_id = MessageId(
@@ -350,7 +351,7 @@ impl Thread {
next_message_id,
context: BTreeMap::default(),
context_by_message: HashMap::default(),
- system_prompt_context: None,
+ project_context,
checkpoints_by_message: HashMap::default(),
completion_count: 0,
pending_completions: Vec::new(),
@@ -388,6 +389,10 @@ impl Thread {
self.summary.clone()
}
+ pub fn project_context(&self) -> SharedProjectContext {
+ self.project_context.clone()
+ }
+
pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
pub fn summary_or_default(&self) -> SharedString {
@@ -812,86 +817,6 @@ impl Thread {
})
}
- pub fn set_system_prompt_context(&mut self, context: AssistantSystemPromptContext) {
- self.system_prompt_context = Some(context);
- }
-
- pub fn system_prompt_context(&self) -> &Option<AssistantSystemPromptContext> {
- &self.system_prompt_context
- }
-
- pub fn load_system_prompt_context(
- &self,
- cx: &App,
- ) -> Task<(AssistantSystemPromptContext, Option<ThreadError>)> {
- let project = self.project.read(cx);
- let tasks = project
- .visible_worktrees(cx)
- .map(|worktree| {
- Self::load_worktree_info_for_system_prompt(
- project.fs().clone(),
- worktree.read(cx),
- cx,
- )
- })
- .collect::<Vec<_>>();
-
- cx.spawn(async |_cx| {
- let results = futures::future::join_all(tasks).await;
- let mut first_err = None;
- let worktrees = results
- .into_iter()
- .map(|(worktree, err)| {
- if first_err.is_none() && err.is_some() {
- first_err = err;
- }
- worktree
- })
- .collect::<Vec<_>>();
- (AssistantSystemPromptContext::new(worktrees), first_err)
- })
- }
-
- fn load_worktree_info_for_system_prompt(
- fs: Arc<dyn Fs>,
- worktree: &Worktree,
- cx: &App,
- ) -> Task<(WorktreeInfoForSystemPrompt, Option<ThreadError>)> {
- let root_name = worktree.root_name().into();
- let abs_path = worktree.abs_path();
-
- let rules_task = load_worktree_rules_file(fs, worktree, cx);
- let Some(rules_task) = rules_task else {
- return Task::ready((
- WorktreeInfoForSystemPrompt {
- root_name,
- abs_path,
- rules_file: None,
- },
- None,
- ));
- };
-
- cx.spawn(async move |_| {
- let (rules_file, rules_file_error) = match rules_task.await {
- Ok(rules_file) => (Some(rules_file), None),
- Err(err) => (
- None,
- Some(ThreadError::Message {
- header: "Error loading rules file".into(),
- message: format!("{err}").into(),
- }),
- ),
- };
- let worktree_info = WorktreeInfoForSystemPrompt {
- root_name,
- abs_path,
- rules_file,
- };
- (worktree_info, rules_file_error)
- })
- }
-
pub fn send_to_model(
&mut self,
model: Arc<dyn LanguageModel>,
@@ -941,10 +866,10 @@ impl Thread {
temperature: None,
};
- if let Some(system_prompt_context) = self.system_prompt_context.as_ref() {
+ if let Some(project_context) = self.project_context.borrow().as_ref() {
if let Some(system_prompt) = self
.prompt_builder
- .generate_assistant_system_prompt(system_prompt_context)
+ .generate_assistant_system_prompt(project_context)
.context("failed to generate assistant system prompt")
.log_err()
{
@@ -955,7 +880,7 @@ impl Thread {
});
}
} else {
- log::error!("system_prompt_context not set.")
+ log::error!("project_context not set.")
}
for message in &self.messages {
@@ -1215,7 +1140,7 @@ impl Thread {
thread.cancel_last_completion(cx);
}
}
- cx.emit(ThreadEvent::DoneStreaming);
+ cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
thread.auto_capture_telemetry(cx);
@@ -1963,10 +1888,13 @@ impl Thread {
}
}
-#[derive(Debug, Clone)]
+#[derive(Debug, Clone, Error)]
pub enum ThreadError {
+ #[error("Payment required")]
PaymentRequired,
+ #[error("Max monthly spend reached")]
MaxMonthlySpendReached,
+ #[error("Message {header}: {message}")]
Message {
header: SharedString,
message: SharedString,
@@ -1979,7 +1907,7 @@ pub enum ThreadEvent {
StreamedCompletion,
StreamedAssistantText(MessageId, String),
StreamedAssistantThinking(MessageId, String),
- DoneStreaming,
+ Stopped(Result<StopReason, Arc<anyhow::Error>>),
MessageAdded(MessageId),
MessageEdited(MessageId),
MessageDeleted(MessageId),
@@ -2085,9 +2013,9 @@ fn main() {{
thread.to_completion_request(RequestKind::Chat, cx)
});
- assert_eq!(request.messages.len(), 1);
+ assert_eq!(request.messages.len(), 2);
let expected_full_message = format!("{}Please explain this code", expected_context);
- assert_eq!(request.messages[0].string_contents(), expected_full_message);
+ assert_eq!(request.messages[1].string_contents(), expected_full_message);
}
#[gpui::test]
@@ -2178,20 +2106,20 @@ fn main() {{
});
// The request should contain all 3 messages
- assert_eq!(request.messages.len(), 3);
+ assert_eq!(request.messages.len(), 4);
// Check that the contexts are properly formatted in each message
- assert!(request.messages[0].string_contents().contains("file1.rs"));
- assert!(!request.messages[0].string_contents().contains("file2.rs"));
- assert!(!request.messages[0].string_contents().contains("file3.rs"));
-
- assert!(!request.messages[1].string_contents().contains("file1.rs"));
- assert!(request.messages[1].string_contents().contains("file2.rs"));
+ assert!(request.messages[1].string_contents().contains("file1.rs"));
+ assert!(!request.messages[1].string_contents().contains("file2.rs"));
assert!(!request.messages[1].string_contents().contains("file3.rs"));
assert!(!request.messages[2].string_contents().contains("file1.rs"));
- assert!(!request.messages[2].string_contents().contains("file2.rs"));
- assert!(request.messages[2].string_contents().contains("file3.rs"));
+ assert!(request.messages[2].string_contents().contains("file2.rs"));
+ assert!(!request.messages[2].string_contents().contains("file3.rs"));
+
+ assert!(!request.messages[3].string_contents().contains("file1.rs"));
+ assert!(!request.messages[3].string_contents().contains("file2.rs"));
+ assert!(request.messages[3].string_contents().contains("file3.rs"));
}
#[gpui::test]
@@ -2229,9 +2157,9 @@ fn main() {{
thread.to_completion_request(RequestKind::Chat, cx)
});
- assert_eq!(request.messages.len(), 1);
+ assert_eq!(request.messages.len(), 2);
assert_eq!(
- request.messages[0].string_contents(),
+ request.messages[1].string_contents(),
"What is the best way to learn Rust?"
);
@@ -2249,13 +2177,13 @@ fn main() {{
thread.to_completion_request(RequestKind::Chat, cx)
});
- assert_eq!(request.messages.len(), 2);
+ assert_eq!(request.messages.len(), 3);
assert_eq!(
- request.messages[0].string_contents(),
+ request.messages[1].string_contents(),
"What is the best way to learn Rust?"
);
assert_eq!(
- request.messages[1].string_contents(),
+ request.messages[2].string_contents(),
"Are there any good books?"
);
}
@@ -2376,15 +2304,16 @@ fn main() {{
let (workspace, cx) =
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
- let thread_store = cx.update(|_, cx| {
- ThreadStore::new(
- project.clone(),
- Arc::default(),
- Arc::new(PromptBuilder::new(None).unwrap()),
- cx,
- )
- .unwrap()
- });
+ let thread_store = cx
+ .update(|_, cx| {
+ ThreadStore::load(
+ project.clone(),
+ Arc::default(),
+ Arc::new(PromptBuilder::new(None).unwrap()),
+ cx,
+ )
+ })
+ .await;
let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
@@ -1,37 +1,57 @@
use std::borrow::Cow;
-use std::path::PathBuf;
+use std::cell::{Ref, RefCell};
+use std::path::{Path, PathBuf};
+use std::rc::Rc;
use std::sync::Arc;
-use anyhow::{Result, anyhow};
+use anyhow::{Context as _, Result, anyhow};
use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings};
use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
use chrono::{DateTime, Utc};
use collections::HashMap;
use context_server::manager::ContextServerManager;
use context_server::{ContextServerFactoryRegistry, ContextServerTool};
+use fs::Fs;
use futures::FutureExt as _;
use futures::future::{self, BoxFuture, Shared};
use gpui::{
- App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Subscription, Task,
- prelude::*,
+ App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
+ Subscription, Task, prelude::*,
};
use heed::Database;
use heed::types::SerdeBincode;
use language_model::{LanguageModelToolUseId, Role, TokenUsage};
-use project::Project;
-use prompt_store::PromptBuilder;
+use project::{Project, Worktree};
+use prompt_store::{ProjectContext, PromptBuilder, RulesFileContext, WorktreeContext};
use serde::{Deserialize, Serialize};
use settings::{Settings as _, SettingsStore};
use util::ResultExt as _;
-use crate::thread::{
- DetailedSummaryState, MessageId, ProjectSnapshot, Thread, ThreadEvent, ThreadId,
-};
+use crate::thread::{DetailedSummaryState, MessageId, ProjectSnapshot, Thread, ThreadId};
+
+const RULES_FILE_NAMES: [&'static str; 6] = [
+ ".rules",
+ ".cursorrules",
+ ".windsurfrules",
+ ".clinerules",
+ ".github/copilot-instructions.md",
+ "CLAUDE.md",
+];
pub fn init(cx: &mut App) {
ThreadsDatabase::init(cx);
}
+/// A system prompt shared by all threads created by this ThreadStore
+#[derive(Clone, Default)]
+pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
+
+impl SharedProjectContext {
+ pub fn borrow(&self) -> Ref<Option<ProjectContext>> {
+ self.0.borrow()
+ }
+}
+
pub struct ThreadStore {
project: Entity<Project>,
tools: Arc<ToolWorkingSet>,
@@ -39,43 +59,187 @@ pub struct ThreadStore {
context_server_manager: Entity<ContextServerManager>,
context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
threads: Vec<SerializedThreadMetadata>,
+ project_context: SharedProjectContext,
_subscriptions: Vec<Subscription>,
}
+pub struct RulesLoadingError {
+ pub message: SharedString,
+}
+
+impl EventEmitter<RulesLoadingError> for ThreadStore {}
+
impl ThreadStore {
- pub fn new(
+ pub fn load(
project: Entity<Project>,
tools: Arc<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
cx: &mut App,
- ) -> Result<Entity<Self>> {
- let this = cx.new(|cx| {
- let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
- let context_server_manager = cx.new(|cx| {
- ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
+ ) -> Task<Entity<Self>> {
+ let thread_store = cx.new(|cx| Self::new(project, tools, prompt_builder, cx));
+ let reload = thread_store.update(cx, |store, cx| store.reload_system_prompt(cx));
+ cx.foreground_executor().spawn(async move {
+ reload.await;
+ thread_store
+ })
+ }
+
+ fn new(
+ project: Entity<Project>,
+ tools: Arc<ToolWorkingSet>,
+ prompt_builder: Arc<PromptBuilder>,
+ cx: &mut Context<Self>,
+ ) -> Self {
+ let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
+ let context_server_manager = cx.new(|cx| {
+ ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
+ });
+ let settings_subscription =
+ cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
+ this.load_default_profile(cx);
});
- let settings_subscription =
- cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
- this.load_default_profile(cx);
- });
+ let project_subscription = cx.subscribe(&project, Self::handle_project_event);
+
+ let this = Self {
+ project,
+ tools,
+ prompt_builder,
+ context_server_manager,
+ context_server_tool_ids: HashMap::default(),
+ threads: Vec::new(),
+ project_context: SharedProjectContext::default(),
+ _subscriptions: vec![settings_subscription, project_subscription],
+ };
+ this.load_default_profile(cx);
+ this.register_context_server_handlers(cx);
+ this.reload(cx).detach_and_log_err(cx);
+ this
+ }
- let this = Self {
- project,
- tools,
- prompt_builder,
- context_server_manager,
- context_server_tool_ids: HashMap::default(),
- threads: Vec::new(),
- _subscriptions: vec![settings_subscription],
- };
- this.load_default_profile(cx);
- this.register_context_server_handlers(cx);
- this.reload(cx).detach_and_log_err(cx);
+ fn handle_project_event(
+ &mut self,
+ _project: Entity<Project>,
+ event: &project::Event,
+ cx: &mut Context<Self>,
+ ) {
+ match event {
+ project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
+ self.reload_system_prompt(cx).detach();
+ }
+ project::Event::WorktreeUpdatedEntries(_, items) => {
+ if items.iter().any(|(path, _, _)| {
+ RULES_FILE_NAMES
+ .iter()
+ .any(|name| path.as_ref() == Path::new(name))
+ }) {
+ self.reload_system_prompt(cx).detach();
+ }
+ }
+ _ => {}
+ }
+ }
- this
- });
+ pub fn reload_system_prompt(&self, cx: &mut Context<Self>) -> Task<()> {
+ let project = self.project.read(cx);
+ let tasks = project
+ .visible_worktrees(cx)
+ .map(|worktree| {
+ Self::load_worktree_info_for_system_prompt(
+ project.fs().clone(),
+ worktree.read(cx),
+ cx,
+ )
+ })
+ .collect::<Vec<_>>();
- Ok(this)
+ cx.spawn(async move |this, cx| {
+ let results = futures::future::join_all(tasks).await;
+ let worktrees = results
+ .into_iter()
+ .map(|(worktree, rules_error)| {
+ if let Some(rules_error) = rules_error {
+ this.update(cx, |_, cx| cx.emit(rules_error)).ok();
+ }
+ worktree
+ })
+ .collect::<Vec<_>>();
+ this.update(cx, |this, _cx| {
+ *this.project_context.0.borrow_mut() = Some(ProjectContext::new(worktrees));
+ })
+ .ok();
+ })
+ }
+
+ fn load_worktree_info_for_system_prompt(
+ fs: Arc<dyn Fs>,
+ worktree: &Worktree,
+ cx: &App,
+ ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
+ let root_name = worktree.root_name().into();
+ let abs_path = worktree.abs_path();
+
+ let rules_task = Self::load_worktree_rules_file(fs, worktree, cx);
+ let Some(rules_task) = rules_task else {
+ return Task::ready((
+ WorktreeContext {
+ root_name,
+ abs_path,
+ rules_file: None,
+ },
+ None,
+ ));
+ };
+
+ cx.spawn(async move |_| {
+ let (rules_file, rules_file_error) = match rules_task.await {
+ Ok(rules_file) => (Some(rules_file), None),
+ Err(err) => (
+ None,
+ Some(RulesLoadingError {
+ message: format!("{err}").into(),
+ }),
+ ),
+ };
+ let worktree_info = WorktreeContext {
+ root_name,
+ abs_path,
+ rules_file,
+ };
+ (worktree_info, rules_file_error)
+ })
+ }
+
+ fn load_worktree_rules_file(
+ fs: Arc<dyn Fs>,
+ worktree: &Worktree,
+ cx: &App,
+ ) -> Option<Task<Result<RulesFileContext>>> {
+ let selected_rules_file = RULES_FILE_NAMES
+ .into_iter()
+ .filter_map(|name| {
+ worktree
+ .entry_for_path(name)
+ .filter(|entry| entry.is_file())
+ .map(|entry| (entry.path.clone(), worktree.absolutize(&entry.path)))
+ })
+ .next();
+
+ // Note that Cline supports `.clinerules` being a directory, but that is not currently
+ // supported. This doesn't seem to occur often in GitHub repositories.
+ selected_rules_file.map(|(path_in_worktree, abs_path)| {
+ let fs = fs.clone();
+ cx.background_spawn(async move {
+ let abs_path = abs_path?;
+ let text = fs.load(&abs_path).await.with_context(|| {
+ format!("Failed to load assistant rules file {:?}", abs_path)
+ })?;
+ anyhow::Ok(RulesFileContext {
+ path_in_worktree,
+ abs_path: abs_path.into(),
+ text: text.trim().to_string(),
+ })
+ })
+ })
}
pub fn context_server_manager(&self) -> Entity<ContextServerManager> {
@@ -107,6 +271,7 @@ impl ThreadStore {
self.project.clone(),
self.tools.clone(),
self.prompt_builder.clone(),
+ self.project_context.clone(),
cx,
)
})
@@ -134,21 +299,12 @@ impl ThreadStore {
this.project.clone(),
this.tools.clone(),
this.prompt_builder.clone(),
+ this.project_context.clone(),
cx,
)
})
})?;
- let (system_prompt_context, load_error) = thread
- .update(cx, |thread, cx| thread.load_system_prompt_context(cx))?
- .await;
- thread.update(cx, |thread, cx| {
- thread.set_system_prompt_context(system_prompt_context);
- if let Some(load_error) = load_error {
- cx.emit(ThreadEvent::ShowError(load_error));
- }
- })?;
-
Ok(thread)
})
}
@@ -1,46 +0,0 @@
-[package]
-name = "agent_eval"
-version = "0.1.0"
-edition.workspace = true
-publish.workspace = true
-license = "GPL-3.0-or-later"
-
-[lints]
-workspace = true
-
-[[bin]]
-name = "agent_eval"
-path = "src/main.rs"
-
-[dependencies]
-agent.workspace = true
-anyhow.workspace = true
-assistant_tool.workspace = true
-assistant_tools.workspace = true
-clap.workspace = true
-client.workspace = true
-collections.workspace = true
-context_server.workspace = true
-dap.workspace = true
-env_logger.workspace = true
-fs.workspace = true
-futures.workspace = true
-gpui.workspace = true
-gpui_tokio.workspace = true
-language.workspace = true
-language_model.workspace = true
-language_models.workspace = true
-node_runtime.workspace = true
-project.workspace = true
-prompt_store.workspace = true
-release_channel.workspace = true
-reqwest_client.workspace = true
-serde.workspace = true
-serde_json.workspace = true
-serde_json_lenient.workspace = true
-settings.workspace = true
-smol.workspace = true
-tempfile.workspace = true
-util.workspace = true
-walkdir.workspace = true
-workspace-hack.workspace = true
@@ -1 +0,0 @@
-../../LICENSE-GPL
@@ -1,52 +0,0 @@
-// Copied from `crates/zed/build.rs`, with removal of code for including the zed icon on windows.
-
-use std::process::Command;
-
-fn main() {
- if cfg!(target_os = "macos") {
- println!("cargo:rustc-env=MACOSX_DEPLOYMENT_TARGET=10.15.7");
-
- // Weakly link ReplayKit to ensure Zed can be used on macOS 10.15+.
- println!("cargo:rustc-link-arg=-Wl,-weak_framework,ReplayKit");
-
- // Seems to be required to enable Swift concurrency
- println!("cargo:rustc-link-arg=-Wl,-rpath,/usr/lib/swift");
-
- // Register exported Objective-C selectors, protocols, etc
- println!("cargo:rustc-link-arg=-Wl,-ObjC");
- }
-
- // Populate git sha environment variable if git is available
- println!("cargo:rerun-if-changed=../../.git/logs/HEAD");
- println!(
- "cargo:rustc-env=TARGET={}",
- std::env::var("TARGET").unwrap()
- );
- if let Ok(output) = Command::new("git").args(["rev-parse", "HEAD"]).output() {
- if output.status.success() {
- let git_sha = String::from_utf8_lossy(&output.stdout);
- let git_sha = git_sha.trim();
-
- println!("cargo:rustc-env=ZED_COMMIT_SHA={git_sha}");
-
- if let Ok(build_profile) = std::env::var("PROFILE") {
- if build_profile == "release" {
- // This is currently the best way to make `cargo build ...`'s build script
- // to print something to stdout without extra verbosity.
- println!(
- "cargo:warning=Info: using '{git_sha}' hash for ZED_COMMIT_SHA env var"
- );
- }
- }
- }
- }
-
- #[cfg(target_os = "windows")]
- {
- #[cfg(target_env = "msvc")]
- {
- // todo(windows): This is to avoid stack overflow. Remove it when solved.
- println!("cargo:rustc-link-arg=/stack:{}", 8 * 1024 * 1024);
- }
- }
-}
@@ -1,384 +0,0 @@
-use crate::git_commands::{run_git, setup_temp_repo};
-use crate::headless_assistant::{HeadlessAppState, HeadlessAssistant};
-use crate::{get_exercise_language, get_exercise_name};
-use agent::RequestKind;
-use anyhow::{Result, anyhow};
-use collections::HashMap;
-use gpui::{App, Task};
-use language_model::{LanguageModel, TokenUsage};
-use serde::{Deserialize, Serialize};
-use std::{
- fs,
- io::Write,
- path::{Path, PathBuf},
- sync::Arc,
- time::{Duration, SystemTime},
-};
-
-#[derive(Debug, Serialize, Deserialize, Clone)]
-pub struct EvalResult {
- pub exercise_name: String,
- pub diff: String,
- pub assistant_response: String,
- pub elapsed_time_ms: u128,
- pub timestamp: u128,
- // Token usage fields
- pub input_tokens: usize,
- pub output_tokens: usize,
- pub total_tokens: usize,
- pub tool_use_counts: usize,
-}
-
-pub struct EvalOutput {
- pub diff: String,
- pub last_message: String,
- pub elapsed_time: Duration,
- pub assistant_response_count: usize,
- pub tool_use_counts: HashMap<Arc<str>, u32>,
- pub token_usage: TokenUsage,
-}
-
-#[derive(Deserialize)]
-pub struct EvalSetup {
- pub url: String,
- pub base_sha: String,
-}
-
-pub struct Eval {
- pub repo_path: PathBuf,
- pub eval_setup: EvalSetup,
- pub user_prompt: String,
-}
-
-impl Eval {
- // Keep this method for potential future use, but mark it as intentionally unused
- #[allow(dead_code)]
- pub async fn load(_name: String, path: PathBuf, repos_dir: &Path) -> Result<Self> {
- let prompt_path = path.join("prompt.txt");
- let user_prompt = smol::unblock(|| std::fs::read_to_string(prompt_path)).await?;
- let setup_path = path.join("setup.json");
- let setup_contents = smol::unblock(|| std::fs::read_to_string(setup_path)).await?;
- let eval_setup = serde_json_lenient::from_str_lenient::<EvalSetup>(&setup_contents)?;
-
- // Move this internal function inside the load method since it's only used here
- fn repo_dir_name(url: &str) -> String {
- url.trim_start_matches("https://")
- .replace(|c: char| !c.is_alphanumeric(), "_")
- }
-
- let repo_path = repos_dir.join(repo_dir_name(&eval_setup.url));
-
- Ok(Eval {
- repo_path,
- eval_setup,
- user_prompt,
- })
- }
-
- pub fn run(
- self,
- app_state: Arc<HeadlessAppState>,
- model: Arc<dyn LanguageModel>,
- cx: &mut App,
- ) -> Task<Result<EvalOutput>> {
- cx.spawn(async move |cx| {
- run_git(&self.repo_path, &["checkout", &self.eval_setup.base_sha]).await?;
-
- let (assistant, done_rx) =
- cx.update(|cx| HeadlessAssistant::new(app_state.clone(), cx))??;
-
- let _worktree = assistant
- .update(cx, |assistant, cx| {
- assistant.project.update(cx, |project, cx| {
- project.create_worktree(&self.repo_path, true, cx)
- })
- })?
- .await?;
-
- let start_time = std::time::SystemTime::now();
-
- let (system_prompt_context, load_error) = cx
- .update(|cx| {
- assistant
- .read(cx)
- .thread
- .read(cx)
- .load_system_prompt_context(cx)
- })?
- .await;
-
- if let Some(load_error) = load_error {
- return Err(anyhow!("{:?}", load_error));
- };
-
- assistant.update(cx, |assistant, cx| {
- assistant.thread.update(cx, |thread, cx| {
- let context = vec![];
- thread.insert_user_message(self.user_prompt.clone(), context, None, cx);
- thread.set_system_prompt_context(system_prompt_context);
- thread.send_to_model(model, RequestKind::Chat, cx);
- });
- })?;
-
- done_rx.recv().await??;
-
- // Add this section to check untracked files
- println!("Checking for untracked files:");
- let untracked = run_git(
- &self.repo_path,
- &["ls-files", "--others", "--exclude-standard"],
- )
- .await?;
- if untracked.is_empty() {
- println!("No untracked files found");
- } else {
- // Add all files to git so they appear in the diff
- println!("Adding untracked files to git");
- run_git(&self.repo_path, &["add", "."]).await?;
- }
-
- // get git status
- let _status = run_git(&self.repo_path, &["status", "--short"]).await?;
-
- let elapsed_time = start_time.elapsed()?;
-
- // Get diff of staged changes (the files we just added)
- let staged_diff = run_git(&self.repo_path, &["diff", "--staged"]).await?;
-
- // Get diff of unstaged changes
- let unstaged_diff = run_git(&self.repo_path, &["diff"]).await?;
-
- // Combine both diffs
- let diff = if unstaged_diff.is_empty() {
- staged_diff
- } else if staged_diff.is_empty() {
- unstaged_diff
- } else {
- format!(
- "# Staged changes\n{}\n\n# Unstaged changes\n{}",
- staged_diff, unstaged_diff
- )
- };
-
- assistant.update(cx, |assistant, cx| {
- let thread = assistant.thread.read(cx);
- let last_message = thread.messages().last().unwrap();
- if last_message.role != language_model::Role::Assistant {
- return Err(anyhow!("Last message is not from assistant"));
- }
- let assistant_response_count = thread
- .messages()
- .filter(|message| message.role == language_model::Role::Assistant)
- .count();
- Ok(EvalOutput {
- diff,
- last_message: last_message.to_string(),
- elapsed_time,
- assistant_response_count,
- tool_use_counts: assistant.tool_use_counts.clone(),
- token_usage: thread.cumulative_token_usage(),
- })
- })?
- })
- }
-}
-
-impl EvalOutput {
- // Keep this method for potential future use, but mark it as intentionally unused
- #[allow(dead_code)]
- pub fn save_to_directory(&self, output_dir: &Path, eval_output_value: String) -> Result<()> {
- // Create the output directory if it doesn't exist
- fs::create_dir_all(&output_dir)?;
-
- // Save the diff to a file
- let diff_path = output_dir.join("diff.patch");
- let mut diff_file = fs::File::create(&diff_path)?;
- diff_file.write_all(self.diff.as_bytes())?;
-
- // Save the last message to a file
- let message_path = output_dir.join("assistant_response.txt");
- let mut message_file = fs::File::create(&message_path)?;
- message_file.write_all(self.last_message.as_bytes())?;
-
- // Current metrics for this run
- let current_metrics = serde_json::json!({
- "elapsed_time_ms": self.elapsed_time.as_millis(),
- "assistant_response_count": self.assistant_response_count,
- "tool_use_counts": self.tool_use_counts,
- "token_usage": self.token_usage,
- "eval_output_value": eval_output_value,
- });
-
- // Get current timestamp in milliseconds
- let timestamp = std::time::SystemTime::now()
- .duration_since(std::time::UNIX_EPOCH)?
- .as_millis()
- .to_string();
-
- // Path to metrics file
- let metrics_path = output_dir.join("metrics.json");
-
- // Load existing metrics if the file exists, or create a new object
- let mut historical_metrics = if metrics_path.exists() {
- let metrics_content = fs::read_to_string(&metrics_path)?;
- serde_json::from_str::<serde_json::Value>(&metrics_content)
- .unwrap_or_else(|_| serde_json::json!({}))
- } else {
- serde_json::json!({})
- };
-
- // Add new run with timestamp as key
- if let serde_json::Value::Object(ref mut map) = historical_metrics {
- map.insert(timestamp, current_metrics);
- }
-
- // Write updated metrics back to file
- let metrics_json = serde_json::to_string_pretty(&historical_metrics)?;
- let mut metrics_file = fs::File::create(&metrics_path)?;
- metrics_file.write_all(metrics_json.as_bytes())?;
-
- Ok(())
- }
-}
-
-pub async fn read_instructions(exercise_path: &Path) -> Result<String> {
- let instructions_path = exercise_path.join(".docs").join("instructions.md");
- println!("Reading instructions from: {}", instructions_path.display());
- let instructions = smol::unblock(move || std::fs::read_to_string(&instructions_path)).await?;
- Ok(instructions)
-}
-
-pub async fn save_eval_results(exercise_path: &Path, results: Vec<EvalResult>) -> Result<()> {
- let eval_dir = exercise_path.join("evaluation");
- fs::create_dir_all(&eval_dir)?;
-
- let eval_file = eval_dir.join("evals.json");
-
- println!("Saving evaluation results to: {}", eval_file.display());
- println!(
- "Results to save: {} evaluations for exercise path: {}",
- results.len(),
- exercise_path.display()
- );
-
- // Check file existence before reading/writing
- if eval_file.exists() {
- println!("Existing evals.json file found, will update it");
- } else {
- println!("No existing evals.json file found, will create new one");
- }
-
- // Structure to organize evaluations by test name and timestamp
- let mut eval_data: serde_json::Value = if eval_file.exists() {
- let content = fs::read_to_string(&eval_file)?;
- serde_json::from_str(&content).unwrap_or_else(|_| serde_json::json!({}))
- } else {
- serde_json::json!({})
- };
-
- // Get current timestamp for this batch of results
- let timestamp = SystemTime::now()
- .duration_since(SystemTime::UNIX_EPOCH)?
- .as_millis()
- .to_string();
-
- // Group the new results by test name (exercise name)
- for result in results {
- let exercise_name = &result.exercise_name;
-
- println!("Adding result: exercise={}", exercise_name);
-
- // Ensure the exercise entry exists
- if eval_data.get(exercise_name).is_none() {
- eval_data[exercise_name] = serde_json::json!({});
- }
-
- // Ensure the timestamp entry exists as an object
- if eval_data[exercise_name].get(×tamp).is_none() {
- eval_data[exercise_name][×tamp] = serde_json::json!({});
- }
-
- // Add this result under the timestamp with template name as key
- eval_data[exercise_name][×tamp] = serde_json::to_value(&result)?;
- }
-
- // Write back to file with pretty formatting
- let json_content = serde_json::to_string_pretty(&eval_data)?;
- match fs::write(&eval_file, json_content) {
- Ok(_) => println!("✓ Successfully saved results to {}", eval_file.display()),
- Err(e) => println!("✗ Failed to write results file: {}", e),
- }
-
- Ok(())
-}
-
-pub async fn run_exercise_eval(
- exercise_path: PathBuf,
- model: Arc<dyn LanguageModel>,
- app_state: Arc<HeadlessAppState>,
- base_sha: String,
- _framework_path: PathBuf,
- cx: gpui::AsyncApp,
-) -> Result<EvalResult> {
- let exercise_name = get_exercise_name(&exercise_path);
- let language = get_exercise_language(&exercise_path)?;
- let mut instructions = read_instructions(&exercise_path).await?;
- instructions.push_str(&format!(
- "\n\nWhen writing the code for this prompt, use {} to achieve the goal.",
- language
- ));
-
- println!("Running evaluation for exercise: {}", exercise_name);
-
- // Create temporary directory with exercise files
- let temp_dir = setup_temp_repo(&exercise_path, &base_sha).await?;
- let temp_path = temp_dir.path().to_path_buf();
-
- let local_commit_sha = run_git(&temp_path, &["rev-parse", "HEAD"]).await?;
-
- let start_time = SystemTime::now();
-
- // Create a basic eval struct to work with the existing system
- let eval = Eval {
- repo_path: temp_path.clone(),
- eval_setup: EvalSetup {
- url: format!("file://{}", temp_path.display()),
- base_sha: local_commit_sha, // Use the local commit SHA instead of the framework base SHA
- },
- user_prompt: instructions.clone(),
- };
-
- // Run the evaluation
- let eval_output = cx
- .update(|cx| eval.run(app_state.clone(), model.clone(), cx))?
- .await?;
-
- // Get diff from git
- let diff = eval_output.diff.clone();
-
- let elapsed_time = start_time.elapsed()?;
-
- // Calculate total tokens as the sum of input and output tokens
- let input_tokens = eval_output.token_usage.input_tokens;
- let output_tokens = eval_output.token_usage.output_tokens;
- let tool_use_counts = eval_output.tool_use_counts.values().sum::<u32>();
- let total_tokens = input_tokens + output_tokens;
-
- // Save results to evaluation directory
- let result = EvalResult {
- exercise_name: exercise_name.clone(),
- diff,
- assistant_response: eval_output.last_message.clone(),
- elapsed_time_ms: elapsed_time.as_millis(),
- timestamp: SystemTime::now()
- .duration_since(SystemTime::UNIX_EPOCH)?
- .as_millis(),
- // Convert u32 token counts to usize
- input_tokens: input_tokens.try_into().unwrap(),
- output_tokens: output_tokens.try_into().unwrap(),
- total_tokens: total_tokens.try_into().unwrap(),
- tool_use_counts: tool_use_counts.try_into().unwrap(),
- };
-
- Ok(result)
-}
@@ -1,149 +0,0 @@
-use anyhow::{Result, anyhow};
-use std::{
- fs,
- path::{Path, PathBuf},
-};
-
-pub fn get_exercise_name(exercise_path: &Path) -> String {
- exercise_path
- .file_name()
- .unwrap_or_default()
- .to_string_lossy()
- .to_string()
-}
-
-pub fn get_exercise_language(exercise_path: &Path) -> Result<String> {
- // Extract the language from path (data/python/exercises/... => python)
- let parts: Vec<_> = exercise_path.components().collect();
-
- for (i, part) in parts.iter().enumerate() {
- if i > 0 && part.as_os_str() == "eval_code" {
- if i + 1 < parts.len() {
- let language = parts[i + 1].as_os_str().to_string_lossy().to_string();
- return Ok(language);
- }
- }
- }
-
- Err(anyhow!(
- "Could not determine language from path: {:?}",
- exercise_path
- ))
-}
-
-pub fn find_exercises(
- framework_path: &Path,
- languages: &[&str],
- max_per_language: Option<usize>,
-) -> Result<Vec<PathBuf>> {
- let mut all_exercises = Vec::new();
-
- println!("Searching for exercises in languages: {:?}", languages);
-
- for language in languages {
- let language_dir = framework_path
- .join("eval_code")
- .join(language)
- .join("exercises")
- .join("practice");
-
- println!("Checking language directory: {:?}", language_dir);
- if !language_dir.exists() {
- println!("Warning: Language directory not found: {:?}", language_dir);
- continue;
- }
-
- let mut exercises = Vec::new();
- match fs::read_dir(&language_dir) {
- Ok(entries) => {
- for entry_result in entries {
- match entry_result {
- Ok(entry) => {
- let path = entry.path();
-
- if path.is_dir() {
- // Special handling for "internal" directory
- if *language == "internal" {
- // Check for repo_info.json to validate it's an internal exercise
- let repo_info_path = path.join(".meta").join("repo_info.json");
- let instructions_path =
- path.join(".docs").join("instructions.md");
-
- if repo_info_path.exists() && instructions_path.exists() {
- exercises.push(path);
- }
- } else {
- // Map the language to the file extension - original code
- let language_extension = match *language {
- "python" => "py",
- "go" => "go",
- "rust" => "rs",
- "typescript" => "ts",
- "javascript" => "js",
- "ruby" => "rb",
- "php" => "php",
- "bash" => "sh",
- "multi" => "diff",
- _ => continue, // Skip unsupported languages
- };
-
- // Check if this is a valid exercise with instructions and example
- let instructions_path =
- path.join(".docs").join("instructions.md");
- let has_instructions = instructions_path.exists();
- let example_path = path
- .join(".meta")
- .join(format!("example.{}", language_extension));
- let has_example = example_path.exists();
-
- if has_instructions && has_example {
- exercises.push(path);
- }
- }
- }
- }
- Err(err) => println!("Error reading directory entry: {}", err),
- }
- }
- }
- Err(err) => println!(
- "Error reading directory {}: {}",
- language_dir.display(),
- err
- ),
- }
-
- // Sort exercises by name for consistent selection
- exercises.sort_by(|a, b| {
- let a_name = a.file_name().unwrap_or_default().to_string_lossy();
- let b_name = b.file_name().unwrap_or_default().to_string_lossy();
- a_name.cmp(&b_name)
- });
-
- // Apply the limit if specified
- if let Some(limit) = max_per_language {
- if exercises.len() > limit {
- println!(
- "Limiting {} exercises to {} for language {}",
- exercises.len(),
- limit,
- language
- );
- exercises.truncate(limit);
- }
- }
-
- println!(
- "Found {} exercises for language {}: {:?}",
- exercises.len(),
- language,
- exercises
- .iter()
- .map(|p| p.file_name().unwrap_or_default().to_string_lossy())
- .collect::<Vec<_>>()
- );
- all_exercises.extend(exercises);
- }
-
- Ok(all_exercises)
-}
@@ -1,125 +0,0 @@
-use anyhow::{Result, anyhow};
-use serde::Deserialize;
-use std::{fs, path::Path};
-use tempfile::TempDir;
-use util::command::new_smol_command;
-use walkdir::WalkDir;
-
-#[derive(Debug, Deserialize)]
-pub struct SetupConfig {
- #[serde(rename = "base.sha")]
- pub base_sha: String,
-}
-
-#[derive(Debug, Deserialize)]
-pub struct RepoInfo {
- pub remote_url: String,
- pub head_sha: String,
-}
-
-pub async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
- let output = new_smol_command("git")
- .current_dir(repo_path)
- .args(args)
- .output()
- .await?;
-
- if output.status.success() {
- Ok(String::from_utf8(output.stdout)?.trim().to_string())
- } else {
- Err(anyhow!(
- "Git command failed: {} with status: {}",
- args.join(" "),
- output.status
- ))
- }
-}
-
-pub async fn read_base_sha(framework_path: &Path) -> Result<String> {
- let setup_path = framework_path.join("setup.json");
- let setup_content = smol::unblock(move || std::fs::read_to_string(&setup_path)).await?;
- let setup_config: SetupConfig = serde_json_lenient::from_str_lenient(&setup_content)?;
- Ok(setup_config.base_sha)
-}
-
-pub async fn read_repo_info(exercise_path: &Path) -> Result<RepoInfo> {
- let repo_info_path = exercise_path.join(".meta").join("repo_info.json");
- println!("Reading repo info from: {}", repo_info_path.display());
- let repo_info_content = smol::unblock(move || std::fs::read_to_string(&repo_info_path)).await?;
- let repo_info: RepoInfo = serde_json_lenient::from_str_lenient(&repo_info_content)?;
-
- // Remove any quotes from the strings
- let remote_url = repo_info.remote_url.trim_matches('"').to_string();
- let head_sha = repo_info.head_sha.trim_matches('"').to_string();
-
- Ok(RepoInfo {
- remote_url,
- head_sha,
- })
-}
-
-pub async fn setup_temp_repo(exercise_path: &Path, _base_sha: &str) -> Result<TempDir> {
- let temp_dir = TempDir::new()?;
-
- // Check if this is an internal exercise by looking for repo_info.json
- let repo_info_path = exercise_path.join(".meta").join("repo_info.json");
- if repo_info_path.exists() {
- // This is an internal exercise, handle it differently
- let repo_info = read_repo_info(exercise_path).await?;
-
- // Clone the repository to the temp directory
- let url = repo_info.remote_url;
- let clone_path = temp_dir.path();
- println!(
- "Cloning repository from {} to {}",
- url,
- clone_path.display()
- );
- run_git(
- &std::env::current_dir()?,
- &["clone", &url, &clone_path.to_string_lossy()],
- )
- .await?;
-
- // Checkout the specified commit
- println!("Checking out commit: {}", repo_info.head_sha);
- run_git(temp_dir.path(), &["checkout", &repo_info.head_sha]).await?;
-
- println!("Successfully set up internal repository");
- } else {
- // Original code for regular exercises
- // Copy the exercise files to the temp directory, excluding .docs and .meta
- for entry in WalkDir::new(exercise_path).min_depth(0).max_depth(10) {
- let entry = entry?;
- let source_path = entry.path();
-
- // Skip .docs and .meta directories completely
- if source_path.starts_with(exercise_path.join(".docs"))
- || source_path.starts_with(exercise_path.join(".meta"))
- {
- continue;
- }
-
- if source_path.is_file() {
- let relative_path = source_path.strip_prefix(exercise_path)?;
- let dest_path = temp_dir.path().join(relative_path);
-
- // Make sure parent directories exist
- if let Some(parent) = dest_path.parent() {
- fs::create_dir_all(parent)?;
- }
-
- fs::copy(source_path, dest_path)?;
- }
- }
-
- // Initialize git repo in the temp directory
- run_git(temp_dir.path(), &["init"]).await?;
- run_git(temp_dir.path(), &["add", "."]).await?;
- run_git(temp_dir.path(), &["commit", "-m", "Initial commit"]).await?;
-
- println!("Created temp repo without .docs and .meta directories");
- }
-
- Ok(temp_dir)
-}
@@ -1,229 +0,0 @@
-use agent::{RequestKind, Thread, ThreadEvent, ThreadStore};
-use anyhow::anyhow;
-use assistant_tool::ToolWorkingSet;
-use client::{Client, UserStore};
-use collections::HashMap;
-use dap::DapRegistry;
-use gpui::{App, Entity, SemanticVersion, Subscription, Task, prelude::*};
-use language::LanguageRegistry;
-use language_model::{
- AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
-};
-use node_runtime::NodeRuntime;
-use project::{Project, RealFs};
-use prompt_store::PromptBuilder;
-use settings::SettingsStore;
-use smol::channel;
-use std::sync::Arc;
-
-/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
-pub struct HeadlessAppState {
- pub languages: Arc<LanguageRegistry>,
- pub client: Arc<Client>,
- pub user_store: Entity<UserStore>,
- pub fs: Arc<dyn fs::Fs>,
- pub node_runtime: NodeRuntime,
-
- // Additional fields not present in `workspace::AppState`.
- pub prompt_builder: Arc<PromptBuilder>,
-}
-
-pub struct HeadlessAssistant {
- pub thread: Entity<Thread>,
- pub project: Entity<Project>,
- #[allow(dead_code)]
- pub thread_store: Entity<ThreadStore>,
- pub tool_use_counts: HashMap<Arc<str>, u32>,
- pub done_tx: channel::Sender<anyhow::Result<()>>,
- _subscription: Subscription,
-}
-
-impl HeadlessAssistant {
- pub fn new(
- app_state: Arc<HeadlessAppState>,
- cx: &mut App,
- ) -> anyhow::Result<(Entity<Self>, channel::Receiver<anyhow::Result<()>>)> {
- let env = None;
- let project = Project::local(
- app_state.client.clone(),
- app_state.node_runtime.clone(),
- app_state.user_store.clone(),
- app_state.languages.clone(),
- Arc::new(DapRegistry::default()),
- app_state.fs.clone(),
- env,
- cx,
- );
-
- let tools = Arc::new(ToolWorkingSet::default());
- let thread_store =
- ThreadStore::new(project.clone(), tools, app_state.prompt_builder.clone(), cx)?;
-
- let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
-
- let (done_tx, done_rx) = channel::unbounded::<anyhow::Result<()>>();
-
- let headless_thread = cx.new(move |cx| Self {
- _subscription: cx.subscribe(&thread, Self::handle_thread_event),
- thread,
- project,
- thread_store,
- tool_use_counts: HashMap::default(),
- done_tx,
- });
-
- Ok((headless_thread, done_rx))
- }
-
- fn handle_thread_event(
- &mut self,
- thread: Entity<Thread>,
- event: &ThreadEvent,
- cx: &mut Context<Self>,
- ) {
- match event {
- ThreadEvent::ShowError(err) => self
- .done_tx
- .send_blocking(Err(anyhow!("{:?}", err)))
- .unwrap(),
- ThreadEvent::DoneStreaming => {
- let thread = thread.read(cx);
- if let Some(message) = thread.messages().last() {
- println!("Message: {}", message.to_string());
- }
- if thread.all_tools_finished() {
- self.done_tx.send_blocking(Ok(())).unwrap()
- }
- }
- ThreadEvent::UsePendingTools { .. } => {}
- ThreadEvent::ToolConfirmationNeeded => {
- // Automatically approve all tools that need confirmation in headless mode
- println!("Tool confirmation needed - automatically approving in headless mode");
-
- // Get the tools needing confirmation
- let tools_needing_confirmation: Vec<_> = thread
- .read(cx)
- .tools_needing_confirmation()
- .cloned()
- .collect();
-
- // Run each tool that needs confirmation
- for tool_use in tools_needing_confirmation {
- if let Some(tool) = thread.read(cx).tools().tool(&tool_use.name, cx) {
- thread.update(cx, |thread, cx| {
- println!("Auto-approving tool: {}", tool_use.name);
-
- // Create a request to send to the tool
- let request = thread.to_completion_request(RequestKind::Chat, cx);
- let messages = Arc::new(request.messages);
-
- // Run the tool
- thread.run_tool(
- tool_use.id.clone(),
- tool_use.ui_text.clone(),
- tool_use.input.clone(),
- &messages,
- tool,
- cx,
- );
- });
- }
- }
- }
- ThreadEvent::ToolFinished {
- tool_use_id,
- pending_tool_use,
- ..
- } => {
- if let Some(pending_tool_use) = pending_tool_use {
- println!(
- "Used tool {} with input: {}",
- pending_tool_use.name, pending_tool_use.input
- );
- *self
- .tool_use_counts
- .entry(pending_tool_use.name.clone())
- .or_insert(0) += 1;
- }
- if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
- println!("Tool result: {:?}", tool_result);
- }
- }
- _ => {}
- }
- }
-}
-
-pub fn init(cx: &mut App) -> Arc<HeadlessAppState> {
- release_channel::init(SemanticVersion::default(), cx);
- gpui_tokio::init(cx);
-
- let mut settings_store = SettingsStore::new(cx);
- settings_store
- .set_default_settings(settings::default_settings().as_ref(), cx)
- .unwrap();
- cx.set_global(settings_store);
- client::init_settings(cx);
- Project::init_settings(cx);
-
- let client = Client::production(cx);
- cx.set_http_client(client.http_client().clone());
-
- let git_binary_path = None;
- let fs = Arc::new(RealFs::new(
- git_binary_path,
- cx.background_executor().clone(),
- ));
-
- let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone()));
-
- let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
-
- language::init(cx);
- language_model::init(client.clone(), cx);
- language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
- assistant_tools::init(client.http_client().clone(), cx);
- context_server::init(cx);
- let stdout_is_a_pty = false;
- let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
- agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
-
- Arc::new(HeadlessAppState {
- languages,
- client,
- user_store,
- fs,
- node_runtime: NodeRuntime::unavailable(),
- prompt_builder,
- })
-}
-
-pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
- let model_registry = LanguageModelRegistry::read_global(cx);
- let model = model_registry
- .available_models(cx)
- .find(|model| model.id().0 == model_name);
-
- let Some(model) = model else {
- return Err(anyhow!(
- "No language model named {} was available. Available models: {}",
- model_name,
- model_registry
- .available_models(cx)
- .map(|model| model.id().0.clone())
- .collect::<Vec<_>>()
- .join(", ")
- ));
- };
-
- Ok(model)
-}
-
-pub fn authenticate_model_provider(
- provider_id: LanguageModelProviderId,
- cx: &mut App,
-) -> Task<std::result::Result<(), AuthenticateError>> {
- let model_registry = LanguageModelRegistry::read_global(cx);
- let model_provider = model_registry.provider(&provider_id).unwrap();
- model_provider.authenticate(cx)
-}
@@ -1,205 +0,0 @@
-mod eval;
-mod get_exercise;
-mod git_commands;
-mod headless_assistant;
-
-use clap::Parser;
-use eval::{run_exercise_eval, save_eval_results};
-use futures::stream::{self, StreamExt};
-use get_exercise::{find_exercises, get_exercise_language, get_exercise_name};
-use git_commands::read_base_sha;
-use gpui::Application;
-use headless_assistant::{authenticate_model_provider, find_model};
-use language_model::LanguageModelRegistry;
-use reqwest_client::ReqwestClient;
-use std::{path::PathBuf, sync::Arc};
-
-#[derive(Parser, Debug)]
-#[command(
- name = "agent_eval",
- disable_version_flag = true,
- before_help = "Tool eval runner"
-)]
-struct Args {
- /// Match the names of evals to run.
- #[arg(long)]
- exercise_names: Vec<String>,
- /// Runs all exercises, causes the exercise_names to be ignored.
- #[arg(long)]
- all: bool,
- /// Supported language types to evaluate (default: internal).
- /// Internal is data generated from the agent panel
- #[arg(long, default_value = "internal")]
- languages: String,
- /// Name of the model (default: "claude-3-7-sonnet-latest")
- #[arg(long, default_value = "claude-3-7-sonnet-latest")]
- model_name: String,
- /// Name of the editor model (default: value of `--model_name`).
- #[arg(long)]
- editor_model_name: Option<String>,
- /// Number of evaluations to run concurrently (default: 3)
- #[arg(short, long, default_value = "5")]
- concurrency: usize,
- /// Maximum number of exercises to evaluate per language
- #[arg(long)]
- max_exercises_per_language: Option<usize>,
-}
-
-fn main() {
- env_logger::init();
- let args = Args::parse();
- let http_client = Arc::new(ReqwestClient::new());
- let app = Application::headless().with_http_client(http_client.clone());
-
- // Path to the zed-ace-framework repo
- let framework_path = PathBuf::from("../zed-ace-framework")
- .canonicalize()
- .unwrap();
-
- // Fix the 'languages' lifetime issue by creating owned Strings instead of slices
- let languages: Vec<String> = args.languages.split(',').map(|s| s.to_string()).collect();
-
- println!("Using zed-ace-framework at: {:?}", framework_path);
- println!("Evaluating languages: {:?}", languages);
-
- app.run(move |cx| {
- let app_state = headless_assistant::init(cx);
-
- let model = find_model(&args.model_name, cx).unwrap();
- let editor_model = if let Some(model_name) = &args.editor_model_name {
- find_model(model_name, cx).unwrap()
- } else {
- model.clone()
- };
-
- LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
- registry.set_default_model(Some(model.clone()), cx);
- });
-
- let model_provider_id = model.provider_id();
- let editor_model_provider_id = editor_model.provider_id();
-
- let framework_path_clone = framework_path.clone();
- let languages_clone = languages.clone();
- let exercise_names = args.exercise_names.clone();
- let all_flag = args.all;
-
- cx.spawn(async move |cx| {
- // Authenticate all model providers first
- cx.update(|cx| authenticate_model_provider(model_provider_id.clone(), cx))
- .unwrap()
- .await
- .unwrap();
- cx.update(|cx| authenticate_model_provider(editor_model_provider_id.clone(), cx))
- .unwrap()
- .await
- .unwrap();
-
- println!("framework path: {}", framework_path_clone.display());
-
- let base_sha = read_base_sha(&framework_path_clone).await.unwrap();
-
- println!("base sha: {}", base_sha);
-
- let all_exercises = find_exercises(
- &framework_path_clone,
- &languages_clone
- .iter()
- .map(|s| s.as_str())
- .collect::<Vec<_>>(),
- args.max_exercises_per_language,
- )
- .unwrap();
- println!("Found {} exercises total", all_exercises.len());
-
- // Filter exercises if specific ones were requested
- let exercises_to_run = if !exercise_names.is_empty() {
- // If exercise names are specified, filter by them regardless of --all flag
- all_exercises
- .into_iter()
- .filter(|path| {
- let name = get_exercise_name(path);
- exercise_names.iter().any(|filter| name.contains(filter))
- })
- .collect()
- } else if all_flag {
- // Only use all_flag if no exercise names are specified
- all_exercises
- } else {
- // Default behavior (no filters)
- all_exercises
- };
-
- println!("Will run {} exercises", exercises_to_run.len());
-
- // Create exercise eval tasks - each exercise is a single task that will run templates sequentially
- let exercise_tasks: Vec<_> = exercises_to_run
- .into_iter()
- .map(|exercise_path| {
- let exercise_name = get_exercise_name(&exercise_path);
- let model_clone = model.clone();
- let app_state_clone = app_state.clone();
- let base_sha_clone = base_sha.clone();
- let framework_path_clone = framework_path_clone.clone();
- let cx_clone = cx.clone();
-
- async move {
- println!("Processing exercise: {}", exercise_name);
- let mut exercise_results = Vec::new();
-
- match run_exercise_eval(
- exercise_path.clone(),
- model_clone.clone(),
- app_state_clone.clone(),
- base_sha_clone.clone(),
- framework_path_clone.clone(),
- cx_clone.clone(),
- )
- .await
- {
- Ok(result) => {
- println!("Completed {}", exercise_name);
- exercise_results.push(result);
- }
- Err(err) => {
- println!("Error running {}: {}", exercise_name, err);
- }
- }
-
- // Save results for this exercise
- if !exercise_results.is_empty() {
- if let Err(err) =
- save_eval_results(&exercise_path, exercise_results.clone()).await
- {
- println!("Error saving results for {}: {}", exercise_name, err);
- } else {
- println!("Saved results for {}", exercise_name);
- }
- }
-
- exercise_results
- }
- })
- .collect();
-
- println!(
- "Running {} exercises with concurrency: {}",
- exercise_tasks.len(),
- args.concurrency
- );
-
- // Run exercises concurrently, with each exercise running its templates sequentially
- let all_results = stream::iter(exercise_tasks)
- .buffer_unordered(args.concurrency)
- .flat_map(stream::iter)
- .collect::<Vec<_>>()
- .await;
-
- println!("Completed {} evaluation runs", all_results.len());
- cx.update(|cx| cx.quit()).unwrap();
- })
- .detach();
- });
-
- println!("Done running evals");
-}
@@ -1,25 +0,0 @@
-[package]
-name = "agent_rules"
-version = "0.1.0"
-edition.workspace = true
-publish.workspace = true
-license = "GPL-3.0-or-later"
-
-[lints]
-workspace = true
-
-[lib]
-path = "src/agent_rules.rs"
-doctest = false
-
-[dependencies]
-anyhow.workspace = true
-fs.workspace = true
-gpui.workspace = true
-prompt_store.workspace = true
-util.workspace = true
-worktree.workspace = true
-workspace-hack = { version = "0.1", path = "../../tooling/workspace-hack" }
-
-[dev-dependencies]
-indoc.workspace = true
@@ -1 +0,0 @@
-../../LICENSE-GPL
@@ -1,51 +0,0 @@
-use std::sync::Arc;
-
-use anyhow::{Context as _, Result};
-use fs::Fs;
-use gpui::{App, AppContext, Task};
-use prompt_store::SystemPromptRulesFile;
-use util::maybe;
-use worktree::Worktree;
-
-const RULES_FILE_NAMES: [&'static str; 6] = [
- ".rules",
- ".cursorrules",
- ".windsurfrules",
- ".clinerules",
- ".github/copilot-instructions.md",
- "CLAUDE.md",
-];
-
-pub fn load_worktree_rules_file(
- fs: Arc<dyn Fs>,
- worktree: &Worktree,
- cx: &App,
-) -> Option<Task<Result<SystemPromptRulesFile>>> {
- let selected_rules_file = RULES_FILE_NAMES
- .into_iter()
- .filter_map(|name| {
- worktree
- .entry_for_path(name)
- .filter(|entry| entry.is_file())
- .map(|entry| (entry.path.clone(), worktree.absolutize(&entry.path)))
- })
- .next();
-
- // Note that Cline supports `.clinerules` being a directory, but that is not currently
- // supported. This doesn't seem to occur often in GitHub repositories.
- selected_rules_file.map(|(path_in_worktree, abs_path)| {
- let fs = fs.clone();
- cx.background_spawn(maybe!(async move {
- let abs_path = abs_path?;
- let text = fs
- .load(&abs_path)
- .await
- .with_context(|| format!("Failed to load assistant rules file {:?}", abs_path))?;
- anyhow::Ok(SystemPromptRulesFile {
- path_in_worktree,
- abs_path: abs_path.into(),
- text: text.trim().to_string(),
- })
- }))
- })
-}
@@ -69,7 +69,7 @@ pub enum AssistantProviderContentV1 {
},
}
-#[derive(Debug, Default)]
+#[derive(Clone, Debug, Default)]
pub struct AssistantSettings {
pub enabled: bool,
pub button: bool,
@@ -179,11 +179,9 @@ pub async fn file_outline(
// Wait until the buffer has been fully parsed, so that we can read its outline.
let mut parse_status = buffer.read_with(cx, |buffer, _| buffer.parse_status())?;
- while parse_status
- .recv()
- .await
- .map_or(false, |status| status != ParseStatus::Idle)
- {}
+ while *parse_status.borrow() != ParseStatus::Idle {
+ parse_status.changed().await?;
+ }
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
let Some(outline) = snapshot.outline(None) else {
@@ -9,12 +9,13 @@ agent.workspace = true
anyhow.workspace = true
assistant_tool.workspace = true
assistant_tools.workspace = true
+assistant_settings.workspace = true
client.workspace = true
-collections.workspace = true
context_server.workspace = true
dap.workspace = true
env_logger.workspace = true
fs.workspace = true
+futures.workspace = true
gpui.workspace = true
gpui_tokio.workspace = true
language.workspace = true
@@ -27,7 +28,6 @@ release_channel.workspace = true
reqwest_client.workspace = true
serde.workspace = true
settings.workspace = true
-smol.workspace = true
toml.workspace = true
workspace-hack.workspace = true
@@ -1,229 +0,0 @@
-use ::agent::{RequestKind, Thread, ThreadEvent, ThreadStore};
-use anyhow::anyhow;
-use assistant_tool::ToolWorkingSet;
-use client::{Client, UserStore};
-use collections::HashMap;
-use dap::DapRegistry;
-use gpui::{App, Entity, SemanticVersion, Subscription, Task, prelude::*};
-use language::LanguageRegistry;
-use language_model::{
- AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
-};
-use node_runtime::NodeRuntime;
-use project::{Project, RealFs};
-use prompt_store::PromptBuilder;
-use settings::SettingsStore;
-use smol::channel;
-use std::sync::Arc;
-
-/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
-pub struct AgentAppState {
- pub languages: Arc<LanguageRegistry>,
- pub client: Arc<Client>,
- pub user_store: Entity<UserStore>,
- pub fs: Arc<dyn fs::Fs>,
- pub node_runtime: NodeRuntime,
-
- // Additional fields not present in `workspace::AppState`.
- pub prompt_builder: Arc<PromptBuilder>,
-}
-
-pub struct Agent {
- // pub thread: Entity<Thread>,
- // pub project: Entity<Project>,
- #[allow(dead_code)]
- pub thread_store: Entity<ThreadStore>,
- pub tool_use_counts: HashMap<Arc<str>, u32>,
- pub done_tx: channel::Sender<anyhow::Result<()>>,
- _subscription: Subscription,
-}
-
-impl Agent {
- pub fn new(
- app_state: Arc<AgentAppState>,
- cx: &mut App,
- ) -> anyhow::Result<(Entity<Self>, channel::Receiver<anyhow::Result<()>>)> {
- let env = None;
- let project = Project::local(
- app_state.client.clone(),
- app_state.node_runtime.clone(),
- app_state.user_store.clone(),
- app_state.languages.clone(),
- Arc::new(DapRegistry::default()),
- app_state.fs.clone(),
- env,
- cx,
- );
-
- let tools = Arc::new(ToolWorkingSet::default());
- let thread_store =
- ThreadStore::new(project.clone(), tools, app_state.prompt_builder.clone(), cx)?;
-
- let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
-
- let (done_tx, done_rx) = channel::unbounded::<anyhow::Result<()>>();
-
- let headless_thread = cx.new(move |cx| Self {
- _subscription: cx.subscribe(&thread, Self::handle_thread_event),
- // thread,
- // project,
- thread_store,
- tool_use_counts: HashMap::default(),
- done_tx,
- });
-
- Ok((headless_thread, done_rx))
- }
-
- fn handle_thread_event(
- &mut self,
- thread: Entity<Thread>,
- event: &ThreadEvent,
- cx: &mut Context<Self>,
- ) {
- match event {
- ThreadEvent::ShowError(err) => self
- .done_tx
- .send_blocking(Err(anyhow!("{:?}", err)))
- .unwrap(),
- ThreadEvent::DoneStreaming => {
- let thread = thread.read(cx);
- if let Some(message) = thread.messages().last() {
- println!("Message: {}", message.to_string());
- }
- if thread.all_tools_finished() {
- self.done_tx.send_blocking(Ok(())).unwrap()
- }
- }
- ThreadEvent::UsePendingTools { .. } => {}
- ThreadEvent::ToolConfirmationNeeded => {
- // Automatically approve all tools that need confirmation in headless mode
- println!("Tool confirmation needed - automatically approving in headless mode");
-
- // Get the tools needing confirmation
- let tools_needing_confirmation: Vec<_> = thread
- .read(cx)
- .tools_needing_confirmation()
- .cloned()
- .collect();
-
- // Run each tool that needs confirmation
- for tool_use in tools_needing_confirmation {
- if let Some(tool) = thread.read(cx).tools().tool(&tool_use.name, cx) {
- thread.update(cx, |thread, cx| {
- println!("Auto-approving tool: {}", tool_use.name);
-
- // Create a request to send to the tool
- let request = thread.to_completion_request(RequestKind::Chat, cx);
- let messages = Arc::new(request.messages);
-
- // Run the tool
- thread.run_tool(
- tool_use.id.clone(),
- tool_use.ui_text.clone(),
- tool_use.input.clone(),
- &messages,
- tool,
- cx,
- );
- });
- }
- }
- }
- ThreadEvent::ToolFinished {
- tool_use_id,
- pending_tool_use,
- ..
- } => {
- if let Some(pending_tool_use) = pending_tool_use {
- println!(
- "Used tool {} with input: {}",
- pending_tool_use.name, pending_tool_use.input
- );
- *self
- .tool_use_counts
- .entry(pending_tool_use.name.clone())
- .or_insert(0) += 1;
- }
- if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
- println!("Tool result: {:?}", tool_result);
- }
- }
- _ => {}
- }
- }
-}
-
-pub fn init(cx: &mut App) -> Arc<AgentAppState> {
- release_channel::init(SemanticVersion::default(), cx);
- gpui_tokio::init(cx);
-
- let mut settings_store = SettingsStore::new(cx);
- settings_store
- .set_default_settings(settings::default_settings().as_ref(), cx)
- .unwrap();
- cx.set_global(settings_store);
- client::init_settings(cx);
- Project::init_settings(cx);
-
- let client = Client::production(cx);
- cx.set_http_client(client.http_client().clone());
-
- let git_binary_path = None;
- let fs = Arc::new(RealFs::new(
- git_binary_path,
- cx.background_executor().clone(),
- ));
-
- let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone()));
-
- let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
-
- language::init(cx);
- language_model::init(client.clone(), cx);
- language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
- assistant_tools::init(client.http_client().clone(), cx);
- context_server::init(cx);
- let stdout_is_a_pty = false;
- let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
- agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
-
- Arc::new(AgentAppState {
- languages,
- client,
- user_store,
- fs,
- node_runtime: NodeRuntime::unavailable(),
- prompt_builder,
- })
-}
-
-pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
- let model_registry = LanguageModelRegistry::read_global(cx);
- let model = model_registry
- .available_models(cx)
- .find(|model| model.id().0 == model_name);
-
- let Some(model) = model else {
- return Err(anyhow!(
- "No language model named {} was available. Available models: {}",
- model_name,
- model_registry
- .available_models(cx)
- .map(|model| model.id().0.clone())
- .collect::<Vec<_>>()
- .join(", ")
- ));
- };
-
- Ok(model)
-}
-
-pub fn authenticate_model_provider(
- provider_id: LanguageModelProviderId,
- cx: &mut App,
-) -> Task<std::result::Result<(), AuthenticateError>> {
- let model_registry = LanguageModelRegistry::read_global(cx);
- let model_provider = model_registry.provider(&provider_id).unwrap();
- model_provider.authenticate(cx)
-}
@@ -1,74 +1,22 @@
-use agent::Agent;
-use anyhow::Result;
-use gpui::Application;
-use language_model::LanguageModelRegistry;
-use reqwest_client::ReqwestClient;
-use serde::Deserialize;
-use std::{
- fs,
- path::{Path, PathBuf},
- sync::Arc,
+mod example;
+
+use assistant_settings::AssistantSettings;
+use client::{Client, UserStore};
+pub(crate) use example::*;
+
+use ::fs::RealFs;
+use anyhow::anyhow;
+use gpui::{App, AppContext, Application, Entity, SemanticVersion, Task};
+use language::LanguageRegistry;
+use language_model::{
+ AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
};
-mod agent;
-
-#[derive(Debug, Deserialize)]
-pub struct ExampleBase {
- pub path: PathBuf,
- pub revision: String,
-}
-
-#[derive(Debug)]
-pub struct Example {
- pub base: ExampleBase,
-
- /// Content of the prompt.md file
- pub prompt: String,
-
- /// Content of the rubric.md file
- pub rubric: String,
-}
-
-impl Example {
- /// Load an example from a directory containing base.toml, prompt.md, and rubric.md
- pub fn load_from_directory<P: AsRef<Path>>(dir_path: P) -> Result<Self> {
- let base_path = dir_path.as_ref().join("base.toml");
- let prompt_path = dir_path.as_ref().join("prompt.md");
- let rubric_path = dir_path.as_ref().join("rubric.md");
-
- let mut base: ExampleBase = toml::from_str(&fs::read_to_string(&base_path)?)?;
- base.path = base.path.canonicalize()?;
-
- Ok(Example {
- base,
- prompt: fs::read_to_string(prompt_path)?,
- rubric: fs::read_to_string(rubric_path)?,
- })
- }
-
- /// Set up the example by checking out the specified Git revision
- pub fn setup(&self) -> Result<()> {
- use std::process::Command;
-
- // Check if the directory exists
- let path = Path::new(&self.base.path);
- anyhow::ensure!(path.exists(), "Path does not exist: {:?}", self.base.path);
-
- // Change to the project directory and checkout the specified revision
- let output = Command::new("git")
- .current_dir(&self.base.path)
- .arg("checkout")
- .arg(&self.base.revision)
- .output()?;
- anyhow::ensure!(
- output.status.success(),
- "Failed to checkout revision {}: {}",
- self.base.revision,
- String::from_utf8_lossy(&output.stderr),
- );
-
- Ok(())
- }
-}
+use node_runtime::NodeRuntime;
+use project::Project;
+use prompt_store::PromptBuilder;
+use reqwest_client::ReqwestClient;
+use settings::{Settings, SettingsStore};
+use std::sync::Arc;
fn main() {
env_logger::init();
@@ -76,10 +24,9 @@ fn main() {
let app = Application::headless().with_http_client(http_client.clone());
app.run(move |cx| {
- let app_state = crate::agent::init(cx);
- let _agent = Agent::new(app_state, cx);
+ let app_state = init(cx);
- let model = agent::find_model("claude-3-7-sonnet-thinking-latest", cx).unwrap();
+ let model = find_model("claude-3-7-sonnet-thinking-latest", cx).unwrap();
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
registry.set_default_model(Some(model.clone()), cx);
@@ -87,15 +34,112 @@ fn main() {
let model_provider_id = model.provider_id();
- let authenticate = agent::authenticate_model_provider(model_provider_id.clone(), cx);
+ let authenticate = authenticate_model_provider(model_provider_id.clone(), cx);
- cx.spawn(async move |_cx| {
+ cx.spawn(async move |cx| {
authenticate.await.unwrap();
+
+ let example =
+ Example::load_from_directory("./crates/eval/examples/find_and_replace_diff_card")?;
+ example.setup()?;
+ cx.update(|cx| example.run(model, app_state, cx))?.await?;
+
+ anyhow::Ok(())
})
- .detach();
+ .detach_and_log_err(cx);
});
+}
+
+/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
+pub struct AgentAppState {
+ pub languages: Arc<LanguageRegistry>,
+ pub client: Arc<Client>,
+ pub user_store: Entity<UserStore>,
+ pub fs: Arc<dyn fs::Fs>,
+ pub node_runtime: NodeRuntime,
+
+ // Additional fields not present in `workspace::AppState`.
+ pub prompt_builder: Arc<PromptBuilder>,
+}
+
+pub fn init(cx: &mut App) -> Arc<AgentAppState> {
+ release_channel::init(SemanticVersion::default(), cx);
+ gpui_tokio::init(cx);
+
+ let mut settings_store = SettingsStore::new(cx);
+ settings_store
+ .set_default_settings(settings::default_settings().as_ref(), cx)
+ .unwrap();
+ cx.set_global(settings_store);
+ client::init_settings(cx);
+ Project::init_settings(cx);
+
+ let client = Client::production(cx);
+ cx.set_http_client(client.http_client().clone());
+
+ let git_binary_path = None;
+ let fs = Arc::new(RealFs::new(
+ git_binary_path,
+ cx.background_executor().clone(),
+ ));
+
+ let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone()));
+
+ let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
+
+ language::init(cx);
+ language_model::init(client.clone(), cx);
+ language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
+ assistant_tools::init(client.http_client().clone(), cx);
+ context_server::init(cx);
+ let stdout_is_a_pty = false;
+ let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
+ agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
+
+ AssistantSettings::override_global(
+ AssistantSettings {
+ always_allow_tool_actions: true,
+ ..AssistantSettings::get_global(cx).clone()
+ },
+ cx,
+ );
+
+ Arc::new(AgentAppState {
+ languages,
+ client,
+ user_store,
+ fs,
+ node_runtime: NodeRuntime::unavailable(),
+ prompt_builder,
+ })
+}
+
+pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
+ let model_registry = LanguageModelRegistry::read_global(cx);
+ let model = model_registry
+ .available_models(cx)
+ .find(|model| model.id().0 == model_name);
+
+ let Some(model) = model else {
+ return Err(anyhow!(
+ "No language model named {} was available. Available models: {}",
+ model_name,
+ model_registry
+ .available_models(cx)
+ .map(|model| model.id().0.clone())
+ .collect::<Vec<_>>()
+ .join(", ")
+ ));
+ };
+
+ Ok(model)
+}
- // let example =
- // Example::load_from_directory("./crates/eval/examples/find_and_replace_diff_card")?;
- // example.setup()?;
+pub fn authenticate_model_provider(
+ provider_id: LanguageModelProviderId,
+ cx: &mut App,
+) -> Task<std::result::Result<(), AuthenticateError>> {
+ let model_registry = LanguageModelRegistry::read_global(cx);
+ let model_provider = model_registry.provider(&provider_id).unwrap();
+ model_provider.authenticate(cx)
}
@@ -0,0 +1,178 @@
+use agent::{RequestKind, ThreadEvent, ThreadStore};
+use anyhow::{Result, anyhow};
+use assistant_tool::ToolWorkingSet;
+use dap::DapRegistry;
+use futures::channel::oneshot;
+use gpui::{App, Task};
+use language_model::{LanguageModel, StopReason};
+use project::Project;
+use serde::Deserialize;
+use std::process::Command;
+use std::sync::Arc;
+use std::{
+ fs,
+ path::{Path, PathBuf},
+};
+
+use crate::AgentAppState;
+
+#[derive(Debug, Deserialize)]
+pub struct ExampleBase {
+ pub path: PathBuf,
+ pub revision: String,
+}
+
+#[derive(Debug)]
+pub struct Example {
+ pub base: ExampleBase,
+
+ /// Content of the prompt.md file
+ pub prompt: String,
+
+ /// Content of the rubric.md file
+ pub _rubric: String,
+}
+
+impl Example {
+ /// Load an example from a directory containing base.toml, prompt.md, and rubric.md
+ pub fn load_from_directory<P: AsRef<Path>>(dir_path: P) -> Result<Self> {
+ let base_path = dir_path.as_ref().join("base.toml");
+ let prompt_path = dir_path.as_ref().join("prompt.md");
+ let rubric_path = dir_path.as_ref().join("rubric.md");
+
+ let mut base: ExampleBase = toml::from_str(&fs::read_to_string(&base_path)?)?;
+ base.path = base.path.canonicalize()?;
+
+ Ok(Example {
+ base,
+ prompt: fs::read_to_string(prompt_path)?,
+ _rubric: fs::read_to_string(rubric_path)?,
+ })
+ }
+
+ /// Set up the example by checking out the specified Git revision
+ pub fn setup(&self) -> Result<()> {
+ // Check if the directory exists
+ let path = Path::new(&self.base.path);
+ anyhow::ensure!(path.exists(), "Path does not exist: {:?}", self.base.path);
+
+ // Change to the project directory and checkout the specified revision
+ let output = Command::new("git")
+ .current_dir(&self.base.path)
+ .arg("checkout")
+ .arg(&self.base.revision)
+ .output()?;
+ anyhow::ensure!(
+ output.status.success(),
+ "Failed to checkout revision {}: {}",
+ self.base.revision,
+ String::from_utf8_lossy(&output.stderr),
+ );
+
+ Ok(())
+ }
+
+ pub fn run(
+ self,
+ model: Arc<dyn LanguageModel>,
+ app_state: Arc<AgentAppState>,
+ cx: &mut App,
+ ) -> Task<Result<()>> {
+ let project = Project::local(
+ app_state.client.clone(),
+ app_state.node_runtime.clone(),
+ app_state.user_store.clone(),
+ app_state.languages.clone(),
+ Arc::new(DapRegistry::default()),
+ app_state.fs.clone(),
+ None,
+ cx,
+ );
+
+ let worktree = project.update(cx, |project, cx| {
+ project.create_worktree(self.base.path, true, cx)
+ });
+
+ let tools = Arc::new(ToolWorkingSet::default());
+ let thread_store =
+ ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx);
+
+ println!("USER:");
+ println!("{}", self.prompt);
+ println!("ASSISTANT:");
+ cx.spawn(async move |cx| {
+ worktree.await?;
+ let thread_store = thread_store.await;
+ let thread =
+ thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?;
+
+ let (tx, rx) = oneshot::channel();
+ let mut tx = Some(tx);
+
+ let _subscription =
+ cx.subscribe(
+ &thread,
+ move |thread, event: &ThreadEvent, cx| match event {
+ ThreadEvent::Stopped(reason) => match reason {
+ Ok(StopReason::EndTurn) => {
+ if let Some(tx) = tx.take() {
+ tx.send(Ok(())).ok();
+ }
+ }
+ Ok(StopReason::MaxTokens) => {
+ if let Some(tx) = tx.take() {
+ tx.send(Err(anyhow!("Exceeded maximum tokens"))).ok();
+ }
+ }
+ Ok(StopReason::ToolUse) => {}
+ Err(error) => {
+ if let Some(tx) = tx.take() {
+ tx.send(Err(anyhow!(error.clone()))).ok();
+ }
+ }
+ },
+ ThreadEvent::ShowError(thread_error) => {
+ if let Some(tx) = tx.take() {
+ tx.send(Err(anyhow!(thread_error.clone()))).ok();
+ }
+ }
+ ThreadEvent::StreamedAssistantText(_, chunk) => {
+ print!("{}", chunk);
+ }
+ ThreadEvent::StreamedAssistantThinking(_, chunk) => {
+ print!("{}", chunk);
+ }
+ ThreadEvent::UsePendingTools { tool_uses } => {
+ println!("\n\nUSING TOOLS:");
+ for tool_use in tool_uses {
+ println!("{}: {}", tool_use.name, tool_use.input);
+ }
+ }
+ ThreadEvent::ToolFinished {
+ tool_use_id,
+ pending_tool_use,
+ ..
+ } => {
+ if let Some(tool_use) = pending_tool_use {
+ println!("\nTOOL FINISHED: {}", tool_use.name);
+ }
+ if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
+ println!("\n{}\n", tool_result.content);
+ }
+ }
+ _ => {}
+ },
+ )?;
+
+ thread.update(cx, |thread, cx| {
+ let context = vec![];
+ thread.insert_user_message(self.prompt.clone(), context, None, cx);
+ thread.send_to_model(model, RequestKind::Chat, cx);
+ })?;
+
+ rx.await??;
+
+ Ok(())
+ })
+ }
+}
@@ -1,7 +1,7 @@
use crate::{
AnyView, AnyWindowHandle, App, AppCell, AppContext, BackgroundExecutor, BorrowAppContext,
- Entity, Focusable, ForegroundExecutor, Global, PromptLevel, Render, Reservation, Result, Task,
- VisualContext, Window, WindowHandle,
+ Entity, EventEmitter, Focusable, ForegroundExecutor, Global, PromptLevel, Render, Reservation,
+ Result, Subscription, Task, VisualContext, Window, WindowHandle,
};
use anyhow::{Context as _, anyhow};
use derive_more::{Deref, DerefMut};
@@ -154,6 +154,26 @@ impl AsyncApp {
Ok(lock.update(f))
}
+ /// Arrange for the given callback to be invoked whenever the given entity emits an event of a given type.
+ /// The callback is provided a handle to the emitting entity and a reference to the emitted event.
+ pub fn subscribe<T, Event>(
+ &mut self,
+ entity: &Entity<T>,
+ mut on_event: impl FnMut(Entity<T>, &Event, &mut App) + 'static,
+ ) -> Result<Subscription>
+ where
+ T: 'static + EventEmitter<Event>,
+ Event: 'static,
+ {
+ let app = self
+ .app
+ .upgrade()
+ .ok_or_else(|| anyhow!("app was released"))?;
+ let mut lock = app.borrow_mut();
+ let subscription = lock.subscribe(entity, on_event);
+ Ok(subscription)
+ }
+
/// Open a window with the given options based on the root view returned by the given function.
pub fn open_window<V>(
&self,
@@ -16,17 +16,17 @@ use std::{
use text::LineEnding;
use util::{ResultExt, get_system_shell};
-#[derive(Serialize)]
-pub struct AssistantSystemPromptContext {
- pub worktrees: Vec<WorktreeInfoForSystemPrompt>,
+#[derive(Debug, Clone, Serialize)]
+pub struct ProjectContext {
+ pub worktrees: Vec<WorktreeContext>,
pub has_rules: bool,
pub os: String,
pub arch: String,
pub shell: String,
}
-impl AssistantSystemPromptContext {
- pub fn new(worktrees: Vec<WorktreeInfoForSystemPrompt>) -> Self {
+impl ProjectContext {
+ pub fn new(worktrees: Vec<WorktreeContext>) -> Self {
let has_rules = worktrees
.iter()
.any(|worktree| worktree.rules_file.is_some());
@@ -40,15 +40,15 @@ impl AssistantSystemPromptContext {
}
}
-#[derive(Serialize)]
-pub struct WorktreeInfoForSystemPrompt {
+#[derive(Debug, Clone, Serialize)]
+pub struct WorktreeContext {
pub root_name: String,
pub abs_path: Arc<Path>,
- pub rules_file: Option<SystemPromptRulesFile>,
+ pub rules_file: Option<RulesFileContext>,
}
-#[derive(Serialize)]
-pub struct SystemPromptRulesFile {
+#[derive(Debug, Clone, Serialize)]
+pub struct RulesFileContext {
pub path_in_worktree: Arc<Path>,
pub abs_path: Arc<Path>,
pub text: String,
@@ -260,7 +260,7 @@ impl PromptBuilder {
pub fn generate_assistant_system_prompt(
&self,
- context: &AssistantSystemPromptContext,
+ context: &ProjectContext,
) -> Result<String, RenderError> {
self.handlebars
.lock()