Detailed changes
@@ -497,6 +497,8 @@ jobs:
env:
GIT_AUTHOR_NAME: Protobuf Action
GIT_AUTHOR_EMAIL: ci@zed.dev
+ GIT_COMMITTER_NAME: Protobuf Action
+ GIT_COMMITTER_EMAIL: ci@zed.dev
steps:
- name: steps::checkout_repo
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683
@@ -401,6 +401,7 @@ dependencies = [
"unindent",
"url",
"util",
+ "uuid",
"watch",
"workspace",
"zed_actions",
@@ -3110,16 +3111,6 @@ dependencies = [
"uuid",
]
-[[package]]
-name = "cloud_zeta2_prompt"
-version = "0.1.0"
-dependencies = [
- "anyhow",
- "cloud_llm_client",
- "indoc",
- "serde",
-]
-
[[package]]
name = "cmake"
version = "0.1.54"
@@ -3594,6 +3585,7 @@ dependencies = [
"settings",
"smol",
"tempfile",
+ "terminal",
"url",
"util",
]
@@ -5117,7 +5109,6 @@ dependencies = [
"clock",
"cloud_api_types",
"cloud_llm_client",
- "cloud_zeta2_prompt",
"collections",
"copilot",
"credentials_provider",
@@ -5148,8 +5139,6 @@ dependencies = [
"serde",
"serde_json",
"settings",
- "smol",
- "strsim",
"strum 0.27.2",
"telemetry",
"telemetry_events",
@@ -5160,6 +5149,7 @@ dependencies = [
"workspace",
"worktree",
"zed_actions",
+ "zeta_prompt",
"zlog",
]
@@ -5173,11 +5163,10 @@ dependencies = [
"clap",
"client",
"cloud_llm_client",
- "cloud_zeta2_prompt",
"collections",
"debug_adapter_extension",
+ "dirs 4.0.0",
"edit_prediction",
- "edit_prediction_context",
"extension",
"fs",
"futures 0.3.31",
@@ -5207,9 +5196,10 @@ dependencies = [
"sqlez",
"sqlez_macros",
"terminal_view",
- "toml 0.8.23",
"util",
+ "wasmtime",
"watch",
+ "zeta_prompt",
"zlog",
]
@@ -5237,6 +5227,7 @@ dependencies = [
"text",
"tree-sitter",
"util",
+ "zeta_prompt",
"zlog",
]
@@ -5247,6 +5238,7 @@ dependencies = [
"client",
"gpui",
"language",
+ "text",
]
[[package]]
@@ -5257,7 +5249,6 @@ dependencies = [
"buffer_diff",
"client",
"cloud_llm_client",
- "cloud_zeta2_prompt",
"codestral",
"command_palette_hooks",
"copilot",
@@ -5288,6 +5279,7 @@ dependencies = [
"util",
"workspace",
"zed_actions",
+ "zeta_prompt",
]
[[package]]
@@ -7763,7 +7755,6 @@ dependencies = [
"tempfile",
"url",
"util",
- "zed-reqwest",
]
[[package]]
@@ -13165,6 +13156,7 @@ dependencies = [
"askpass",
"auto_update",
"dap",
+ "db",
"editor",
"extension_host",
"file_finder",
@@ -13176,6 +13168,7 @@ dependencies = [
"log",
"markdown",
"menu",
+ "node_runtime",
"ordered-float 2.10.1",
"paths",
"picker",
@@ -13194,6 +13187,7 @@ dependencies = [
"util",
"windows-registry 0.6.1",
"workspace",
+ "worktree",
"zed_actions",
]
@@ -20478,7 +20472,7 @@ dependencies = [
[[package]]
name = "zed"
-version = "0.217.0"
+version = "0.218.0"
dependencies = [
"acp_tools",
"activity_indicator",
@@ -20938,6 +20932,13 @@ dependencies = [
"syn 2.0.106",
]
+[[package]]
+name = "zeta_prompt"
+version = "0.1.0"
+dependencies = [
+ "serde",
+]
+
[[package]]
name = "zip"
version = "0.6.6"
@@ -32,7 +32,6 @@ members = [
"crates/cloud_api_client",
"crates/cloud_api_types",
"crates/cloud_llm_client",
- "crates/cloud_zeta2_prompt",
"crates/collab",
"crates/collab_ui",
"crates/collections",
@@ -202,6 +201,7 @@ members = [
"crates/zed_actions",
"crates/zed_env_vars",
"crates/edit_prediction_cli",
+ "crates/zeta_prompt",
"crates/zlog",
"crates/zlog_settings",
"crates/ztracing",
@@ -266,7 +266,6 @@ clock = { path = "crates/clock" }
cloud_api_client = { path = "crates/cloud_api_client" }
cloud_api_types = { path = "crates/cloud_api_types" }
cloud_llm_client = { path = "crates/cloud_llm_client" }
-cloud_zeta2_prompt = { path = "crates/cloud_zeta2_prompt" }
collab_ui = { path = "crates/collab_ui" }
collections = { path = "crates/collections", version = "0.1.0" }
command_palette = { path = "crates/command_palette" }
@@ -425,6 +424,7 @@ zed = { path = "crates/zed" }
zed_actions = { path = "crates/zed_actions" }
zed_env_vars = { path = "crates/zed_env_vars" }
edit_prediction = { path = "crates/edit_prediction" }
+zeta_prompt = { path = "crates/zeta_prompt" }
zlog = { path = "crates/zlog" }
zlog_settings = { path = "crates/zlog_settings" }
ztracing = { path = "crates/ztracing" }
@@ -657,6 +657,7 @@ time = { version = "0.3", features = [
tiny_http = "0.8"
tokio = { version = "1" }
tokio-tungstenite = { version = "0.26", features = ["__rustls-tls"] }
+tokio-socks = { version = "0.5.2", default-features = false, features = ["futures-io", "tokio"] }
toml = "0.8"
toml_edit = { version = "0.22", default-features = false, features = ["display", "parse", "serde"] }
tower-http = "0.4.4"
@@ -0,0 +1,5 @@
+<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
+<path d="M13.3996 5.59852C13.3994 5.3881 13.3439 5.18144 13.2386 4.99926C13.1333 4.81709 12.9819 4.66581 12.7997 4.56059L8.59996 2.16076C8.41755 2.05544 8.21063 2 8 2C7.78937 2 7.58246 2.05544 7.40004 2.16076L3.20033 4.56059C3.0181 4.66581 2.86674 4.81709 2.76144 4.99926C2.65613 5.18144 2.60059 5.3881 2.60037 5.59852V10.3982C2.60059 10.6086 2.65613 10.8153 2.76144 10.9975C2.86674 11.1796 3.0181 11.3309 3.20033 11.4361L7.40004 13.836C7.58246 13.9413 7.78937 13.9967 8 13.9967C8.21063 13.9967 8.41755 13.9413 8.59996 13.836L12.7997 11.4361C12.9819 11.3309 13.1333 11.1796 13.2386 10.9975C13.3439 10.8153 13.3994 10.6086 13.3996 10.3982V5.59852Z" stroke="white" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
+<path d="M2.78033 4.99857L7.99998 7.99836L13.2196 4.99857" stroke="white" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
+<path d="M8 13.9979V7.99829" stroke="white" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
+</svg>
@@ -811,7 +811,10 @@
"context": "PromptEditor",
"bindings": {
"ctrl-[": "agent::CyclePreviousInlineAssist",
- "ctrl-]": "agent::CycleNextInlineAssist"
+ "ctrl-]": "agent::CycleNextInlineAssist",
+ "ctrl-shift-enter": "inline_assistant::ThumbsUpResult",
+ "ctrl-shift-backspace": "inline_assistant::ThumbsDownResult"
+
}
},
{
@@ -878,7 +878,9 @@
"bindings": {
"cmd-alt-/": "agent::ToggleModelSelector",
"ctrl-[": "agent::CyclePreviousInlineAssist",
- "ctrl-]": "agent::CycleNextInlineAssist"
+ "ctrl-]": "agent::CycleNextInlineAssist",
+ "cmd-shift-enter": "inline_assistant::ThumbsUpResult",
+ "cmd-shift-backspace": "inline_assistant::ThumbsDownResult"
}
},
{
@@ -816,7 +816,9 @@
"use_key_equivalents": true,
"bindings": {
"ctrl-[": "agent::CyclePreviousInlineAssist",
- "ctrl-]": "agent::CycleNextInlineAssist"
+ "ctrl-]": "agent::CycleNextInlineAssist",
+ "ctrl-shift-enter": "inline_assistant::ThumbsUpResult",
+ "ctrl-shift-delete": "inline_assistant::ThumbsDownResult"
}
},
{
@@ -180,7 +180,6 @@
"ctrl-w g shift-d": "editor::GoToTypeDefinitionSplit",
"ctrl-w space": "editor::OpenExcerptsSplit",
"ctrl-w g space": "editor::OpenExcerptsSplit",
- "ctrl-6": "pane::AlternateFile",
"ctrl-^": "pane::AlternateFile",
".": "vim::Repeat"
}
@@ -902,7 +901,11 @@
"context": "!Editor && !Terminal",
"bindings": {
":": "command_palette::Toggle",
- "g /": "pane::DeploySearch"
+ "g /": "pane::DeploySearch",
+ "] b": "pane::ActivateNextItem",
+ "[ b": "pane::ActivatePreviousItem",
+ "] shift-b": "pane::ActivateLastItem",
+ "[ shift-b": ["pane::ActivateItem", 0]
}
},
{
@@ -870,6 +870,10 @@
//
// Default: false
"collapse_untracked_diff": false,
+ /// Whether to show entries with tree or flat view in the panel
+ ///
+ /// Default: false
+ "tree_view": false,
"scrollbar": {
// When to show the scrollbar in the git panel.
//
@@ -1815,6 +1819,9 @@
"allowed": false
}
},
+ "CSharp": {
+ "language_servers": ["roslyn", "!omnisharp", "..."]
+ },
"CSS": {
"prettier": {
"allowed": true
@@ -1372,7 +1372,7 @@ impl AcpThread {
let path_style = self.project.read(cx).path_style(cx);
let id = update.tool_call_id.clone();
- let agent = self.connection().telemetry_id();
+ let agent_telemetry_id = self.connection().telemetry_id();
let session = self.session_id();
if let ToolCallStatus::Completed | ToolCallStatus::Failed = status {
let status = if matches!(status, ToolCallStatus::Completed) {
@@ -1380,7 +1380,12 @@ impl AcpThread {
} else {
"failed"
};
- telemetry::event!("Agent Tool Call Completed", agent, session, status);
+ telemetry::event!(
+ "Agent Tool Call Completed",
+ agent_telemetry_id,
+ session,
+ status
+ );
}
if let Some(ix) = self.index_for_tool_call(&id) {
@@ -3556,8 +3561,8 @@ mod tests {
}
impl AgentConnection for FakeAgentConnection {
- fn telemetry_id(&self) -> &'static str {
- "fake"
+ fn telemetry_id(&self) -> SharedString {
+ "fake".into()
}
fn auth_methods(&self) -> &[acp::AuthMethod] {
@@ -20,7 +20,7 @@ impl UserMessageId {
}
pub trait AgentConnection {
- fn telemetry_id(&self) -> &'static str;
+ fn telemetry_id(&self) -> SharedString;
fn new_thread(
self: Rc<Self>,
@@ -331,8 +331,8 @@ mod test_support {
}
impl AgentConnection for StubAgentConnection {
- fn telemetry_id(&self) -> &'static str {
- "stub"
+ fn telemetry_id(&self) -> SharedString {
+ "stub".into()
}
fn auth_methods(&self) -> &[acp::AuthMethod] {
@@ -777,7 +777,7 @@ impl ActionLog {
#[derive(Clone)]
pub struct ActionLogTelemetry {
- pub agent_telemetry_id: &'static str,
+ pub agent_telemetry_id: SharedString,
pub session_id: Arc<str>,
}
@@ -952,8 +952,8 @@ impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
}
impl acp_thread::AgentConnection for NativeAgentConnection {
- fn telemetry_id(&self) -> &'static str {
- "zed"
+ fn telemetry_id(&self) -> SharedString {
+ "zed".into()
}
fn new_thread(
@@ -21,10 +21,6 @@ impl NativeAgentServer {
}
impl AgentServer for NativeAgentServer {
- fn telemetry_id(&self) -> &'static str {
- "zed"
- }
-
fn name(&self) -> SharedString {
"Zed Agent".into()
}
@@ -9,6 +9,10 @@ use futures::io::BufReader;
use project::Project;
use project::agent_server_store::AgentServerCommand;
use serde::Deserialize;
+use settings::Settings as _;
+use task::ShellBuilder;
+#[cfg(windows)]
+use task::ShellKind;
use util::ResultExt as _;
use std::path::PathBuf;
@@ -21,7 +25,7 @@ use gpui::{App, AppContext as _, AsyncApp, Entity, SharedString, Task, WeakEntit
use acp_thread::{AcpThread, AuthRequired, LoadError, TerminalProviderEvent};
use terminal::TerminalBuilder;
-use terminal::terminal_settings::{AlternateScroll, CursorShape};
+use terminal::terminal_settings::{AlternateScroll, CursorShape, TerminalSettings};
#[derive(Debug, Error)]
#[error("Unsupported version")]
@@ -29,7 +33,7 @@ pub struct UnsupportedVersion;
pub struct AcpConnection {
server_name: SharedString,
- telemetry_id: &'static str,
+ telemetry_id: SharedString,
connection: Rc<acp::ClientSideConnection>,
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
auth_methods: Vec<acp::AuthMethod>,
@@ -54,7 +58,6 @@ pub struct AcpSession {
pub async fn connect(
server_name: SharedString,
- telemetry_id: &'static str,
command: AgentServerCommand,
root_dir: &Path,
default_mode: Option<acp::SessionModeId>,
@@ -64,7 +67,6 @@ pub async fn connect(
) -> Result<Rc<dyn AgentConnection>> {
let conn = AcpConnection::stdio(
server_name,
- telemetry_id,
command.clone(),
root_dir,
default_mode,
@@ -81,7 +83,6 @@ const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::ProtocolVersion::V1
impl AcpConnection {
pub async fn stdio(
server_name: SharedString,
- telemetry_id: &'static str,
command: AgentServerCommand,
root_dir: &Path,
default_mode: Option<acp::SessionModeId>,
@@ -89,9 +90,26 @@ impl AcpConnection {
is_remote: bool,
cx: &mut AsyncApp,
) -> Result<Self> {
- let mut child = util::command::new_smol_command(&command.path);
+ let shell = cx.update(|cx| TerminalSettings::get(None, cx).shell.clone())?;
+ let builder = ShellBuilder::new(&shell, cfg!(windows));
+ #[cfg(windows)]
+ let kind = builder.kind();
+ let (cmd, args) = builder.build(Some(command.path.display().to_string()), &command.args);
+
+ let mut child = util::command::new_smol_command(cmd);
+ #[cfg(windows)]
+ if kind == ShellKind::Cmd {
+ use smol::process::windows::CommandExt;
+ for arg in args {
+ child.raw_arg(arg);
+ }
+ } else {
+ child.args(args);
+ }
+ #[cfg(not(windows))]
+ child.args(args);
+
child
- .args(command.args.iter().map(|arg| arg.as_str()))
.envs(command.env.iter().flatten())
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
@@ -199,6 +217,13 @@ impl AcpConnection {
return Err(UnsupportedVersion.into());
}
+ let telemetry_id = response
+ .agent_info
+ // Use the one the agent provides if we have one
+ .map(|info| info.name.into())
+ // Otherwise, just use the name
+ .unwrap_or_else(|| server_name.clone());
+
Ok(Self {
auth_methods: response.auth_methods,
root_dir: root_dir.to_owned(),
@@ -233,8 +258,8 @@ impl Drop for AcpConnection {
}
impl AgentConnection for AcpConnection {
- fn telemetry_id(&self) -> &'static str {
- self.telemetry_id
+ fn telemetry_id(&self) -> SharedString {
+ self.telemetry_id.clone()
}
fn new_thread(
@@ -56,7 +56,6 @@ impl AgentServerDelegate {
pub trait AgentServer: Send {
fn logo(&self) -> ui::IconName;
fn name(&self) -> SharedString;
- fn telemetry_id(&self) -> &'static str;
fn default_mode(&self, _cx: &mut App) -> Option<agent_client_protocol::SessionModeId> {
None
}
@@ -22,10 +22,6 @@ pub struct AgentServerLoginCommand {
}
impl AgentServer for ClaudeCode {
- fn telemetry_id(&self) -> &'static str {
- "claude-code"
- }
-
fn name(&self) -> SharedString {
"Claude Code".into()
}
@@ -83,7 +79,6 @@ impl AgentServer for ClaudeCode {
cx: &mut App,
) -> Task<Result<(Rc<dyn AgentConnection>, Option<task::SpawnInTerminal>)>> {
let name = self.name();
- let telemetry_id = self.telemetry_id();
let root_dir = root_dir.map(|root_dir| root_dir.to_string_lossy().into_owned());
let is_remote = delegate.project.read(cx).is_via_remote_server();
let store = delegate.store.downgrade();
@@ -108,7 +103,6 @@ impl AgentServer for ClaudeCode {
.await?;
let connection = crate::acp::connect(
name,
- telemetry_id,
command,
root_dir.as_ref(),
default_mode,
@@ -23,10 +23,6 @@ pub(crate) mod tests {
}
impl AgentServer for Codex {
- fn telemetry_id(&self) -> &'static str {
- "codex"
- }
-
fn name(&self) -> SharedString {
"Codex".into()
}
@@ -84,7 +80,6 @@ impl AgentServer for Codex {
cx: &mut App,
) -> Task<Result<(Rc<dyn AgentConnection>, Option<task::SpawnInTerminal>)>> {
let name = self.name();
- let telemetry_id = self.telemetry_id();
let root_dir = root_dir.map(|root_dir| root_dir.to_string_lossy().into_owned());
let is_remote = delegate.project.read(cx).is_via_remote_server();
let store = delegate.store.downgrade();
@@ -110,7 +105,6 @@ impl AgentServer for Codex {
let connection = crate::acp::connect(
name,
- telemetry_id,
command,
root_dir.as_ref(),
default_mode,
@@ -1,4 +1,4 @@
-use crate::{AgentServerDelegate, load_proxy_env};
+use crate::{AgentServer, AgentServerDelegate, load_proxy_env};
use acp_thread::AgentConnection;
use agent_client_protocol as acp;
use anyhow::{Context as _, Result};
@@ -20,11 +20,7 @@ impl CustomAgentServer {
}
}
-impl crate::AgentServer for CustomAgentServer {
- fn telemetry_id(&self) -> &'static str {
- "custom"
- }
-
+impl AgentServer for CustomAgentServer {
fn name(&self) -> SharedString {
self.name.clone()
}
@@ -112,14 +108,12 @@ impl crate::AgentServer for CustomAgentServer {
cx: &mut App,
) -> Task<Result<(Rc<dyn AgentConnection>, Option<task::SpawnInTerminal>)>> {
let name = self.name();
- let telemetry_id = self.telemetry_id();
let root_dir = root_dir.map(|root_dir| root_dir.to_string_lossy().into_owned());
let is_remote = delegate.project.read(cx).is_via_remote_server();
let default_mode = self.default_mode(cx);
let default_model = self.default_model(cx);
let store = delegate.store.downgrade();
let extra_env = load_proxy_env(cx);
-
cx.spawn(async move |cx| {
let (command, root_dir, login) = store
.update(cx, |store, cx| {
@@ -139,7 +133,6 @@ impl crate::AgentServer for CustomAgentServer {
.await?;
let connection = crate::acp::connect(
name,
- telemetry_id,
command,
root_dir.as_ref(),
default_mode,
@@ -12,10 +12,6 @@ use project::agent_server_store::GEMINI_NAME;
pub struct Gemini;
impl AgentServer for Gemini {
- fn telemetry_id(&self) -> &'static str {
- "gemini-cli"
- }
-
fn name(&self) -> SharedString {
"Gemini CLI".into()
}
@@ -31,7 +27,6 @@ impl AgentServer for Gemini {
cx: &mut App,
) -> Task<Result<(Rc<dyn AgentConnection>, Option<task::SpawnInTerminal>)>> {
let name = self.name();
- let telemetry_id = self.telemetry_id();
let root_dir = root_dir.map(|root_dir| root_dir.to_string_lossy().into_owned());
let is_remote = delegate.project.read(cx).is_via_remote_server();
let store = delegate.store.downgrade();
@@ -62,7 +57,6 @@ impl AgentServer for Gemini {
let connection = crate::acp::connect(
name,
- telemetry_id,
command,
root_dir.as_ref(),
default_mode,
@@ -95,6 +95,7 @@ ui.workspace = true
ui_input.workspace = true
url.workspace = true
util.workspace = true
+uuid.workspace = true
watch.workspace = true
workspace.workspace = true
zed_actions.workspace = true
@@ -565,7 +565,33 @@ impl MessageEditor {
if let Some((workspace, selections)) =
self.workspace.upgrade().zip(editor_clipboard_selections)
{
+ let Some(first_selection) = selections.first() else {
+ return;
+ };
+ if let Some(file_path) = &first_selection.file_path {
+ // In case someone pastes selections from another window
+ // with a different project, we don't want to insert the
+ // crease (containing the absolute path) since the agent
+ // cannot access files outside the project.
+ let is_in_project = workspace
+ .read(cx)
+ .project()
+ .read(cx)
+ .project_path_for_absolute_path(file_path, cx)
+ .is_some();
+ if !is_in_project {
+ return;
+ }
+ }
+
cx.stop_propagation();
+ let insertion_target = self
+ .editor
+ .read(cx)
+ .selections
+ .newest_anchor()
+ .start
+ .text_anchor;
let project = workspace.read(cx).project().clone();
for selection in selections {
@@ -587,8 +613,7 @@ impl MessageEditor {
let snapshot = buffer.snapshot(cx);
let (excerpt_id, _, buffer_snapshot) =
snapshot.as_singleton().unwrap();
- let start_offset = buffer_snapshot.len();
- let text_anchor = buffer_snapshot.anchor_before(start_offset);
+ let text_anchor = insertion_target.bias_left(&buffer_snapshot);
editor.insert(&mention_text, window, cx);
editor.insert(" ", window, cx);
@@ -170,7 +170,7 @@ impl ThreadFeedbackState {
}
}
let session_id = thread.read(cx).session_id().clone();
- let agent = thread.read(cx).connection().telemetry_id();
+ let agent_telemetry_id = thread.read(cx).connection().telemetry_id();
let task = telemetry.thread_data(&session_id, cx);
let rating = match feedback {
ThreadFeedback::Positive => "positive",
@@ -180,7 +180,7 @@ impl ThreadFeedbackState {
let thread = task.await?;
telemetry::event!(
"Agent Thread Rated",
- agent = agent,
+ agent = agent_telemetry_id,
session_id = session_id,
rating = rating,
thread = thread
@@ -207,13 +207,13 @@ impl ThreadFeedbackState {
self.comments_editor.take();
let session_id = thread.read(cx).session_id().clone();
- let agent = thread.read(cx).connection().telemetry_id();
+ let agent_telemetry_id = thread.read(cx).connection().telemetry_id();
let task = telemetry.thread_data(&session_id, cx);
cx.background_spawn(async move {
let thread = task.await?;
telemetry::event!(
"Agent Thread Feedback Comments",
- agent = agent,
+ agent = agent_telemetry_id,
session_id = session_id,
comments = comments,
thread = thread
@@ -333,6 +333,7 @@ impl AcpThreadView {
project: Entity<Project>,
history_store: Entity<HistoryStore>,
prompt_store: Option<Entity<PromptStore>>,
+ track_load_event: bool,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
@@ -391,8 +392,9 @@ impl AcpThreadView {
),
];
- let show_codex_windows_warning = crate::ExternalAgent::parse_built_in(agent.as_ref())
- == Some(crate::ExternalAgent::Codex);
+ let show_codex_windows_warning = cfg!(windows)
+ && project.read(cx).is_local()
+ && agent.clone().downcast::<agent_servers::Codex>().is_some();
Self {
agent: agent.clone(),
@@ -404,6 +406,7 @@ impl AcpThreadView {
resume_thread.clone(),
workspace.clone(),
project.clone(),
+ track_load_event,
window,
cx,
),
@@ -448,6 +451,7 @@ impl AcpThreadView {
self.resume_thread_metadata.clone(),
self.workspace.clone(),
self.project.clone(),
+ true,
window,
cx,
);
@@ -461,6 +465,7 @@ impl AcpThreadView {
resume_thread: Option<DbThreadMetadata>,
workspace: WeakEntity<Workspace>,
project: Entity<Project>,
+ track_load_event: bool,
window: &mut Window,
cx: &mut Context<Self>,
) -> ThreadState {
@@ -519,6 +524,10 @@ impl AcpThreadView {
}
};
+ if track_load_event {
+ telemetry::event!("Agent Thread Started", agent = connection.telemetry_id());
+ }
+
let result = if let Some(native_agent) = connection
.clone()
.downcast::<agent::NativeAgentConnection>()
@@ -1133,8 +1142,8 @@ impl AcpThreadView {
let Some(thread) = self.thread() else {
return;
};
- let agent_telemetry_id = self.agent.telemetry_id();
let session_id = thread.read(cx).session_id().clone();
+ let agent_telemetry_id = thread.read(cx).connection().telemetry_id();
let thread = thread.downgrade();
if self.should_be_following {
self.workspace
@@ -1512,6 +1521,7 @@ impl AcpThreadView {
else {
return;
};
+ let agent_telemetry_id = connection.telemetry_id();
// Check for the experimental "terminal-auth" _meta field
let auth_method = connection.auth_methods().iter().find(|m| m.id == method);
@@ -1579,19 +1589,18 @@ impl AcpThreadView {
);
cx.notify();
self.auth_task = Some(cx.spawn_in(window, {
- let agent = self.agent.clone();
async move |this, cx| {
let result = authenticate.await;
match &result {
Ok(_) => telemetry::event!(
"Authenticate Agent Succeeded",
- agent = agent.telemetry_id()
+ agent = agent_telemetry_id
),
Err(_) => {
telemetry::event!(
"Authenticate Agent Failed",
- agent = agent.telemetry_id(),
+ agent = agent_telemetry_id,
)
}
}
@@ -1675,6 +1684,7 @@ impl AcpThreadView {
None,
this.workspace.clone(),
this.project.clone(),
+ true,
window,
cx,
)
@@ -1730,43 +1740,38 @@ impl AcpThreadView {
connection.authenticate(method, cx)
};
cx.notify();
- self.auth_task =
- Some(cx.spawn_in(window, {
- let agent = self.agent.clone();
- async move |this, cx| {
- let result = authenticate.await;
-
- match &result {
- Ok(_) => telemetry::event!(
- "Authenticate Agent Succeeded",
- agent = agent.telemetry_id()
- ),
- Err(_) => {
- telemetry::event!(
- "Authenticate Agent Failed",
- agent = agent.telemetry_id(),
- )
- }
+ self.auth_task = Some(cx.spawn_in(window, {
+ async move |this, cx| {
+ let result = authenticate.await;
+
+ match &result {
+ Ok(_) => telemetry::event!(
+ "Authenticate Agent Succeeded",
+ agent = agent_telemetry_id
+ ),
+ Err(_) => {
+ telemetry::event!("Authenticate Agent Failed", agent = agent_telemetry_id,)
}
+ }
- this.update_in(cx, |this, window, cx| {
- if let Err(err) = result {
- if let ThreadState::Unauthenticated {
- pending_auth_method,
- ..
- } = &mut this.thread_state
- {
- pending_auth_method.take();
- }
- this.handle_thread_error(err, cx);
- } else {
- this.reset(window, cx);
+ this.update_in(cx, |this, window, cx| {
+ if let Err(err) = result {
+ if let ThreadState::Unauthenticated {
+ pending_auth_method,
+ ..
+ } = &mut this.thread_state
+ {
+ pending_auth_method.take();
}
- this.auth_task.take()
- })
- .ok();
- }
- }));
+ this.handle_thread_error(err, cx);
+ } else {
+ this.reset(window, cx);
+ }
+ this.auth_task.take()
+ })
+ .ok();
+ }
+ }));
}
fn spawn_external_agent_login(
@@ -1896,10 +1901,11 @@ impl AcpThreadView {
let Some(thread) = self.thread() else {
return;
};
+ let agent_telemetry_id = thread.read(cx).connection().telemetry_id();
telemetry::event!(
"Agent Tool Call Authorized",
- agent = self.agent.telemetry_id(),
+ agent = agent_telemetry_id,
session = thread.read(cx).session_id(),
option = option_kind
);
@@ -3509,7 +3515,9 @@ impl AcpThreadView {
(method.id.0.clone(), method.name.clone())
};
- Button::new(SharedString::from(method_id.clone()), name)
+ let agent_telemetry_id = connection.telemetry_id();
+
+ Button::new(method_id.clone(), name)
.label_size(LabelSize::Small)
.map(|this| {
if ix == 0 {
@@ -3528,7 +3536,7 @@ impl AcpThreadView {
cx.listener(move |this, _, window, cx| {
telemetry::event!(
"Authenticate Agent Started",
- agent = this.agent.telemetry_id(),
+ agent = agent_telemetry_id,
method = method_id
);
@@ -5376,47 +5384,39 @@ impl AcpThreadView {
)
}
- fn render_codex_windows_warning(&self, cx: &mut Context<Self>) -> Option<Callout> {
- if self.show_codex_windows_warning {
- Some(
- Callout::new()
- .icon(IconName::Warning)
- .severity(Severity::Warning)
- .title("Codex on Windows")
- .description(
- "For best performance, run Codex in Windows Subsystem for Linux (WSL2)",
- )
- .actions_slot(
- Button::new("open-wsl-modal", "Open in WSL")
- .icon_size(IconSize::Small)
- .icon_color(Color::Muted)
- .on_click(cx.listener({
- move |_, _, _window, cx| {
- #[cfg(windows)]
- _window.dispatch_action(
- zed_actions::wsl_actions::OpenWsl::default().boxed_clone(),
- cx,
- );
- cx.notify();
- }
- })),
- )
- .dismiss_action(
- IconButton::new("dismiss", IconName::Close)
- .icon_size(IconSize::Small)
- .icon_color(Color::Muted)
- .tooltip(Tooltip::text("Dismiss Warning"))
- .on_click(cx.listener({
- move |this, _, _, cx| {
- this.show_codex_windows_warning = false;
- cx.notify();
- }
- })),
- ),
+ fn render_codex_windows_warning(&self, cx: &mut Context<Self>) -> Callout {
+ Callout::new()
+ .icon(IconName::Warning)
+ .severity(Severity::Warning)
+ .title("Codex on Windows")
+ .description("For best performance, run Codex in Windows Subsystem for Linux (WSL2)")
+ .actions_slot(
+ Button::new("open-wsl-modal", "Open in WSL")
+ .icon_size(IconSize::Small)
+ .icon_color(Color::Muted)
+ .on_click(cx.listener({
+ move |_, _, _window, cx| {
+ #[cfg(windows)]
+ _window.dispatch_action(
+ zed_actions::wsl_actions::OpenWsl::default().boxed_clone(),
+ cx,
+ );
+ cx.notify();
+ }
+ })),
+ )
+ .dismiss_action(
+ IconButton::new("dismiss", IconName::Close)
+ .icon_size(IconSize::Small)
+ .icon_color(Color::Muted)
+ .tooltip(Tooltip::text("Dismiss Warning"))
+ .on_click(cx.listener({
+ move |this, _, _, cx| {
+ this.show_codex_windows_warning = false;
+ cx.notify();
+ }
+ })),
)
- } else {
- None
- }
}
fn render_thread_error(&mut self, window: &mut Window, cx: &mut Context<Self>) -> Option<Div> {
@@ -5936,12 +5936,8 @@ impl Render for AcpThreadView {
_ => this,
})
.children(self.render_thread_retry_status_callout(window, cx))
- .children({
- if cfg!(windows) && self.project.read(cx).is_local() {
- self.render_codex_windows_warning(cx)
- } else {
- None
- }
+ .when(self.show_codex_windows_warning, |this| {
+ this.child(self.render_codex_windows_warning(cx))
})
.children(self.render_thread_error(window, cx))
.when_some(
@@ -6398,6 +6394,7 @@ pub(crate) mod tests {
project,
history_store,
None,
+ false,
window,
cx,
)
@@ -6475,10 +6472,6 @@ pub(crate) mod tests {
where
C: 'static + AgentConnection + Send + Clone,
{
- fn telemetry_id(&self) -> &'static str {
- "test"
- }
-
fn logo(&self) -> ui::IconName {
ui::IconName::Ai
}
@@ -6505,8 +6498,8 @@ pub(crate) mod tests {
struct SaboteurAgentConnection;
impl AgentConnection for SaboteurAgentConnection {
- fn telemetry_id(&self) -> &'static str {
- "saboteur"
+ fn telemetry_id(&self) -> SharedString {
+ "saboteur".into()
}
fn new_thread(
@@ -6569,8 +6562,8 @@ pub(crate) mod tests {
struct RefusalAgentConnection;
impl AgentConnection for RefusalAgentConnection {
- fn telemetry_id(&self) -> &'static str {
- "refusal"
+ fn telemetry_id(&self) -> SharedString {
+ "refusal".into()
}
fn new_thread(
@@ -6671,6 +6664,7 @@ pub(crate) mod tests {
project.clone(),
history_store.clone(),
None,
+ false,
window,
cx,
)
@@ -842,7 +842,7 @@ impl AgentConfiguration {
.min_w_0()
.child(
h_flex()
- .id(SharedString::from(format!("tooltip-{}", item_id)))
+ .id(format!("tooltip-{}", item_id))
.h_full()
.w_3()
.mr_2()
@@ -982,7 +982,10 @@ impl AgentConfiguration {
} else {
AgentIcon::Name(IconName::Ai)
};
- (name, icon)
+ let display_name = agent_server_store
+ .agent_display_name(&name)
+ .unwrap_or_else(|| name.0.clone());
+ (name, icon, display_name)
})
.collect();
@@ -1089,6 +1092,7 @@ impl AgentConfiguration {
.child(self.render_agent_server(
AgentIcon::Name(IconName::AiClaude),
"Claude Code",
+ "Claude Code",
false,
cx,
))
@@ -1096,6 +1100,7 @@ impl AgentConfiguration {
.child(self.render_agent_server(
AgentIcon::Name(IconName::AiOpenAi),
"Codex CLI",
+ "Codex CLI",
false,
cx,
))
@@ -1103,16 +1108,23 @@ impl AgentConfiguration {
.child(self.render_agent_server(
AgentIcon::Name(IconName::AiGemini),
"Gemini CLI",
+ "Gemini CLI",
false,
cx,
))
.map(|mut parent| {
- for (name, icon) in user_defined_agents {
+ for (name, icon, display_name) in user_defined_agents {
parent = parent
.child(
Divider::horizontal().color(DividerColor::BorderFaded),
)
- .child(self.render_agent_server(icon, name, true, cx));
+ .child(self.render_agent_server(
+ icon,
+ name,
+ display_name,
+ true,
+ cx,
+ ));
}
parent
}),
@@ -1123,11 +1135,13 @@ impl AgentConfiguration {
fn render_agent_server(
&self,
icon: AgentIcon,
- name: impl Into<SharedString>,
+ id: impl Into<SharedString>,
+ display_name: impl Into<SharedString>,
external: bool,
cx: &mut Context<Self>,
) -> impl IntoElement {
- let name = name.into();
+ let id = id.into();
+ let display_name = display_name.into();
let icon = match icon {
AgentIcon::Name(icon_name) => Icon::new(icon_name)
.size(IconSize::Small)
@@ -1137,12 +1151,15 @@ impl AgentConfiguration {
.color(Color::Muted),
};
- let tooltip_id = SharedString::new(format!("agent-source-{}", name));
- let tooltip_message = format!("The {} agent was installed from an extension.", name);
+ let tooltip_id = SharedString::new(format!("agent-source-{}", id));
+ let tooltip_message = format!(
+ "The {} agent was installed from an extension.",
+ display_name
+ );
- let agent_server_name = ExternalAgentServerName(name.clone());
+ let agent_server_name = ExternalAgentServerName(id.clone());
- let uninstall_btn_id = SharedString::from(format!("uninstall-{}", name));
+ let uninstall_btn_id = SharedString::from(format!("uninstall-{}", id));
let uninstall_button = IconButton::new(uninstall_btn_id, IconName::Trash)
.icon_color(Color::Muted)
.icon_size(IconSize::Small)
@@ -1166,7 +1183,7 @@ impl AgentConfiguration {
h_flex()
.gap_1p5()
.child(icon)
- .child(Label::new(name))
+ .child(Label::new(display_name))
.when(external, |this| {
this.child(
div()
@@ -87,7 +87,7 @@ impl ConfigureContextServerToolsModal {
v_flex()
.child(
h_flex()
- .id(SharedString::from(format!("tool-header-{}", index)))
+ .id(format!("tool-header-{}", index))
.py_1()
.pl_1()
.pr_2()
@@ -422,7 +422,7 @@ impl ManageProfilesModal {
let is_focused = profile.navigation.focus_handle.contains_focused(window, cx);
div()
- .id(SharedString::from(format!("profile-{}", profile.id)))
+ .id(format!("profile-{}", profile.id))
.track_focus(&profile.navigation.focus_handle)
.on_action({
let profile_id = profile.id.clone();
@@ -431,7 +431,7 @@ impl ManageProfilesModal {
})
})
.child(
- ListItem::new(SharedString::from(format!("profile-{}", profile.id)))
+ ListItem::new(format!("profile-{}", profile.id))
.toggle_state(is_focused)
.inset(true)
.spacing(ListItemSpacing::Sparse)
@@ -63,6 +63,10 @@ impl AgentModelSelector {
pub fn toggle(&self, window: &mut Window, cx: &mut Context<Self>) {
self.menu_handle.toggle(window, cx);
}
+
+ pub fn active_model(&self, cx: &App) -> Option<language_model::ConfiguredModel> {
+ self.selector.read(cx).delegate.active_model(cx)
+ }
}
impl Render for AgentModelSelector {
@@ -305,6 +305,7 @@ impl ActiveView {
project,
history_store,
prompt_store,
+ false,
window,
cx,
)
@@ -885,10 +886,6 @@ impl AgentPanel {
let server = ext_agent.server(fs, history);
- if !loading {
- telemetry::event!("Agent Thread Started", agent = server.telemetry_id());
- }
-
this.update_in(cx, |this, window, cx| {
let selected_agent = ext_agent.into();
if this.selected_agent != selected_agent {
@@ -905,6 +902,7 @@ impl AgentPanel {
project,
this.history_store.clone(),
this.prompt_store.clone(),
+ !loading,
window,
cx,
)
@@ -2083,8 +2081,11 @@ impl AgentPanel {
for agent_name in agent_names {
let icon_path = agent_server_store.agent_icon(&agent_name);
+ let display_name = agent_server_store
+ .agent_display_name(&agent_name)
+ .unwrap_or_else(|| agent_name.0.clone());
- let mut entry = ContextMenuEntry::new(agent_name.clone());
+ let mut entry = ContextMenuEntry::new(display_name);
if let Some(icon_path) = icon_path {
entry = entry.custom_icon_svg(icon_path);
@@ -160,16 +160,6 @@ pub enum ExternalAgent {
}
impl ExternalAgent {
- pub fn parse_built_in(server: &dyn agent_servers::AgentServer) -> Option<Self> {
- match server.telemetry_id() {
- "gemini-cli" => Some(Self::Gemini),
- "claude-code" => Some(Self::ClaudeCode),
- "codex" => Some(Self::Codex),
- "zed" => Some(Self::NativeAgent),
- _ => None,
- }
- }
-
pub fn server(
&self,
fs: Arc<dyn fs::Fs>,
@@ -119,6 +119,10 @@ impl BufferCodegen {
.push(cx.subscribe(&codegen, |_, _, event, cx| cx.emit(*event)));
}
+ pub fn active_completion(&self, cx: &App) -> Option<String> {
+ self.active_alternative().read(cx).current_completion()
+ }
+
pub fn active_alternative(&self) -> &Entity<CodegenAlternative> {
&self.alternatives[self.active_alternative]
}
@@ -241,6 +245,10 @@ impl BufferCodegen {
pub fn last_equal_ranges<'a>(&self, cx: &'a App) -> &'a [Range<Anchor>] {
self.active_alternative().read(cx).last_equal_ranges()
}
+
+ pub fn selected_text<'a>(&self, cx: &'a App) -> Option<&'a str> {
+ self.active_alternative().read(cx).selected_text()
+ }
}
impl EventEmitter<CodegenEvent> for BufferCodegen {}
@@ -264,6 +272,7 @@ pub struct CodegenAlternative {
line_operations: Vec<LineOperation>,
elapsed_time: Option<f64>,
completion: Option<String>,
+ selected_text: Option<String>,
pub message_id: Option<String>,
pub model_explanation: Option<SharedString>,
}
@@ -323,6 +332,7 @@ impl CodegenAlternative {
range,
elapsed_time: None,
completion: None,
+ selected_text: None,
model_explanation: None,
_subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
}
@@ -608,6 +618,8 @@ impl CodegenAlternative {
.text_for_range(self.range.start..self.range.end)
.collect::<Rope>();
+ self.selected_text = Some(selected_text.to_string());
+
let selection_start = self.range.start.to_point(&snapshot);
// Start with the indentation of the first line in the selection
@@ -868,6 +880,14 @@ impl CodegenAlternative {
cx.notify();
}
+ pub fn current_completion(&self) -> Option<String> {
+ self.completion.clone()
+ }
+
+ pub fn selected_text(&self) -> Option<&str> {
+ self.selected_text.as_deref()
+ }
+
pub fn stop(&mut self, cx: &mut Context<Self>) {
self.last_equal_ranges.clear();
if self.diff.is_empty() {
@@ -8,10 +8,11 @@ use editor::{
ContextMenuOptions, Editor, EditorElement, EditorEvent, EditorMode, EditorStyle, MultiBuffer,
actions::{MoveDown, MoveUp},
};
+use feature_flags::{FeatureFlag, FeatureFlagAppExt};
use fs::Fs;
use gpui::{
- AnyElement, App, Context, Entity, EventEmitter, FocusHandle, Focusable, Subscription,
- TextStyle, TextStyleRefinement, WeakEntity, Window,
+ AnyElement, App, ClipboardItem, Context, Entity, EventEmitter, FocusHandle, Focusable,
+ Subscription, TextStyle, TextStyleRefinement, WeakEntity, Window, actions,
};
use language_model::{LanguageModel, LanguageModelRegistry};
use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle};
@@ -19,14 +20,16 @@ use parking_lot::Mutex;
use project::Project;
use prompt_store::PromptStore;
use settings::Settings;
-use std::cmp;
use std::ops::Range;
use std::rc::Rc;
use std::sync::Arc;
+use std::{cmp, mem};
use theme::ThemeSettings;
use ui::utils::WithRemSize;
use ui::{IconButtonShape, KeyBinding, PopoverMenuHandle, Tooltip, prelude::*};
-use workspace::Workspace;
+use uuid::Uuid;
+use workspace::notifications::NotificationId;
+use workspace::{Toast, Workspace};
use zed_actions::agent::ToggleModelSelector;
use crate::agent_model_selector::AgentModelSelector;
@@ -39,6 +42,58 @@ use crate::mention_set::{MentionSet, crease_for_mention};
use crate::terminal_codegen::TerminalCodegen;
use crate::{CycleNextInlineAssist, CyclePreviousInlineAssist, ModelUsageContext};
+actions!(inline_assistant, [ThumbsUpResult, ThumbsDownResult]);
+
+pub struct InlineAssistRatingFeatureFlag;
+
+impl FeatureFlag for InlineAssistRatingFeatureFlag {
+ const NAME: &'static str = "inline-assist-rating";
+
+ fn enabled_for_staff() -> bool {
+ false
+ }
+}
+
+enum RatingState {
+ Pending,
+ GeneratedCompletion(Option<String>),
+ Rated(Uuid),
+}
+
+impl RatingState {
+ fn is_pending(&self) -> bool {
+ matches!(self, RatingState::Pending)
+ }
+
+ fn rating_id(&self) -> Option<Uuid> {
+ match self {
+ RatingState::Pending => None,
+ RatingState::GeneratedCompletion(_) => None,
+ RatingState::Rated(id) => Some(*id),
+ }
+ }
+
+ fn rate(&mut self) -> (Uuid, Option<String>) {
+ let id = Uuid::new_v4();
+ let old_state = mem::replace(self, RatingState::Rated(id));
+ let completion = match old_state {
+ RatingState::Pending => None,
+ RatingState::GeneratedCompletion(completion) => completion,
+ RatingState::Rated(_) => None,
+ };
+
+ (id, completion)
+ }
+
+ fn reset(&mut self) {
+ *self = RatingState::Pending;
+ }
+
+ fn generated_completion(&mut self, generated_completion: Option<String>) {
+ *self = RatingState::GeneratedCompletion(generated_completion);
+ }
+}
+
pub struct PromptEditor<T> {
pub editor: Entity<Editor>,
mode: PromptEditorMode,
@@ -54,6 +109,7 @@ pub struct PromptEditor<T> {
_codegen_subscription: Subscription,
editor_subscriptions: Vec<Subscription>,
show_rate_limit_notice: bool,
+ rated: RatingState,
_phantom: std::marker::PhantomData<T>,
}
@@ -153,6 +209,8 @@ impl<T: 'static> Render for PromptEditor<T> {
.on_action(cx.listener(Self::cancel))
.on_action(cx.listener(Self::move_up))
.on_action(cx.listener(Self::move_down))
+ .on_action(cx.listener(Self::thumbs_up))
+ .on_action(cx.listener(Self::thumbs_down))
.capture_action(cx.listener(Self::cycle_prev))
.capture_action(cx.listener(Self::cycle_next))
.child(
@@ -429,6 +487,7 @@ impl<T: 'static> PromptEditor<T> {
}
self.edited_since_done = true;
+ self.rated.reset();
cx.notify();
}
EditorEvent::Blurred => {
@@ -516,6 +575,121 @@ impl<T: 'static> PromptEditor<T> {
}
}
+ fn thumbs_up(&mut self, _: &ThumbsUpResult, _window: &mut Window, cx: &mut Context<Self>) {
+ if self.rated.is_pending() {
+ self.toast("Still generating...", None, cx);
+ return;
+ }
+
+ if let Some(rating_id) = self.rated.rating_id() {
+ self.toast("Already rated this completion", Some(rating_id), cx);
+ return;
+ }
+
+ let (rating_id, completion) = self.rated.rate();
+
+ let selected_text = match &self.mode {
+ PromptEditorMode::Buffer { codegen, .. } => {
+ codegen.read(cx).selected_text(cx).map(|s| s.to_string())
+ }
+ PromptEditorMode::Terminal { .. } => None,
+ };
+
+ let model_info = self.model_selector.read(cx).active_model(cx);
+ let model_id = {
+ let Some(configured_model) = model_info else {
+ self.toast("No configured model", None, cx);
+ return;
+ };
+
+ configured_model.model.telemetry_id()
+ };
+
+ let prompt = self.editor.read(cx).text(cx);
+
+ telemetry::event!(
+ "Inline Assistant Rated",
+ rating = "positive",
+ model = model_id,
+ prompt = prompt,
+ completion = completion,
+ selected_text = selected_text,
+ rating_id = rating_id.to_string()
+ );
+
+ cx.notify();
+ }
+
+ fn thumbs_down(&mut self, _: &ThumbsDownResult, _window: &mut Window, cx: &mut Context<Self>) {
+ if self.rated.is_pending() {
+ self.toast("Still generating...", None, cx);
+ return;
+ }
+ if let Some(rating_id) = self.rated.rating_id() {
+ self.toast("Already rated this completion", Some(rating_id), cx);
+ return;
+ }
+
+ let (rating_id, completion) = self.rated.rate();
+
+ let selected_text = match &self.mode {
+ PromptEditorMode::Buffer { codegen, .. } => {
+ codegen.read(cx).selected_text(cx).map(|s| s.to_string())
+ }
+ PromptEditorMode::Terminal { .. } => None,
+ };
+
+ let model_info = self.model_selector.read(cx).active_model(cx);
+ let model_telemetry_id = {
+ let Some(configured_model) = model_info else {
+ self.toast("No configured model", None, cx);
+ return;
+ };
+
+ configured_model.model.telemetry_id()
+ };
+
+ let prompt = self.editor.read(cx).text(cx);
+
+ telemetry::event!(
+ "Inline Assistant Rated",
+ rating = "negative",
+ model = model_telemetry_id,
+ prompt = prompt,
+ completion = completion,
+ selected_text = selected_text,
+ rating_id = rating_id.to_string()
+ );
+
+ cx.notify();
+ }
+
+ fn toast(&mut self, msg: &str, uuid: Option<Uuid>, cx: &mut Context<'_, PromptEditor<T>>) {
+ self.workspace
+ .update(cx, |workspace, cx| {
+ enum InlinePromptRating {}
+ workspace.show_toast(
+ {
+ let mut toast = Toast::new(
+ NotificationId::unique::<InlinePromptRating>(),
+ msg.to_string(),
+ )
+ .autohide();
+
+ if let Some(uuid) = uuid {
+ toast = toast.on_click("Click to copy rating ID", move |_, cx| {
+ cx.write_to_clipboard(ClipboardItem::new_string(uuid.to_string()));
+ });
+ };
+
+ toast
+ },
+ cx,
+ );
+ })
+ .ok();
+ }
+
fn move_up(&mut self, _: &MoveUp, window: &mut Window, cx: &mut Context<Self>) {
if let Some(ix) = self.prompt_history_ix {
if ix > 0 {
@@ -621,6 +795,9 @@ impl<T: 'static> PromptEditor<T> {
.into_any_element(),
]
} else {
+ let show_rating_buttons = cx.has_flag::<InlineAssistRatingFeatureFlag>();
+ let rated = self.rated.rating_id().is_some();
+
let accept = IconButton::new("accept", IconName::Check)
.icon_color(Color::Info)
.shape(IconButtonShape::Square)
@@ -632,25 +809,59 @@ impl<T: 'static> PromptEditor<T> {
}))
.into_any_element();
- match &self.mode {
- PromptEditorMode::Terminal { .. } => vec![
- accept,
- IconButton::new("confirm", IconName::PlayFilled)
- .icon_color(Color::Info)
+ let mut buttons = Vec::new();
+
+ if show_rating_buttons {
+ buttons.push(
+ IconButton::new("thumbs-down", IconName::ThumbsDown)
+ .icon_color(if rated { Color::Muted } else { Color::Default })
.shape(IconButtonShape::Square)
- .tooltip(|_window, cx| {
- Tooltip::for_action(
- "Execute Generated Command",
- &menu::SecondaryConfirm,
- cx,
- )
- })
- .on_click(cx.listener(|_, _, _, cx| {
- cx.emit(PromptEditorEvent::ConfirmRequested { execute: true });
+ .disabled(rated)
+ .tooltip(Tooltip::text("Bad result"))
+ .on_click(cx.listener(|this, _, window, cx| {
+ this.thumbs_down(&ThumbsDownResult, window, cx);
}))
.into_any_element(),
- ],
- PromptEditorMode::Buffer { .. } => vec![accept],
+ );
+
+ buttons.push(
+ IconButton::new("thumbs-up", IconName::ThumbsUp)
+ .icon_color(if rated { Color::Muted } else { Color::Default })
+ .shape(IconButtonShape::Square)
+ .disabled(rated)
+ .tooltip(Tooltip::text("Good result"))
+ .on_click(cx.listener(|this, _, window, cx| {
+ this.thumbs_up(&ThumbsUpResult, window, cx);
+ }))
+ .into_any_element(),
+ );
+ }
+
+ buttons.push(accept);
+
+ match &self.mode {
+ PromptEditorMode::Terminal { .. } => {
+ buttons.push(
+ IconButton::new("confirm", IconName::PlayFilled)
+ .icon_color(Color::Info)
+ .shape(IconButtonShape::Square)
+ .tooltip(|_window, cx| {
+ Tooltip::for_action(
+ "Execute Generated Command",
+ &menu::SecondaryConfirm,
+ cx,
+ )
+ })
+ .on_click(cx.listener(|_, _, _, cx| {
+ cx.emit(PromptEditorEvent::ConfirmRequested {
+ execute: true,
+ });
+ }))
+ .into_any_element(),
+ );
+ buttons
+ }
+ PromptEditorMode::Buffer { .. } => buttons,
}
}
}
@@ -979,6 +1190,7 @@ impl PromptEditor<BufferCodegen> {
editor_subscriptions: Vec::new(),
show_rate_limit_notice: false,
mode,
+ rated: RatingState::Pending,
_phantom: Default::default(),
};
@@ -989,7 +1201,7 @@ impl PromptEditor<BufferCodegen> {
fn handle_codegen_changed(
&mut self,
- _: Entity<BufferCodegen>,
+ codegen: Entity<BufferCodegen>,
cx: &mut Context<PromptEditor<BufferCodegen>>,
) {
match self.codegen_status(cx) {
@@ -998,10 +1210,13 @@ impl PromptEditor<BufferCodegen> {
.update(cx, |editor, _| editor.set_read_only(false));
}
CodegenStatus::Pending => {
+ self.rated.reset();
self.editor
.update(cx, |editor, _| editor.set_read_only(true));
}
CodegenStatus::Done => {
+ let completion = codegen.read(cx).active_completion(cx);
+ self.rated.generated_completion(completion);
self.edited_since_done = false;
self.editor
.update(cx, |editor, _| editor.set_read_only(false));
@@ -1122,6 +1337,7 @@ impl PromptEditor<TerminalCodegen> {
editor_subscriptions: Vec::new(),
mode,
show_rate_limit_notice: false,
+ rated: RatingState::Pending,
_phantom: Default::default(),
};
this.count_lines(cx);
@@ -1154,17 +1370,20 @@ impl PromptEditor<TerminalCodegen> {
}
}
- fn handle_codegen_changed(&mut self, _: Entity<TerminalCodegen>, cx: &mut Context<Self>) {
+ fn handle_codegen_changed(&mut self, codegen: Entity<TerminalCodegen>, cx: &mut Context<Self>) {
match &self.codegen().read(cx).status {
CodegenStatus::Idle => {
self.editor
.update(cx, |editor, _| editor.set_read_only(false));
}
CodegenStatus::Pending => {
+ self.rated = RatingState::Pending;
self.editor
.update(cx, |editor, _| editor.set_read_only(true));
}
CodegenStatus::Done | CodegenStatus::Error(_) => {
+ self.rated
+ .generated_completion(codegen.read(cx).completion());
self.edited_since_done = false;
self.editor
.update(cx, |editor, _| editor.set_read_only(false));
@@ -542,7 +542,7 @@ impl PickerDelegate for ProfilePickerDelegate {
let is_active = active_id == candidate.id;
Some(
- ListItem::new(SharedString::from(candidate.id.0.clone()))
+ ListItem::new(candidate.id.0.clone())
.inset(true)
.spacing(ListItemSpacing::Sparse)
.toggle_state(selected)
@@ -135,6 +135,12 @@ impl TerminalCodegen {
cx.notify();
}
+ pub fn completion(&self) -> Option<String> {
+ self.transaction
+ .as_ref()
+ .map(|transaction| transaction.completion.clone())
+ }
+
pub fn stop(&mut self, cx: &mut Context<Self>) {
self.status = CodegenStatus::Done;
self.generation = Task::ready(());
@@ -167,27 +173,32 @@ pub const CLEAR_INPUT: &str = "\x03";
const CARRIAGE_RETURN: &str = "\x0d";
struct TerminalTransaction {
+ completion: String,
terminal: Entity<Terminal>,
}
impl TerminalTransaction {
pub fn start(terminal: Entity<Terminal>) -> Self {
- Self { terminal }
+ Self {
+ completion: String::new(),
+ terminal,
+ }
}
pub fn push(&mut self, hunk: String, cx: &mut App) {
// Ensure that the assistant cannot accidentally execute commands that are streamed into the terminal
let input = Self::sanitize_input(hunk);
+ self.completion.push_str(&input);
self.terminal
.update(cx, |terminal, _| terminal.input(input.into_bytes()));
}
- pub fn undo(&self, cx: &mut App) {
+ pub fn undo(self, cx: &mut App) {
self.terminal
.update(cx, |terminal, _| terminal.input(CLEAR_INPUT.as_bytes()));
}
- pub fn complete(&self, cx: &mut App) {
+ pub fn complete(self, cx: &mut App) {
self.terminal
.update(cx, |terminal, _| terminal.input(CARRIAGE_RETURN.as_bytes()));
}
@@ -32,7 +32,7 @@ struct Detect;
trait InstalledApp {
fn zed_version_string(&self) -> String;
- fn launch(&self, ipc_url: String) -> anyhow::Result<()>;
+ fn launch(&self, ipc_url: String, user_data_dir: Option<&str>) -> anyhow::Result<()>;
fn run_foreground(
&self,
ipc_url: String,
@@ -588,7 +588,7 @@ fn main() -> Result<()> {
if args.foreground {
app.run_foreground(url, user_data_dir.as_deref())?;
} else {
- app.launch(url)?;
+ app.launch(url, user_data_dir.as_deref())?;
sender.join().unwrap()?;
if let Some(handle) = stdin_pipe_handle {
handle.join().unwrap()?;
@@ -709,14 +709,18 @@ mod linux {
)
}
- fn launch(&self, ipc_url: String) -> anyhow::Result<()> {
- let sock_path = paths::data_dir().join(format!(
+ fn launch(&self, ipc_url: String, user_data_dir: Option<&str>) -> anyhow::Result<()> {
+ let data_dir = user_data_dir
+ .map(PathBuf::from)
+ .unwrap_or_else(|| paths::data_dir().clone());
+
+ let sock_path = data_dir.join(format!(
"zed-{}.sock",
*release_channel::RELEASE_CHANNEL_NAME
));
let sock = UnixDatagram::unbound()?;
if sock.connect(&sock_path).is_err() {
- self.boot_background(ipc_url)?;
+ self.boot_background(ipc_url, user_data_dir)?;
} else {
sock.send(ipc_url.as_bytes())?;
}
@@ -742,7 +746,11 @@ mod linux {
}
impl App {
- fn boot_background(&self, ipc_url: String) -> anyhow::Result<()> {
+ fn boot_background(
+ &self,
+ ipc_url: String,
+ user_data_dir: Option<&str>,
+ ) -> anyhow::Result<()> {
let path = &self.0;
match fork::fork() {
@@ -756,8 +764,13 @@ mod linux {
if fork::close_fd().is_err() {
eprintln!("failed to close_fd: {}", std::io::Error::last_os_error());
}
- let error =
- exec::execvp(path.clone(), &[path.as_os_str(), &OsString::from(ipc_url)]);
+ let mut args: Vec<OsString> =
+ vec![path.as_os_str().to_owned(), OsString::from(ipc_url)];
+ if let Some(dir) = user_data_dir {
+ args.push(OsString::from("--user-data-dir"));
+ args.push(OsString::from(dir));
+ }
+ let error = exec::execvp(path.clone(), &args);
// if exec succeeded, we never get here.
eprintln!("failed to exec {:?}: {}", path, error);
process::exit(1)
@@ -943,11 +956,14 @@ mod windows {
)
}
- fn launch(&self, ipc_url: String) -> anyhow::Result<()> {
+ fn launch(&self, ipc_url: String, user_data_dir: Option<&str>) -> anyhow::Result<()> {
if check_single_instance() {
- std::process::Command::new(self.0.clone())
- .arg(ipc_url)
- .spawn()?;
+ let mut cmd = std::process::Command::new(self.0.clone());
+ cmd.arg(ipc_url);
+ if let Some(dir) = user_data_dir {
+ cmd.arg("--user-data-dir").arg(dir);
+ }
+ cmd.spawn()?;
} else {
unsafe {
let pipe = CreateFileW(
@@ -1096,7 +1112,7 @@ mod mac_os {
format!("Zed {} – {}", self.version(), self.path().display(),)
}
- fn launch(&self, url: String) -> anyhow::Result<()> {
+ fn launch(&self, url: String, user_data_dir: Option<&str>) -> anyhow::Result<()> {
match self {
Self::App { app_bundle, .. } => {
let app_path = app_bundle;
@@ -1146,8 +1162,11 @@ mod mac_os {
format!("Cloning descriptor for file {subprocess_stdout_file:?}")
})?;
let mut command = std::process::Command::new(executable);
- let command = command
- .env(FORCE_CLI_MODE_ENV_VAR_NAME, "")
+ command.env(FORCE_CLI_MODE_ENV_VAR_NAME, "");
+ if let Some(dir) = user_data_dir {
+ command.arg("--user-data-dir").arg(dir);
+ }
+ command
.stderr(subprocess_stdout_file)
.stdout(subprocess_stdin_file)
.arg(url);
@@ -53,7 +53,7 @@ text.workspace = true
thiserror.workspace = true
time.workspace = true
tiny_http.workspace = true
-tokio-socks = { version = "0.5.2", default-features = false, features = ["futures-io"] }
+tokio-socks.workspace = true
tokio.workspace = true
url.workspace = true
util.workspace = true
@@ -1,18 +0,0 @@
-[package]
-name = "cloud_zeta2_prompt"
-version = "0.1.0"
-publish.workspace = true
-edition.workspace = true
-license = "GPL-3.0-or-later"
-
-[lints]
-workspace = true
-
-[lib]
-path = "src/cloud_zeta2_prompt.rs"
-
-[dependencies]
-anyhow.workspace = true
-cloud_llm_client.workspace = true
-indoc.workspace = true
-serde.workspace = true
@@ -1,485 +0,0 @@
-use anyhow::Result;
-use cloud_llm_client::predict_edits_v3::{
- self, DiffPathFmt, Event, Excerpt, Line, Point, PromptFormat, RelatedFile,
-};
-use indoc::indoc;
-use std::cmp;
-use std::fmt::Write;
-use std::path::Path;
-use std::sync::Arc;
-
-pub const DEFAULT_MAX_PROMPT_BYTES: usize = 10 * 1024;
-
-pub const CURSOR_MARKER: &str = "<|user_cursor|>";
-/// NOTE: Differs from zed version of constant - includes a newline
-pub const EDITABLE_REGION_START_MARKER_WITH_NEWLINE: &str = "<|editable_region_start|>\n";
-/// NOTE: Differs from zed version of constant - includes a newline
-pub const EDITABLE_REGION_END_MARKER_WITH_NEWLINE: &str = "<|editable_region_end|>\n";
-
-const STUDENT_MODEL_INSTRUCTIONS: &str = indoc! {r#"
- You are a code completion assistant that analyzes edit history to identify and systematically complete incomplete refactorings or patterns across the entire codebase.
-
- ## Edit History
-
- "#};
-
-const MINIMAL_PROMPT_REMINDER: &str = indoc! {"
- ---
-
- Please analyze the edit history and the files, then provide the unified diff for your predicted edits.
- Do not include the cursor marker in your output.
- If you're editing multiple files, be sure to reflect filename in the hunk's header.
- "};
-
-const XML_TAGS_INSTRUCTIONS: &str = indoc! {r#"
- # Instructions
-
- You are an edit prediction agent in a code editor.
-
- Analyze the history of edits made by the user in order to infer what they are currently trying to accomplish.
- Then complete the remainder of the current change if it is incomplete, or predict the next edit the user intends to make.
- Always continue along the user's current trajectory, rather than changing course.
-
- ## Output Format
-
- You should briefly explain your understanding of the user's overall goal in one sentence, then explain what the next change
- along the users current trajectory will be in another, and finally specify the next edit using the following XML-like format:
-
- <edits path="my-project/src/myapp/cli.py">
- <old_text>
- OLD TEXT 1 HERE
- </old_text>
- <new_text>
- NEW TEXT 1 HERE
- </new_text>
-
- <old_text>
- OLD TEXT 1 HERE
- </old_text>
- <new_text>
- NEW TEXT 1 HERE
- </new_text>
- </edits>
-
- - Specify the file to edit using the `path` attribute.
- - Use `<old_text>` and `<new_text>` tags to replace content
- - `<old_text>` must exactly match existing file content, including indentation
- - `<old_text>` cannot be empty
- - Do not escape quotes, newlines, or other characters within tags
- - Always close all tags properly
- - Don't include the <|user_cursor|> marker in your output.
-
- ## Edit History
-
-"#};
-
-const OLD_TEXT_NEW_TEXT_REMINDER: &str = indoc! {r#"
- ---
-
- Remember that the edits in the edit history have already been applied.
-"#};
-
-pub fn build_prompt(request: &predict_edits_v3::PredictEditsRequest) -> Result<String> {
- let prompt_data = PromptData {
- events: request.events.clone(),
- cursor_point: request.cursor_point,
- cursor_path: request.excerpt_path.clone(),
- included_files: request.related_files.clone(),
- };
- match request.prompt_format {
- PromptFormat::MinimalQwen => {
- return Ok(MinimalQwenPrompt.render(&prompt_data));
- }
- PromptFormat::SeedCoder1120 => {
- return Ok(SeedCoder1120Prompt.render(&prompt_data));
- }
- _ => (),
- };
-
- let insertions = match request.prompt_format {
- PromptFormat::Minimal | PromptFormat::OldTextNewText => {
- vec![(request.cursor_point, CURSOR_MARKER)]
- }
- PromptFormat::OnlySnippets => vec![],
- PromptFormat::MinimalQwen => unreachable!(),
- PromptFormat::SeedCoder1120 => unreachable!(),
- };
-
- let mut prompt = match request.prompt_format {
- PromptFormat::OldTextNewText => XML_TAGS_INSTRUCTIONS.to_string(),
- PromptFormat::OnlySnippets => String::new(),
- PromptFormat::Minimal => STUDENT_MODEL_INSTRUCTIONS.to_string(),
- PromptFormat::MinimalQwen => unreachable!(),
- PromptFormat::SeedCoder1120 => unreachable!(),
- };
-
- if request.events.is_empty() {
- prompt.push_str("(No edit history)\n\n");
- } else {
- let edit_preamble = if request.prompt_format == PromptFormat::Minimal {
- "The following are the latest edits made by the user, from earlier to later.\n\n"
- } else {
- "Here are the latest edits made by the user, from earlier to later.\n\n"
- };
- prompt.push_str(edit_preamble);
- push_events(&mut prompt, &request.events);
- }
-
- let excerpts_preamble = match request.prompt_format {
- PromptFormat::Minimal => indoc! {"
- ## Part of the file under the cursor
-
- (The cursor marker <|user_cursor|> indicates the current user cursor position.
- The file is in current state, edits from edit history has been applied.
- We only show part of the file around the cursor.
- You can only edit exactly this part of the file.
- We prepend line numbers (e.g., `123|<actual line>`); they are not part of the file.)
- "},
- PromptFormat::OldTextNewText => indoc! {"
- ## Code Excerpts
-
- Here is some excerpts of code that you should take into account to predict the next edit.
-
- The cursor position is marked by `<|user_cursor|>` as it stands after the last edit in the history.
-
- In addition other excerpts are included to better understand what the edit will be, including the declaration
- or references of symbols around the cursor, or other similar code snippets that may need to be updated
- following patterns that appear in the edit history.
-
- Consider each of them carefully in relation to the edit history, and that the user may not have navigated
- to the next place they want to edit yet.
-
- Lines starting with `…` indicate omitted line ranges. These may appear inside multi-line code constructs.
- "},
- PromptFormat::OnlySnippets | PromptFormat::MinimalQwen | PromptFormat::SeedCoder1120 => {
- indoc! {"
- ## Code Excerpts
-
- The cursor marker <|user_cursor|> indicates the current user cursor position.
- The file is in current state, edits from edit history have been applied.
- "}
- }
- };
-
- prompt.push_str(excerpts_preamble);
- prompt.push('\n');
-
- let include_line_numbers = matches!(request.prompt_format, PromptFormat::Minimal);
- for related_file in &request.related_files {
- if request.prompt_format == PromptFormat::Minimal {
- write_codeblock_with_filename(
- &related_file.path,
- &related_file.excerpts,
- if related_file.path == request.excerpt_path {
- &insertions
- } else {
- &[]
- },
- related_file.max_row,
- include_line_numbers,
- &mut prompt,
- );
- } else {
- write_codeblock(
- &related_file.path,
- &related_file.excerpts,
- if related_file.path == request.excerpt_path {
- &insertions
- } else {
- &[]
- },
- related_file.max_row,
- include_line_numbers,
- &mut prompt,
- );
- }
- }
-
- match request.prompt_format {
- PromptFormat::OldTextNewText => {
- prompt.push_str(OLD_TEXT_NEW_TEXT_REMINDER);
- }
- PromptFormat::Minimal => {
- prompt.push_str(MINIMAL_PROMPT_REMINDER);
- }
- _ => {}
- }
-
- Ok(prompt)
-}
-
-pub fn generation_params(prompt_format: PromptFormat) -> GenerationParams {
- match prompt_format {
- PromptFormat::SeedCoder1120 => SeedCoder1120Prompt::generation_params(),
- _ => GenerationParams::default(),
- }
-}
-
-pub fn write_codeblock<'a>(
- path: &Path,
- excerpts: impl IntoIterator<Item = &'a Excerpt>,
- sorted_insertions: &[(Point, &str)],
- file_line_count: Line,
- include_line_numbers: bool,
- output: &'a mut String,
-) {
- writeln!(output, "`````{}", DiffPathFmt(path)).unwrap();
-
- write_excerpts(
- excerpts,
- sorted_insertions,
- file_line_count,
- include_line_numbers,
- output,
- );
- write!(output, "`````\n\n").unwrap();
-}
-
-fn write_codeblock_with_filename<'a>(
- path: &Path,
- excerpts: impl IntoIterator<Item = &'a Excerpt>,
- sorted_insertions: &[(Point, &str)],
- file_line_count: Line,
- include_line_numbers: bool,
- output: &'a mut String,
-) {
- writeln!(output, "`````filename={}", DiffPathFmt(path)).unwrap();
-
- write_excerpts(
- excerpts,
- sorted_insertions,
- file_line_count,
- include_line_numbers,
- output,
- );
- write!(output, "`````\n\n").unwrap();
-}
-
-pub fn write_excerpts<'a>(
- excerpts: impl IntoIterator<Item = &'a Excerpt>,
- sorted_insertions: &[(Point, &str)],
- file_line_count: Line,
- include_line_numbers: bool,
- output: &mut String,
-) {
- let mut current_row = Line(0);
- let mut sorted_insertions = sorted_insertions.iter().peekable();
-
- for excerpt in excerpts {
- if excerpt.start_line > current_row {
- writeln!(output, "…").unwrap();
- }
- if excerpt.text.is_empty() {
- return;
- }
-
- current_row = excerpt.start_line;
-
- for mut line in excerpt.text.lines() {
- if include_line_numbers {
- write!(output, "{}|", current_row.0 + 1).unwrap();
- }
-
- while let Some((insertion_location, insertion_marker)) = sorted_insertions.peek() {
- match current_row.cmp(&insertion_location.line) {
- cmp::Ordering::Equal => {
- let (prefix, suffix) = line.split_at(insertion_location.column as usize);
- output.push_str(prefix);
- output.push_str(insertion_marker);
- line = suffix;
- sorted_insertions.next();
- }
- cmp::Ordering::Less => break,
- cmp::Ordering::Greater => {
- sorted_insertions.next();
- break;
- }
- }
- }
- output.push_str(line);
- output.push('\n');
- current_row.0 += 1;
- }
- }
-
- if current_row < file_line_count {
- writeln!(output, "…").unwrap();
- }
-}
-
-pub fn push_events(output: &mut String, events: &[Arc<predict_edits_v3::Event>]) {
- if events.is_empty() {
- return;
- };
-
- writeln!(output, "`````diff").unwrap();
- for event in events {
- writeln!(output, "{}", event).unwrap();
- }
- writeln!(output, "`````\n").unwrap();
-}
-
-struct PromptData {
- events: Vec<Arc<Event>>,
- cursor_point: Point,
- cursor_path: Arc<Path>, // TODO: make a common struct with cursor_point
- included_files: Vec<RelatedFile>,
-}
-
-#[derive(Default)]
-pub struct GenerationParams {
- pub temperature: Option<f32>,
- pub top_p: Option<f32>,
- pub stop: Option<Vec<String>>,
-}
-
-trait PromptFormatter {
- fn render(&self, data: &PromptData) -> String;
-
- fn generation_params() -> GenerationParams {
- return GenerationParams::default();
- }
-}
-
-struct MinimalQwenPrompt;
-
-impl PromptFormatter for MinimalQwenPrompt {
- fn render(&self, data: &PromptData) -> String {
- let edit_history = self.fmt_edit_history(data);
- let context = self.fmt_context(data);
-
- format!(
- "{instructions}\n\n{edit_history}\n\n{context}",
- instructions = MinimalQwenPrompt::INSTRUCTIONS,
- edit_history = edit_history,
- context = context
- )
- }
-}
-
-impl MinimalQwenPrompt {
- const INSTRUCTIONS: &str = "You are a code completion assistant that analyzes edit history to identify and systematically complete incomplete refactorings or patterns across the entire codebase.\n";
-
- fn fmt_edit_history(&self, data: &PromptData) -> String {
- if data.events.is_empty() {
- "(No edit history)\n\n".to_string()
- } else {
- let mut events_str = String::new();
- push_events(&mut events_str, &data.events);
- format!(
- "The following are the latest edits made by the user, from earlier to later.\n\n{}",
- events_str
- )
- }
- }
-
- fn fmt_context(&self, data: &PromptData) -> String {
- let mut context = String::new();
- let include_line_numbers = true;
-
- for related_file in &data.included_files {
- writeln!(context, "<|file_sep|>{}", DiffPathFmt(&related_file.path)).unwrap();
-
- if related_file.path == data.cursor_path {
- write!(context, "<|fim_prefix|>").unwrap();
- write_excerpts(
- &related_file.excerpts,
- &[(data.cursor_point, "<|fim_suffix|>")],
- related_file.max_row,
- include_line_numbers,
- &mut context,
- );
- writeln!(context, "<|fim_middle|>").unwrap();
- } else {
- write_excerpts(
- &related_file.excerpts,
- &[],
- related_file.max_row,
- include_line_numbers,
- &mut context,
- );
- }
- }
- context
- }
-}
-
-struct SeedCoder1120Prompt;
-
-impl PromptFormatter for SeedCoder1120Prompt {
- fn render(&self, data: &PromptData) -> String {
- let edit_history = self.fmt_edit_history(data);
- let context = self.fmt_context(data);
-
- format!(
- "# Edit History:\n{edit_history}\n\n{context}",
- edit_history = edit_history,
- context = context
- )
- }
-
- fn generation_params() -> GenerationParams {
- GenerationParams {
- temperature: Some(0.2),
- top_p: Some(0.9),
- stop: Some(vec!["<[end_of_sentence]>".into()]),
- }
- }
-}
-
-impl SeedCoder1120Prompt {
- fn fmt_edit_history(&self, data: &PromptData) -> String {
- if data.events.is_empty() {
- "(No edit history)\n\n".to_string()
- } else {
- let mut events_str = String::new();
- push_events(&mut events_str, &data.events);
- events_str
- }
- }
-
- fn fmt_context(&self, data: &PromptData) -> String {
- let mut context = String::new();
- let include_line_numbers = true;
-
- for related_file in &data.included_files {
- writeln!(context, "# Path: {}\n", DiffPathFmt(&related_file.path)).unwrap();
-
- if related_file.path == data.cursor_path {
- let fim_prompt = self.fmt_fim(&related_file, data.cursor_point);
- context.push_str(&fim_prompt);
- } else {
- write_excerpts(
- &related_file.excerpts,
- &[],
- related_file.max_row,
- include_line_numbers,
- &mut context,
- );
- }
- }
- context
- }
-
- fn fmt_fim(&self, file: &RelatedFile, cursor_point: Point) -> String {
- let mut buf = String::new();
- const FIM_SUFFIX: &str = "<[fim-suffix]>";
- const FIM_PREFIX: &str = "<[fim-prefix]>";
- const FIM_MIDDLE: &str = "<[fim-middle]>";
- write!(buf, "{}", FIM_PREFIX).unwrap();
- write_excerpts(
- &file.excerpts,
- &[(cursor_point, FIM_SUFFIX)],
- file.max_row,
- true,
- &mut buf,
- );
-
- // Swap prefix and suffix parts
- let index = buf.find(FIM_SUFFIX).unwrap();
- let prefix = &buf[..index];
- let suffix = &buf[index..];
-
- format!("{}{}{}", suffix, prefix, FIM_MIDDLE)
- }
-}
@@ -2,7 +2,7 @@ mod persistence;
use std::{
cmp::{self, Reverse},
- collections::HashMap,
+ collections::{HashMap, VecDeque},
sync::Arc,
time::Duration,
};
@@ -19,6 +19,7 @@ use gpui::{
ParentElement, Render, Styled, Task, WeakEntity, Window,
};
use persistence::COMMAND_PALETTE_HISTORY;
+use picker::Direction;
use picker::{Picker, PickerDelegate};
use postage::{sink::Sink, stream::Stream};
use settings::Settings;
@@ -163,6 +164,7 @@ pub struct CommandPaletteDelegate {
Task<()>,
postage::dispatch::Receiver<(Vec<Command>, Vec<StringMatch>, CommandInterceptResult)>,
)>,
+ query_history: QueryHistory,
}
struct Command {
@@ -170,6 +172,91 @@ struct Command {
action: Box<dyn Action>,
}
+#[derive(Default)]
+struct QueryHistory {
+ history: Option<VecDeque<String>>,
+ cursor: Option<usize>,
+ prefix: Option<String>,
+}
+
+impl QueryHistory {
+ fn history(&mut self) -> &mut VecDeque<String> {
+ self.history.get_or_insert_with(|| {
+ COMMAND_PALETTE_HISTORY
+ .list_recent_queries()
+ .unwrap_or_default()
+ .into_iter()
+ .collect()
+ })
+ }
+
+ fn add(&mut self, query: String) {
+ if let Some(pos) = self.history().iter().position(|h| h == &query) {
+ self.history().remove(pos);
+ }
+ self.history().push_back(query);
+ self.cursor = None;
+ self.prefix = None;
+ }
+
+ fn validate_cursor(&mut self, current_query: &str) -> Option<usize> {
+ if let Some(pos) = self.cursor {
+ if self.history().get(pos).map(|s| s.as_str()) != Some(current_query) {
+ self.cursor = None;
+ self.prefix = None;
+ }
+ }
+ self.cursor
+ }
+
+ fn previous(&mut self, current_query: &str) -> Option<&str> {
+ if self.validate_cursor(current_query).is_none() {
+ self.prefix = Some(current_query.to_string());
+ }
+
+ let prefix = self.prefix.clone().unwrap_or_default();
+ let start_index = self.cursor.unwrap_or(self.history().len());
+
+ for i in (0..start_index).rev() {
+ if self
+ .history()
+ .get(i)
+ .is_some_and(|e| e.starts_with(&prefix))
+ {
+ self.cursor = Some(i);
+ return self.history().get(i).map(|s| s.as_str());
+ }
+ }
+ None
+ }
+
+ fn next(&mut self, current_query: &str) -> Option<&str> {
+ let selected = self.validate_cursor(current_query)?;
+ let prefix = self.prefix.clone().unwrap_or_default();
+
+ for i in (selected + 1)..self.history().len() {
+ if self
+ .history()
+ .get(i)
+ .is_some_and(|e| e.starts_with(&prefix))
+ {
+ self.cursor = Some(i);
+ return self.history().get(i).map(|s| s.as_str());
+ }
+ }
+ None
+ }
+
+ fn reset_cursor(&mut self) {
+ self.cursor = None;
+ self.prefix = None;
+ }
+
+ fn is_navigating(&self) -> bool {
+ self.cursor.is_some()
+ }
+}
+
impl Clone for Command {
fn clone(&self) -> Self {
Self {
@@ -196,6 +283,7 @@ impl CommandPaletteDelegate {
previous_focus_handle,
latest_query: String::new(),
updating_matches: None,
+ query_history: Default::default(),
}
}
@@ -271,6 +359,11 @@ impl CommandPaletteDelegate {
// so we need to return an Option here
self.commands.get(action_ix)
}
+
+ #[cfg(any(test, feature = "test-support"))]
+ pub fn seed_history(&mut self, queries: &[&str]) {
+ self.query_history.history = Some(queries.iter().map(|s| s.to_string()).collect());
+ }
}
impl PickerDelegate for CommandPaletteDelegate {
@@ -280,6 +373,38 @@ impl PickerDelegate for CommandPaletteDelegate {
"Execute a command...".into()
}
+ fn select_history(
+ &mut self,
+ direction: Direction,
+ query: &str,
+ _window: &mut Window,
+ _cx: &mut App,
+ ) -> Option<String> {
+ match direction {
+ Direction::Up => {
+ let should_use_history =
+ self.selected_ix == 0 || self.query_history.is_navigating();
+ if should_use_history {
+ if let Some(query) = self.query_history.previous(query).map(|s| s.to_string()) {
+ return Some(query);
+ }
+ }
+ }
+ Direction::Down => {
+ if self.query_history.is_navigating() {
+ if let Some(query) = self.query_history.next(query).map(|s| s.to_string()) {
+ return Some(query);
+ } else {
+ let prefix = self.query_history.prefix.take().unwrap_or_default();
+ self.query_history.reset_cursor();
+ return Some(prefix);
+ }
+ }
+ }
+ }
+ None
+ }
+
fn match_count(&self) -> usize {
self.matches.len()
}
@@ -439,6 +564,12 @@ impl PickerDelegate for CommandPaletteDelegate {
self.dismissed(window, cx);
return;
}
+
+ if !self.latest_query.is_empty() {
+ self.query_history.add(self.latest_query.clone());
+ self.query_history.reset_cursor();
+ }
+
let action_ix = self.matches[self.selected_ix].candidate_id;
let command = self.commands.swap_remove(action_ix);
telemetry::event!(
@@ -588,7 +719,7 @@ mod tests {
use super::*;
use editor::Editor;
use go_to_line::GoToLine;
- use gpui::TestAppContext;
+ use gpui::{TestAppContext, VisualTestContext};
use language::Point;
use project::Project;
use settings::KeymapFile;
@@ -799,7 +930,9 @@ mod tests {
"bindings": {
"cmd-n": "workspace::NewFile",
"enter": "menu::Confirm",
- "cmd-shift-p": "command_palette::Toggle"
+ "cmd-shift-p": "command_palette::Toggle",
+ "up": "menu::SelectPrevious",
+ "down": "menu::SelectNext"
}
}
]"#,
@@ -808,4 +941,264 @@ mod tests {
app_state
})
}
+
+ fn open_palette_with_history(
+ workspace: &Entity<Workspace>,
+ history: &[&str],
+ cx: &mut VisualTestContext,
+ ) -> Entity<Picker<CommandPaletteDelegate>> {
+ cx.simulate_keystrokes("cmd-shift-p");
+ cx.run_until_parked();
+
+ let palette = workspace.update(cx, |workspace, cx| {
+ workspace
+ .active_modal::<CommandPalette>(cx)
+ .unwrap()
+ .read(cx)
+ .picker
+ .clone()
+ });
+
+ palette.update(cx, |palette, _cx| {
+ palette.delegate.seed_history(history);
+ });
+
+ palette
+ }
+
+ #[gpui::test]
+ async fn test_history_navigation_basic(cx: &mut TestAppContext) {
+ let app_state = init_test(cx);
+ let project = Project::test(app_state.fs.clone(), [], cx).await;
+ let (workspace, cx) =
+ cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
+
+ let palette = open_palette_with_history(&workspace, &["backspace", "select all"], cx);
+
+ // Query should be empty initially
+ palette.read_with(cx, |palette, cx| {
+ assert_eq!(palette.query(cx), "");
+ });
+
+ // Press up - should load most recent query "select all"
+ cx.simulate_keystrokes("up");
+ cx.background_executor.run_until_parked();
+ palette.read_with(cx, |palette, cx| {
+ assert_eq!(palette.query(cx), "select all");
+ });
+
+ // Press up again - should load "backspace"
+ cx.simulate_keystrokes("up");
+ cx.background_executor.run_until_parked();
+ palette.read_with(cx, |palette, cx| {
+ assert_eq!(palette.query(cx), "backspace");
+ });
+
+ // Press down - should go back to "select all"
+ cx.simulate_keystrokes("down");
+ cx.background_executor.run_until_parked();
+ palette.read_with(cx, |palette, cx| {
+ assert_eq!(palette.query(cx), "select all");
+ });
+
+ // Press down again - should clear query (exit history mode)
+ cx.simulate_keystrokes("down");
+ cx.background_executor.run_until_parked();
+ palette.read_with(cx, |palette, cx| {
+ assert_eq!(palette.query(cx), "");
+ });
+ }
+
+ #[gpui::test]
+ async fn test_history_mode_exit_on_typing(cx: &mut TestAppContext) {
+ let app_state = init_test(cx);
+ let project = Project::test(app_state.fs.clone(), [], cx).await;
+ let (workspace, cx) =
+ cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
+
+ let palette = open_palette_with_history(&workspace, &["backspace"], cx);
+
+ // Press up to enter history mode
+ cx.simulate_keystrokes("up");
+ cx.background_executor.run_until_parked();
+ palette.read_with(cx, |palette, cx| {
+ assert_eq!(palette.query(cx), "backspace");
+ });
+
+ // Type something - should append to the history query
+ cx.simulate_input("x");
+ cx.background_executor.run_until_parked();
+ palette.read_with(cx, |palette, cx| {
+ assert_eq!(palette.query(cx), "backspacex");
+ });
+ }
+
+ #[gpui::test]
+ async fn test_history_navigation_with_suggestions(cx: &mut TestAppContext) {
+ let app_state = init_test(cx);
+ let project = Project::test(app_state.fs.clone(), [], cx).await;
+ let (workspace, cx) =
+ cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
+
+ let palette = open_palette_with_history(&workspace, &["editor: close", "editor: open"], cx);
+
+ // Open palette with a query that has multiple matches
+ cx.simulate_input("editor");
+ cx.background_executor.run_until_parked();
+
+ // Should have multiple matches, selected_ix should be 0
+ palette.read_with(cx, |palette, _| {
+ assert!(palette.delegate.matches.len() > 1);
+ assert_eq!(palette.delegate.selected_ix, 0);
+ });
+
+ // Press down - should navigate to next suggestion (not history)
+ cx.simulate_keystrokes("down");
+ cx.background_executor.run_until_parked();
+ palette.read_with(cx, |palette, _| {
+ assert_eq!(palette.delegate.selected_ix, 1);
+ });
+
+ // Press up - should go back to first suggestion
+ cx.simulate_keystrokes("up");
+ cx.background_executor.run_until_parked();
+ palette.read_with(cx, |palette, _| {
+ assert_eq!(palette.delegate.selected_ix, 0);
+ });
+
+ // Press up again at top - should enter history mode and show previous query
+ // that matches the "editor" prefix
+ cx.simulate_keystrokes("up");
+ cx.background_executor.run_until_parked();
+ palette.read_with(cx, |palette, cx| {
+ assert_eq!(palette.query(cx), "editor: open");
+ });
+ }
+
+ #[gpui::test]
+ async fn test_history_prefix_search(cx: &mut TestAppContext) {
+ let app_state = init_test(cx);
+ let project = Project::test(app_state.fs.clone(), [], cx).await;
+ let (workspace, cx) =
+ cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
+
+ let palette = open_palette_with_history(
+ &workspace,
+ &["open file", "select all", "select line", "backspace"],
+ cx,
+ );
+
+ // Type "sel" as a prefix
+ cx.simulate_input("sel");
+ cx.background_executor.run_until_parked();
+
+ // Press up - should get "select line" (most recent matching "sel")
+ cx.simulate_keystrokes("up");
+ cx.background_executor.run_until_parked();
+ palette.read_with(cx, |palette, cx| {
+ assert_eq!(palette.query(cx), "select line");
+ });
+
+ // Press up again - should get "select all" (next matching "sel")
+ cx.simulate_keystrokes("up");
+ cx.background_executor.run_until_parked();
+ palette.read_with(cx, |palette, cx| {
+ assert_eq!(palette.query(cx), "select all");
+ });
+
+ // Press up again - should stay at "select all" (no more matches for "sel")
+ cx.simulate_keystrokes("up");
+ cx.background_executor.run_until_parked();
+ palette.read_with(cx, |palette, cx| {
+ assert_eq!(palette.query(cx), "select all");
+ });
+
+ // Press down - should go back to "select line"
+ cx.simulate_keystrokes("down");
+ cx.background_executor.run_until_parked();
+ palette.read_with(cx, |palette, cx| {
+ assert_eq!(palette.query(cx), "select line");
+ });
+
+ // Press down again - should return to original prefix "sel"
+ cx.simulate_keystrokes("down");
+ cx.background_executor.run_until_parked();
+ palette.read_with(cx, |palette, cx| {
+ assert_eq!(palette.query(cx), "sel");
+ });
+ }
+
+ #[gpui::test]
+ async fn test_history_prefix_search_no_matches(cx: &mut TestAppContext) {
+ let app_state = init_test(cx);
+ let project = Project::test(app_state.fs.clone(), [], cx).await;
+ let (workspace, cx) =
+ cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
+
+ let palette =
+ open_palette_with_history(&workspace, &["open file", "backspace", "select all"], cx);
+
+ // Type "xyz" as a prefix that doesn't match anything
+ cx.simulate_input("xyz");
+ cx.background_executor.run_until_parked();
+
+ // Press up - should stay at "xyz" (no matches)
+ cx.simulate_keystrokes("up");
+ cx.background_executor.run_until_parked();
+ palette.read_with(cx, |palette, cx| {
+ assert_eq!(palette.query(cx), "xyz");
+ });
+ }
+
+ #[gpui::test]
+ async fn test_history_empty_prefix_searches_all(cx: &mut TestAppContext) {
+ let app_state = init_test(cx);
+ let project = Project::test(app_state.fs.clone(), [], cx).await;
+ let (workspace, cx) =
+ cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
+
+ let palette = open_palette_with_history(&workspace, &["alpha", "beta", "gamma"], cx);
+
+ // With empty query, press up - should get "gamma" (most recent)
+ cx.simulate_keystrokes("up");
+ cx.background_executor.run_until_parked();
+ palette.read_with(cx, |palette, cx| {
+ assert_eq!(palette.query(cx), "gamma");
+ });
+
+ // Press up - should get "beta"
+ cx.simulate_keystrokes("up");
+ cx.background_executor.run_until_parked();
+ palette.read_with(cx, |palette, cx| {
+ assert_eq!(palette.query(cx), "beta");
+ });
+
+ // Press up - should get "alpha"
+ cx.simulate_keystrokes("up");
+ cx.background_executor.run_until_parked();
+ palette.read_with(cx, |palette, cx| {
+ assert_eq!(palette.query(cx), "alpha");
+ });
+
+ // Press down - should get "beta"
+ cx.simulate_keystrokes("down");
+ cx.background_executor.run_until_parked();
+ palette.read_with(cx, |palette, cx| {
+ assert_eq!(palette.query(cx), "beta");
+ });
+
+ // Press down - should get "gamma"
+ cx.simulate_keystrokes("down");
+ cx.background_executor.run_until_parked();
+ palette.read_with(cx, |palette, cx| {
+ assert_eq!(palette.query(cx), "gamma");
+ });
+
+ // Press down - should return to empty string (exit history mode)
+ cx.simulate_keystrokes("down");
+ cx.background_executor.run_until_parked();
+ palette.read_with(cx, |palette, cx| {
+ assert_eq!(palette.query(cx), "");
+ });
+ }
}
@@ -123,6 +123,16 @@ impl CommandPaletteDB {
ORDER BY COUNT(1) DESC
}
}
+
+ query! {
+ pub fn list_recent_queries() -> Result<Vec<String>> {
+ SELECT user_query
+ FROM command_invocations
+ WHERE user_query != ""
+ GROUP BY user_query
+ ORDER BY MAX(last_invoked) ASC
+ }
+ }
}
#[cfg(test)]
@@ -33,6 +33,7 @@ smol.workspace = true
tempfile.workspace = true
url = { workspace = true, features = ["serde"] }
util.workspace = true
+terminal.workspace = true
[dev-dependencies]
gpui = { workspace = true, features = ["test-support"] }
@@ -8,9 +8,12 @@ use futures::{
AsyncBufReadExt as _, AsyncRead, AsyncWrite, AsyncWriteExt as _, Stream, StreamExt as _,
};
use gpui::AsyncApp;
+use settings::Settings as _;
use smol::channel;
use smol::process::Child;
+use terminal::terminal_settings::TerminalSettings;
use util::TryFutureExt as _;
+use util::shell_builder::ShellBuilder;
use crate::client::ModelContextServerBinary;
use crate::transport::Transport;
@@ -28,9 +31,14 @@ impl StdioTransport {
working_directory: &Option<PathBuf>,
cx: &AsyncApp,
) -> Result<Self> {
- let mut command = util::command::new_smol_command(&binary.executable);
+ let shell = cx.update(|cx| TerminalSettings::get(None, cx).shell.clone())?;
+ let builder = ShellBuilder::new(&shell, cfg!(windows));
+ let (command, args) =
+ builder.build(Some(binary.executable.display().to_string()), &binary.args);
+
+ let mut command = util::command::new_smol_command(command);
command
- .args(&binary.args)
+ .args(args)
.envs(binary.env.unwrap_or_default())
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
@@ -807,7 +807,7 @@ impl Copilot {
.ok();
}
language::BufferEvent::FileHandleChanged
- | language::BufferEvent::LanguageChanged => {
+ | language::BufferEvent::LanguageChanged(_) => {
let new_language_id = id_for_language(buffer.read(cx).language());
let Ok(new_uri) = uri_for_buffer(&buffer, cx) else {
return Ok(());
@@ -317,7 +317,7 @@ impl PickerDelegate for AttachModalDelegate {
let candidate = self.candidates.get(hit.candidate_id)?;
Some(
- ListItem::new(SharedString::from(format!("process-entry-{ix}")))
+ ListItem::new(format!("process-entry-{ix}"))
.inset(true)
.spacing(ListItemSpacing::Sparse)
.toggle_state(selected)
@@ -327,7 +327,7 @@ impl PickerDelegate for AttachModalDelegate {
.child(Label::new(format!("{} {}", candidate.name, candidate.pid)))
.child(
div()
- .id(SharedString::from(format!("process-entry-{ix}-command")))
+ .id(format!("process-entry-{ix}-command"))
.tooltip(Tooltip::text(
candidate
.command
@@ -1519,7 +1519,7 @@ impl PickerDelegate for DebugDelegate {
});
Some(
- ListItem::new(SharedString::from(format!("debug-scenario-selection-{ix}")))
+ ListItem::new(format!("debug-scenario-selection-{ix}"))
.inset(true)
.start_slot::<IconWithIndicator>(icon)
.spacing(ListItemSpacing::Sparse)
@@ -286,10 +286,10 @@ impl Item for SubView {
impl Render for SubView {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
v_flex()
- .id(SharedString::from(format!(
+ .id(format!(
"subview-container-{}",
self.kind.to_shared_string()
- )))
+ ))
.on_hover(cx.listener(|this, hovered, _, cx| {
this.hovered = *hovered;
cx.notify();
@@ -484,10 +484,7 @@ pub(crate) fn new_debugger_pane(
let deemphasized = !pane.has_focus(window, cx);
let item_ = item.boxed_clone();
div()
- .id(SharedString::from(format!(
- "debugger_tab_{}",
- item.item_id().as_u64()
- )))
+ .id(format!("debugger_tab_{}", item.item_id().as_u64()))
.p_1()
.rounded_md()
.cursor_pointer()
@@ -155,6 +155,8 @@ pub enum RequestMessage {
content: Option<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
tool_calls: Vec<ToolCall>,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ reasoning_content: Option<String>,
},
User {
content: String,
@@ -21,7 +21,6 @@ arrayvec.workspace = true
brotli.workspace = true
client.workspace = true
cloud_llm_client.workspace = true
-cloud_zeta2_prompt.workspace = true
collections.workspace = true
copilot.workspace = true
credentials_provider.workspace = true
@@ -50,8 +49,6 @@ semver.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
-smol.workspace = true
-strsim.workspace = true
strum.workspace = true
telemetry.workspace = true
telemetry_events.workspace = true
@@ -62,6 +59,7 @@ uuid.workspace = true
workspace.workspace = true
worktree.workspace = true
zed_actions.workspace = true
+zeta_prompt.workspace = true
[dev-dependencies]
clock = { workspace = true, features = ["test-support"] }
@@ -1,14 +1,13 @@
use anyhow::Result;
use arrayvec::ArrayVec;
use client::{Client, EditPredictionUsage, UserStore};
-use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat};
+use cloud_llm_client::predict_edits_v3::{self, PromptFormat};
use cloud_llm_client::{
AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef,
ZED_VERSION_HEADER_NAME,
};
-use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
use collections::{HashMap, HashSet};
use db::kvp::{Dismissable, KEY_VALUE_STORE};
use edit_prediction_context::EditPredictionExcerptOptions;
@@ -16,10 +15,7 @@ use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, Rel
use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
use futures::{
AsyncReadExt as _, FutureExt as _, StreamExt as _,
- channel::{
- mpsc::{self, UnboundedReceiver},
- oneshot,
- },
+ channel::mpsc::{self, UnboundedReceiver},
select_biased,
};
use gpui::BackgroundExecutor;
@@ -58,8 +54,10 @@ mod onboarding_modal;
pub mod open_ai_response;
mod prediction;
pub mod sweep_ai;
+
+#[cfg(any(test, feature = "test-support", feature = "eval-support"))]
pub mod udiff;
-mod xml_edits;
+
mod zed_edit_prediction_delegate;
pub mod zeta1;
pub mod zeta2;
@@ -72,7 +70,6 @@ use crate::mercury::Mercury;
use crate::onboarding_modal::ZedPredictModal;
pub use crate::prediction::EditPrediction;
pub use crate::prediction::EditPredictionId;
-pub use crate::prediction::EditPredictionInputs;
use crate::prediction::EditPredictionResult;
pub use crate::sweep_ai::SweepAi;
pub use telemetry_events::EditPredictionRating;
@@ -112,7 +109,6 @@ pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
min_bytes: 128,
target_before_cursor_over_total_bytes: 0.5,
},
- max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
prompt_format: PromptFormat::DEFAULT,
};
@@ -162,7 +158,6 @@ pub struct EditPredictionStore {
use_context: bool,
options: ZetaOptions,
update_required: bool,
- debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
#[cfg(feature = "eval-support")]
eval_cache: Option<Arc<dyn EvalCache>>,
edit_prediction_model: EditPredictionModel,
@@ -183,10 +178,22 @@ pub enum EditPredictionModel {
Mercury,
}
+pub struct EditPredictionModelInput {
+ project: Entity<Project>,
+ buffer: Entity<Buffer>,
+ snapshot: BufferSnapshot,
+ position: Anchor,
+ events: Vec<Arc<zeta_prompt::Event>>,
+ related_files: Arc<[RelatedFile]>,
+ recent_paths: VecDeque<ProjectPath>,
+ trigger: PredictEditsRequestTrigger,
+ diagnostic_search_range: Range<Point>,
+ debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
+}
+
#[derive(Debug, Clone, PartialEq)]
pub struct ZetaOptions {
pub context: EditPredictionExcerptOptions,
- pub max_prompt_bytes: usize,
pub prompt_format: predict_edits_v3::PromptFormat,
}
@@ -194,7 +201,8 @@ pub struct ZetaOptions {
pub enum DebugEvent {
ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
- EditPredictionRequested(EditPredictionRequestedDebugEvent),
+ EditPredictionStarted(EditPredictionStartedDebugEvent),
+ EditPredictionFinished(EditPredictionFinishedDebugEvent),
}
#[derive(Debug)]
@@ -212,27 +220,30 @@ pub struct ContextRetrievalFinishedDebugEvent {
}
#[derive(Debug)]
-pub struct EditPredictionRequestedDebugEvent {
- pub inputs: EditPredictionInputs,
- pub retrieval_time: Duration,
+pub struct EditPredictionStartedDebugEvent {
pub buffer: WeakEntity<Buffer>,
pub position: Anchor,
- pub local_prompt: Result<String, String>,
- pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, Duration)>,
+ pub prompt: Option<String>,
+}
+
+#[derive(Debug)]
+pub struct EditPredictionFinishedDebugEvent {
+ pub buffer: WeakEntity<Buffer>,
+ pub position: Anchor,
+ pub model_output: Option<String>,
}
pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
struct ProjectState {
- events: VecDeque<Arc<cloud_llm_client::predict_edits_v3::Event>>,
+ events: VecDeque<Arc<zeta_prompt::Event>>,
last_event: Option<LastEvent>,
recent_paths: VecDeque<ProjectPath>,
registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
current_prediction: Option<CurrentEditPrediction>,
next_pending_prediction_id: usize,
pending_predictions: ArrayVec<PendingPrediction, 2>,
- context_updates_tx: smol::channel::Sender<()>,
- context_updates_rx: smol::channel::Receiver<()>,
+ debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
last_prediction_refresh: Option<(EntityId, Instant)>,
cancelled_predictions: HashSet<usize>,
context: Entity<RelatedExcerptStore>,
@@ -241,7 +252,7 @@ struct ProjectState {
}
impl ProjectState {
- pub fn events(&self, cx: &App) -> Vec<Arc<cloud_llm_client::predict_edits_v3::Event>> {
+ pub fn events(&self, cx: &App) -> Vec<Arc<zeta_prompt::Event>> {
self.events
.iter()
.cloned()
@@ -376,7 +387,7 @@ impl LastEvent {
&self,
license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
cx: &App,
- ) -> Option<Arc<predict_edits_v3::Event>> {
+ ) -> Option<Arc<zeta_prompt::Event>> {
let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
@@ -396,7 +407,7 @@ impl LastEvent {
if path == old_path && diff.is_empty() {
None
} else {
- Some(Arc::new(predict_edits_v3::Event::BufferChange {
+ Some(Arc::new(zeta_prompt::Event::BufferChange {
old_path,
path,
diff,
@@ -481,7 +492,6 @@ impl EditPredictionStore {
},
),
update_required: false,
- debug_tx: None,
#[cfg(feature = "eval-support")]
eval_cache: None,
edit_prediction_model: EditPredictionModel::Zeta2,
@@ -536,12 +546,6 @@ impl EditPredictionStore {
self.eval_cache = Some(cache);
}
- pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<DebugEvent> {
- let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
- self.debug_tx = Some(debug_watch_tx);
- debug_watch_rx
- }
-
pub fn options(&self) -> &ZetaOptions {
&self.options
}
@@ -560,15 +564,35 @@ impl EditPredictionStore {
}
}
+ pub fn edit_history_for_project(
+ &self,
+ project: &Entity<Project>,
+ ) -> Vec<Arc<zeta_prompt::Event>> {
+ self.projects
+ .get(&project.entity_id())
+ .map(|project_state| project_state.events.iter().cloned().collect())
+ .unwrap_or_default()
+ }
+
pub fn context_for_project<'a>(
&'a self,
project: &Entity<Project>,
cx: &'a App,
- ) -> &'a [RelatedFile] {
+ ) -> Arc<[RelatedFile]> {
self.projects
.get(&project.entity_id())
.map(|project| project.context.read(cx).related_files())
- .unwrap_or(&[])
+ .unwrap_or_else(|| vec![].into())
+ }
+
+ pub fn context_for_project_with_buffers<'a>(
+ &'a self,
+ project: &Entity<Project>,
+ cx: &'a App,
+ ) -> Option<impl 'a + Iterator<Item = (RelatedFile, Entity<Buffer>)>> {
+ self.projects
+ .get(&project.entity_id())
+ .map(|project| project.context.read(cx).related_files_with_buffers())
}
pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
@@ -599,85 +623,21 @@ impl EditPredictionStore {
cx: &mut Context<Self>,
) -> &mut ProjectState {
let entity_id = project.entity_id();
- let (context_updates_tx, context_updates_rx) = smol::channel::unbounded();
self.projects
.entry(entity_id)
.or_insert_with(|| ProjectState {
context: {
let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
- cx.subscribe(
- &related_excerpt_store,
- move |this, _, event, _| match event {
- RelatedExcerptStoreEvent::StartedRefresh => {
- if let Some(debug_tx) = this.debug_tx.clone() {
- debug_tx
- .unbounded_send(DebugEvent::ContextRetrievalStarted(
- ContextRetrievalStartedDebugEvent {
- project_entity_id: entity_id,
- timestamp: Instant::now(),
- search_prompt: String::new(),
- },
- ))
- .ok();
- }
- }
- RelatedExcerptStoreEvent::FinishedRefresh {
- cache_hit_count,
- cache_miss_count,
- mean_definition_latency,
- max_definition_latency,
- } => {
- if let Some(debug_tx) = this.debug_tx.clone() {
- debug_tx
- .unbounded_send(DebugEvent::ContextRetrievalFinished(
- ContextRetrievalFinishedDebugEvent {
- project_entity_id: entity_id,
- timestamp: Instant::now(),
- metadata: vec![
- (
- "Cache Hits",
- format!(
- "{}/{}",
- cache_hit_count,
- cache_hit_count + cache_miss_count
- )
- .into(),
- ),
- (
- "Max LSP Time",
- format!(
- "{} ms",
- max_definition_latency.as_millis()
- )
- .into(),
- ),
- (
- "Mean LSP Time",
- format!(
- "{} ms",
- mean_definition_latency.as_millis()
- )
- .into(),
- ),
- ],
- },
- ))
- .ok();
- }
- if let Some(project_state) = this.projects.get(&entity_id) {
- project_state.context_updates_tx.send_blocking(()).ok();
- }
- }
- },
- )
+ cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
+ this.handle_excerpt_store_event(entity_id, event);
+ })
.detach();
related_excerpt_store
},
events: VecDeque::new(),
last_event: None,
recent_paths: VecDeque::new(),
- context_updates_rx,
- context_updates_tx,
+ debug_tx: None,
registered_buffers: HashMap::default(),
current_prediction: None,
cancelled_predictions: HashSet::default(),
@@ -689,12 +649,79 @@ impl EditPredictionStore {
})
}
- pub fn project_context_updates(
- &self,
+ pub fn remove_project(&mut self, project: &Entity<Project>) {
+ self.projects.remove(&project.entity_id());
+ }
+
+ fn handle_excerpt_store_event(
+ &mut self,
+ project_entity_id: EntityId,
+ event: &RelatedExcerptStoreEvent,
+ ) {
+ if let Some(project_state) = self.projects.get(&project_entity_id) {
+ if let Some(debug_tx) = project_state.debug_tx.clone() {
+ match event {
+ RelatedExcerptStoreEvent::StartedRefresh => {
+ debug_tx
+ .unbounded_send(DebugEvent::ContextRetrievalStarted(
+ ContextRetrievalStartedDebugEvent {
+ project_entity_id: project_entity_id,
+ timestamp: Instant::now(),
+ search_prompt: String::new(),
+ },
+ ))
+ .ok();
+ }
+ RelatedExcerptStoreEvent::FinishedRefresh {
+ cache_hit_count,
+ cache_miss_count,
+ mean_definition_latency,
+ max_definition_latency,
+ } => {
+ debug_tx
+ .unbounded_send(DebugEvent::ContextRetrievalFinished(
+ ContextRetrievalFinishedDebugEvent {
+ project_entity_id: project_entity_id,
+ timestamp: Instant::now(),
+ metadata: vec![
+ (
+ "Cache Hits",
+ format!(
+ "{}/{}",
+ cache_hit_count,
+ cache_hit_count + cache_miss_count
+ )
+ .into(),
+ ),
+ (
+ "Max LSP Time",
+ format!("{} ms", max_definition_latency.as_millis())
+ .into(),
+ ),
+ (
+ "Mean LSP Time",
+ format!("{} ms", mean_definition_latency.as_millis())
+ .into(),
+ ),
+ ],
+ },
+ ))
+ .ok();
+ }
+ }
+ }
+ }
+ }
+
+ pub fn debug_info(
+ &mut self,
project: &Entity<Project>,
- ) -> Option<smol::channel::Receiver<()>> {
- let project_state = self.projects.get(&project.entity_id())?;
- Some(project_state.context_updates_rx.clone())
+ cx: &mut Context<Self>,
+ ) -> mpsc::UnboundedReceiver<DebugEvent> {
+ let project_state = self.get_or_init_project(project, cx);
+ let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
+ project_state.debug_tx = Some(debug_watch_tx);
+ debug_watch_rx
}
fn handle_project_event(
@@ -1348,6 +1375,7 @@ impl EditPredictionStore {
let project_state = self.projects.get(&project.entity_id()).unwrap();
let events = project_state.events(cx);
let has_events = !events.is_empty();
+ let debug_tx = project_state.debug_tx.clone();
let snapshot = active_buffer.read(cx).snapshot();
let cursor_point = position.to_point(&snapshot);
@@ -1357,55 +1385,29 @@ impl EditPredictionStore {
Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
let related_files = if self.use_context {
- self.context_for_project(&project, cx).to_vec()
+ self.context_for_project(&project, cx)
} else {
- Vec::new()
+ Vec::new().into()
+ };
+
+ let inputs = EditPredictionModelInput {
+ project: project.clone(),
+ buffer: active_buffer.clone(),
+ snapshot: snapshot.clone(),
+ position,
+ events,
+ related_files,
+ recent_paths: project_state.recent_paths.clone(),
+ trigger,
+ diagnostic_search_range: diagnostic_search_range.clone(),
+ debug_tx,
};
let task = match self.edit_prediction_model {
- EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(
- self,
- &project,
- &active_buffer,
- snapshot.clone(),
- position,
- events,
- trigger,
- cx,
- ),
- EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(
- self,
- &project,
- &active_buffer,
- snapshot.clone(),
- position,
- events,
- related_files,
- trigger,
- cx,
- ),
- EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(
- &project,
- &active_buffer,
- snapshot.clone(),
- position,
- events,
- &project_state.recent_paths,
- related_files,
- diagnostic_search_range.clone(),
- cx,
- ),
- EditPredictionModel::Mercury => self.mercury.request_prediction(
- &project,
- &active_buffer,
- snapshot.clone(),
- position,
- events,
- &project_state.recent_paths,
- related_files,
- diagnostic_search_range.clone(),
- cx,
- ),
+ EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(self, inputs, cx),
+ EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx),
+ EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
+ EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
};
cx.spawn(async move |this, cx| {
@@ -1706,6 +1708,20 @@ impl EditPredictionStore {
}
}
+ #[cfg(feature = "eval-support")]
+ pub fn set_context_for_buffer(
+ &mut self,
+ project: &Entity<Project>,
+ related_files: Vec<RelatedFile>,
+ cx: &mut Context<Self>,
+ ) {
+ self.get_or_init_project(project, cx)
+ .context
+ .update(cx, |store, _| {
+ store.set_related_files(related_files);
+ });
+ }
+
fn is_file_open_source(
&self,
project: &Entity<Project>,
@@ -1729,14 +1745,14 @@ impl EditPredictionStore {
self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
}
- fn can_collect_events(&self, events: &[Arc<Event>]) -> bool {
+ fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>]) -> bool {
if !self.data_collection_choice.is_enabled() {
return false;
}
events.iter().all(|event| {
matches!(
event.as_ref(),
- Event::BufferChange {
+ zeta_prompt::Event::BufferChange {
in_open_source_repo: true,
..
}
@@ -1,5 +1,5 @@
use super::*;
-use crate::zeta1::MAX_EVENT_TOKENS;
+use crate::{udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS};
use client::{UserStore, test::FakeServer};
use clock::{FakeSystemClock, ReplicaId};
use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
@@ -7,7 +7,6 @@ use cloud_llm_client::{
EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse,
RejectEditPredictionsBody,
};
-use edit_prediction_context::Line;
use futures::{
AsyncReadExt, StreamExt,
channel::{mpsc, oneshot},
@@ -28,6 +27,7 @@ use settings::SettingsStore;
use std::{path::Path, sync::Arc, time::Duration};
use util::{path, rel_path::rel_path};
use uuid::Uuid;
+use zeta_prompt::ZetaPromptInput;
use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
@@ -65,18 +65,21 @@ async fn test_current_state(cx: &mut TestAppContext) {
ep_store.update(cx, |ep_store, cx| {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
});
- let (_request, respond_tx) = requests.predict.next().await.unwrap();
+ let (request, respond_tx) = requests.predict.next().await.unwrap();
respond_tx
- .send(model_response(indoc! {r"
- --- a/root/1.txt
- +++ b/root/1.txt
- @@ ... @@
- Hello!
- -How
- +How are you?
- Bye
- "}))
+ .send(model_response(
+ request,
+ indoc! {r"
+ --- a/root/1.txt
+ +++ b/root/1.txt
+ @@ ... @@
+ Hello!
+ -How
+ +How are you?
+ Bye
+ "},
+ ))
.unwrap();
cx.run_until_parked();
@@ -120,16 +123,20 @@ async fn test_current_state(cx: &mut TestAppContext) {
});
});
- let (_request, respond_tx) = requests.predict.next().await.unwrap();
+ let (request, respond_tx) = requests.predict.next().await.unwrap();
respond_tx
- .send(model_response(indoc! {r#"
- --- a/root/2.txt
- +++ b/root/2.txt
- Hola!
- -Como
- +Como estas?
- Adios
- "#}))
+ .send(model_response(
+ request,
+ indoc! {r#"
+ --- a/root/2.txt
+ +++ b/root/2.txt
+ @@ ... @@
+ Hola!
+ -Como
+ +Como estas?
+ Adios
+ "#},
+ ))
.unwrap();
cx.run_until_parked();
@@ -186,7 +193,7 @@ async fn test_simple_request(cx: &mut TestAppContext) {
ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
});
- let (_, respond_tx) = requests.predict.next().await.unwrap();
+ let (request, respond_tx) = requests.predict.next().await.unwrap();
// TODO Put back when we have a structured request again
// assert_eq!(
@@ -202,15 +209,18 @@ async fn test_simple_request(cx: &mut TestAppContext) {
// );
respond_tx
- .send(model_response(indoc! { r"
- --- a/root/foo.md
- +++ b/root/foo.md
- @@ ... @@
- Hello!
- -How
- +How are you?
- Bye
- "}))
+ .send(model_response(
+ request,
+ indoc! { r"
+ --- a/root/foo.md
+ +++ b/root/foo.md
+ @@ ... @@
+ Hello!
+ -How
+ +How are you?
+ Bye
+ "},
+ ))
.unwrap();
let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
@@ -276,15 +286,18 @@ async fn test_request_events(cx: &mut TestAppContext) {
);
respond_tx
- .send(model_response(indoc! {r#"
- --- a/root/foo.md
- +++ b/root/foo.md
- @@ ... @@
- Hello!
- -How
- +How are you?
- Bye
- "#}))
+ .send(model_response(
+ request,
+ indoc! {r#"
+ --- a/root/foo.md
+ +++ b/root/foo.md
+ @@ ... @@
+ Hello!
+ -How
+ +How are you?
+ Bye
+ "#},
+ ))
.unwrap();
let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
@@ -324,18 +337,8 @@ async fn test_empty_prediction(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
- const NO_OP_DIFF: &str = indoc! { r"
- --- a/root/foo.md
- +++ b/root/foo.md
- @@ ... @@
- Hello!
- -How
- +How
- Bye
- "};
-
- let (_, respond_tx) = requests.predict.next().await.unwrap();
- let response = model_response(NO_OP_DIFF);
+ let (request, respond_tx) = requests.predict.next().await.unwrap();
+ let response = model_response(request, "");
let id = response.id.clone();
respond_tx.send(response).unwrap();
@@ -389,13 +392,13 @@ async fn test_interpolated_empty(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
- let (_, respond_tx) = requests.predict.next().await.unwrap();
+ let (request, respond_tx) = requests.predict.next().await.unwrap();
buffer.update(cx, |buffer, cx| {
buffer.set_text("Hello!\nHow are you?\nBye", cx);
});
- let response = model_response(SIMPLE_DIFF);
+ let response = model_response(request, SIMPLE_DIFF);
let id = response.id.clone();
respond_tx.send(response).unwrap();
@@ -459,8 +462,8 @@ async fn test_replace_current(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
- let (_, respond_tx) = requests.predict.next().await.unwrap();
- let first_response = model_response(SIMPLE_DIFF);
+ let (request, respond_tx) = requests.predict.next().await.unwrap();
+ let first_response = model_response(request, SIMPLE_DIFF);
let first_id = first_response.id.clone();
respond_tx.send(first_response).unwrap();
@@ -482,8 +485,8 @@ async fn test_replace_current(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
- let (_, respond_tx) = requests.predict.next().await.unwrap();
- let second_response = model_response(SIMPLE_DIFF);
+ let (request, respond_tx) = requests.predict.next().await.unwrap();
+ let second_response = model_response(request, SIMPLE_DIFF);
let second_id = second_response.id.clone();
respond_tx.send(second_response).unwrap();
@@ -541,8 +544,8 @@ async fn test_current_preferred(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
- let (_, respond_tx) = requests.predict.next().await.unwrap();
- let first_response = model_response(SIMPLE_DIFF);
+ let (request, respond_tx) = requests.predict.next().await.unwrap();
+ let first_response = model_response(request, SIMPLE_DIFF);
let first_id = first_response.id.clone();
respond_tx.send(first_response).unwrap();
@@ -564,17 +567,20 @@ async fn test_current_preferred(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
- let (_, respond_tx) = requests.predict.next().await.unwrap();
+ let (request, respond_tx) = requests.predict.next().await.unwrap();
// worse than current prediction
- let second_response = model_response(indoc! { r"
- --- a/root/foo.md
- +++ b/root/foo.md
- @@ ... @@
- Hello!
- -How
- +How are
- Bye
- "});
+ let second_response = model_response(
+ request,
+ indoc! { r"
+ --- a/root/foo.md
+ +++ b/root/foo.md
+ @@ ... @@
+ Hello!
+ -How
+ +How are
+ Bye
+ "},
+ );
let second_id = second_response.id.clone();
respond_tx.send(second_response).unwrap();
@@ -633,19 +639,19 @@ async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
- let (_, respond_first) = requests.predict.next().await.unwrap();
+ let (request1, respond_first) = requests.predict.next().await.unwrap();
ep_store.update(cx, |ep_store, cx| {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
- let (_, respond_second) = requests.predict.next().await.unwrap();
+ let (request, respond_second) = requests.predict.next().await.unwrap();
// wait for throttle
cx.run_until_parked();
// second responds first
- let second_response = model_response(SIMPLE_DIFF);
+ let second_response = model_response(request, SIMPLE_DIFF);
let second_id = second_response.id.clone();
respond_second.send(second_response).unwrap();
@@ -663,7 +669,7 @@ async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
);
});
- let first_response = model_response(SIMPLE_DIFF);
+ let first_response = model_response(request1, SIMPLE_DIFF);
let first_id = first_response.id.clone();
respond_first.send(first_response).unwrap();
@@ -724,13 +730,13 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
- let (_, respond_first) = requests.predict.next().await.unwrap();
+ let (request1, respond_first) = requests.predict.next().await.unwrap();
ep_store.update(cx, |ep_store, cx| {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
- let (_, respond_second) = requests.predict.next().await.unwrap();
+ let (request2, respond_second) = requests.predict.next().await.unwrap();
// wait for throttle, so requests are sent
cx.run_until_parked();
@@ -754,9 +760,9 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
// wait for throttle
cx.run_until_parked();
- let (_, respond_third) = requests.predict.next().await.unwrap();
+ let (request3, respond_third) = requests.predict.next().await.unwrap();
- let first_response = model_response(SIMPLE_DIFF);
+ let first_response = model_response(request1, SIMPLE_DIFF);
let first_id = first_response.id.clone();
respond_first.send(first_response).unwrap();
@@ -774,7 +780,7 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
);
});
- let cancelled_response = model_response(SIMPLE_DIFF);
+ let cancelled_response = model_response(request2, SIMPLE_DIFF);
let cancelled_id = cancelled_response.id.clone();
respond_second.send(cancelled_response).unwrap();
@@ -792,7 +798,7 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
);
});
- let third_response = model_response(SIMPLE_DIFF);
+ let third_response = model_response(request3, SIMPLE_DIFF);
let third_response_id = third_response.id.clone();
respond_third.send(third_response).unwrap();
@@ -1036,7 +1042,24 @@ async fn test_rejections_flushing(cx: &mut TestAppContext) {
// );
// }
-fn model_response(text: &str) -> open_ai::Response {
+// Generate a model response that would apply the given diff to the active file.
+fn model_response(request: open_ai::Request, diff_to_apply: &str) -> open_ai::Response {
+ let prompt = match &request.messages[0] {
+ open_ai::RequestMessage::User {
+ content: open_ai::MessageContent::Plain(content),
+ } => content,
+ _ => panic!("unexpected request {request:?}"),
+ };
+
+ let open = "<editable_region>\n";
+ let close = "</editable_region>";
+ let cursor = "<|user_cursor|>";
+
+ let start_ix = open.len() + prompt.find(open).unwrap();
+ let end_ix = start_ix + &prompt[start_ix..].find(close).unwrap();
+ let excerpt = prompt[start_ix..end_ix].replace(cursor, "");
+ let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
+
open_ai::Response {
id: Uuid::new_v4().to_string(),
object: "response".into(),
@@ -1045,7 +1068,7 @@ fn model_response(text: &str) -> open_ai::Response {
choices: vec![open_ai::Choice {
index: 0,
message: open_ai::RequestMessage::Assistant {
- content: Some(open_ai::MessageContent::Plain(text.to_string())),
+ content: Some(open_ai::MessageContent::Plain(new_excerpt)),
tool_calls: vec![],
},
finish_reason: None,
@@ -1160,20 +1183,19 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
.read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
.await;
- let completion = EditPrediction {
+ let prediction = EditPrediction {
edits,
edit_preview,
buffer: buffer.clone(),
snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
id: EditPredictionId("the-id".into()),
- inputs: EditPredictionInputs {
+ inputs: ZetaPromptInput {
events: Default::default(),
- included_files: Default::default(),
- cursor_point: cloud_llm_client::predict_edits_v3::Point {
- line: Line(0),
- column: 0,
- },
+ related_files: Default::default(),
cursor_path: Path::new("").into(),
+ cursor_excerpt: "".into(),
+ editable_range_in_excerpt: 0..0,
+ cursor_offset_in_excerpt: 0,
},
buffer_snapshotted_at: Instant::now(),
response_received_at: Instant::now(),
@@ -1182,7 +1204,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
cx.update(|cx| {
assert_eq!(
from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1192,7 +1214,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
assert_eq!(
from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1202,7 +1224,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
buffer.update(cx, |buffer, cx| buffer.undo(cx));
assert_eq!(
from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1212,7 +1234,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
assert_eq!(
from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1222,7 +1244,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
assert_eq!(
from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1232,7 +1254,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
assert_eq!(
from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1242,7 +1264,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
assert_eq!(
from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1252,7 +1274,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
assert_eq!(
from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1260,7 +1282,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
);
buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
- assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
+ assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
})
}
@@ -1,20 +1,17 @@
use anyhow::{Context as _, Result};
-use cloud_llm_client::predict_edits_v3::Event;
use credentials_provider::CredentialsProvider;
-use edit_prediction_context::RelatedFile;
use futures::{AsyncReadExt as _, FutureExt, future::Shared};
use gpui::{
- App, AppContext as _, Entity, Task,
+ App, AppContext as _, Task,
http_client::{self, AsyncBody, Method},
};
-use language::{Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _};
-use project::{Project, ProjectPath};
-use std::{
- collections::VecDeque, fmt::Write as _, mem, ops::Range, path::Path, sync::Arc, time::Instant,
-};
+use language::{OffsetRangeExt as _, ToOffset, ToPoint as _};
+use std::{mem, ops::Range, path::Path, sync::Arc, time::Instant};
+use zeta_prompt::ZetaPromptInput;
use crate::{
- EditPredictionId, EditPredictionInputs, open_ai_response::text_from_response,
+ DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
+ EditPredictionStartedDebugEvent, open_ai_response::text_from_response,
prediction::EditPredictionResult,
};
@@ -38,16 +35,17 @@ impl Mercury {
store_api_token_in_keychain(api_token, cx)
}
- pub fn request_prediction(
+ pub(crate) fn request_prediction(
&self,
- _project: &Entity<Project>,
- active_buffer: &Entity<Buffer>,
- snapshot: BufferSnapshot,
- position: language::Anchor,
- events: Vec<Arc<Event>>,
- _recent_paths: &VecDeque<ProjectPath>,
- related_files: Vec<RelatedFile>,
- _diagnostic_search_range: Range<Point>,
+ EditPredictionModelInput {
+ buffer,
+ snapshot,
+ position,
+ events,
+ related_files,
+ debug_tx,
+ ..
+ }: EditPredictionModelInput,
cx: &mut App,
) -> Task<Result<Option<EditPredictionResult>>> {
let Some(api_token) = self.api_token.clone().now_or_never().flatten() else {
@@ -62,6 +60,7 @@ impl Mercury {
let http_client = cx.http_client();
let cursor_point = position.to_point(&snapshot);
let buffer_snapshotted_at = Instant::now();
+ let active_buffer = buffer.clone();
let result = cx.background_spawn(async move {
let (editable_range, context_range) =
@@ -72,39 +71,39 @@ impl Mercury {
MAX_REWRITE_TOKENS,
);
- let offset_range = editable_range.to_offset(&snapshot);
- let prompt = build_prompt(
- &events,
- &related_files,
- &snapshot,
- full_path.as_ref(),
- cursor_point,
- editable_range,
- context_range.clone(),
- );
-
- let inputs = EditPredictionInputs {
- events: events,
- included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile {
- path: full_path.clone(),
- max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
- excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
- start_line: cloud_llm_client::predict_edits_v3::Line(
- context_range.start.row,
- ),
- text: snapshot
- .text_for_range(context_range.clone())
- .collect::<String>()
- .into(),
- }],
- }],
- cursor_point: cloud_llm_client::predict_edits_v3::Point {
- column: cursor_point.column,
- line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
- },
+ let context_offset_range = context_range.to_offset(&snapshot);
+
+ let editable_offset_range = editable_range.to_offset(&snapshot);
+
+ let inputs = zeta_prompt::ZetaPromptInput {
+ events,
+ related_files,
+ cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot)
+ - context_range.start.to_offset(&snapshot),
cursor_path: full_path.clone(),
+ cursor_excerpt: snapshot
+ .text_for_range(context_range)
+ .collect::<String>()
+ .into(),
+ editable_range_in_excerpt: (editable_offset_range.start
+ - context_offset_range.start)
+ ..(editable_offset_range.end - context_offset_range.start),
};
+ let prompt = build_prompt(&inputs);
+
+ if let Some(debug_tx) = &debug_tx {
+ debug_tx
+ .unbounded_send(DebugEvent::EditPredictionStarted(
+ EditPredictionStartedDebugEvent {
+ buffer: active_buffer.downgrade(),
+ prompt: Some(prompt.clone()),
+ position,
+ },
+ ))
+ .ok();
+ }
+
let request_body = open_ai::Request {
model: "mercury-coder".into(),
messages: vec![open_ai::RequestMessage::User {
@@ -160,6 +159,18 @@ impl Mercury {
let id = mem::take(&mut response.id);
let response_str = text_from_response(response).unwrap_or_default();
+ if let Some(debug_tx) = &debug_tx {
+ debug_tx
+ .unbounded_send(DebugEvent::EditPredictionFinished(
+ EditPredictionFinishedDebugEvent {
+ buffer: active_buffer.downgrade(),
+ model_output: Some(response_str.clone()),
+ position,
+ },
+ ))
+ .ok();
+ }
+
let response_str = response_str.strip_prefix("```\n").unwrap_or(&response_str);
let response_str = response_str.strip_suffix("\n```").unwrap_or(&response_str);
@@ -168,15 +179,16 @@ impl Mercury {
if response_str != NO_PREDICTION_OUTPUT {
let old_text = snapshot
- .text_for_range(offset_range.clone())
+ .text_for_range(editable_offset_range.clone())
.collect::<String>();
edits.extend(
language::text_diff(&old_text, &response_str)
.into_iter()
.map(|(range, text)| {
(
- snapshot.anchor_after(offset_range.start + range.start)
- ..snapshot.anchor_before(offset_range.start + range.end),
+ snapshot.anchor_after(editable_offset_range.start + range.start)
+ ..snapshot
+ .anchor_before(editable_offset_range.start + range.end),
text,
)
}),
@@ -186,8 +198,6 @@ impl Mercury {
anyhow::Ok((id, edits, snapshot, response_received_at, inputs))
});
- let buffer = active_buffer.clone();
-
cx.spawn(async move |cx| {
let (id, edits, old_snapshot, response_received_at, inputs) =
result.await.context("Mercury edit prediction failed")?;
@@ -208,15 +218,7 @@ impl Mercury {
}
}
-fn build_prompt(
- events: &[Arc<Event>],
- related_files: &[RelatedFile],
- cursor_buffer: &BufferSnapshot,
- cursor_buffer_path: &Path,
- cursor_point: Point,
- editable_range: Range<Point>,
- context_range: Range<Point>,
-) -> String {
+fn build_prompt(inputs: &ZetaPromptInput) -> String {
const RECENTLY_VIEWED_SNIPPETS_START: &str = "<|recently_viewed_code_snippets|>\n";
const RECENTLY_VIEWED_SNIPPETS_END: &str = "<|/recently_viewed_code_snippets|>\n";
const RECENTLY_VIEWED_SNIPPET_START: &str = "<|recently_viewed_code_snippet|>\n";
@@ -237,14 +239,14 @@ fn build_prompt(
&mut prompt,
RECENTLY_VIEWED_SNIPPETS_START..RECENTLY_VIEWED_SNIPPETS_END,
|prompt| {
- for related_file in related_files {
+ for related_file in inputs.related_files.iter() {
for related_excerpt in &related_file.excerpts {
push_delimited(
prompt,
RECENTLY_VIEWED_SNIPPET_START..RECENTLY_VIEWED_SNIPPET_END,
|prompt| {
prompt.push_str(CODE_SNIPPET_FILE_PATH_PREFIX);
- prompt.push_str(related_file.path.path.as_unix_str());
+ prompt.push_str(related_file.path.to_string_lossy().as_ref());
prompt.push('\n');
prompt.push_str(&related_excerpt.text.to_string());
},
@@ -259,21 +261,22 @@ fn build_prompt(
CURRENT_FILE_CONTENT_START..CURRENT_FILE_CONTENT_END,
|prompt| {
prompt.push_str(CURRENT_FILE_PATH_PREFIX);
- prompt.push_str(cursor_buffer_path.as_os_str().to_string_lossy().as_ref());
+ prompt.push_str(inputs.cursor_path.as_os_str().to_string_lossy().as_ref());
prompt.push('\n');
- let prefix_range = context_range.start..editable_range.start;
- let suffix_range = editable_range.end..context_range.end;
-
- prompt.extend(cursor_buffer.text_for_range(prefix_range));
+ prompt.push_str(&inputs.cursor_excerpt[0..inputs.editable_range_in_excerpt.start]);
push_delimited(prompt, CODE_TO_EDIT_START..CODE_TO_EDIT_END, |prompt| {
- let range_before_cursor = editable_range.start..cursor_point;
- let range_after_cursor = cursor_point..editable_range.end;
- prompt.extend(cursor_buffer.text_for_range(range_before_cursor));
+ prompt.push_str(
+ &inputs.cursor_excerpt
+ [inputs.editable_range_in_excerpt.start..inputs.cursor_offset_in_excerpt],
+ );
prompt.push_str(CURSOR_TAG);
- prompt.extend(cursor_buffer.text_for_range(range_after_cursor));
+ prompt.push_str(
+ &inputs.cursor_excerpt
+ [inputs.cursor_offset_in_excerpt..inputs.editable_range_in_excerpt.end],
+ );
});
- prompt.extend(cursor_buffer.text_for_range(suffix_range));
+ prompt.push_str(&inputs.cursor_excerpt[inputs.editable_range_in_excerpt.end..]);
},
);
@@ -281,8 +284,8 @@ fn build_prompt(
&mut prompt,
EDIT_DIFF_HISTORY_START..EDIT_DIFF_HISTORY_END,
|prompt| {
- for event in events {
- writeln!(prompt, "{event}").unwrap();
+ for event in inputs.events.iter() {
+ zeta_prompt::write_event(prompt, &event);
}
},
);
@@ -1,14 +1,14 @@
use std::{
ops::Range,
- path::Path,
sync::Arc,
time::{Duration, Instant},
};
use cloud_llm_client::EditPredictionRejectReason;
+use edit_prediction_types::interpolate_edits;
use gpui::{AsyncApp, Entity, SharedString};
-use language::{Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, TextBufferSnapshot};
-use serde::Serialize;
+use language::{Anchor, Buffer, BufferSnapshot, EditPreview, TextBufferSnapshot};
+use zeta_prompt::ZetaPromptInput;
#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)]
pub struct EditPredictionId(pub SharedString);
@@ -39,7 +39,7 @@ impl EditPredictionResult {
edits: Arc<[(Range<Anchor>, Arc<str>)]>,
buffer_snapshotted_at: Instant,
response_received_at: Instant,
- inputs: EditPredictionInputs,
+ inputs: ZetaPromptInput,
cx: &mut AsyncApp,
) -> Self {
if edits.is_empty() {
@@ -53,7 +53,7 @@ impl EditPredictionResult {
.read_with(cx, |buffer, cx| {
let new_snapshot = buffer.snapshot();
let edits: Arc<[_]> =
- interpolate_edits(&edited_buffer_snapshot, &new_snapshot, edits)?.into();
+ interpolate_edits(&edited_buffer_snapshot, &new_snapshot, &edits)?.into();
Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
})
@@ -93,15 +93,7 @@ pub struct EditPrediction {
pub buffer: Entity<Buffer>,
pub buffer_snapshotted_at: Instant,
pub response_received_at: Instant,
- pub inputs: EditPredictionInputs,
-}
-
-#[derive(Debug, Clone, Serialize)]
-pub struct EditPredictionInputs {
- pub events: Vec<Arc<cloud_llm_client::predict_edits_v3::Event>>,
- pub included_files: Vec<cloud_llm_client::predict_edits_v3::RelatedFile>,
- pub cursor_point: cloud_llm_client::predict_edits_v3::Point,
- pub cursor_path: Arc<Path>,
+ pub inputs: zeta_prompt::ZetaPromptInput,
}
impl EditPrediction {
@@ -109,7 +101,7 @@ impl EditPrediction {
&self,
new_snapshot: &TextBufferSnapshot,
) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
- interpolate_edits(&self.snapshot, new_snapshot, self.edits.clone())
+ interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
}
pub fn targets_buffer(&self, buffer: &Buffer) -> bool {
@@ -130,57 +122,14 @@ impl std::fmt::Debug for EditPrediction {
}
}
-pub fn interpolate_edits(
- old_snapshot: &TextBufferSnapshot,
- new_snapshot: &TextBufferSnapshot,
- current_edits: Arc<[(Range<Anchor>, Arc<str>)]>,
-) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
- let mut edits = Vec::new();
-
- let mut model_edits = current_edits.iter().peekable();
- for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
- while let Some((model_old_range, _)) = model_edits.peek() {
- let model_old_range = model_old_range.to_offset(old_snapshot);
- if model_old_range.end < user_edit.old.start {
- let (model_old_range, model_new_text) = model_edits.next().unwrap();
- edits.push((model_old_range.clone(), model_new_text.clone()));
- } else {
- break;
- }
- }
-
- if let Some((model_old_range, model_new_text)) = model_edits.peek() {
- let model_old_offset_range = model_old_range.to_offset(old_snapshot);
- if user_edit.old == model_old_offset_range {
- let user_new_text = new_snapshot
- .text_for_range(user_edit.new.clone())
- .collect::<String>();
-
- if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
- if !model_suffix.is_empty() {
- let anchor = old_snapshot.anchor_after(user_edit.old.end);
- edits.push((anchor..anchor, model_suffix.into()));
- }
-
- model_edits.next();
- continue;
- }
- }
- }
-
- return None;
- }
-
- edits.extend(model_edits.cloned());
-
- if edits.is_empty() { None } else { Some(edits) }
-}
-
#[cfg(test)]
mod tests {
+ use std::path::Path;
+
use super::*;
use gpui::{App, Entity, TestAppContext, prelude::*};
use language::{Buffer, ToOffset as _};
+ use zeta_prompt::ZetaPromptInput;
#[gpui::test]
async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
@@ -199,14 +148,13 @@ mod tests {
snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
buffer: buffer.clone(),
edit_preview,
- inputs: EditPredictionInputs {
+ inputs: ZetaPromptInput {
events: vec![],
- included_files: vec![],
- cursor_point: cloud_llm_client::predict_edits_v3::Point {
- line: cloud_llm_client::predict_edits_v3::Line(0),
- column: 0,
- },
+ related_files: vec![].into(),
cursor_path: Path::new("path.txt").into(),
+ cursor_offset_in_excerpt: 0,
+ cursor_excerpt: "".into(),
+ editable_range_in_excerpt: 0..0,
},
buffer_snapshotted_at: Instant::now(),
response_received_at: Instant::now(),
@@ -1,26 +1,21 @@
use anyhow::{Context as _, Result};
-use cloud_llm_client::predict_edits_v3::Event;
use credentials_provider::CredentialsProvider;
-use edit_prediction_context::RelatedFile;
use futures::{AsyncReadExt as _, FutureExt, future::Shared};
use gpui::{
- App, AppContext as _, Entity, Task,
+ App, AppContext as _, Task,
http_client::{self, AsyncBody, Method},
};
-use language::{Buffer, BufferSnapshot, Point, ToOffset as _, ToPoint as _};
+use language::{Point, ToOffset as _};
use lsp::DiagnosticSeverity;
-use project::{Project, ProjectPath};
use serde::{Deserialize, Serialize};
use std::{
- collections::VecDeque,
fmt::{self, Write as _},
- ops::Range,
path::Path,
sync::Arc,
time::Instant,
};
-use crate::{EditPredictionId, EditPredictionInputs, prediction::EditPredictionResult};
+use crate::{EditPredictionId, EditPredictionModelInput, prediction::EditPredictionResult};
const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
@@ -44,40 +39,34 @@ impl SweepAi {
pub fn request_prediction_with_sweep(
&self,
- project: &Entity<Project>,
- active_buffer: &Entity<Buffer>,
- snapshot: BufferSnapshot,
- position: language::Anchor,
- events: Vec<Arc<Event>>,
- recent_paths: &VecDeque<ProjectPath>,
- related_files: Vec<RelatedFile>,
- diagnostic_search_range: Range<Point>,
+ inputs: EditPredictionModelInput,
cx: &mut App,
) -> Task<Result<Option<EditPredictionResult>>> {
let debug_info = self.debug_info.clone();
let Some(api_token) = self.api_token.clone().now_or_never().flatten() else {
return Task::ready(Ok(None));
};
- let full_path: Arc<Path> = snapshot
+ let full_path: Arc<Path> = inputs
+ .snapshot
.file()
.map(|file| file.full_path(cx))
.unwrap_or_else(|| "untitled".into())
.into();
- let project_file = project::File::from_dyn(snapshot.file());
+ let project_file = project::File::from_dyn(inputs.snapshot.file());
let repo_name = project_file
.map(|file| file.worktree.read(cx).root_name_str())
.unwrap_or("untitled")
.into();
- let offset = position.to_offset(&snapshot);
+ let offset = inputs.position.to_offset(&inputs.snapshot);
- let recent_buffers = recent_paths.iter().cloned();
+ let recent_buffers = inputs.recent_paths.iter().cloned();
let http_client = cx.http_client();
let recent_buffer_snapshots = recent_buffers
.filter_map(|project_path| {
- let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
- if active_buffer == &buffer {
+ let buffer = inputs.project.read(cx).get_open_buffer(&project_path, cx)?;
+ if inputs.buffer == buffer {
None
} else {
Some(buffer.read(cx).snapshot())
@@ -86,14 +75,13 @@ impl SweepAi {
.take(3)
.collect::<Vec<_>>();
- let cursor_point = position.to_point(&snapshot);
let buffer_snapshotted_at = Instant::now();
let result = cx.background_spawn(async move {
- let text = snapshot.text();
+ let text = inputs.snapshot.text();
let mut recent_changes = String::new();
- for event in &events {
+ for event in &inputs.events {
write_event(event.as_ref(), &mut recent_changes).unwrap();
}
@@ -122,20 +110,23 @@ impl SweepAi {
})
.collect::<Vec<_>>();
- let retrieval_chunks = related_files
+ let retrieval_chunks = inputs
+ .related_files
.iter()
.flat_map(|related_file| {
related_file.excerpts.iter().map(|excerpt| FileChunk {
- file_path: related_file.path.path.as_unix_str().to_string(),
- start_line: excerpt.point_range.start.row as usize,
- end_line: excerpt.point_range.end.row as usize,
+ file_path: related_file.path.to_string_lossy().to_string(),
+ start_line: excerpt.row_range.start as usize,
+ end_line: excerpt.row_range.end as usize,
content: excerpt.text.to_string(),
timestamp: None,
})
})
.collect();
- let diagnostic_entries = snapshot.diagnostics_in_range(diagnostic_search_range, false);
+ let diagnostic_entries = inputs
+ .snapshot
+ .diagnostics_in_range(inputs.diagnostic_search_range, false);
let mut diagnostic_content = String::new();
let mut diagnostic_count = 0;
@@ -195,21 +186,14 @@ impl SweepAi {
serde_json::to_writer(writer, &request_body)?;
let body: AsyncBody = buf.into();
- let inputs = EditPredictionInputs {
- events,
- included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile {
- path: full_path.clone(),
- max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
- excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
- start_line: cloud_llm_client::predict_edits_v3::Line(0),
- text: request_body.file_contents.into(),
- }],
- }],
- cursor_point: cloud_llm_client::predict_edits_v3::Point {
- column: cursor_point.column,
- line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
- },
+ let ep_inputs = zeta_prompt::ZetaPromptInput {
+ events: inputs.events,
+ related_files: inputs.related_files.clone(),
cursor_path: full_path.clone(),
+ cursor_excerpt: request_body.file_contents.into(),
+ // we actually don't know
+ editable_range_in_excerpt: 0..inputs.snapshot.len(),
+ cursor_offset_in_excerpt: request_body.cursor_position,
};
let request = http_client::Request::builder()
@@ -237,15 +221,20 @@ impl SweepAi {
let response: AutocompleteResponse = serde_json::from_slice(&body)?;
- let old_text = snapshot
+ let old_text = inputs
+ .snapshot
.text_for_range(response.start_index..response.end_index)
.collect::<String>();
let edits = language::text_diff(&old_text, &response.completion)
.into_iter()
.map(|(range, text)| {
(
- snapshot.anchor_after(response.start_index + range.start)
- ..snapshot.anchor_before(response.start_index + range.end),
+ inputs
+ .snapshot
+ .anchor_after(response.start_index + range.start)
+ ..inputs
+ .snapshot
+ .anchor_before(response.start_index + range.end),
text,
)
})
@@ -254,13 +243,13 @@ impl SweepAi {
anyhow::Ok((
response.autocomplete_id,
edits,
- snapshot,
+ inputs.snapshot,
response_received_at,
- inputs,
+ ep_inputs,
))
});
- let buffer = active_buffer.clone();
+ let buffer = inputs.buffer.clone();
cx.spawn(async move |cx| {
let (id, edits, old_snapshot, response_received_at, inputs) = result.await?;
@@ -403,12 +392,9 @@ struct AdditionalCompletion {
pub finish_reason: Option<String>,
}
-fn write_event(
- event: &cloud_llm_client::predict_edits_v3::Event,
- f: &mut impl fmt::Write,
-) -> fmt::Result {
+fn write_event(event: &zeta_prompt::Event, f: &mut impl fmt::Write) -> fmt::Result {
match event {
- cloud_llm_client::predict_edits_v3::Event::BufferChange {
+ zeta_prompt::Event::BufferChange {
old_path,
path,
diff,
@@ -14,68 +14,18 @@ use anyhow::anyhow;
use collections::HashMap;
use gpui::AsyncApp;
use gpui::Entity;
-use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, TextBufferSnapshot};
+use language::{Anchor, Buffer, OffsetRangeExt as _, TextBufferSnapshot};
use project::Project;
-pub async fn parse_diff<'a>(
- diff_str: &'a str,
- get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range<Anchor>])> + Send,
-) -> Result<(&'a BufferSnapshot, Vec<(Range<Anchor>, Arc<str>)>)> {
- let mut diff = DiffParser::new(diff_str);
- let mut edited_buffer = None;
- let mut edits = Vec::new();
-
- while let Some(event) = diff.next()? {
- match event {
- DiffEvent::Hunk {
- path: file_path,
- hunk,
- } => {
- let (buffer, ranges) = match edited_buffer {
- None => {
- edited_buffer = get_buffer(&Path::new(file_path.as_ref()));
- edited_buffer
- .as_ref()
- .context("Model tried to edit a file that wasn't included")?
- }
- Some(ref current) => current,
- };
-
- edits.extend(
- resolve_hunk_edits_in_buffer(hunk, &buffer.text, ranges)
- .with_context(|| format!("Diff:\n{diff_str}"))?,
- );
- }
- DiffEvent::FileEnd { renamed_to } => {
- let (buffer, _) = edited_buffer
- .take()
- .context("Got a FileEnd event before an Hunk event")?;
-
- if renamed_to.is_some() {
- anyhow::bail!("edit predictions cannot rename files");
- }
-
- if diff.next()?.is_some() {
- anyhow::bail!("Edited more than one file");
- }
-
- return Ok((buffer, edits));
- }
- }
- }
-
- Err(anyhow::anyhow!("No EOF"))
-}
-
-#[derive(Debug)]
-pub struct OpenedBuffers<'a>(#[allow(unused)] HashMap<Cow<'a, str>, Entity<Buffer>>);
+#[derive(Clone, Debug)]
+pub struct OpenedBuffers(#[allow(unused)] HashMap<String, Entity<Buffer>>);
#[must_use]
-pub async fn apply_diff<'a>(
- diff_str: &'a str,
+pub async fn apply_diff(
+ diff_str: &str,
project: &Entity<Project>,
cx: &mut AsyncApp,
-) -> Result<OpenedBuffers<'a>> {
+) -> Result<OpenedBuffers> {
let mut included_files = HashMap::default();
for line in diff_str.lines() {
@@ -94,7 +44,7 @@ pub async fn apply_diff<'a>(
})??
.await?;
- included_files.insert(path, buffer);
+ included_files.insert(path.to_string(), buffer);
}
}
@@ -113,7 +63,7 @@ pub async fn apply_diff<'a>(
let (buffer, ranges) = match current_file {
None => {
let buffer = included_files
- .get_mut(&file_path)
+ .get_mut(file_path.as_ref())
.expect("Opened all files in diff");
current_file = Some((buffer, ranges.as_slice()));
@@ -167,6 +117,29 @@ pub async fn apply_diff<'a>(
Ok(OpenedBuffers(included_files))
}
+pub fn apply_diff_to_string(diff_str: &str, text: &str) -> Result<String> {
+ let mut diff = DiffParser::new(diff_str);
+
+ let mut text = text.to_string();
+
+ while let Some(event) = diff.next()? {
+ match event {
+ DiffEvent::Hunk { hunk, .. } => {
+ let hunk_offset = text
+ .find(&hunk.context)
+ .ok_or_else(|| anyhow!("couldn't result hunk {:?}", hunk.context))?;
+ for edit in hunk.edits.iter().rev() {
+ let range = (hunk_offset + edit.range.start)..(hunk_offset + edit.range.end);
+ text.replace_range(range, &edit.text);
+ }
+ }
+ DiffEvent::FileEnd { .. } => {}
+ }
+ }
+
+ Ok(text)
+}
+
struct PatchFile<'a> {
old_path: Cow<'a, str>,
new_path: Cow<'a, str>,
@@ -492,7 +465,6 @@ mod tests {
use super::*;
use gpui::TestAppContext;
use indoc::indoc;
- use language::Point;
use pretty_assertions::assert_eq;
use project::{FakeFs, Project};
use serde_json::json;
@@ -817,137 +789,6 @@ mod tests {
});
}
- #[gpui::test]
- async fn test_apply_diff_non_unique(cx: &mut TestAppContext) {
- let fs = init_test(cx);
-
- let buffer_1_text = indoc! {r#"
- one
- two
- three
- four
- five
- one
- two
- three
- four
- five
- "# };
-
- fs.insert_tree(
- path!("/root"),
- json!({
- "file1": buffer_1_text,
- }),
- )
- .await;
-
- let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
- let buffer = project
- .update(cx, |project, cx| {
- project.open_local_buffer(path!("/root/file1"), cx)
- })
- .await
- .unwrap();
- let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
-
- let diff = indoc! {r#"
- --- a/root/file1
- +++ b/root/file1
- one
- two
- -three
- +3
- four
- five
- "#};
-
- let final_text = indoc! {r#"
- one
- two
- three
- four
- five
- one
- two
- 3
- four
- five
- "#};
-
- apply_diff(diff, &project, &mut cx.to_async())
- .await
- .expect_err("Non-unique edits should fail");
-
- let ranges = [buffer_snapshot.anchor_before(Point::new(1, 0))
- ..buffer_snapshot.anchor_after(buffer_snapshot.max_point())];
-
- let (edited_snapshot, edits) = parse_diff(diff, |_path| Some((&buffer_snapshot, &ranges)))
- .await
- .unwrap();
-
- assert_eq!(edited_snapshot.remote_id(), buffer_snapshot.remote_id());
- buffer.update(cx, |buffer, cx| {
- buffer.edit(edits, None, cx);
- assert_eq!(buffer.text(), final_text);
- });
- }
-
- #[gpui::test]
- async fn test_parse_diff_with_edits_within_line(cx: &mut TestAppContext) {
- let fs = init_test(cx);
-
- let buffer_1_text = indoc! {r#"
- one two three four
- five six seven eight
- nine ten eleven twelve
- "# };
-
- fs.insert_tree(
- path!("/root"),
- json!({
- "file1": buffer_1_text,
- }),
- )
- .await;
-
- let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
- let buffer = project
- .update(cx, |project, cx| {
- project.open_local_buffer(path!("/root/file1"), cx)
- })
- .await
- .unwrap();
- let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
-
- let diff = indoc! {r#"
- --- a/root/file1
- +++ b/root/file1
- one two three four
- -five six seven eight
- +five SIX seven eight!
- nine ten eleven twelve
- "#};
-
- let (buffer, edits) = parse_diff(diff, |_path| {
- Some((&buffer_snapshot, &[(Anchor::MIN..Anchor::MAX)] as &[_]))
- })
- .await
- .unwrap();
-
- let edits = edits
- .into_iter()
- .map(|(range, text)| (range.to_point(&buffer), text))
- .collect::<Vec<_>>();
- assert_eq!(
- edits,
- &[
- (Point::new(1, 5)..Point::new(1, 8), "SIX".into()),
- (Point::new(1, 20)..Point::new(1, 20), "!".into())
- ]
- );
- }
-
#[gpui::test]
async fn test_apply_diff_unique_via_previous_context(cx: &mut TestAppContext) {
let fs = init_test(cx);
@@ -1,637 +0,0 @@
-use anyhow::{Context as _, Result};
-use language::{Anchor, BufferSnapshot, OffsetRangeExt as _, Point};
-use std::{cmp, ops::Range, path::Path, sync::Arc};
-
-const EDITS_TAG_NAME: &'static str = "edits";
-const OLD_TEXT_TAG_NAME: &'static str = "old_text";
-const NEW_TEXT_TAG_NAME: &'static str = "new_text";
-const XML_TAGS: &[&str] = &[EDITS_TAG_NAME, OLD_TEXT_TAG_NAME, NEW_TEXT_TAG_NAME];
-
-pub async fn parse_xml_edits<'a>(
- input: &'a str,
- get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range<Anchor>])> + Send,
-) -> Result<(&'a BufferSnapshot, Vec<(Range<Anchor>, Arc<str>)>)> {
- parse_xml_edits_inner(input, get_buffer)
- .await
- .with_context(|| format!("Failed to parse XML edits:\n{input}"))
-}
-
-async fn parse_xml_edits_inner<'a>(
- input: &'a str,
- get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range<Anchor>])> + Send,
-) -> Result<(&'a BufferSnapshot, Vec<(Range<Anchor>, Arc<str>)>)> {
- let xml_edits = extract_xml_replacements(input)?;
-
- let (buffer, context_ranges) = get_buffer(xml_edits.file_path.as_ref())
- .with_context(|| format!("no buffer for file {}", xml_edits.file_path))?;
-
- let mut all_edits = vec![];
- for (old_text, new_text) in xml_edits.replacements {
- let match_range = fuzzy_match_in_ranges(old_text, buffer, context_ranges)?;
- let matched_old_text = buffer
- .text_for_range(match_range.clone())
- .collect::<String>();
- let edits_within_hunk = language::text_diff(&matched_old_text, new_text);
- all_edits.extend(
- edits_within_hunk
- .into_iter()
- .map(move |(inner_range, inner_text)| {
- (
- buffer.anchor_after(match_range.start + inner_range.start)
- ..buffer.anchor_before(match_range.start + inner_range.end),
- inner_text,
- )
- }),
- );
- }
-
- Ok((buffer, all_edits))
-}
-
-fn fuzzy_match_in_ranges(
- old_text: &str,
- buffer: &BufferSnapshot,
- context_ranges: &[Range<Anchor>],
-) -> Result<Range<usize>> {
- let mut state = FuzzyMatcher::new(buffer, old_text);
- let mut best_match = None;
- let mut tie_match_range = None;
-
- for range in context_ranges {
- let best_match_cost = best_match.as_ref().map(|(score, _)| *score);
- match (best_match_cost, state.match_range(range.to_offset(buffer))) {
- (Some(lowest_cost), Some((new_cost, new_range))) => {
- if new_cost == lowest_cost {
- tie_match_range = Some(new_range);
- } else if new_cost < lowest_cost {
- tie_match_range.take();
- best_match = Some((new_cost, new_range));
- }
- }
- (None, Some(new_match)) => {
- best_match = Some(new_match);
- }
- (None, None) | (Some(_), None) => {}
- };
- }
-
- if let Some((_, best_match_range)) = best_match {
- if let Some(tie_match_range) = tie_match_range {
- anyhow::bail!(
- "Multiple ambiguous matches:\n{:?}:\n{}\n\n{:?}:\n{}",
- best_match_range.clone(),
- buffer.text_for_range(best_match_range).collect::<String>(),
- tie_match_range.clone(),
- buffer.text_for_range(tie_match_range).collect::<String>()
- );
- }
- return Ok(best_match_range);
- }
-
- anyhow::bail!(
- "Failed to fuzzy match `old_text`:\n{}\nin:\n```\n{}\n```",
- old_text,
- context_ranges
- .iter()
- .map(|range| buffer.text_for_range(range.clone()).collect::<String>())
- .collect::<Vec<String>>()
- .join("```\n```")
- );
-}
-
-#[derive(Debug)]
-struct XmlEdits<'a> {
- file_path: &'a str,
- /// Vec of (old_text, new_text) pairs
- replacements: Vec<(&'a str, &'a str)>,
-}
-
-fn extract_xml_replacements(input: &str) -> Result<XmlEdits<'_>> {
- let mut cursor = 0;
-
- let (edits_body_start, edits_attrs) =
- find_tag_open(input, &mut cursor, EDITS_TAG_NAME)?.context("No edits tag found")?;
-
- let file_path = edits_attrs
- .trim_start()
- .strip_prefix("path")
- .context("no path attribute on edits tag")?
- .trim_end()
- .strip_prefix('=')
- .context("no value for path attribute")?
- .trim()
- .trim_start_matches('"')
- .trim_end_matches('"');
-
- cursor = edits_body_start;
- let mut edits_list = Vec::new();
-
- while let Some((old_body_start, _)) = find_tag_open(input, &mut cursor, OLD_TEXT_TAG_NAME)? {
- let old_body_end = find_tag_close(input, &mut cursor)?;
- let old_text = trim_surrounding_newlines(&input[old_body_start..old_body_end]);
-
- let (new_body_start, _) = find_tag_open(input, &mut cursor, NEW_TEXT_TAG_NAME)?
- .context("no new_text tag following old_text")?;
- let new_body_end = find_tag_close(input, &mut cursor)?;
- let new_text = trim_surrounding_newlines(&input[new_body_start..new_body_end]);
-
- edits_list.push((old_text, new_text));
- }
-
- Ok(XmlEdits {
- file_path,
- replacements: edits_list,
- })
-}
-
-/// Trims a single leading and trailing newline
-fn trim_surrounding_newlines(input: &str) -> &str {
- let start = input.strip_prefix('\n').unwrap_or(input);
- let end = start.strip_suffix('\n').unwrap_or(start);
- end
-}
-
-fn find_tag_open<'a>(
- input: &'a str,
- cursor: &mut usize,
- expected_tag: &str,
-) -> Result<Option<(usize, &'a str)>> {
- let mut search_pos = *cursor;
-
- while search_pos < input.len() {
- let Some(tag_start) = input[search_pos..].find("<") else {
- break;
- };
- let tag_start = search_pos + tag_start;
- if !input[tag_start + 1..].starts_with(expected_tag) {
- search_pos = search_pos + tag_start + 1;
- continue;
- };
-
- let after_tag_name = tag_start + expected_tag.len() + 1;
- let close_bracket = input[after_tag_name..]
- .find('>')
- .with_context(|| format!("missing > after <{}", expected_tag))?;
- let attrs_end = after_tag_name + close_bracket;
- let body_start = attrs_end + 1;
-
- let attributes = input[after_tag_name..attrs_end].trim();
- *cursor = body_start;
-
- return Ok(Some((body_start, attributes)));
- }
-
- Ok(None)
-}
-
-fn find_tag_close(input: &str, cursor: &mut usize) -> Result<usize> {
- let mut depth = 1;
- let mut search_pos = *cursor;
-
- while search_pos < input.len() && depth > 0 {
- let Some(bracket_offset) = input[search_pos..].find('<') else {
- break;
- };
- let bracket_pos = search_pos + bracket_offset;
-
- if input[bracket_pos..].starts_with("</")
- && let Some(close_end) = input[bracket_pos + 2..].find('>')
- {
- let close_start = bracket_pos + 2;
- let tag_name = input[close_start..close_start + close_end].trim();
-
- if XML_TAGS.contains(&tag_name) {
- depth -= 1;
- if depth == 0 {
- *cursor = close_start + close_end + 1;
- return Ok(bracket_pos);
- }
- }
- search_pos = close_start + close_end + 1;
- continue;
- } else if let Some(close_bracket_offset) = input[bracket_pos..].find('>') {
- let close_bracket_pos = bracket_pos + close_bracket_offset;
- let tag_name = &input[bracket_pos + 1..close_bracket_pos].trim();
- if XML_TAGS.contains(&tag_name) {
- depth += 1;
- }
- }
-
- search_pos = bracket_pos + 1;
- }
-
- anyhow::bail!("no closing tag found")
-}
-
-const REPLACEMENT_COST: u32 = 1;
-const INSERTION_COST: u32 = 3;
-const DELETION_COST: u32 = 10;
-
-/// A fuzzy matcher that can process text chunks incrementally
-/// and return the best match found so far at each step.
-struct FuzzyMatcher<'a> {
- snapshot: &'a BufferSnapshot,
- query_lines: Vec<&'a str>,
- matrix: SearchMatrix,
-}
-
-impl<'a> FuzzyMatcher<'a> {
- fn new(snapshot: &'a BufferSnapshot, old_text: &'a str) -> Self {
- let query_lines = old_text.lines().collect();
- Self {
- snapshot,
- query_lines,
- matrix: SearchMatrix::new(0),
- }
- }
-
- fn match_range(&mut self, range: Range<usize>) -> Option<(u32, Range<usize>)> {
- let point_range = range.to_point(&self.snapshot);
- let buffer_line_count = (point_range.end.row - point_range.start.row + 1) as usize;
-
- self.matrix
- .reset(self.query_lines.len() + 1, buffer_line_count + 1);
- let query_line_count = self.query_lines.len();
-
- for row in 0..query_line_count {
- let query_line = self.query_lines[row].trim();
- let leading_deletion_cost = (row + 1) as u32 * DELETION_COST;
-
- self.matrix.set(
- row + 1,
- 0,
- SearchState::new(leading_deletion_cost, SearchDirection::Up),
- );
-
- let mut buffer_lines = self.snapshot.text_for_range(range.clone()).lines();
-
- let mut col = 0;
- while let Some(buffer_line) = buffer_lines.next() {
- let buffer_line = buffer_line.trim();
- let up = SearchState::new(
- self.matrix
- .get(row, col + 1)
- .cost
- .saturating_add(DELETION_COST),
- SearchDirection::Up,
- );
- let left = SearchState::new(
- self.matrix
- .get(row + 1, col)
- .cost
- .saturating_add(INSERTION_COST),
- SearchDirection::Left,
- );
- let diagonal = SearchState::new(
- if query_line == buffer_line {
- self.matrix.get(row, col).cost
- } else if fuzzy_eq(query_line, buffer_line) {
- self.matrix.get(row, col).cost + REPLACEMENT_COST
- } else {
- self.matrix
- .get(row, col)
- .cost
- .saturating_add(DELETION_COST + INSERTION_COST)
- },
- SearchDirection::Diagonal,
- );
- self.matrix
- .set(row + 1, col + 1, up.min(left).min(diagonal));
- col += 1;
- }
- }
-
- // Find all matches with the best cost
- let mut best_cost = u32::MAX;
- let mut matches_with_best_cost = Vec::new();
-
- for col in 1..=buffer_line_count {
- let cost = self.matrix.get(query_line_count, col).cost;
- if cost < best_cost {
- best_cost = cost;
- matches_with_best_cost.clear();
- matches_with_best_cost.push(col as u32);
- } else if cost == best_cost {
- matches_with_best_cost.push(col as u32);
- }
- }
-
- // Find ranges for the matches
- for &match_end_col in &matches_with_best_cost {
- let mut matched_lines = 0;
- let mut query_row = query_line_count;
- let mut match_start_col = match_end_col;
- while query_row > 0 && match_start_col > 0 {
- let current = self.matrix.get(query_row, match_start_col as usize);
- match current.direction {
- SearchDirection::Diagonal => {
- query_row -= 1;
- match_start_col -= 1;
- matched_lines += 1;
- }
- SearchDirection::Up => {
- query_row -= 1;
- }
- SearchDirection::Left => {
- match_start_col -= 1;
- }
- }
- }
-
- let buffer_row_start = match_start_col + point_range.start.row;
- let buffer_row_end = match_end_col + point_range.start.row;
-
- let matched_buffer_row_count = buffer_row_end - buffer_row_start;
- let matched_ratio = matched_lines as f32
- / (matched_buffer_row_count as f32).max(query_line_count as f32);
- if matched_ratio >= 0.8 {
- let buffer_start_ix = self
- .snapshot
- .point_to_offset(Point::new(buffer_row_start, 0));
- let buffer_end_ix = self.snapshot.point_to_offset(Point::new(
- buffer_row_end - 1,
- self.snapshot.line_len(buffer_row_end - 1),
- ));
- return Some((best_cost, buffer_start_ix..buffer_end_ix));
- }
- }
-
- None
- }
-}
-
-fn fuzzy_eq(left: &str, right: &str) -> bool {
- const THRESHOLD: f64 = 0.8;
-
- let min_levenshtein = left.len().abs_diff(right.len());
- let min_normalized_levenshtein =
- 1. - (min_levenshtein as f64 / cmp::max(left.len(), right.len()) as f64);
- if min_normalized_levenshtein < THRESHOLD {
- return false;
- }
-
- strsim::normalized_levenshtein(left, right) >= THRESHOLD
-}
-
-#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
-enum SearchDirection {
- Up,
- Left,
- Diagonal,
-}
-
-#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
-struct SearchState {
- cost: u32,
- direction: SearchDirection,
-}
-
-impl SearchState {
- fn new(cost: u32, direction: SearchDirection) -> Self {
- Self { cost, direction }
- }
-}
-
-struct SearchMatrix {
- cols: usize,
- rows: usize,
- data: Vec<SearchState>,
-}
-
-impl SearchMatrix {
- fn new(cols: usize) -> Self {
- SearchMatrix {
- cols,
- rows: 0,
- data: Vec::new(),
- }
- }
-
- fn reset(&mut self, rows: usize, cols: usize) {
- self.rows = rows;
- self.cols = cols;
- self.data
- .fill(SearchState::new(0, SearchDirection::Diagonal));
- self.data.resize(
- self.rows * self.cols,
- SearchState::new(0, SearchDirection::Diagonal),
- );
- }
-
- fn get(&self, row: usize, col: usize) -> SearchState {
- debug_assert!(row < self.rows);
- debug_assert!(col < self.cols);
- self.data[row * self.cols + col]
- }
-
- fn set(&mut self, row: usize, col: usize, state: SearchState) {
- debug_assert!(row < self.rows && col < self.cols);
- self.data[row * self.cols + col] = state;
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
- use gpui::TestAppContext;
- use indoc::indoc;
- use language::Point;
- use project::{FakeFs, Project};
- use serde_json::json;
- use settings::SettingsStore;
- use util::path;
-
- #[test]
- fn test_extract_xml_edits() {
- let input = indoc! {r#"
- <edits path="test.rs">
- <old_text>
- old content
- </old_text>
- <new_text>
- new content
- </new_text>
- </edits>
- "#};
-
- let result = extract_xml_replacements(input).unwrap();
- assert_eq!(result.file_path, "test.rs");
- assert_eq!(result.replacements.len(), 1);
- assert_eq!(result.replacements[0].0, "old content");
- assert_eq!(result.replacements[0].1, "new content");
- }
-
- #[test]
- fn test_extract_xml_edits_with_wrong_closing_tags() {
- let input = indoc! {r#"
- <edits path="test.rs">
- <old_text>
- old content
- </new_text>
- <new_text>
- new content
- </old_text>
- </ edits >
- "#};
-
- let result = extract_xml_replacements(input).unwrap();
- assert_eq!(result.file_path, "test.rs");
- assert_eq!(result.replacements.len(), 1);
- assert_eq!(result.replacements[0].0, "old content");
- assert_eq!(result.replacements[0].1, "new content");
- }
-
- #[test]
- fn test_extract_xml_edits_with_xml_like_content() {
- let input = indoc! {r#"
- <edits path="component.tsx">
- <old_text>
- <foo><bar></bar></foo>
- </old_text>
- <new_text>
- <foo><bar><baz></baz></bar></foo>
- </new_text>
- </edits>
- "#};
-
- let result = extract_xml_replacements(input).unwrap();
- assert_eq!(result.file_path, "component.tsx");
- assert_eq!(result.replacements.len(), 1);
- assert_eq!(result.replacements[0].0, "<foo><bar></bar></foo>");
- assert_eq!(
- result.replacements[0].1,
- "<foo><bar><baz></baz></bar></foo>"
- );
- }
-
- #[test]
- fn test_extract_xml_edits_with_conflicting_content() {
- let input = indoc! {r#"
- <edits path="component.tsx">
- <old_text>
- <new_text></new_text>
- </old_text>
- <new_text>
- <old_text></old_text>
- </new_text>
- </edits>
- "#};
-
- let result = extract_xml_replacements(input).unwrap();
- assert_eq!(result.file_path, "component.tsx");
- assert_eq!(result.replacements.len(), 1);
- assert_eq!(result.replacements[0].0, "<new_text></new_text>");
- assert_eq!(result.replacements[0].1, "<old_text></old_text>");
- }
-
- #[test]
- fn test_extract_xml_edits_multiple_pairs() {
- let input = indoc! {r#"
- Some reasoning before edits. Lots of thinking going on here
-
- <edits path="test.rs">
- <old_text>
- first old
- </old_text>
- <new_text>
- first new
- </new_text>
- <old_text>
- second old
- </edits>
- <new_text>
- second new
- </old_text>
- </edits>
- "#};
-
- let result = extract_xml_replacements(input).unwrap();
- assert_eq!(result.file_path, "test.rs");
- assert_eq!(result.replacements.len(), 2);
- assert_eq!(result.replacements[0].0, "first old");
- assert_eq!(result.replacements[0].1, "first new");
- assert_eq!(result.replacements[1].0, "second old");
- assert_eq!(result.replacements[1].1, "second new");
- }
-
- #[test]
- fn test_extract_xml_edits_unexpected_eof() {
- let input = indoc! {r#"
- <edits path="test.rs">
- <old_text>
- first old
- </
- "#};
-
- extract_xml_replacements(input).expect_err("Unexpected end of file");
- }
-
- #[gpui::test]
- async fn test_parse_xml_edits(cx: &mut TestAppContext) {
- let fs = init_test(cx);
-
- let buffer_1_text = indoc! {r#"
- one two three four
- five six seven eight
- nine ten eleven twelve
- thirteen fourteen fifteen
- sixteen seventeen eighteen
- "#};
-
- fs.insert_tree(
- path!("/root"),
- json!({
- "file1": buffer_1_text,
- }),
- )
- .await;
-
- let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
- let buffer = project
- .update(cx, |project, cx| {
- project.open_local_buffer(path!("/root/file1"), cx)
- })
- .await
- .unwrap();
- let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
-
- let edits = indoc! {r#"
- <edits path="root/file1">
- <old_text>
- nine ten eleven twelve
- </old_text>
- <new_text>
- nine TEN eleven twelve!
- </new_text>
- </edits>
- "#};
-
- let included_ranges = [(buffer_snapshot.anchor_before(Point::new(1, 0))..Anchor::MAX)];
- let (buffer, edits) = parse_xml_edits(edits, |_path| {
- Some((&buffer_snapshot, included_ranges.as_slice()))
- })
- .await
- .unwrap();
-
- let edits = edits
- .into_iter()
- .map(|(range, text)| (range.to_point(&buffer), text))
- .collect::<Vec<_>>();
- assert_eq!(
- edits,
- &[
- (Point::new(2, 5)..Point::new(2, 8), "TEN".into()),
- (Point::new(2, 22)..Point::new(2, 22), "!".into())
- ]
- );
- }
-
- fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
- cx.update(|cx| {
- let settings_store = SettingsStore::test(cx);
- cx.set_global(settings_store);
- });
-
- FakeFs::new(cx.background_executor.clone())
- }
-}
@@ -1,22 +1,23 @@
use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
use crate::{
- EditPredictionId, EditPredictionStore, ZedUpdateRequiredError,
+ DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
+ EditPredictionStartedDebugEvent, EditPredictionStore, ZedUpdateRequiredError,
cursor_excerpt::{editable_and_context_ranges_for_cursor_position, guess_token_count},
- prediction::{EditPredictionInputs, EditPredictionResult},
+ prediction::EditPredictionResult,
};
use anyhow::{Context as _, Result};
use cloud_llm_client::{
PredictEditsBody, PredictEditsGitInfo, PredictEditsRequestTrigger, PredictEditsResponse,
- predict_edits_v3::Event,
};
use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task};
use language::{
- Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _, text_diff,
+ Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset, ToPoint as _, text_diff,
};
use project::{Project, ProjectPath};
use release_channel::AppVersion;
use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
+use zeta_prompt::{Event, ZetaPromptInput};
const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
@@ -29,24 +30,27 @@ pub(crate) const MAX_EVENT_TOKENS: usize = 500;
pub(crate) fn request_prediction_with_zeta1(
store: &mut EditPredictionStore,
- project: &Entity<Project>,
- buffer: &Entity<Buffer>,
- snapshot: BufferSnapshot,
- position: language::Anchor,
- events: Vec<Arc<Event>>,
- trigger: PredictEditsRequestTrigger,
+ EditPredictionModelInput {
+ project,
+ buffer,
+ snapshot,
+ position,
+ events,
+ trigger,
+ debug_tx,
+ ..
+ }: EditPredictionModelInput,
cx: &mut Context<EditPredictionStore>,
) -> Task<Result<Option<EditPredictionResult>>> {
- let buffer = buffer.clone();
let buffer_snapshotted_at = Instant::now();
let client = store.client.clone();
let llm_token = store.llm_token.clone();
let app_version = AppVersion::global(cx);
let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
- let can_collect_file = store.can_collect_file(project, file, cx);
+ let can_collect_file = store.can_collect_file(&project, file, cx);
let git_info = if can_collect_file {
- git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
+ git_info_for_file(&project, &ProjectPath::from_file(file.as_ref(), cx), cx)
} else {
None
};
@@ -120,33 +124,33 @@ pub(crate) fn request_prediction_with_zeta1(
)
.await;
- let inputs = EditPredictionInputs {
+ let context_start_offset = context_range.start.to_offset(&snapshot);
+ let editable_offset_range = editable_range.to_offset(&snapshot);
+
+ let inputs = ZetaPromptInput {
events: included_events.into(),
- included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile {
- path: full_path.clone(),
- max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
- excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
- start_line: cloud_llm_client::predict_edits_v3::Line(context_range.start.row),
- text: snapshot
- .text_for_range(context_range)
- .collect::<String>()
- .into(),
- }],
- }],
- cursor_point: cloud_llm_client::predict_edits_v3::Point {
- column: cursor_point.column,
- line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
- },
+ related_files: vec![].into(),
cursor_path: full_path,
+ cursor_excerpt: snapshot
+ .text_for_range(context_range)
+ .collect::<String>()
+ .into(),
+ editable_range_in_excerpt: (editable_range.start - context_start_offset)
+ ..(editable_offset_range.end - context_start_offset),
+ cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot) - context_start_offset,
};
- // let response = perform_predict_edits(PerformPredictEditsParams {
- // client,
- // llm_token,
- // app_version,
- // body,
- // })
- // .await;
+ if let Some(debug_tx) = &debug_tx {
+ debug_tx
+ .unbounded_send(DebugEvent::EditPredictionStarted(
+ EditPredictionStartedDebugEvent {
+ buffer: buffer.downgrade(),
+ prompt: Some(serde_json::to_string(&inputs).unwrap()),
+ position,
+ },
+ ))
+ .ok();
+ }
let (response, usage) = match response {
Ok(response) => response,
@@ -189,6 +193,18 @@ pub(crate) fn request_prediction_with_zeta1(
.ok();
}
+ if let Some(debug_tx) = &debug_tx {
+ debug_tx
+ .unbounded_send(DebugEvent::EditPredictionFinished(
+ EditPredictionFinishedDebugEvent {
+ buffer: buffer.downgrade(),
+ model_output: Some(response.output_excerpt.clone()),
+ position,
+ },
+ ))
+ .ok();
+ }
+
let edit_prediction = process_completion_response(
response,
buffer,
@@ -226,7 +242,7 @@ fn process_completion_response(
buffer: Entity<Buffer>,
snapshot: &BufferSnapshot,
editable_range: Range<usize>,
- inputs: EditPredictionInputs,
+ inputs: ZetaPromptInput,
buffer_snapshotted_at: Instant,
received_response_at: Instant,
cx: &AsyncApp,
@@ -3,46 +3,39 @@ use crate::EvalCacheEntryKind;
use crate::open_ai_response::text_from_response;
use crate::prediction::EditPredictionResult;
use crate::{
- DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionId, EditPredictionInputs,
- EditPredictionRequestedDebugEvent, EditPredictionStore,
+ DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionFinishedDebugEvent, EditPredictionId,
+ EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore,
};
-use anyhow::{Result, anyhow, bail};
-use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat};
-use cloud_llm_client::{EditPredictionRejectReason, PredictEditsRequestTrigger};
-use cloud_zeta2_prompt::CURSOR_MARKER;
-use edit_prediction_context::{EditPredictionExcerpt, Line};
-use edit_prediction_context::{RelatedExcerpt, RelatedFile};
-use futures::channel::oneshot;
-use gpui::{Entity, Task, prelude::*};
-use language::{Anchor, BufferSnapshot};
-use language::{Buffer, Point, ToOffset as _, ToPoint};
-use project::{Project, ProjectItem as _};
+use anyhow::{Result, anyhow};
+use cloud_llm_client::EditPredictionRejectReason;
+use gpui::{Task, prelude::*};
+use language::{OffsetRangeExt as _, ToOffset as _, ToPoint};
use release_channel::AppVersion;
-use std::{
- env,
- path::Path,
- sync::Arc,
- time::{Duration, Instant},
-};
+use std::{path::Path, sync::Arc, time::Instant};
+use zeta_prompt::CURSOR_MARKER;
+use zeta_prompt::format_zeta_prompt;
+
+const MAX_CONTEXT_TOKENS: usize = 150;
+const MAX_REWRITE_TOKENS: usize = 350;
pub fn request_prediction_with_zeta2(
store: &mut EditPredictionStore,
- project: &Entity<Project>,
- active_buffer: &Entity<Buffer>,
- active_snapshot: BufferSnapshot,
- position: Anchor,
- events: Vec<Arc<Event>>,
- mut included_files: Vec<RelatedFile>,
- trigger: PredictEditsRequestTrigger,
+ EditPredictionModelInput {
+ buffer,
+ snapshot,
+ position,
+ related_files,
+ events,
+ debug_tx,
+ ..
+ }: EditPredictionModelInput,
cx: &mut Context<EditPredictionStore>,
) -> Task<Result<Option<EditPredictionResult>>> {
- let options = store.options.clone();
let buffer_snapshotted_at = Instant::now();
- let Some((excerpt_path, active_project_path)) = active_snapshot
+ let Some(excerpt_path) = snapshot
.file()
.map(|file| -> Arc<Path> { file.full_path(cx).into() })
- .zip(active_buffer.read(cx).project_path(cx))
else {
return Task::ready(Err(anyhow!("No file path for excerpt")));
};
@@ -50,148 +43,35 @@ pub fn request_prediction_with_zeta2(
let client = store.client.clone();
let llm_token = store.llm_token.clone();
let app_version = AppVersion::global(cx);
- let debug_tx = store.debug_tx.clone();
-
- let file = active_buffer.read(cx).file();
-
- let active_file_full_path = file.as_ref().map(|f| f.full_path(cx));
-
- // TODO data collection
- let can_collect_data = file
- .as_ref()
- .map_or(false, |file| store.can_collect_file(project, file, cx));
#[cfg(feature = "eval-support")]
let eval_cache = store.eval_cache.clone();
let request_task = cx.background_spawn({
- let active_buffer = active_buffer.clone();
async move {
- let cursor_offset = position.to_offset(&active_snapshot);
- let cursor_point = cursor_offset.to_point(&active_snapshot);
-
- let before_retrieval = Instant::now();
-
- let excerpt_options = options.context;
-
- let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
- cursor_point,
- &active_snapshot,
- &excerpt_options,
- ) else {
- return Ok((None, None));
- };
-
- let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
- ..active_snapshot.anchor_before(excerpt.range.end);
- let related_excerpt = RelatedExcerpt {
- anchor_range: excerpt_anchor_range.clone(),
- point_range: Point::new(excerpt.line_range.start.0, 0)
- ..Point::new(excerpt.line_range.end.0, 0),
- text: active_snapshot.as_rope().slice(excerpt.range),
- };
-
- if let Some(buffer_ix) = included_files
- .iter()
- .position(|file| file.buffer.entity_id() == active_buffer.entity_id())
- {
- let file = &mut included_files[buffer_ix];
- file.excerpts.push(related_excerpt);
- file.merge_excerpts();
- let last_ix = included_files.len() - 1;
- included_files.swap(buffer_ix, last_ix);
- } else {
- let active_file = RelatedFile {
- path: active_project_path,
- buffer: active_buffer.downgrade(),
- excerpts: vec![related_excerpt],
- max_row: active_snapshot.max_point().row,
- };
- included_files.push(active_file);
- }
-
- let included_files = included_files
- .iter()
- .map(|related_file| predict_edits_v3::RelatedFile {
- path: Arc::from(related_file.path.path.as_std_path()),
- max_row: Line(related_file.max_row),
- excerpts: related_file
- .excerpts
- .iter()
- .map(|excerpt| predict_edits_v3::Excerpt {
- start_line: Line(excerpt.point_range.start.row),
- text: excerpt.text.to_string().into(),
- })
- .collect(),
- })
- .collect::<Vec<_>>();
-
- let cloud_request = predict_edits_v3::PredictEditsRequest {
- excerpt_path,
- excerpt: String::new(),
- excerpt_line_range: Line(0)..Line(0),
- excerpt_range: 0..0,
- cursor_point: predict_edits_v3::Point {
- line: predict_edits_v3::Line(cursor_point.row),
- column: cursor_point.column,
- },
- related_files: included_files,
+ let cursor_offset = position.to_offset(&snapshot);
+ let (editable_offset_range, prompt_input) = zeta2_prompt_input(
+ &snapshot,
+ related_files,
events,
- can_collect_data,
- debug_info: debug_tx.is_some(),
- prompt_max_bytes: Some(options.max_prompt_bytes),
- prompt_format: options.prompt_format,
- excerpt_parent: None,
- git_info: None,
- trigger,
- };
-
- let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
-
- let inputs = EditPredictionInputs {
- included_files: cloud_request.related_files,
- events: cloud_request.events,
- cursor_point: cloud_request.cursor_point,
- cursor_path: cloud_request.excerpt_path,
- };
-
- let retrieval_time = Instant::now() - before_retrieval;
+ excerpt_path,
+ cursor_offset,
+ );
- let debug_response_tx = if let Some(debug_tx) = &debug_tx {
- let (response_tx, response_rx) = oneshot::channel();
+ let prompt = format_zeta_prompt(&prompt_input);
+ if let Some(debug_tx) = &debug_tx {
debug_tx
- .unbounded_send(DebugEvent::EditPredictionRequested(
- EditPredictionRequestedDebugEvent {
- inputs: inputs.clone(),
- retrieval_time,
- buffer: active_buffer.downgrade(),
- local_prompt: match prompt_result.as_ref() {
- Ok(prompt) => Ok(prompt.clone()),
- Err(err) => Err(err.to_string()),
- },
+ .unbounded_send(DebugEvent::EditPredictionStarted(
+ EditPredictionStartedDebugEvent {
+ buffer: buffer.downgrade(),
+ prompt: Some(prompt.clone()),
position,
- response_rx,
},
))
.ok();
- Some(response_tx)
- } else {
- None
- };
-
- if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
- if let Some(debug_response_tx) = debug_response_tx {
- debug_response_tx
- .send((Err("Request skipped".to_string()), Duration::ZERO))
- .ok();
- }
- anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
}
- let prompt = prompt_result?;
- let generation_params =
- cloud_zeta2_prompt::generation_params(cloud_request.prompt_format);
let request = open_ai::Request {
model: EDIT_PREDICTIONS_MODEL_ID.clone(),
messages: vec![open_ai::RequestMessage::User {
@@ -199,8 +79,8 @@ pub fn request_prediction_with_zeta2(
}],
stream: false,
max_completion_tokens: None,
- stop: generation_params.stop.unwrap_or_default(),
- temperature: generation_params.temperature.or(Some(0.7)),
+ stop: Default::default(),
+ temperature: Default::default(),
tool_choice: None,
parallel_tool_calls: None,
tools: vec![],
@@ -210,7 +90,6 @@ pub fn request_prediction_with_zeta2(
log::trace!("Sending edit prediction request");
- let before_request = Instant::now();
let response = EditPredictionStore::send_raw_llm_request(
request,
client,
@@ -223,68 +102,53 @@ pub fn request_prediction_with_zeta2(
)
.await;
let received_response_at = Instant::now();
- let request_time = received_response_at - before_request;
log::trace!("Got edit prediction response");
- if let Some(debug_response_tx) = debug_response_tx {
- debug_response_tx
- .send((
- response
- .as_ref()
- .map_err(|err| err.to_string())
- .map(|response| response.0.clone()),
- request_time,
- ))
- .ok();
- }
-
let (res, usage) = response?;
let request_id = EditPredictionId(res.id.clone().into());
let Some(mut output_text) = text_from_response(res) else {
return Ok((Some((request_id, None)), usage));
};
+ if let Some(debug_tx) = &debug_tx {
+ debug_tx
+ .unbounded_send(DebugEvent::EditPredictionFinished(
+ EditPredictionFinishedDebugEvent {
+ buffer: buffer.downgrade(),
+ position,
+ model_output: Some(output_text.clone()),
+ },
+ ))
+ .ok();
+ }
+
if output_text.contains(CURSOR_MARKER) {
log::trace!("Stripping out {CURSOR_MARKER} from response");
output_text = output_text.replace(CURSOR_MARKER, "");
}
- let get_buffer_from_context = |path: &Path| {
- if Some(path) == active_file_full_path.as_deref() {
- Some((
- &active_snapshot,
- std::slice::from_ref(&excerpt_anchor_range),
- ))
- } else {
- None
- }
- };
-
- let (_, edits) = match options.prompt_format {
- PromptFormat::Minimal | PromptFormat::MinimalQwen | PromptFormat::SeedCoder1120 => {
- if output_text.contains("--- a/\n+++ b/\nNo edits") {
- let edits = vec![];
- (&active_snapshot, edits)
- } else {
- crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
- }
- }
- PromptFormat::OldTextNewText => {
- crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context).await?
- }
- _ => {
- bail!("unsupported prompt format {}", options.prompt_format)
- }
- };
+ let old_text = snapshot
+ .text_for_range(editable_offset_range.clone())
+ .collect::<String>();
+ let edits: Vec<_> = language::text_diff(&old_text, &output_text)
+ .into_iter()
+ .map(|(range, text)| {
+ (
+ snapshot.anchor_after(editable_offset_range.start + range.start)
+ ..snapshot.anchor_before(editable_offset_range.start + range.end),
+ text,
+ )
+ })
+ .collect();
anyhow::Ok((
Some((
request_id,
Some((
- inputs,
- active_buffer,
- active_snapshot.clone(),
+ prompt_input,
+ buffer,
+ snapshot.clone(),
edits,
received_response_at,
)),
@@ -325,3 +189,40 @@ pub fn request_prediction_with_zeta2(
))
})
}
+
+pub fn zeta2_prompt_input(
+ snapshot: &language::BufferSnapshot,
+ related_files: Arc<[zeta_prompt::RelatedFile]>,
+ events: Vec<Arc<zeta_prompt::Event>>,
+ excerpt_path: Arc<Path>,
+ cursor_offset: usize,
+) -> (std::ops::Range<usize>, zeta_prompt::ZetaPromptInput) {
+ let cursor_point = cursor_offset.to_point(snapshot);
+
+ let (editable_range, context_range) =
+ crate::cursor_excerpt::editable_and_context_ranges_for_cursor_position(
+ cursor_point,
+ snapshot,
+ MAX_CONTEXT_TOKENS,
+ MAX_REWRITE_TOKENS,
+ );
+
+ let context_start_offset = context_range.start.to_offset(snapshot);
+ let editable_offset_range = editable_range.to_offset(snapshot);
+ let cursor_offset_in_excerpt = cursor_offset - context_start_offset;
+ let editable_range_in_excerpt = (editable_offset_range.start - context_start_offset)
+ ..(editable_offset_range.end - context_start_offset);
+
+ let prompt_input = zeta_prompt::ZetaPromptInput {
+ cursor_path: excerpt_path,
+ cursor_excerpt: snapshot
+ .text_for_range(context_range)
+ .collect::<String>()
+ .into(),
+ editable_range_in_excerpt,
+ cursor_offset_in_excerpt,
+ events,
+ related_files,
+ };
+ (editable_offset_range, prompt_input)
+}
@@ -9,7 +9,7 @@ license = "GPL-3.0-or-later"
workspace = true
[[bin]]
-name = "ep_cli"
+name = "ep"
path = "src/main.rs"
[dependencies]
@@ -20,10 +20,9 @@ chrono.workspace = true
clap.workspace = true
client.workspace = true
cloud_llm_client.workspace= true
-cloud_zeta2_prompt.workspace = true
collections.workspace = true
debug_adapter_extension.workspace = true
-edit_prediction_context.workspace = true
+dirs.workspace = true
extension.workspace = true
fs.workspace = true
futures.workspace = true
@@ -51,12 +50,21 @@ smol.workspace = true
sqlez.workspace = true
sqlez_macros.workspace = true
terminal_view.workspace = true
-toml.workspace = true
util.workspace = true
watch.workspace = true
edit_prediction = { workspace = true, features = ["eval-support"] }
+wasmtime.workspace = true
+zeta_prompt.workspace = true
zlog.workspace = true
+# Wasmtime is included as a dependency in order to enable the same
+# features that are enabled in Zed.
+#
+# If we don't enable these features we get crashes when creating
+# a Tree-sitter WasmStore.
+[package.metadata.cargo-machete]
+ignored = ["wasmtime"]
+
[dev-dependencies]
indoc.workspace = true
gpui = { workspace = true, features = ["test-support"] }
@@ -5,11 +5,13 @@ use anthropic::{
use anyhow::Result;
use http_client::HttpClient;
use indoc::indoc;
+use reqwest_client::ReqwestClient;
use sqlez::bindable::Bind;
use sqlez::bindable::StaticColumnCount;
use sqlez_macros::sql;
use std::hash::Hash;
use std::hash::Hasher;
+use std::path::Path;
use std::sync::Arc;
pub struct PlainLlmClient {
@@ -18,7 +20,8 @@ pub struct PlainLlmClient {
}
impl PlainLlmClient {
- fn new(http_client: Arc<dyn HttpClient>) -> Result<Self> {
+ fn new() -> Result<Self> {
+ let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
let api_key = std::env::var("ANTHROPIC_API_KEY")
.map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
Ok(Self {
@@ -29,12 +32,12 @@ impl PlainLlmClient {
async fn generate(
&self,
- model: String,
+ model: &str,
max_tokens: u64,
messages: Vec<Message>,
) -> Result<AnthropicResponse> {
let request = AnthropicRequest {
- model,
+ model: model.to_string(),
max_tokens,
messages,
tools: Vec::new(),
@@ -105,11 +108,12 @@ struct SerializableMessage {
}
impl BatchingLlmClient {
- fn new(cache_path: &str, http_client: Arc<dyn HttpClient>) -> Result<Self> {
+ fn new(cache_path: &Path) -> Result<Self> {
+ let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
let api_key = std::env::var("ANTHROPIC_API_KEY")
.map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
- let connection = sqlez::connection::Connection::open_file(&cache_path);
+ let connection = sqlez::connection::Connection::open_file(&cache_path.to_str().unwrap());
let mut statement = sqlez::statement::Statement::prepare(
&connection,
indoc! {"
@@ -182,16 +186,16 @@ impl BatchingLlmClient {
async fn generate(
&self,
- model: String,
+ model: &str,
max_tokens: u64,
messages: Vec<Message>,
) -> Result<Option<AnthropicResponse>> {
- let response = self.lookup(&model, max_tokens, &messages)?;
+ let response = self.lookup(model, max_tokens, &messages)?;
if let Some(response) = response {
return Ok(Some(response));
}
- self.mark_for_batch(&model, max_tokens, &messages)?;
+ self.mark_for_batch(model, max_tokens, &messages)?;
Ok(None)
}
@@ -258,7 +262,7 @@ impl BatchingLlmClient {
}
}
}
- log::info!("Uploaded {} successful requests", success_count);
+ log::info!("Downloaded {} successful requests", success_count);
}
}
@@ -363,23 +367,20 @@ fn message_content_to_string(content: &[RequestContent]) -> String {
.join("\n")
}
-pub enum LlmClient {
+pub enum AnthropicClient {
// No batching
Plain(PlainLlmClient),
Batch(BatchingLlmClient),
Dummy,
}
-impl LlmClient {
- pub fn plain(http_client: Arc<dyn HttpClient>) -> Result<Self> {
- Ok(Self::Plain(PlainLlmClient::new(http_client)?))
+impl AnthropicClient {
+ pub fn plain() -> Result<Self> {
+ Ok(Self::Plain(PlainLlmClient::new()?))
}
- pub fn batch(cache_path: &str, http_client: Arc<dyn HttpClient>) -> Result<Self> {
- Ok(Self::Batch(BatchingLlmClient::new(
- cache_path,
- http_client,
- )?))
+ pub fn batch(cache_path: &Path) -> Result<Self> {
+ Ok(Self::Batch(BatchingLlmClient::new(cache_path)?))
}
#[allow(dead_code)]
@@ -389,29 +390,29 @@ impl LlmClient {
pub async fn generate(
&self,
- model: String,
+ model: &str,
max_tokens: u64,
messages: Vec<Message>,
) -> Result<Option<AnthropicResponse>> {
match self {
- LlmClient::Plain(plain_llm_client) => plain_llm_client
+ AnthropicClient::Plain(plain_llm_client) => plain_llm_client
.generate(model, max_tokens, messages)
.await
.map(Some),
- LlmClient::Batch(batching_llm_client) => {
+ AnthropicClient::Batch(batching_llm_client) => {
batching_llm_client
.generate(model, max_tokens, messages)
.await
}
- LlmClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
+ AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
}
}
pub async fn sync_batches(&self) -> Result<()> {
match self {
- LlmClient::Plain(_) => Ok(()),
- LlmClient::Batch(batching_llm_client) => batching_llm_client.sync_batches().await,
- LlmClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
+ AnthropicClient::Plain(_) => Ok(()),
+ AnthropicClient::Batch(batching_llm_client) => batching_llm_client.sync_batches().await,
+ AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
}
}
}
@@ -1,641 +0,0 @@
-use crate::metrics::{self, Scores};
-use std::{
- collections::HashMap,
- io::{IsTerminal, Write},
- sync::Arc,
-};
-
-use anyhow::Result;
-use edit_prediction::{EditPredictionStore, udiff::DiffLine};
-use gpui::{AsyncApp, Entity};
-use project::Project;
-use util::ResultExt as _;
-
-use crate::{
- EvaluateArguments, PredictionOptions,
- example::{Example, NamedExample},
- headless::ZetaCliAppState,
- paths::print_run_data_dir,
- predict::{PredictionDetails, perform_predict, setup_store},
-};
-
-#[derive(Debug)]
-pub(crate) struct ExecutionData {
- execution_id: String,
- diff: String,
- reasoning: String,
-}
-
-pub async fn run_evaluate(
- args: EvaluateArguments,
- app_state: &Arc<ZetaCliAppState>,
- cx: &mut AsyncApp,
-) {
- if args.example_paths.is_empty() {
- eprintln!("No examples provided");
- return;
- }
-
- let all_tasks = args.example_paths.into_iter().map(|path| {
- let options = args.options.clone();
- let app_state = app_state.clone();
- let example = NamedExample::load(&path).expect("Failed to load example");
-
- cx.spawn(async move |cx| {
- let project = example.setup_project(&app_state, cx).await.unwrap();
-
- let providers = (0..args.repetitions)
- .map(|_| setup_store(args.options.provider, &project, &app_state, cx).unwrap())
- .collect::<Vec<_>>();
-
- let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
-
- let tasks = providers
- .into_iter()
- .enumerate()
- .map(move |(repetition_ix, store)| {
- let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16);
- let example = example.clone();
- let project = project.clone();
- let options = options.clone();
-
- cx.spawn(async move |cx| {
- let name = example.name.clone();
- run_evaluate_one(
- example,
- repetition_ix,
- project,
- store,
- options,
- !args.skip_prediction,
- cx,
- )
- .await
- .map_err(|err| (err, name, repetition_ix))
- })
- });
- futures::future::join_all(tasks).await
- })
- });
- let all_results = futures::future::join_all(all_tasks).await;
-
- write_aggregated_scores(&mut std::io::stdout(), &all_results).unwrap();
- if let Some(mut output_file) =
- std::fs::File::create(crate::paths::RUN_DIR.join("aggregated_results.md")).log_err()
- {
- write_aggregated_scores(&mut output_file, &all_results).log_err();
- };
-
- if args.repetitions > 1 {
- if let Err(e) = write_bucketed_analysis(&all_results) {
- eprintln!("Failed to write bucketed analysis: {:?}", e);
- }
- }
-
- print_run_data_dir(args.repetitions == 1, std::io::stdout().is_terminal());
-}
-
-fn write_aggregated_scores(
- w: &mut impl std::io::Write,
- all_results: &Vec<
- Vec<Result<(EvaluationResult, ExecutionData), (anyhow::Error, String, Option<u16>)>>,
- >,
-) -> Result<()> {
- let mut successful = Vec::new();
- let mut failed_count = 0;
-
- for result in all_results.iter().flatten() {
- match result {
- Ok((eval_result, _execution_data)) => successful.push(eval_result),
- Err((err, name, repetition_ix)) => {
- if failed_count == 0 {
- writeln!(w, "## Errors\n")?;
- }
-
- failed_count += 1;
- writeln!(w, "{}", fmt_evaluation_error(err, name, repetition_ix))?;
- }
- }
- }
-
- if successful.len() > 1 {
- let edit_scores = successful
- .iter()
- .filter_map(|r| r.edit_scores.clone())
- .collect::<Vec<_>>();
- let has_edit_predictions = edit_scores.len() > 0;
- let aggregated_result = EvaluationResult {
- context_scores: Scores::aggregate(successful.iter().map(|r| &r.context_scores)),
- edit_scores: has_edit_predictions.then(|| EditScores::aggregate(&edit_scores)),
- prompt_len: successful.iter().map(|r| r.prompt_len).sum::<usize>() / successful.len(),
- generated_len: successful.iter().map(|r| r.generated_len).sum::<usize>()
- / successful.len(),
- };
-
- writeln!(w, "\n{}", "-".repeat(80))?;
- writeln!(w, "\n## TOTAL SCORES")?;
- writeln!(w, "{:#}", aggregated_result)?;
- }
-
- if successful.len() + failed_count > 1 {
- writeln!(
- w,
- "\nCongratulations! {}/{} ({:.2}%) of runs weren't outright failures 🎉",
- successful.len(),
- successful.len() + failed_count,
- (successful.len() as f64 / (successful.len() + failed_count) as f64) * 100.0
- )?;
- }
-
- Ok(())
-}
-
-pub async fn run_evaluate_one(
- example: NamedExample,
- repetition_ix: Option<u16>,
- project: Entity<Project>,
- store: Entity<EditPredictionStore>,
- prediction_options: PredictionOptions,
- predict: bool,
- cx: &mut AsyncApp,
-) -> Result<(EvaluationResult, ExecutionData)> {
- let predict_result = perform_predict(
- example.clone(),
- project,
- store,
- repetition_ix,
- prediction_options,
- cx,
- )
- .await?;
-
- let evaluation_result = evaluate(&example.example, &predict_result, predict);
-
- if repetition_ix.is_none() {
- write_eval_result(
- &example,
- &predict_result,
- &evaluation_result,
- &mut std::io::stdout(),
- std::io::stdout().is_terminal(),
- predict,
- )?;
- }
-
- if let Some(mut results_file) =
- std::fs::File::create(predict_result.run_example_dir.join("results.md")).log_err()
- {
- write_eval_result(
- &example,
- &predict_result,
- &evaluation_result,
- &mut results_file,
- false,
- predict,
- )
- .log_err();
- }
-
- let execution_data = ExecutionData {
- execution_id: if let Some(rep_ix) = repetition_ix {
- format!("{:03}", rep_ix)
- } else {
- example.name.clone()
- },
- diff: predict_result.diff.clone(),
- reasoning: std::fs::read_to_string(
- predict_result
- .run_example_dir
- .join("prediction_response.md"),
- )
- .unwrap_or_default(),
- };
-
- anyhow::Ok((evaluation_result, execution_data))
-}
-
-fn write_eval_result(
- example: &NamedExample,
- predictions: &PredictionDetails,
- evaluation_result: &EvaluationResult,
- out: &mut impl Write,
- use_color: bool,
- predict: bool,
-) -> Result<()> {
- if predict {
- writeln!(
- out,
- "## Expected edit prediction:\n\n```diff\n{}\n```\n",
- compare_diffs(
- &example.example.expected_patch,
- &predictions.diff,
- use_color
- )
- )?;
- writeln!(
- out,
- "## Actual edit prediction:\n\n```diff\n{}\n```\n",
- compare_diffs(
- &predictions.diff,
- &example.example.expected_patch,
- use_color
- )
- )?;
- }
-
- writeln!(out, "{:#}", evaluation_result)?;
-
- anyhow::Ok(())
-}
-
-#[derive(Debug, Default, Clone)]
-pub struct EditScores {
- pub line_match: Scores,
- pub chr_f: f64,
-}
-
-impl EditScores {
- pub fn aggregate(scores: &[EditScores]) -> EditScores {
- let line_match = Scores::aggregate(scores.iter().map(|s| &s.line_match));
- let chr_f = scores.iter().map(|s| s.chr_f).sum::<f64>() / scores.len() as f64;
-
- EditScores { line_match, chr_f }
- }
-}
-
-#[derive(Debug, Default)]
-pub struct EvaluationResult {
- pub edit_scores: Option<EditScores>,
- pub context_scores: Scores,
- pub prompt_len: usize,
- pub generated_len: usize,
-}
-
-impl std::fmt::Display for EvaluationResult {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- if f.alternate() {
- self.fmt_table(f)
- } else {
- self.fmt_markdown(f)
- }
- }
-}
-
-impl EvaluationResult {
- fn fmt_markdown(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- write!(
- f,
- r#"
-### Context Scores
-{}
-"#,
- self.context_scores.to_markdown(),
- )?;
- if let Some(scores) = &self.edit_scores {
- write!(
- f,
- r#"
- ### Edit Prediction Scores
- {}"#,
- scores.line_match.to_markdown()
- )?;
- }
- Ok(())
- }
-
- fn fmt_table(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- writeln!(f, "#### Prompt Statistics")?;
- writeln!(f, "─────────────────────────")?;
- writeln!(f, "Prompt_len Generated_len")?;
- writeln!(f, "─────────────────────────")?;
- writeln!(f, "{:<11} {:<14}", self.prompt_len, self.generated_len,)?;
- writeln!(f)?;
- writeln!(f)?;
- writeln!(f, "#### Performance Scores")?;
- writeln!(
- f,
- "──────────────────────────────────────────────────────────────────"
- )?;
- writeln!(
- f,
- " TP FP FN Precision Recall F1"
- )?;
- writeln!(
- f,
- "──────────────────────────────────────────────────────────────────"
- )?;
- writeln!(
- f,
- "Context Retrieval {:<6} {:<6} {:<6} {:>8.2} {:>7.2} {:>6.2}",
- self.context_scores.true_positives,
- self.context_scores.false_positives,
- self.context_scores.false_negatives,
- self.context_scores.precision() * 100.0,
- self.context_scores.recall() * 100.0,
- self.context_scores.f1_score() * 100.0
- )?;
- if let Some(edit_scores) = &self.edit_scores {
- let line_match = &edit_scores.line_match;
- writeln!(f, "Edit Prediction")?;
- writeln!(
- f,
- " ├─ exact lines {:<6} {:<6} {:<6} {:>8.2} {:>7.2} {:>6.2}",
- line_match.true_positives,
- line_match.false_positives,
- line_match.false_negatives,
- line_match.precision() * 100.0,
- line_match.recall() * 100.0,
- line_match.f1_score() * 100.0
- )?;
- writeln!(
- f,
- " └─ diff chrF {:<6} {:<6} {:<6} {:>8} {:>8} {:>6.2}",
- "-", "-", "-", "-", "-", edit_scores.chr_f
- )?;
- }
- Ok(())
- }
-}
-
-fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> EvaluationResult {
- let mut eval_result = EvaluationResult {
- prompt_len: preds.prompt_len,
- generated_len: preds.generated_len,
- ..Default::default()
- };
-
- if predict {
- // todo: alternatives for patches
- let expected_patch = example
- .expected_patch
- .lines()
- .map(DiffLine::parse)
- .collect::<Vec<_>>();
- let actual_patch = preds.diff.lines().map(DiffLine::parse).collect::<Vec<_>>();
-
- let line_match = metrics::line_match_score(&expected_patch, &actual_patch);
- let chr_f = metrics::delta_chr_f(&expected_patch, &actual_patch);
-
- eval_result.edit_scores = Some(EditScores { line_match, chr_f });
- }
-
- eval_result
-}
-
-/// Return annotated `patch_a` so that:
-/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
-/// Additions and deletions that are present in `patch_b` will be highlighted in green.
-pub fn compare_diffs(patch_a: &str, patch_b: &str, use_color: bool) -> String {
- let green = if use_color { "\x1b[32m✓ " } else { "" };
- let red = if use_color { "\x1b[31m✗ " } else { "" };
- let neutral = if use_color { " " } else { "" };
- let reset = if use_color { "\x1b[0m" } else { "" };
- let lines_a = patch_a.lines().map(DiffLine::parse);
- let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect();
-
- let annotated = lines_a
- .map(|line| match line {
- DiffLine::Addition(_) | DiffLine::Deletion(_) => {
- if lines_b.contains(&line) {
- format!("{green}{line}{reset}")
- } else {
- format!("{red}{line}{reset}")
- }
- }
- _ => format!("{neutral}{line}{reset}"),
- })
- .collect::<Vec<String>>();
-
- annotated.join("\n")
-}
-
-fn write_bucketed_analysis(
- all_results: &Vec<
- Vec<Result<(EvaluationResult, ExecutionData), (anyhow::Error, String, Option<u16>)>>,
- >,
-) -> Result<()> {
- #[derive(Debug)]
- struct EditBucket {
- diff: String,
- is_correct: bool,
- execution_indices: Vec<String>,
- reasoning_samples: Vec<String>,
- }
-
- let mut total_executions = 0;
- let mut empty_predictions = Vec::new();
- let mut errors = Vec::new();
-
- let mut buckets: HashMap<String, EditBucket> = HashMap::new();
-
- for result in all_results.iter().flatten() {
- total_executions += 1;
-
- let (evaluation_result, execution_data) = match result {
- Ok((eval_result, execution_data)) => {
- if execution_data.diff.is_empty() {
- empty_predictions.push(execution_data);
- continue;
- }
- (eval_result, execution_data)
- }
- Err(err) => {
- errors.push(err);
- continue;
- }
- };
-
- buckets
- .entry(execution_data.diff.clone())
- .and_modify(|bucket| {
- bucket
- .execution_indices
- .push(execution_data.execution_id.clone());
- bucket
- .reasoning_samples
- .push(execution_data.reasoning.clone());
- })
- .or_insert_with(|| EditBucket {
- diff: execution_data.diff.clone(),
- is_correct: {
- evaluation_result
- .edit_scores
- .as_ref()
- .map_or(false, |edit_scores| {
- edit_scores.line_match.false_positives == 0
- && edit_scores.line_match.false_negatives == 0
- && edit_scores.line_match.true_positives > 0
- })
- },
- execution_indices: vec![execution_data.execution_id.clone()],
- reasoning_samples: vec![execution_data.reasoning.clone()],
- });
- }
-
- let mut sorted_buckets = buckets.into_values().collect::<Vec<_>>();
- sorted_buckets.sort_by(|a, b| match (a.is_correct, b.is_correct) {
- (true, false) => std::cmp::Ordering::Less,
- (false, true) => std::cmp::Ordering::Greater,
- _ => b.execution_indices.len().cmp(&a.execution_indices.len()),
- });
-
- let output_path = crate::paths::RUN_DIR.join("bucketed_analysis.md");
- let mut output = std::fs::File::create(&output_path)?;
-
- writeln!(output, "# Bucketed Edit Analysis\n")?;
-
- writeln!(output, "## Summary\n")?;
- writeln!(output, "- **Total executions**: {}", total_executions)?;
-
- let correct_count: usize = sorted_buckets
- .iter()
- .filter(|b| b.is_correct)
- .map(|b| b.execution_indices.len())
- .sum();
-
- let incorrect_count: usize = sorted_buckets
- .iter()
- .filter(|b| !b.is_correct)
- .map(|b| b.execution_indices.len())
- .sum();
-
- writeln!(
- output,
- "- **Correct predictions**: {} ({:.1}%)",
- correct_count,
- (correct_count as f64 / total_executions as f64) * 100.0
- )?;
-
- writeln!(
- output,
- "- **Incorrect predictions**: {} ({:.1}%)",
- incorrect_count,
- (incorrect_count as f64 / total_executions as f64) * 100.0
- )?;
-
- writeln!(
- output,
- "- **No Predictions**: {} ({:.1}%)",
- empty_predictions.len(),
- (empty_predictions.len() as f64 / total_executions as f64) * 100.0
- )?;
-
- let unique_incorrect = sorted_buckets.iter().filter(|b| !b.is_correct).count();
- writeln!(
- output,
- "- **Unique incorrect edit patterns**: {}\n",
- unique_incorrect
- )?;
-
- writeln!(output, "---\n")?;
-
- for (idx, bucket) in sorted_buckets.iter().filter(|b| b.is_correct).enumerate() {
- if idx == 0 {
- writeln!(
- output,
- "## Correct Predictions ({} occurrences)\n",
- bucket.execution_indices.len()
- )?;
- }
-
- writeln!(output, "**Predicted Edit:**\n")?;
- writeln!(output, "```diff")?;
- writeln!(output, "{}", bucket.diff)?;
- writeln!(output, "```\n")?;
-
- writeln!(
- output,
- "**Executions:** {}\n",
- bucket.execution_indices.join(", ")
- )?;
- writeln!(output, "---\n")?;
- }
-
- for (idx, bucket) in sorted_buckets.iter().filter(|b| !b.is_correct).enumerate() {
- writeln!(
- output,
- "## Incorrect Prediction #{} ({} occurrences)\n",
- idx + 1,
- bucket.execution_indices.len()
- )?;
-
- writeln!(output, "**Predicted Edit:**\n")?;
- writeln!(output, "```diff")?;
- writeln!(output, "{}", bucket.diff)?;
- writeln!(output, "```\n")?;
-
- writeln!(
- output,
- "**Executions:** {}\n",
- bucket.execution_indices.join(", ")
- )?;
-
- for (exec_id, reasoning) in bucket
- .execution_indices
- .iter()
- .zip(bucket.reasoning_samples.iter())
- {
- writeln!(output, "{}", fmt_execution(exec_id, reasoning))?;
- }
-
- writeln!(output, "\n---\n")?;
- }
-
- if !empty_predictions.is_empty() {
- writeln!(
- output,
- "## No Predictions ({} occurrences)\n",
- empty_predictions.len()
- )?;
-
- for execution_data in &empty_predictions {
- writeln!(
- output,
- "{}",
- fmt_execution(&execution_data.execution_id, &execution_data.reasoning)
- )?;
- }
- writeln!(output, "\n---\n")?;
- }
-
- if !errors.is_empty() {
- writeln!(output, "## Errors ({} occurrences)\n", errors.len())?;
-
- for (err, name, repetition_ix) in &errors {
- writeln!(output, "{}", fmt_evaluation_error(err, name, repetition_ix))?;
- }
- writeln!(output, "\n---\n")?;
- }
-
- fn fmt_execution(exec_id: &str, reasoning: &str) -> String {
- let exec_content = format!(
- "\n### Execution {} `{}/{}/prediction_response.md`{}",
- exec_id,
- crate::paths::RUN_DIR.display(),
- exec_id,
- indent_text(&format!("\n\n```\n{}\n```\n", reasoning,), 2)
- );
- indent_text(&exec_content, 2)
- }
-
- fn indent_text(text: &str, spaces: usize) -> String {
- let indent = " ".repeat(spaces);
- text.lines()
- .collect::<Vec<_>>()
- .join(&format!("\n{}", indent))
- }
-
- Ok(())
-}
-
-fn fmt_evaluation_error(err: &anyhow::Error, name: &str, repetition_ix: &Option<u16>) -> String {
- let err = format!("{err:?}")
- .replace("<edits", "```xml\n<edits")
- .replace("</edits>", "</edits>\n```");
- format!(
- "### ERROR {name}{}\n\n{err}\n",
- repetition_ix
- .map(|ix| format!(" [RUN {ix:03}]"))
- .unwrap_or_default()
- )
-}
@@ -1,59 +1,103 @@
+use crate::{
+ PredictionProvider, PromptFormat,
+ metrics::ClassificationMetrics,
+ paths::{REPOS_DIR, WORKTREES_DIR},
+};
+use anyhow::{Context as _, Result};
+use edit_prediction::udiff::OpenedBuffers;
+use gpui::Entity;
+use http_client::Url;
+use language::{Anchor, Buffer};
+use project::Project;
+use serde::{Deserialize, Serialize};
+use std::sync::Arc;
use std::{
borrow::Cow,
- cell::RefCell,
- fmt::{self, Display},
- fs,
- hash::Hash,
- hash::Hasher,
- io::Write,
+ io::{Read, Write},
mem,
path::{Path, PathBuf},
- sync::{Arc, OnceLock},
};
+use zeta_prompt::RelatedFile;
-use crate::headless::ZetaCliAppState;
-use anyhow::{Context as _, Result, anyhow};
-use clap::ValueEnum;
-use cloud_zeta2_prompt::CURSOR_MARKER;
-use collections::HashMap;
-use edit_prediction::udiff::OpenedBuffers;
-use futures::{
- AsyncWriteExt as _,
- lock::{Mutex, OwnedMutexGuard},
-};
-use futures::{FutureExt as _, future::Shared};
-use gpui::{AsyncApp, Entity, Task, http_client::Url};
-use language::{Anchor, Buffer};
-use project::{Project, ProjectPath};
-use pulldown_cmark::CowStr;
-use serde::{Deserialize, Serialize};
-use util::{paths::PathStyle, rel_path::RelPath};
-
-use crate::paths::{REPOS_DIR, WORKTREES_DIR};
-
-const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
-const EDIT_HISTORY_HEADING: &str = "Edit History";
-const CURSOR_POSITION_HEADING: &str = "Cursor Position";
-const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
-const EXPECTED_CONTEXT_HEADING: &str = "Expected Context";
-const REPOSITORY_URL_FIELD: &str = "repository_url";
-const REVISION_FIELD: &str = "revision";
-
-#[derive(Debug, Clone)]
-pub struct NamedExample {
- pub name: String,
- pub example: Example,
-}
-
-#[derive(Clone, Debug, Hash, Serialize, Deserialize)]
+#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Example {
+ #[serde(default)]
+ pub name: String,
pub repository_url: String,
pub revision: String,
pub uncommitted_diff: String,
- pub cursor_path: PathBuf,
+ pub cursor_path: Arc<Path>,
pub cursor_position: String,
pub edit_history: String,
pub expected_patch: String,
+
+ /// The full content of the file where an edit is being predicted, and the
+ /// actual cursor offset.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub buffer: Option<ExampleBuffer>,
+
+ /// The context retrieved for the prediction. This requires the worktree to
+ /// be loaded and the language server to be started.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub context: Option<ExampleContext>,
+
+ /// The input and expected output from the edit prediction model.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub prompt: Option<ExamplePrompt>,
+
+ /// The actual predictions from the model.
+ #[serde(default, skip_serializing_if = "Vec::is_empty")]
+ pub predictions: Vec<ExamplePrediction>,
+
+ /// The scores, for how well the actual predictions match the expected
+ /// predictions.
+ #[serde(default, skip_serializing_if = "Vec::is_empty")]
+ pub score: Vec<ExampleScore>,
+
+ /// The application state used to process this example.
+ #[serde(skip)]
+ pub state: Option<ExampleState>,
+}
+
+#[derive(Clone, Debug)]
+pub struct ExampleState {
+ pub project: Entity<Project>,
+ pub buffer: Entity<Buffer>,
+ pub cursor_position: Anchor,
+ pub _open_buffers: OpenedBuffers,
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct ExampleContext {
+ pub files: Arc<[RelatedFile]>,
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct ExampleBuffer {
+ pub content: String,
+ pub cursor_row: u32,
+ pub cursor_column: u32,
+ pub cursor_offset: usize,
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct ExamplePrompt {
+ pub input: String,
+ pub expected_output: String,
+ pub format: PromptFormat,
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct ExamplePrediction {
+ pub actual_patch: String,
+ pub actual_output: String,
+ pub provider: PredictionProvider,
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct ExampleScore {
+ pub delta_chr_f: f32,
+ pub line_match: ClassificationMetrics,
}
impl Example {
@@ -90,485 +134,244 @@ impl Example {
}
}
- pub async fn setup_worktree(&self, file_name: String) -> Result<PathBuf> {
- let (repo_owner, repo_name) = self.repo_name()?;
-
- let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
- let repo_lock = lock_repo(&repo_dir).await;
+ pub fn worktree_path(&self) -> PathBuf {
+ WORKTREES_DIR
+ .join(&self.name)
+ .join(self.repo_name().unwrap().1.as_ref())
+ }
- if !repo_dir.is_dir() {
- fs::create_dir_all(&repo_dir)?;
- run_git(&repo_dir, &["init"]).await?;
- run_git(
- &repo_dir,
- &["remote", "add", "origin", &self.repository_url],
- )
- .await?;
- }
+ pub fn repo_path(&self) -> PathBuf {
+ let (repo_owner, repo_name) = self.repo_name().expect("failed to get repo name");
+ REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref())
+ }
+}
- // Resolve the example to a revision, fetching it if needed.
- let revision = run_git(
- &repo_dir,
- &["rev-parse", &format!("{}^{{commit}}", self.revision)],
- )
- .await;
- let revision = if let Ok(revision) = revision {
- revision
+pub fn read_examples(inputs: &[PathBuf]) -> Vec<Example> {
+ let mut examples = Vec::new();
+
+ let stdin_path: PathBuf = PathBuf::from("-");
+
+ let inputs = if inputs.is_empty() {
+ &[stdin_path]
+ } else {
+ inputs
+ };
+
+ for path in inputs {
+ let is_stdin = path.as_path() == Path::new("-");
+ let content = if is_stdin {
+ let mut buffer = String::new();
+ std::io::stdin()
+ .read_to_string(&mut buffer)
+ .expect("Failed to read from stdin");
+ buffer
} else {
- if run_git(
- &repo_dir,
- &["fetch", "--depth", "1", "origin", &self.revision],
- )
- .await
- .is_err()
- {
- run_git(&repo_dir, &["fetch", "origin"]).await?;
- }
- let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?;
- if revision != self.revision {
- run_git(&repo_dir, &["tag", &self.revision, &revision]).await?;
- }
- revision
+ std::fs::read_to_string(path)
+ .unwrap_or_else(|_| panic!("Failed to read path: {:?}", &path))
};
-
- // Create the worktree for this example if needed.
- let worktree_path = WORKTREES_DIR.join(&file_name).join(repo_name.as_ref());
- if worktree_path.is_dir() {
- run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
- run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
- run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
+ let filename = path.file_stem().unwrap().to_string_lossy().to_string();
+ let ext = if !is_stdin {
+ path.extension()
+ .map(|ext| ext.to_string_lossy().to_string())
+ .unwrap_or_else(|| panic!("{} should have an extension", path.display()))
} else {
- let worktree_path_string = worktree_path.to_string_lossy();
- run_git(&repo_dir, &["branch", "-f", &file_name, revision.as_str()]).await?;
- run_git(
- &repo_dir,
- &["worktree", "add", "-f", &worktree_path_string, &file_name],
- )
- .await?;
- }
- drop(repo_lock);
-
- // Apply the uncommitted diff for this example.
- if !self.uncommitted_diff.is_empty() {
- let mut apply_process = smol::process::Command::new("git")
- .current_dir(&worktree_path)
- .args(&["apply", "-"])
- .stdin(std::process::Stdio::piped())
- .spawn()?;
-
- let mut stdin = apply_process.stdin.take().unwrap();
- stdin.write_all(self.uncommitted_diff.as_bytes()).await?;
- stdin.close().await?;
- drop(stdin);
-
- let apply_result = apply_process.output().await?;
- if !apply_result.status.success() {
- anyhow::bail!(
- "Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
- apply_result.status,
- String::from_utf8_lossy(&apply_result.stderr),
- String::from_utf8_lossy(&apply_result.stdout),
- );
+ "jsonl".to_string()
+ };
+
+ match ext.as_ref() {
+ "json" => {
+ let mut example =
+ serde_json::from_str::<Example>(&content).unwrap_or_else(|error| {
+ panic!("Failed to parse example file: {}\n{error}", path.display())
+ });
+ if example.name.is_empty() {
+ example.name = filename;
+ }
+ examples.push(example);
+ }
+ "jsonl" => examples.extend(
+ content
+ .lines()
+ .enumerate()
+ .map(|(line_ix, line)| {
+ let mut example =
+ serde_json::from_str::<Example>(line).unwrap_or_else(|_| {
+ panic!(
+ "Failed to parse example on {}:{}",
+ path.display(),
+ line_ix + 1
+ )
+ });
+ if example.name.is_empty() {
+ example.name = format!("{filename}-{line_ix}")
+ }
+ example
+ })
+ .collect::<Vec<Example>>(),
+ ),
+ "md" => {
+ examples.push(parse_markdown_example(filename, &content).unwrap());
+ }
+ ext => {
+ panic!("{} has invalid example extension `{ext}`", path.display())
}
}
-
- Ok(worktree_path)
- }
-
- pub fn unique_name(&self) -> String {
- let mut hasher = std::hash::DefaultHasher::new();
- self.hash(&mut hasher);
- let disambiguator = hasher.finish();
- let hash = format!("{:04x}", disambiguator);
- format!("{}_{}", &self.revision[..8], &hash[..4])
}
+ examples
}
-pub type ActualExcerpt = Excerpt;
-
-#[derive(Clone, Debug, Serialize, Deserialize)]
-pub struct Excerpt {
- pub path: PathBuf,
- pub text: String,
-}
-
-#[derive(ValueEnum, Debug, Clone)]
-pub enum ExampleFormat {
- Json,
- Toml,
- Md,
+pub fn write_examples(examples: &[Example], output_path: Option<&PathBuf>) {
+ let mut content = String::new();
+ for example in examples {
+ let line = serde_json::to_string(example).unwrap();
+ content.push_str(&line);
+ content.push('\n');
+ }
+ if let Some(output_path) = output_path {
+ std::fs::write(output_path, content).expect("Failed to write examples");
+ } else {
+ std::io::stdout().write_all(&content.as_bytes()).unwrap();
+ }
}
-impl NamedExample {
- pub fn load(path: impl AsRef<Path>) -> Result<Self> {
- let path = path.as_ref();
- let content = std::fs::read_to_string(path)?;
- let ext = path.extension();
-
- match ext.and_then(|s| s.to_str()) {
- Some("json") => Ok(Self {
- name: path.file_stem().unwrap_or_default().display().to_string(),
- example: serde_json::from_str(&content)?,
- }),
- Some("toml") => Ok(Self {
- name: path.file_stem().unwrap_or_default().display().to_string(),
- example: toml::from_str(&content)?,
- }),
- Some("md") => Self::parse_md(&content),
- Some(_) => {
- anyhow::bail!("Unrecognized example extension: {}", ext.unwrap().display());
- }
- None => {
- anyhow::bail!(
- "Failed to determine example type since the file does not have an extension."
- );
- }
- }
+fn parse_markdown_example(id: String, input: &str) -> Result<Example> {
+ use pulldown_cmark::{CodeBlockKind, CowStr, Event, HeadingLevel, Parser, Tag, TagEnd};
+
+ const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
+ const EDIT_HISTORY_HEADING: &str = "Edit History";
+ const CURSOR_POSITION_HEADING: &str = "Cursor Position";
+ const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
+ const EXPECTED_CONTEXT_HEADING: &str = "Expected Context";
+ const REPOSITORY_URL_FIELD: &str = "repository_url";
+ const REVISION_FIELD: &str = "revision";
+
+ let parser = Parser::new(input);
+
+ let mut example = Example {
+ name: id,
+ repository_url: String::new(),
+ revision: String::new(),
+ uncommitted_diff: String::new(),
+ cursor_path: PathBuf::new().into(),
+ cursor_position: String::new(),
+ edit_history: String::new(),
+ expected_patch: String::new(),
+ buffer: None,
+ context: None,
+ prompt: None,
+ predictions: Vec::new(),
+ score: Vec::new(),
+ state: None,
+ };
+
+ let mut name = String::new();
+ let mut text = String::new();
+ let mut block_info: CowStr = "".into();
+
+ #[derive(PartialEq)]
+ enum Section {
+ UncommittedDiff,
+ EditHistory,
+ CursorPosition,
+ ExpectedExcerpts,
+ ExpectedPatch,
+ Other,
}
- pub fn parse_md(input: &str) -> Result<Self> {
- use pulldown_cmark::{CodeBlockKind, Event, HeadingLevel, Parser, Tag, TagEnd};
-
- let parser = Parser::new(input);
-
- let mut named = NamedExample {
- name: String::new(),
- example: Example {
- repository_url: String::new(),
- revision: String::new(),
- uncommitted_diff: String::new(),
- cursor_path: PathBuf::new(),
- cursor_position: String::new(),
- edit_history: String::new(),
- expected_patch: String::new(),
- },
- };
+ let mut current_section = Section::Other;
- let mut text = String::new();
- let mut block_info: CowStr = "".into();
-
- #[derive(PartialEq)]
- enum Section {
- UncommittedDiff,
- EditHistory,
- CursorPosition,
- ExpectedExcerpts,
- ExpectedPatch,
- Other,
- }
+ for event in parser {
+ match event {
+ Event::Text(line) => {
+ text.push_str(&line);
- let mut current_section = Section::Other;
-
- for event in parser {
- match event {
- Event::Text(line) => {
- text.push_str(&line);
-
- if !named.name.is_empty()
- && current_section == Section::Other
- // in h1 section
- && let Some((field, value)) = line.split_once('=')
- {
- match field.trim() {
- REPOSITORY_URL_FIELD => {
- named.example.repository_url = value.trim().to_string();
- }
- REVISION_FIELD => {
- named.example.revision = value.trim().to_string();
- }
- _ => {}
- }
- }
- }
- Event::End(TagEnd::Heading(HeadingLevel::H1)) => {
- if !named.name.is_empty() {
- anyhow::bail!(
- "Found multiple H1 headings. There should only be one with the name of the example."
- );
- }
- named.name = mem::take(&mut text);
- }
- Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
- let title = mem::take(&mut text);
- current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
- Section::UncommittedDiff
- } else if title.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
- Section::EditHistory
- } else if title.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
- Section::CursorPosition
- } else if title.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
- Section::ExpectedPatch
- } else if title.eq_ignore_ascii_case(EXPECTED_CONTEXT_HEADING) {
- Section::ExpectedExcerpts
- } else {
- Section::Other
- };
- }
- Event::End(TagEnd::Heading(HeadingLevel::H3)) => {
- mem::take(&mut text);
- }
- Event::End(TagEnd::Heading(HeadingLevel::H4)) => {
- mem::take(&mut text);
- }
- Event::End(TagEnd::Heading(level)) => {
- anyhow::bail!("Unexpected heading level: {level}");
- }
- Event::Start(Tag::CodeBlock(kind)) => {
- match kind {
- CodeBlockKind::Fenced(info) => {
- block_info = info;
- }
- CodeBlockKind::Indented => {
- anyhow::bail!("Unexpected indented codeblock");
- }
- };
- }
- Event::Start(_) => {
- text.clear();
- block_info = "".into();
- }
- Event::End(TagEnd::CodeBlock) => {
- let block_info = block_info.trim();
- match current_section {
- Section::UncommittedDiff => {
- named.example.uncommitted_diff = mem::take(&mut text);
- }
- Section::EditHistory => {
- named.example.edit_history.push_str(&mem::take(&mut text));
- }
- Section::CursorPosition => {
- named.example.cursor_path = block_info.into();
- named.example.cursor_position = mem::take(&mut text);
- }
- Section::ExpectedExcerpts => {
- mem::take(&mut text);
+ if let Some((field, value)) = line.split_once('=') {
+ match field.trim() {
+ REPOSITORY_URL_FIELD => {
+ example.repository_url = value.trim().to_string();
}
- Section::ExpectedPatch => {
- named.example.expected_patch = mem::take(&mut text);
+ REVISION_FIELD => {
+ example.revision = value.trim().to_string();
}
- Section::Other => {}
+ _ => {}
}
}
- _ => {}
}
- }
-
- if named.example.cursor_path.as_path() == Path::new("")
- || named.example.cursor_position.is_empty()
- {
- anyhow::bail!("Missing cursor position codeblock");
- }
-
- Ok(named)
- }
-
- pub fn write(&self, format: ExampleFormat, mut out: impl Write) -> Result<()> {
- match format {
- ExampleFormat::Json => Ok(serde_json::to_writer(out, &self.example)?),
- ExampleFormat::Toml => {
- Ok(out.write_all(toml::to_string_pretty(&self.example)?.as_bytes())?)
+ Event::End(TagEnd::Heading(HeadingLevel::H1)) => {
+ if !name.is_empty() {
+ anyhow::bail!(
+ "Found multiple H1 headings. There should only be one with the name of the example."
+ );
+ }
+ name = mem::take(&mut text);
}
- ExampleFormat::Md => Ok(write!(out, "{}", self)?),
- }
- }
-
- pub async fn setup_project(
- &self,
- app_state: &Arc<ZetaCliAppState>,
- cx: &mut AsyncApp,
- ) -> Result<Entity<Project>> {
- let worktree_path = self.setup_worktree().await?;
-
- static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
-
- AUTHENTICATED
- .get_or_init(|| {
- let client = app_state.client.clone();
- cx.spawn(async move |cx| {
- client
- .sign_in_with_optional_connect(true, cx)
- .await
- .unwrap();
- })
- .shared()
- })
- .clone()
- .await;
-
- let project = cx.update(|cx| {
- Project::local(
- app_state.client.clone(),
- app_state.node_runtime.clone(),
- app_state.user_store.clone(),
- app_state.languages.clone(),
- app_state.fs.clone(),
- None,
- cx,
- )
- })?;
-
- let worktree = project
- .update(cx, |project, cx| {
- project.create_worktree(&worktree_path, true, cx)
- })?
- .await?;
- worktree
- .read_with(cx, |worktree, _cx| {
- worktree.as_local().unwrap().scan_complete()
- })?
- .await;
-
- anyhow::Ok(project)
- }
-
- pub async fn setup_worktree(&self) -> Result<PathBuf> {
- self.example.setup_worktree(self.file_name()).await
- }
-
- pub fn file_name(&self) -> String {
- self.name
- .chars()
- .map(|c| {
- if c.is_whitespace() {
- '-'
+ Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
+ let title = mem::take(&mut text);
+ current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
+ Section::UncommittedDiff
+ } else if title.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
+ Section::EditHistory
+ } else if title.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
+ Section::CursorPosition
+ } else if title.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
+ Section::ExpectedPatch
+ } else if title.eq_ignore_ascii_case(EXPECTED_CONTEXT_HEADING) {
+ Section::ExpectedExcerpts
} else {
- c.to_ascii_lowercase()
+ Section::Other
+ };
+ }
+ Event::End(TagEnd::Heading(HeadingLevel::H3)) => {
+ mem::take(&mut text);
+ }
+ Event::End(TagEnd::Heading(HeadingLevel::H4)) => {
+ mem::take(&mut text);
+ }
+ Event::End(TagEnd::Heading(level)) => {
+ anyhow::bail!("Unexpected heading level: {level}");
+ }
+ Event::Start(Tag::CodeBlock(kind)) => {
+ match kind {
+ CodeBlockKind::Fenced(info) => {
+ block_info = info;
+ }
+ CodeBlockKind::Indented => {
+ anyhow::bail!("Unexpected indented codeblock");
+ }
+ };
+ }
+ Event::Start(_) => {
+ text.clear();
+ block_info = "".into();
+ }
+ Event::End(TagEnd::CodeBlock) => {
+ let block_info = block_info.trim();
+ match current_section {
+ Section::UncommittedDiff => {
+ example.uncommitted_diff = mem::take(&mut text);
+ }
+ Section::EditHistory => {
+ example.edit_history.push_str(&mem::take(&mut text));
+ }
+ Section::CursorPosition => {
+ example.cursor_path = Path::new(block_info).into();
+ example.cursor_position = mem::take(&mut text);
+ }
+ Section::ExpectedExcerpts => {
+ mem::take(&mut text);
+ }
+ Section::ExpectedPatch => {
+ example.expected_patch = mem::take(&mut text);
+ }
+ Section::Other => {}
}
- })
- .collect()
- }
-
- pub async fn cursor_position(
- &self,
- project: &Entity<Project>,
- cx: &mut AsyncApp,
- ) -> Result<(Entity<Buffer>, Anchor)> {
- let worktree = project.read_with(cx, |project, cx| {
- project.visible_worktrees(cx).next().unwrap()
- })?;
- let cursor_path = RelPath::new(&self.example.cursor_path, PathStyle::Posix)?.into_arc();
- let cursor_buffer = project
- .update(cx, |project, cx| {
- project.open_buffer(
- ProjectPath {
- worktree_id: worktree.read(cx).id(),
- path: cursor_path,
- },
- cx,
- )
- })?
- .await?;
- let cursor_offset_within_excerpt = self
- .example
- .cursor_position
- .find(CURSOR_MARKER)
- .ok_or_else(|| anyhow!("missing cursor marker"))?;
- let mut cursor_excerpt = self.example.cursor_position.clone();
- cursor_excerpt.replace_range(
- cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
- "",
- );
- let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
- let text = buffer.text();
-
- let mut matches = text.match_indices(&cursor_excerpt);
- let Some((excerpt_offset, _)) = matches.next() else {
- anyhow::bail!(
- "\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n.Cursor excerpt did not exist in buffer."
- );
- };
- assert!(matches.next().is_none());
-
- Ok(excerpt_offset)
- })??;
-
- let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
- let cursor_anchor =
- cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?;
- Ok((cursor_buffer, cursor_anchor))
- }
-
- #[must_use]
- pub async fn apply_edit_history(
- &self,
- project: &Entity<Project>,
- cx: &mut AsyncApp,
- ) -> Result<OpenedBuffers<'_>> {
- edit_prediction::udiff::apply_diff(&self.example.edit_history, project, cx).await
- }
-}
-
-async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
- let output = smol::process::Command::new("git")
- .current_dir(repo_path)
- .args(args)
- .output()
- .await?;
-
- anyhow::ensure!(
- output.status.success(),
- "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
- args.join(" "),
- repo_path.display(),
- output.status,
- String::from_utf8_lossy(&output.stderr),
- String::from_utf8_lossy(&output.stdout),
- );
- Ok(String::from_utf8(output.stdout)?.trim().to_string())
-}
-
-impl Display for NamedExample {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- write!(f, "# {}\n\n", self.name)?;
- write!(
- f,
- "{REPOSITORY_URL_FIELD} = {}\n",
- self.example.repository_url
- )?;
- write!(f, "{REVISION_FIELD} = {}\n\n", self.example.revision)?;
-
- write!(f, "## {UNCOMMITTED_DIFF_HEADING}\n\n")?;
- write!(f, "`````diff\n")?;
- write!(f, "{}", self.example.uncommitted_diff)?;
- write!(f, "`````\n")?;
-
- if !self.example.edit_history.is_empty() {
- write!(f, "`````diff\n{}`````\n", self.example.edit_history)?;
- }
-
- write!(
- f,
- "## {CURSOR_POSITION_HEADING}\n\n`````{}\n{}`````\n",
- self.example.cursor_path.display(),
- self.example.cursor_position
- )?;
- write!(f, "## {EDIT_HISTORY_HEADING}\n\n")?;
-
- if !self.example.expected_patch.is_empty() {
- write!(
- f,
- "\n## {EXPECTED_PATCH_HEADING}\n\n`````diff\n{}`````\n",
- self.example.expected_patch
- )?;
+ }
+ _ => {}
}
-
- Ok(())
}
-}
-
-thread_local! {
- static REPO_LOCKS: RefCell<HashMap<PathBuf, Arc<Mutex<()>>>> = RefCell::new(HashMap::default());
-}
+ if example.cursor_path.as_ref() == Path::new("") || example.cursor_position.is_empty() {
+ anyhow::bail!("Missing cursor position codeblock");
+ }
-#[must_use]
-pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
- REPO_LOCKS
- .with(|cell| {
- cell.borrow_mut()
- .entry(path.as_ref().to_path_buf())
- .or_default()
- .clone()
- })
- .lock_owned()
- .await
+ Ok(example)
}
@@ -0,0 +1,280 @@
+use crate::{
+ PromptFormat,
+ example::{Example, ExamplePrompt},
+ headless::EpAppState,
+ retrieve_context::run_context_retrieval,
+};
+use edit_prediction::{EditPredictionStore, zeta2::zeta2_prompt_input};
+use gpui::AsyncApp;
+use std::sync::Arc;
+use zeta_prompt::format_zeta_prompt;
+
+pub async fn run_format_prompt(
+ example: &mut Example,
+ prompt_format: PromptFormat,
+ app_state: Arc<EpAppState>,
+ mut cx: AsyncApp,
+) {
+ run_context_retrieval(example, app_state, cx.clone()).await;
+
+ let prompt = match prompt_format {
+ PromptFormat::Teacher => TeacherPrompt::format(example),
+ PromptFormat::Zeta2 => {
+ let ep_store = cx
+ .update(|cx| EditPredictionStore::try_global(cx).unwrap())
+ .unwrap();
+
+ let state = example.state.as_ref().unwrap();
+ let snapshot = state
+ .buffer
+ .read_with(&cx, |buffer, _| buffer.snapshot())
+ .unwrap();
+ let project = state.project.clone();
+ let (_, input) = ep_store
+ .update(&mut cx, |ep_store, _cx| {
+ zeta2_prompt_input(
+ &snapshot,
+ example.context.as_ref().unwrap().files.clone(),
+ ep_store.edit_history_for_project(&project),
+ example.cursor_path.clone(),
+ example.buffer.as_ref().unwrap().cursor_offset,
+ )
+ })
+ .unwrap();
+ format_zeta_prompt(&input)
+ }
+ };
+
+ example.prompt = Some(ExamplePrompt {
+ input: prompt,
+ expected_output: example.expected_patch.clone(), // TODO
+ format: prompt_format,
+ });
+}
+
+pub trait PromptFormatter {
+ fn format(example: &Example) -> String;
+}
+
+pub trait PromptParser {
+ /// Return unified diff patch of prediction given raw LLM response
+ fn parse(example: &Example, response: &str) -> String;
+}
+
+pub struct TeacherPrompt;
+
+impl PromptFormatter for TeacherPrompt {
+ fn format(example: &Example) -> String {
+ let edit_history = Self::format_edit_history(&example.edit_history);
+ let context = Self::format_context(example);
+ let editable_region = Self::format_editable_region(example);
+
+ let prompt = Self::PROMPT
+ .replace("{{context}}", &context)
+ .replace("{{edit_history}}", &edit_history)
+ .replace("{{editable_region}}", &editable_region);
+
+ prompt
+ }
+}
+
+impl TeacherPrompt {
+ const PROMPT: &str = include_str!("teacher.prompt.md");
+ pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
+ pub(crate) const EDITABLE_REGION_END: &str = "<|editable_region_end|>";
+
+ /// Truncate edit history to this number of last lines
+ const MAX_HISTORY_LINES: usize = 128;
+
+ fn format_edit_history(edit_history: &str) -> String {
+ // Strip comments ("garbage lines") from edit history
+ let lines = edit_history
+ .lines()
+ .filter(|&s| Self::is_udiff_content_line(s))
+ .collect::<Vec<_>>();
+
+ let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
+ &lines[lines.len() - Self::MAX_HISTORY_LINES..]
+ } else {
+ &lines
+ };
+
+ if history_lines.is_empty() {
+ return "(No edit history)".to_string();
+ }
+
+ history_lines.join("\n")
+ }
+
+ fn format_context(example: &Example) -> String {
+ if example.context.is_none() {
+ panic!("Missing context retriever step");
+ }
+
+ let mut prompt = String::new();
+ zeta_prompt::write_related_files(&mut prompt, &example.context.as_ref().unwrap().files);
+
+ prompt
+ }
+
+ fn format_editable_region(example: &Example) -> String {
+ let mut result = String::new();
+
+ let path_str = example.cursor_path.to_string_lossy();
+ result.push_str(&format!("`````path=\"{path_str}\"\n"));
+ result.push_str(Self::EDITABLE_REGION_START);
+
+ // TODO: control number of lines around cursor
+ result.push_str(&example.cursor_position);
+ if !example.cursor_position.ends_with('\n') {
+ result.push('\n');
+ }
+
+ result.push_str(&format!("{}\n", Self::EDITABLE_REGION_END));
+ result.push_str("`````");
+
+ result
+ }
+
+ fn extract_editable_region(text: &str) -> String {
+ let start = text
+ .find(Self::EDITABLE_REGION_START)
+ .map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
+ let end = text.find(Self::EDITABLE_REGION_END).unwrap_or(text.len());
+
+ let region = &text[start..end];
+
+ region.replace("<|user_cursor|>", "")
+ }
+
+ fn is_udiff_content_line(s: &str) -> bool {
+ s.starts_with("-")
+ || s.starts_with("+")
+ || s.starts_with(" ")
+ || s.starts_with("---")
+ || s.starts_with("+++")
+ || s.starts_with("@@")
+ }
+}
+
+impl PromptParser for TeacherPrompt {
+ fn parse(example: &Example, response: &str) -> String {
+ // Ideally, we should always be able to find cursor position in the retrieved context.
+ // In reality, sometimes we don't find it for these reasons:
+ // 1. `example.cursor_position` contains _more_ context than included in the retrieved context
+ // (can be fixed by getting cursor coordinates at the load_example stage)
+ // 2. Context retriever just didn't include cursor line.
+ //
+ // In that case, fallback to using `cursor_position` as excerpt.
+ let cursor_file = &example
+ .buffer
+ .as_ref()
+ .expect("`buffer` should be filled in in the context collection step")
+ .content;
+
+ // Extract updated (new) editable region from the model response
+ let new_editable_region = extract_last_codeblock(response);
+
+ // Reconstruct old editable region we sent to the model
+ let old_editable_region = Self::format_editable_region(example);
+ let old_editable_region = Self::extract_editable_region(&old_editable_region);
+ if !cursor_file.contains(&old_editable_region) {
+ panic!("Something's wrong: editable_region is not found in the cursor file")
+ }
+
+ // Apply editable region to a larger context and compute diff.
+ // This is needed to get a better context lines around the editable region
+ let edited_file = cursor_file.replace(&old_editable_region, &new_editable_region);
+ let diff = language::unified_diff(&cursor_file, &edited_file);
+
+ let diff = indoc::formatdoc! {"
+ --- a/{path}
+ +++ b/{path}
+ {diff}
+ ",
+ path = example.cursor_path.to_string_lossy(),
+ diff = diff,
+ };
+
+ diff
+ }
+}
+
+fn extract_last_codeblock(text: &str) -> String {
+ let mut last_block = None;
+ let mut search_start = 0;
+
+ while let Some(start) = text[search_start..].find("```") {
+ let start = start + search_start;
+ let bytes = text.as_bytes();
+ let mut backtick_end = start;
+
+ while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
+ backtick_end += 1;
+ }
+
+ let backtick_count = backtick_end - start;
+ let closing_backticks = "`".repeat(backtick_count);
+
+ while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
+ backtick_end += 1;
+ }
+
+ if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) {
+ let code_block = &text[backtick_end + 1..backtick_end + end_pos - 1];
+ last_block = Some(code_block.to_string());
+ search_start = backtick_end + end_pos + backtick_count;
+ } else {
+ break;
+ }
+ }
+
+ last_block.unwrap_or_else(|| text.to_string())
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_extract_last_code_block() {
+ let text = indoc::indoc! {"
+ Some thinking
+
+ ```
+ first block
+ ```
+
+ `````path='something' lines=1:2
+ last block
+ `````
+ "};
+ let last_block = extract_last_codeblock(text);
+ assert_eq!(last_block, "last block");
+ }
+
+ #[test]
+ fn test_extract_editable_region() {
+ let text = indoc::indoc! {"
+ some lines
+ are
+ here
+ <|editable_region_start|>
+ one
+ two three
+
+ <|editable_region_end|>
+ more
+ lines here
+ "};
+ let parsed = TeacherPrompt::extract_editable_region(text);
+ assert_eq!(
+ parsed,
+ indoc::indoc! {"
+ one
+ two three
+
+ "}
+ );
+ }
+}
@@ -16,7 +16,7 @@ use std::sync::Arc;
use util::ResultExt as _;
/// Headless subset of `workspace::AppState`.
-pub struct ZetaCliAppState {
+pub struct EpAppState {
pub languages: Arc<LanguageRegistry>,
pub client: Arc<Client>,
pub user_store: Entity<UserStore>,
@@ -25,7 +25,7 @@ pub struct ZetaCliAppState {
}
// TODO: dedupe with crates/eval/src/eval.rs
-pub fn init(cx: &mut App) -> ZetaCliAppState {
+pub fn init(cx: &mut App) -> EpAppState {
let app_commit_sha = option_env!("ZED_COMMIT_SHA").map(|s| AppCommitSha::new(s.to_owned()));
let app_version = AppVersion::load(
@@ -112,7 +112,7 @@ pub fn init(cx: &mut App) -> ZetaCliAppState {
prompt_store::init(cx);
terminal_view::init(cx);
- ZetaCliAppState {
+ EpAppState {
languages,
client,
user_store,
@@ -0,0 +1,320 @@
+use crate::{
+ example::{Example, ExampleBuffer, ExampleState},
+ headless::EpAppState,
+};
+use anyhow::{Result, anyhow};
+use collections::HashMap;
+use edit_prediction::EditPredictionStore;
+use edit_prediction::udiff::OpenedBuffers;
+use futures::{
+ AsyncWriteExt as _,
+ lock::{Mutex, OwnedMutexGuard},
+};
+use gpui::{AsyncApp, Entity};
+use language::{Anchor, Buffer, ToOffset, ToPoint};
+use project::buffer_store::BufferStoreEvent;
+use project::{Project, ProjectPath};
+use std::{
+ cell::RefCell,
+ fs,
+ path::{Path, PathBuf},
+ sync::Arc,
+};
+use util::{paths::PathStyle, rel_path::RelPath};
+use zeta_prompt::CURSOR_MARKER;
+
+pub async fn run_load_project(example: &mut Example, app_state: Arc<EpAppState>, mut cx: AsyncApp) {
+ if example.state.is_some() {
+ return;
+ }
+
+ let project = setup_project(example, &app_state, &mut cx).await;
+ let buffer_store = project
+ .read_with(&cx, |project, _| project.buffer_store().clone())
+ .unwrap();
+
+ let ep_store = cx
+ .update(|cx| EditPredictionStore::try_global(cx).unwrap())
+ .unwrap();
+
+ cx.subscribe(&buffer_store, {
+ let project = project.clone();
+ move |_, event, cx| match event {
+ BufferStoreEvent::BufferAdded(buffer) => {
+ ep_store.update(cx, |store, cx| store.register_buffer(&buffer, &project, cx));
+ }
+ _ => {}
+ }
+ })
+ .unwrap()
+ .detach();
+
+ let _open_buffers = apply_edit_history(example, &project, &mut cx)
+ .await
+ .unwrap();
+ let (buffer, cursor_position) = cursor_position(example, &project, &mut cx).await;
+ example.buffer = buffer
+ .read_with(&cx, |buffer, _cx| {
+ let cursor_point = cursor_position.to_point(&buffer);
+ Some(ExampleBuffer {
+ content: buffer.text(),
+ cursor_row: cursor_point.row,
+ cursor_column: cursor_point.column,
+ cursor_offset: cursor_position.to_offset(&buffer),
+ })
+ })
+ .unwrap();
+ example.state = Some(ExampleState {
+ buffer,
+ project,
+ cursor_position,
+ _open_buffers,
+ });
+}
+
+async fn cursor_position(
+ example: &Example,
+ project: &Entity<Project>,
+ cx: &mut AsyncApp,
+) -> (Entity<Buffer>, Anchor) {
+ let worktree = project
+ .read_with(cx, |project, cx| {
+ project.visible_worktrees(cx).next().unwrap()
+ })
+ .unwrap();
+
+ let cursor_path = RelPath::new(&example.cursor_path, PathStyle::Posix)
+ .unwrap()
+ .into_arc();
+ let cursor_buffer = project
+ .update(cx, |project, cx| {
+ project.open_buffer(
+ ProjectPath {
+ worktree_id: worktree.read(cx).id(),
+ path: cursor_path,
+ },
+ cx,
+ )
+ })
+ .unwrap()
+ .await
+ .unwrap();
+ let cursor_offset_within_excerpt = example
+ .cursor_position
+ .find(CURSOR_MARKER)
+ .ok_or_else(|| anyhow!("missing cursor marker"))
+ .unwrap();
+ let mut cursor_excerpt = example.cursor_position.clone();
+ cursor_excerpt.replace_range(
+ cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
+ "",
+ );
+ let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
+ let text = buffer.text();
+
+ let mut matches = text.match_indices(&cursor_excerpt);
+ let (excerpt_offset, _) = matches.next().unwrap_or_else(|| {
+ panic!(
+ "\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n.Cursor excerpt did not exist in buffer."
+ );
+ });
+ assert!(matches.next().is_none(), "More than one cursor position match found for {}", &example.name);
+ excerpt_offset
+ }).unwrap();
+
+ let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
+ let cursor_anchor = cursor_buffer
+ .read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))
+ .unwrap();
+
+ (cursor_buffer, cursor_anchor)
+}
+
+async fn setup_project(
+ example: &mut Example,
+ app_state: &Arc<EpAppState>,
+ cx: &mut AsyncApp,
+) -> Entity<Project> {
+ setup_worktree(example).await;
+
+ let project = cx
+ .update(|cx| {
+ Project::local(
+ app_state.client.clone(),
+ app_state.node_runtime.clone(),
+ app_state.user_store.clone(),
+ app_state.languages.clone(),
+ app_state.fs.clone(),
+ None,
+ cx,
+ )
+ })
+ .unwrap();
+
+ let worktree = project
+ .update(cx, |project, cx| {
+ project.create_worktree(&example.worktree_path(), true, cx)
+ })
+ .unwrap()
+ .await
+ .unwrap();
+ worktree
+ .read_with(cx, |worktree, _cx| {
+ worktree.as_local().unwrap().scan_complete()
+ })
+ .unwrap()
+ .await;
+ project
+}
+
+pub async fn setup_worktree(example: &Example) {
+ let repo_dir = example.repo_path();
+ let repo_lock = lock_repo(&repo_dir).await;
+
+ if !repo_dir.is_dir() {
+ fs::create_dir_all(&repo_dir).unwrap();
+ run_git(&repo_dir, &["init"]).await.unwrap();
+ run_git(
+ &repo_dir,
+ &["remote", "add", "origin", &example.repository_url],
+ )
+ .await
+ .unwrap();
+ }
+
+ // Resolve the example to a revision, fetching it if needed.
+ let revision = run_git(
+ &repo_dir,
+ &["rev-parse", &format!("{}^{{commit}}", example.revision)],
+ )
+ .await;
+ let revision = if let Ok(revision) = revision {
+ revision
+ } else {
+ if run_git(
+ &repo_dir,
+ &["fetch", "--depth", "1", "origin", &example.revision],
+ )
+ .await
+ .is_err()
+ {
+ run_git(&repo_dir, &["fetch", "origin"]).await.unwrap();
+ }
+ let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"])
+ .await
+ .unwrap();
+ if revision != example.revision {
+ run_git(&repo_dir, &["tag", &example.revision, &revision])
+ .await
+ .unwrap();
+ }
+ revision
+ };
+
+ // Create the worktree for this example if needed.
+ let worktree_path = example.worktree_path();
+ if worktree_path.is_dir() {
+ run_git(&worktree_path, &["clean", "--force", "-d"])
+ .await
+ .unwrap();
+ run_git(&worktree_path, &["reset", "--hard", "HEAD"])
+ .await
+ .unwrap();
+ run_git(&worktree_path, &["checkout", revision.as_str()])
+ .await
+ .unwrap();
+ } else {
+ let worktree_path_string = worktree_path.to_string_lossy();
+ run_git(
+ &repo_dir,
+ &["branch", "-f", &example.name, revision.as_str()],
+ )
+ .await
+ .unwrap();
+ run_git(
+ &repo_dir,
+ &[
+ "worktree",
+ "add",
+ "-f",
+ &worktree_path_string,
+ &example.name,
+ ],
+ )
+ .await
+ .unwrap();
+ }
+ drop(repo_lock);
+
+ // Apply the uncommitted diff for this example.
+ if !example.uncommitted_diff.is_empty() {
+ let mut apply_process = smol::process::Command::new("git")
+ .current_dir(&worktree_path)
+ .args(&["apply", "-"])
+ .stdin(std::process::Stdio::piped())
+ .spawn()
+ .unwrap();
+
+ let mut stdin = apply_process.stdin.take().unwrap();
+ stdin
+ .write_all(example.uncommitted_diff.as_bytes())
+ .await
+ .unwrap();
+ stdin.close().await.unwrap();
+ drop(stdin);
+
+ let apply_result = apply_process.output().await.unwrap();
+ if !apply_result.status.success() {
+ panic!(
+ "Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
+ apply_result.status,
+ String::from_utf8_lossy(&apply_result.stderr),
+ String::from_utf8_lossy(&apply_result.stdout),
+ );
+ }
+ }
+}
+
+async fn apply_edit_history(
+ example: &Example,
+ project: &Entity<Project>,
+ cx: &mut AsyncApp,
+) -> Result<OpenedBuffers> {
+ edit_prediction::udiff::apply_diff(&example.edit_history, project, cx).await
+}
+
+thread_local! {
+ static REPO_LOCKS: RefCell<HashMap<PathBuf, Arc<Mutex<()>>>> = RefCell::new(HashMap::default());
+}
+
+#[must_use]
+pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
+ REPO_LOCKS
+ .with(|cell| {
+ cell.borrow_mut()
+ .entry(path.as_ref().to_path_buf())
+ .or_default()
+ .clone()
+ })
+ .lock_owned()
+ .await
+}
+
+async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
+ let output = smol::process::Command::new("git")
+ .current_dir(repo_path)
+ .args(args)
+ .output()
+ .await?;
+
+ anyhow::ensure!(
+ output.status.success(),
+ "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
+ args.join(" "),
+ repo_path.display(),
+ output.status,
+ String::from_utf8_lossy(&output.stderr),
+ String::from_utf8_lossy(&output.stdout),
+ );
+ Ok(String::from_utf8(output.stdout)?.trim().to_string())
+}
@@ -1,522 +1,196 @@
-mod evaluate;
+mod anthropic_client;
mod example;
+mod format_prompt;
mod headless;
+mod load_project;
mod metrics;
mod paths;
mod predict;
-mod source_location;
-mod training;
-mod util;
+mod retrieve_context;
+mod score;
-use crate::{
- evaluate::run_evaluate,
- example::{ExampleFormat, NamedExample},
- headless::ZetaCliAppState,
- predict::run_predict,
- source_location::SourceLocation,
- training::{context::ContextType, distill::run_distill},
- util::{open_buffer, open_buffer_with_language_server},
-};
-use ::util::{ResultExt, paths::PathStyle};
-use anyhow::{Result, anyhow};
-use clap::{Args, Parser, Subcommand, ValueEnum};
-use cloud_llm_client::predict_edits_v3;
-use edit_prediction::udiff::DiffLine;
-use edit_prediction_context::EditPredictionExcerptOptions;
-use gpui::{Application, AsyncApp, Entity, prelude::*};
-use language::{Bias, Buffer, BufferSnapshot, Point};
-use metrics::delta_chr_f;
-use project::{Project, Worktree, lsp_store::OpenLspBufferHandle};
+use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
+use edit_prediction::EditPredictionStore;
+use gpui::Application;
use reqwest_client::ReqwestClient;
-use std::io::{self};
-use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc};
+use serde::{Deserialize, Serialize};
+use std::{path::PathBuf, sync::Arc};
+
+use crate::example::{read_examples, write_examples};
+use crate::format_prompt::run_format_prompt;
+use crate::load_project::run_load_project;
+use crate::predict::run_prediction;
+use crate::retrieve_context::run_context_retrieval;
+use crate::score::run_scoring;
#[derive(Parser, Debug)]
-#[command(name = "zeta")]
-struct ZetaCliArgs {
+#[command(name = "ep")]
+struct EpArgs {
#[arg(long, default_value_t = false)]
printenv: bool,
+ #[clap(long, default_value_t = 10)]
+ max_parallelism: usize,
#[command(subcommand)]
command: Option<Command>,
+ #[clap(global = true)]
+ inputs: Vec<PathBuf>,
+ #[arg(long, short, global = true)]
+ output: Option<PathBuf>,
+ #[arg(long, short, global = true)]
+ in_place: bool,
}
#[derive(Subcommand, Debug)]
enum Command {
- Context(ContextArgs),
- Predict(PredictArguments),
- Eval(EvaluateArguments),
- Distill(DistillArguments),
- ConvertExample {
- path: PathBuf,
- #[arg(long, value_enum, default_value_t = ExampleFormat::Md)]
- output_format: ExampleFormat,
- },
- Score {
- golden_patch: PathBuf,
- actual_patch: PathBuf,
- },
+ /// Parse markdown examples and output a combined .jsonl file
+ ParseExample,
+ /// Create git worktrees for each example and load file contents
+ LoadBuffer,
+ /// Retrieve context for input examples.
+ Context,
+ /// Generate a prompt string for a specific model
+ FormatPrompt(FormatPromptArgs),
+ /// Runs edit prediction
+ Predict(PredictArgs),
+ /// Computes a score based on actual and expected patches
+ Score(PredictArgs),
+ /// Print aggregated scores
+ Eval(PredictArgs),
+ /// Remove git repositories and worktrees
Clean,
}
#[derive(Debug, Args)]
-struct ContextArgs {
- #[arg(long)]
- provider: ContextProvider,
- #[arg(long)]
- worktree: PathBuf,
- #[arg(long)]
- cursor: SourceLocation,
- #[arg(long)]
- use_language_server: bool,
- #[arg(long)]
- edit_history: Option<FileOrStdin>,
- #[clap(flatten)]
- zeta2_args: Zeta2Args,
-}
-
-#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
-enum ContextProvider {
- Zeta1,
- #[default]
- Zeta2,
-}
-
-#[derive(Clone, Debug, Args)]
-struct Zeta2Args {
- #[arg(long, default_value_t = 8192)]
- max_prompt_bytes: usize,
- #[arg(long, default_value_t = 2048)]
- max_excerpt_bytes: usize,
- #[arg(long, default_value_t = 1024)]
- min_excerpt_bytes: usize,
- #[arg(long, default_value_t = 0.66)]
- target_before_cursor_over_total_bytes: f32,
- #[arg(long, default_value_t = 1024)]
- max_diagnostic_bytes: usize,
- #[arg(long, value_enum, default_value_t = PromptFormat::default())]
+struct FormatPromptArgs {
+ #[clap(long)]
prompt_format: PromptFormat,
- #[arg(long, value_enum, default_value_t = Default::default())]
- output_format: OutputFormat,
- #[arg(long, default_value_t = 42)]
- file_indexing_parallelism: usize,
- #[arg(long, default_value_t = false)]
- disable_imports_gathering: bool,
- #[arg(long, default_value_t = u8::MAX)]
- max_retrieved_definitions: u8,
}
-#[derive(Debug, Args)]
-pub struct PredictArguments {
- #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
- format: PredictionsOutputFormat,
- example_path: PathBuf,
- #[clap(flatten)]
- options: PredictionOptions,
+#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
+enum PromptFormat {
+ Teacher,
+ Zeta2,
}
#[derive(Debug, Args)]
-pub struct DistillArguments {
- split_commit_dataset: PathBuf,
- #[clap(long, value_enum, default_value_t = ContextType::CurrentFile)]
- context_type: ContextType,
- #[clap(long)]
- batch: Option<String>,
-}
-
-#[derive(Clone, Debug, Args)]
-pub struct PredictionOptions {
- #[clap(flatten)]
- zeta2: Zeta2Args,
+struct PredictArgs {
#[clap(long)]
provider: PredictionProvider,
- #[clap(long, value_enum, default_value_t = CacheMode::default())]
- cache: CacheMode,
-}
-
-#[derive(Debug, ValueEnum, Default, Clone, Copy, PartialEq)]
-pub enum CacheMode {
- /// Use cached LLM requests and responses, except when multiple repetitions are requested
- #[default]
- Auto,
- /// Use cached LLM requests and responses, based on the hash of the prompt and the endpoint.
- #[value(alias = "request")]
- Requests,
- /// Ignore existing cache entries for both LLM and search.
- Skip,
- /// Use cached LLM responses AND search results for full determinism. Fails if they haven't been cached yet.
- /// Useful for reproducing results and fixing bugs outside of search queries
- Force,
-}
-
-impl CacheMode {
- fn use_cached_llm_responses(&self) -> bool {
- self.assert_not_auto();
- matches!(self, CacheMode::Requests | CacheMode::Force)
- }
-
- fn use_cached_search_results(&self) -> bool {
- self.assert_not_auto();
- matches!(self, CacheMode::Force)
- }
-
- fn assert_not_auto(&self) {
- assert_ne!(
- *self,
- CacheMode::Auto,
- "Cache mode should not be auto at this point!"
- );
- }
-}
-
-#[derive(clap::ValueEnum, Debug, Clone)]
-pub enum PredictionsOutputFormat {
- Json,
- Md,
- Diff,
+ #[clap(long, default_value_t = 1)]
+ repetitions: usize,
}
-#[derive(Debug, Args)]
-pub struct EvaluateArguments {
- example_paths: Vec<PathBuf>,
- #[clap(flatten)]
- options: PredictionOptions,
- #[clap(short, long, default_value_t = 1, alias = "repeat")]
- repetitions: u16,
- #[arg(long)]
- skip_prediction: bool,
-}
-
-#[derive(clap::ValueEnum, Default, Debug, Clone, Copy, PartialEq)]
+#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
enum PredictionProvider {
+ Sweep,
+ Mercury,
Zeta1,
- #[default]
Zeta2,
- Sweep,
-}
-
-fn zeta2_args_to_options(args: &Zeta2Args) -> edit_prediction::ZetaOptions {
- edit_prediction::ZetaOptions {
- context: EditPredictionExcerptOptions {
- max_bytes: args.max_excerpt_bytes,
- min_bytes: args.min_excerpt_bytes,
- target_before_cursor_over_total_bytes: args.target_before_cursor_over_total_bytes,
- },
- max_prompt_bytes: args.max_prompt_bytes,
- prompt_format: args.prompt_format.into(),
- }
-}
-
-#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
-enum PromptFormat {
- OnlySnippets,
- #[default]
- OldTextNewText,
- Minimal,
- MinimalQwen,
- SeedCoder1120,
+ Teacher,
}
-impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
- fn into(self) -> predict_edits_v3::PromptFormat {
- match self {
- Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
- Self::OldTextNewText => predict_edits_v3::PromptFormat::OldTextNewText,
- Self::Minimal => predict_edits_v3::PromptFormat::Minimal,
- Self::MinimalQwen => predict_edits_v3::PromptFormat::MinimalQwen,
- Self::SeedCoder1120 => predict_edits_v3::PromptFormat::SeedCoder1120,
+impl EpArgs {
+ fn output_path(&self) -> Option<PathBuf> {
+ if self.in_place {
+ if self.inputs.len() == 1 {
+ self.inputs.first().cloned()
+ } else {
+ panic!("--in-place requires exactly one input file")
+ }
+ } else {
+ self.output.clone()
}
}
}
-#[derive(clap::ValueEnum, Default, Debug, Clone)]
-enum OutputFormat {
- #[default]
- Prompt,
- Request,
- Full,
-}
-
-#[derive(Debug, Clone)]
-enum FileOrStdin {
- File(PathBuf),
- Stdin,
-}
+fn main() {
+ zlog::init();
+ zlog::init_output_stderr();
+ let args = EpArgs::parse();
-impl FileOrStdin {
- async fn read_to_string(&self) -> Result<String, std::io::Error> {
- match self {
- FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
- FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
- }
+ if args.printenv {
+ ::util::shell_env::print_env();
+ return;
}
-}
-
-impl FromStr for FileOrStdin {
- type Err = <PathBuf as FromStr>::Err;
- fn from_str(s: &str) -> Result<Self, Self::Err> {
- match s {
- "-" => Ok(Self::Stdin),
- _ => Ok(Self::File(PathBuf::from_str(s)?)),
+ let output = args.output_path();
+ let command = match args.command {
+ Some(cmd) => cmd,
+ None => {
+ EpArgs::command().print_help().unwrap();
+ return;
}
- }
-}
-
-struct LoadedContext {
- full_path_str: String,
- snapshot: BufferSnapshot,
- clipped_cursor: Point,
- worktree: Entity<Worktree>,
- project: Entity<Project>,
- buffer: Entity<Buffer>,
- lsp_open_handle: Option<OpenLspBufferHandle>,
-}
-
-async fn load_context(
- args: &ContextArgs,
- app_state: &Arc<ZetaCliAppState>,
- cx: &mut AsyncApp,
-) -> Result<LoadedContext> {
- let ContextArgs {
- worktree: worktree_path,
- cursor,
- use_language_server,
- ..
- } = args;
-
- let worktree_path = worktree_path.canonicalize()?;
-
- let project = cx.update(|cx| {
- Project::local(
- app_state.client.clone(),
- app_state.node_runtime.clone(),
- app_state.user_store.clone(),
- app_state.languages.clone(),
- app_state.fs.clone(),
- None,
- cx,
- )
- })?;
-
- let worktree = project
- .update(cx, |project, cx| {
- project.create_worktree(&worktree_path, true, cx)
- })?
- .await?;
-
- let mut ready_languages = HashSet::default();
- let (lsp_open_handle, buffer) = if *use_language_server {
- let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
- project.clone(),
- worktree.clone(),
- cursor.path.clone(),
- &mut ready_languages,
- cx,
- )
- .await?;
- (Some(lsp_open_handle), buffer)
- } else {
- let buffer =
- open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
- (None, buffer)
};
- let full_path_str = worktree
- .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
- .display(PathStyle::local())
- .to_string();
-
- let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
- let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
- if clipped_cursor != cursor.point {
- let max_row = snapshot.max_point().row;
- if cursor.point.row < max_row {
- return Err(anyhow!(
- "Cursor position {:?} is out of bounds (line length is {})",
- cursor.point,
- snapshot.line_len(cursor.point.row)
- ));
- } else {
- return Err(anyhow!(
- "Cursor position {:?} is out of bounds (max row is {})",
- cursor.point,
- max_row
- ));
+ match &command {
+ Command::Clean => {
+ std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
+ return;
}
+ _ => {}
}
- Ok(LoadedContext {
- full_path_str,
- snapshot,
- clipped_cursor,
- worktree,
- project,
- buffer,
- lsp_open_handle,
- })
-}
-
-async fn zeta2_context(
- args: ContextArgs,
- app_state: &Arc<ZetaCliAppState>,
- cx: &mut AsyncApp,
-) -> Result<String> {
- let LoadedContext {
- worktree,
- project,
- buffer,
- clipped_cursor,
- lsp_open_handle: _handle,
- ..
- } = load_context(&args, app_state, cx).await?;
-
- // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
- // the whole worktree.
- worktree
- .read_with(cx, |worktree, _cx| {
- worktree.as_local().unwrap().scan_complete()
- })?
- .await;
- let output = cx
- .update(|cx| {
- let store = cx.new(|cx| {
- edit_prediction::EditPredictionStore::new(
- app_state.client.clone(),
- app_state.user_store.clone(),
- cx,
- )
- });
- store.update(cx, |store, cx| {
- store.set_options(zeta2_args_to_options(&args.zeta2_args));
- store.register_buffer(&buffer, &project, cx);
- });
- cx.spawn(async move |cx| {
- let updates_rx = store.update(cx, |store, cx| {
- let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
- store.set_use_context(true);
- store.refresh_context(&project, &buffer, cursor, cx);
- store.project_context_updates(&project).unwrap()
- })?;
-
- updates_rx.recv().await.ok();
-
- let context = store.update(cx, |store, cx| {
- store.context_for_project(&project, cx).to_vec()
- })?;
-
- anyhow::Ok(serde_json::to_string_pretty(&context).unwrap())
- })
- })?
- .await?;
-
- Ok(output)
-}
-
-async fn zeta1_context(
- args: ContextArgs,
- app_state: &Arc<ZetaCliAppState>,
- cx: &mut AsyncApp,
-) -> Result<edit_prediction::zeta1::GatherContextOutput> {
- let LoadedContext {
- full_path_str,
- snapshot,
- clipped_cursor,
- ..
- } = load_context(&args, app_state, cx).await?;
-
- let events = match args.edit_history {
- Some(events) => events.read_to_string().await?,
- None => String::new(),
- };
-
- let prompt_for_events = move || (events, 0);
- cx.update(|cx| {
- edit_prediction::zeta1::gather_context(
- full_path_str,
- &snapshot,
- clipped_cursor,
- prompt_for_events,
- cloud_llm_client::PredictEditsRequestTrigger::Cli,
- cx,
- )
- })?
- .await
-}
-
-fn main() {
- zlog::init();
- zlog::init_output_stderr();
- let args = ZetaCliArgs::parse();
+ let mut examples = read_examples(&args.inputs);
let http_client = Arc::new(ReqwestClient::new());
let app = Application::headless().with_http_client(http_client);
app.run(move |cx| {
let app_state = Arc::new(headless::init(cx));
+ EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
+
cx.spawn(async move |cx| {
- match args.command {
- None => {
- if args.printenv {
- ::util::shell_env::print_env();
- } else {
- panic!("Expected a command");
- }
- }
- Some(Command::Context(context_args)) => {
- let result = match context_args.provider {
- ContextProvider::Zeta1 => {
- let context =
- zeta1_context(context_args, &app_state, cx).await.unwrap();
- serde_json::to_string_pretty(&context.body).unwrap()
- }
- ContextProvider::Zeta2 => {
- zeta2_context(context_args, &app_state, cx).await.unwrap()
+ match &command {
+ Command::Predict(args) => predict::sync_batches(&args.provider).await,
+ _ => (),
+ };
+
+ for data in examples.chunks_mut(args.max_parallelism) {
+ let mut futures = Vec::new();
+ for example in data.iter_mut() {
+ let cx = cx.clone();
+ let app_state = app_state.clone();
+ futures.push(async {
+ match &command {
+ Command::ParseExample => {}
+ Command::LoadBuffer => {
+ run_load_project(example, app_state.clone(), cx).await;
+ }
+ Command::Context => {
+ run_context_retrieval(example, app_state, cx).await;
+ }
+ Command::FormatPrompt(args) => {
+ run_format_prompt(example, args.prompt_format, app_state, cx).await;
+ }
+ Command::Predict(args) => {
+ run_prediction(
+ example,
+ Some(args.provider),
+ args.repetitions,
+ app_state.clone(),
+ cx,
+ )
+ .await;
+ }
+ Command::Score(args) | Command::Eval(args) => {
+ run_scoring(example, &args, app_state, cx).await;
+ }
+ Command::Clean => {
+ unreachable!()
+ }
}
- };
- println!("{}", result);
- }
- Some(Command::Predict(arguments)) => {
- run_predict(arguments, &app_state, cx).await;
- }
- Some(Command::Eval(arguments)) => {
- run_evaluate(arguments, &app_state, cx).await;
+ });
}
- Some(Command::Distill(arguments)) => {
- let _guard = cx
- .update(|cx| gpui_tokio::Tokio::handle(cx))
- .unwrap()
- .enter();
- run_distill(arguments).await.log_err();
- }
- Some(Command::ConvertExample {
- path,
- output_format,
- }) => {
- let example = NamedExample::load(path).unwrap();
- example.write(output_format, io::stdout()).unwrap();
- }
- Some(Command::Score {
- golden_patch,
- actual_patch,
- }) => {
- let golden_content = std::fs::read_to_string(golden_patch).unwrap();
- let actual_content = std::fs::read_to_string(actual_patch).unwrap();
-
- let golden_diff: Vec<DiffLine> = golden_content
- .lines()
- .map(|line| DiffLine::parse(line))
- .collect();
+ futures::future::join_all(futures).await;
+ }
- let actual_diff: Vec<DiffLine> = actual_content
- .lines()
- .map(|line| DiffLine::parse(line))
- .collect();
+ if args.output.is_some() || !matches!(command, Command::Eval(_)) {
+ write_examples(&examples, output.as_ref());
+ }
- let score = delta_chr_f(&golden_diff, &actual_diff);
- println!("{:.2}", score);
- }
- Some(Command::Clean) => {
- std::fs::remove_dir_all(&*crate::paths::TARGET_ZETA_DIR).unwrap()
- }
+ match &command {
+ Command::Predict(args) => predict::sync_batches(&args.provider).await,
+ Command::Eval(_) => score::print_report(&examples),
+ _ => (),
};
let _ = cx.update(|cx| cx.quit());
@@ -1,30 +1,34 @@
use collections::{HashMap, HashSet};
use edit_prediction::udiff::DiffLine;
+use serde::{Deserialize, Serialize};
type Counts = HashMap<String, usize>;
type CountsDelta = HashMap<String, isize>;
-#[derive(Default, Debug, Clone)]
-pub struct Scores {
+#[derive(Default, Debug, Clone, Serialize, Deserialize)]
+pub struct ClassificationMetrics {
pub true_positives: usize,
pub false_positives: usize,
pub false_negatives: usize,
}
-impl Scores {
- pub fn from_sets(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
+impl ClassificationMetrics {
+ pub fn from_sets(
+ expected: &HashSet<String>,
+ actual: &HashSet<String>,
+ ) -> ClassificationMetrics {
let true_positives = expected.intersection(actual).count();
let false_positives = actual.difference(expected).count();
let false_negatives = expected.difference(actual).count();
- Scores {
+ ClassificationMetrics {
true_positives,
false_positives,
false_negatives,
}
}
- pub fn from_counts(expected: &Counts, actual: &Counts) -> Scores {
+ pub fn from_counts(expected: &Counts, actual: &Counts) -> ClassificationMetrics {
let mut true_positives = 0;
let mut false_positives = 0;
let mut false_negatives = 0;
@@ -45,32 +49,16 @@ impl Scores {
}
}
- Scores {
+ ClassificationMetrics {
true_positives,
false_positives,
false_negatives,
}
}
- pub fn to_markdown(&self) -> String {
- format!(
- "
-Precision : {:.4}
-Recall : {:.4}
-F1 Score : {:.4}
-True Positives : {}
-False Positives : {}
-False Negatives : {}",
- self.precision(),
- self.recall(),
- self.f1_score(),
- self.true_positives,
- self.false_positives,
- self.false_negatives
- )
- }
-
- pub fn aggregate<'a>(scores: impl Iterator<Item = &'a Scores>) -> Scores {
+ pub fn aggregate<'a>(
+ scores: impl Iterator<Item = &'a ClassificationMetrics>,
+ ) -> ClassificationMetrics {
let mut true_positives = 0;
let mut false_positives = 0;
let mut false_negatives = 0;
@@ -81,7 +69,7 @@ False Negatives : {}",
false_negatives += score.false_negatives;
}
- Scores {
+ ClassificationMetrics {
true_positives,
false_positives,
false_negatives,
@@ -115,7 +103,10 @@ False Negatives : {}",
}
}
-pub fn line_match_score(expected_patch: &[DiffLine], actual_patch: &[DiffLine]) -> Scores {
+pub fn line_match_score(
+ expected_patch: &[DiffLine],
+ actual_patch: &[DiffLine],
+) -> ClassificationMetrics {
let expected_change_lines = expected_patch
.iter()
.filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
@@ -128,7 +119,7 @@ pub fn line_match_score(expected_patch: &[DiffLine], actual_patch: &[DiffLine])
.map(|line| line.to_string())
.collect();
- Scores::from_sets(&expected_change_lines, &actual_change_lines)
+ ClassificationMetrics::from_sets(&expected_change_lines, &actual_change_lines)
}
enum ChrfWhitespace {
@@ -204,7 +195,7 @@ pub fn delta_chr_f(expected: &[DiffLine], actual: &[DiffLine]) -> f64 {
let expected_counts = ngram_delta_to_counts(&expected_delta);
let actual_counts = ngram_delta_to_counts(&actual_delta);
- let score = Scores::from_counts(&expected_counts, &actual_counts);
+ let score = ClassificationMetrics::from_counts(&expected_counts, &actual_counts);
total_precision += score.precision();
total_recall += score.recall();
}
@@ -1,57 +1,25 @@
-use std::{env, path::PathBuf, sync::LazyLock};
+use std::{
+ path::{Path, PathBuf},
+ sync::LazyLock,
+};
-pub static TARGET_ZETA_DIR: LazyLock<PathBuf> =
- LazyLock::new(|| env::current_dir().unwrap().join("target/zeta"));
-pub static CACHE_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_ZETA_DIR.join("cache"));
-pub static REPOS_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_ZETA_DIR.join("repos"));
-pub static WORKTREES_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_ZETA_DIR.join("worktrees"));
+pub static DATA_DIR: LazyLock<PathBuf> = LazyLock::new(|| {
+ let dir = dirs::home_dir().unwrap().join(".zed_ep");
+ ensure_dir(&dir)
+});
+pub static CACHE_DIR: LazyLock<PathBuf> = LazyLock::new(|| ensure_dir(&DATA_DIR.join("cache")));
+pub static REPOS_DIR: LazyLock<PathBuf> = LazyLock::new(|| ensure_dir(&DATA_DIR.join("repos")));
+pub static WORKTREES_DIR: LazyLock<PathBuf> =
+ LazyLock::new(|| ensure_dir(&DATA_DIR.join("worktrees")));
pub static RUN_DIR: LazyLock<PathBuf> = LazyLock::new(|| {
- TARGET_ZETA_DIR
+ DATA_DIR
.join("runs")
.join(chrono::Local::now().format("%d-%m-%y-%H_%M_%S").to_string())
});
-pub static LATEST_EXAMPLE_RUN_DIR: LazyLock<PathBuf> =
- LazyLock::new(|| TARGET_ZETA_DIR.join("latest"));
-
-pub fn print_run_data_dir(deep: bool, use_color: bool) {
- println!("\n## Run Data\n");
- let mut files = Vec::new();
-
- let current_dir = std::env::current_dir().unwrap();
- for file in std::fs::read_dir(&*RUN_DIR).unwrap() {
- let file = file.unwrap();
- if file.file_type().unwrap().is_dir() && deep {
- for file in std::fs::read_dir(file.path()).unwrap() {
- let path = file.unwrap().path();
- let path = path.strip_prefix(¤t_dir).unwrap_or(&path);
- files.push(format!(
- "- {}/{}{}{}",
- path.parent().unwrap().display(),
- if use_color { "\x1b[34m" } else { "" },
- path.file_name().unwrap().display(),
- if use_color { "\x1b[0m" } else { "" },
- ));
- }
- } else {
- let path = file.path();
- let path = path.strip_prefix(¤t_dir).unwrap_or(&path);
- files.push(format!(
- "- {}/{}{}{}",
- path.parent().unwrap().display(),
- if use_color { "\x1b[34m" } else { "" },
- path.file_name().unwrap().display(),
- if use_color { "\x1b[0m" } else { "" }
- ));
- }
- }
- files.sort();
-
- for file in files {
- println!("{}", file);
- }
+pub static LATEST_EXAMPLE_RUN_DIR: LazyLock<PathBuf> = LazyLock::new(|| DATA_DIR.join("latest"));
+pub static LLM_CACHE_DB: LazyLock<PathBuf> = LazyLock::new(|| CACHE_DIR.join("llm_cache.sqlite"));
- println!(
- "\n💡 Tip of the day: {} always points to the latest run\n",
- LATEST_EXAMPLE_RUN_DIR.display()
- );
+fn ensure_dir(path: &Path) -> PathBuf {
+ std::fs::create_dir_all(path).expect("Failed to create directory");
+ path.to_path_buf()
}
@@ -1,374 +1,271 @@
-use crate::example::{ActualExcerpt, NamedExample};
-use crate::headless::ZetaCliAppState;
-use crate::paths::{CACHE_DIR, LATEST_EXAMPLE_RUN_DIR, RUN_DIR, print_run_data_dir};
use crate::{
- CacheMode, PredictArguments, PredictionOptions, PredictionProvider, PredictionsOutputFormat,
+ PredictionProvider, PromptFormat,
+ anthropic_client::AnthropicClient,
+ example::{Example, ExamplePrediction},
+ format_prompt::{PromptParser, TeacherPrompt, run_format_prompt},
+ headless::EpAppState,
+ load_project::run_load_project,
+ paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
+ retrieve_context::run_context_retrieval,
+};
+use edit_prediction::{DebugEvent, EditPredictionStore};
+use futures::{FutureExt as _, StreamExt as _, future::Shared};
+use gpui::{AppContext as _, AsyncApp, Task};
+use std::{
+ fs,
+ sync::{
+ Arc, Mutex, OnceLock,
+ atomic::{AtomicUsize, Ordering::SeqCst},
+ },
};
-use ::serde::Serialize;
-use anyhow::{Context, Result, anyhow};
-use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
-use edit_prediction::{EditPredictionStore, EvalCache, EvalCacheEntryKind, EvalCacheKey};
-use futures::StreamExt as _;
-use gpui::{AppContext, AsyncApp, Entity};
-use project::Project;
-use project::buffer_store::BufferStoreEvent;
-use serde::Deserialize;
-use std::fs;
-use std::io::{IsTerminal, Write};
-use std::path::PathBuf;
-use std::sync::Arc;
-use std::sync::Mutex;
-use std::time::{Duration, Instant};
-pub async fn run_predict(
- args: PredictArguments,
- app_state: &Arc<ZetaCliAppState>,
- cx: &mut AsyncApp,
+pub async fn run_prediction(
+ example: &mut Example,
+ provider: Option<PredictionProvider>,
+ repetition_count: usize,
+ app_state: Arc<EpAppState>,
+ mut cx: AsyncApp,
) {
- let example = NamedExample::load(args.example_path).unwrap();
- let project = example.setup_project(app_state, cx).await.unwrap();
- let store = setup_store(args.options.provider, &project, app_state, cx).unwrap();
- let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
- let result = perform_predict(example, project, store, None, args.options, cx)
- .await
- .unwrap();
- result.write(args.format, std::io::stdout()).unwrap();
-
- print_run_data_dir(true, std::io::stdout().is_terminal());
-}
-
-pub fn setup_store(
- provider: PredictionProvider,
- project: &Entity<Project>,
- app_state: &Arc<ZetaCliAppState>,
- cx: &mut AsyncApp,
-) -> Result<Entity<EditPredictionStore>> {
- let store = cx.new(|cx| {
- edit_prediction::EditPredictionStore::new(
- app_state.client.clone(),
- app_state.user_store.clone(),
- cx,
- )
- })?;
+ if !example.predictions.is_empty() {
+ return;
+ }
- store.update(cx, |store, _cx| {
- let model = match provider {
- PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
- PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
- PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
- };
- store.set_edit_prediction_model(model);
- })?;
+ run_load_project(example, app_state.clone(), cx.clone()).await;
+ run_context_retrieval(example, app_state.clone(), cx.clone()).await;
- let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
+ let provider = provider.unwrap();
- cx.subscribe(&buffer_store, {
- let project = project.clone();
- let store = store.clone();
- move |_, event, cx| match event {
- BufferStoreEvent::BufferAdded(buffer) => {
- store.update(cx, |store, cx| store.register_buffer(&buffer, &project, cx));
- }
- _ => {}
+ if matches!(provider, PredictionProvider::Teacher) {
+ if example.prompt.is_none() {
+ run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await;
}
- })?
- .detach();
- anyhow::Ok(store)
-}
-
-pub async fn perform_predict(
- example: NamedExample,
- project: Entity<Project>,
- store: Entity<EditPredictionStore>,
- repetition_ix: Option<u16>,
- options: PredictionOptions,
- cx: &mut AsyncApp,
-) -> Result<PredictionDetails> {
- let mut cache_mode = options.cache;
- if repetition_ix.is_some() {
- if cache_mode != CacheMode::Auto && cache_mode != CacheMode::Skip {
- panic!("Repetitions are not supported in Auto cache mode");
- } else {
- cache_mode = CacheMode::Skip;
- }
- } else if cache_mode == CacheMode::Auto {
- cache_mode = CacheMode::Requests;
+ let batched = true;
+ return predict_anthropic(example, repetition_count, batched).await;
}
- let mut example_run_dir = RUN_DIR.join(&example.file_name());
- if let Some(repetition_ix) = repetition_ix {
- example_run_dir = example_run_dir.join(format!("{:03}", repetition_ix));
- }
- fs::create_dir_all(&example_run_dir)?;
- if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
- fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
+ if matches!(
+ provider,
+ PredictionProvider::Zeta1 | PredictionProvider::Zeta2
+ ) {
+ static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
+ AUTHENTICATED
+ .get_or_init(|| {
+ let client = app_state.client.clone();
+ cx.spawn(async move |cx| {
+ client
+ .sign_in_with_optional_connect(true, cx)
+ .await
+ .unwrap();
+ })
+ .shared()
+ })
+ .clone()
+ .await;
}
- #[cfg(unix)]
- std::os::unix::fs::symlink(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
- .context("creating latest link")?;
-
- #[cfg(windows)]
- std::os::windows::fs::symlink_dir(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
- .context("creating latest link")?;
-
- store.update(cx, |store, _cx| {
- store.with_eval_cache(Arc::new(RunCache {
- example_run_dir: example_run_dir.clone(),
- cache_mode,
- }));
- })?;
-
- let (cursor_buffer, cursor_anchor) = example.cursor_position(&project, cx).await?;
-
- let result = Arc::new(Mutex::new(PredictionDetails::new(example_run_dir.clone())));
-
- let prompt_format = options.zeta2.prompt_format;
-
- store.update(cx, |store, _cx| {
- let mut options = store.options().clone();
- options.prompt_format = prompt_format.into();
- store.set_options(options);
- })?;
+ let ep_store = cx
+ .update(|cx| EditPredictionStore::try_global(cx).unwrap())
+ .unwrap();
- let mut debug_task = gpui::Task::ready(Ok(()));
+ ep_store
+ .update(&mut cx, |store, _cx| {
+ let model = match provider {
+ PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
+ PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
+ PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
+ PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
+ PredictionProvider::Teacher => unreachable!(),
+ };
+ store.set_edit_prediction_model(model);
+ })
+ .unwrap();
+ let state = example.state.as_ref().unwrap();
+ let run_dir = RUN_DIR.join(&example.name);
- if options.provider == crate::PredictionProvider::Zeta2 {
- let mut debug_rx = store.update(cx, |store, _| store.debug_info())?;
+ let updated_example = Arc::new(Mutex::new(example.clone()));
+ let current_run_ix = Arc::new(AtomicUsize::new(0));
- debug_task = cx.background_spawn({
- let result = result.clone();
- async move {
- let mut start_time = None;
- let mut retrieval_finished_at = None;
- while let Some(event) = debug_rx.next().await {
- match event {
- edit_prediction::DebugEvent::ContextRetrievalStarted(info) => {
- start_time = Some(info.timestamp);
- fs::write(
- example_run_dir.join("search_prompt.md"),
- &info.search_prompt,
- )?;
+ let mut debug_rx = ep_store
+ .update(&mut cx, |store, cx| store.debug_info(&state.project, cx))
+ .unwrap();
+ let debug_task = cx.background_spawn({
+ let updated_example = updated_example.clone();
+ let current_run_ix = current_run_ix.clone();
+ let run_dir = run_dir.clone();
+ async move {
+ while let Some(event) = debug_rx.next().await {
+ let run_ix = current_run_ix.load(SeqCst);
+ let mut updated_example = updated_example.lock().unwrap();
+
+ let run_dir = if repetition_count > 1 {
+ run_dir.join(format!("{:03}", run_ix))
+ } else {
+ run_dir.clone()
+ };
+
+ match event {
+ DebugEvent::EditPredictionStarted(request) => {
+ assert_eq!(updated_example.predictions.len(), run_ix + 1);
+
+ if let Some(prompt) = request.prompt {
+ fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
}
- edit_prediction::DebugEvent::ContextRetrievalFinished(info) => {
- retrieval_finished_at = Some(info.timestamp);
- for (key, value) in &info.metadata {
- if *key == "search_queries" {
- fs::write(
- example_run_dir.join("search_queries.json"),
- value.as_bytes(),
- )?;
- }
- }
+ }
+ DebugEvent::EditPredictionFinished(request) => {
+ assert_eq!(updated_example.predictions.len(), run_ix + 1);
+
+ if let Some(output) = request.model_output {
+ fs::write(run_dir.join("prediction_response.md"), &output)?;
+ updated_example
+ .predictions
+ .last_mut()
+ .unwrap()
+ .actual_output = output;
}
- edit_prediction::DebugEvent::EditPredictionRequested(request) => {
- let prediction_started_at = Instant::now();
- start_time.get_or_insert(prediction_started_at);
- let prompt = request.local_prompt.unwrap_or_default();
- fs::write(example_run_dir.join("prediction_prompt.md"), &prompt)?;
-
- {
- let mut result = result.lock().unwrap();
- result.prompt_len = prompt.chars().count();
-
- for included_file in request.inputs.included_files {
- let insertions =
- vec![(request.inputs.cursor_point, CURSOR_MARKER)];
- result.excerpts.extend(included_file.excerpts.iter().map(
- |excerpt| ActualExcerpt {
- path: included_file.path.components().skip(1).collect(),
- text: String::from(excerpt.text.as_ref()),
- },
- ));
- write_codeblock(
- &included_file.path,
- included_file.excerpts.iter(),
- if included_file.path == request.inputs.cursor_path {
- &insertions
- } else {
- &[]
- },
- included_file.max_row,
- false,
- &mut result.excerpts_text,
- );
- }
- }
-
- let response =
- request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
- let response =
- edit_prediction::open_ai_response::text_from_response(response)
- .unwrap_or_default();
- let prediction_finished_at = Instant::now();
- fs::write(example_run_dir.join("prediction_response.md"), &response)?;
-
- let mut result = result.lock().unwrap();
- result.generated_len = response.chars().count();
- result.retrieval_time =
- retrieval_finished_at.unwrap() - start_time.unwrap();
- result.prediction_time = prediction_finished_at - prediction_started_at;
- result.total_time = prediction_finished_at - start_time.unwrap();
-
+ if run_ix >= repetition_count {
break;
}
}
+ _ => {}
}
- anyhow::Ok(())
}
- });
-
- store.update(cx, |store, cx| {
- store.refresh_context(&project, &cursor_buffer, cursor_anchor, cx)
- })?;
- }
-
- let prediction = store
- .update(cx, |store, cx| {
- store.request_prediction(
- &project,
- &cursor_buffer,
- cursor_anchor,
- cloud_llm_client::PredictEditsRequestTrigger::Cli,
- cx,
- )
- })?
- .await?;
-
- debug_task.await?;
-
- let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap();
-
- result.diff = prediction
- .and_then(|prediction| {
- let prediction = prediction.prediction.ok()?;
- prediction.edit_preview.as_unified_diff(&prediction.edits)
- })
- .unwrap_or_default();
-
- anyhow::Ok(result)
-}
-
-struct RunCache {
- cache_mode: CacheMode,
- example_run_dir: PathBuf,
-}
+ anyhow::Ok(())
+ }
+ });
-impl RunCache {
- fn output_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
- CACHE_DIR.join(format!("{kind}_out_{key:x}.json",))
- }
+ for ix in 0..repetition_count {
+ current_run_ix.store(ix, SeqCst);
+ let run_dir = if repetition_count > 1 {
+ run_dir.join(format!("{:03}", ix))
+ } else {
+ run_dir.clone()
+ };
- fn input_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
- CACHE_DIR.join(format!("{kind}_in_{key:x}.json",))
+ fs::create_dir_all(&run_dir).unwrap();
+ if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
+ fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR).unwrap();
+ }
+ #[cfg(unix)]
+ std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap();
+ #[cfg(windows)]
+ std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap();
+
+ updated_example
+ .lock()
+ .unwrap()
+ .predictions
+ .push(ExamplePrediction {
+ actual_patch: String::new(),
+ actual_output: String::new(),
+ provider,
+ });
+
+ let prediction = ep_store
+ .update(&mut cx, |store, cx| {
+ store.request_prediction(
+ &state.project,
+ &state.buffer,
+ state.cursor_position,
+ cloud_llm_client::PredictEditsRequestTrigger::Cli,
+ cx,
+ )
+ })
+ .unwrap()
+ .await
+ .unwrap();
+
+ updated_example
+ .lock()
+ .unwrap()
+ .predictions
+ .last_mut()
+ .unwrap()
+ .actual_patch = prediction
+ .and_then(|prediction| {
+ let prediction = prediction.prediction.ok()?;
+ prediction.edit_preview.as_unified_diff(&prediction.edits)
+ })
+ .unwrap_or_default();
}
- fn link_to_run(&self, key: &EvalCacheKey) {
- let output_link_path = self.example_run_dir.join(format!("{}_out.json", key.0));
- fs::hard_link(Self::output_cache_path(key), &output_link_path).unwrap();
+ ep_store
+ .update(&mut cx, |store, _| {
+ store.remove_project(&state.project);
+ })
+ .unwrap();
+ debug_task.await.unwrap();
- let input_link_path = self.example_run_dir.join(format!("{}_in.json", key.0));
- fs::hard_link(Self::input_cache_path(key), &input_link_path).unwrap();
- }
+ *example = Arc::into_inner(updated_example)
+ .unwrap()
+ .into_inner()
+ .unwrap();
}
-impl EvalCache for RunCache {
- fn read(&self, key: EvalCacheKey) -> Option<String> {
- let path = RunCache::output_cache_path(&key);
-
- if path.exists() {
- let use_cache = match key.0 {
- EvalCacheEntryKind::Search => self.cache_mode.use_cached_search_results(),
- EvalCacheEntryKind::Context | EvalCacheEntryKind::Prediction => {
- self.cache_mode.use_cached_llm_responses()
- }
- };
- if use_cache {
- log::info!("Using cache entry: {}", path.display());
- self.link_to_run(&key);
- Some(fs::read_to_string(path).unwrap())
- } else {
- log::trace!("Skipping cached entry: {}", path.display());
- None
- }
- } else if matches!(self.cache_mode, CacheMode::Force) {
- panic!(
- "No cached entry found for {:?}. Run without `--cache force` at least once.",
- key.0
- );
- } else {
- None
- }
- }
-
- fn write(&self, key: EvalCacheKey, input: &str, output: &str) {
- fs::create_dir_all(&*CACHE_DIR).unwrap();
+async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batched: bool) {
+ let llm_model_name = "claude-sonnet-4-5";
+ let max_tokens = 16384;
+ let llm_client = if batched {
+ AnthropicClient::batch(&crate::paths::LLM_CACHE_DB.as_ref())
+ } else {
+ AnthropicClient::plain()
+ };
+ let llm_client = llm_client.expect("Failed to create LLM client");
+
+ let prompt = example
+ .prompt
+ .as_ref()
+ .unwrap_or_else(|| panic!("Prompt is required for an example {}", &example.name));
+
+ let messages = vec![anthropic::Message {
+ role: anthropic::Role::User,
+ content: vec![anthropic::RequestContent::Text {
+ text: prompt.input.clone(),
+ cache_control: None,
+ }],
+ }];
+
+ let Some(response) = llm_client
+ .generate(llm_model_name, max_tokens, messages)
+ .await
+ .unwrap()
+ else {
+ // Request stashed for batched processing
+ return;
+ };
+
+ let actual_output = response
+ .content
+ .into_iter()
+ .filter_map(|content| match content {
+ anthropic::ResponseContent::Text { text } => Some(text),
+ _ => None,
+ })
+ .collect::<Vec<String>>()
+ .join("\n");
- let input_path = RunCache::input_cache_path(&key);
- fs::write(&input_path, input).unwrap();
+ let actual_patch = TeacherPrompt::parse(example, &actual_output);
- let output_path = RunCache::output_cache_path(&key);
- log::trace!("Writing cache entry: {}", output_path.display());
- fs::write(&output_path, output).unwrap();
+ let prediction = ExamplePrediction {
+ actual_patch,
+ actual_output,
+ provider: PredictionProvider::Teacher,
+ };
- self.link_to_run(&key);
- }
+ example.predictions.push(prediction);
}
-#[derive(Clone, Debug, Serialize, Deserialize)]
-pub struct PredictionDetails {
- pub diff: String,
- pub excerpts: Vec<ActualExcerpt>,
- pub excerpts_text: String, // TODO: contains the worktree root path. Drop this field and compute it on the fly
- pub retrieval_time: Duration,
- pub prediction_time: Duration,
- pub total_time: Duration,
- pub run_example_dir: PathBuf,
- pub prompt_len: usize,
- pub generated_len: usize,
-}
-
-impl PredictionDetails {
- pub fn new(run_example_dir: PathBuf) -> Self {
- Self {
- diff: Default::default(),
- excerpts: Default::default(),
- excerpts_text: Default::default(),
- retrieval_time: Default::default(),
- prediction_time: Default::default(),
- total_time: Default::default(),
- run_example_dir,
- prompt_len: 0,
- generated_len: 0,
+pub async fn sync_batches(provider: &PredictionProvider) {
+ match provider {
+ PredictionProvider::Teacher => {
+ let cache_path = crate::paths::LLM_CACHE_DB.as_ref();
+ let llm_client =
+ AnthropicClient::batch(cache_path).expect("Failed to create LLM client");
+ llm_client
+ .sync_batches()
+ .await
+ .expect("Failed to sync batches");
}
- }
-
- pub fn write(&self, format: PredictionsOutputFormat, mut out: impl Write) -> Result<()> {
- let formatted = match format {
- PredictionsOutputFormat::Md => self.to_markdown(),
- PredictionsOutputFormat::Json => serde_json::to_string_pretty(self)?,
- PredictionsOutputFormat::Diff => self.diff.clone(),
- };
-
- Ok(out.write_all(formatted.as_bytes())?)
- }
-
- pub fn to_markdown(&self) -> String {
- format!(
- "## Excerpts\n\n\
- {}\n\n\
- ## Prediction\n\n\
- {}\n\n\
- ## Time\n\n\
- Retrieval: {}ms\n\
- Prediction: {}ms\n\n\
- Total: {}ms\n",
- self.excerpts_text,
- self.diff,
- self.retrieval_time.as_millis(),
- self.prediction_time.as_millis(),
- self.total_time.as_millis(),
- )
+ _ => (),
}
}
@@ -1,106 +1,136 @@
-use anyhow::{Result, anyhow};
-use futures::channel::mpsc;
-use futures::{FutureExt as _, StreamExt as _};
+use crate::{
+ example::{Example, ExampleContext},
+ headless::EpAppState,
+ load_project::run_load_project,
+};
+use anyhow::Result;
+use collections::HashSet;
+use edit_prediction::{DebugEvent, EditPredictionStore};
+use futures::{FutureExt as _, StreamExt as _, channel::mpsc};
use gpui::{AsyncApp, Entity, Task};
-use language::{Buffer, LanguageId, LanguageNotFound, LanguageServerId, ParseStatus};
-use project::lsp_store::OpenLspBufferHandle;
-use project::{Project, ProjectPath, Worktree};
-use std::collections::HashSet;
-use std::sync::Arc;
-use std::time::Duration;
-use util::rel_path::RelPath;
-
-pub fn open_buffer(
- project: Entity<Project>,
- worktree: Entity<Worktree>,
- path: Arc<RelPath>,
- cx: &AsyncApp,
-) -> Task<Result<Entity<Buffer>>> {
- cx.spawn(async move |cx| {
- let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath {
- worktree_id: worktree.id(),
- path,
- })?;
-
- let buffer = project
- .update(cx, |project, cx| project.open_buffer(project_path, cx))?
- .await?;
-
- let mut parse_status = buffer.read_with(cx, |buffer, _cx| buffer.parse_status())?;
- while *parse_status.borrow() != ParseStatus::Idle {
- parse_status.changed().await?;
+use language::{Buffer, LanguageNotFound};
+use project::Project;
+use std::{sync::Arc, time::Duration};
+
+pub async fn run_context_retrieval(
+ example: &mut Example,
+ app_state: Arc<EpAppState>,
+ mut cx: AsyncApp,
+) {
+ if example.context.is_some() {
+ return;
+ }
+
+ run_load_project(example, app_state.clone(), cx.clone()).await;
+
+ let state = example.state.as_ref().unwrap();
+ let project = state.project.clone();
+
+ let _lsp_handle = project
+ .update(&mut cx, |project, cx| {
+ project.register_buffer_with_language_servers(&state.buffer, cx)
+ })
+ .unwrap();
+
+ wait_for_language_server_to_start(example, &project, &state.buffer, &mut cx).await;
+
+ let ep_store = cx
+ .update(|cx| EditPredictionStore::try_global(cx).unwrap())
+ .unwrap();
+
+ let mut events = ep_store
+ .update(&mut cx, |store, cx| {
+ store.register_buffer(&state.buffer, &project, cx);
+ store.set_use_context(true);
+ store.refresh_context(&project, &state.buffer, state.cursor_position, cx);
+ store.debug_info(&project, cx)
+ })
+ .unwrap();
+
+ while let Some(event) = events.next().await {
+ match event {
+ DebugEvent::ContextRetrievalFinished(_) => {
+ break;
+ }
+ _ => {}
}
+ }
- Ok(buffer)
- })
+ let context_files = ep_store
+ .update(&mut cx, |store, cx| store.context_for_project(&project, cx))
+ .unwrap();
+
+ example.context = Some(ExampleContext {
+ files: context_files,
+ });
}
-pub async fn open_buffer_with_language_server(
- project: Entity<Project>,
- worktree: Entity<Worktree>,
- path: Arc<RelPath>,
- ready_languages: &mut HashSet<LanguageId>,
+async fn wait_for_language_server_to_start(
+ example: &Example,
+ project: &Entity<Project>,
+ buffer: &Entity<Buffer>,
cx: &mut AsyncApp,
-) -> Result<(OpenLspBufferHandle, LanguageServerId, Entity<Buffer>)> {
- let buffer = open_buffer(project.clone(), worktree, path.clone(), cx).await?;
-
- let (lsp_open_handle, path_style) = project.update(cx, |project, cx| {
- (
- project.register_buffer_with_language_servers(&buffer, cx),
- project.path_style(cx),
- )
- })?;
-
- let language_registry = project.read_with(cx, |project, _| project.languages().clone())?;
+) {
+ let language_registry = project
+ .read_with(cx, |project, _| project.languages().clone())
+ .unwrap();
let result = language_registry
- .load_language_for_file_path(path.as_std_path())
+ .load_language_for_file_path(&example.cursor_path)
.await;
if let Err(error) = result
&& !error.is::<LanguageNotFound>()
{
- anyhow::bail!(error);
+ panic!("Failed to load language for file path: {}", error);
}
- let Some(language_id) = buffer.read_with(cx, |buffer, _cx| {
- buffer.language().map(|language| language.id())
- })?
+ let Some(language_id) = buffer
+ .read_with(cx, |buffer, _cx| {
+ buffer.language().map(|language| language.id())
+ })
+ .unwrap()
else {
- return Err(anyhow!("No language for {}", path.display(path_style)));
+ panic!("No language for {:?}", example.cursor_path);
};
- let log_prefix = format!("{} | ", path.display(path_style));
+ let mut ready_languages = HashSet::default();
+ let log_prefix = format!("{} | ", example.name);
if !ready_languages.contains(&language_id) {
- wait_for_lang_server(&project, &buffer, log_prefix, cx).await?;
+ wait_for_lang_server(&project, &buffer, log_prefix, cx)
+ .await
+ .unwrap();
ready_languages.insert(language_id);
}
- let lsp_store = project.read_with(cx, |project, _cx| project.lsp_store())?;
+ let lsp_store = project
+ .read_with(cx, |project, _cx| project.lsp_store())
+ .unwrap();
// hacky wait for buffer to be registered with the language server
for _ in 0..100 {
- let Some(language_server_id) = lsp_store.update(cx, |lsp_store, cx| {
- buffer.update(cx, |buffer, cx| {
- lsp_store
- .language_servers_for_local_buffer(&buffer, cx)
- .next()
- .map(|(_, language_server)| language_server.server_id())
+ if lsp_store
+ .update(cx, |lsp_store, cx| {
+ buffer.update(cx, |buffer, cx| {
+ lsp_store
+ .language_servers_for_local_buffer(&buffer, cx)
+ .next()
+ .map(|(_, language_server)| language_server.server_id())
+ })
})
- })?
- else {
+ .unwrap()
+ .is_some()
+ {
+ return;
+ } else {
cx.background_executor()
.timer(Duration::from_millis(10))
.await;
- continue;
- };
-
- return Ok((lsp_open_handle, language_server_id, buffer));
+ }
}
- return Err(anyhow!("No language server found for buffer"));
+ panic!("No language server found for buffer");
}
-// TODO: Dedupe with similar function in crates/eval/src/instance.rs
pub fn wait_for_lang_server(
project: &Entity<Project>,
buffer: &Entity<Buffer>,
@@ -0,0 +1,119 @@
+use crate::{
+ PredictArgs,
+ example::{Example, ExampleScore},
+ headless::EpAppState,
+ metrics::{self, ClassificationMetrics},
+ predict::run_prediction,
+};
+use edit_prediction::udiff::DiffLine;
+use gpui::AsyncApp;
+use std::sync::Arc;
+
+pub async fn run_scoring(
+ example: &mut Example,
+ args: &PredictArgs,
+ app_state: Arc<EpAppState>,
+ cx: AsyncApp,
+) {
+ run_prediction(
+ example,
+ Some(args.provider),
+ args.repetitions,
+ app_state,
+ cx,
+ )
+ .await;
+
+ let expected_patch = parse_patch(&example.expected_patch);
+
+ let mut scores = vec![];
+
+ for pred in &example.predictions {
+ let actual_patch = parse_patch(&pred.actual_patch);
+ let line_match = metrics::line_match_score(&expected_patch, &actual_patch);
+ let delta_chr_f = metrics::delta_chr_f(&expected_patch, &actual_patch) as f32;
+
+ scores.push(ExampleScore {
+ delta_chr_f,
+ line_match,
+ });
+ }
+
+ example.score = scores;
+}
+
+fn parse_patch(patch: &str) -> Vec<DiffLine<'_>> {
+ patch.lines().map(DiffLine::parse).collect()
+}
+
+pub fn print_report(examples: &[Example]) {
+ eprintln!(
+ "──────────────────────────────────────────────────────────────────────────────────────"
+ );
+ eprintln!(
+ "{:<30} {:>4} {:>4} {:>4} {:>10} {:>8} {:>8} {:>10}",
+ "Example name", "TP", "FP", "FN", "Precision", "Recall", "F1", "DeltaChrF"
+ );
+ eprintln!(
+ "──────────────────────────────────────────────────────────────────────────────────────"
+ );
+
+ let mut all_line_match_scores = Vec::new();
+ let mut all_delta_chr_f_scores = Vec::new();
+
+ for example in examples {
+ for score in example.score.iter() {
+ let line_match = &score.line_match;
+
+ eprintln!(
+ "{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}",
+ truncate_name(&example.name, 30),
+ line_match.true_positives,
+ line_match.false_positives,
+ line_match.false_negatives,
+ line_match.precision() * 100.0,
+ line_match.recall() * 100.0,
+ line_match.f1_score() * 100.0,
+ score.delta_chr_f
+ );
+
+ all_line_match_scores.push(line_match.clone());
+ all_delta_chr_f_scores.push(score.delta_chr_f);
+ }
+ }
+
+ eprintln!(
+ "──────────────────────────────────────────────────────────────────────────────────────"
+ );
+
+ if !all_line_match_scores.is_empty() {
+ let total_line_match = ClassificationMetrics::aggregate(all_line_match_scores.iter());
+ let avg_delta_chr_f: f32 =
+ all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
+
+ eprintln!(
+ "{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}",
+ "TOTAL",
+ total_line_match.true_positives,
+ total_line_match.false_positives,
+ total_line_match.false_negatives,
+ total_line_match.precision() * 100.0,
+ total_line_match.recall() * 100.0,
+ total_line_match.f1_score() * 100.0,
+ avg_delta_chr_f
+ );
+ eprintln!(
+ "──────────────────────────────────────────────────────────────────────────────────────"
+ );
+ }
+
+ eprintln!("\n");
+}
+
+fn truncate_name(name: &str, max_len: usize) -> String {
+ if name.len() <= max_len {
+ name.to_string()
+ } else {
+ format!("{}...", &name[..max_len - 3])
+ }
+}
@@ -1,70 +0,0 @@
-use std::{fmt, fmt::Display, path::Path, str::FromStr, sync::Arc};
-
-use ::util::{paths::PathStyle, rel_path::RelPath};
-use anyhow::{Result, anyhow};
-use language::Point;
-use serde::{Deserialize, Deserializer, Serialize, Serializer};
-
-#[derive(Debug, Clone, Hash, Eq, PartialEq)]
-pub struct SourceLocation {
- pub path: Arc<RelPath>,
- pub point: Point,
-}
-
-impl Serialize for SourceLocation {
- fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
- where
- S: Serializer,
- {
- serializer.serialize_str(&self.to_string())
- }
-}
-
-impl<'de> Deserialize<'de> for SourceLocation {
- fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
- where
- D: Deserializer<'de>,
- {
- let s = String::deserialize(deserializer)?;
- s.parse().map_err(serde::de::Error::custom)
- }
-}
-
-impl Display for SourceLocation {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- write!(
- f,
- "{}:{}:{}",
- self.path.display(PathStyle::Posix),
- self.point.row + 1,
- self.point.column + 1
- )
- }
-}
-
-impl FromStr for SourceLocation {
- type Err = anyhow::Error;
-
- fn from_str(s: &str) -> Result<Self> {
- let parts: Vec<&str> = s.split(':').collect();
- if parts.len() != 3 {
- return Err(anyhow!(
- "Invalid source location. Expected 'file.rs:line:column', got '{}'",
- s
- ));
- }
-
- let path = RelPath::new(Path::new(&parts[0]), PathStyle::local())?.into_arc();
- let line: u32 = parts[1]
- .parse()
- .map_err(|_| anyhow!("Invalid line number: '{}'", parts[1]))?;
- let column: u32 = parts[2]
- .parse()
- .map_err(|_| anyhow!("Invalid column number: '{}'", parts[2]))?;
-
- // Convert from 1-based to 0-based indexing
- let point = Point::new(line.saturating_sub(1), column.saturating_sub(1));
-
- Ok(SourceLocation { path, point })
- }
-}
@@ -46,3 +46,7 @@ Output example:
## Code Context
{{context}}
+
+## Editable region
+
+{{editable_region}}
@@ -1,89 +0,0 @@
-use std::path::Path;
-
-use crate::{source_location::SourceLocation, training::teacher::TeacherModel};
-
-#[derive(Debug, Clone, Default, clap::ValueEnum)]
-pub enum ContextType {
- #[default]
- CurrentFile,
-}
-
-const MAX_CONTEXT_SIZE: usize = 32768;
-
-pub fn collect_context(
- context_type: &ContextType,
- worktree_dir: &Path,
- cursor: SourceLocation,
-) -> String {
- let context = match context_type {
- ContextType::CurrentFile => {
- let file_path = worktree_dir.join(cursor.path.as_std_path());
- let context = std::fs::read_to_string(&file_path).unwrap_or_default();
-
- let context = add_special_tags(&context, worktree_dir, cursor);
- context
- }
- };
-
- let region_end_offset = context.find(TeacherModel::REGION_END);
-
- if context.len() <= MAX_CONTEXT_SIZE {
- return context;
- }
-
- if let Some(region_end_offset) = region_end_offset
- && region_end_offset + TeacherModel::REGION_END.len() > MAX_CONTEXT_SIZE
- {
- let to_truncate = context.len() - MAX_CONTEXT_SIZE;
- format!(
- "[...{} bytes truncated]\n{}\n",
- to_truncate,
- &context[to_truncate..]
- )
- } else {
- format!(
- "{}\n[...{} bytes truncated]\n",
- &context[..MAX_CONTEXT_SIZE],
- context.len() - MAX_CONTEXT_SIZE
- )
- }
-}
-
-/// Add <|editable_region_start/end|> tags
-fn add_special_tags(context: &str, worktree_dir: &Path, cursor: SourceLocation) -> String {
- let path = worktree_dir.join(cursor.path.as_std_path());
- let file = std::fs::read_to_string(&path).unwrap_or_default();
- let lines = file.lines().collect::<Vec<_>>();
- let cursor_row = cursor.point.row as usize;
- let start_line = cursor_row.saturating_sub(TeacherModel::LEFT_CONTEXT_SIZE);
- let end_line = (cursor_row + TeacherModel::RIGHT_CONTEXT_SIZE).min(lines.len());
-
- let snippet = lines[start_line..end_line].join("\n");
-
- if context.contains(&snippet) {
- let mut cursor_line = lines[cursor_row].to_string();
- cursor_line.insert_str(cursor.point.column as usize, TeacherModel::USER_CURSOR);
-
- let mut snippet_with_tags_lines = vec![];
- snippet_with_tags_lines.push(TeacherModel::REGION_START);
- snippet_with_tags_lines.extend(&lines[start_line..cursor_row]);
- snippet_with_tags_lines.push(&cursor_line);
- snippet_with_tags_lines.extend(&lines[cursor_row + 1..end_line]);
- snippet_with_tags_lines.push(TeacherModel::REGION_END);
- let snippet_with_tags = snippet_with_tags_lines.join("\n");
-
- context.replace(&snippet, &snippet_with_tags)
- } else {
- log::warn!(
- "Can't find area around the cursor in the context; proceeding without special tags"
- );
- context.to_string()
- }
-}
-
-pub fn strip_special_tags(context: &str) -> String {
- context
- .replace(TeacherModel::REGION_START, "")
- .replace(TeacherModel::REGION_END, "")
- .replace(TeacherModel::USER_CURSOR, "")
-}
@@ -1,94 +0,0 @@
-use serde::Deserialize;
-use std::sync::Arc;
-
-use crate::{
- DistillArguments,
- example::Example,
- source_location::SourceLocation,
- training::{
- context::ContextType,
- llm_client::LlmClient,
- teacher::{TeacherModel, TeacherOutput},
- },
-};
-use anyhow::Result;
-use reqwest_client::ReqwestClient;
-
-#[derive(Debug, Deserialize)]
-pub struct SplitCommit {
- repo_url: String,
- commit_sha: String,
- edit_history: String,
- expected_patch: String,
- cursor_position: String,
-}
-
-pub async fn run_distill(arguments: DistillArguments) -> Result<()> {
- let split_commits: Vec<SplitCommit> = std::fs::read_to_string(&arguments.split_commit_dataset)
- .expect("Failed to read split commit dataset")
- .lines()
- .map(|line| serde_json::from_str(line).expect("Failed to parse JSON line"))
- .collect();
-
- let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
-
- let llm_client = if let Some(cache_path) = arguments.batch {
- LlmClient::batch(&cache_path, http_client)?
- } else {
- LlmClient::plain(http_client)?
- };
-
- let mut teacher = TeacherModel::new(
- "claude-sonnet-4-5".to_string(),
- ContextType::CurrentFile,
- llm_client,
- );
-
- let mut num_marked_for_batching = 0;
-
- for commit in split_commits {
- if let Some(distilled) = distill_one(&mut teacher, commit).await? {
- println!("{}", serde_json::to_string(&distilled)?);
- } else {
- if num_marked_for_batching == 0 {
- log::warn!("Marked for batching");
- }
- num_marked_for_batching += 1;
- }
- }
-
- eprintln!(
- "{} requests are marked for batching",
- num_marked_for_batching
- );
- let llm_client = teacher.client;
- llm_client.sync_batches().await?;
-
- Ok(())
-}
-
-pub async fn distill_one(
- teacher: &mut TeacherModel,
- commit: SplitCommit,
-) -> Result<Option<TeacherOutput>> {
- let cursor: SourceLocation = commit
- .cursor_position
- .parse()
- .expect("Failed to parse cursor position");
-
- let path = cursor.path.to_rel_path_buf();
-
- let example = Example {
- repository_url: commit.repo_url,
- revision: commit.commit_sha,
- uncommitted_diff: commit.edit_history.clone(),
- cursor_path: path.as_std_path().to_path_buf(),
- cursor_position: commit.cursor_position,
- edit_history: commit.edit_history, // todo: trim
- expected_patch: commit.expected_patch,
- };
-
- let prediction = teacher.predict(example).await;
-
- prediction
-}
@@ -1,4 +0,0 @@
-pub mod context;
-pub mod distill;
-pub mod llm_client;
-pub mod teacher;
@@ -1,266 +0,0 @@
-use crate::{
- example::Example,
- source_location::SourceLocation,
- training::{
- context::{ContextType, collect_context, strip_special_tags},
- llm_client::LlmClient,
- },
-};
-use anthropic::{Message, RequestContent, ResponseContent, Role};
-use anyhow::Result;
-
-pub struct TeacherModel {
- pub llm_name: String,
- pub context: ContextType,
- pub client: LlmClient,
-}
-
-#[derive(Debug, serde::Serialize)]
-pub struct TeacherOutput {
- parsed_output: String,
- prompt: String,
- raw_llm_response: String,
- context: String,
- diff: String,
-}
-
-impl TeacherModel {
- const PROMPT: &str = include_str!("teacher.prompt.md");
- pub(crate) const REGION_START: &str = "<|editable_region_start|>\n";
- pub(crate) const REGION_END: &str = "<|editable_region_end|>";
- pub(crate) const USER_CURSOR: &str = "<|user_cursor|>";
-
- /// Number of lines to include before the cursor position
- pub(crate) const LEFT_CONTEXT_SIZE: usize = 5;
-
- /// Number of lines to include after the cursor position
- pub(crate) const RIGHT_CONTEXT_SIZE: usize = 5;
-
- /// Truncate edit history to this number of last lines
- const MAX_HISTORY_LINES: usize = 128;
-
- pub fn new(llm_name: String, context: ContextType, client: LlmClient) -> Self {
- TeacherModel {
- llm_name,
- context,
- client,
- }
- }
-
- pub async fn predict(&self, input: Example) -> Result<Option<TeacherOutput>> {
- let name = input.unique_name();
- let worktree_dir = input.setup_worktree(name).await?;
- let cursor: SourceLocation = input
- .cursor_position
- .parse()
- .expect("Failed to parse cursor position");
-
- let context = collect_context(&self.context, &worktree_dir, cursor.clone());
- let edit_history = Self::format_edit_history(&input.edit_history);
-
- let prompt = Self::PROMPT
- .replace("{{context}}", &context)
- .replace("{{edit_history}}", &edit_history);
-
- let messages = vec![Message {
- role: Role::User,
- content: vec![RequestContent::Text {
- text: prompt.clone(),
- cache_control: None,
- }],
- }];
-
- let Some(response) = self
- .client
- .generate(self.llm_name.clone(), 16384, messages)
- .await?
- else {
- return Ok(None);
- };
-
- let response_text = response
- .content
- .into_iter()
- .filter_map(|content| match content {
- ResponseContent::Text { text } => Some(text),
- _ => None,
- })
- .collect::<Vec<String>>()
- .join("\n");
-
- let parsed_output = self.parse_response(&response_text);
-
- let original_editable_region = Self::extract_editable_region(&context);
- let context_after_edit = context.replace(&original_editable_region, &parsed_output);
- let context_after_edit = strip_special_tags(&context_after_edit);
- let context_before_edit = strip_special_tags(&context);
- let diff = language::unified_diff(&context_before_edit, &context_after_edit);
-
- // zeta distill --batch batch_results.txt
- // zeta distill
- // 1. Run `zeta distill <2000 examples <- all examples>` for the first time
- // - store LLM requests in a batch, don't actual send the request
- // - send the batch (2000 requests) after all inputs are processed
- // 2. `zeta send-batches`
- // - upload the batch to Anthropic
-
- // https://platform.claude.com/docs/en/build-with-claude/batch-processing
- // https://crates.io/crates/anthropic-sdk-rust
-
- // - poll for results
- // - when ready, store results in cache (a database)
- // 3. `zeta distill` again
- // - use the cached results this time
-
- Ok(Some(TeacherOutput {
- parsed_output,
- prompt,
- raw_llm_response: response_text,
- context,
- diff,
- }))
- }
-
- fn parse_response(&self, content: &str) -> String {
- let codeblock = Self::extract_last_codeblock(content);
- let editable_region = Self::extract_editable_region(&codeblock);
-
- editable_region
- }
-
- /// Extract content from the last code-fenced block if any, or else return content as is
- fn extract_last_codeblock(text: &str) -> String {
- let mut last_block = None;
- let mut search_start = 0;
-
- while let Some(start) = text[search_start..].find("```") {
- let start = start + search_start;
- let bytes = text.as_bytes();
- let mut backtick_end = start;
-
- while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
- backtick_end += 1;
- }
-
- let backtick_count = backtick_end - start;
- let closing_backticks = "`".repeat(backtick_count);
-
- if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) {
- let code_block = &text[backtick_end + 1..backtick_end + end_pos - 1];
- last_block = Some(code_block.to_string());
- search_start = backtick_end + end_pos + backtick_count;
- } else {
- break;
- }
- }
-
- last_block.unwrap_or_else(|| text.to_string())
- }
-
- fn extract_editable_region(text: &str) -> String {
- let start = text
- .find(Self::REGION_START)
- .map_or(0, |pos| pos + Self::REGION_START.len());
- let end = text.find(Self::REGION_END).unwrap_or(text.len());
-
- text[start..end].to_string()
- }
-
- /// Truncates edit history to a maximum length and removes comments (unified diff garbage lines)
- fn format_edit_history(edit_history: &str) -> String {
- let lines = edit_history
- .lines()
- .filter(|&s| Self::is_content_line(s))
- .collect::<Vec<_>>();
-
- let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
- &lines[lines.len() - Self::MAX_HISTORY_LINES..]
- } else {
- &lines
- };
- history_lines.join("\n")
- }
-
- fn is_content_line(s: &str) -> bool {
- s.starts_with("-")
- || s.starts_with("+")
- || s.starts_with(" ")
- || s.starts_with("---")
- || s.starts_with("+++")
- || s.starts_with("@@")
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[test]
- fn test_parse_response() {
- let teacher = TeacherModel::new(
- "test".to_string(),
- ContextType::CurrentFile,
- LlmClient::dummy(),
- );
- let response = "This is a test response.";
- let parsed = teacher.parse_response(response);
- assert_eq!(parsed, response.to_string());
-
- let response = indoc::indoc! {"
- Some thinking
-
- `````
- actual response
- `````
- "};
- let parsed = teacher.parse_response(response);
- assert_eq!(parsed, "actual response");
- }
-
- #[test]
- fn test_extract_last_code_block() {
- let text = indoc::indoc! {"
- Some thinking
-
- ```
- first block
- ```
-
- `````
- last block
- `````
- "};
- let last_block = TeacherModel::extract_last_codeblock(text);
- assert_eq!(last_block, "last block");
- }
-
- #[test]
- fn test_extract_editable_region() {
- let teacher = TeacherModel::new(
- "test".to_string(),
- ContextType::CurrentFile,
- LlmClient::dummy(),
- );
- let response = indoc::indoc! {"
- some lines
- are
- here
- <|editable_region_start|>
- one
- two three
-
- <|editable_region_end|>
- more
- lines here
- "};
- let parsed = teacher.parse_response(response);
- assert_eq!(
- parsed,
- indoc::indoc! {"
- one
- two three
-
- "}
- );
- }
-}
@@ -26,6 +26,7 @@ serde.workspace = true
smallvec.workspace = true
tree-sitter.workspace = true
util.workspace = true
+zeta_prompt.workspace = true
[dev-dependencies]
env_logger.workspace = true
@@ -1,6 +1,6 @@
-use crate::RelatedExcerpt;
use language::{BufferSnapshot, OffsetRangeExt as _, Point};
use std::ops::Range;
+use zeta_prompt::RelatedExcerpt;
#[cfg(not(test))]
const MAX_OUTLINE_ITEM_BODY_SIZE: usize = 512;
@@ -76,14 +76,9 @@ pub fn assemble_excerpts(
input_ranges
.into_iter()
- .map(|range| {
- let offset_range = range.to_offset(buffer);
- RelatedExcerpt {
- point_range: range,
- anchor_range: buffer.anchor_before(offset_range.start)
- ..buffer.anchor_after(offset_range.end),
- text: buffer.as_rope().slice(offset_range),
- }
+ .map(|range| RelatedExcerpt {
+ row_range: range.start.row..range.end.row,
+ text: buffer.text_for_range(range).collect(),
})
.collect()
}
@@ -3,13 +3,13 @@ use anyhow::Result;
use collections::HashMap;
use futures::{FutureExt, StreamExt as _, channel::mpsc, future};
use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
-use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, Rope, ToOffset as _};
+use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset as _};
use project::{LocationLink, Project, ProjectPath};
-use serde::{Serialize, Serializer};
use smallvec::SmallVec;
use std::{
collections::hash_map,
ops::Range,
+ path::Path,
sync::Arc,
time::{Duration, Instant},
};
@@ -24,12 +24,14 @@ mod fake_definition_lsp;
pub use cloud_llm_client::predict_edits_v3::Line;
pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText};
+pub use zeta_prompt::{RelatedExcerpt, RelatedFile};
const IDENTIFIER_LINE_COUNT: u32 = 3;
pub struct RelatedExcerptStore {
project: WeakEntity<Project>,
- related_files: Vec<RelatedFile>,
+ related_files: Arc<[RelatedFile]>,
+ related_file_buffers: Vec<Entity<Buffer>>,
cache: HashMap<Identifier, Arc<CacheEntry>>,
update_tx: mpsc::UnboundedSender<(Entity<Buffer>, Anchor)>,
identifier_line_count: u32,
@@ -68,82 +70,6 @@ struct CachedDefinition {
anchor_range: Range<Anchor>,
}
-#[derive(Clone, Debug, Serialize)]
-pub struct RelatedFile {
- #[serde(serialize_with = "serialize_project_path")]
- pub path: ProjectPath,
- #[serde(skip)]
- pub buffer: WeakEntity<Buffer>,
- pub excerpts: Vec<RelatedExcerpt>,
- pub max_row: u32,
-}
-
-impl RelatedFile {
- pub fn merge_excerpts(&mut self) {
- self.excerpts.sort_unstable_by(|a, b| {
- a.point_range
- .start
- .cmp(&b.point_range.start)
- .then(b.point_range.end.cmp(&a.point_range.end))
- });
-
- let mut index = 1;
- while index < self.excerpts.len() {
- if self.excerpts[index - 1]
- .point_range
- .end
- .cmp(&self.excerpts[index].point_range.start)
- .is_ge()
- {
- let removed = self.excerpts.remove(index);
- if removed
- .point_range
- .end
- .cmp(&self.excerpts[index - 1].point_range.end)
- .is_gt()
- {
- self.excerpts[index - 1].point_range.end = removed.point_range.end;
- self.excerpts[index - 1].anchor_range.end = removed.anchor_range.end;
- }
- } else {
- index += 1;
- }
- }
- }
-}
-
-#[derive(Clone, Debug, Serialize)]
-pub struct RelatedExcerpt {
- #[serde(skip)]
- pub anchor_range: Range<Anchor>,
- #[serde(serialize_with = "serialize_point_range")]
- pub point_range: Range<Point>,
- #[serde(serialize_with = "serialize_rope")]
- pub text: Rope,
-}
-
-fn serialize_project_path<S: Serializer>(
- project_path: &ProjectPath,
- serializer: S,
-) -> Result<S::Ok, S::Error> {
- project_path.path.serialize(serializer)
-}
-
-fn serialize_rope<S: Serializer>(rope: &Rope, serializer: S) -> Result<S::Ok, S::Error> {
- rope.to_string().serialize(serializer)
-}
-
-fn serialize_point_range<S: Serializer>(
- range: &Range<Point>,
- serializer: S,
-) -> Result<S::Ok, S::Error> {
- [
- [range.start.row, range.start.column],
- [range.end.row, range.end.column],
- ]
- .serialize(serializer)
-}
-
const DEBOUNCE_DURATION: Duration = Duration::from_millis(100);
impl EventEmitter<RelatedExcerptStoreEvent> for RelatedExcerptStore {}
@@ -179,7 +105,8 @@ impl RelatedExcerptStore {
RelatedExcerptStore {
project: project.downgrade(),
update_tx,
- related_files: Vec::new(),
+ related_files: Vec::new().into(),
+ related_file_buffers: Vec::new(),
cache: Default::default(),
identifier_line_count: IDENTIFIER_LINE_COUNT,
}
@@ -193,8 +120,21 @@ impl RelatedExcerptStore {
self.update_tx.unbounded_send((buffer, position)).ok();
}
- pub fn related_files(&self) -> &[RelatedFile] {
- &self.related_files
+ pub fn related_files(&self) -> Arc<[RelatedFile]> {
+ self.related_files.clone()
+ }
+
+ pub fn related_files_with_buffers(
+ &self,
+ ) -> impl Iterator<Item = (RelatedFile, Entity<Buffer>)> {
+ self.related_files
+ .iter()
+ .cloned()
+ .zip(self.related_file_buffers.iter().cloned())
+ }
+
+ pub fn set_related_files(&mut self, files: Vec<RelatedFile>) {
+ self.related_files = files.into();
}
async fn fetch_excerpts(
@@ -297,7 +237,8 @@ impl RelatedExcerptStore {
}
mean_definition_latency /= cache_miss_count.max(1) as u32;
- let (new_cache, related_files) = rebuild_related_files(new_cache, cx).await?;
+ let (new_cache, related_files, related_file_buffers) =
+ rebuild_related_files(&project, new_cache, cx).await?;
if let Some(file) = &file {
log::debug!(
@@ -309,7 +250,8 @@ impl RelatedExcerptStore {
this.update(cx, |this, cx| {
this.cache = new_cache;
- this.related_files = related_files;
+ this.related_files = related_files.into();
+ this.related_file_buffers = related_file_buffers;
cx.emit(RelatedExcerptStoreEvent::FinishedRefresh {
cache_hit_count,
cache_miss_count,
@@ -323,10 +265,16 @@ impl RelatedExcerptStore {
}
async fn rebuild_related_files(
+ project: &Entity<Project>,
new_entries: HashMap<Identifier, Arc<CacheEntry>>,
cx: &mut AsyncApp,
-) -> Result<(HashMap<Identifier, Arc<CacheEntry>>, Vec<RelatedFile>)> {
+) -> Result<(
+ HashMap<Identifier, Arc<CacheEntry>>,
+ Vec<RelatedFile>,
+ Vec<Entity<Buffer>>,
+)> {
let mut snapshots = HashMap::default();
+ let mut worktree_root_names = HashMap::default();
for entry in new_entries.values() {
for definition in &entry.definitions {
if let hash_map::Entry::Vacant(e) = snapshots.entry(definition.buffer.entity_id()) {
@@ -340,12 +288,22 @@ async fn rebuild_related_files(
.read_with(cx, |buffer, _| buffer.snapshot())?,
);
}
+ let worktree_id = definition.path.worktree_id;
+ if let hash_map::Entry::Vacant(e) =
+ worktree_root_names.entry(definition.path.worktree_id)
+ {
+ project.read_with(cx, |project, cx| {
+ if let Some(worktree) = project.worktree_for_id(worktree_id, cx) {
+ e.insert(worktree.read(cx).root_name().as_unix_str().to_string());
+ }
+ })?;
+ }
}
}
Ok(cx
.background_spawn(async move {
- let mut files = Vec::<RelatedFile>::new();
+ let mut files = Vec::new();
let mut ranges_by_buffer = HashMap::<_, Vec<Range<Point>>>::default();
let mut paths_by_buffer = HashMap::default();
for entry in new_entries.values() {
@@ -369,20 +327,37 @@ async fn rebuild_related_files(
continue;
};
let excerpts = assemble_excerpts(snapshot, ranges);
- files.push(RelatedFile {
- path: project_path.clone(),
- buffer: buffer.downgrade(),
- excerpts,
- max_row: snapshot.max_point().row,
- });
+ let Some(root_name) = worktree_root_names.get(&project_path.worktree_id) else {
+ continue;
+ };
+
+ let path = Path::new(&format!(
+ "{}/{}",
+ root_name,
+ project_path.path.as_unix_str()
+ ))
+ .into();
+
+ files.push((
+ buffer,
+ RelatedFile {
+ path,
+ excerpts,
+ max_row: snapshot.max_point().row,
+ },
+ ));
}
- files.sort_by_key(|file| file.path.clone());
- (new_entries, files)
+ files.sort_by_key(|(_, file)| file.path.clone());
+ let (related_buffers, related_files) = files.into_iter().unzip();
+
+ (new_entries, related_files, related_buffers)
})
.await)
}
+const MAX_TARGET_LEN: usize = 128;
+
fn process_definition(
location: LocationLink,
project: &Entity<Project>,
@@ -395,6 +370,15 @@ fn process_definition(
if worktree.read(cx).is_single_file() {
return None;
}
+
+ // If the target range is large, it likely means we requested the definition of an entire module.
+ // For individual definitions, the target range should be small as it only covers the symbol.
+ let buffer = location.target.buffer.read(cx);
+ let target_len = anchor_range.to_offset(&buffer).len();
+ if target_len > MAX_TARGET_LEN {
+ return None;
+ }
+
Some(CachedDefinition {
path: ProjectPath {
worktree_id: file.worktree_id(cx),
@@ -48,7 +48,7 @@ async fn test_edit_prediction_context(cx: &mut TestAppContext) {
&excerpts,
&[
(
- "src/company.rs",
+ "root/src/company.rs",
&[indoc! {"
pub struct Company {
owner: Arc<Person>,
@@ -56,7 +56,7 @@ async fn test_edit_prediction_context(cx: &mut TestAppContext) {
}"}],
),
(
- "src/main.rs",
+ "root/src/main.rs",
&[
indoc! {"
pub struct Session {
@@ -71,7 +71,7 @@ async fn test_edit_prediction_context(cx: &mut TestAppContext) {
],
),
(
- "src/person.rs",
+ "root/src/person.rs",
&[
indoc! {"
impl Person {
@@ -446,7 +446,7 @@ fn assert_related_files(actual_files: &[RelatedFile], expected_files: &[(&str, &
.iter()
.map(|excerpt| excerpt.text.to_string())
.collect::<Vec<_>>();
- (file.path.path.as_unix_str(), excerpts)
+ (file.path.to_str().unwrap(), excerpts)
})
.collect::<Vec<_>>();
let expected_excerpts = expected_files
@@ -492,10 +492,10 @@ fn format_excerpts(buffer: &Buffer, excerpts: &[RelatedExcerpt]) -> String {
if excerpt.text.is_empty() {
continue;
}
- if current_row < excerpt.point_range.start.row {
+ if current_row < excerpt.row_range.start {
writeln!(&mut output, "…").unwrap();
}
- current_row = excerpt.point_range.start.row;
+ current_row = excerpt.row_range.start;
for line in excerpt.text.to_string().lines() {
output.push_str(line);
@@ -15,3 +15,4 @@ path = "src/edit_prediction_types.rs"
client.workspace = true
gpui.workspace = true
language.workspace = true
+text.workspace = true
@@ -2,7 +2,7 @@ use std::{ops::Range, sync::Arc};
use client::EditPredictionUsage;
use gpui::{App, Context, Entity, SharedString};
-use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt};
+use language::{Anchor, Buffer, OffsetRangeExt};
// TODO: Find a better home for `Direction`.
//
@@ -252,8 +252,8 @@ where
/// Returns edits updated based on user edits since the old snapshot. None is returned if any user
/// edit is not a prefix of a predicted insertion.
pub fn interpolate_edits(
- old_snapshot: &BufferSnapshot,
- new_snapshot: &BufferSnapshot,
+ old_snapshot: &text::BufferSnapshot,
+ new_snapshot: &text::BufferSnapshot,
current_edits: &[(Range<Anchor>, Arc<str>)],
) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
let mut edits = Vec::new();
@@ -17,7 +17,6 @@ anyhow.workspace = true
buffer_diff.workspace = true
client.workspace = true
cloud_llm_client.workspace = true
-cloud_zeta2_prompt.workspace = true
codestral.workspace = true
command_palette_hooks.workspace = true
copilot.workspace = true
@@ -46,6 +45,7 @@ ui_input.workspace = true
util.workspace = true
workspace.workspace = true
zed_actions.workspace = true
+zeta_prompt.workspace = true
[dev-dependencies]
copilot = { workspace = true, features = ["test-support"] }
@@ -17,7 +17,7 @@ use gpui::{
};
use multi_buffer::MultiBuffer;
use project::Project;
-use text::OffsetRangeExt;
+use text::Point;
use ui::{
ButtonCommon, Clickable, Disableable, FluentBuilder as _, IconButton, IconName,
StyledTypography as _, h_flex, v_flex,
@@ -66,7 +66,7 @@ impl EditPredictionContextView {
) -> Self {
let store = EditPredictionStore::global(client, user_store, cx);
- let mut debug_rx = store.update(cx, |store, _| store.debug_info());
+ let mut debug_rx = store.update(cx, |store, cx| store.debug_info(&project, cx));
let _update_task = cx.spawn_in(window, async move |this, cx| {
while let Some(event) = debug_rx.next().await {
this.update_in(cx, |this, window, cx| {
@@ -103,7 +103,8 @@ impl EditPredictionContextView {
self.handle_context_retrieval_finished(info, window, cx);
}
}
- DebugEvent::EditPredictionRequested(_) => {}
+ DebugEvent::EditPredictionStarted(_) => {}
+ DebugEvent::EditPredictionFinished(_) => {}
}
}
@@ -152,12 +153,11 @@ impl EditPredictionContextView {
run.finished_at = Some(info.timestamp);
run.metadata = info.metadata;
- let project = self.project.clone();
let related_files = self
.store
.read(cx)
- .context_for_project(&self.project, cx)
- .to_vec();
+ .context_for_project_with_buffers(&self.project, cx)
+ .map_or(Vec::new(), |files| files.collect());
let editor = run.editor.clone();
let multibuffer = run.editor.read(cx).buffer().clone();
@@ -168,33 +168,14 @@ impl EditPredictionContextView {
cx.spawn_in(window, async move |this, cx| {
let mut paths = Vec::new();
- for related_file in related_files {
- let (buffer, point_ranges): (_, Vec<_>) =
- if let Some(buffer) = related_file.buffer.upgrade() {
- let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
-
- (
- buffer,
- related_file
- .excerpts
- .iter()
- .map(|excerpt| excerpt.anchor_range.to_point(&snapshot))
- .collect(),
- )
- } else {
- (
- project
- .update(cx, |project, cx| {
- project.open_buffer(related_file.path.clone(), cx)
- })?
- .await?,
- related_file
- .excerpts
- .iter()
- .map(|excerpt| excerpt.point_range.clone())
- .collect(),
- )
- };
+ for (related_file, buffer) in related_files {
+ let point_ranges = related_file
+ .excerpts
+ .iter()
+ .map(|excerpt| {
+ Point::new(excerpt.row_range.start, 0)..Point::new(excerpt.row_range.end, 0)
+ })
+ .collect::<Vec<_>>();
cx.update(|_, cx| {
let path = PathKey::for_buffer(&buffer, cx);
paths.push((path, buffer, point_ranges));
@@ -1,5 +1,4 @@
use buffer_diff::{BufferDiff, BufferDiffSnapshot};
-use cloud_zeta2_prompt::write_codeblock;
use edit_prediction::{EditPrediction, EditPredictionRating, EditPredictionStore};
use editor::{Editor, ExcerptRange, MultiBuffer};
use feature_flags::FeatureFlag;
@@ -362,14 +361,14 @@ impl RatePredictionsModal {
write!(&mut formatted_inputs, "## Events\n\n").unwrap();
for event in &prediction.inputs.events {
- write!(&mut formatted_inputs, "```diff\n{event}```\n\n").unwrap();
+ formatted_inputs.push_str("```diff\n");
+ zeta_prompt::write_event(&mut formatted_inputs, event.as_ref());
+ formatted_inputs.push_str("```\n\n");
}
- write!(&mut formatted_inputs, "## Included files\n\n").unwrap();
-
- for included_file in &prediction.inputs.included_files {
- let cursor_insertions = &[(prediction.inputs.cursor_point, "<|CURSOR|>")];
+ write!(&mut formatted_inputs, "## Related files\n\n").unwrap();
+ for included_file in prediction.inputs.related_files.as_ref() {
write!(
&mut formatted_inputs,
"### {}\n\n",
@@ -377,20 +376,28 @@ impl RatePredictionsModal {
)
.unwrap();
- write_codeblock(
- &included_file.path,
- &included_file.excerpts,
- if included_file.path == prediction.inputs.cursor_path {
- cursor_insertions.as_slice()
- } else {
- &[]
- },
- included_file.max_row,
- false,
- &mut formatted_inputs,
- );
+ for excerpt in included_file.excerpts.iter() {
+ write!(
+ &mut formatted_inputs,
+ "```{}\n{}\n```\n",
+ included_file.path.display(),
+ excerpt.text
+ )
+ .unwrap();
+ }
}
+ write!(&mut formatted_inputs, "## Cursor Excerpt\n\n").unwrap();
+
+ writeln!(
+ &mut formatted_inputs,
+ "```{}\n{}<CURSOR>{}\n```\n",
+ prediction.inputs.cursor_path.display(),
+ &prediction.inputs.cursor_excerpt[..prediction.inputs.cursor_offset_in_excerpt],
+ &prediction.inputs.cursor_excerpt[prediction.inputs.cursor_offset_in_excerpt..],
+ )
+ .unwrap();
+
self.active_prediction = Some(ActivePrediction {
prediction,
feedback_editor: cx.new(|cx| {
@@ -680,6 +680,10 @@ actions!(
ReloadFile,
/// Rewraps text to fit within the preferred line length.
Rewrap,
+ /// Rotates selections or lines backward.
+ RotateSelectionsBackward,
+ /// Rotates selections or lines forward.
+ RotateSelectionsForward,
/// Runs flycheck diagnostics.
RunFlycheck,
/// Scrolls the cursor to the bottom of the viewport.
@@ -206,6 +206,13 @@ impl CodeContextMenu {
CodeContextMenu::CodeActions(_) => (),
}
}
+
+ pub fn primary_scroll_handle(&self) -> UniformListScrollHandle {
+ match self {
+ CodeContextMenu::Completions(menu) => menu.scroll_handle.clone(),
+ CodeContextMenu::CodeActions(menu) => menu.scroll_handle.clone(),
+ }
+ }
}
pub enum ContextMenuOrigin {
@@ -303,6 +310,7 @@ impl CompletionsMenu {
is_incomplete: bool,
buffer: Entity<Buffer>,
completions: Box<[Completion]>,
+ scroll_handle: Option<UniformListScrollHandle>,
display_options: CompletionDisplayOptions,
snippet_sort_order: SnippetSortOrder,
language_registry: Option<Arc<LanguageRegistry>>,
@@ -332,7 +340,7 @@ impl CompletionsMenu {
selected_item: 0,
filter_task: Task::ready(()),
cancel_filter: Arc::new(AtomicBool::new(false)),
- scroll_handle: UniformListScrollHandle::new(),
+ scroll_handle: scroll_handle.unwrap_or_else(UniformListScrollHandle::new),
scroll_handle_aside: ScrollHandle::new(),
resolve_completions: true,
last_rendered_range: RefCell::new(None).into(),
@@ -354,6 +362,7 @@ impl CompletionsMenu {
choices: &Vec<String>,
selection: Range<Anchor>,
buffer: Entity<Buffer>,
+ scroll_handle: Option<UniformListScrollHandle>,
snippet_sort_order: SnippetSortOrder,
) -> Self {
let completions = choices
@@ -404,7 +413,7 @@ impl CompletionsMenu {
selected_item: 0,
filter_task: Task::ready(()),
cancel_filter: Arc::new(AtomicBool::new(false)),
- scroll_handle: UniformListScrollHandle::new(),
+ scroll_handle: scroll_handle.unwrap_or_else(UniformListScrollHandle::new),
scroll_handle_aside: ScrollHandle::new(),
resolve_completions: false,
show_completion_documentation: false,
@@ -79,12 +79,15 @@ fn create_highlight_endpoints(
let start_ix = ranges
.binary_search_by(|probe| probe.end.cmp(&start, buffer).then(cmp::Ordering::Less))
.unwrap_or_else(|i| i);
+ let end_ix = ranges[start_ix..]
+ .binary_search_by(|probe| {
+ probe.start.cmp(&end, buffer).then(cmp::Ordering::Greater)
+ })
+ .unwrap_or_else(|i| i);
- for range in &ranges[start_ix..] {
- if range.start.cmp(&end, buffer).is_ge() {
- break;
- }
+ highlight_endpoints.reserve(2 * end_ix);
+ for range in &ranges[start_ix..][..end_ix] {
let start = range.start.to_offset(buffer);
let end = range.end.to_offset(buffer);
if start == end {
@@ -108,7 +108,7 @@ use gpui::{
DispatchPhase, Edges, Entity, EntityInputHandler, EventEmitter, FocusHandle, FocusOutEvent,
Focusable, FontId, FontWeight, Global, HighlightStyle, Hsla, KeyContext, Modifiers,
MouseButton, MouseDownEvent, MouseMoveEvent, PaintQuad, ParentElement, Pixels, Render,
- ScrollHandle, SharedString, Size, Stateful, Styled, Subscription, Task, TextStyle,
+ ScrollHandle, SharedString, Size, Stateful, Styled, Subscription, Task, TextRun, TextStyle,
TextStyleRefinement, UTF16Selection, UnderlineStyle, UniformListScrollHandle, WeakEntity,
WeakFocusHandle, Window, div, point, prelude::*, pulsating_between, px, relative, size,
};
@@ -575,7 +575,7 @@ impl Default for EditorStyle {
}
}
-pub fn make_inlay_hints_style(cx: &mut App) -> HighlightStyle {
+pub fn make_inlay_hints_style(cx: &App) -> HighlightStyle {
let show_background = language_settings::language_settings(None, None, cx)
.inlay_hints
.show_background;
@@ -598,7 +598,7 @@ pub fn make_inlay_hints_style(cx: &mut App) -> HighlightStyle {
style
}
-pub fn make_suggestion_styles(cx: &mut App) -> EditPredictionStyles {
+pub fn make_suggestion_styles(cx: &App) -> EditPredictionStyles {
EditPredictionStyles {
insertion: HighlightStyle {
color: Some(cx.theme().status().predictive),
@@ -1249,6 +1249,7 @@ impl NextScrollCursorCenterTopBottom {
pub struct EditorSnapshot {
pub mode: EditorMode,
show_gutter: bool,
+ offset_content: bool,
show_line_numbers: Option<bool>,
show_git_diff_gutter: Option<bool>,
show_code_actions: Option<bool>,
@@ -1825,7 +1826,11 @@ impl Editor {
Editor::new_internal(mode, buffer, project, None, window, cx)
}
- pub fn sticky_headers(&self, cx: &App) -> Option<Vec<OutlineItem<Anchor>>> {
+ pub fn sticky_headers(
+ &self,
+ style: &EditorStyle,
+ cx: &App,
+ ) -> Option<Vec<OutlineItem<Anchor>>> {
let multi_buffer = self.buffer().read(cx);
let multi_buffer_snapshot = multi_buffer.snapshot(cx);
let multi_buffer_visible_start = self
@@ -1843,7 +1848,7 @@ impl Editor {
.outline_items_containing(
Point::new(start_row, 0)..Point::new(end_row, 0),
true,
- self.style().map(|style| style.syntax.as_ref()),
+ Some(style.syntax.as_ref()),
)
.into_iter()
.map(|outline_item| OutlineItem {
@@ -2935,6 +2940,7 @@ impl Editor {
EditorSnapshot {
mode: self.mode.clone(),
show_gutter: self.show_gutter,
+ offset_content: self.offset_content,
show_line_numbers: self.show_line_numbers,
show_git_diff_gutter: self.show_git_diff_gutter,
show_code_actions: self.show_code_actions,
@@ -5882,6 +5888,11 @@ impl Editor {
is_incomplete,
buffer.clone(),
completions.into(),
+ editor
+ .context_menu()
+ .borrow_mut()
+ .as_ref()
+ .map(|menu| menu.primary_scroll_handle()),
display_options,
snippet_sort_order,
languages,
@@ -6890,7 +6901,7 @@ impl Editor {
};
let anchor = self.selections.newest_anchor().head();
- let position = self.to_pixel_point(anchor, &snapshot, window);
+ let position = self.to_pixel_point(anchor, &snapshot, window, cx);
if let (Some(position), Some(last_bounds)) = (position, self.last_bounds) {
self.show_blame_popover(
buffer,
@@ -9203,7 +9214,8 @@ impl Editor {
let size = element.layout_as_root(AvailableSpace::min_size(), window, cx);
- let line_origin = self.display_to_pixel_point(target_line_end, editor_snapshot, window)?;
+ let line_origin =
+ self.display_to_pixel_point(target_line_end, editor_snapshot, window, cx)?;
let start_point = content_origin - point(scroll_pixel_position.x.into(), Pixels::ZERO);
let mut origin = start_point
@@ -9945,8 +9957,7 @@ impl Editor {
}
pub fn render_context_menu(
- &self,
- style: &EditorStyle,
+ &mut self,
max_height_in_lines: u32,
window: &mut Window,
cx: &mut Context<Editor>,
@@ -9956,7 +9967,9 @@ impl Editor {
if !menu.visible() {
return None;
};
- Some(menu.render(style, max_height_in_lines, window, cx))
+ self.style
+ .as_ref()
+ .map(|style| menu.render(style, max_height_in_lines, window, cx))
}
fn render_context_menu_aside(
@@ -10016,13 +10029,16 @@ impl Editor {
let id = post_inc(&mut self.next_completion_id);
let snippet_sort_order = EditorSettings::get_global(cx).snippet_sort_order;
- *self.context_menu.borrow_mut() = Some(CodeContextMenu::Completions(
+ let mut context_menu = self.context_menu.borrow_mut();
+ let old_menu = context_menu.take();
+ *context_menu = Some(CodeContextMenu::Completions(
CompletionsMenu::new_snippet_choices(
id,
true,
choices,
selection,
buffer,
+ old_menu.map(|menu| menu.primary_scroll_handle()),
snippet_sort_order,
),
));
@@ -11516,6 +11532,168 @@ impl Editor {
self.manipulate_immutable_lines(window, cx, |lines| lines.shuffle(&mut rand::rng()))
}
+ pub fn rotate_selections_forward(
+ &mut self,
+ _: &RotateSelectionsForward,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) {
+ self.rotate_selections(window, cx, false)
+ }
+
+ pub fn rotate_selections_backward(
+ &mut self,
+ _: &RotateSelectionsBackward,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) {
+ self.rotate_selections(window, cx, true)
+ }
+
+ fn rotate_selections(&mut self, window: &mut Window, cx: &mut Context<Self>, reverse: bool) {
+ self.hide_mouse_cursor(HideMouseCursorOrigin::TypingAction, cx);
+ let display_snapshot = self.display_snapshot(cx);
+ let selections = self.selections.all::<MultiBufferOffset>(&display_snapshot);
+
+ if selections.len() < 2 {
+ return;
+ }
+
+ let (edits, new_selections) = {
+ let buffer = self.buffer.read(cx).read(cx);
+ let has_selections = selections.iter().any(|s| !s.is_empty());
+ if has_selections {
+ let mut selected_texts: Vec<String> = selections
+ .iter()
+ .map(|selection| {
+ buffer
+ .text_for_range(selection.start..selection.end)
+ .collect()
+ })
+ .collect();
+
+ if reverse {
+ selected_texts.rotate_left(1);
+ } else {
+ selected_texts.rotate_right(1);
+ }
+
+ let mut offset_delta: i64 = 0;
+ let mut new_selections = Vec::new();
+ let edits: Vec<_> = selections
+ .iter()
+ .zip(selected_texts.iter())
+ .map(|(selection, new_text)| {
+ let old_len = (selection.end.0 - selection.start.0) as i64;
+ let new_len = new_text.len() as i64;
+ let adjusted_start =
+ MultiBufferOffset((selection.start.0 as i64 + offset_delta) as usize);
+ let adjusted_end =
+ MultiBufferOffset((adjusted_start.0 as i64 + new_len) as usize);
+
+ new_selections.push(Selection {
+ id: selection.id,
+ start: adjusted_start,
+ end: adjusted_end,
+ reversed: selection.reversed,
+ goal: selection.goal,
+ });
+
+ offset_delta += new_len - old_len;
+ (selection.start..selection.end, new_text.clone())
+ })
+ .collect();
+ (edits, new_selections)
+ } else {
+ let mut all_rows: Vec<u32> = selections
+ .iter()
+ .map(|selection| buffer.offset_to_point(selection.start).row)
+ .collect();
+ all_rows.sort_unstable();
+ all_rows.dedup();
+
+ if all_rows.len() < 2 {
+ return;
+ }
+
+ let line_ranges: Vec<Range<MultiBufferOffset>> = all_rows
+ .iter()
+ .map(|&row| {
+ let start = Point::new(row, 0);
+ let end = Point::new(row, buffer.line_len(MultiBufferRow(row)));
+ buffer.point_to_offset(start)..buffer.point_to_offset(end)
+ })
+ .collect();
+
+ let mut line_texts: Vec<String> = line_ranges
+ .iter()
+ .map(|range| buffer.text_for_range(range.clone()).collect())
+ .collect();
+
+ if reverse {
+ line_texts.rotate_left(1);
+ } else {
+ line_texts.rotate_right(1);
+ }
+
+ let edits = line_ranges
+ .iter()
+ .zip(line_texts.iter())
+ .map(|(range, new_text)| (range.clone(), new_text.clone()))
+ .collect();
+
+ let num_rows = all_rows.len();
+ let row_to_index: std::collections::HashMap<u32, usize> = all_rows
+ .iter()
+ .enumerate()
+ .map(|(i, &row)| (row, i))
+ .collect();
+
+ // Compute new line start offsets after rotation (handles CRLF)
+ let newline_len = line_ranges[1].start.0 - line_ranges[0].end.0;
+ let first_line_start = line_ranges[0].start.0;
+ let mut new_line_starts: Vec<usize> = vec![first_line_start];
+ for text in line_texts.iter().take(num_rows - 1) {
+ let prev_start = *new_line_starts.last().unwrap();
+ new_line_starts.push(prev_start + text.len() + newline_len);
+ }
+
+ let new_selections = selections
+ .iter()
+ .map(|selection| {
+ let point = buffer.offset_to_point(selection.start);
+ let old_index = row_to_index[&point.row];
+ let new_index = if reverse {
+ (old_index + num_rows - 1) % num_rows
+ } else {
+ (old_index + 1) % num_rows
+ };
+ let new_offset =
+ MultiBufferOffset(new_line_starts[new_index] + point.column as usize);
+ Selection {
+ id: selection.id,
+ start: new_offset,
+ end: new_offset,
+ reversed: selection.reversed,
+ goal: selection.goal,
+ }
+ })
+ .collect();
+
+ (edits, new_selections)
+ }
+ };
+
+ self.transact(window, cx, |this, window, cx| {
+ this.buffer.update(cx, |buffer, cx| {
+ buffer.edit(edits, None, cx);
+ });
+ this.change_selections(Default::default(), window, cx, |s| {
+ s.select(new_selections);
+ });
+ });
+ }
+
fn manipulate_lines<M>(
&mut self,
window: &mut Window,
@@ -20194,8 +20372,11 @@ impl Editor {
self.style = Some(style);
}
- pub fn style(&self) -> Option<&EditorStyle> {
- self.style.as_ref()
+ pub fn style(&mut self, cx: &App) -> &EditorStyle {
+ if self.style.is_none() {
+ self.style = Some(self.create_style(cx));
+ }
+ self.style.as_ref().unwrap()
}
// Called by the element. This method is not designed to be called outside of the editor
@@ -21845,8 +22026,10 @@ impl Editor {
multi_buffer::Event::DiffHunksToggled => {
self.tasks_update_task = Some(self.refresh_runnables(window, cx));
}
- multi_buffer::Event::LanguageChanged(buffer_id) => {
- self.registered_buffers.remove(&buffer_id);
+ multi_buffer::Event::LanguageChanged(buffer_id, is_fresh_language) => {
+ if !is_fresh_language {
+ self.registered_buffers.remove(&buffer_id);
+ }
jsx_tag_auto_close::refresh_enabled_in_any_buffer(self, multibuffer, cx);
cx.emit(EditorEvent::Reparsed(*buffer_id));
cx.notify();
@@ -22817,22 +23000,24 @@ impl Editor {
}
pub fn to_pixel_point(
- &self,
+ &mut self,
source: multi_buffer::Anchor,
editor_snapshot: &EditorSnapshot,
window: &mut Window,
+ cx: &App,
) -> Option<gpui::Point<Pixels>> {
let source_point = source.to_display_point(editor_snapshot);
- self.display_to_pixel_point(source_point, editor_snapshot, window)
+ self.display_to_pixel_point(source_point, editor_snapshot, window, cx)
}
pub fn display_to_pixel_point(
- &self,
+ &mut self,
source: DisplayPoint,
editor_snapshot: &EditorSnapshot,
window: &mut Window,
+ cx: &App,
) -> Option<gpui::Point<Pixels>> {
- let line_height = self.style()?.text.line_height_in_pixels(window.rem_size());
+ let line_height = self.style(cx).text.line_height_in_pixels(window.rem_size());
let text_layout_details = self.text_layout_details(window);
let scroll_top = text_layout_details
.scroll_anchor
@@ -22896,10 +23081,6 @@ impl Editor {
}
}
- pub fn last_gutter_dimensions(&self) -> &GutterDimensions {
- &self.gutter_dimensions
- }
-
pub fn wait_for_diff_to_load(&self) -> Option<Shared<Task<()>>> {
self.load_diff_task.clone()
}
@@ -22999,6 +23180,57 @@ impl Editor {
// skip any LSP updates for it.
self.active_diagnostics == ActiveDiagnostic::All || !self.mode().is_full()
}
+
+ fn create_style(&self, cx: &App) -> EditorStyle {
+ let settings = ThemeSettings::get_global(cx);
+
+ let mut text_style = match self.mode {
+ EditorMode::SingleLine | EditorMode::AutoHeight { .. } => TextStyle {
+ color: cx.theme().colors().editor_foreground,
+ font_family: settings.ui_font.family.clone(),
+ font_features: settings.ui_font.features.clone(),
+ font_fallbacks: settings.ui_font.fallbacks.clone(),
+ font_size: rems(0.875).into(),
+ font_weight: settings.ui_font.weight,
+ line_height: relative(settings.buffer_line_height.value()),
+ ..Default::default()
+ },
+ EditorMode::Full { .. } | EditorMode::Minimap { .. } => TextStyle {
+ color: cx.theme().colors().editor_foreground,
+ font_family: settings.buffer_font.family.clone(),
+ font_features: settings.buffer_font.features.clone(),
+ font_fallbacks: settings.buffer_font.fallbacks.clone(),
+ font_size: settings.buffer_font_size(cx).into(),
+ font_weight: settings.buffer_font.weight,
+ line_height: relative(settings.buffer_line_height.value()),
+ ..Default::default()
+ },
+ };
+ if let Some(text_style_refinement) = &self.text_style_refinement {
+ text_style.refine(text_style_refinement)
+ }
+
+ let background = match self.mode {
+ EditorMode::SingleLine => cx.theme().system().transparent,
+ EditorMode::AutoHeight { .. } => cx.theme().system().transparent,
+ EditorMode::Full { .. } => cx.theme().colors().editor_background,
+ EditorMode::Minimap { .. } => cx.theme().colors().editor_background.opacity(0.7),
+ };
+
+ EditorStyle {
+ background,
+ border: cx.theme().colors().border,
+ local_player: cx.theme().players().local(),
+ text: text_style,
+ scrollbar_width: EditorElement::SCROLLBAR_WIDTH,
+ syntax: cx.theme().syntax().clone(),
+ status: cx.theme().status().clone(),
+ inlay_hints_style: make_inlay_hints_style(cx),
+ edit_prediction_styles: make_suggestion_styles(cx),
+ unnecessary_code_fade: settings.unnecessary_code_fade,
+ show_underlines: self.diagnostics_enabled(),
+ }
+ }
}
fn edit_for_markdown_paste<'a>(
@@ -24526,94 +24758,98 @@ impl EditorSnapshot {
self.scroll_anchor.scroll_position(&self.display_snapshot)
}
- fn gutter_dimensions(
+ pub fn gutter_dimensions(
&self,
font_id: FontId,
font_size: Pixels,
- max_line_number_width: Pixels,
+ style: &EditorStyle,
+ window: &mut Window,
cx: &App,
- ) -> Option<GutterDimensions> {
- if !self.show_gutter {
- return None;
- }
+ ) -> GutterDimensions {
+ if self.show_gutter
+ && let Some(ch_width) = cx.text_system().ch_width(font_id, font_size).log_err()
+ && let Some(ch_advance) = cx.text_system().ch_advance(font_id, font_size).log_err()
+ {
+ let show_git_gutter = self.show_git_diff_gutter.unwrap_or_else(|| {
+ matches!(
+ ProjectSettings::get_global(cx).git.git_gutter,
+ GitGutterSetting::TrackedFiles
+ )
+ });
+ let gutter_settings = EditorSettings::get_global(cx).gutter;
+ let show_line_numbers = self
+ .show_line_numbers
+ .unwrap_or(gutter_settings.line_numbers);
+ let line_gutter_width = if show_line_numbers {
+ // Avoid flicker-like gutter resizes when the line number gains another digit by
+ // only resizing the gutter on files with > 10**min_line_number_digits lines.
+ let min_width_for_number_on_gutter =
+ ch_advance * gutter_settings.min_line_number_digits as f32;
+ self.max_line_number_width(style, window)
+ .max(min_width_for_number_on_gutter)
+ } else {
+ 0.0.into()
+ };
- let ch_width = cx.text_system().ch_width(font_id, font_size).log_err()?;
- let ch_advance = cx.text_system().ch_advance(font_id, font_size).log_err()?;
+ let show_runnables = self.show_runnables.unwrap_or(gutter_settings.runnables);
+ let show_breakpoints = self.show_breakpoints.unwrap_or(gutter_settings.breakpoints);
- let show_git_gutter = self.show_git_diff_gutter.unwrap_or_else(|| {
- matches!(
- ProjectSettings::get_global(cx).git.git_gutter,
- GitGutterSetting::TrackedFiles
- )
- });
- let gutter_settings = EditorSettings::get_global(cx).gutter;
- let show_line_numbers = self
- .show_line_numbers
- .unwrap_or(gutter_settings.line_numbers);
- let line_gutter_width = if show_line_numbers {
- // Avoid flicker-like gutter resizes when the line number gains another digit by
- // only resizing the gutter on files with > 10**min_line_number_digits lines.
- let min_width_for_number_on_gutter =
- ch_advance * gutter_settings.min_line_number_digits as f32;
- max_line_number_width.max(min_width_for_number_on_gutter)
- } else {
- 0.0.into()
- };
-
- let show_runnables = self.show_runnables.unwrap_or(gutter_settings.runnables);
- let show_breakpoints = self.show_breakpoints.unwrap_or(gutter_settings.breakpoints);
+ let git_blame_entries_width =
+ self.git_blame_gutter_max_author_length
+ .map(|max_author_length| {
+ let renderer = cx.global::<GlobalBlameRenderer>().0.clone();
+ const MAX_RELATIVE_TIMESTAMP: &str = "60 minutes ago";
- let git_blame_entries_width =
- self.git_blame_gutter_max_author_length
- .map(|max_author_length| {
- let renderer = cx.global::<GlobalBlameRenderer>().0.clone();
- const MAX_RELATIVE_TIMESTAMP: &str = "60 minutes ago";
+ /// The number of characters to dedicate to gaps and margins.
+ const SPACING_WIDTH: usize = 4;
- /// The number of characters to dedicate to gaps and margins.
- const SPACING_WIDTH: usize = 4;
+ let max_char_count = max_author_length.min(renderer.max_author_length())
+ + ::git::SHORT_SHA_LENGTH
+ + MAX_RELATIVE_TIMESTAMP.len()
+ + SPACING_WIDTH;
- let max_char_count = max_author_length.min(renderer.max_author_length())
- + ::git::SHORT_SHA_LENGTH
- + MAX_RELATIVE_TIMESTAMP.len()
- + SPACING_WIDTH;
+ ch_advance * max_char_count
+ });
- ch_advance * max_char_count
- });
+ let is_singleton = self.buffer_snapshot().is_singleton();
+
+ let mut left_padding = git_blame_entries_width.unwrap_or(Pixels::ZERO);
+ left_padding += if !is_singleton {
+ ch_width * 4.0
+ } else if show_runnables || show_breakpoints {
+ ch_width * 3.0
+ } else if show_git_gutter && show_line_numbers {
+ ch_width * 2.0
+ } else if show_git_gutter || show_line_numbers {
+ ch_width
+ } else {
+ px(0.)
+ };
- let is_singleton = self.buffer_snapshot().is_singleton();
-
- let mut left_padding = git_blame_entries_width.unwrap_or(Pixels::ZERO);
- left_padding += if !is_singleton {
- ch_width * 4.0
- } else if show_runnables || show_breakpoints {
- ch_width * 3.0
- } else if show_git_gutter && show_line_numbers {
- ch_width * 2.0
- } else if show_git_gutter || show_line_numbers {
- ch_width
- } else {
- px(0.)
- };
+ let shows_folds = is_singleton && gutter_settings.folds;
- let shows_folds = is_singleton && gutter_settings.folds;
+ let right_padding = if shows_folds && show_line_numbers {
+ ch_width * 4.0
+ } else if shows_folds || (!is_singleton && show_line_numbers) {
+ ch_width * 3.0
+ } else if show_line_numbers {
+ ch_width
+ } else {
+ px(0.)
+ };
- let right_padding = if shows_folds && show_line_numbers {
- ch_width * 4.0
- } else if shows_folds || (!is_singleton && show_line_numbers) {
- ch_width * 3.0
- } else if show_line_numbers {
- ch_width
+ GutterDimensions {
+ left_padding,
+ right_padding,
+ width: line_gutter_width + left_padding + right_padding,
+ margin: GutterDimensions::default_gutter_margin(font_id, font_size, cx),
+ git_blame_entries_width,
+ }
+ } else if self.offset_content {
+ GutterDimensions::default_with_margin(font_id, font_size, cx)
} else {
- px(0.)
- };
-
- Some(GutterDimensions {
- left_padding,
- right_padding,
- width: line_gutter_width + left_padding + right_padding,
- margin: GutterDimensions::default_gutter_margin(font_id, font_size, cx),
- git_blame_entries_width,
- })
+ GutterDimensions::default()
+ }
}
pub fn render_crease_toggle(
@@ -24696,6 +24932,28 @@ impl EditorSnapshot {
None
}
}
+
+ pub fn max_line_number_width(&self, style: &EditorStyle, window: &mut Window) -> Pixels {
+ let digit_count = self.widest_line_number().ilog10() + 1;
+ column_pixels(style, digit_count as usize, window)
+ }
+}
+
+pub fn column_pixels(style: &EditorStyle, column: usize, window: &Window) -> Pixels {
+ let font_size = style.text.font_size.to_pixels(window.rem_size());
+ let layout = window.text_system().shape_line(
+ SharedString::from(" ".repeat(column)),
+ font_size,
+ &[TextRun {
+ len: column,
+ font: style.text.font(),
+ color: Hsla::default(),
+ ..Default::default()
+ }],
+ None,
+ );
+
+ layout.width
}
impl Deref for EditorSnapshot {
@@ -24776,57 +25034,7 @@ impl Focusable for Editor {
impl Render for Editor {
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
- let settings = ThemeSettings::get_global(cx);
-
- let mut text_style = match self.mode {
- EditorMode::SingleLine | EditorMode::AutoHeight { .. } => TextStyle {
- color: cx.theme().colors().editor_foreground,
- font_family: settings.ui_font.family.clone(),
- font_features: settings.ui_font.features.clone(),
- font_fallbacks: settings.ui_font.fallbacks.clone(),
- font_size: rems(0.875).into(),
- font_weight: settings.ui_font.weight,
- line_height: relative(settings.buffer_line_height.value()),
- ..Default::default()
- },
- EditorMode::Full { .. } | EditorMode::Minimap { .. } => TextStyle {
- color: cx.theme().colors().editor_foreground,
- font_family: settings.buffer_font.family.clone(),
- font_features: settings.buffer_font.features.clone(),
- font_fallbacks: settings.buffer_font.fallbacks.clone(),
- font_size: settings.buffer_font_size(cx).into(),
- font_weight: settings.buffer_font.weight,
- line_height: relative(settings.buffer_line_height.value()),
- ..Default::default()
- },
- };
- if let Some(text_style_refinement) = &self.text_style_refinement {
- text_style.refine(text_style_refinement)
- }
-
- let background = match self.mode {
- EditorMode::SingleLine => cx.theme().system().transparent,
- EditorMode::AutoHeight { .. } => cx.theme().system().transparent,
- EditorMode::Full { .. } => cx.theme().colors().editor_background,
- EditorMode::Minimap { .. } => cx.theme().colors().editor_background.opacity(0.7),
- };
-
- EditorElement::new(
- &cx.entity(),
- EditorStyle {
- background,
- border: cx.theme().colors().border,
- local_player: cx.theme().players().local(),
- text: text_style,
- scrollbar_width: EditorElement::SCROLLBAR_WIDTH,
- syntax: cx.theme().syntax().clone(),
- status: cx.theme().status().clone(),
- inlay_hints_style: make_inlay_hints_style(cx),
- edit_prediction_styles: make_suggestion_styles(cx),
- unnecessary_code_fade: ThemeSettings::get_global(cx).unnecessary_code_fade,
- show_underlines: self.diagnostics_enabled(),
- },
- )
+ EditorElement::new(&cx.entity(), self.create_style(cx))
}
}
@@ -2218,10 +2218,9 @@ async fn test_move_start_of_paragraph_end_of_paragraph(cx: &mut TestAppContext)
init_test(cx, |_| {});
let mut cx = EditorTestContext::new(cx).await;
- let line_height = cx.editor(|editor, window, _| {
+ let line_height = cx.update_editor(|editor, window, cx| {
editor
- .style()
- .unwrap()
+ .style(cx)
.text
.line_height_in_pixels(window.rem_size())
});
@@ -2334,10 +2333,9 @@ async fn test_move_start_of_paragraph_end_of_paragraph(cx: &mut TestAppContext)
async fn test_scroll_page_up_page_down(cx: &mut TestAppContext) {
init_test(cx, |_| {});
let mut cx = EditorTestContext::new(cx).await;
- let line_height = cx.editor(|editor, window, _| {
+ let line_height = cx.update_editor(|editor, window, cx| {
editor
- .style()
- .unwrap()
+ .style(cx)
.text
.line_height_in_pixels(window.rem_size())
});
@@ -2400,8 +2398,7 @@ async fn test_autoscroll(cx: &mut TestAppContext) {
let line_height = cx.update_editor(|editor, window, cx| {
editor.set_vertical_scroll_margin(2, cx);
editor
- .style()
- .unwrap()
+ .style(cx)
.text
.line_height_in_pixels(window.rem_size())
});
@@ -2480,10 +2477,9 @@ async fn test_move_page_up_page_down(cx: &mut TestAppContext) {
init_test(cx, |_| {});
let mut cx = EditorTestContext::new(cx).await;
- let line_height = cx.editor(|editor, window, _cx| {
+ let line_height = cx.update_editor(|editor, window, cx| {
editor
- .style()
- .unwrap()
+ .style(cx)
.text
.line_height_in_pixels(window.rem_size())
});
@@ -5777,6 +5773,116 @@ fn test_duplicate_line(cx: &mut TestAppContext) {
});
}
+#[gpui::test]
+async fn test_rotate_selections(cx: &mut TestAppContext) {
+ init_test(cx, |_| {});
+
+ let mut cx = EditorTestContext::new(cx).await;
+
+ // Rotate text selections (horizontal)
+ cx.set_state("x=«1ˇ», y=«2ˇ», z=«3ˇ»");
+ cx.update_editor(|e, window, cx| {
+ e.rotate_selections_forward(&RotateSelectionsForward, window, cx)
+ });
+ cx.assert_editor_state("x=«3ˇ», y=«1ˇ», z=«2ˇ»");
+ cx.update_editor(|e, window, cx| {
+ e.rotate_selections_backward(&RotateSelectionsBackward, window, cx)
+ });
+ cx.assert_editor_state("x=«1ˇ», y=«2ˇ», z=«3ˇ»");
+
+ // Rotate text selections (vertical)
+ cx.set_state(indoc! {"
+ x=«1ˇ»
+ y=«2ˇ»
+ z=«3ˇ»
+ "});
+ cx.update_editor(|e, window, cx| {
+ e.rotate_selections_forward(&RotateSelectionsForward, window, cx)
+ });
+ cx.assert_editor_state(indoc! {"
+ x=«3ˇ»
+ y=«1ˇ»
+ z=«2ˇ»
+ "});
+ cx.update_editor(|e, window, cx| {
+ e.rotate_selections_backward(&RotateSelectionsBackward, window, cx)
+ });
+ cx.assert_editor_state(indoc! {"
+ x=«1ˇ»
+ y=«2ˇ»
+ z=«3ˇ»
+ "});
+
+ // Rotate text selections (vertical, different lengths)
+ cx.set_state(indoc! {"
+ x=\"«ˇ»\"
+ y=\"«aˇ»\"
+ z=\"«aaˇ»\"
+ "});
+ cx.update_editor(|e, window, cx| {
+ e.rotate_selections_forward(&RotateSelectionsForward, window, cx)
+ });
+ cx.assert_editor_state(indoc! {"
+ x=\"«aaˇ»\"
+ y=\"«ˇ»\"
+ z=\"«aˇ»\"
+ "});
+ cx.update_editor(|e, window, cx| {
+ e.rotate_selections_backward(&RotateSelectionsBackward, window, cx)
+ });
+ cx.assert_editor_state(indoc! {"
+ x=\"«ˇ»\"
+ y=\"«aˇ»\"
+ z=\"«aaˇ»\"
+ "});
+
+ // Rotate whole lines (cursor positions preserved)
+ cx.set_state(indoc! {"
+ ˇline123
+ liˇne23
+ line3ˇ
+ "});
+ cx.update_editor(|e, window, cx| {
+ e.rotate_selections_forward(&RotateSelectionsForward, window, cx)
+ });
+ cx.assert_editor_state(indoc! {"
+ line3ˇ
+ ˇline123
+ liˇne23
+ "});
+ cx.update_editor(|e, window, cx| {
+ e.rotate_selections_backward(&RotateSelectionsBackward, window, cx)
+ });
+ cx.assert_editor_state(indoc! {"
+ ˇline123
+ liˇne23
+ line3ˇ
+ "});
+
+ // Rotate whole lines, multiple cursors per line (positions preserved)
+ cx.set_state(indoc! {"
+ ˇliˇne123
+ ˇline23
+ ˇline3
+ "});
+ cx.update_editor(|e, window, cx| {
+ e.rotate_selections_forward(&RotateSelectionsForward, window, cx)
+ });
+ cx.assert_editor_state(indoc! {"
+ ˇline3
+ ˇliˇne123
+ ˇline23
+ "});
+ cx.update_editor(|e, window, cx| {
+ e.rotate_selections_backward(&RotateSelectionsBackward, window, cx)
+ });
+ cx.assert_editor_state(indoc! {"
+ ˇliˇne123
+ ˇline23
+ ˇline3
+ "});
+}
+
#[gpui::test]
fn test_move_line_up_down(cx: &mut TestAppContext) {
init_test(cx, |_| {});
@@ -28201,7 +28307,8 @@ async fn test_sticky_scroll(cx: &mut TestAppContext) {
let mut sticky_headers = |offset: ScrollOffset| {
cx.update_editor(|e, window, cx| {
e.scroll(gpui::Point { x: 0., y: offset }, None, window, cx);
- EditorElement::sticky_headers(&e, &e.snapshot(window, cx), cx)
+ let style = e.style(cx).clone();
+ EditorElement::sticky_headers(&e, &e.snapshot(window, cx), &style, cx)
.into_iter()
.map(
|StickyHeader {
@@ -28255,10 +28362,9 @@ async fn test_scroll_by_clicking_sticky_header(cx: &mut TestAppContext) {
});
let mut cx = EditorTestContext::new(cx).await;
- let line_height = cx.editor(|editor, window, _cx| {
+ let line_height = cx.update_editor(|editor, window, cx| {
editor
- .style()
- .unwrap()
+ .style(cx)
.text
.line_height_in_pixels(window.rem_size())
});
@@ -11,6 +11,7 @@ use crate::{
SelectedTextHighlight, Selection, SelectionDragState, SelectionEffects, SizingBehavior,
SoftWrap, StickyHeaderExcerpt, ToPoint, ToggleFold, ToggleFoldAll,
code_context_menus::{CodeActionsMenu, MENU_ASIDE_MAX_WIDTH, MENU_ASIDE_MIN_WIDTH, MENU_GAP},
+ column_pixels,
display_map::{
Block, BlockContext, BlockStyle, ChunkRendererId, DisplaySnapshot, EditorMargins,
HighlightKey, HighlightedChunk, ToDisplayPoint,
@@ -253,6 +254,8 @@ impl EditorElement {
register_action(editor, window, Editor::sort_lines_case_insensitive);
register_action(editor, window, Editor::reverse_lines);
register_action(editor, window, Editor::shuffle_lines);
+ register_action(editor, window, Editor::rotate_selections_forward);
+ register_action(editor, window, Editor::rotate_selections_backward);
register_action(editor, window, Editor::convert_indentation_to_spaces);
register_action(editor, window, Editor::convert_indentation_to_tabs);
register_action(editor, window, Editor::convert_to_upper_case);
@@ -2267,7 +2270,8 @@ impl EditorElement {
};
let padding = ProjectSettings::get_global(cx).diagnostics.inline.padding as f32 * em_width;
- let min_x = self.column_pixels(
+ let min_x = column_pixels(
+ &self.style,
ProjectSettings::get_global(cx)
.diagnostics
.inline
@@ -2570,7 +2574,8 @@ impl EditorElement {
let padded_line_end = line_end + padding;
- let min_column_in_pixels = self.column_pixels(
+ let min_column_in_pixels = column_pixels(
+ &self.style,
ProjectSettings::get_global(cx).git.inline_blame.min_column as usize,
window,
);
@@ -2794,7 +2799,7 @@ impl EditorElement {
.enumerate()
.filter_map(|(i, indent_guide)| {
let single_indent_width =
- self.column_pixels(indent_guide.tab_size as usize, window);
+ column_pixels(&self.style, indent_guide.tab_size as usize, window);
let total_width = single_indent_width * indent_guide.depth as f32;
let start_x = Pixels::from(
ScrollOffset::from(content_origin.x + total_width)
@@ -2851,7 +2856,7 @@ impl EditorElement {
.wrap_guides(cx)
.into_iter()
.flat_map(|(guide, active)| {
- let wrap_position = self.column_pixels(guide, window);
+ let wrap_position = column_pixels(&self.style, guide, window);
let wrap_guide_x = wrap_position + horizontal_offset;
let display_wrap_guide = wrap_guide_x >= content_origin
&& wrap_guide_x <= hitbox.bounds.right() - vertical_scrollbar_width;
@@ -4617,6 +4622,7 @@ impl EditorElement {
gutter_dimensions: &GutterDimensions,
gutter_hitbox: &Hitbox,
text_hitbox: &Hitbox,
+ style: &EditorStyle,
window: &mut Window,
cx: &mut App,
) -> Option<StickyHeaders> {
@@ -4624,7 +4630,7 @@ impl EditorElement {
.show_line_numbers
.unwrap_or_else(|| EditorSettings::get_global(cx).gutter.line_numbers);
- let rows = Self::sticky_headers(self.editor.read(cx), snapshot, cx);
+ let rows = Self::sticky_headers(self.editor.read(cx), snapshot, style, cx);
let mut lines = Vec::<StickyHeaderLine>::new();
@@ -4683,6 +4689,7 @@ impl EditorElement {
pub(crate) fn sticky_headers(
editor: &Editor,
snapshot: &EditorSnapshot,
+ style: &EditorStyle,
cx: &App,
) -> Vec<StickyHeader> {
let scroll_top = snapshot.scroll_position().y;
@@ -4690,7 +4697,7 @@ impl EditorElement {
let mut end_rows = Vec::<DisplayRow>::new();
let mut rows = Vec::<StickyHeader>::new();
- let items = editor.sticky_headers(cx).unwrap_or_default();
+ let items = editor.sticky_headers(style, cx).unwrap_or_default();
for item in items {
let start_point = item.range.start.to_point(snapshot.buffer_snapshot());
@@ -5253,7 +5260,7 @@ impl EditorElement {
) -> Option<AnyElement> {
let max_height_in_lines = ((height - POPOVER_Y_PADDING) / line_height).floor() as u32;
self.editor.update(cx, |editor, cx| {
- editor.render_context_menu(&self.style, max_height_in_lines, window, cx)
+ editor.render_context_menu(max_height_in_lines, window, cx)
})
}
@@ -5280,16 +5287,18 @@ impl EditorElement {
window: &mut Window,
cx: &mut App,
) -> Option<AnyElement> {
- let position = self.editor.update(cx, |editor, _cx| {
+ let position = self.editor.update(cx, |editor, cx| {
let visible_start_point = editor.display_to_pixel_point(
DisplayPoint::new(visible_range.start, 0),
editor_snapshot,
window,
+ cx,
)?;
let visible_end_point = editor.display_to_pixel_point(
DisplayPoint::new(visible_range.end, 0),
editor_snapshot,
window,
+ cx,
)?;
let mouse_context_menu = editor.mouse_context_menu.as_ref()?;
@@ -5297,7 +5306,8 @@ impl EditorElement {
MenuPosition::PinnedToScreen(point) => (None, point),
MenuPosition::PinnedToEditor { source, offset } => {
let source_display_point = source.to_display_point(editor_snapshot);
- let source_point = editor.to_pixel_point(source, editor_snapshot, window)?;
+ let source_point =
+ editor.to_pixel_point(source, editor_snapshot, window, cx)?;
let position = content_origin + source_point + offset;
(Some(source_display_point), position)
}
@@ -7771,29 +7781,6 @@ impl EditorElement {
});
}
- fn column_pixels(&self, column: usize, window: &Window) -> Pixels {
- let style = &self.style;
- let font_size = style.text.font_size.to_pixels(window.rem_size());
- let layout = window.text_system().shape_line(
- SharedString::from(" ".repeat(column)),
- font_size,
- &[TextRun {
- len: column,
- font: style.text.font(),
- color: Hsla::default(),
- ..Default::default()
- }],
- None,
- );
-
- layout.width
- }
-
- fn max_line_number_width(&self, snapshot: &EditorSnapshot, window: &mut Window) -> Pixels {
- let digit_count = snapshot.widest_line_number().ilog10() + 1;
- self.column_pixels(digit_count as usize, window)
- }
-
fn shape_line_number(
&self,
text: SharedString,
@@ -8941,8 +8928,6 @@ impl Element for EditorElement {
max_lines,
} => {
let editor_handle = cx.entity();
- let max_line_number_width =
- self.max_line_number_width(&editor.snapshot(window, cx), window);
window.request_measured_layout(
Style::default(),
move |known_dimensions, available_space, window, cx| {
@@ -8952,7 +8937,6 @@ impl Element for EditorElement {
editor,
min_lines,
max_lines,
- max_line_number_width,
known_dimensions,
available_space.width,
window,
@@ -9039,15 +9023,10 @@ impl Element for EditorElement {
.gutter_dimensions(
font_id,
font_size,
- self.max_line_number_width(&snapshot, window),
+ style,
+ window,
cx,
- )
- .or_else(|| {
- self.editor.read(cx).offset_content.then(|| {
- GutterDimensions::default_with_margin(font_id, font_size, cx)
- })
- })
- .unwrap_or_default();
+ );
let text_width = bounds.size.width - gutter_dimensions.width;
let settings = EditorSettings::get_global(cx);
@@ -9738,6 +9717,7 @@ impl Element for EditorElement {
&gutter_dimensions,
&gutter_hitbox,
&text_hitbox,
+ &style,
window,
cx,
)
@@ -11454,7 +11434,6 @@ fn compute_auto_height_layout(
editor: &mut Editor,
min_lines: usize,
max_lines: Option<usize>,
- max_line_number_width: Pixels,
known_dimensions: Size<Option<Pixels>>,
available_width: AvailableSpace,
window: &mut Window,
@@ -11478,14 +11457,7 @@ fn compute_auto_height_layout(
let em_width = window.text_system().em_width(font_id, font_size).unwrap();
let mut snapshot = editor.snapshot(window, cx);
- let gutter_dimensions = snapshot
- .gutter_dimensions(font_id, font_size, max_line_number_width, cx)
- .or_else(|| {
- editor
- .offset_content
- .then(|| GutterDimensions::default_with_margin(font_id, font_size, cx))
- })
- .unwrap_or_default();
+ let gutter_dimensions = snapshot.gutter_dimensions(font_id, font_size, style, window, cx);
editor.gutter_dimensions = gutter_dimensions;
let text_width = width - gutter_dimensions.width;
@@ -11548,7 +11520,7 @@ mod tests {
});
let cx = &mut VisualTestContext::from_window(*window, cx);
let editor = window.root(cx).unwrap();
- let style = cx.update(|_, cx| editor.read(cx).style().unwrap().clone());
+ let style = cx.update(|_, cx| editor.update(cx, |editor, cx| editor.style(cx).clone()));
for x in 1..=100 {
let (_, state) = cx.draw(
@@ -11576,7 +11548,7 @@ mod tests {
});
let cx = &mut VisualTestContext::from_window(*window, cx);
let editor = window.root(cx).unwrap();
- let style = cx.update(|_, cx| editor.read(cx).style().unwrap().clone());
+ let style = cx.update(|_, cx| editor.update(cx, |editor, cx| editor.style(cx).clone()));
for x in 1..=100 {
let (_, state) = cx.draw(
@@ -11601,7 +11573,7 @@ mod tests {
});
let editor = window.root(cx).unwrap();
- let style = cx.update(|cx| editor.read(cx).style().unwrap().clone());
+ let style = editor.update(cx, |editor, cx| editor.style(cx).clone());
let line_height = window
.update(cx, |_, window, _| {
style.text.line_height_in_pixels(window.rem_size())
@@ -11749,7 +11721,7 @@ mod tests {
});
let editor = window.root(cx).unwrap();
- let style = cx.update(|cx| editor.read(cx).style().unwrap().clone());
+ let style = editor.update(cx, |editor, cx| editor.style(cx).clone());
let line_height = window
.update(cx, |_, window, _| {
style.text.line_height_in_pixels(window.rem_size())
@@ -11876,7 +11848,7 @@ mod tests {
});
let cx = &mut VisualTestContext::from_window(*window, cx);
let editor = window.root(cx).unwrap();
- let style = cx.update(|_, cx| editor.read(cx).style().unwrap().clone());
+ let style = cx.update(|_, cx| editor.update(cx, |editor, cx| editor.style(cx).clone()));
window
.update(cx, |editor, window, cx| {
@@ -11947,7 +11919,7 @@ mod tests {
});
let cx = &mut VisualTestContext::from_window(*window, cx);
let editor = window.root(cx).unwrap();
- let style = cx.update(|_, cx| editor.read(cx).style().unwrap().clone());
+ let style = cx.update(|_, cx| editor.update(cx, |editor, cx| editor.style(cx).clone()));
window
.update(cx, |editor, window, cx| {
editor.set_placeholder_text("hello", window, cx);
@@ -12187,7 +12159,7 @@ mod tests {
let cx = &mut VisualTestContext::from_window(*window, cx);
let editor = window.root(cx).unwrap();
- let style = cx.update(|_, cx| editor.read(cx).style().unwrap().clone());
+ let style = editor.update(cx, |editor, cx| editor.style(cx).clone());
window
.update(cx, |editor, _, cx| {
editor.set_soft_wrap_mode(language_settings::SoftWrap::EditorWidth, cx);
@@ -59,7 +59,7 @@ impl MouseContextMenu {
x: editor.gutter_dimensions.width,
y: Pixels::ZERO,
};
- let source_position = editor.to_pixel_point(source, &editor_snapshot, window)?;
+ let source_position = editor.to_pixel_point(source, &editor_snapshot, window, cx)?;
let menu_position = MenuPosition::PinnedToEditor {
source,
offset: position - (source_position + content_origin),
@@ -280,7 +280,11 @@ pub fn deploy_context_menu(
"Copy Permalink",
Box::new(CopyPermalinkToLine),
)
- .action_disabled_when(!has_git_repo, "File History", Box::new(git::FileHistory));
+ .action_disabled_when(
+ !has_git_repo,
+ "View File History",
+ Box::new(git::FileHistory),
+ );
match focus {
Some(focus) => builder.context(focus),
None => builder,
@@ -283,8 +283,7 @@ impl EditorTestContext {
.head();
let pixel_position = editor.pixel_position_of_newest_cursor.unwrap();
let line_height = editor
- .style()
- .unwrap()
+ .style(cx)
.text
.line_height_in_pixels(window.rem_size());
let snapshot = editor.snapshot(window, cx);
@@ -16,4 +16,8 @@ pub struct InlineAssistantV2FeatureFlag;
impl FeatureFlag for InlineAssistantV2FeatureFlag {
const NAME: &'static str = "inline-assistant-v2";
+
+ fn enabled_for_staff() -> bool {
+ false
+ }
}
@@ -232,14 +232,12 @@ impl From<Oid> for usize {
#[derive(Copy, Clone, Debug)]
pub enum RunHook {
PreCommit,
- PrePush,
}
impl RunHook {
pub fn as_str(&self) -> &str {
match self {
Self::PreCommit => "pre-commit",
- Self::PrePush => "pre-push",
}
}
@@ -250,7 +248,6 @@ impl RunHook {
pub fn from_proto(value: i32) -> Option<Self> {
match value {
0 => Some(Self::PreCommit),
- 1 => Some(Self::PrePush),
_ => None,
}
}
@@ -652,6 +652,7 @@ pub struct RealGitRepository {
pub repository: Arc<Mutex<git2::Repository>>,
pub system_git_binary_path: Option<PathBuf>,
pub any_git_binary_path: PathBuf,
+ any_git_binary_help_output: Arc<Mutex<Option<SharedString>>>,
executor: BackgroundExecutor,
}
@@ -670,6 +671,7 @@ impl RealGitRepository {
system_git_binary_path,
any_git_binary_path,
executor,
+ any_git_binary_help_output: Arc::new(Mutex::new(None)),
})
}
@@ -680,6 +682,27 @@ impl RealGitRepository {
.context("failed to read git work directory")
.map(Path::to_path_buf)
}
+
+ async fn any_git_binary_help_output(&self) -> SharedString {
+ if let Some(output) = self.any_git_binary_help_output.lock().clone() {
+ return output;
+ }
+ let git_binary_path = self.any_git_binary_path.clone();
+ let executor = self.executor.clone();
+ let working_directory = self.working_directory();
+ let output: SharedString = self
+ .executor
+ .spawn(async move {
+ GitBinary::new(git_binary_path, working_directory?, executor)
+ .run(["help", "-a"])
+ .await
+ })
+ .await
+ .unwrap_or_default()
+ .into();
+ *self.any_git_binary_help_output.lock() = Some(output.clone());
+ output
+ }
}
#[derive(Clone, Debug)]
@@ -2290,18 +2313,50 @@ impl GitRepository for RealGitRepository {
env: Arc<HashMap<String, String>>,
) -> BoxFuture<'_, Result<()>> {
let working_directory = self.working_directory();
+ let repository = self.repository.clone();
let git_binary_path = self.any_git_binary_path.clone();
let executor = self.executor.clone();
- self.executor
- .spawn(async move {
- let working_directory = working_directory?;
- let git = GitBinary::new(git_binary_path, working_directory, executor)
- .envs(HashMap::clone(&env));
- git.run(&["hook", "run", "--ignore-missing", hook.as_str()])
- .await?;
- Ok(())
- })
- .boxed()
+ let help_output = self.any_git_binary_help_output();
+
+ async move {
+ let working_directory = working_directory?;
+ if !help_output
+ .await
+ .lines()
+ .any(|line| line.trim().starts_with("hook "))
+ {
+ let hook_abs_path = repository.lock().path().join("hooks").join(hook.as_str());
+ if hook_abs_path.is_file() {
+ let output = self
+ .executor
+ .spawn(
+ new_smol_command(&hook_abs_path)
+ .envs(env.iter())
+ .current_dir(&working_directory)
+ .output(),
+ )
+ .await?;
+
+ if !output.status.success() {
+ return Err(GitBinaryCommandError {
+ stdout: String::from_utf8_lossy(&output.stdout).into_owned(),
+ stderr: String::from_utf8_lossy(&output.stderr).into_owned(),
+ status: output.status,
+ }
+ .into());
+ }
+ }
+
+ return Ok(());
+ }
+
+ let git = GitBinary::new(git_binary_path, working_directory, executor)
+ .envs(HashMap::clone(&env));
+ git.run(&["hook", "run", "--ignore-missing", hook.as_str()])
+ .await?;
+ Ok(())
+ }
+ .boxed()
}
}
@@ -911,7 +911,7 @@ impl PickerDelegate for BranchListDelegate {
});
Some(
- ListItem::new(SharedString::from(format!("vcs-menu-{ix}")))
+ ListItem::new(format!("vcs-menu-{ix}"))
.inset(true)
.spacing(ListItemSpacing::Sparse)
.toggle_state(selected)
@@ -1,9 +1,7 @@
use anyhow::{Context as _, Result};
use buffer_diff::{BufferDiff, BufferDiffSnapshot};
use editor::display_map::{BlockPlacement, BlockProperties, BlockStyle};
-use editor::{
- Editor, EditorEvent, ExcerptId, ExcerptRange, MultiBuffer, multibuffer_context_lines,
-};
+use editor::{Editor, EditorEvent, ExcerptRange, MultiBuffer, multibuffer_context_lines};
use git::repository::{CommitDetails, CommitDiff, RepoPath};
use git::{GitHostingProviderRegistry, GitRemote, parse_git_remote_url};
use gpui::{
@@ -13,7 +11,7 @@ use gpui::{
};
use language::{
Anchor, Buffer, Capability, DiskState, File, LanguageRegistry, LineEnding, OffsetRangeExt as _,
- ReplicaId, Rope, TextBuffer,
+ Point, ReplicaId, Rope, TextBuffer,
};
use multi_buffer::PathKey;
use project::{Project, WorktreeId, git_store::Repository};
@@ -70,6 +68,7 @@ struct GitBlob {
display_name: Arc<str>,
}
+const COMMIT_MESSAGE_SORT_PREFIX: u64 = 0;
const FILE_NAMESPACE_SORT_PREFIX: u64 = 1;
impl CommitView {
@@ -147,15 +146,71 @@ impl CommitView {
) -> Self {
let language_registry = project.read(cx).languages().clone();
let multibuffer = cx.new(|_| MultiBuffer::new(Capability::ReadOnly));
+
+ let message_buffer = cx.new(|cx| {
+ let mut buffer = Buffer::local(commit.message.clone(), cx);
+ buffer.set_capability(Capability::ReadOnly, cx);
+ buffer
+ });
+
+ multibuffer.update(cx, |multibuffer, cx| {
+ let snapshot = message_buffer.read(cx).snapshot();
+ let full_range = Point::zero()..snapshot.max_point();
+ let range = ExcerptRange {
+ context: full_range.clone(),
+ primary: full_range,
+ };
+ multibuffer.set_excerpt_ranges_for_path(
+ PathKey::with_sort_prefix(
+ COMMIT_MESSAGE_SORT_PREFIX,
+ RelPath::unix("commit message").unwrap().into(),
+ ),
+ message_buffer.clone(),
+ &snapshot,
+ vec![range],
+ cx,
+ )
+ });
+
let editor = cx.new(|cx| {
let mut editor =
Editor::for_multibuffer(multibuffer.clone(), Some(project.clone()), window, cx);
editor.disable_inline_diagnostics();
+ editor.set_show_breakpoints(false, cx);
editor.set_expand_all_diff_hunks(cx);
+ editor.disable_header_for_buffer(message_buffer.read(cx).remote_id(), cx);
+ editor.disable_indent_guides_for_buffer(message_buffer.read(cx).remote_id(), cx);
+
+ editor.insert_blocks(
+ [BlockProperties {
+ placement: BlockPlacement::Above(editor::Anchor::min()),
+ height: Some(1),
+ style: BlockStyle::Sticky,
+ render: Arc::new(|_| gpui::Empty.into_any_element()),
+ priority: 0,
+ }]
+ .into_iter()
+ .chain(
+ editor
+ .buffer()
+ .read(cx)
+ .buffer_anchor_to_anchor(&message_buffer, Anchor::MAX, cx)
+ .map(|anchor| BlockProperties {
+ placement: BlockPlacement::Below(anchor),
+ height: Some(1),
+ style: BlockStyle::Sticky,
+ render: Arc::new(|_| gpui::Empty.into_any_element()),
+ priority: 0,
+ }),
+ ),
+ None,
+ cx,
+ );
editor
});
+
let commit_sha = Arc::<str>::from(commit.sha.as_ref());
let first_worktree_id = project
@@ -165,7 +220,6 @@ impl CommitView {
.map(|worktree| worktree.read(cx).id());
let repository_clone = repository.clone();
- let commit_message = commit.message.clone();
cx.spawn(async move |this, cx| {
for file in commit_diff.files {
@@ -227,59 +281,6 @@ impl CommitView {
})?;
}
- let message_buffer = cx.new(|cx| {
- let mut buffer = Buffer::local(commit_message, cx);
- buffer.set_capability(Capability::ReadOnly, cx);
- buffer
- })?;
-
- this.update(cx, |this, cx| {
- this.multibuffer.update(cx, |multibuffer, cx| {
- let range = ExcerptRange {
- context: Anchor::MIN..Anchor::MAX,
- primary: Anchor::MIN..Anchor::MAX,
- };
- multibuffer.insert_excerpts_after(
- ExcerptId::min(),
- message_buffer.clone(),
- [range],
- cx,
- )
- });
-
- this.editor.update(cx, |editor, cx| {
- editor.disable_header_for_buffer(message_buffer.read(cx).remote_id(), cx);
- editor
- .disable_indent_guides_for_buffer(message_buffer.read(cx).remote_id(), cx);
-
- editor.insert_blocks(
- [BlockProperties {
- placement: BlockPlacement::Above(editor::Anchor::min()),
- height: Some(1),
- style: BlockStyle::Sticky,
- render: Arc::new(|_| gpui::Empty.into_any_element()),
- priority: 0,
- }]
- .into_iter()
- .chain(
- editor
- .buffer()
- .read(cx)
- .buffer_anchor_to_anchor(&message_buffer, Anchor::MAX, cx)
- .map(|anchor| BlockProperties {
- placement: BlockPlacement::Below(anchor),
- height: Some(1),
- style: BlockStyle::Sticky,
- render: Arc::new(|_| gpui::Empty.into_any_element()),
- priority: 0,
- }),
- ),
- None,
- cx,
- )
- });
- })?;
-
anyhow::Ok(())
})
.detach();
@@ -416,12 +417,23 @@ impl CommitView {
None
};
+ let gutter_width = self.editor.update(cx, |editor, cx| {
+ let snapshot = editor.snapshot(window, cx);
+ let style = editor.style(cx);
+ let font_id = window.text_system().resolve_font(&style.text.font());
+ let font_size = style.text.font_size.to_pixels(window.rem_size());
+ snapshot
+ .gutter_dimensions(font_id, font_size, style, window, cx)
+ .full_width()
+ });
+
h_flex()
.border_b_1()
.border_color(cx.theme().colors().border_variant)
+ .w_full()
.child(
h_flex()
- .w(self.editor.read(cx).last_gutter_dimensions().full_width())
+ .w(gutter_width)
.justify_center()
.child(self.render_commit_avatar(&commit.sha, rems_from_px(48.), window, cx)),
)
@@ -1010,7 +1022,9 @@ impl Render for CommitView {
.size_full()
.bg(cx.theme().colors().editor_background)
.child(self.render_header(window, cx))
- .child(div().flex_grow().child(self.editor.clone()))
+ .when(!self.editor.read(cx).is_empty(cx), |this| {
+ this.child(div().flex_grow().child(self.editor.clone()))
+ })
}
}
@@ -108,7 +108,7 @@ impl FileDiffView {
for buffer in [&old_buffer, &new_buffer] {
cx.subscribe(buffer, move |this, _, event, _| match event {
language::BufferEvent::Edited
- | language::BufferEvent::LanguageChanged
+ | language::BufferEvent::LanguageChanged(_)
| language::BufferEvent::Reparsed => {
this.buffer_changes_tx.send(()).ok();
}
@@ -13,6 +13,7 @@ use agent_settings::AgentSettings;
use anyhow::Context as _;
use askpass::AskPassDelegate;
use cloud_llm_client::CompletionIntent;
+use collections::{BTreeMap, HashMap, HashSet};
use db::kvp::KEY_VALUE_STORE;
use editor::{
Direction, Editor, EditorElement, EditorMode, MultiBuffer, MultiBufferOffset,
@@ -33,10 +34,11 @@ use git::{
TrashUntrackedFiles, UnstageAll,
};
use gpui::{
- Action, AsyncApp, AsyncWindowContext, ClickEvent, Corner, DismissEvent, Entity, EventEmitter,
- FocusHandle, Focusable, KeyContext, ListHorizontalSizingBehavior, ListSizingBehavior,
- MouseButton, MouseDownEvent, Point, PromptLevel, ScrollStrategy, Subscription, Task,
- UniformListScrollHandle, WeakEntity, actions, anchored, deferred, uniform_list,
+ Action, AsyncApp, AsyncWindowContext, Bounds, ClickEvent, Corner, DismissEvent, Entity,
+ EventEmitter, FocusHandle, Focusable, KeyContext, ListHorizontalSizingBehavior,
+ ListSizingBehavior, MouseButton, MouseDownEvent, Point, PromptLevel, ScrollStrategy,
+ Subscription, Task, UniformListScrollHandle, WeakEntity, actions, anchored, deferred, point,
+ size, uniform_list,
};
use itertools::Itertools;
use language::{Buffer, File};
@@ -60,12 +62,13 @@ use settings::{Settings, SettingsStore, StatusStyle};
use std::future::Future;
use std::ops::Range;
use std::path::Path;
-use std::{collections::HashSet, sync::Arc, time::Duration, usize};
+use std::{sync::Arc, time::Duration, usize};
use strum::{IntoEnumIterator, VariantNames};
use time::OffsetDateTime;
use ui::{
- ButtonLike, Checkbox, CommonAnimationExt, ContextMenu, ElevationIndex, PopoverMenu, ScrollAxes,
- Scrollbars, SplitButton, Tooltip, WithScrollbar, prelude::*,
+ ButtonLike, Checkbox, CommonAnimationExt, ContextMenu, ElevationIndex, IndentGuideColors,
+ PopoverMenu, RenderedIndentGuide, ScrollAxes, Scrollbars, SplitButton, Tooltip, WithScrollbar,
+ prelude::*,
};
use util::paths::PathStyle;
use util::{ResultExt, TryFutureExt, maybe};
@@ -92,6 +95,8 @@ actions!(
ToggleFillCoAuthors,
/// Toggles sorting entries by path vs status.
ToggleSortByPath,
+ /// Toggles showing entries in tree vs flat view.
+ ToggleTreeView,
]
);
@@ -122,6 +127,7 @@ struct GitMenuState {
has_new_changes: bool,
sort_by_path: bool,
has_stash_items: bool,
+ tree_view: bool,
}
fn git_panel_context_menu(
@@ -166,20 +172,34 @@ fn git_panel_context_menu(
)
.separator()
.entry(
- if state.sort_by_path {
- "Sort by Status"
+ if state.tree_view {
+ "Flat View"
} else {
- "Sort by Path"
+ "Tree View"
},
- Some(Box::new(ToggleSortByPath)),
- move |window, cx| window.dispatch_action(Box::new(ToggleSortByPath), cx),
+ Some(Box::new(ToggleTreeView)),
+ move |window, cx| window.dispatch_action(Box::new(ToggleTreeView), cx),
)
+ .when(!state.tree_view, |this| {
+ this.entry(
+ if state.sort_by_path {
+ "Sort by Status"
+ } else {
+ "Sort by Path"
+ },
+ Some(Box::new(ToggleSortByPath)),
+ move |window, cx| window.dispatch_action(Box::new(ToggleSortByPath), cx),
+ )
+ })
})
}
const GIT_PANEL_KEY: &str = "GitPanel";
const UPDATE_DEBOUNCE: Duration = Duration::from_millis(50);
+// TODO: We should revise this part. It seems the indentation width is not aligned with the one in project panel
+const TREE_INDENT: f32 = 12.0;
+const TREE_INDENT_GUIDE_OFFSET: f32 = 16.0;
pub fn register(workspace: &mut Workspace) {
workspace.register_action(|workspace, _: &ToggleFocus, window, cx| {
@@ -204,7 +224,7 @@ struct SerializedGitPanel {
signoff_enabled: bool,
}
-#[derive(Debug, PartialEq, Eq, Clone, Copy)]
+#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
enum Section {
Conflict,
Tracked,
@@ -240,6 +260,8 @@ impl GitHeaderEntry {
#[derive(Debug, PartialEq, Eq, Clone)]
enum GitListEntry {
Status(GitStatusEntry),
+ TreeStatus(GitTreeStatusEntry),
+ Directory(GitTreeDirEntry),
Header(GitHeaderEntry),
}
@@ -247,11 +269,250 @@ impl GitListEntry {
fn status_entry(&self) -> Option<&GitStatusEntry> {
match self {
GitListEntry::Status(entry) => Some(entry),
+ GitListEntry::TreeStatus(entry) => Some(&entry.entry),
_ => None,
}
}
}
+enum GitPanelViewMode {
+ Flat,
+ Tree(TreeViewState),
+}
+
+impl GitPanelViewMode {
+ fn from_settings(cx: &App) -> Self {
+ if GitPanelSettings::get_global(cx).tree_view {
+ GitPanelViewMode::Tree(TreeViewState::default())
+ } else {
+ GitPanelViewMode::Flat
+ }
+ }
+
+ fn tree_state(&self) -> Option<&TreeViewState> {
+ match self {
+ GitPanelViewMode::Tree(state) => Some(state),
+ GitPanelViewMode::Flat => None,
+ }
+ }
+
+ fn tree_state_mut(&mut self) -> Option<&mut TreeViewState> {
+ match self {
+ GitPanelViewMode::Tree(state) => Some(state),
+ GitPanelViewMode::Flat => None,
+ }
+ }
+}
+
+#[derive(Default)]
+struct TreeViewState {
+ // Maps visible index to actual entry index.
+ // Length equals the number of visible entries.
+ // This is needed because some entries (like collapsed directories) may be hidden.
+ logical_indices: Vec<usize>,
+ expanded_dirs: HashMap<TreeKey, bool>,
+ directory_descendants: HashMap<TreeKey, Vec<GitStatusEntry>>,
+}
+
+impl TreeViewState {
+ fn build_tree_entries(
+ &mut self,
+ section: Section,
+ mut entries: Vec<GitStatusEntry>,
+ repo: &Repository,
+ seen_directories: &mut HashSet<TreeKey>,
+ optimistic_staging: &HashMap<RepoPath, bool>,
+ ) -> Vec<(GitListEntry, bool)> {
+ if entries.is_empty() {
+ return Vec::new();
+ }
+
+ entries.sort_by(|a, b| a.repo_path.cmp(&b.repo_path));
+
+ let mut root = TreeNode::default();
+ for entry in entries {
+ let components: Vec<&str> = entry.repo_path.components().collect();
+ if components.is_empty() {
+ root.files.push(entry);
+ continue;
+ }
+
+ let mut current = &mut root;
+ let mut current_path = String::new();
+
+ for (ix, component) in components.iter().enumerate() {
+ if ix == components.len() - 1 {
+ current.files.push(entry.clone());
+ } else {
+ if !current_path.is_empty() {
+ current_path.push('/');
+ }
+ current_path.push_str(component);
+ let dir_path = RepoPath::new(¤t_path)
+ .expect("repo path from status entry component");
+
+ let component = SharedString::from(component.to_string());
+
+ current = current
+ .children
+ .entry(component.clone())
+ .or_insert_with(|| TreeNode {
+ name: component,
+ path: Some(dir_path),
+ ..Default::default()
+ });
+ }
+ }
+ }
+
+ let (flattened, _) = self.flatten_tree(
+ &root,
+ section,
+ 0,
+ repo,
+ seen_directories,
+ optimistic_staging,
+ );
+ flattened
+ }
+
+ fn flatten_tree(
+ &mut self,
+ node: &TreeNode,
+ section: Section,
+ depth: usize,
+ repo: &Repository,
+ seen_directories: &mut HashSet<TreeKey>,
+ optimistic_staging: &HashMap<RepoPath, bool>,
+ ) -> (Vec<(GitListEntry, bool)>, Vec<GitStatusEntry>) {
+ let mut all_statuses = Vec::new();
+ let mut flattened = Vec::new();
+
+ for child in node.children.values() {
+ let (terminal, name) = Self::compact_directory_chain(child);
+ let Some(path) = terminal.path.clone().or_else(|| child.path.clone()) else {
+ continue;
+ };
+ let (child_flattened, mut child_statuses) = self.flatten_tree(
+ terminal,
+ section,
+ depth + 1,
+ repo,
+ seen_directories,
+ optimistic_staging,
+ );
+ let key = TreeKey { section, path };
+ let expanded = *self.expanded_dirs.get(&key).unwrap_or(&true);
+ self.expanded_dirs.entry(key.clone()).or_insert(true);
+ seen_directories.insert(key.clone());
+
+ let staged_count = child_statuses
+ .iter()
+ .filter(|entry| Self::is_entry_staged(entry, repo, optimistic_staging))
+ .count();
+ let staged_state =
+ GitPanel::toggle_state_for_counts(staged_count, child_statuses.len());
+
+ self.directory_descendants
+ .insert(key.clone(), child_statuses.clone());
+
+ flattened.push((
+ GitListEntry::Directory(GitTreeDirEntry {
+ key,
+ name,
+ depth,
+ staged_state,
+ expanded,
+ }),
+ true,
+ ));
+
+ if expanded {
+ flattened.extend(child_flattened);
+ } else {
+ flattened.extend(child_flattened.into_iter().map(|(child, _)| (child, false)));
+ }
+
+ all_statuses.append(&mut child_statuses);
+ }
+
+ for file in &node.files {
+ all_statuses.push(file.clone());
+ flattened.push((
+ GitListEntry::TreeStatus(GitTreeStatusEntry {
+ entry: file.clone(),
+ depth,
+ }),
+ true,
+ ));
+ }
+
+ (flattened, all_statuses)
+ }
+
+ fn compact_directory_chain(mut node: &TreeNode) -> (&TreeNode, SharedString) {
+ let mut parts = vec![node.name.clone()];
+ while node.files.is_empty() && node.children.len() == 1 {
+ let Some(child) = node.children.values().next() else {
+ continue;
+ };
+ if child.path.is_none() {
+ break;
+ }
+ parts.push(child.name.clone());
+ node = child;
+ }
+ let name = parts.join("/");
+ (node, SharedString::from(name))
+ }
+
+ fn is_entry_staged(
+ entry: &GitStatusEntry,
+ repo: &Repository,
+ optimistic_staging: &HashMap<RepoPath, bool>,
+ ) -> bool {
+ if let Some(optimistic) = optimistic_staging.get(&entry.repo_path) {
+ return *optimistic;
+ }
+ repo.pending_ops_for_path(&entry.repo_path)
+ .map(|ops| ops.staging() || ops.staged())
+ .or_else(|| {
+ repo.status_for_path(&entry.repo_path)
+ .and_then(|status| status.status.staging().as_bool())
+ })
+ .unwrap_or(entry.staging.has_staged())
+ }
+}
+
+#[derive(Debug, PartialEq, Eq, Clone)]
+struct GitTreeStatusEntry {
+ entry: GitStatusEntry,
+ depth: usize,
+}
+
+#[derive(Debug, PartialEq, Eq, Clone, Hash)]
+struct TreeKey {
+ section: Section,
+ path: RepoPath,
+}
+
+#[derive(Debug, PartialEq, Eq, Clone)]
+struct GitTreeDirEntry {
+ key: TreeKey,
+ name: SharedString,
+ depth: usize,
+ staged_state: ToggleState,
+ expanded: bool,
+}
+
+#[derive(Default)]
+struct TreeNode {
+ name: SharedString,
+ path: Option<RepoPath>,
+ children: BTreeMap<SharedString, TreeNode>,
+ files: Vec<GitStatusEntry>,
+}
+
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct GitStatusEntry {
pub(crate) repo_path: RepoPath,
@@ -345,12 +606,15 @@ pub struct GitPanel {
add_coauthors: bool,
generate_commit_message_task: Option<Task<Option<()>>>,
entries: Vec<GitListEntry>,
+ view_mode: GitPanelViewMode,
+ entries_indices: HashMap<RepoPath, usize>,
single_staged_entry: Option<GitStatusEntry>,
single_tracked_entry: Option<GitStatusEntry>,
focus_handle: FocusHandle,
fs: Arc<dyn Fs>,
new_count: usize,
entry_count: usize,
+ changes_count: usize,
new_staged_count: usize,
pending_commit: Option<Task<()>>,
amend_pending: bool,
@@ -374,6 +638,7 @@ pub struct GitPanel {
local_committer_task: Option<Task<()>>,
bulk_staging: Option<BulkStaging>,
stash_entries: GitStash,
+ optimistic_staging: HashMap<RepoPath, bool>,
_settings_subscription: Subscription,
}
@@ -433,14 +698,19 @@ impl GitPanel {
cx.on_focus(&focus_handle, window, Self::focus_in).detach();
let mut was_sort_by_path = GitPanelSettings::get_global(cx).sort_by_path;
+ let mut was_tree_view = GitPanelSettings::get_global(cx).tree_view;
cx.observe_global_in::<SettingsStore>(window, move |this, window, cx| {
- let is_sort_by_path = GitPanelSettings::get_global(cx).sort_by_path;
- if is_sort_by_path != was_sort_by_path {
- this.entries.clear();
+ let sort_by_path = GitPanelSettings::get_global(cx).sort_by_path;
+ let tree_view = GitPanelSettings::get_global(cx).tree_view;
+ if tree_view != was_tree_view {
+ this.view_mode = GitPanelViewMode::from_settings(cx);
+ }
+ if sort_by_path != was_sort_by_path || tree_view != was_tree_view {
this.bulk_staging.take();
this.update_visible_entries(window, cx);
}
- was_sort_by_path = is_sort_by_path
+ was_sort_by_path = sort_by_path;
+ was_tree_view = tree_view;
})
.detach();
@@ -506,10 +776,13 @@ impl GitPanel {
add_coauthors: true,
generate_commit_message_task: None,
entries: Vec::new(),
+ view_mode: GitPanelViewMode::from_settings(cx),
+ entries_indices: HashMap::default(),
focus_handle: cx.focus_handle(),
fs,
new_count: 0,
new_staged_count: 0,
+ changes_count: 0,
pending_commit: None,
amend_pending: false,
original_commit_message: None,
@@ -535,6 +808,7 @@ impl GitPanel {
entry_count: 0,
bulk_staging: None,
stash_entries: Default::default(),
+ optimistic_staging: HashMap::default(),
_settings_subscription,
};
@@ -543,51 +817,8 @@ impl GitPanel {
})
}
- pub fn entry_by_path(&self, path: &RepoPath, cx: &App) -> Option<usize> {
- if GitPanelSettings::get_global(cx).sort_by_path {
- return self
- .entries
- .binary_search_by(|entry| entry.status_entry().unwrap().repo_path.cmp(path))
- .ok();
- }
-
- if self.conflicted_count > 0 {
- let conflicted_start = 1;
- if let Ok(ix) = self.entries[conflicted_start..conflicted_start + self.conflicted_count]
- .binary_search_by(|entry| entry.status_entry().unwrap().repo_path.cmp(path))
- {
- return Some(conflicted_start + ix);
- }
- }
- if self.tracked_count > 0 {
- let tracked_start = if self.conflicted_count > 0 {
- 1 + self.conflicted_count
- } else {
- 0
- } + 1;
- if let Ok(ix) = self.entries[tracked_start..tracked_start + self.tracked_count]
- .binary_search_by(|entry| entry.status_entry().unwrap().repo_path.cmp(path))
- {
- return Some(tracked_start + ix);
- }
- }
- if self.new_count > 0 {
- let untracked_start = if self.conflicted_count > 0 {
- 1 + self.conflicted_count
- } else {
- 0
- } + if self.tracked_count > 0 {
- 1 + self.tracked_count
- } else {
- 0
- } + 1;
- if let Ok(ix) = self.entries[untracked_start..untracked_start + self.new_count]
- .binary_search_by(|entry| entry.status_entry().unwrap().repo_path.cmp(path))
- {
- return Some(untracked_start + ix);
- }
- }
- None
+ pub fn entry_by_path(&self, path: &RepoPath) -> Option<usize> {
+ self.entries_indices.get(path).copied()
}
pub fn select_entry_by_path(
@@ -602,7 +833,7 @@ impl GitPanel {
let Some(repo_path) = git_repo.read(cx).project_path_to_repo_path(&path, cx) else {
return;
};
- let Some(ix) = self.entry_by_path(&repo_path, cx) else {
+ let Some(ix) = self.entry_by_path(&repo_path) else {
return;
};
self.selected_entry = Some(ix);
@@ -702,9 +933,15 @@ impl GitPanel {
cx.notify();
}
+ fn first_status_entry_index(&self) -> Option<usize> {
+ self.entries
+ .iter()
+ .position(|entry| entry.status_entry().is_some())
+ }
+
fn select_first(&mut self, _: &SelectFirst, _window: &mut Window, cx: &mut Context<Self>) {
- if !self.entries.is_empty() {
- self.selected_entry = Some(1);
+ if let Some(first_entry) = self.first_status_entry_index() {
+ self.selected_entry = Some(first_entry);
self.scroll_to_selected_entry(cx);
}
}
@@ -791,7 +1028,7 @@ impl GitPanel {
.as_ref()
.is_some_and(|active_repository| active_repository.read(cx).status_summary().count > 0);
if have_entries && self.selected_entry.is_none() {
- self.selected_entry = Some(1);
+ self.selected_entry = self.first_status_entry_index();
self.scroll_to_selected_entry(cx);
cx.notify();
}
@@ -1318,6 +1555,37 @@ impl GitPanel {
.detach();
}
+ fn is_entry_staged(&self, entry: &GitStatusEntry, repo: &Repository) -> bool {
+ // Checking for current staged/unstaged file status is a chained operation:
+ // 1. first, we check for any pending operation recorded in repository
+ // 2. if there are no pending ops either running or finished, we then ask the repository
+ // for the most up-to-date file status read from disk - we do this since `entry` arg to this function `render_entry`
+ // is likely to be staled, and may lead to weird artifacts in the form of subsecond auto-uncheck/check on
+ // the checkbox's state (or flickering) which is undesirable.
+ // 3. finally, if there is no info about this `entry` in the repo, we fall back to whatever status is encoded
+ // in `entry` arg.
+ if let Some(optimistic) = self.optimistic_staging.get(&entry.repo_path) {
+ return *optimistic;
+ }
+ repo.pending_ops_for_path(&entry.repo_path)
+ .map(|ops| ops.staging() || ops.staged())
+ .or_else(|| {
+ repo.status_for_path(&entry.repo_path)
+ .and_then(|status| status.status.staging().as_bool())
+ })
+ .unwrap_or(entry.staging.has_staged())
+ }
+
+ fn toggle_state_for_counts(staged_count: usize, total: usize) -> ToggleState {
+ if staged_count == 0 || total == 0 {
+ ToggleState::Unselected
+ } else if staged_count == total {
+ ToggleState::Selected
+ } else {
+ ToggleState::Indeterminate
+ }
+ }
+
pub fn stage_all(&mut self, _: &StageAll, _window: &mut Window, cx: &mut Context<Self>) {
self.change_all_files_stage(true, cx);
}
@@ -1332,50 +1600,92 @@ impl GitPanel {
_window: &mut Window,
cx: &mut Context<Self>,
) {
- let Some(active_repository) = self.active_repository.as_ref() else {
+ let Some(active_repository) = self.active_repository.clone() else {
return;
};
- let repo = active_repository.read(cx);
- let (stage, repo_paths) = match entry {
- GitListEntry::Status(status_entry) => {
- let repo_paths = vec![status_entry.clone()];
- let stage = if repo
- .pending_ops_for_path(&status_entry.repo_path)
- .map(|ops| ops.staging() || ops.staged())
- .or_else(|| {
- repo.status_for_path(&status_entry.repo_path)
- .map(|status| status.status.staging().has_staged())
- })
- .unwrap_or(status_entry.staging.has_staged())
- {
- if let Some(op) = self.bulk_staging.clone()
- && op.anchor == status_entry.repo_path
- {
- self.bulk_staging = None;
- }
- false
- } else {
- self.set_bulk_staging_anchor(status_entry.repo_path.clone(), cx);
- true
- };
- (stage, repo_paths)
- }
- GitListEntry::Header(section) => {
- let goal_staged_state = !self.header_state(section.header).selected();
- let entries = self
- .entries
- .iter()
- .filter_map(|entry| entry.status_entry())
- .filter(|status_entry| {
- section.contains(status_entry, repo)
- && status_entry.staging.as_bool() != Some(goal_staged_state)
- })
- .cloned()
- .collect::<Vec<_>>();
+ let mut set_anchor: Option<RepoPath> = None;
+ let mut clear_anchor = None;
+
+ let (stage, repo_paths) = {
+ let repo = active_repository.read(cx);
+ match entry {
+ GitListEntry::Status(status_entry) => {
+ let repo_paths = vec![status_entry.clone()];
+ let stage = if self.is_entry_staged(status_entry, &repo) {
+ if let Some(op) = self.bulk_staging.clone()
+ && op.anchor == status_entry.repo_path
+ {
+ clear_anchor = Some(op.anchor);
+ }
+ false
+ } else {
+ set_anchor = Some(status_entry.repo_path.clone());
+ true
+ };
+ (stage, repo_paths)
+ }
+ GitListEntry::TreeStatus(status_entry) => {
+ let repo_paths = vec![status_entry.entry.clone()];
+ let stage = if self.is_entry_staged(&status_entry.entry, &repo) {
+ if let Some(op) = self.bulk_staging.clone()
+ && op.anchor == status_entry.entry.repo_path
+ {
+ clear_anchor = Some(op.anchor);
+ }
+ false
+ } else {
+ set_anchor = Some(status_entry.entry.repo_path.clone());
+ true
+ };
+ (stage, repo_paths)
+ }
+ GitListEntry::Header(section) => {
+ let goal_staged_state = !self.header_state(section.header).selected();
+ let entries = self
+ .entries
+ .iter()
+ .filter_map(|entry| entry.status_entry())
+ .filter(|status_entry| {
+ section.contains(status_entry, &repo)
+ && status_entry.staging.as_bool() != Some(goal_staged_state)
+ })
+ .cloned()
+ .collect::<Vec<_>>();
- (goal_staged_state, entries)
+ (goal_staged_state, entries)
+ }
+ GitListEntry::Directory(entry) => {
+ let goal_staged_state = entry.staged_state != ToggleState::Selected;
+ let entries = self
+ .view_mode
+ .tree_state()
+ .and_then(|state| state.directory_descendants.get(&entry.key))
+ .cloned()
+ .unwrap_or_default()
+ .into_iter()
+ .filter(|status_entry| {
+ self.is_entry_staged(status_entry, &repo) != goal_staged_state
+ })
+ .collect::<Vec<_>>();
+ (goal_staged_state, entries)
+ }
}
};
+ if let Some(anchor) = clear_anchor {
+ if let Some(op) = self.bulk_staging.clone()
+ && op.anchor == anchor
+ {
+ self.bulk_staging = None;
+ }
+ }
+ if let Some(anchor) = set_anchor {
+ self.set_bulk_staging_anchor(anchor, cx);
+ }
+
+ let repo = active_repository.read(cx);
+ self.apply_optimistic_stage(&repo_paths, stage, &repo);
+ cx.notify();
+
self.change_file_stage(stage, repo_paths, cx);
}
@@ -1420,6 +1730,81 @@ impl GitPanel {
.detach();
}
+ fn apply_optimistic_stage(
+ &mut self,
+ entries: &[GitStatusEntry],
+ stage: bool,
+ repo: &Repository,
+ ) {
+ // This “optimistic” pass keeps all checkboxes—files, folders, and section headers—visually in sync the moment you click,
+ // even though `change_file_stage` is still talking to the repository in the background.
+ // Before, the UI would wait for Git, causing checkbox flicker or stale parent states;
+ // Now, users see instant feedback and accurate parent/child tri-states while the async staging operation completes.
+ //
+ // Description:
+ // It records the desired state in `self.optimistic_staging` (a map from path → bool),
+ // walks the rendered entries, and swaps their `staging` flags based on that map.
+ // In tree view it also recomputes every directory’s tri-state checkbox using the updated child data,
+ // so parent folders flip between selected/indeterminate/empty in the same frame.
+ let new_stage = if stage {
+ StageStatus::Staged
+ } else {
+ StageStatus::Unstaged
+ };
+
+ self.optimistic_staging
+ .extend(entries.iter().map(|entry| (entry.repo_path.clone(), stage)));
+
+ let staged_states: HashMap<TreeKey, ToggleState> = self
+ .view_mode
+ .tree_state()
+ .map(|state| state.directory_descendants.iter())
+ .into_iter()
+ .flatten()
+ .map(|(key, descendants)| {
+ let staged_count = descendants
+ .iter()
+ .filter(|entry| self.is_entry_staged(entry, repo))
+ .count();
+ (
+ key.clone(),
+ Self::toggle_state_for_counts(staged_count, descendants.len()),
+ )
+ })
+ .collect();
+
+ for list_entry in &mut self.entries {
+ match list_entry {
+ GitListEntry::Status(status) => {
+ if self
+ .optimistic_staging
+ .get(&status.repo_path)
+ .is_some_and(|s| *s == stage)
+ {
+ status.staging = new_stage;
+ }
+ }
+ GitListEntry::TreeStatus(status) => {
+ if self
+ .optimistic_staging
+ .get(&status.entry.repo_path)
+ .is_some_and(|s| *s == stage)
+ {
+ status.entry.staging = new_stage;
+ }
+ }
+ GitListEntry::Directory(dir) => {
+ if let Some(state) = staged_states.get(&dir.key) {
+ dir.staged_state = *state;
+ }
+ }
+ _ => {}
+ }
+ }
+
+ self.update_counts(repo);
+ }
+
pub fn total_staged_count(&self) -> usize {
self.tracked_staged_count + self.new_staged_count + self.conflicted_staged_count
}
@@ -2690,6 +3075,29 @@ impl GitPanel {
}
}
+ fn toggle_tree_view(&mut self, _: &ToggleTreeView, _: &mut Window, cx: &mut Context<Self>) {
+ let current_setting = GitPanelSettings::get_global(cx).tree_view;
+ if let Some(workspace) = self.workspace.upgrade() {
+ let workspace = workspace.read(cx);
+ let fs = workspace.app_state().fs.clone();
+ cx.update_global::<SettingsStore, _>(|store, _cx| {
+ store.update_settings_file(fs, move |settings, _cx| {
+ settings.git_panel.get_or_insert_default().tree_view = Some(!current_setting);
+ });
+ })
+ }
+ }
+
+ fn toggle_directory(&mut self, key: &TreeKey, window: &mut Window, cx: &mut Context<Self>) {
+ if let Some(state) = self.view_mode.tree_state_mut() {
+ let expanded = state.expanded_dirs.entry(key.clone()).or_insert(true);
+ *expanded = !*expanded;
+ self.update_visible_entries(window, cx);
+ } else {
+ util::debug_panic!("Attempted to toggle directory in flat Git Panel state");
+ }
+ }
+
fn fill_co_authors(&mut self, message: &mut String, cx: &mut Context<Self>) {
const CO_AUTHOR_PREFIX: &str = "Co-authored-by: ";
@@ -2799,27 +3207,34 @@ impl GitPanel {
let bulk_staging = self.bulk_staging.take();
let last_staged_path_prev_index = bulk_staging
.as_ref()
- .and_then(|op| self.entry_by_path(&op.anchor, cx));
+ .and_then(|op| self.entry_by_path(&op.anchor));
self.entries.clear();
+ self.entries_indices.clear();
self.single_staged_entry.take();
self.single_tracked_entry.take();
self.conflicted_count = 0;
self.conflicted_staged_count = 0;
+ self.changes_count = 0;
self.new_count = 0;
self.tracked_count = 0;
self.new_staged_count = 0;
self.tracked_staged_count = 0;
self.entry_count = 0;
+ self.max_width_item_index = None;
let sort_by_path = GitPanelSettings::get_global(cx).sort_by_path;
+ let is_tree_view = matches!(self.view_mode, GitPanelViewMode::Tree(_));
+ let group_by_status = is_tree_view || !sort_by_path;
let mut changed_entries = Vec::new();
let mut new_entries = Vec::new();
let mut conflict_entries = Vec::new();
let mut single_staged_entry = None;
let mut staged_count = 0;
- let mut max_width_item: Option<(RepoPath, usize)> = None;
+ let mut seen_directories = HashSet::default();
+ let mut max_width_estimate = 0usize;
+ let mut max_width_item_index = None;
let Some(repo) = self.active_repository.as_ref() else {
// Just clear entries if no repository is active.
@@ -2832,6 +3247,7 @@ impl GitPanel {
self.stash_entries = repo.cached_stash();
for entry in repo.cached_status() {
+ self.changes_count += 1;
let is_conflict = repo.had_conflict_on_last_merge_head_change(&entry.repo_path);
let is_new = entry.status.is_created();
let staging = entry.status.staging();
@@ -2856,26 +3272,9 @@ impl GitPanel {
single_staged_entry = Some(entry.clone());
}
- let width_estimate = Self::item_width_estimate(
- entry.parent_dir(path_style).map(|s| s.len()).unwrap_or(0),
- entry.display_name(path_style).len(),
- );
-
- match max_width_item.as_mut() {
- Some((repo_path, estimate)) => {
- if width_estimate > *estimate {
- *repo_path = entry.repo_path.clone();
- *estimate = width_estimate;
- }
- }
- None => max_width_item = Some((entry.repo_path.clone(), width_estimate)),
- }
-
- if sort_by_path {
- changed_entries.push(entry);
- } else if is_conflict {
+ if group_by_status && is_conflict {
conflict_entries.push(entry);
- } else if is_new {
+ } else if group_by_status && is_new {
new_entries.push(entry);
} else {
changed_entries.push(entry);
@@ -2910,52 +3309,126 @@ impl GitPanel {
self.single_tracked_entry = changed_entries.first().cloned();
}
- if !conflict_entries.is_empty() {
- self.entries.push(GitListEntry::Header(GitHeaderEntry {
- header: Section::Conflict,
- }));
- self.entries
- .extend(conflict_entries.into_iter().map(GitListEntry::Status));
+ let mut push_entry =
+ |this: &mut Self,
+ entry: GitListEntry,
+ is_visible: bool,
+ logical_indices: Option<&mut Vec<usize>>| {
+ if let Some(estimate) =
+ this.width_estimate_for_list_entry(is_tree_view, &entry, path_style)
+ {
+ if estimate > max_width_estimate {
+ max_width_estimate = estimate;
+ max_width_item_index = Some(this.entries.len());
+ }
+ }
+
+ if let Some(repo_path) = entry.status_entry().map(|status| status.repo_path.clone())
+ {
+ this.entries_indices.insert(repo_path, this.entries.len());
+ }
+
+ if let (Some(indices), true) = (logical_indices, is_visible) {
+ indices.push(this.entries.len());
+ }
+
+ this.entries.push(entry);
+ };
+
+ macro_rules! take_section_entries {
+ () => {
+ [
+ (Section::Conflict, std::mem::take(&mut conflict_entries)),
+ (Section::Tracked, std::mem::take(&mut changed_entries)),
+ (Section::New, std::mem::take(&mut new_entries)),
+ ]
+ };
}
- if !changed_entries.is_empty() {
- if !sort_by_path {
- self.entries.push(GitListEntry::Header(GitHeaderEntry {
- header: Section::Tracked,
- }));
+ match &mut self.view_mode {
+ GitPanelViewMode::Tree(tree_state) => {
+ tree_state.logical_indices.clear();
+ tree_state.directory_descendants.clear();
+
+ // This is just to get around the borrow checker
+ // because push_entry mutably borrows self
+ let mut tree_state = std::mem::take(tree_state);
+
+ for (section, entries) in take_section_entries!() {
+ if entries.is_empty() {
+ continue;
+ }
+
+ push_entry(
+ self,
+ GitListEntry::Header(GitHeaderEntry { header: section }),
+ true,
+ Some(&mut tree_state.logical_indices),
+ );
+
+ for (entry, is_visible) in tree_state.build_tree_entries(
+ section,
+ entries,
+ &repo,
+ &mut seen_directories,
+ &self.optimistic_staging,
+ ) {
+ push_entry(
+ self,
+ entry,
+ is_visible,
+ Some(&mut tree_state.logical_indices),
+ );
+ }
+ }
+
+ tree_state
+ .expanded_dirs
+ .retain(|key, _| seen_directories.contains(key));
+ self.view_mode = GitPanelViewMode::Tree(tree_state);
}
- self.entries
- .extend(changed_entries.into_iter().map(GitListEntry::Status));
- }
- if !new_entries.is_empty() {
- self.entries.push(GitListEntry::Header(GitHeaderEntry {
- header: Section::New,
- }));
- self.entries
- .extend(new_entries.into_iter().map(GitListEntry::Status));
- }
+ GitPanelViewMode::Flat => {
+ for (section, entries) in take_section_entries!() {
+ if entries.is_empty() {
+ continue;
+ }
- if let Some((repo_path, _)) = max_width_item {
- self.max_width_item_index = self.entries.iter().position(|entry| match entry {
- GitListEntry::Status(git_status_entry) => git_status_entry.repo_path == repo_path,
- GitListEntry::Header(_) => false,
- });
+ if section != Section::Tracked || !sort_by_path {
+ push_entry(
+ self,
+ GitListEntry::Header(GitHeaderEntry { header: section }),
+ true,
+ None,
+ );
+ }
+
+ for entry in entries {
+ push_entry(self, GitListEntry::Status(entry), true, None);
+ }
+ }
+ }
}
+ self.max_width_item_index = max_width_item_index;
+
self.update_counts(repo);
+ let visible_paths: HashSet<RepoPath> = self
+ .entries
+ .iter()
+ .filter_map(|entry| entry.status_entry().map(|e| e.repo_path.clone()))
+ .collect();
+ self.optimistic_staging
+ .retain(|path, _| visible_paths.contains(path));
let bulk_staging_anchor_new_index = bulk_staging
.as_ref()
.filter(|op| op.repo_id == repo.id)
- .and_then(|op| self.entry_by_path(&op.anchor, cx));
+ .and_then(|op| self.entry_by_path(&op.anchor));
if bulk_staging_anchor_new_index == last_staged_path_prev_index
&& let Some(index) = bulk_staging_anchor_new_index
&& let Some(entry) = self.entries.get(index)
&& let Some(entry) = entry.status_entry()
- && repo
- .pending_ops_for_path(&entry.repo_path)
- .map(|ops| ops.staging() || ops.staged())
- .unwrap_or(entry.staging.has_staged())
+ && self.is_entry_staged(entry, &repo)
{
self.bulk_staging = bulk_staging;
}
@@ -24,6 +24,7 @@ pub struct GitPanelSettings {
pub fallback_branch_name: String,
pub sort_by_path: bool,
pub collapse_untracked_diff: bool,
+ pub tree_view: bool,
}
impl ScrollbarVisibility for GitPanelSettings {
@@ -56,6 +57,7 @@ impl Settings for GitPanelSettings {
fallback_branch_name: git_panel.fallback_branch_name.unwrap(),
sort_by_path: git_panel.sort_by_path.unwrap(),
collapse_untracked_diff: git_panel.collapse_untracked_diff.unwrap(),
+ tree_view: git_panel.tree_view.unwrap(),
}
}
}
@@ -220,7 +220,7 @@ impl PickerDelegate for PickerPromptDelegate {
let shortened_option = util::truncate_and_trailoff(&hit.string, self.max_match_length);
Some(
- ListItem::new(SharedString::from(format!("picker-prompt-menu-{ix}")))
+ ListItem::new(format!("picker-prompt-menu-{ix}"))
.inset(true)
.spacing(ListItemSpacing::Sparse)
.toggle_state(selected)
@@ -644,7 +644,10 @@ impl ProjectDiff {
}
fn sort_prefix(repo: &Repository, repo_path: &RepoPath, status: FileStatus, cx: &App) -> u64 {
- if GitPanelSettings::get_global(cx).sort_by_path {
+ let settings = GitPanelSettings::get_global(cx);
+
+ // Tree view can only sort by path
+ if settings.sort_by_path || settings.tree_view {
TRACKED_SORT_PREFIX
} else if repo.had_conflict_on_last_merge_head_change(repo_path) {
CONFLICT_SORT_PREFIX
@@ -464,7 +464,7 @@ impl PickerDelegate for StashListDelegate {
);
Some(
- ListItem::new(SharedString::from(format!("stash-{ix}")))
+ ListItem::new(format!("stash-{ix}"))
.inset(true)
.spacing(ListItemSpacing::Sparse)
.toggle_state(selected)
@@ -170,7 +170,7 @@ impl TextDiffView {
cx.subscribe(&source_buffer, move |this, _, event, _| match event {
language::BufferEvent::Edited
- | language::BufferEvent::LanguageChanged
+ | language::BufferEvent::LanguageChanged(_)
| language::BufferEvent::Reparsed => {
this.buffer_changes_tx.send(()).ok();
}
@@ -665,7 +665,7 @@ impl PickerDelegate for WorktreeListDelegate {
};
Some(
- ListItem::new(SharedString::from(format!("worktree-menu-{ix}")))
+ ListItem::new(format!("worktree-menu-{ix}"))
.inset(true)
.spacing(ListItemSpacing::Sparse)
.toggle_state(selected)
@@ -11,7 +11,7 @@ GPUI is still in active development as we work on the Zed code editor, and is st
gpui = { version = "*" }
```
- - [Ownership and data flow](src/_ownership_and_data_flow.rs)
+ - [Ownership and data flow](_ownership_and_data_flow)
Everything in GPUI starts with an `Application`. You can create one with `Application::new()`, and kick off your application by passing a callback to `Application::run()`. Inside this callback, you can create a new window with `App::open_window()`, and register your first root view. See [gpui.rs](https://www.gpui.rs/) for a complete example.
@@ -1,7 +1,7 @@
use gpui::{
Application, Background, Bounds, ColorSpace, Context, MouseDownEvent, Path, PathBuilder,
- PathStyle, Pixels, Point, Render, SharedString, StrokeOptions, Window, WindowOptions, canvas,
- div, linear_color_stop, linear_gradient, point, prelude::*, px, quad, rgb, size,
+ PathStyle, Pixels, Point, Render, StrokeOptions, Window, WindowOptions, canvas, div,
+ linear_color_stop, linear_gradient, point, prelude::*, px, quad, rgb, size,
};
struct PaintingViewer {
@@ -309,7 +309,7 @@ fn button(
on_click: impl Fn(&mut PaintingViewer, &mut Context<PaintingViewer>) + 'static,
) -> impl IntoElement {
div()
- .id(SharedString::from(text.to_string()))
+ .id(text.to_string())
.child(text.to_string())
.bg(gpui::black())
.text_color(gpui::white())
@@ -1,6 +1,6 @@
use gpui::{
- App, Application, Bounds, Context, KeyBinding, PromptButton, PromptLevel, SharedString, Timer,
- Window, WindowBounds, WindowKind, WindowOptions, actions, div, prelude::*, px, rgb, size,
+ App, Application, Bounds, Context, KeyBinding, PromptButton, PromptLevel, Timer, Window,
+ WindowBounds, WindowKind, WindowOptions, actions, div, prelude::*, px, rgb, size,
};
struct SubWindow {
@@ -9,7 +9,7 @@ struct SubWindow {
fn button(text: &str, on_click: impl Fn(&mut Window, &mut App) + 'static) -> impl IntoElement {
div()
- .id(SharedString::from(text.to_string()))
+ .id(text.to_string())
.flex_none()
.px_2()
.bg(rgb(0xf7f7f7))
@@ -43,6 +43,50 @@ pub(crate) const KEYRING_LABEL: &str = "zed-github-account";
const FILE_PICKER_PORTAL_MISSING: &str =
"Couldn't open file picker due to missing xdg-desktop-portal implementation.";
+#[cfg(any(feature = "x11", feature = "wayland"))]
+pub trait ResultExt {
+ type Ok;
+
+ fn notify_err(self, msg: &'static str) -> Self::Ok;
+}
+
+#[cfg(any(feature = "x11", feature = "wayland"))]
+impl<T> ResultExt for anyhow::Result<T> {
+ type Ok = T;
+
+ fn notify_err(self, msg: &'static str) -> T {
+ match self {
+ Ok(v) => v,
+ Err(e) => {
+ use ashpd::desktop::notification::{Notification, NotificationProxy, Priority};
+ use futures::executor::block_on;
+
+ let proxy = block_on(NotificationProxy::new()).expect(msg);
+
+ let notification_id = "dev.zed.Oops";
+ block_on(
+ proxy.add_notification(
+ notification_id,
+ Notification::new("Zed failed to launch")
+ .body(Some(
+ format!(
+ "{e:?}. See https://zed.dev/docs/linux for troubleshooting steps."
+ )
+ .as_str(),
+ ))
+ .priority(Priority::High)
+ .icon(ashpd::desktop::Icon::with_names(&[
+ "dialog-question-symbolic",
+ ])),
+ )
+ ).expect(msg);
+
+ panic!("{msg}");
+ }
+ }
+ }
+}
+
pub trait LinuxClient {
fn compositor_name(&self) -> &'static str;
fn with_common<R>(&self, f: impl FnOnce(&mut LinuxCommon) -> R) -> R;
@@ -605,8 +649,9 @@ pub(super) fn open_uri_internal(
.activation_token(activation_token.clone().map(ashpd::ActivationToken::from))
.send_uri(&uri)
.await
+ .and_then(|e| e.response())
{
- Ok(_) => return,
+ Ok(()) => return,
Err(e) => log::error!("Failed to open with dbus: {}", e),
}
@@ -17,7 +17,7 @@ use collections::HashMap;
use filedescriptor::Pipe;
use http_client::Url;
use smallvec::SmallVec;
-use util::ResultExt;
+use util::ResultExt as _;
use wayland_backend::client::ObjectId;
use wayland_backend::protocol::WEnum;
use wayland_client::event_created_child;
@@ -76,8 +76,8 @@ use crate::{
FileDropEvent, ForegroundExecutor, KeyDownEvent, KeyUpEvent, Keystroke, LinuxCommon,
LinuxKeyboardLayout, Modifiers, ModifiersChangedEvent, MouseButton, MouseDownEvent,
MouseExitEvent, MouseMoveEvent, MouseUpEvent, NavigationDirection, Pixels, PlatformDisplay,
- PlatformInput, PlatformKeyboardLayout, Point, SCROLL_LINES, ScrollDelta, ScrollWheelEvent,
- Size, TouchPhase, WindowParams, point, px, size,
+ PlatformInput, PlatformKeyboardLayout, Point, ResultExt as _, SCROLL_LINES, ScrollDelta,
+ ScrollWheelEvent, Size, TouchPhase, WindowParams, point, px, size,
};
use crate::{
LinuxDispatcher, RunnableVariant, TaskTiming,
@@ -531,7 +531,8 @@ impl WaylandClient {
})
.unwrap();
- let gpu_context = BladeContext::new().expect("Unable to init GPU context");
+ // This could be unified with the notification handling in zed/main:fail_to_open_window.
+ let gpu_context = BladeContext::new().notify_err("Unable to init GPU context");
let seat = seat.unwrap();
let globals = Globals::new(
@@ -1,4 +1,4 @@
-use crate::{Capslock, LinuxDispatcher, RunnableVariant, TaskTiming, xcb_flush};
+use crate::{Capslock, LinuxDispatcher, ResultExt as _, RunnableVariant, TaskTiming, xcb_flush};
use anyhow::{Context as _, anyhow};
use ashpd::WindowIdentifier;
use calloop::{
@@ -18,7 +18,7 @@ use std::{
rc::{Rc, Weak},
time::{Duration, Instant},
};
-use util::ResultExt;
+use util::ResultExt as _;
use x11rb::{
connection::{Connection, RequestConnection},
@@ -437,7 +437,7 @@ impl X11Client {
.to_string();
let keyboard_layout = LinuxKeyboardLayout::new(layout_name.into());
- let gpu_context = BladeContext::new().context("Unable to init GPU context")?;
+ let gpu_context = BladeContext::new().notify_err("Unable to init GPU context");
let resource_database = x11rb::resource_manager::new_from_default(&xcb_connection)
.context("Failed to create resource database")?;
@@ -5084,6 +5084,18 @@ impl From<SharedString> for ElementId {
}
}
+impl From<String> for ElementId {
+ fn from(name: String) -> Self {
+ ElementId::Name(name.into())
+ }
+}
+
+impl From<Arc<str>> for ElementId {
+ fn from(name: Arc<str>) -> Self {
+ ElementId::Name(name.into())
+ }
+}
+
impl From<Arc<std::path::Path>> for ElementId {
fn from(path: Arc<std::path::Path>) -> Self {
ElementId::Path(path)
@@ -28,7 +28,6 @@ http-body.workspace = true
http.workspace = true
log.workspace = true
parking_lot.workspace = true
-reqwest.workspace = true
serde.workspace = true
serde_json.workspace = true
serde_urlencoded.workspace = true
@@ -88,17 +88,6 @@ impl From<&'static str> for AsyncBody {
}
}
-impl TryFrom<reqwest::Body> for AsyncBody {
- type Error = anyhow::Error;
-
- fn try_from(value: reqwest::Body) -> Result<Self, Self::Error> {
- value
- .as_bytes()
- .ok_or_else(|| anyhow::anyhow!("Underlying data is a stream"))
- .map(|bytes| Self::from_bytes(Bytes::copy_from_slice(bytes)))
- }
-}
-
impl<T: Into<Self>> From<Option<T>> for AsyncBody {
fn from(body: Option<T>) -> Self {
match body {
@@ -8,10 +8,7 @@ use derive_more::Deref;
use http::HeaderValue;
pub use http::{self, Method, Request, Response, StatusCode, Uri, request::Builder};
-use futures::{
- FutureExt as _,
- future::{self, BoxFuture},
-};
+use futures::future::BoxFuture;
use parking_lot::Mutex;
use serde::Serialize;
use std::sync::Arc;
@@ -110,14 +107,6 @@ pub trait HttpClient: 'static + Send + Sync {
fn as_fake(&self) -> &FakeHttpClient {
panic!("called as_fake on {}", type_name::<Self>())
}
-
- fn send_multipart_form<'a>(
- &'a self,
- _url: &str,
- _request: reqwest::multipart::Form,
- ) -> BoxFuture<'a, anyhow::Result<Response<AsyncBody>>> {
- future::ready(Err(anyhow!("not implemented"))).boxed()
- }
}
/// An [`HttpClient`] that may have a proxy.
@@ -165,14 +154,6 @@ impl HttpClient for HttpClientWithProxy {
fn as_fake(&self) -> &FakeHttpClient {
self.client.as_fake()
}
-
- fn send_multipart_form<'a>(
- &'a self,
- url: &str,
- form: reqwest::multipart::Form,
- ) -> BoxFuture<'a, anyhow::Result<Response<AsyncBody>>> {
- self.client.send_multipart_form(url, form)
- }
}
/// An [`HttpClient`] that has a base URL.
@@ -306,14 +287,6 @@ impl HttpClient for HttpClientWithUrl {
fn as_fake(&self) -> &FakeHttpClient {
self.client.as_fake()
}
-
- fn send_multipart_form<'a>(
- &'a self,
- url: &str,
- request: reqwest::multipart::Form,
- ) -> BoxFuture<'a, anyhow::Result<Response<AsyncBody>>> {
- self.client.send_multipart_form(url, request)
- }
}
pub fn read_proxy_from_env() -> Option<Url> {
@@ -49,6 +49,7 @@ pub enum IconName {
BoltOutlined,
Book,
BookCopy,
+ Box,
CaseSensitive,
Chat,
Check,
@@ -1,8 +1,8 @@
pub mod row_chunk;
use crate::{
- DebuggerTextObject, LanguageScope, Outline, OutlineConfig, RunnableCapture, RunnableTag,
- TextObject, TreeSitterOptions,
+ DebuggerTextObject, LanguageScope, Outline, OutlineConfig, PLAIN_TEXT, RunnableCapture,
+ RunnableTag, TextObject, TreeSitterOptions,
diagnostic_set::{DiagnosticEntry, DiagnosticEntryRef, DiagnosticGroup},
language_settings::{LanguageSettings, language_settings},
outline::OutlineItem,
@@ -353,7 +353,8 @@ pub enum BufferEvent {
/// The buffer is in need of a reload
ReloadNeeded,
/// The buffer's language was changed.
- LanguageChanged,
+ /// The boolean indicates whether this buffer did not have a language before, but does now.
+ LanguageChanged(bool),
/// The buffer's syntax trees were updated.
Reparsed,
/// The buffer's diagnostics were updated.
@@ -1386,10 +1387,12 @@ impl Buffer {
) {
self.non_text_state_update_count += 1;
self.syntax_map.lock().clear(&self.text);
- self.language = language;
+ let old_language = std::mem::replace(&mut self.language, language);
self.was_changed();
self.reparse(cx, may_block);
- cx.emit(BufferEvent::LanguageChanged);
+ let has_fresh_language =
+ self.language.is_some() && old_language.is_none_or(|old| old == *PLAIN_TEXT);
+ cx.emit(BufferEvent::LanguageChanged(has_fresh_language));
}
/// Assign a language registry to the buffer. This allows the buffer to retrieve
@@ -136,6 +136,46 @@ pub static PLAIN_TEXT: LazyLock<Arc<Language>> = LazyLock::new(|| {
path_suffixes: vec!["txt".to_owned()],
first_line_pattern: None,
},
+ brackets: BracketPairConfig {
+ pairs: vec![
+ BracketPair {
+ start: "(".to_string(),
+ end: ")".to_string(),
+ close: true,
+ surround: true,
+ newline: false,
+ },
+ BracketPair {
+ start: "[".to_string(),
+ end: "]".to_string(),
+ close: true,
+ surround: true,
+ newline: false,
+ },
+ BracketPair {
+ start: "{".to_string(),
+ end: "}".to_string(),
+ close: true,
+ surround: true,
+ newline: false,
+ },
+ BracketPair {
+ start: "\"".to_string(),
+ end: "\"".to_string(),
+ close: true,
+ surround: true,
+ newline: false,
+ },
+ BracketPair {
+ start: "'".to_string(),
+ end: "'".to_string(),
+ close: true,
+ surround: true,
+ newline: false,
+ },
+ ],
+ disabled_scopes_by_bracket_ix: Default::default(),
+ },
..Default::default()
},
None,
@@ -1,5 +1,5 @@
use std::ops::Range;
-use std::path::PathBuf;
+use std::path::{Path, PathBuf};
use std::sync::Arc;
use anyhow::{Context as _, Result};
@@ -174,7 +174,32 @@ impl DynLspInstaller for ExtensionLspAdapter {
)
.await?;
- let path = self.extension.path_from_extension(command.command.as_ref());
+ // on windows, extensions might produce weird paths
+ // that start with a leading slash due to WASI
+ // requiring that for PWD and friends so account for
+ // that here and try to transform those paths back
+ // to windows paths
+ //
+ // if we don't do this, std will interpret the path as relative,
+ // which changes join behavior
+ let command_path: &Path = if cfg!(windows)
+ && let Some(command) = command.command.to_str()
+ {
+ let mut chars = command.chars();
+ if chars.next().is_some_and(|c| c == '/')
+ && chars.next().is_some_and(|c| c.is_ascii_alphabetic())
+ && chars.next().is_some_and(|c| c == ':')
+ && chars.next().is_some_and(|c| c == '\\' || c == '/')
+ {
+ // looks like a windows path with a leading slash, so strip it
+ command.strip_prefix('/').unwrap().as_ref()
+ } else {
+ command.as_ref()
+ }
+ } else {
+ command.command.as_ref()
+ };
+ let path = self.extension.path_from_extension(command_path);
// TODO: This should now be done via the `zed::make_file_executable` function in
// Zed extension API, but we're leaving these existing usages in place temporarily
@@ -193,7 +218,32 @@ impl DynLspInstaller for ExtensionLspAdapter {
Ok(LanguageServerBinary {
path,
- arguments: command.args.into_iter().map(|arg| arg.into()).collect(),
+ arguments: command
+ .args
+ .into_iter()
+ .map(|arg| {
+ // on windows, extensions might produce weird paths
+ // that start with a leading slash due to WASI
+ // requiring that for PWD and friends so account for
+ // that here and try to transform those paths back
+ // to windows paths
+ if cfg!(windows) {
+ let mut chars = arg.chars();
+ if chars.next().is_some_and(|c| c == '/')
+ && chars.next().is_some_and(|c| c.is_ascii_alphabetic())
+ && chars.next().is_some_and(|c| c == ':')
+ && chars.next().is_some_and(|c| c == '\\' || c == '/')
+ {
+ // looks like a windows path with a leading slash, so strip it
+ arg.strip_prefix('/').unwrap().into()
+ } else {
+ arg.into()
+ }
+ } else {
+ arg.into()
+ }
+ })
+ .collect(),
env: Some(command.env.into_iter().collect()),
})
})
@@ -332,9 +332,11 @@ pub fn into_deepseek(
model: &deepseek::Model,
max_output_tokens: Option<u64>,
) -> deepseek::Request {
- let is_reasoner = *model == deepseek::Model::Reasoner;
+ let is_reasoner = model == &deepseek::Model::Reasoner;
let mut messages = Vec::new();
+ let mut current_reasoning: Option<String> = None;
+
for message in request.messages {
for content in message.content {
match content {
@@ -343,10 +345,14 @@ pub fn into_deepseek(
Role::Assistant => deepseek::RequestMessage::Assistant {
content: Some(text),
tool_calls: Vec::new(),
+ reasoning_content: current_reasoning.take(),
},
Role::System => deepseek::RequestMessage::System { content: text },
}),
- MessageContent::Thinking { .. } => {}
+ MessageContent::Thinking { text, .. } => {
+ // Accumulate reasoning content for next assistant message
+ current_reasoning.get_or_insert_default().push_str(&text);
+ }
MessageContent::RedactedThinking(_) => {}
MessageContent::Image(_) => {}
MessageContent::ToolUse(tool_use) => {
@@ -369,6 +375,7 @@ pub fn into_deepseek(
messages.push(deepseek::RequestMessage::Assistant {
content: None,
tool_calls: vec![tool_call],
+ reasoning_content: current_reasoning.take(),
});
}
}
@@ -126,11 +126,11 @@ impl LspInstaller for EsLintLspAdapter {
}
self.node
- .run_npm_subcommand(&repo_root, "install", &[])
+ .run_npm_subcommand(Some(&repo_root), "install", &[])
.await?;
self.node
- .run_npm_subcommand(&repo_root, "run-script", &["compile"])
+ .run_npm_subcommand(Some(&repo_root), "run-script", &["compile"])
.await?;
}
@@ -73,7 +73,9 @@ impl LspInstaller for GoLspAdapter {
delegate.show_notification(NOTIFICATION_MESSAGE, cx);
})?
}
- anyhow::bail!("cannot install gopls");
+ anyhow::bail!(
+ "Could not install the Go language server `gopls`, because `go` was not found."
+ );
}
let release =
@@ -47,45 +47,10 @@
left: (identifier) @function
right: [(function_expression) (arrow_function)])
-; Parameters
-
-(required_parameter
- (identifier) @variable.parameter)
-
-(required_parameter
- (_
- ([
- (identifier)
- (shorthand_property_identifier_pattern)
- ]) @variable.parameter))
-
-(optional_parameter
- (identifier) @variable.parameter)
-
-(optional_parameter
- (_
- ([
- (identifier)
- (shorthand_property_identifier_pattern)
- ]) @variable.parameter))
-
-(catch_clause
- parameter: (identifier) @variable.parameter)
-
-(index_signature
- name: (identifier) @variable.parameter)
-
-(arrow_function
- parameter: (identifier) @variable.parameter)
-
; Special identifiers
-;
-(class_declaration
- (type_identifier) @type.class)
-
-(extends_clause
- value: (identifier) @type.class)
+((identifier) @type
+ (#match? @type "^[A-Z]"))
(type_identifier) @type
(predefined_type) @type.builtin
@@ -1,3 +1,2 @@
(tag_name) @keyword.jsdoc
(type) @type.jsdoc
-(identifier) @variable.jsdoc
@@ -4,9 +4,10 @@ path_suffixes = ["json", "flake.lock"]
line_comments = ["// "]
autoclose_before = ",]}"
brackets = [
- { start = "{", end = "}", close = true, newline = true },
- { start = "[", end = "]", close = true, newline = true },
- { start = "\"", end = "\"", close = true, newline = false, not_in = ["string"] },
+ { start = "{", end = "}", close = true, surround = true, newline = true },
+ { start = "[", end = "]", close = true, surround = true, newline = true },
+ { start = "(", end = ")", close = true, surround = true, newline = false },
+ { start = "\"", end = "\"", close = true, surround = true, newline = false, not_in = ["string"] },
]
tab_size = 2
prettier_parser_name = "json"
@@ -4,9 +4,10 @@ path_suffixes = ["jsonc", "bun.lock", "tsconfig.json", "pyrightconfig.json"]
line_comments = ["// "]
autoclose_before = ",]}"
brackets = [
- { start = "{", end = "}", close = true, newline = true },
- { start = "[", end = "]", close = true, newline = true },
- { start = "\"", end = "\"", close = true, newline = false, not_in = ["string"] },
+ { start = "{", end = "}", close = true, surround = true, newline = true },
+ { start = "[", end = "]", close = true, surround = true, newline = true },
+ { start = "(", end = ")", close = true, surround = true, newline = false },
+ { start = "\"", end = "\"", close = true, surround = true, newline = false, not_in = ["string"] },
]
tab_size = 2
prettier_parser_name = "jsonc"
@@ -1344,7 +1344,7 @@ impl ToolchainLister for PythonToolchainProvider {
ShellKind::Fish => Some(format!("\"{pyenv}\" shell - fish {version}")),
ShellKind::Posix => Some(format!("\"{pyenv}\" shell - sh {version}")),
ShellKind::Nushell => Some(format!("^\"{pyenv}\" shell - nu {version}")),
- ShellKind::PowerShell => None,
+ ShellKind::PowerShell | ShellKind::Pwsh => None,
ShellKind::Csh => None,
ShellKind::Tcsh => None,
ShellKind::Cmd => None,
@@ -140,13 +140,6 @@ impl LspAdapter for TailwindLspAdapter {
) -> Result<Option<serde_json::Value>> {
Ok(Some(json!({
"provideFormatter": true,
- "userLanguages": {
- "html": "html",
- "css": "css",
- "javascript": "javascript",
- "typescript": "typescript",
- "typescriptreact": "typescriptreact",
- },
})))
}
@@ -167,8 +160,18 @@ impl LspAdapter for TailwindLspAdapter {
tailwind_user_settings["emmetCompletions"] = Value::Bool(true);
}
+ if tailwind_user_settings.get("includeLanguages").is_none() {
+ tailwind_user_settings["includeLanguages"] = json!({
+ "html": "html",
+ "css": "css",
+ "javascript": "javascript",
+ "typescript": "typescript",
+ "typescriptreact": "typescriptreact",
+ });
+ }
+
Ok(json!({
- "tailwindCSS": tailwind_user_settings,
+ "tailwindCSS": tailwind_user_settings
}))
}
@@ -47,68 +47,13 @@
left: (identifier) @function
right: [(function_expression) (arrow_function)])
-; Parameters
-
-(required_parameter
- (identifier) @variable.parameter)
-
-(required_parameter
- (_
- ([
- (identifier)
- (shorthand_property_identifier_pattern)
- ]) @variable.parameter))
-
-(optional_parameter
- (identifier) @variable.parameter)
-
-(optional_parameter
- (_
- ([
- (identifier)
- (shorthand_property_identifier_pattern)
- ]) @variable.parameter))
-
-(catch_clause
- parameter: (identifier) @variable.parameter)
-
-(index_signature
- name: (identifier) @variable.parameter)
-
-(arrow_function
- parameter: (identifier) @variable.parameter)
-
-(type_predicate
- name: (identifier) @variable.parameter)
-
; Special identifiers
-(type_annotation) @type
+((identifier) @type
+ (#match? @type "^[A-Z]"))
(type_identifier) @type
(predefined_type) @type.builtin
-(type_alias_declaration
- (type_identifier) @type)
-
-(type_alias_declaration
- value: (_
- (type_identifier) @type))
-
-(interface_declaration
- (type_identifier) @type)
-
-(class_declaration
- (type_identifier) @type.class)
-
-(extends_clause
- value: (identifier) @type.class)
-
-(extends_type_clause
- type: (type_identifier) @type)
-
-(implements_clause
- (type_identifier) @type)
-
([
(identifier)
(shorthand_property_identifier)
@@ -286,42 +231,8 @@
"<" @punctuation.bracket
">" @punctuation.bracket)
-(type_parameters
- "<" @punctuation.bracket
- ">" @punctuation.bracket)
-
(decorator "@" @punctuation.special)
-(union_type
- ("|") @punctuation.special)
-
-(intersection_type
- ("&") @punctuation.special)
-
-(type_annotation
- (":") @punctuation.special)
-
-(index_signature
- (":") @punctuation.special)
-
-(type_predicate_annotation
- (":") @punctuation.special)
-
-(public_field_definition
- ("?") @punctuation.special)
-
-(property_signature
- ("?") @punctuation.special)
-
-(method_signature
- ("?") @punctuation.special)
-
-(optional_parameter
- ([
- "?"
- ":"
- ]) @punctuation.special)
-
; Keywords
[ "abstract"
@@ -4,33 +4,11 @@
; Special identifiers
-(type_annotation) @type
-
+((identifier) @type
+ (#match? @type "^[A-Z]"))
(type_identifier) @type
(predefined_type) @type.builtin
-(type_alias_declaration
- (type_identifier) @type)
-
-(type_alias_declaration
- value: (_
- (type_identifier) @type))
-
-(interface_declaration
- (type_identifier) @type)
-
-(class_declaration
- (type_identifier) @type.class)
-
-(extends_clause
- value: (identifier) @type.class)
-
-(extends_type_clause
- type: (type_identifier) @type)
-
-(implements_clause
- (type_identifier) @type)
-
;; Enables ts-pretty-errors
;; The Lsp returns "snippets" of typescript, which are not valid typescript in totality,
;; but should still be highlighted
@@ -136,40 +114,6 @@
(arrow_function) @function
-; Parameters
-
-(required_parameter
- (identifier) @variable.parameter)
-
-(required_parameter
- (_
- ([
- (identifier)
- (shorthand_property_identifier_pattern)
- ]) @variable.parameter))
-
-(optional_parameter
- (identifier) @variable.parameter)
-
-(optional_parameter
- (_
- ([
- (identifier)
- (shorthand_property_identifier_pattern)
- ]) @variable.parameter))
-
-(catch_clause
- parameter: (identifier) @variable.parameter)
-
-(index_signature
- name: (identifier) @variable.parameter)
-
-(arrow_function
- parameter: (identifier) @variable.parameter)
-
-(type_predicate
- name: (identifier) @variable.parameter)
-
; Literals
(this) @variable.special
@@ -300,42 +244,8 @@
"<" @punctuation.bracket
">" @punctuation.bracket)
-(type_parameters
- "<" @punctuation.bracket
- ">" @punctuation.bracket)
-
(decorator "@" @punctuation.special)
-(union_type
- ("|") @punctuation.special)
-
-(intersection_type
- ("&") @punctuation.special)
-
-(type_annotation
- (":") @punctuation.special)
-
-(index_signature
- (":") @punctuation.special)
-
-(type_predicate_annotation
- (":") @punctuation.special)
-
-(public_field_definition
- ("?") @punctuation.special)
-
-(property_signature
- ("?") @punctuation.special)
-
-(method_signature
- ("?") @punctuation.special)
-
-(optional_parameter
- ([
- "?"
- ":"
- ]) @punctuation.special)
-
; Keywords
[
@@ -378,7 +378,7 @@ impl Render for LivekitWindow {
.when_some(state.audio_output_stream.as_ref(), |el, state| {
el.child(
button()
- .id(SharedString::from(identity.0.clone()))
+ .id(identity.0.clone())
.child(if state.0.is_enabled() {
"Deafen"
} else {
@@ -129,7 +129,7 @@ pub enum Event {
transaction_id: TransactionId,
},
Reloaded,
- LanguageChanged(BufferId),
+ LanguageChanged(BufferId, bool),
Reparsed(BufferId),
Saved,
FileHandleChanged,
@@ -2294,7 +2294,9 @@ impl MultiBuffer {
BufferEvent::Saved => Event::Saved,
BufferEvent::FileHandleChanged => Event::FileHandleChanged,
BufferEvent::Reloaded => Event::Reloaded,
- BufferEvent::LanguageChanged => Event::LanguageChanged(buffer_id),
+ BufferEvent::LanguageChanged(has_language) => {
+ Event::LanguageChanged(buffer_id, *has_language)
+ }
BufferEvent::Reparsed => Event::Reparsed(buffer_id),
BufferEvent::DiagnosticsUpdated => Event::DiagnosticsUpdated,
BufferEvent::CapabilityChanged => {
@@ -206,14 +206,14 @@ impl NodeRuntime {
pub async fn run_npm_subcommand(
&self,
- directory: &Path,
+ directory: Option<&Path>,
subcommand: &str,
args: &[&str],
) -> Result<Output> {
let http = self.0.lock().await.http.clone();
self.instance()
.await
- .run_npm_subcommand(Some(directory), http.proxy(), subcommand, args)
+ .run_npm_subcommand(directory, http.proxy(), subcommand, args)
.await
}
@@ -283,7 +283,7 @@ impl NodeRuntime {
]);
// This is also wrong because the directory is wrong.
- self.run_npm_subcommand(directory, "install", &arguments)
+ self.run_npm_subcommand(Some(directory), "install", &arguments)
.await?;
Ok(())
}
@@ -559,7 +559,10 @@ impl NodeRuntimeTrait for ManagedNodeRuntime {
command.env("PATH", env_path);
command.env(NODE_CA_CERTS_ENV_VAR, node_ca_certs);
command.arg(npm_file).arg(subcommand);
- command.args(["--cache".into(), self.installation_path.join("cache")]);
+ command.arg(format!(
+ "--cache={}",
+ self.installation_path.join("cache").display()
+ ));
command.args([
"--userconfig".into(),
self.installation_path.join("blank_user_npmrc"),
@@ -703,7 +706,10 @@ impl NodeRuntimeTrait for SystemNodeRuntime {
.env("PATH", path)
.env(NODE_CA_CERTS_ENV_VAR, node_ca_certs)
.arg(subcommand)
- .args(["--cache".into(), self.scratch_dir.join("cache")])
+ .arg(format!(
+ "--cache={}",
+ self.scratch_dir.join("cache").display()
+ ))
.args(args);
configure_npm_command(&mut command, directory, proxy);
let output = command.output().await?;
@@ -408,6 +408,12 @@ pub fn remote_servers_dir() -> &'static PathBuf {
REMOTE_SERVERS_DIR.get_or_init(|| data_dir().join("remote_servers"))
}
+/// Returns the path to the directory where the devcontainer CLI is installed.
+pub fn devcontainer_dir() -> &'static PathBuf {
+ static DEVCONTAINER_DIR: OnceLock<PathBuf> = OnceLock::new();
+ DEVCONTAINER_DIR.get_or_init(|| data_dir().join("devcontainer"))
+}
+
/// Returns the relative path to a `.zed` folder within a project.
pub fn local_settings_folder_name() -> &'static str {
".zed"
@@ -97,6 +97,18 @@ pub trait PickerDelegate: Sized + 'static {
window: &mut Window,
cx: &mut Context<Picker<Self>>,
);
+
+ /// Called before the picker handles `SelectPrevious` or `SelectNext`. Return `Some(query)` to
+ /// set a new query and prevent the default selection behavior.
+ fn select_history(
+ &mut self,
+ _direction: Direction,
+ _query: &str,
+ _window: &mut Window,
+ _cx: &mut App,
+ ) -> Option<String> {
+ None
+ }
fn can_select(
&mut self,
_ix: usize,
@@ -448,6 +460,14 @@ impl<D: PickerDelegate> Picker<D> {
window: &mut Window,
cx: &mut Context<Self>,
) {
+ let query = self.query(cx);
+ if let Some(query) = self
+ .delegate
+ .select_history(Direction::Down, &query, window, cx)
+ {
+ self.set_query(query, window, cx);
+ return;
+ }
let count = self.delegate.match_count();
if count > 0 {
let index = self.delegate.selected_index();
@@ -467,6 +487,14 @@ impl<D: PickerDelegate> Picker<D> {
window: &mut Window,
cx: &mut Context<Self>,
) {
+ let query = self.query(cx);
+ if let Some(query) = self
+ .delegate
+ .select_history(Direction::Up, &query, window, cx)
+ {
+ self.set_query(query, window, cx);
+ return;
+ }
let count = self.delegate.match_count();
if count > 0 {
let index = self.delegate.selected_index();
@@ -137,6 +137,7 @@ pub struct AgentServerStore {
state: AgentServerStoreState,
external_agents: HashMap<ExternalAgentServerName, Box<dyn ExternalAgentServer>>,
agent_icons: HashMap<ExternalAgentServerName, SharedString>,
+ agent_display_names: HashMap<ExternalAgentServerName, SharedString>,
}
pub struct AgentServersUpdated;
@@ -155,6 +156,7 @@ mod ext_agent_tests {
state: AgentServerStoreState::Collab,
external_agents: HashMap::default(),
agent_icons: HashMap::default(),
+ agent_display_names: HashMap::default(),
}
}
@@ -258,6 +260,7 @@ impl AgentServerStore {
self.external_agents.retain(|name, agent| {
if agent.downcast_mut::<LocalExtensionArchiveAgent>().is_some() {
self.agent_icons.remove(name);
+ self.agent_display_names.remove(name);
false
} else {
// Keep the hardcoded external agents that don't come from extensions
@@ -275,6 +278,12 @@ impl AgentServerStore {
for (ext_id, manifest) in manifests {
for (agent_name, agent_entry) in &manifest.agent_servers {
// Store absolute icon path if provided, resolving symlinks for dev extensions
+ // Store display name from manifest
+ self.agent_display_names.insert(
+ ExternalAgentServerName(agent_name.clone().into()),
+ SharedString::from(agent_entry.name.clone()),
+ );
+
let icon_path = if let Some(icon) = &agent_entry.icon {
let icon_path = extensions_dir.join(ext_id).join(icon);
// Canonicalize to resolve symlinks (dev extensions are symlinked)
@@ -310,6 +319,12 @@ impl AgentServerStore {
let mut agents = vec![];
for (ext_id, manifest) in manifests {
for (agent_name, agent_entry) in &manifest.agent_servers {
+ // Store display name from manifest
+ self.agent_display_names.insert(
+ ExternalAgentServerName(agent_name.clone().into()),
+ SharedString::from(agent_entry.name.clone()),
+ );
+
// Store absolute icon path if provided, resolving symlinks for dev extensions
let icon = if let Some(icon) = &agent_entry.icon {
let icon_path = extensions_dir.join(ext_id).join(icon);
@@ -369,6 +384,10 @@ impl AgentServerStore {
self.agent_icons.get(name).cloned()
}
+ pub fn agent_display_name(&self, name: &ExternalAgentServerName) -> Option<SharedString> {
+ self.agent_display_names.get(name).cloned()
+ }
+
pub fn init_remote(session: &AnyProtoClient) {
session.add_entity_message_handler(Self::handle_external_agents_updated);
session.add_entity_message_handler(Self::handle_loading_status_updated);
@@ -559,6 +578,7 @@ impl AgentServerStore {
},
external_agents: Default::default(),
agent_icons: Default::default(),
+ agent_display_names: Default::default(),
};
if let Some(_events) = extension::ExtensionEvents::try_global(cx) {}
this.agent_servers_settings_changed(cx);
@@ -609,6 +629,7 @@ impl AgentServerStore {
},
external_agents: external_agents.into_iter().collect(),
agent_icons: HashMap::default(),
+ agent_display_names: HashMap::default(),
}
}
@@ -617,6 +638,7 @@ impl AgentServerStore {
state: AgentServerStoreState::Collab,
external_agents: Default::default(),
agent_icons: Default::default(),
+ agent_display_names: Default::default(),
}
}
@@ -2040,6 +2062,7 @@ mod extension_agent_tests {
state: AgentServerStoreState::Collab,
external_agents: HashMap::default(),
agent_icons: HashMap::default(),
+ agent_display_names: HashMap::default(),
};
// Seed with extension agents (contain ": ") and custom agents (don't contain ": ")
@@ -620,9 +620,21 @@ impl LocalBufferStore {
let load_file = worktree.update(cx, |worktree, cx| worktree.load_file(path.as_ref(), cx));
cx.spawn(async move |this, cx| {
let path = path.clone();
- let buffer = match load_file.await.with_context(|| {
- format!("Could not open path: {}", path.display(PathStyle::local()))
- }) {
+ let single_file_path = cx.update(|cx| {
+ if worktree.read(cx).is_single_file() {
+ Some(worktree.read(cx).abs_path())
+ } else {
+ None
+ }
+ })?;
+ let path_string = single_file_path
+ .as_ref()
+ .map(|path| path.to_string_lossy())
+ .unwrap_or_else(|| path.display(PathStyle::local()));
+ let buffer = match load_file
+ .await
+ .with_context(|| format!("Opening path \"{path_string}\""))
+ {
Ok(loaded) => {
let reservation = cx.reserve_entity::<Buffer>()?;
let buffer_id = BufferId::from(reservation.entity_id().as_non_zero_u64());
@@ -1129,7 +1141,7 @@ impl BufferStore {
})
.log_err();
}
- BufferEvent::LanguageChanged => {}
+ BufferEvent::LanguageChanged(_) => {}
_ => {}
}
}
@@ -411,11 +411,11 @@ impl ContextServerStore {
) {
self.stop_server(&id, cx).log_err();
}
-
let task = cx.spawn({
let id = server.id();
let server = server.clone();
let configuration = configuration.clone();
+
async move |this, cx| {
match server.clone().start(cx).await {
Ok(_) => {
@@ -1451,7 +1451,7 @@ impl GitStore {
match event {
BufferStoreEvent::BufferAdded(buffer) => {
cx.subscribe(buffer, |this, buffer, event, cx| {
- if let BufferEvent::LanguageChanged = event {
+ if let BufferEvent::LanguageChanged(_) = event {
let buffer_id = buffer.read(cx).remote_id();
if let Some(diff_state) = this.diffs.get(&buffer_id) {
diff_state.update(cx, |diff_state, cx| {
@@ -4692,11 +4692,9 @@ impl Repository {
});
let this = cx.weak_entity();
- let rx = self.run_hook(RunHook::PrePush, cx);
self.send_job(
Some(format!("git push {} {} {}", args, remote, branch).into()),
move |git_repo, mut cx| async move {
- rx.await??;
match git_repo {
RepositoryState::Local(LocalRepositoryState {
backend,
@@ -1,118 +0,0 @@
-use std::{path::Path, sync::Arc};
-
-use gpui::{EventEmitter, FocusHandle, Focusable};
-use ui::{
- App, Button, ButtonCommon, ButtonStyle, Clickable, Context, FluentBuilder, InteractiveElement,
- KeyBinding, Label, LabelCommon, LabelSize, ParentElement, Render, SharedString, Styled as _,
- Window, h_flex, v_flex,
-};
-use zed_actions::workspace::OpenWithSystem;
-
-use crate::Item;
-
-/// A view to display when a certain buffer fails to open.
-#[derive(Debug)]
-pub struct InvalidItemView {
- /// Which path was attempted to open.
- pub abs_path: Arc<Path>,
- /// An error message, happened when opening the buffer.
- pub error: SharedString,
- is_local: bool,
- focus_handle: FocusHandle,
-}
-
-impl InvalidItemView {
- pub fn new(
- abs_path: &Path,
- is_local: bool,
- e: &anyhow::Error,
- _: &mut Window,
- cx: &mut App,
- ) -> Self {
- Self {
- is_local,
- abs_path: Arc::from(abs_path),
- error: format!("{}", e.root_cause()).into(),
- focus_handle: cx.focus_handle(),
- }
- }
-}
-
-impl Item for InvalidItemView {
- type Event = ();
-
- fn tab_content_text(&self, mut detail: usize, _: &App) -> SharedString {
- // Ensure we always render at least the filename.
- detail += 1;
-
- let path = self.abs_path.as_ref();
-
- let mut prefix = path;
- while detail > 0 {
- if let Some(parent) = prefix.parent() {
- prefix = parent;
- detail -= 1;
- } else {
- break;
- }
- }
-
- let path = if detail > 0 {
- path
- } else {
- path.strip_prefix(prefix).unwrap_or(path)
- };
-
- SharedString::new(path.to_string_lossy())
- }
-}
-
-impl EventEmitter<()> for InvalidItemView {}
-
-impl Focusable for InvalidItemView {
- fn focus_handle(&self, _: &App) -> FocusHandle {
- self.focus_handle.clone()
- }
-}
-
-impl Render for InvalidItemView {
- fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl gpui::IntoElement {
- let abs_path = self.abs_path.clone();
- v_flex()
- .size_full()
- .track_focus(&self.focus_handle(cx))
- .flex_none()
- .justify_center()
- .overflow_hidden()
- .key_context("InvalidBuffer")
- .child(
- h_flex().size_full().justify_center().child(
- v_flex()
- .justify_center()
- .gap_2()
- .child(h_flex().justify_center().child("Could not open file"))
- .child(
- h_flex()
- .justify_center()
- .child(Label::new(self.error.clone()).size(LabelSize::Small)),
- )
- .when(self.is_local, |contents| {
- contents.child(
- h_flex().justify_center().child(
- Button::new("open-with-system", "Open in Default App")
- .on_click(move |_, _, cx| {
- cx.open_with_system(&abs_path);
- })
- .style(ButtonStyle::Outlined)
- .key_binding(KeyBinding::for_action(
- &OpenWithSystem,
- window,
- cx,
- )),
- ),
- )
- }),
- ),
- )
- }
-}
@@ -219,7 +219,7 @@ struct UnifiedLanguageServer {
project_roots: HashSet<Arc<RelPath>>,
}
-#[derive(Clone, Hash, PartialEq, Eq)]
+#[derive(Clone, Debug, Hash, PartialEq, Eq)]
struct LanguageServerSeed {
worktree_id: WorktreeId,
name: LanguageServerName,
@@ -1142,7 +1142,7 @@ impl ProjectPanel {
)
.when(has_git_repo, |menu| {
menu.separator()
- .action("File History", Box::new(git::FileHistory))
+ .action("View File History", Box::new(git::FileHistory))
})
.when(!should_hide_rename, |menu| {
menu.separator().action("Rename", Box::new(Rename))
@@ -1663,12 +1663,20 @@ impl ProjectPanel {
let edit_state = self.state.edit_state.as_mut()?;
let worktree_id = edit_state.worktree_id;
let is_new_entry = edit_state.is_new_entry();
- let filename = self.filename_editor.read(cx).text(cx);
+ let mut filename = self.filename_editor.read(cx).text(cx);
+ let path_style = self.project.read(cx).path_style(cx);
+ if path_style.is_windows() {
+ // on windows, trailing dots are ignored in paths
+ // this can cause project panel to create a new entry with a trailing dot
+ // while the actual one without the dot gets populated by the file watcher
+ while let Some(trimmed) = filename.strip_suffix('.') {
+ filename = trimmed.to_string();
+ }
+ }
if filename.trim().is_empty() {
return None;
}
- let path_style = self.project.read(cx).path_style(cx);
let filename_indicates_dir = if path_style.is_windows() {
filename.ends_with('/') || filename.ends_with('\\')
} else {
@@ -6612,6 +6612,74 @@ async fn test_create_entries_without_selection_hide_root(cx: &mut gpui::TestAppC
);
}
+#[cfg(windows)]
+#[gpui::test]
+async fn test_create_entry_with_trailing_dot_windows(cx: &mut gpui::TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ path!("/root"),
+ json!({
+ "dir1": {
+ "file1.txt": "",
+ },
+ }),
+ )
+ .await;
+
+ let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
+ let workspace = cx.add_window(|window, cx| Workspace::test_new(project.clone(), window, cx));
+ let cx = &mut VisualTestContext::from_window(*workspace, cx);
+
+ let panel = workspace
+ .update(cx, |workspace, window, cx| {
+ let panel = ProjectPanel::new(workspace, window, cx);
+ workspace.add_panel(panel.clone(), window, cx);
+ panel
+ })
+ .unwrap();
+ cx.run_until_parked();
+
+ #[rustfmt::skip]
+ assert_eq!(
+ visible_entries_as_strings(&panel, 0..20, cx),
+ &[
+ "v root",
+ " > dir1",
+ ],
+ "Initial state with nothing selected"
+ );
+
+ panel.update_in(cx, |panel, window, cx| {
+ panel.new_file(&NewFile, window, cx);
+ });
+ cx.run_until_parked();
+ panel.update_in(cx, |panel, window, cx| {
+ assert!(panel.filename_editor.read(cx).is_focused(window));
+ });
+ panel
+ .update_in(cx, |panel, window, cx| {
+ panel
+ .filename_editor
+ .update(cx, |editor, cx| editor.set_text("foo.", window, cx));
+ panel.confirm_edit(true, window, cx).unwrap()
+ })
+ .await
+ .unwrap();
+ cx.run_until_parked();
+ #[rustfmt::skip]
+ assert_eq!(
+ visible_entries_as_strings(&panel, 0..20, cx),
+ &[
+ "v root",
+ " > dir1",
+ " foo <== selected <== marked",
+ ],
+ "A new file is created under the root directory without the trailing dot"
+ );
+}
+
#[gpui::test]
async fn test_highlight_entry_for_external_drag(cx: &mut gpui::TestAppContext) {
init_test(cx);
@@ -580,7 +580,7 @@ message GitCreateWorktree {
message RunGitHook {
enum GitHook {
PRE_COMMIT = 0;
- PRE_PUSH = 1;
+ reserved 1;
}
uint64 project_id = 1;
@@ -16,6 +16,7 @@ doctest = false
anyhow.workspace = true
askpass.workspace = true
auto_update.workspace = true
+db.workspace = true
editor.workspace = true
extension_host.workspace = true
file_finder.workspace = true
@@ -26,6 +27,7 @@ language.workspace = true
log.workspace = true
markdown.workspace = true
menu.workspace = true
+node_runtime.workspace = true
ordered-float.workspace = true
paths.workspace = true
picker.workspace = true
@@ -34,6 +36,7 @@ release_channel.workspace = true
remote.workspace = true
semver.workspace = true
serde.workspace = true
+serde_json.workspace = true
settings.workspace = true
smol.workspace = true
task.workspace = true
@@ -42,6 +45,7 @@ theme.workspace = true
ui.workspace = true
util.workspace = true
workspace.workspace = true
+worktree.workspace = true
zed_actions.workspace = true
indoc.workspace = true
@@ -0,0 +1,295 @@
+use std::path::{Path, PathBuf};
+use std::sync::Arc;
+
+use gpui::AsyncWindowContext;
+use node_runtime::NodeRuntime;
+use serde::Deserialize;
+use settings::DevContainerConnection;
+use smol::fs;
+use workspace::Workspace;
+
+use crate::remote_connections::Connection;
+
+#[derive(Debug, Deserialize)]
+#[serde(rename_all = "camelCase")]
+struct DevContainerUp {
+ _outcome: String,
+ container_id: String,
+ _remote_user: String,
+ remote_workspace_folder: String,
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(rename_all = "camelCase")]
+struct DevContainerConfiguration {
+ name: Option<String>,
+}
+
+#[derive(Debug, Deserialize)]
+struct DevContainerConfigurationOutput {
+ configuration: DevContainerConfiguration,
+}
+
+#[cfg(not(target_os = "windows"))]
+fn dev_container_cli() -> String {
+ "devcontainer".to_string()
+}
+
+#[cfg(target_os = "windows")]
+fn dev_container_cli() -> String {
+ "devcontainer.cmd".to_string()
+}
+
+async fn check_for_docker() -> Result<(), DevContainerError> {
+ let mut command = util::command::new_smol_command("docker");
+ command.arg("--version");
+
+ match command.output().await {
+ Ok(_) => Ok(()),
+ Err(e) => {
+ log::error!("Unable to find docker in $PATH: {:?}", e);
+ Err(DevContainerError::DockerNotAvailable)
+ }
+ }
+}
+
+async fn ensure_devcontainer_cli(node_runtime: NodeRuntime) -> Result<PathBuf, DevContainerError> {
+ let mut command = util::command::new_smol_command(&dev_container_cli());
+ command.arg("--version");
+
+ if let Err(e) = command.output().await {
+ log::error!(
+ "Unable to find devcontainer CLI in $PATH. Checking for a zed installed version. Error: {:?}",
+ e
+ );
+
+ let datadir_cli_path = paths::devcontainer_dir()
+ .join("node_modules")
+ .join(".bin")
+ .join(&dev_container_cli());
+
+ let mut command =
+ util::command::new_smol_command(&datadir_cli_path.as_os_str().display().to_string());
+ command.arg("--version");
+
+ if let Err(e) = command.output().await {
+ log::error!(
+ "Unable to find devcontainer CLI in Data dir. Will try to install. Error: {:?}",
+ e
+ );
+ } else {
+ log::info!("Found devcontainer CLI in Data dir");
+ return Ok(datadir_cli_path.clone());
+ }
+
+ if let Err(e) = fs::create_dir_all(paths::devcontainer_dir()).await {
+ log::error!("Unable to create devcontainer directory. Error: {:?}", e);
+ return Err(DevContainerError::DevContainerCliNotAvailable);
+ }
+
+ if let Err(e) = node_runtime
+ .npm_install_packages(
+ &paths::devcontainer_dir(),
+ &[("@devcontainers/cli", "latest")],
+ )
+ .await
+ {
+ log::error!(
+ "Unable to install devcontainer CLI to data directory. Error: {:?}",
+ e
+ );
+ return Err(DevContainerError::DevContainerCliNotAvailable);
+ };
+
+ let mut command = util::command::new_smol_command(&datadir_cli_path.display().to_string());
+ command.arg("--version");
+ if let Err(e) = command.output().await {
+ log::error!(
+ "Unable to find devcontainer cli after NPM install. Error: {:?}",
+ e
+ );
+ Err(DevContainerError::DevContainerCliNotAvailable)
+ } else {
+ Ok(datadir_cli_path)
+ }
+ } else {
+ log::info!("Found devcontainer cli on $PATH, using it");
+ Ok(PathBuf::from(&dev_container_cli()))
+ }
+}
+
+async fn devcontainer_up(
+ path_to_cli: &PathBuf,
+ path: Arc<Path>,
+) -> Result<DevContainerUp, DevContainerError> {
+ let mut command = util::command::new_smol_command(path_to_cli.display().to_string());
+ command.arg("up");
+ command.arg("--workspace-folder");
+ command.arg(path.display().to_string());
+
+ match command.output().await {
+ Ok(output) => {
+ if output.status.success() {
+ let raw = String::from_utf8_lossy(&output.stdout);
+ serde_json::from_str::<DevContainerUp>(&raw).map_err(|e| {
+ log::error!(
+ "Unable to parse response from 'devcontainer up' command, error: {:?}",
+ e
+ );
+ DevContainerError::DevContainerParseFailed
+ })
+ } else {
+ log::error!(
+ "Non-success status running devcontainer up for workspace: out: {:?}, err: {:?}",
+ String::from_utf8_lossy(&output.stdout),
+ String::from_utf8_lossy(&output.stderr)
+ );
+ Err(DevContainerError::DevContainerUpFailed)
+ }
+ }
+ Err(e) => {
+ log::error!("Error running devcontainer up: {:?}", e);
+ Err(DevContainerError::DevContainerUpFailed)
+ }
+ }
+}
+
+async fn devcontainer_read_configuration(
+ path_to_cli: &PathBuf,
+ path: Arc<Path>,
+) -> Result<DevContainerConfigurationOutput, DevContainerError> {
+ let mut command = util::command::new_smol_command(path_to_cli.display().to_string());
+ command.arg("read-configuration");
+ command.arg("--workspace-folder");
+ command.arg(path.display().to_string());
+ match command.output().await {
+ Ok(output) => {
+ if output.status.success() {
+ let raw = String::from_utf8_lossy(&output.stdout);
+ serde_json::from_str::<DevContainerConfigurationOutput>(&raw).map_err(|e| {
+ log::error!(
+ "Unable to parse response from 'devcontainer read-configuration' command, error: {:?}",
+ e
+ );
+ DevContainerError::DevContainerParseFailed
+ })
+ } else {
+ log::error!(
+ "Non-success status running devcontainer read-configuration for workspace: out: {:?}, err: {:?}",
+ String::from_utf8_lossy(&output.stdout),
+ String::from_utf8_lossy(&output.stderr)
+ );
+ Err(DevContainerError::DevContainerUpFailed)
+ }
+ }
+ Err(e) => {
+ log::error!("Error running devcontainer read-configuration: {:?}", e);
+ Err(DevContainerError::DevContainerUpFailed)
+ }
+ }
+}
+
+// Name the project with two fallbacks
+async fn get_project_name(
+ path_to_cli: &PathBuf,
+ path: Arc<Path>,
+ remote_workspace_folder: String,
+ container_id: String,
+) -> Result<String, DevContainerError> {
+ if let Ok(dev_container_configuration) =
+ devcontainer_read_configuration(path_to_cli, path).await
+ && let Some(name) = dev_container_configuration.configuration.name
+ {
+ // Ideally, name the project after the name defined in devcontainer.json
+ Ok(name)
+ } else {
+ // Otherwise, name the project after the remote workspace folder name
+ Ok(Path::new(&remote_workspace_folder)
+ .file_name()
+ .and_then(|name| name.to_str())
+ .map(|string| string.into())
+ // Finally, name the project after the container ID as a last resort
+ .unwrap_or_else(|| container_id.clone()))
+ }
+}
+
+fn project_directory(cx: &mut AsyncWindowContext) -> Option<Arc<Path>> {
+ let Some(workspace) = cx.window_handle().downcast::<Workspace>() else {
+ return None;
+ };
+
+ match workspace.update(cx, |workspace, _, cx| {
+ workspace.project().read(cx).active_project_directory(cx)
+ }) {
+ Ok(dir) => dir,
+ Err(e) => {
+ log::error!("Error getting project directory from workspace: {:?}", e);
+ None
+ }
+ }
+}
+
+pub(crate) async fn start_dev_container(
+ cx: &mut AsyncWindowContext,
+ node_runtime: NodeRuntime,
+) -> Result<(Connection, String), DevContainerError> {
+ check_for_docker().await?;
+
+ let path_to_devcontainer_cli = ensure_devcontainer_cli(node_runtime).await?;
+
+ let Some(directory) = project_directory(cx) else {
+ return Err(DevContainerError::DevContainerNotFound);
+ };
+
+ if let Ok(DevContainerUp {
+ container_id,
+ remote_workspace_folder,
+ ..
+ }) = devcontainer_up(&path_to_devcontainer_cli, directory.clone()).await
+ {
+ let project_name = get_project_name(
+ &path_to_devcontainer_cli,
+ directory,
+ remote_workspace_folder.clone(),
+ container_id.clone(),
+ )
+ .await?;
+
+ let connection = Connection::DevContainer(DevContainerConnection {
+ name: project_name.into(),
+ container_id: container_id.into(),
+ });
+
+ Ok((connection, remote_workspace_folder))
+ } else {
+ Err(DevContainerError::DevContainerUpFailed)
+ }
+}
+
+#[derive(Debug)]
+pub(crate) enum DevContainerError {
+ DockerNotAvailable,
+ DevContainerCliNotAvailable,
+ DevContainerUpFailed,
+ DevContainerNotFound,
+ DevContainerParseFailed,
+}
+
+#[cfg(test)]
+mod test {
+
+ use crate::dev_container::DevContainerUp;
+
+ #[test]
+ fn should_parse_from_devcontainer_json() {
+ let json = r#"{"outcome":"success","containerId":"826abcac45afd412abff083ab30793daff2f3c8ce2c831df728baf39933cb37a","remoteUser":"vscode","remoteWorkspaceFolder":"/workspaces/zed"}"#;
+ let up: DevContainerUp = serde_json::from_str(json).unwrap();
+ assert_eq!(up._outcome, "success");
+ assert_eq!(
+ up.container_id,
+ "826abcac45afd412abff083ab30793daff2f3c8ce2c831df728baf39933cb37a"
+ );
+ assert_eq!(up._remote_user, "vscode");
+ assert_eq!(up.remote_workspace_folder, "/workspaces/zed");
+ }
+}
@@ -0,0 +1,106 @@
+use db::kvp::KEY_VALUE_STORE;
+use gpui::{SharedString, Window};
+use project::{Project, WorktreeId};
+use std::sync::LazyLock;
+use ui::prelude::*;
+use util::rel_path::RelPath;
+use workspace::Workspace;
+use workspace::notifications::NotificationId;
+use workspace::notifications::simple_message_notification::MessageNotification;
+use worktree::UpdatedEntriesSet;
+
+const DEV_CONTAINER_SUGGEST_KEY: &str = "dev_container_suggest_dismissed";
+
+fn devcontainer_path() -> &'static RelPath {
+ static PATH: LazyLock<&'static RelPath> =
+ LazyLock::new(|| RelPath::unix(".devcontainer").expect("valid path"));
+ *PATH
+}
+
+fn project_devcontainer_key(project_path: &str) -> String {
+ format!("{}_{}", DEV_CONTAINER_SUGGEST_KEY, project_path)
+}
+
+pub fn suggest_on_worktree_updated(
+ worktree_id: WorktreeId,
+ updated_entries: &UpdatedEntriesSet,
+ project: &gpui::Entity<Project>,
+ window: &mut Window,
+ cx: &mut Context<Workspace>,
+) {
+ let devcontainer_updated = updated_entries
+ .iter()
+ .any(|(path, _, _)| path.as_ref() == devcontainer_path());
+
+ if !devcontainer_updated {
+ return;
+ }
+
+ let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) else {
+ return;
+ };
+
+ let worktree = worktree.read(cx);
+
+ if !worktree.is_local() {
+ return;
+ }
+
+ let has_devcontainer = worktree
+ .entry_for_path(devcontainer_path())
+ .is_some_and(|entry| entry.is_dir());
+
+ if !has_devcontainer {
+ return;
+ }
+
+ let abs_path = worktree.abs_path();
+ let project_path = abs_path.to_string_lossy().to_string();
+ let key_for_dismiss = project_devcontainer_key(&project_path);
+
+ let already_dismissed = KEY_VALUE_STORE
+ .read_kvp(&key_for_dismiss)
+ .ok()
+ .flatten()
+ .is_some();
+
+ if already_dismissed {
+ return;
+ }
+
+ cx.on_next_frame(window, move |workspace, _window, cx| {
+ struct DevContainerSuggestionNotification;
+
+ let notification_id = NotificationId::composite::<DevContainerSuggestionNotification>(
+ SharedString::from(project_path.clone()),
+ );
+
+ workspace.show_notification(notification_id, cx, |cx| {
+ cx.new(move |cx| {
+ MessageNotification::new(
+ "This project contains a Dev Container configuration file. Would you like to re-open it in a container?",
+ cx,
+ )
+ .primary_message("Yes, Open in Container")
+ .primary_icon(IconName::Check)
+ .primary_icon_color(Color::Success)
+ .primary_on_click({
+ move |window, cx| {
+ window.dispatch_action(Box::new(zed_actions::OpenDevContainer), cx);
+ }
+ })
+ .secondary_message("Don't Show Again")
+ .secondary_icon(IconName::Close)
+ .secondary_icon_color(Color::Error)
+ .secondary_on_click({
+ move |_window, cx| {
+ let key = key_for_dismiss.clone();
+ db::write_and_log(cx, move || {
+ KEY_VALUE_STORE.write_kvp(key, "dismissed".to_string())
+ });
+ }
+ })
+ })
+ });
+ });
+}
@@ -1,8 +1,12 @@
+mod dev_container;
+mod dev_container_suggest;
pub mod disconnected_overlay;
mod remote_connections;
mod remote_servers;
mod ssh_config;
+use std::path::PathBuf;
+
#[cfg(target_os = "windows")]
mod wsl_picker;
@@ -31,7 +35,7 @@ use workspace::{
WORKSPACE_DB, Workspace, WorkspaceId, notifications::DetachAndPromptErr,
with_active_or_new_workspace,
};
-use zed_actions::{OpenRecent, OpenRemote};
+use zed_actions::{OpenDevContainer, OpenRecent, OpenRemote};
pub fn init(cx: &mut App) {
#[cfg(target_os = "windows")]
@@ -161,6 +165,95 @@ pub fn init(cx: &mut App) {
});
cx.observe_new(DisconnectedOverlay::register).detach();
+
+ cx.on_action(|_: &OpenDevContainer, cx| {
+ with_active_or_new_workspace(cx, move |workspace, window, cx| {
+ let app_state = workspace.app_state().clone();
+ let replace_window = window.window_handle().downcast::<Workspace>();
+
+ cx.spawn_in(window, async move |_, mut cx| {
+ let (connection, starting_dir) = match dev_container::start_dev_container(
+ &mut cx,
+ app_state.node_runtime.clone(),
+ )
+ .await
+ {
+ Ok((c, s)) => (c, s),
+ Err(e) => {
+ log::error!("Failed to start Dev Container: {:?}", e);
+ cx.prompt(
+ gpui::PromptLevel::Critical,
+ "Failed to start Dev Container",
+ Some(&format!("{:?}", e)),
+ &["Ok"],
+ )
+ .await
+ .ok();
+ return;
+ }
+ };
+
+ let result = open_remote_project(
+ connection.into(),
+ vec![starting_dir].into_iter().map(PathBuf::from).collect(),
+ app_state,
+ OpenOptions {
+ replace_window,
+ ..OpenOptions::default()
+ },
+ &mut cx,
+ )
+ .await;
+
+ if let Err(e) = result {
+ log::error!("Failed to connect: {e:#}");
+ cx.prompt(
+ gpui::PromptLevel::Critical,
+ "Failed to connect",
+ Some(&e.to_string()),
+ &["Ok"],
+ )
+ .await
+ .ok();
+ }
+ })
+ .detach();
+
+ let fs = workspace.project().read(cx).fs().clone();
+ let handle = cx.entity().downgrade();
+ workspace.toggle_modal(window, cx, |window, cx| {
+ RemoteServerProjects::new_dev_container(fs, window, handle, cx)
+ });
+ });
+ });
+
+ // Subscribe to worktree additions to suggest opening the project in a dev container
+ cx.observe_new(
+ |workspace: &mut Workspace, window: Option<&mut Window>, cx: &mut Context<Workspace>| {
+ let Some(window) = window else {
+ return;
+ };
+ cx.subscribe_in(
+ workspace.project(),
+ window,
+ move |_, project, event, window, cx| {
+ if let project::Event::WorktreeUpdatedEntries(worktree_id, updated_entries) =
+ event
+ {
+ dev_container_suggest::suggest_on_worktree_updated(
+ *worktree_id,
+ updated_entries,
+ project,
+ window,
+ cx,
+ );
+ }
+ },
+ )
+ .detach();
+ },
+ )
+ .detach();
}
#[cfg(target_os = "windows")]
@@ -609,6 +702,7 @@ impl PickerDelegate for RecentProjectsDelegate {
Icon::new(match options {
RemoteConnectionOptions::Ssh { .. } => IconName::Server,
RemoteConnectionOptions::Wsl { .. } => IconName::Linux,
+ RemoteConnectionOptions::Docker(_) => IconName::Box,
})
.color(Color::Muted)
.into_any_element()
@@ -18,16 +18,16 @@ use language::{CursorShape, Point};
use markdown::{Markdown, MarkdownElement, MarkdownStyle};
use release_channel::ReleaseChannel;
use remote::{
- ConnectionIdentifier, RemoteClient, RemoteConnection, RemoteConnectionOptions, RemotePlatform,
- SshConnectionOptions,
+ ConnectionIdentifier, DockerConnectionOptions, RemoteClient, RemoteConnection,
+ RemoteConnectionOptions, RemotePlatform, SshConnectionOptions,
};
use semver::Version;
pub use settings::SshConnection;
-use settings::{ExtendingVec, RegisterSetting, Settings, WslConnection};
+use settings::{DevContainerConnection, ExtendingVec, RegisterSetting, Settings, WslConnection};
use theme::ThemeSettings;
use ui::{
- ActiveTheme, Color, CommonAnimationExt, Context, Icon, IconName, IconSize, InteractiveElement,
- IntoElement, Label, LabelCommon, Styled, Window, prelude::*,
+ ActiveTheme, Color, CommonAnimationExt, Context, InteractiveElement, IntoElement, KeyBinding,
+ LabelCommon, ListItem, Styled, Window, prelude::*,
};
use util::paths::PathWithPosition;
use workspace::{AppState, ModalView, Workspace};
@@ -85,6 +85,7 @@ impl SshSettings {
pub enum Connection {
Ssh(SshConnection),
Wsl(WslConnection),
+ DevContainer(DevContainerConnection),
}
impl From<Connection> for RemoteConnectionOptions {
@@ -92,6 +93,13 @@ impl From<Connection> for RemoteConnectionOptions {
match val {
Connection::Ssh(conn) => RemoteConnectionOptions::Ssh(conn.into()),
Connection::Wsl(conn) => RemoteConnectionOptions::Wsl(conn.into()),
+ Connection::DevContainer(conn) => {
+ RemoteConnectionOptions::Docker(DockerConnectionOptions {
+ name: conn.name.to_string(),
+ container_id: conn.container_id.to_string(),
+ upload_binary_over_docker_exec: false,
+ })
+ }
}
}
}
@@ -123,6 +131,7 @@ pub struct RemoteConnectionPrompt {
connection_string: SharedString,
nickname: Option<SharedString>,
is_wsl: bool,
+ is_devcontainer: bool,
status_message: Option<SharedString>,
prompt: Option<(Entity<Markdown>, oneshot::Sender<EncryptedPassword>)>,
cancellation: Option<oneshot::Sender<()>>,
@@ -148,6 +157,7 @@ impl RemoteConnectionPrompt {
connection_string: String,
nickname: Option<String>,
is_wsl: bool,
+ is_devcontainer: bool,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
@@ -155,6 +165,7 @@ impl RemoteConnectionPrompt {
connection_string: connection_string.into(),
nickname: nickname.map(|nickname| nickname.into()),
is_wsl,
+ is_devcontainer,
editor: cx.new(|cx| Editor::single_line(window, cx)),
status_message: None,
cancellation: None,
@@ -244,17 +255,16 @@ impl Render for RemoteConnectionPrompt {
v_flex()
.key_context("PasswordPrompt")
- .py_2()
- .px_3()
+ .p_2()
.size_full()
.text_buffer(cx)
.when_some(self.status_message.clone(), |el, status_message| {
el.child(
h_flex()
- .gap_1()
+ .gap_2()
.child(
Icon::new(IconName::ArrowCircle)
- .size(IconSize::Medium)
+ .color(Color::Muted)
.with_rotate_animation(2),
)
.child(
@@ -287,15 +297,28 @@ impl RemoteConnectionModal {
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
- let (connection_string, nickname, is_wsl) = match connection_options {
- RemoteConnectionOptions::Ssh(options) => {
- (options.connection_string(), options.nickname.clone(), false)
+ let (connection_string, nickname, is_wsl, is_devcontainer) = match connection_options {
+ RemoteConnectionOptions::Ssh(options) => (
+ options.connection_string(),
+ options.nickname.clone(),
+ false,
+ false,
+ ),
+ RemoteConnectionOptions::Wsl(options) => {
+ (options.distro_name.clone(), None, true, false)
}
- RemoteConnectionOptions::Wsl(options) => (options.distro_name.clone(), None, true),
+ RemoteConnectionOptions::Docker(options) => (options.name.clone(), None, false, true),
};
Self {
prompt: cx.new(|cx| {
- RemoteConnectionPrompt::new(connection_string, nickname, is_wsl, window, cx)
+ RemoteConnectionPrompt::new(
+ connection_string,
+ nickname,
+ is_wsl,
+ is_devcontainer,
+ window,
+ cx,
+ )
}),
finished: false,
paths,
@@ -328,6 +351,7 @@ pub(crate) struct SshConnectionHeader {
pub(crate) paths: Vec<PathBuf>,
pub(crate) nickname: Option<SharedString>,
pub(crate) is_wsl: bool,
+ pub(crate) is_devcontainer: bool,
}
impl RenderOnce for SshConnectionHeader {
@@ -343,9 +367,12 @@ impl RenderOnce for SshConnectionHeader {
(self.connection_string, None)
};
- let icon = match self.is_wsl {
- true => IconName::Linux,
- false => IconName::Server,
+ let icon = if self.is_wsl {
+ IconName::Linux
+ } else if self.is_devcontainer {
+ IconName::Box
+ } else {
+ IconName::Server
};
h_flex()
@@ -388,6 +415,7 @@ impl Render for RemoteConnectionModal {
let nickname = self.prompt.read(cx).nickname.clone();
let connection_string = self.prompt.read(cx).connection_string.clone();
let is_wsl = self.prompt.read(cx).is_wsl;
+ let is_devcontainer = self.prompt.read(cx).is_devcontainer;
let theme = cx.theme().clone();
let body_color = theme.colors().editor_background;
@@ -407,18 +435,34 @@ impl Render for RemoteConnectionModal {
connection_string,
nickname,
is_wsl,
+ is_devcontainer,
}
.render(window, cx),
)
.child(
div()
.w_full()
- .rounded_b_lg()
.bg(body_color)
- .border_t_1()
+ .border_y_1()
.border_color(theme.colors().border_variant)
.child(self.prompt.clone()),
)
+ .child(
+ div().w_full().py_1().child(
+ ListItem::new("li-devcontainer-go-back")
+ .inset(true)
+ .spacing(ui::ListItemSpacing::Sparse)
+ .start_slot(Icon::new(IconName::Close).color(Color::Muted))
+ .child(Label::new("Cancel"))
+ .end_slot(
+ KeyBinding::for_action_in(&menu::Cancel, &self.focus_handle(cx), cx)
+ .size(rems_from_px(12.)),
+ )
+ .on_click(cx.listener(|this, _, window, cx| {
+ this.dismiss(&menu::Cancel, window, cx);
+ })),
+ ),
+ )
}
}
@@ -671,6 +715,9 @@ pub async fn open_remote_project(
match connection_options {
RemoteConnectionOptions::Ssh(_) => "Failed to connect over SSH",
RemoteConnectionOptions::Wsl(_) => "Failed to connect to WSL",
+ RemoteConnectionOptions::Docker(_) => {
+ "Failed to connect to Dev Container"
+ }
},
Some(&format!("{e:#}")),
&["Retry", "Cancel"],
@@ -727,6 +774,9 @@ pub async fn open_remote_project(
match connection_options {
RemoteConnectionOptions::Ssh(_) => "Failed to connect over SSH",
RemoteConnectionOptions::Wsl(_) => "Failed to connect to WSL",
+ RemoteConnectionOptions::Docker(_) => {
+ "Failed to connect to Dev Container"
+ }
},
Some(&format!("{e:#}")),
&["Retry", "Cancel"],
@@ -1,4 +1,5 @@
use crate::{
+ dev_container::start_dev_container,
remote_connections::{
Connection, RemoteConnectionModal, RemoteConnectionPrompt, SshConnection,
SshConnectionHeader, SshSettings, connect, determine_paths_with_positions,
@@ -24,7 +25,7 @@ use remote::{
remote_client::ConnectionIdentifier,
};
use settings::{
- RemoteSettingsContent, Settings as _, SettingsStore, SshProject, update_settings_file,
+ RemoteProject, RemoteSettingsContent, Settings as _, SettingsStore, update_settings_file,
watch_config_file,
};
use smol::stream::StreamExt as _;
@@ -39,12 +40,13 @@ use std::{
},
};
use ui::{
- IconButtonShape, List, ListItem, ListSeparator, Modal, ModalHeader, Navigable, NavigableEntry,
- Section, Tooltip, WithScrollbar, prelude::*,
+ CommonAnimationExt, IconButtonShape, KeyBinding, List, ListItem, ListSeparator, Modal,
+ ModalHeader, Navigable, NavigableEntry, Section, Tooltip, WithScrollbar, prelude::*,
};
use util::{
ResultExt,
paths::{PathStyle, RemotePathBuf},
+ rel_path::RelPath,
};
use workspace::{
ModalView, OpenOptions, Toast, Workspace,
@@ -85,6 +87,39 @@ impl CreateRemoteServer {
}
}
+#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
+enum DevContainerCreationProgress {
+ Initial,
+ Creating,
+ Error(String),
+}
+
+#[derive(Clone)]
+struct CreateRemoteDevContainer {
+ // 3 Navigable Options
+ // - Create from devcontainer.json
+ // - Edit devcontainer.json
+ // - Go back
+ entries: [NavigableEntry; 3],
+ progress: DevContainerCreationProgress,
+}
+
+impl CreateRemoteDevContainer {
+ fn new(window: &mut Window, cx: &mut Context<RemoteServerProjects>) -> Self {
+ let entries = std::array::from_fn(|_| NavigableEntry::focusable(cx));
+ entries[0].focus_handle.focus(window);
+ Self {
+ entries,
+ progress: DevContainerCreationProgress::Initial,
+ }
+ }
+
+ fn progress(&mut self, progress: DevContainerCreationProgress) -> Self {
+ self.progress = progress;
+ self.clone()
+ }
+}
+
#[cfg(target_os = "windows")]
struct AddWslDistro {
picker: Entity<Picker<crate::wsl_picker::WslPickerDelegate>>,
@@ -207,6 +242,11 @@ impl ProjectPicker {
RemoteConnectionOptions::Wsl(connection) => ProjectPickerData::Wsl {
distro_name: connection.distro_name.clone().into(),
},
+ RemoteConnectionOptions::Docker(_) => ProjectPickerData::Ssh {
+ // Not implemented as a project picker at this time
+ connection_string: "".into(),
+ nickname: None,
+ },
};
let _path_task = cx
.spawn_in(window, {
@@ -259,7 +299,7 @@ impl ProjectPicker {
.as_mut()
.and_then(|connections| connections.get_mut(index.0))
{
- server.projects.insert(SshProject { paths });
+ server.projects.insert(RemoteProject { paths });
};
}
ServerIndex::Wsl(index) => {
@@ -269,7 +309,7 @@ impl ProjectPicker {
.as_mut()
.and_then(|connections| connections.get_mut(index.0))
{
- server.projects.insert(SshProject { paths });
+ server.projects.insert(RemoteProject { paths });
};
}
}
@@ -349,6 +389,7 @@ impl gpui::Render for ProjectPicker {
paths: Default::default(),
nickname: nickname.clone(),
is_wsl: false,
+ is_devcontainer: false,
}
.render(window, cx),
ProjectPickerData::Wsl { distro_name } => SshConnectionHeader {
@@ -356,6 +397,7 @@ impl gpui::Render for ProjectPicker {
paths: Default::default(),
nickname: None,
is_wsl: true,
+ is_devcontainer: false,
}
.render(window, cx),
})
@@ -406,7 +448,7 @@ impl From<WslServerIndex> for ServerIndex {
enum RemoteEntry {
Project {
open_folder: NavigableEntry,
- projects: Vec<(NavigableEntry, SshProject)>,
+ projects: Vec<(NavigableEntry, RemoteProject)>,
configure: NavigableEntry,
connection: Connection,
index: ServerIndex,
@@ -440,6 +482,7 @@ impl RemoteEntry {
struct DefaultState {
scroll_handle: ScrollHandle,
add_new_server: NavigableEntry,
+ add_new_devcontainer: NavigableEntry,
add_new_wsl: NavigableEntry,
servers: Vec<RemoteEntry>,
}
@@ -448,6 +491,7 @@ impl DefaultState {
fn new(ssh_config_servers: &BTreeSet<SharedString>, cx: &mut App) -> Self {
let handle = ScrollHandle::new();
let add_new_server = NavigableEntry::new(&handle, cx);
+ let add_new_devcontainer = NavigableEntry::new(&handle, cx);
let add_new_wsl = NavigableEntry::new(&handle, cx);
let ssh_settings = SshSettings::get_global(cx);
@@ -517,6 +561,7 @@ impl DefaultState {
Self {
scroll_handle: handle,
add_new_server,
+ add_new_devcontainer,
add_new_wsl,
servers,
}
@@ -552,6 +597,7 @@ enum Mode {
EditNickname(EditNicknameState),
ProjectPicker(Entity<ProjectPicker>),
CreateRemoteServer(CreateRemoteServer),
+ CreateRemoteDevContainer(CreateRemoteDevContainer),
#[cfg(target_os = "windows")]
AddWslDistro(AddWslDistro),
}
@@ -598,6 +644,27 @@ impl RemoteServerProjects {
)
}
+ /// Creates a new RemoteServerProjects modal that opens directly in dev container creation mode.
+ /// Used when suggesting dev container connection from toast notification.
+ pub fn new_dev_container(
+ fs: Arc<dyn Fs>,
+ window: &mut Window,
+ workspace: WeakEntity<Workspace>,
+ cx: &mut Context<Self>,
+ ) -> Self {
+ Self::new_inner(
+ Mode::CreateRemoteDevContainer(
+ CreateRemoteDevContainer::new(window, cx)
+ .progress(DevContainerCreationProgress::Creating),
+ ),
+ false,
+ fs,
+ window,
+ workspace,
+ cx,
+ )
+ }
+
fn new_inner(
mode: Mode,
create_new_window: bool,
@@ -703,6 +770,7 @@ impl RemoteServerProjects {
connection_options.connection_string(),
connection_options.nickname.clone(),
false,
+ false,
window,
cx,
)
@@ -778,6 +846,7 @@ impl RemoteServerProjects {
connection_options.distro_name.clone(),
None,
true,
+ false,
window,
cx,
)
@@ -862,6 +931,15 @@ impl RemoteServerProjects {
cx.notify();
}
+ fn view_in_progress_dev_container(&mut self, window: &mut Window, cx: &mut Context<Self>) {
+ self.mode = Mode::CreateRemoteDevContainer(
+ CreateRemoteDevContainer::new(window, cx)
+ .progress(DevContainerCreationProgress::Creating),
+ );
+ self.focus_handle(cx).focus(window);
+ cx.notify();
+ }
+
fn create_remote_project(
&mut self,
index: ServerIndex,
@@ -981,6 +1059,7 @@ impl RemoteServerProjects {
self.create_ssh_server(state.address_editor.clone(), window, cx);
}
+ Mode::CreateRemoteDevContainer(_) => {}
Mode::EditNickname(state) => {
let text = Some(state.editor.read(cx).text(cx)).filter(|text| !text.is_empty());
let index = state.index;
@@ -1024,14 +1103,14 @@ impl RemoteServerProjects {
}
}
- fn render_ssh_connection(
+ fn render_remote_connection(
&mut self,
ix: usize,
- ssh_server: RemoteEntry,
+ remote_server: RemoteEntry,
window: &mut Window,
cx: &mut Context<Self>,
) -> impl IntoElement {
- let connection = ssh_server.connection().into_owned();
+ let connection = remote_server.connection().into_owned();
let (main_label, aux_label, is_wsl) = match &connection {
Connection::Ssh(connection) => {
@@ -1045,6 +1124,9 @@ impl RemoteServerProjects {
Connection::Wsl(wsl_connection_options) => {
(wsl_connection_options.distro_name.clone(), None, true)
}
+ Connection::DevContainer(dev_container_options) => {
+ (dev_container_options.name.clone(), None, false)
+ }
};
v_flex()
.w_full()
@@ -1082,7 +1164,7 @@ impl RemoteServerProjects {
}),
),
)
- .child(match &ssh_server {
+ .child(match &remote_server {
RemoteEntry::Project {
open_folder,
projects,
@@ -1094,9 +1176,9 @@ impl RemoteServerProjects {
List::new()
.empty_message("No projects.")
.children(projects.iter().enumerate().map(|(pix, p)| {
- v_flex().gap_0p5().child(self.render_ssh_project(
+ v_flex().gap_0p5().child(self.render_remote_project(
index,
- ssh_server.clone(),
+ remote_server.clone(),
pix,
p,
window,
@@ -1222,12 +1304,12 @@ impl RemoteServerProjects {
})
}
- fn render_ssh_project(
+ fn render_remote_project(
&mut self,
server_ix: ServerIndex,
server: RemoteEntry,
ix: usize,
- (navigation, project): &(NavigableEntry, SshProject),
+ (navigation, project): &(NavigableEntry, RemoteProject),
window: &mut Window,
cx: &mut Context<Self>,
) -> impl IntoElement {
@@ -1372,7 +1454,7 @@ impl RemoteServerProjects {
fn delete_remote_project(
&mut self,
server: ServerIndex,
- project: &SshProject,
+ project: &RemoteProject,
cx: &mut Context<Self>,
) {
match server {
@@ -1388,7 +1470,7 @@ impl RemoteServerProjects {
fn delete_ssh_project(
&mut self,
server: SshServerIndex,
- project: &SshProject,
+ project: &RemoteProject,
cx: &mut Context<Self>,
) {
let project = project.clone();
@@ -1406,7 +1488,7 @@ impl RemoteServerProjects {
fn delete_wsl_project(
&mut self,
server: WslServerIndex,
- project: &SshProject,
+ project: &RemoteProject,
cx: &mut Context<Self>,
) {
let project = project.clone();
@@ -1451,6 +1533,342 @@ impl RemoteServerProjects {
});
}
+ fn edit_in_dev_container_json(&mut self, window: &mut Window, cx: &mut Context<Self>) {
+ let Some(workspace) = self.workspace.upgrade() else {
+ cx.emit(DismissEvent);
+ cx.notify();
+ return;
+ };
+
+ workspace.update(cx, |workspace, cx| {
+ let project = workspace.project().clone();
+
+ let worktree = project
+ .read(cx)
+ .visible_worktrees(cx)
+ .find_map(|tree| tree.read(cx).root_entry()?.is_dir().then_some(tree));
+
+ if let Some(worktree) = worktree {
+ let tree_id = worktree.read(cx).id();
+ let devcontainer_path = RelPath::unix(".devcontainer/devcontainer.json").unwrap();
+ cx.spawn_in(window, async move |workspace, cx| {
+ workspace
+ .update_in(cx, |workspace, window, cx| {
+ workspace.open_path(
+ (tree_id, devcontainer_path),
+ None,
+ true,
+ window,
+ cx,
+ )
+ })?
+ .await
+ })
+ .detach();
+ } else {
+ return;
+ }
+ });
+ cx.emit(DismissEvent);
+ cx.notify();
+ }
+
+ fn open_dev_container(&self, window: &mut Window, cx: &mut Context<Self>) {
+ let Some(app_state) = self
+ .workspace
+ .read_with(cx, |workspace, _| workspace.app_state().clone())
+ .log_err()
+ else {
+ return;
+ };
+
+ let replace_window = window.window_handle().downcast::<Workspace>();
+
+ cx.spawn_in(window, async move |entity, cx| {
+ let (connection, starting_dir) =
+ match start_dev_container(cx, app_state.node_runtime.clone()).await {
+ Ok((c, s)) => (c, s),
+ Err(e) => {
+ log::error!("Failed to start dev container: {:?}", e);
+ entity
+ .update_in(cx, |remote_server_projects, window, cx| {
+ remote_server_projects.mode = Mode::CreateRemoteDevContainer(
+ CreateRemoteDevContainer::new(window, cx).progress(
+ DevContainerCreationProgress::Error(format!("{:?}", e)),
+ ),
+ );
+ })
+ .log_err();
+ return;
+ }
+ };
+ entity
+ .update(cx, |_, cx| {
+ cx.emit(DismissEvent);
+ })
+ .log_err();
+
+ let result = open_remote_project(
+ connection.into(),
+ vec![starting_dir].into_iter().map(PathBuf::from).collect(),
+ app_state,
+ OpenOptions {
+ replace_window,
+ ..OpenOptions::default()
+ },
+ cx,
+ )
+ .await;
+ if let Err(e) = result {
+ log::error!("Failed to connect: {e:#}");
+ cx.prompt(
+ gpui::PromptLevel::Critical,
+ "Failed to connect",
+ Some(&e.to_string()),
+ &["Ok"],
+ )
+ .await
+ .ok();
+ }
+ })
+ .detach();
+ }
+
+ fn render_create_dev_container(
+ &self,
+ state: &CreateRemoteDevContainer,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) -> impl IntoElement {
+ match &state.progress {
+ DevContainerCreationProgress::Error(message) => {
+ self.focus_handle(cx).focus(window);
+ return div()
+ .track_focus(&self.focus_handle(cx))
+ .size_full()
+ .child(
+ v_flex()
+ .py_1()
+ .child(
+ ListItem::new("Error")
+ .inset(true)
+ .selectable(false)
+ .spacing(ui::ListItemSpacing::Sparse)
+ .start_slot(Icon::new(IconName::XCircle).color(Color::Error))
+ .child(Label::new("Error Creating Dev Container:"))
+ .child(Label::new(message).buffer_font(cx)),
+ )
+ .child(ListSeparator)
+ .child(
+ div()
+ .id("devcontainer-go-back")
+ .track_focus(&state.entries[0].focus_handle)
+ .on_action(cx.listener(
+ |this, _: &menu::Confirm, window, cx| {
+ this.mode =
+ Mode::default_mode(&this.ssh_config_servers, cx);
+ cx.focus_self(window);
+ cx.notify();
+ },
+ ))
+ .child(
+ ListItem::new("li-devcontainer-go-back")
+ .toggle_state(
+ state.entries[0]
+ .focus_handle
+ .contains_focused(window, cx),
+ )
+ .inset(true)
+ .spacing(ui::ListItemSpacing::Sparse)
+ .start_slot(
+ Icon::new(IconName::ArrowLeft).color(Color::Muted),
+ )
+ .child(Label::new("Go Back"))
+ .end_slot(
+ KeyBinding::for_action_in(
+ &menu::Cancel,
+ &self.focus_handle,
+ cx,
+ )
+ .size(rems_from_px(12.)),
+ )
+ .on_click(cx.listener(|this, _, window, cx| {
+ let state =
+ CreateRemoteDevContainer::new(window, cx);
+ this.mode = Mode::CreateRemoteDevContainer(state);
+
+ cx.notify();
+ })),
+ ),
+ ),
+ )
+ .into_any_element();
+ }
+ _ => {}
+ };
+
+ let mut view = Navigable::new(
+ div()
+ .track_focus(&self.focus_handle(cx))
+ .size_full()
+ .child(
+ v_flex()
+ .pb_1()
+ .child(
+ ModalHeader::new()
+ .child(Headline::new("Dev Containers").size(HeadlineSize::XSmall)),
+ )
+ .child(ListSeparator)
+ .child(
+ div()
+ .id("confirm-create-from-devcontainer-json")
+ .track_focus(&state.entries[0].focus_handle)
+ .on_action(cx.listener({
+ move |this, _: &menu::Confirm, window, cx| {
+ this.open_dev_container(window, cx);
+ this.view_in_progress_dev_container(window, cx);
+ }
+ }))
+ .map(|this| {
+ if state.progress == DevContainerCreationProgress::Creating {
+ this.child(
+ ListItem::new("creating")
+ .inset(true)
+ .spacing(ui::ListItemSpacing::Sparse)
+ .disabled(true)
+ .start_slot(
+ Icon::new(IconName::ArrowCircle)
+ .color(Color::Muted)
+ .with_rotate_animation(2),
+ )
+ .child(
+ h_flex()
+ .opacity(0.6)
+ .gap_1()
+ .child(Label::new("Creating From"))
+ .child(
+ Label::new("devcontainer.json")
+ .buffer_font(cx),
+ )
+ .child(LoadingLabel::new("")),
+ ),
+ )
+ } else {
+ this.child(
+ ListItem::new(
+ "li-confirm-create-from-devcontainer-json",
+ )
+ .toggle_state(
+ state.entries[0]
+ .focus_handle
+ .contains_focused(window, cx),
+ )
+ .inset(true)
+ .spacing(ui::ListItemSpacing::Sparse)
+ .start_slot(
+ Icon::new(IconName::Plus).color(Color::Muted),
+ )
+ .child(
+ h_flex()
+ .gap_1()
+ .child(Label::new("Open or Create New From"))
+ .child(
+ Label::new("devcontainer.json")
+ .buffer_font(cx),
+ ),
+ )
+ .on_click(
+ cx.listener({
+ move |this, _, window, cx| {
+ this.open_dev_container(window, cx);
+ this.view_in_progress_dev_container(
+ window, cx,
+ );
+ cx.notify();
+ }
+ }),
+ ),
+ )
+ }
+ }),
+ )
+ .child(
+ div()
+ .id("edit-devcontainer-json")
+ .track_focus(&state.entries[1].focus_handle)
+ .on_action(cx.listener(|this, _: &menu::Confirm, window, cx| {
+ this.edit_in_dev_container_json(window, cx);
+ }))
+ .child(
+ ListItem::new("li-edit-devcontainer-json")
+ .toggle_state(
+ state.entries[1]
+ .focus_handle
+ .contains_focused(window, cx),
+ )
+ .inset(true)
+ .spacing(ui::ListItemSpacing::Sparse)
+ .start_slot(Icon::new(IconName::Pencil).color(Color::Muted))
+ .child(
+ h_flex().gap_1().child(Label::new("Edit")).child(
+ Label::new("devcontainer.json").buffer_font(cx),
+ ),
+ )
+ .on_click(cx.listener(move |this, _, window, cx| {
+ this.edit_in_dev_container_json(window, cx);
+ })),
+ ),
+ )
+ .child(ListSeparator)
+ .child(
+ div()
+ .id("devcontainer-go-back")
+ .track_focus(&state.entries[2].focus_handle)
+ .on_action(cx.listener(|this, _: &menu::Confirm, window, cx| {
+ this.mode = Mode::default_mode(&this.ssh_config_servers, cx);
+ cx.focus_self(window);
+ cx.notify();
+ }))
+ .child(
+ ListItem::new("li-devcontainer-go-back")
+ .toggle_state(
+ state.entries[2]
+ .focus_handle
+ .contains_focused(window, cx),
+ )
+ .inset(true)
+ .spacing(ui::ListItemSpacing::Sparse)
+ .start_slot(
+ Icon::new(IconName::ArrowLeft).color(Color::Muted),
+ )
+ .child(Label::new("Go Back"))
+ .end_slot(
+ KeyBinding::for_action_in(
+ &menu::Cancel,
+ &self.focus_handle,
+ cx,
+ )
+ .size(rems_from_px(12.)),
+ )
+ .on_click(cx.listener(|this, _, window, cx| {
+ this.mode =
+ Mode::default_mode(&this.ssh_config_servers, cx);
+ cx.focus_self(window);
+ cx.notify()
+ })),
+ ),
+ ),
+ )
+ .into_any_element(),
+ );
+
+ view = view.entry(state.entries[0].clone());
+ view = view.entry(state.entries[1].clone());
+ view = view.entry(state.entries[2].clone());
+
+ view.render(window, cx).into_any_element()
+ }
+
fn render_create_remote_server(
&self,
state: &CreateRemoteServer,
@@ -1571,6 +1989,7 @@ impl RemoteServerProjects {
paths: Default::default(),
nickname: connection.nickname.clone().map(|s| s.into()),
is_wsl: false,
+ is_devcontainer: false,
}
.render(window, cx)
.into_any_element(),
@@ -1579,6 +1998,7 @@ impl RemoteServerProjects {
paths: Default::default(),
nickname: None,
is_wsl: true,
+ is_devcontainer: false,
}
.render(window, cx)
.into_any_element(),
@@ -1917,6 +2337,7 @@ impl RemoteServerProjects {
paths: Default::default(),
nickname,
is_wsl: false,
+ is_devcontainer: false,
}
.render(window, cx),
)
@@ -1998,7 +2419,7 @@ impl RemoteServerProjects {
.track_focus(&state.add_new_server.focus_handle)
.anchor_scroll(state.add_new_server.scroll_anchor.clone())
.child(
- ListItem::new("register-remove-server-button")
+ ListItem::new("register-remote-server-button")
.toggle_state(
state
.add_new_server
@@ -2008,7 +2429,7 @@ impl RemoteServerProjects {
.inset(true)
.spacing(ui::ListItemSpacing::Sparse)
.start_slot(Icon::new(IconName::Plus).color(Color::Muted))
- .child(Label::new("Connect New Server"))
+ .child(Label::new("Connect SSH Server"))
.on_click(cx.listener(|this, _, window, cx| {
let state = CreateRemoteServer::new(window, cx);
this.mode = Mode::CreateRemoteServer(state);
@@ -2023,6 +2444,36 @@ impl RemoteServerProjects {
cx.notify();
}));
+ let connect_dev_container_button = div()
+ .id("connect-new-dev-container")
+ .track_focus(&state.add_new_devcontainer.focus_handle)
+ .anchor_scroll(state.add_new_devcontainer.scroll_anchor.clone())
+ .child(
+ ListItem::new("register-dev-container-button")
+ .toggle_state(
+ state
+ .add_new_devcontainer
+ .focus_handle
+ .contains_focused(window, cx),
+ )
+ .inset(true)
+ .spacing(ui::ListItemSpacing::Sparse)
+ .start_slot(Icon::new(IconName::Plus).color(Color::Muted))
+ .child(Label::new("Connect Dev Container"))
+ .on_click(cx.listener(|this, _, window, cx| {
+ let state = CreateRemoteDevContainer::new(window, cx);
+ this.mode = Mode::CreateRemoteDevContainer(state);
+
+ cx.notify();
+ })),
+ )
+ .on_action(cx.listener(|this, _: &menu::Confirm, window, cx| {
+ let state = CreateRemoteDevContainer::new(window, cx);
+ this.mode = Mode::CreateRemoteDevContainer(state);
+
+ cx.notify();
+ }));
+
#[cfg(target_os = "windows")]
let wsl_connect_button = div()
.id("wsl-connect-new-server")
@@ -2049,13 +2500,30 @@ impl RemoteServerProjects {
cx.notify();
}));
+ let has_open_project = self
+ .workspace
+ .upgrade()
+ .map(|workspace| {
+ workspace
+ .read(cx)
+ .project()
+ .read(cx)
+ .visible_worktrees(cx)
+ .next()
+ .is_some()
+ })
+ .unwrap_or(false);
+
let modal_section = v_flex()
.track_focus(&self.focus_handle(cx))
.id("ssh-server-list")
.overflow_y_scroll()
.track_scroll(&state.scroll_handle)
.size_full()
- .child(connect_button);
+ .child(connect_button)
+ .when(has_open_project, |this| {
+ this.child(connect_dev_container_button)
+ });
#[cfg(target_os = "windows")]
let modal_section = modal_section.child(wsl_connect_button);
@@ -2067,17 +2535,20 @@ impl RemoteServerProjects {
.child(
List::new()
.empty_message(
- v_flex()
+ h_flex()
+ .size_full()
+ .p_2()
+ .justify_center()
+ .border_t_1()
+ .border_color(cx.theme().colors().border_variant)
.child(
- div().px_3().child(
- Label::new("No remote servers registered yet.")
- .color(Color::Muted),
- ),
+ Label::new("No remote servers registered yet.")
+ .color(Color::Muted),
)
.into_any_element(),
)
.children(state.servers.iter().enumerate().map(|(ix, connection)| {
- self.render_ssh_connection(ix, connection.clone(), window, cx)
+ self.render_remote_connection(ix, connection.clone(), window, cx)
.into_any_element()
})),
)
@@ -2085,6 +2556,10 @@ impl RemoteServerProjects {
)
.entry(state.add_new_server.clone());
+ if has_open_project {
+ modal_section = modal_section.entry(state.add_new_devcontainer.clone());
+ }
+
if cfg!(target_os = "windows") {
modal_section = modal_section.entry(state.add_new_wsl.clone());
}
@@ -2297,6 +2772,9 @@ impl Render for RemoteServerProjects {
Mode::CreateRemoteServer(state) => self
.render_create_remote_server(state, window, cx)
.into_any_element(),
+ Mode::CreateRemoteDevContainer(state) => self
+ .render_create_dev_container(state, window, cx)
+ .into_any_element(),
Mode::EditNickname(state) => self
.render_edit_nickname(state, window, cx)
.into_any_element(),
@@ -10,5 +10,6 @@ pub use remote_client::{
ConnectionIdentifier, ConnectionState, RemoteClient, RemoteClientDelegate, RemoteClientEvent,
RemoteConnection, RemoteConnectionOptions, RemotePlatform, connect,
};
+pub use transport::docker::DockerConnectionOptions;
pub use transport::ssh::{SshConnectionOptions, SshPortForwardOption};
pub use transport::wsl::WslConnectionOptions;
@@ -3,6 +3,7 @@ use crate::{
protocol::MessageId,
proxy::ProxyLaunchError,
transport::{
+ docker::{DockerConnectionOptions, DockerExecConnection},
ssh::SshRemoteConnection,
wsl::{WslConnectionOptions, WslRemoteConnection},
},
@@ -1042,6 +1043,11 @@ impl ConnectionPool {
.await
.map(|connection| Arc::new(connection) as Arc<dyn RemoteConnection>)
}
+ RemoteConnectionOptions::Docker(opts) => {
+ DockerExecConnection::new(opts, delegate, cx)
+ .await
+ .map(|connection| Arc::new(connection) as Arc<dyn RemoteConnection>)
+ }
};
cx.update_global(|pool: &mut Self, _| {
@@ -1077,6 +1083,7 @@ impl ConnectionPool {
pub enum RemoteConnectionOptions {
Ssh(SshConnectionOptions),
Wsl(WslConnectionOptions),
+ Docker(DockerConnectionOptions),
}
impl RemoteConnectionOptions {
@@ -1084,6 +1091,7 @@ impl RemoteConnectionOptions {
match self {
RemoteConnectionOptions::Ssh(opts) => opts.host.clone(),
RemoteConnectionOptions::Wsl(opts) => opts.distro_name.clone(),
+ RemoteConnectionOptions::Docker(opts) => opts.name.clone(),
}
}
}
@@ -12,6 +12,7 @@ use gpui::{AppContext as _, AsyncApp, Task};
use rpc::proto::Envelope;
use smol::process::Child;
+pub mod docker;
pub mod ssh;
pub mod wsl;
@@ -64,15 +65,15 @@ fn parse_shell(output: &str, fallback_shell: &str) -> String {
}
fn handle_rpc_messages_over_child_process_stdio(
- mut ssh_proxy_process: Child,
+ mut remote_proxy_process: Child,
incoming_tx: UnboundedSender<Envelope>,
mut outgoing_rx: UnboundedReceiver<Envelope>,
mut connection_activity_tx: Sender<()>,
cx: &AsyncApp,
) -> Task<Result<i32>> {
- let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
- let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
- let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
+ let mut child_stderr = remote_proxy_process.stderr.take().unwrap();
+ let mut child_stdout = remote_proxy_process.stdout.take().unwrap();
+ let mut child_stdin = remote_proxy_process.stdin.take().unwrap();
let mut stdin_buffer = Vec::new();
let mut stdout_buffer = Vec::new();
@@ -156,7 +157,7 @@ fn handle_rpc_messages_over_child_process_stdio(
result.context("stderr")
}
};
- let status = ssh_proxy_process.status().await?.code().unwrap_or(1);
+ let status = remote_proxy_process.status().await?.code().unwrap_or(1);
match result {
Ok(_) => Ok(status),
Err(error) => Err(error),
@@ -0,0 +1,757 @@
+use anyhow::Context;
+use anyhow::Result;
+use anyhow::anyhow;
+use async_trait::async_trait;
+use collections::HashMap;
+use parking_lot::Mutex;
+use release_channel::{AppCommitSha, AppVersion, ReleaseChannel};
+use semver::Version as SemanticVersion;
+use std::time::Instant;
+use std::{
+ path::{Path, PathBuf},
+ process::Stdio,
+ sync::Arc,
+};
+use util::ResultExt;
+use util::shell::ShellKind;
+use util::{
+ paths::{PathStyle, RemotePathBuf},
+ rel_path::RelPath,
+};
+
+use futures::channel::mpsc::{Sender, UnboundedReceiver, UnboundedSender};
+use gpui::{App, AppContext, AsyncApp, Task};
+use rpc::proto::Envelope;
+
+use crate::{
+ RemoteClientDelegate, RemoteConnection, RemoteConnectionOptions, RemotePlatform,
+ remote_client::CommandTemplate,
+};
+
+#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
+pub struct DockerConnectionOptions {
+ pub name: String,
+ pub container_id: String,
+ pub upload_binary_over_docker_exec: bool,
+}
+
+pub(crate) struct DockerExecConnection {
+ proxy_process: Mutex<Option<u32>>,
+ remote_dir_for_server: String,
+ remote_binary_relpath: Option<Arc<RelPath>>,
+ connection_options: DockerConnectionOptions,
+ remote_platform: Option<RemotePlatform>,
+ path_style: Option<PathStyle>,
+ shell: Option<String>,
+}
+
+impl DockerExecConnection {
+ pub async fn new(
+ connection_options: DockerConnectionOptions,
+ delegate: Arc<dyn RemoteClientDelegate>,
+ cx: &mut AsyncApp,
+ ) -> Result<Self> {
+ let mut this = Self {
+ proxy_process: Mutex::new(None),
+ remote_dir_for_server: "/".to_string(),
+ remote_binary_relpath: None,
+ connection_options,
+ remote_platform: None,
+ path_style: None,
+ shell: None,
+ };
+ let (release_channel, version, commit) = cx.update(|cx| {
+ (
+ ReleaseChannel::global(cx),
+ AppVersion::global(cx),
+ AppCommitSha::try_global(cx),
+ )
+ })?;
+ let remote_platform = this.check_remote_platform().await?;
+
+ this.path_style = match remote_platform.os {
+ "windows" => Some(PathStyle::Windows),
+ _ => Some(PathStyle::Posix),
+ };
+
+ this.remote_platform = Some(remote_platform);
+
+ this.shell = Some(this.discover_shell().await);
+
+ this.remote_dir_for_server = this.docker_user_home_dir().await?.trim().to_string();
+
+ this.remote_binary_relpath = Some(
+ this.ensure_server_binary(
+ &delegate,
+ release_channel,
+ version,
+ &this.remote_dir_for_server,
+ commit,
+ cx,
+ )
+ .await?,
+ );
+
+ Ok(this)
+ }
+
+ async fn discover_shell(&self) -> String {
+ let default_shell = "sh";
+ match self
+ .run_docker_exec("sh", None, &Default::default(), &["-c", "echo $SHELL"])
+ .await
+ {
+ Ok(shell) => match shell.trim() {
+ "" => {
+ log::error!("$SHELL is not set, falling back to {default_shell}");
+ default_shell.to_owned()
+ }
+ shell => shell.to_owned(),
+ },
+ Err(e) => {
+ log::error!("Failed to get shell: {e}");
+ default_shell.to_owned()
+ }
+ }
+ }
+
+ async fn check_remote_platform(&self) -> Result<RemotePlatform> {
+ let uname = self
+ .run_docker_exec("uname", None, &Default::default(), &["-sm"])
+ .await?;
+ let Some((os, arch)) = uname.split_once(" ") else {
+ anyhow::bail!("unknown uname: {uname:?}")
+ };
+
+ let os = match os.trim() {
+ "Darwin" => "macos",
+ "Linux" => "linux",
+ _ => anyhow::bail!(
+ "Prebuilt remote servers are not yet available for {os:?}. See https://zed.dev/docs/remote-development"
+ ),
+ };
+ // exclude armv5,6,7 as they are 32-bit.
+ let arch = if arch.starts_with("armv8")
+ || arch.starts_with("armv9")
+ || arch.starts_with("arm64")
+ || arch.starts_with("aarch64")
+ {
+ "aarch64"
+ } else if arch.starts_with("x86") {
+ "x86_64"
+ } else {
+ anyhow::bail!(
+ "Prebuilt remote servers are not yet available for {arch:?}. See https://zed.dev/docs/remote-development"
+ )
+ };
+
+ Ok(RemotePlatform { os, arch })
+ }
+
+ async fn ensure_server_binary(
+ &self,
+ delegate: &Arc<dyn RemoteClientDelegate>,
+ release_channel: ReleaseChannel,
+ version: SemanticVersion,
+ remote_dir_for_server: &str,
+ commit: Option<AppCommitSha>,
+ cx: &mut AsyncApp,
+ ) -> Result<Arc<RelPath>> {
+ let remote_platform = if self.remote_platform.is_some() {
+ self.remote_platform.unwrap()
+ } else {
+ anyhow::bail!("No remote platform defined; cannot proceed.")
+ };
+
+ let version_str = match release_channel {
+ ReleaseChannel::Nightly => {
+ let commit = commit.map(|s| s.full()).unwrap_or_default();
+ format!("{}-{}", version, commit)
+ }
+ ReleaseChannel::Dev => "build".to_string(),
+ _ => version.to_string(),
+ };
+ let binary_name = format!(
+ "zed-remote-server-{}-{}",
+ release_channel.dev_name(),
+ version_str
+ );
+ let dst_path =
+ paths::remote_server_dir_relative().join(RelPath::unix(&binary_name).unwrap());
+
+ #[cfg(debug_assertions)]
+ if let Some(remote_server_path) =
+ super::build_remote_server_from_source(&remote_platform, delegate.as_ref(), cx).await?
+ {
+ let tmp_path = paths::remote_server_dir_relative().join(
+ RelPath::unix(&format!(
+ "download-{}-{}",
+ std::process::id(),
+ remote_server_path.file_name().unwrap().to_string_lossy()
+ ))
+ .unwrap(),
+ );
+ self.upload_local_server_binary(
+ &remote_server_path,
+ &tmp_path,
+ &remote_dir_for_server,
+ delegate,
+ cx,
+ )
+ .await?;
+ self.extract_server_binary(&dst_path, &tmp_path, &remote_dir_for_server, delegate, cx)
+ .await?;
+ return Ok(dst_path);
+ }
+
+ if self
+ .run_docker_exec(
+ &dst_path.display(self.path_style()),
+ Some(&remote_dir_for_server),
+ &Default::default(),
+ &["version"],
+ )
+ .await
+ .is_ok()
+ {
+ return Ok(dst_path);
+ }
+
+ let wanted_version = cx.update(|cx| match release_channel {
+ ReleaseChannel::Nightly => Ok(None),
+ ReleaseChannel::Dev => {
+ anyhow::bail!(
+ "ZED_BUILD_REMOTE_SERVER is not set and no remote server exists at ({:?})",
+ dst_path
+ )
+ }
+ _ => Ok(Some(AppVersion::global(cx))),
+ })??;
+
+ let tmp_path_gz = paths::remote_server_dir_relative().join(
+ RelPath::unix(&format!(
+ "{}-download-{}.gz",
+ binary_name,
+ std::process::id()
+ ))
+ .unwrap(),
+ );
+ if !self.connection_options.upload_binary_over_docker_exec
+ && let Some(url) = delegate
+ .get_download_url(remote_platform, release_channel, wanted_version.clone(), cx)
+ .await?
+ {
+ match self
+ .download_binary_on_server(&url, &tmp_path_gz, &remote_dir_for_server, delegate, cx)
+ .await
+ {
+ Ok(_) => {
+ self.extract_server_binary(
+ &dst_path,
+ &tmp_path_gz,
+ &remote_dir_for_server,
+ delegate,
+ cx,
+ )
+ .await
+ .context("extracting server binary")?;
+ return Ok(dst_path);
+ }
+ Err(e) => {
+ log::error!(
+ "Failed to download binary on server, attempting to download locally and then upload it the server: {e:#}",
+ )
+ }
+ }
+ }
+
+ let src_path = delegate
+ .download_server_binary_locally(remote_platform, release_channel, wanted_version, cx)
+ .await
+ .context("downloading server binary locally")?;
+ self.upload_local_server_binary(
+ &src_path,
+ &tmp_path_gz,
+ &remote_dir_for_server,
+ delegate,
+ cx,
+ )
+ .await
+ .context("uploading server binary")?;
+ self.extract_server_binary(
+ &dst_path,
+ &tmp_path_gz,
+ &remote_dir_for_server,
+ delegate,
+ cx,
+ )
+ .await
+ .context("extracting server binary")?;
+ Ok(dst_path)
+ }
+
+ async fn docker_user_home_dir(&self) -> Result<String> {
+ let inner_program = self.shell();
+ self.run_docker_exec(
+ &inner_program,
+ None,
+ &Default::default(),
+ &["-c", "echo $HOME"],
+ )
+ .await
+ }
+
+ async fn extract_server_binary(
+ &self,
+ dst_path: &RelPath,
+ tmp_path: &RelPath,
+ remote_dir_for_server: &str,
+ delegate: &Arc<dyn RemoteClientDelegate>,
+ cx: &mut AsyncApp,
+ ) -> Result<()> {
+ delegate.set_status(Some("Extracting remote development server"), cx);
+ let server_mode = 0o755;
+
+ let shell_kind = ShellKind::Posix;
+ let orig_tmp_path = tmp_path.display(self.path_style());
+ let server_mode = format!("{:o}", server_mode);
+ let server_mode = shell_kind
+ .try_quote(&server_mode)
+ .context("shell quoting")?;
+ let dst_path = dst_path.display(self.path_style());
+ let dst_path = shell_kind.try_quote(&dst_path).context("shell quoting")?;
+ let script = if let Some(tmp_path) = orig_tmp_path.strip_suffix(".gz") {
+ let orig_tmp_path = shell_kind
+ .try_quote(&orig_tmp_path)
+ .context("shell quoting")?;
+ let tmp_path = shell_kind.try_quote(&tmp_path).context("shell quoting")?;
+ format!(
+ "gunzip -f {orig_tmp_path} && chmod {server_mode} {tmp_path} && mv {tmp_path} {dst_path}",
+ )
+ } else {
+ let orig_tmp_path = shell_kind
+ .try_quote(&orig_tmp_path)
+ .context("shell quoting")?;
+ format!("chmod {server_mode} {orig_tmp_path} && mv {orig_tmp_path} {dst_path}",)
+ };
+ let args = shell_kind.args_for_shell(false, script.to_string());
+ self.run_docker_exec(
+ "sh",
+ Some(&remote_dir_for_server),
+ &Default::default(),
+ &args,
+ )
+ .await
+ .log_err();
+ Ok(())
+ }
+
+ async fn upload_local_server_binary(
+ &self,
+ src_path: &Path,
+ tmp_path_gz: &RelPath,
+ remote_dir_for_server: &str,
+ delegate: &Arc<dyn RemoteClientDelegate>,
+ cx: &mut AsyncApp,
+ ) -> Result<()> {
+ if let Some(parent) = tmp_path_gz.parent() {
+ self.run_docker_exec(
+ "mkdir",
+ Some(remote_dir_for_server),
+ &Default::default(),
+ &["-p", parent.display(self.path_style()).as_ref()],
+ )
+ .await?;
+ }
+
+ let src_stat = smol::fs::metadata(&src_path).await?;
+ let size = src_stat.len();
+
+ let t0 = Instant::now();
+ delegate.set_status(Some("Uploading remote development server"), cx);
+ log::info!(
+ "uploading remote development server to {:?} ({}kb)",
+ tmp_path_gz,
+ size / 1024
+ );
+ self.upload_file(src_path, tmp_path_gz, remote_dir_for_server)
+ .await
+ .context("failed to upload server binary")?;
+ log::info!("uploaded remote development server in {:?}", t0.elapsed());
+ Ok(())
+ }
+
+ async fn upload_file(
+ &self,
+ src_path: &Path,
+ dest_path: &RelPath,
+ remote_dir_for_server: &str,
+ ) -> Result<()> {
+ log::debug!("uploading file {:?} to {:?}", src_path, dest_path);
+
+ let src_path_display = src_path.display().to_string();
+ let dest_path_str = dest_path.display(self.path_style());
+
+ let mut command = util::command::new_smol_command("docker");
+ command.arg("cp");
+ command.arg("-a");
+ command.arg(&src_path_display);
+ command.arg(format!(
+ "{}:{}/{}",
+ &self.connection_options.container_id, remote_dir_for_server, dest_path_str
+ ));
+
+ let output = command.output().await?;
+
+ if output.status.success() {
+ return Ok(());
+ }
+
+ let stderr = String::from_utf8_lossy(&output.stderr);
+ log::debug!(
+ "failed to upload file via docker cp {src_path_display} -> {dest_path_str}: {stderr}",
+ );
+ anyhow::bail!(
+ "failed to upload file via docker cp {} -> {}: {}",
+ src_path_display,
+ dest_path_str,
+ stderr,
+ );
+ }
+
+ async fn run_docker_command(
+ &self,
+ subcommand: &str,
+ args: &[impl AsRef<str>],
+ ) -> Result<String> {
+ let mut command = util::command::new_smol_command("docker");
+ command.arg(subcommand);
+ for arg in args {
+ command.arg(arg.as_ref());
+ }
+ let output = command.output().await?;
+ anyhow::ensure!(
+ output.status.success(),
+ "failed to run command {command:?}: {}",
+ String::from_utf8_lossy(&output.stderr)
+ );
+ Ok(String::from_utf8_lossy(&output.stdout).to_string())
+ }
+
+ async fn run_docker_exec(
+ &self,
+ inner_program: &str,
+ working_directory: Option<&str>,
+ env: &HashMap<String, String>,
+ program_args: &[impl AsRef<str>],
+ ) -> Result<String> {
+ let mut args = match working_directory {
+ Some(dir) => vec!["-w".to_string(), dir.to_string()],
+ None => vec![],
+ };
+
+ for (k, v) in env.iter() {
+ args.push("-e".to_string());
+ let env_declaration = format!("{}={}", k, v);
+ args.push(env_declaration);
+ }
+
+ args.push(self.connection_options.container_id.clone());
+ args.push(inner_program.to_string());
+
+ for arg in program_args {
+ args.push(arg.as_ref().to_owned());
+ }
+ self.run_docker_command("exec", args.as_ref()).await
+ }
+
+ async fn download_binary_on_server(
+ &self,
+ url: &str,
+ tmp_path_gz: &RelPath,
+ remote_dir_for_server: &str,
+ delegate: &Arc<dyn RemoteClientDelegate>,
+ cx: &mut AsyncApp,
+ ) -> Result<()> {
+ if let Some(parent) = tmp_path_gz.parent() {
+ self.run_docker_exec(
+ "mkdir",
+ Some(remote_dir_for_server),
+ &Default::default(),
+ &["-p", parent.display(self.path_style()).as_ref()],
+ )
+ .await?;
+ }
+
+ delegate.set_status(Some("Downloading remote development server on host"), cx);
+
+ match self
+ .run_docker_exec(
+ "curl",
+ Some(remote_dir_for_server),
+ &Default::default(),
+ &[
+ "-f",
+ "-L",
+ url,
+ "-o",
+ &tmp_path_gz.display(self.path_style()),
+ ],
+ )
+ .await
+ {
+ Ok(_) => {}
+ Err(e) => {
+ if self
+ .run_docker_exec("which", None, &Default::default(), &["curl"])
+ .await
+ .is_ok()
+ {
+ return Err(e);
+ }
+
+ log::info!("curl is not available, trying wget");
+ match self
+ .run_docker_exec(
+ "wget",
+ Some(remote_dir_for_server),
+ &Default::default(),
+ &[url, "-O", &tmp_path_gz.display(self.path_style())],
+ )
+ .await
+ {
+ Ok(_) => {}
+ Err(e) => {
+ if self
+ .run_docker_exec("which", None, &Default::default(), &["wget"])
+ .await
+ .is_ok()
+ {
+ return Err(e);
+ } else {
+ anyhow::bail!("Neither curl nor wget is available");
+ }
+ }
+ }
+ }
+ }
+ Ok(())
+ }
+
+ fn kill_inner(&self) -> Result<()> {
+ if let Some(pid) = self.proxy_process.lock().take() {
+ if let Ok(_) = util::command::new_smol_command("kill")
+ .arg(pid.to_string())
+ .spawn()
+ {
+ Ok(())
+ } else {
+ Err(anyhow::anyhow!("Failed to kill process"))
+ }
+ } else {
+ Ok(())
+ }
+ }
+}
+
+#[async_trait(?Send)]
+impl RemoteConnection for DockerExecConnection {
+ fn has_wsl_interop(&self) -> bool {
+ false
+ }
+ fn start_proxy(
+ &self,
+ unique_identifier: String,
+ reconnect: bool,
+ incoming_tx: UnboundedSender<Envelope>,
+ outgoing_rx: UnboundedReceiver<Envelope>,
+ connection_activity_tx: Sender<()>,
+ delegate: Arc<dyn RemoteClientDelegate>,
+ cx: &mut AsyncApp,
+ ) -> Task<Result<i32>> {
+ // We'll try connecting anew every time we open a devcontainer, so proactively try to kill any old connections.
+ if !self.has_been_killed() {
+ if let Err(e) = self.kill_inner() {
+ return Task::ready(Err(e));
+ };
+ }
+
+ delegate.set_status(Some("Starting proxy"), cx);
+
+ let Some(remote_binary_relpath) = self.remote_binary_relpath.clone() else {
+ return Task::ready(Err(anyhow!("Remote binary path not set")));
+ };
+
+ let mut docker_args = vec![
+ "exec".to_string(),
+ "-w".to_string(),
+ self.remote_dir_for_server.clone(),
+ "-i".to_string(),
+ self.connection_options.container_id.to_string(),
+ ];
+ for env_var in ["RUST_LOG", "RUST_BACKTRACE", "ZED_GENERATE_MINIDUMPS"] {
+ if let Some(value) = std::env::var(env_var).ok() {
+ docker_args.push("-e".to_string());
+ docker_args.push(format!("{}='{}'", env_var, value));
+ }
+ }
+ let val = remote_binary_relpath
+ .display(self.path_style())
+ .into_owned();
+ docker_args.push(val);
+ docker_args.push("proxy".to_string());
+ docker_args.push("--identifier".to_string());
+ docker_args.push(unique_identifier);
+ if reconnect {
+ docker_args.push("--reconnect".to_string());
+ }
+ let mut command = util::command::new_smol_command("docker");
+ command
+ .kill_on_drop(true)
+ .stdin(Stdio::piped())
+ .stdout(Stdio::piped())
+ .stderr(Stdio::piped())
+ .args(docker_args);
+
+ let Ok(child) = command.spawn() else {
+ return Task::ready(Err(anyhow::anyhow!(
+ "Failed to start remote server process"
+ )));
+ };
+
+ let mut proxy_process = self.proxy_process.lock();
+ *proxy_process = Some(child.id());
+
+ super::handle_rpc_messages_over_child_process_stdio(
+ child,
+ incoming_tx,
+ outgoing_rx,
+ connection_activity_tx,
+ cx,
+ )
+ }
+
+ fn upload_directory(
+ &self,
+ src_path: PathBuf,
+ dest_path: RemotePathBuf,
+ cx: &App,
+ ) -> Task<Result<()>> {
+ let dest_path_str = dest_path.to_string();
+ let src_path_display = src_path.display().to_string();
+
+ let mut command = util::command::new_smol_command("docker");
+ command.arg("cp");
+ command.arg("-a"); // Archive mode is required to assign the file ownership to the default docker exec user
+ command.arg(src_path_display);
+ command.arg(format!(
+ "{}:{}",
+ self.connection_options.container_id, dest_path_str
+ ));
+
+ cx.background_spawn(async move {
+ let output = command.output().await?;
+
+ if output.status.success() {
+ Ok(())
+ } else {
+ Err(anyhow::anyhow!("Failed to upload directory"))
+ }
+ })
+ }
+
+ async fn kill(&self) -> Result<()> {
+ self.kill_inner()
+ }
+
+ fn has_been_killed(&self) -> bool {
+ self.proxy_process.lock().is_none()
+ }
+
+ fn build_command(
+ &self,
+ program: Option<String>,
+ args: &[String],
+ env: &HashMap<String, String>,
+ working_dir: Option<String>,
+ _port_forward: Option<(u16, String, u16)>,
+ ) -> Result<CommandTemplate> {
+ let mut parsed_working_dir = None;
+
+ let path_style = self.path_style();
+
+ if let Some(working_dir) = working_dir {
+ let working_dir = RemotePathBuf::new(working_dir, path_style).to_string();
+
+ const TILDE_PREFIX: &'static str = "~/";
+ if working_dir.starts_with(TILDE_PREFIX) {
+ let working_dir = working_dir.trim_start_matches("~").trim_start_matches("/");
+ parsed_working_dir = Some(format!("$HOME/{working_dir}"));
+ } else {
+ parsed_working_dir = Some(working_dir);
+ }
+ }
+
+ let mut inner_program = Vec::new();
+
+ if let Some(program) = program {
+ inner_program.push(program);
+ for arg in args {
+ inner_program.push(arg.clone());
+ }
+ } else {
+ inner_program.push(self.shell());
+ inner_program.push("-l".to_string());
+ };
+
+ let mut docker_args = vec!["exec".to_string()];
+
+ if let Some(parsed_working_dir) = parsed_working_dir {
+ docker_args.push("-w".to_string());
+ docker_args.push(parsed_working_dir);
+ }
+
+ for (k, v) in env.iter() {
+ docker_args.push("-e".to_string());
+ docker_args.push(format!("{}={}", k, v));
+ }
+
+ docker_args.push("-it".to_string());
+ docker_args.push(self.connection_options.container_id.to_string());
+
+ docker_args.append(&mut inner_program);
+
+ Ok(CommandTemplate {
+ program: "docker".to_string(),
+ args: docker_args,
+ // Docker-exec pipes in environment via the "-e" argument
+ env: Default::default(),
+ })
+ }
+
+ fn build_forward_ports_command(
+ &self,
+ _forwards: Vec<(u16, String, u16)>,
+ ) -> Result<CommandTemplate> {
+ Err(anyhow::anyhow!("Not currently supported for docker_exec"))
+ }
+
+ fn connection_options(&self) -> RemoteConnectionOptions {
+ RemoteConnectionOptions::Docker(self.connection_options.clone())
+ }
+
+ fn path_style(&self) -> PathStyle {
+ self.path_style.unwrap_or(PathStyle::Posix)
+ }
+
+ fn shell(&self) -> String {
+ match &self.shell {
+ Some(shell) => shell.clone(),
+ None => self.default_system_shell(),
+ }
+ }
+
+ fn default_system_shell(&self) -> String {
+ String::from("/bin/sh")
+ }
+}
@@ -31,7 +31,8 @@ use tempfile::TempDir;
use util::{
paths::{PathStyle, RemotePathBuf},
rel_path::RelPath,
- shell::ShellKind,
+ shell::{Shell, ShellKind},
+ shell_builder::ShellBuilder,
};
pub(crate) struct SshRemoteConnection {
@@ -1362,6 +1363,8 @@ fn build_command(
} else {
write!(exec, "{ssh_shell} -l")?;
};
+ let (command, command_args) = ShellBuilder::new(&Shell::Program(ssh_shell.to_owned()), false)
+ .build(Some(exec.clone()), &[]);
let mut args = Vec::new();
args.extend(ssh_args);
@@ -1372,7 +1375,9 @@ fn build_command(
}
args.push("-t".into());
- args.push(exec);
+ args.push(command);
+ args.extend(command_args);
+
Ok(CommandTemplate {
program: "ssh".into(),
args,
@@ -1411,6 +1416,9 @@ mod tests {
"-p",
"2222",
"-t",
+ "/bin/fish",
+ "-i",
+ "-c",
"cd \"$HOME/work\" && exec env INPUT_VA=val remote_program arg1 arg2"
]
);
@@ -1443,6 +1451,9 @@ mod tests {
"-L",
"1:foo:2",
"-t",
+ "/bin/fish",
+ "-i",
+ "-c",
"cd && exec env INPUT_VA=val /bin/fish -l"
]
);
@@ -23,7 +23,8 @@ use std::{
use util::{
paths::{PathStyle, RemotePathBuf},
rel_path::RelPath,
- shell::ShellKind,
+ shell::{Shell, ShellKind},
+ shell_builder::ShellBuilder,
};
#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Deserialize, schemars::JsonSchema)]
@@ -453,8 +454,10 @@ impl RemoteConnection for WslRemoteConnection {
} else {
write!(&mut exec, "{} -l", self.shell)?;
}
+ let (command, args) =
+ ShellBuilder::new(&Shell::Program(self.shell.clone()), false).build(Some(exec), &[]);
- let wsl_args = if let Some(user) = &self.connection_options.user {
+ let mut wsl_args = if let Some(user) = &self.connection_options.user {
vec![
"--distribution".to_string(),
self.connection_options.distro_name.clone(),
@@ -463,9 +466,7 @@ impl RemoteConnection for WslRemoteConnection {
"--cd".to_string(),
working_dir,
"--".to_string(),
- self.shell.clone(),
- "-c".to_string(),
- exec,
+ command,
]
} else {
vec![
@@ -474,11 +475,10 @@ impl RemoteConnection for WslRemoteConnection {
"--cd".to_string(),
working_dir,
"--".to_string(),
- self.shell.clone(),
- "-c".to_string(),
- exec,
+ command,
]
};
+ wsl_args.extend(args);
Ok(CommandTemplate {
program: "wsl.exe".to_string(),
@@ -270,26 +270,6 @@ impl http_client::HttpClient for ReqwestClient {
}
.boxed()
}
-
- fn send_multipart_form<'a>(
- &'a self,
- url: &str,
- form: reqwest::multipart::Form,
- ) -> futures::future::BoxFuture<'a, anyhow::Result<http_client::Response<http_client::AsyncBody>>>
- {
- let response = self.client.post(url).multipart(form).send();
- self.handle
- .spawn(async move {
- let response = response.await?;
- let mut builder = http::response::Builder::new().status(response.status());
- for (k, v) in response.headers() {
- builder = builder.header(k, v)
- }
- Ok(builder.body(response.bytes().await?.into())?)
- })
- .map(|e| e?)
- .boxed()
- }
}
#[cfg(test)]
@@ -511,6 +511,11 @@ pub struct GitPanelSettingsContent {
///
/// Default: false
pub collapse_untracked_diff: Option<bool>,
+
+ /// Whether to show entries with tree or flat view in the panel
+ ///
+ /// Default: false
+ pub tree_view: Option<bool>,
}
#[derive(
@@ -889,9 +894,19 @@ pub enum ImageFileSizeUnit {
pub struct RemoteSettingsContent {
pub ssh_connections: Option<Vec<SshConnection>>,
pub wsl_connections: Option<Vec<WslConnection>>,
+ pub dev_container_connections: Option<Vec<DevContainerConnection>>,
pub read_ssh_config: Option<bool>,
}
+#[with_fallible_options]
+#[derive(
+ Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq, JsonSchema, MergeFrom, Hash,
+)]
+pub struct DevContainerConnection {
+ pub name: SharedString,
+ pub container_id: SharedString,
+}
+
#[with_fallible_options]
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, JsonSchema, MergeFrom)]
pub struct SshConnection {
@@ -901,7 +916,7 @@ pub struct SshConnection {
#[serde(default)]
pub args: Vec<String>,
#[serde(default)]
- pub projects: collections::BTreeSet<SshProject>,
+ pub projects: collections::BTreeSet<RemoteProject>,
/// Name to use for this server in UI.
pub nickname: Option<String>,
// By default Zed will download the binary to the host directly.
@@ -918,14 +933,14 @@ pub struct WslConnection {
pub distro_name: SharedString,
pub user: Option<String>,
#[serde(default)]
- pub projects: BTreeSet<SshProject>,
+ pub projects: BTreeSet<RemoteProject>,
}
#[with_fallible_options]
#[derive(
Clone, Debug, Default, Serialize, PartialEq, Eq, PartialOrd, Ord, Deserialize, JsonSchema,
)]
-pub struct SshProject {
+pub struct RemoteProject {
pub paths: Vec<String>,
}
@@ -4314,6 +4314,24 @@ pub(crate) fn settings_data(cx: &App) -> Vec<SettingsPage> {
metadata: None,
files: USER,
}),
+ SettingsPageItem::SettingItem(SettingItem {
+ title: "Tree View",
+ description: "Enable to show entries in tree view list, disable to show in flat view list.",
+ field: Box::new(SettingField {
+ json_path: Some("git_panel.tree_view"),
+ pick: |settings_content| {
+ settings_content.git_panel.as_ref()?.tree_view.as_ref()
+ },
+ write: |settings_content, value| {
+ settings_content
+ .git_panel
+ .get_or_insert_default()
+ .tree_view = value;
+ },
+ }),
+ metadata: None,
+ files: USER,
+ }),
SettingsPageItem::SettingItem(SettingItem {
title: "Scroll Bar",
description: "How and when the scrollbar should be displayed.",
@@ -5,8 +5,8 @@ use editor::Editor;
use fuzzy::{StringMatch, StringMatchCandidate};
use gpui::{
Action, AnyElement, App, AppContext as _, Context, DismissEvent, Entity, EventEmitter,
- Focusable, InteractiveElement, ParentElement, Render, SharedString, Styled, Subscription, Task,
- WeakEntity, Window, rems,
+ Focusable, InteractiveElement, ParentElement, Render, Styled, Subscription, Task, WeakEntity,
+ Window, rems,
};
use itertools::Itertools;
use picker::{Picker, PickerDelegate, highlighted_match_with_paths::HighlightedMatch};
@@ -526,7 +526,7 @@ impl PickerDelegate for TasksModalDelegate {
};
Some(
- ListItem::new(SharedString::from(format!("tasks-modal-{ix}")))
+ ListItem::new(format!("tasks-modal-{ix}"))
.inset(true)
.start_slot::<IconWithIndicator>(icon)
.end_slot::<AnyElement>(
@@ -39,6 +39,7 @@ pub use subscription::*;
pub use sum_tree::Bias;
use sum_tree::{Dimensions, FilterCursor, SumTree, TreeMap, TreeSet};
use undo_map::UndoMap;
+use util::debug_panic;
#[cfg(any(test, feature = "test-support"))]
use util::RandomCharIter;
@@ -2439,7 +2440,7 @@ impl BufferSnapshot {
if bias == Bias::Left && offset == 0 {
Anchor::min_for_buffer(self.remote_id)
} else if bias == Bias::Right
- && ((cfg!(debug_assertions) && offset >= self.len()) || offset == self.len())
+ && ((!cfg!(debug_assertions) && offset >= self.len()) || offset == self.len())
{
Anchor::max_for_buffer(self.remote_id)
} else {
@@ -2453,7 +2454,15 @@ impl BufferSnapshot {
};
}
let (start, _, item) = self.fragments.find::<usize, _>(&None, &offset, bias);
- let fragment = item.unwrap();
+ let Some(fragment) = item else {
+ // We got a bad offset, likely out of bounds
+ debug_panic!(
+ "Failed to find fragment at offset {} (len: {})",
+ offset,
+ self.len()
+ );
+ return Anchor::max_for_buffer(self.remote_id);
+ };
let overshoot = offset - start;
Anchor {
timestamp: fragment.timestamp,
@@ -151,10 +151,10 @@ impl ApplicationMenu {
// Application menu must have same ids as first menu item in standard menu
div()
- .id(SharedString::from(format!("{}-menu-item", menu_name)))
+ .id(format!("{}-menu-item", menu_name))
.occlude()
.child(
- PopoverMenu::new(SharedString::from(format!("{}-menu-popover", menu_name)))
+ PopoverMenu::new(format!("{}-menu-popover", menu_name))
.menu(move |window, cx| {
Self::build_menu_from_items(entry.clone(), window, cx).into()
})
@@ -184,10 +184,10 @@ impl ApplicationMenu {
.collect();
div()
- .id(SharedString::from(format!("{}-menu-item", menu_name)))
+ .id(format!("{}-menu-item", menu_name))
.occlude()
.child(
- PopoverMenu::new(SharedString::from(format!("{}-menu-popover", menu_name)))
+ PopoverMenu::new(format!("{}-menu-popover", menu_name))
.menu(move |window, cx| {
Self::build_menu_from_items(entry.clone(), window, cx).into()
})
@@ -323,12 +323,18 @@ impl TitleBar {
let options = self.project.read(cx).remote_connection_options(cx)?;
let host: SharedString = options.display_name().into();
- let (nickname, icon) = match options {
- RemoteConnectionOptions::Ssh(options) => {
- (options.nickname.map(|nick| nick.into()), IconName::Server)
+ let (nickname, tooltip_title, icon) = match options {
+ RemoteConnectionOptions::Ssh(options) => (
+ options.nickname.map(|nick| nick.into()),
+ "Remote Project",
+ IconName::Server,
+ ),
+ RemoteConnectionOptions::Wsl(_) => (None, "Remote Project", IconName::Linux),
+ RemoteConnectionOptions::Docker(_dev_container_connection) => {
+ (None, "Dev Container", IconName::Box)
}
- RemoteConnectionOptions::Wsl(_) => (None, IconName::Linux),
};
+
let nickname = nickname.unwrap_or_else(|| host.clone());
let (indicator_color, meta) = match self.project.read(cx).remote_connection_state(cx)? {
@@ -375,7 +381,7 @@ impl TitleBar {
)
.tooltip(move |_window, cx| {
Tooltip::with_meta(
- "Remote Project",
+ tooltip_title,
Some(&OpenRemote {
from_existing_connection: false,
create_new_window: false,
@@ -124,7 +124,7 @@ impl ActiveToolchain {
&buffer,
window,
|this, _, event: &BufferEvent, window, cx| {
- if matches!(event, BufferEvent::LanguageChanged) {
+ if matches!(event, BufferEvent::LanguageChanged(_)) {
this._update_toolchain_task = Self::spawn_tracker_task(window, cx);
}
},
@@ -56,7 +56,10 @@ pub enum ShellKind {
Tcsh,
Rc,
Fish,
+ /// Pre-installed "legacy" powershell for windows
PowerShell,
+ /// PowerShell 7.x
+ Pwsh,
Nushell,
Cmd,
Xonsh,
@@ -238,6 +241,7 @@ impl fmt::Display for ShellKind {
ShellKind::Tcsh => write!(f, "tcsh"),
ShellKind::Fish => write!(f, "fish"),
ShellKind::PowerShell => write!(f, "powershell"),
+ ShellKind::Pwsh => write!(f, "pwsh"),
ShellKind::Nushell => write!(f, "nu"),
ShellKind::Cmd => write!(f, "cmd"),
ShellKind::Rc => write!(f, "rc"),
@@ -260,7 +264,8 @@ impl ShellKind {
.to_string_lossy();
match &*program {
- "powershell" | "pwsh" => ShellKind::PowerShell,
+ "powershell" => ShellKind::PowerShell,
+ "pwsh" => ShellKind::Pwsh,
"cmd" => ShellKind::Cmd,
"nu" => ShellKind::Nushell,
"fish" => ShellKind::Fish,
@@ -279,7 +284,7 @@ impl ShellKind {
pub fn to_shell_variable(self, input: &str) -> String {
match self {
- Self::PowerShell => Self::to_powershell_variable(input),
+ Self::PowerShell | Self::Pwsh => Self::to_powershell_variable(input),
Self::Cmd => Self::to_cmd_variable(input),
Self::Posix => input.to_owned(),
Self::Fish => input.to_owned(),
@@ -407,8 +412,12 @@ impl ShellKind {
pub fn args_for_shell(&self, interactive: bool, combined_command: String) -> Vec<String> {
match self {
- ShellKind::PowerShell => vec!["-C".to_owned(), combined_command],
- ShellKind::Cmd => vec!["/C".to_owned(), combined_command],
+ ShellKind::PowerShell | ShellKind::Pwsh => vec!["-C".to_owned(), combined_command],
+ ShellKind::Cmd => vec![
+ "/S".to_owned(),
+ "/C".to_owned(),
+ format!("\"{combined_command}\""),
+ ],
ShellKind::Posix
| ShellKind::Nushell
| ShellKind::Fish
@@ -426,7 +435,7 @@ impl ShellKind {
pub const fn command_prefix(&self) -> Option<char> {
match self {
- ShellKind::PowerShell => Some('&'),
+ ShellKind::PowerShell | ShellKind::Pwsh => Some('&'),
ShellKind::Nushell => Some('^'),
ShellKind::Posix
| ShellKind::Csh
@@ -457,6 +466,7 @@ impl ShellKind {
| ShellKind::Rc
| ShellKind::Fish
| ShellKind::PowerShell
+ | ShellKind::Pwsh
| ShellKind::Nushell
| ShellKind::Xonsh
| ShellKind::Elvish => ';',
@@ -471,6 +481,7 @@ impl ShellKind {
| ShellKind::Tcsh
| ShellKind::Rc
| ShellKind::Fish
+ | ShellKind::Pwsh
| ShellKind::PowerShell
| ShellKind::Xonsh => "&&",
ShellKind::Nushell | ShellKind::Elvish => ";",
@@ -478,11 +489,10 @@ impl ShellKind {
}
pub fn try_quote<'a>(&self, arg: &'a str) -> Option<Cow<'a, str>> {
- shlex::try_quote(arg).ok().map(|arg| match self {
- // If we are running in PowerShell, we want to take extra care when escaping strings.
- // In particular, we want to escape strings with a backtick (`) rather than a backslash (\).
- ShellKind::PowerShell => Cow::Owned(arg.replace("\\\"", "`\"").replace("\\\\", "\\")),
- ShellKind::Cmd => Cow::Owned(arg.replace("\\\\", "\\")),
+ match self {
+ ShellKind::PowerShell => Some(Self::quote_powershell(arg)),
+ ShellKind::Pwsh => Some(Self::quote_pwsh(arg)),
+ ShellKind::Cmd => Some(Self::quote_cmd(arg)),
ShellKind::Posix
| ShellKind::Csh
| ShellKind::Tcsh
@@ -490,8 +500,173 @@ impl ShellKind {
| ShellKind::Fish
| ShellKind::Nushell
| ShellKind::Xonsh
- | ShellKind::Elvish => arg,
- })
+ | ShellKind::Elvish => shlex::try_quote(arg).ok(),
+ }
+ }
+
+ fn quote_windows(arg: &str, enclose: bool) -> Cow<'_, str> {
+ if arg.is_empty() {
+ return Cow::Borrowed("\"\"");
+ }
+
+ let needs_quoting = arg.chars().any(|c| c == ' ' || c == '\t' || c == '"');
+ if !needs_quoting {
+ return Cow::Borrowed(arg);
+ }
+
+ let mut result = String::with_capacity(arg.len() + 2);
+
+ if enclose {
+ result.push('"');
+ }
+
+ let chars: Vec<char> = arg.chars().collect();
+ let mut i = 0;
+
+ while i < chars.len() {
+ if chars[i] == '\\' {
+ let mut num_backslashes = 0;
+ while i < chars.len() && chars[i] == '\\' {
+ num_backslashes += 1;
+ i += 1;
+ }
+
+ if i < chars.len() && chars[i] == '"' {
+ // Backslashes followed by quote: double the backslashes and escape the quote
+ for _ in 0..(num_backslashes * 2 + 1) {
+ result.push('\\');
+ }
+ result.push('"');
+ i += 1;
+ } else if i >= chars.len() {
+ // Trailing backslashes: double them (they precede the closing quote)
+ for _ in 0..(num_backslashes * 2) {
+ result.push('\\');
+ }
+ } else {
+ // Backslashes not followed by quote: output as-is
+ for _ in 0..num_backslashes {
+ result.push('\\');
+ }
+ }
+ } else if chars[i] == '"' {
+ // Quote not preceded by backslash: escape it
+ result.push('\\');
+ result.push('"');
+ i += 1;
+ } else {
+ result.push(chars[i]);
+ i += 1;
+ }
+ }
+
+ if enclose {
+ result.push('"');
+ }
+ Cow::Owned(result)
+ }
+
+ fn needs_quoting_powershell(s: &str) -> bool {
+ s.is_empty()
+ || s.chars().any(|c| {
+ c.is_whitespace()
+ || matches!(
+ c,
+ '"' | '`'
+ | '$'
+ | '&'
+ | '|'
+ | '<'
+ | '>'
+ | ';'
+ | '('
+ | ')'
+ | '['
+ | ']'
+ | '{'
+ | '}'
+ | ','
+ | '\''
+ | '@'
+ )
+ })
+ }
+
+ fn need_quotes_powershell(arg: &str) -> bool {
+ let mut quote_count = 0;
+ for c in arg.chars() {
+ if c == '"' {
+ quote_count += 1;
+ } else if c.is_whitespace() && (quote_count % 2 == 0) {
+ return true;
+ }
+ }
+ false
+ }
+
+ fn escape_powershell_quotes(s: &str) -> String {
+ let mut result = String::with_capacity(s.len() + 4);
+ result.push('\'');
+ for c in s.chars() {
+ if c == '\'' {
+ result.push('\'');
+ }
+ result.push(c);
+ }
+ result.push('\'');
+ result
+ }
+
+ pub fn quote_powershell(arg: &str) -> Cow<'_, str> {
+ let ps_will_quote = Self::need_quotes_powershell(arg);
+ let crt_quoted = Self::quote_windows(arg, !ps_will_quote);
+
+ if !Self::needs_quoting_powershell(arg) {
+ return crt_quoted;
+ }
+
+ Cow::Owned(Self::escape_powershell_quotes(&crt_quoted))
+ }
+
+ pub fn quote_pwsh(arg: &str) -> Cow<'_, str> {
+ if arg.is_empty() {
+ return Cow::Borrowed("''");
+ }
+
+ if !Self::needs_quoting_powershell(arg) {
+ return Cow::Borrowed(arg);
+ }
+
+ Cow::Owned(Self::escape_powershell_quotes(arg))
+ }
+
+ pub fn quote_cmd(arg: &str) -> Cow<'_, str> {
+ let crt_quoted = Self::quote_windows(arg, true);
+
+ let needs_cmd_escaping = crt_quoted.contains('"')
+ || crt_quoted.contains('%')
+ || crt_quoted
+ .chars()
+ .any(|c| matches!(c, '^' | '<' | '>' | '&' | '|' | '(' | ')'));
+
+ if !needs_cmd_escaping {
+ return crt_quoted;
+ }
+
+ let mut result = String::with_capacity(crt_quoted.len() * 2);
+ for c in crt_quoted.chars() {
+ match c {
+ '^' | '"' | '<' | '>' | '&' | '|' | '(' | ')' => {
+ result.push('^');
+ result.push(c);
+ }
+ '%' => {
+ result.push_str("%%cd:~,%");
+ }
+ _ => result.push(c),
+ }
+ }
+ Cow::Owned(result)
}
/// Quotes the given argument if necessary, taking into account the command prefix.
@@ -538,7 +713,7 @@ impl ShellKind {
match self {
ShellKind::Cmd => "",
ShellKind::Nushell => "overlay use",
- ShellKind::PowerShell => ".",
+ ShellKind::PowerShell | ShellKind::Pwsh => ".",
ShellKind::Fish
| ShellKind::Csh
| ShellKind::Tcsh
@@ -558,6 +733,7 @@ impl ShellKind {
| ShellKind::Rc
| ShellKind::Fish
| ShellKind::PowerShell
+ | ShellKind::Pwsh
| ShellKind::Nushell
| ShellKind::Xonsh
| ShellKind::Elvish => "clear",
@@ -576,6 +752,7 @@ impl ShellKind {
| ShellKind::Rc
| ShellKind::Fish
| ShellKind::PowerShell
+ | ShellKind::Pwsh
| ShellKind::Nushell
| ShellKind::Xonsh
| ShellKind::Elvish => true,
@@ -605,7 +782,7 @@ mod tests {
.try_quote("C:\\Users\\johndoe\\dev\\python\\39007\\tests\\.venv\\Scripts\\python.exe -m pytest \"test_foo.py::test_foo\"")
.unwrap()
.into_owned(),
- "\"C:\\Users\\johndoe\\dev\\python\\39007\\tests\\.venv\\Scripts\\python.exe -m pytest `\"test_foo.py::test_foo`\"\"".to_string()
+ "'C:\\Users\\johndoe\\dev\\python\\39007\\tests\\.venv\\Scripts\\python.exe -m pytest \\\"test_foo.py::test_foo\\\"'".to_string()
);
}
@@ -617,7 +794,113 @@ mod tests {
.try_quote("C:\\Users\\johndoe\\dev\\python\\39007\\tests\\.venv\\Scripts\\python.exe -m pytest \"test_foo.py::test_foo\"")
.unwrap()
.into_owned(),
- "\"C:\\Users\\johndoe\\dev\\python\\39007\\tests\\.venv\\Scripts\\python.exe -m pytest \\\"test_foo.py::test_foo\\\"\"".to_string()
+ "^\"C:\\Users\\johndoe\\dev\\python\\39007\\tests\\.venv\\Scripts\\python.exe -m pytest \\^\"test_foo.py::test_foo\\^\"^\"".to_string()
+ );
+ }
+
+ #[test]
+ fn test_try_quote_powershell_edge_cases() {
+ let shell_kind = ShellKind::PowerShell;
+
+ // Empty string
+ assert_eq!(
+ shell_kind.try_quote("").unwrap().into_owned(),
+ "'\"\"'".to_string()
+ );
+
+ // String without special characters (no quoting needed)
+ assert_eq!(shell_kind.try_quote("simple").unwrap(), "simple");
+
+ // String with spaces
+ assert_eq!(
+ shell_kind.try_quote("hello world").unwrap().into_owned(),
+ "'hello world'".to_string()
+ );
+
+ // String with dollar signs
+ assert_eq!(
+ shell_kind.try_quote("$variable").unwrap().into_owned(),
+ "'$variable'".to_string()
+ );
+
+ // String with backticks
+ assert_eq!(
+ shell_kind.try_quote("test`command").unwrap().into_owned(),
+ "'test`command'".to_string()
+ );
+
+ // String with multiple special characters
+ assert_eq!(
+ shell_kind
+ .try_quote("test `\"$var`\" end")
+ .unwrap()
+ .into_owned(),
+ "'test `\\\"$var`\\\" end'".to_string()
+ );
+
+ // String with backslashes and colon (path without spaces doesn't need quoting)
+ assert_eq!(
+ shell_kind.try_quote("C:\\path\\to\\file").unwrap(),
+ "C:\\path\\to\\file"
+ );
+ }
+
+ #[test]
+ fn test_try_quote_cmd_edge_cases() {
+ let shell_kind = ShellKind::Cmd;
+
+ // Empty string
+ assert_eq!(
+ shell_kind.try_quote("").unwrap().into_owned(),
+ "^\"^\"".to_string()
+ );
+
+ // String without special characters (no quoting needed)
+ assert_eq!(shell_kind.try_quote("simple").unwrap(), "simple");
+
+ // String with spaces
+ assert_eq!(
+ shell_kind.try_quote("hello world").unwrap().into_owned(),
+ "^\"hello world^\"".to_string()
+ );
+
+ // String with space and backslash (backslash not at end, so not doubled)
+ assert_eq!(
+ shell_kind.try_quote("path\\ test").unwrap().into_owned(),
+ "^\"path\\ test^\"".to_string()
+ );
+
+ // String ending with backslash (must be doubled before closing quote)
+ assert_eq!(
+ shell_kind.try_quote("test path\\").unwrap().into_owned(),
+ "^\"test path\\\\^\"".to_string()
+ );
+
+ // String ending with multiple backslashes (all doubled before closing quote)
+ assert_eq!(
+ shell_kind.try_quote("test path\\\\").unwrap().into_owned(),
+ "^\"test path\\\\\\\\^\"".to_string()
+ );
+
+ // String with embedded quote (quote is escaped, backslash before it is doubled)
+ assert_eq!(
+ shell_kind.try_quote("test\\\"quote").unwrap().into_owned(),
+ "^\"test\\\\\\^\"quote^\"".to_string()
+ );
+
+ // String with multiple backslashes before embedded quote (all doubled)
+ assert_eq!(
+ shell_kind
+ .try_quote("test\\\\\"quote")
+ .unwrap()
+ .into_owned(),
+ "^\"test\\\\\\\\\\^\"quote^\"".to_string()
+ );
+
+ // String with backslashes not before quotes (path without spaces doesn't need quoting)
+ assert_eq!(
+ shell_kind.try_quote("C:\\path\\to\\file").unwrap(),
+ "C:\\path\\to\\file"
);
}
@@ -1,3 +1,5 @@
+use std::borrow::Cow;
+
use crate::shell::get_system_shell;
use crate::shell::{Shell, ShellKind};
@@ -42,7 +44,7 @@ impl ShellBuilder {
self.program.clone()
} else {
match self.kind {
- ShellKind::PowerShell => {
+ ShellKind::PowerShell | ShellKind::Pwsh => {
format!("{} -C '{}'", self.program, command_to_use_in_label)
}
ShellKind::Cmd => {
@@ -78,11 +80,27 @@ impl ShellBuilder {
task_args: &[String],
) -> (String, Vec<String>) {
if let Some(task_command) = task_command {
- let mut combined_command = task_args.iter().fold(task_command, |mut command, arg| {
- command.push(' ');
- command.push_str(&self.kind.to_shell_variable(arg));
- command
- });
+ let task_command = self.kind.prepend_command_prefix(&task_command);
+ let task_command = if !task_args.is_empty() {
+ match self.kind.try_quote_prefix_aware(&task_command) {
+ Some(task_command) => task_command,
+ None => task_command,
+ }
+ } else {
+ task_command
+ };
+ let mut combined_command =
+ task_args
+ .iter()
+ .fold(task_command.into_owned(), |mut command, arg| {
+ command.push(' ');
+ let shell_variable = self.kind.to_shell_variable(arg);
+ command.push_str(&match self.kind.try_quote(&shell_variable) {
+ Some(shell_variable) => shell_variable,
+ None => Cow::Owned(shell_variable),
+ });
+ command
+ });
if self.redirect_stdin {
match self.kind {
ShellKind::Fish => {
@@ -99,7 +117,7 @@ impl ShellBuilder {
combined_command.insert(0, '(');
combined_command.push_str(") </dev/null");
}
- ShellKind::PowerShell => {
+ ShellKind::PowerShell | ShellKind::Pwsh => {
combined_command.insert_str(0, "$null | & {");
combined_command.push_str("}");
}
@@ -115,6 +133,10 @@ impl ShellBuilder {
(self.program, self.args)
}
+
+ pub fn kind(&self) -> ShellKind {
+ self.kind
+ }
}
#[cfg(test)]
@@ -144,7 +166,7 @@ mod test {
vec![
"-i",
"-c",
- "echo $env.hello $env.world nothing --($env.something) $ ${test"
+ "^echo '$env.hello' '$env.world' nothing '--($env.something)' '$' '${test'"
]
);
}
@@ -159,7 +181,7 @@ mod test {
.build(Some("echo".into()), &["nothing".to_string()]);
assert_eq!(program, "nu");
- assert_eq!(args, vec!["-i", "-c", "(echo nothing) </dev/null"]);
+ assert_eq!(args, vec!["-i", "-c", "(^echo nothing) </dev/null"]);
}
#[test]
@@ -159,7 +159,7 @@ async fn capture_windows(
zed_path.display()
),
]),
- ShellKind::PowerShell => cmd.args([
+ ShellKind::PowerShell | ShellKind::Pwsh => cmd.args([
"-NonInteractive",
"-NoProfile",
"-Command",
@@ -773,6 +773,52 @@ mod test {
"});
}
+ #[gpui::test]
+ async fn test_paste_system_clipboard_never(cx: &mut gpui::TestAppContext) {
+ let mut cx = VimTestContext::new(cx, true).await;
+
+ cx.update_global(|store: &mut SettingsStore, cx| {
+ store.update_user_settings(cx, |s| {
+ s.vim.get_or_insert_default().use_system_clipboard = Some(UseSystemClipboard::Never)
+ });
+ });
+
+ cx.set_state(
+ indoc! {"
+ ˇThe quick brown
+ fox jumps over
+ the lazy dog"},
+ Mode::Normal,
+ );
+
+ cx.write_to_clipboard(ClipboardItem::new_string("something else".to_string()));
+
+ cx.simulate_keystrokes("d d");
+ cx.assert_state(
+ indoc! {"
+ ˇfox jumps over
+ the lazy dog"},
+ Mode::Normal,
+ );
+
+ cx.simulate_keystrokes("shift-v p");
+ cx.assert_state(
+ indoc! {"
+ ˇThe quick brown
+ the lazy dog"},
+ Mode::Normal,
+ );
+
+ cx.simulate_keystrokes("shift-v");
+ cx.dispatch_action(editor::actions::Paste);
+ cx.assert_state(
+ indoc! {"
+ ˇsomething else
+ the lazy dog"},
+ Mode::Normal,
+ );
+ }
+
#[gpui::test]
async fn test_numbered_registers(cx: &mut gpui::TestAppContext) {
let mut cx = NeovimBackedTestContext::new(cx).await;
@@ -294,11 +294,10 @@ mod test {
async fn test_scroll(cx: &mut gpui::TestAppContext) {
let mut cx = VimTestContext::new(cx, true).await;
- let (line_height, visible_line_count) = cx.editor(|editor, window, _cx| {
+ let (line_height, visible_line_count) = cx.update_editor(|editor, window, cx| {
(
editor
- .style()
- .unwrap()
+ .style(cx)
.text
.line_height_in_pixels(window.rem_size()),
editor.visible_line_count().unwrap(),
@@ -2399,7 +2399,7 @@ async fn test_clipping_on_mode_change(cx: &mut gpui::TestAppContext) {
.end;
editor.last_bounds().unwrap().origin
+ editor
- .display_to_pixel_point(current_head, &snapshot, window)
+ .display_to_pixel_point(current_head, &snapshot, window, cx)
.unwrap()
});
pixel_position.x += px(100.);
@@ -304,11 +304,10 @@ impl NeovimBackedTestContext {
self.neovim.set_option(&format!("scrolloff={}", 3)).await;
// +2 to account for the vim command UI at the bottom.
self.neovim.set_option(&format!("lines={}", rows + 2)).await;
- let (line_height, visible_line_count) = self.editor(|editor, window, _cx| {
+ let (line_height, visible_line_count) = self.update_editor(|editor, window, cx| {
(
editor
- .style()
- .unwrap()
+ .style(cx)
.text
.line_height_in_pixels(window.rem_size()),
editor.visible_line_count().unwrap(),
@@ -924,6 +924,7 @@ impl Vim {
|vim, _: &editor::actions::Paste, window, cx| match vim.mode {
Mode::Replace => vim.paste_replace(window, cx),
Mode::Visual | Mode::VisualLine | Mode::VisualBlock => {
+ vim.selected_register.replace('+');
vim.paste(&VimPaste::default(), window, cx);
}
_ => {
@@ -11,6 +11,7 @@ use zed_actions::workspace::OpenWithSystem;
use crate::Item;
/// A view to display when a certain buffer/image/other item fails to open.
+#[derive(Debug)]
pub struct InvalidItemView {
/// Which path was attempted to open.
pub abs_path: Arc<Path>,
@@ -20,7 +20,9 @@ use project::debugger::breakpoint_store::{BreakpointState, SourceBreakpoint};
use language::{LanguageName, Toolchain, ToolchainScope};
use project::WorktreeId;
-use remote::{RemoteConnectionOptions, SshConnectionOptions, WslConnectionOptions};
+use remote::{
+ DockerConnectionOptions, RemoteConnectionOptions, SshConnectionOptions, WslConnectionOptions,
+};
use sqlez::{
bindable::{Bind, Column, StaticColumnCount},
statement::Statement,
@@ -702,6 +704,10 @@ impl Domain for WorkspaceDb {
sql!(
DROP TABLE ssh_connections;
),
+ sql!(
+ ALTER TABLE remote_connections ADD COLUMN name TEXT;
+ ALTER TABLE remote_connections ADD COLUMN container_id TEXT;
+ ),
];
// Allow recovering from bad migration that was initially shipped to nightly
@@ -728,9 +734,9 @@ impl WorkspaceDb {
pub(crate) fn remote_workspace_for_roots<P: AsRef<Path>>(
&self,
worktree_roots: &[P],
- ssh_project_id: RemoteConnectionId,
+ remote_project_id: RemoteConnectionId,
) -> Option<SerializedWorkspace> {
- self.workspace_for_roots_internal(worktree_roots, Some(ssh_project_id))
+ self.workspace_for_roots_internal(worktree_roots, Some(remote_project_id))
}
pub(crate) fn workspace_for_roots_internal<P: AsRef<Path>>(
@@ -806,9 +812,20 @@ impl WorkspaceDb {
order: paths_order,
});
+ let remote_connection_options = if let Some(remote_connection_id) = remote_connection_id {
+ self.remote_connection(remote_connection_id)
+ .context("Get remote connection")
+ .log_err()
+ } else {
+ None
+ };
+
Some(SerializedWorkspace {
id: workspace_id,
- location: SerializedWorkspaceLocation::Local,
+ location: match remote_connection_options {
+ Some(options) => SerializedWorkspaceLocation::Remote(options),
+ None => SerializedWorkspaceLocation::Local,
+ },
paths,
center_group: self
.get_center_pane_group(workspace_id)
@@ -1110,10 +1127,12 @@ impl WorkspaceDb {
options: RemoteConnectionOptions,
) -> Result<RemoteConnectionId> {
let kind;
- let user;
+ let mut user = None;
let mut host = None;
let mut port = None;
let mut distro = None;
+ let mut name = None;
+ let mut container_id = None;
match options {
RemoteConnectionOptions::Ssh(options) => {
kind = RemoteConnectionKind::Ssh;
@@ -1126,8 +1145,22 @@ impl WorkspaceDb {
distro = Some(options.distro_name);
user = options.user;
}
+ RemoteConnectionOptions::Docker(options) => {
+ kind = RemoteConnectionKind::Docker;
+ container_id = Some(options.container_id);
+ name = Some(options.name);
+ }
}
- Self::get_or_create_remote_connection_query(this, kind, host, port, user, distro)
+ Self::get_or_create_remote_connection_query(
+ this,
+ kind,
+ host,
+ port,
+ user,
+ distro,
+ name,
+ container_id,
+ )
}
fn get_or_create_remote_connection_query(
@@ -1137,6 +1170,8 @@ impl WorkspaceDb {
port: Option<u16>,
user: Option<String>,
distro: Option<String>,
+ name: Option<String>,
+ container_id: Option<String>,
) -> Result<RemoteConnectionId> {
if let Some(id) = this.select_row_bound(sql!(
SELECT id
@@ -1146,7 +1181,9 @@ impl WorkspaceDb {
host IS ? AND
port IS ? AND
user IS ? AND
- distro IS ?
+ distro IS ? AND
+ name IS ? AND
+ container_id IS ?
LIMIT 1
))?((
kind.serialize(),
@@ -1154,6 +1191,8 @@ impl WorkspaceDb {
port,
user.clone(),
distro.clone(),
+ name.clone(),
+ container_id.clone(),
))? {
Ok(RemoteConnectionId(id))
} else {
@@ -1163,10 +1202,20 @@ impl WorkspaceDb {
host,
port,
user,
- distro
- ) VALUES (?1, ?2, ?3, ?4, ?5)
+ distro,
+ name,
+ container_id
+ ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)
RETURNING id
- ))?((kind.serialize(), host, port, user, distro))?
+ ))?((
+ kind.serialize(),
+ host,
+ port,
+ user,
+ distro,
+ name,
+ container_id,
+ ))?
.context("failed to insert remote project")?;
Ok(RemoteConnectionId(id))
}
@@ -1249,15 +1298,23 @@ impl WorkspaceDb {
fn remote_connections(&self) -> Result<HashMap<RemoteConnectionId, RemoteConnectionOptions>> {
Ok(self.select(sql!(
SELECT
- id, kind, host, port, user, distro
+ id, kind, host, port, user, distro, container_id, name
FROM
remote_connections
))?()?
.into_iter()
- .filter_map(|(id, kind, host, port, user, distro)| {
+ .filter_map(|(id, kind, host, port, user, distro, container_id, name)| {
Some((
RemoteConnectionId(id),
- Self::remote_connection_from_row(kind, host, port, user, distro)?,
+ Self::remote_connection_from_row(
+ kind,
+ host,
+ port,
+ user,
+ distro,
+ container_id,
+ name,
+ )?,
))
})
.collect())
@@ -1267,13 +1324,13 @@ impl WorkspaceDb {
&self,
id: RemoteConnectionId,
) -> Result<RemoteConnectionOptions> {
- let (kind, host, port, user, distro) = self.select_row_bound(sql!(
- SELECT kind, host, port, user, distro
+ let (kind, host, port, user, distro, container_id, name) = self.select_row_bound(sql!(
+ SELECT kind, host, port, user, distro, container_id, name
FROM remote_connections
WHERE id = ?
))?(id.0)?
.context("no such remote connection")?;
- Self::remote_connection_from_row(kind, host, port, user, distro)
+ Self::remote_connection_from_row(kind, host, port, user, distro, container_id, name)
.context("invalid remote_connection row")
}
@@ -1283,6 +1340,8 @@ impl WorkspaceDb {
port: Option<u16>,
user: Option<String>,
distro: Option<String>,
+ container_id: Option<String>,
+ name: Option<String>,
) -> Option<RemoteConnectionOptions> {
match RemoteConnectionKind::deserialize(&kind)? {
RemoteConnectionKind::Wsl => Some(RemoteConnectionOptions::Wsl(WslConnectionOptions {
@@ -1295,6 +1354,13 @@ impl WorkspaceDb {
username: user,
..Default::default()
})),
+ RemoteConnectionKind::Docker => {
+ Some(RemoteConnectionOptions::Docker(DockerConnectionOptions {
+ container_id: container_id?,
+ name: name?,
+ upload_binary_over_docker_exec: false,
+ }))
+ }
}
}
@@ -32,6 +32,7 @@ pub(crate) struct RemoteConnectionId(pub u64);
pub(crate) enum RemoteConnectionKind {
Ssh,
Wsl,
+ Docker,
}
#[derive(Debug, PartialEq, Clone)]
@@ -75,6 +76,7 @@ impl RemoteConnectionKind {
match self {
RemoteConnectionKind::Ssh => "ssh",
RemoteConnectionKind::Wsl => "wsl",
+ RemoteConnectionKind::Docker => "docker",
}
}
@@ -82,6 +84,7 @@ impl RemoteConnectionKind {
match text {
"ssh" => Some(Self::Ssh),
"wsl" => Some(Self::Wsl),
+ "docker" => Some(Self::Docker),
_ => None,
}
}
@@ -675,6 +675,7 @@ impl ProjectItemRegistry {
Ok((project_entry_id, build_workspace_item))
}
Err(e) => {
+ log::warn!("Failed to open a project item: {e:#}");
if e.error_code() == ErrorCode::Internal {
if let Some(abs_path) =
entry_abs_path.as_deref().filter(|_| is_file)
@@ -7779,7 +7780,7 @@ pub fn open_remote_project_with_new_connection(
) -> Task<Result<Vec<Option<Box<dyn ItemHandle>>>>> {
cx.spawn(async move |cx| {
let (workspace_id, serialized_workspace) =
- serialize_remote_project(remote_connection.connection_options(), paths.clone(), cx)
+ deserialize_remote_project(remote_connection.connection_options(), paths.clone(), cx)
.await?;
let session = match cx
@@ -7833,7 +7834,7 @@ pub fn open_remote_project_with_existing_connection(
) -> Task<Result<Vec<Option<Box<dyn ItemHandle>>>>> {
cx.spawn(async move |cx| {
let (workspace_id, serialized_workspace) =
- serialize_remote_project(connection_options.clone(), paths.clone(), cx).await?;
+ deserialize_remote_project(connection_options.clone(), paths.clone(), cx).await?;
open_remote_project_inner(
project,
@@ -7935,7 +7936,7 @@ async fn open_remote_project_inner(
Ok(items.into_iter().map(|item| item?.ok()).collect())
}
-fn serialize_remote_project(
+fn deserialize_remote_project(
connection_options: RemoteConnectionOptions,
paths: Vec<PathBuf>,
cx: &AsyncApp,
@@ -3814,7 +3814,7 @@ impl BackgroundScanner {
let root_canonical_path = match &root_canonical_path {
Ok(path) => SanitizedPath::new(path),
Err(err) => {
- log::error!("failed to canonicalize root path {root_path:?}: {err}");
+ log::error!("failed to canonicalize root path {root_path:?}: {err:#}");
return true;
}
};
@@ -2,7 +2,7 @@
description = "The fast, collaborative code editor."
edition.workspace = true
name = "zed"
-version = "0.217.0"
+version = "0.218.0"
publish.workspace = true
license = "GPL-3.0-or-later"
authors = ["Zed Team <hi@zed.dev>"]
@@ -3,7 +3,7 @@ mod zed;
use agent_ui::AgentPanel;
use anyhow::{Context as _, Error, Result};
-use clap::{Parser, command};
+use clap::Parser;
use cli::FORCE_CLI_MODE_ENV_VAR_NAME;
use client::{Client, ProxySettings, UserStore, parse_zed_link};
use collab_ui::channel_view::ChannelView;
@@ -130,6 +130,7 @@ fn fail_to_open_window(e: anyhow::Error, _cx: &mut App) {
process::exit(1);
}
+ // Maybe unify this with gpui::platform::linux::platform::ResultExt::notify_err(..)?
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
{
use ashpd::desktop::notification::{Notification, NotificationProxy, Priority};
@@ -1,8 +1,8 @@
use anyhow::{Context as _, Result};
use client::{Client, telemetry::MINIDUMP_ENDPOINT};
-use futures::AsyncReadExt;
+use futures::{AsyncReadExt, TryStreamExt};
use gpui::{App, AppContext as _, SerializedThreadTaskTimings};
-use http_client::{self, HttpClient};
+use http_client::{self, AsyncBody, HttpClient, Request};
use log::info;
use project::Project;
use proto::{CrashReport, GetCrashFilesResponse};
@@ -296,11 +296,14 @@ async fn upload_minidump(
// TODO: feature-flag-context, and more of device-context like screen resolution, available ram, device model, etc
+ let stream = form
+ .into_stream()
+ .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
+ .into_async_read();
+ let body = AsyncBody::from_reader(stream);
+ let req = Request::builder().uri(endpoint).body(body)?;
let mut response_text = String::new();
- let mut response = client
- .http_client()
- .send_multipart_form(endpoint, form)
- .await?;
+ let mut response = client.http_client().send(req).await?;
response
.body_mut()
.read_to_string(&mut response_text)
@@ -4745,6 +4745,7 @@ mod tests {
"git_panel",
"go_to_line",
"icon_theme_selector",
+ "inline_assistant",
"journal",
"keymap_editor",
"keystroke_input",
@@ -169,7 +169,7 @@ pub fn app_menus(cx: &mut App) -> Vec<Menu> {
MenuItem::os_action("Paste", editor::actions::Paste, OsAction::Paste),
MenuItem::separator(),
MenuItem::action("Find", search::buffer_search::Deploy::find()),
- MenuItem::action("Find In Project", workspace::DeploySearch::find()),
+ MenuItem::action("Find in Project", workspace::DeploySearch::find()),
MenuItem::separator(),
MenuItem::action(
"Toggle Line Comment",
@@ -280,7 +280,7 @@ pub fn app_menus(cx: &mut App) -> Vec<Menu> {
MenuItem::separator(),
MenuItem::action("Toggle Breakpoint", editor::actions::ToggleBreakpoint),
MenuItem::action("Edit Breakpoint", editor::actions::EditLogBreakpoint),
- MenuItem::action("Clear all Breakpoints", debugger_ui::ClearAllBreakpoints),
+ MenuItem::action("Clear All Breakpoints", debugger_ui::ClearAllBreakpoints),
],
},
Menu {
@@ -174,17 +174,13 @@ impl Render for QuickActionBar {
.as_ref()
.is_some_and(|menu| matches!(menu.origin(), ContextMenuOrigin::QuickActionBar))
};
- let code_action_element = if is_deployed {
- editor.update(cx, |editor, cx| {
- if let Some(style) = editor.style() {
- editor.render_context_menu(style, MAX_CODE_ACTION_MENU_LINES, window, cx)
- } else {
- None
- }
+ let code_action_element = is_deployed
+ .then(|| {
+ editor.update(cx, |editor, cx| {
+ editor.render_context_menu(MAX_CODE_ACTION_MENU_LINES, window, cx)
+ })
})
- } else {
- None
- };
+ .flatten();
v_flex()
.child(
IconButton::new("toggle_code_actions_icon", IconName::BoltOutlined)
@@ -428,6 +428,12 @@ pub struct OpenRemote {
pub create_new_window: bool,
}
+/// Opens the dev container connection modal.
+#[derive(PartialEq, Clone, Deserialize, Default, JsonSchema, Action)]
+#[action(namespace = projects)]
+#[serde(deny_unknown_fields)]
+pub struct OpenDevContainer;
+
/// Where to spawn the task in the UI.
#[derive(Default, Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
@@ -0,0 +1,15 @@
+[package]
+name = "zeta_prompt"
+version = "0.1.0"
+publish.workspace = true
+edition.workspace = true
+license = "GPL-3.0-or-later"
+
+[lints]
+workspace = true
+
+[lib]
+path = "src/zeta_prompt.rs"
+
+[dependencies]
+serde.workspace = true
@@ -0,0 +1,165 @@
+use serde::{Deserialize, Serialize};
+use std::fmt::Write;
+use std::ops::Range;
+use std::path::Path;
+use std::sync::Arc;
+
+pub const CURSOR_MARKER: &str = "<|user_cursor|>";
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct ZetaPromptInput {
+ pub cursor_path: Arc<Path>,
+ pub cursor_excerpt: Arc<str>,
+ pub editable_range_in_excerpt: Range<usize>,
+ pub cursor_offset_in_excerpt: usize,
+ pub events: Vec<Arc<Event>>,
+ pub related_files: Arc<[RelatedFile]>,
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+#[serde(tag = "event")]
+pub enum Event {
+ BufferChange {
+ path: Arc<Path>,
+ old_path: Arc<Path>,
+ diff: String,
+ predicted: bool,
+ in_open_source_repo: bool,
+ },
+}
+
+pub fn write_event(prompt: &mut String, event: &Event) {
+ fn write_path_as_unix_str(prompt: &mut String, path: &Path) {
+ for component in path.components() {
+ prompt.push('/');
+ write!(prompt, "{}", component.as_os_str().display()).ok();
+ }
+ }
+ match event {
+ Event::BufferChange {
+ path,
+ old_path,
+ diff,
+ predicted,
+ in_open_source_repo: _,
+ } => {
+ if *predicted {
+ prompt.push_str("// User accepted prediction:\n");
+ }
+ prompt.push_str("--- a");
+ write_path_as_unix_str(prompt, old_path.as_ref());
+ prompt.push_str("\n+++ b");
+ write_path_as_unix_str(prompt, path.as_ref());
+ prompt.push('\n');
+ prompt.push_str(diff);
+ }
+ }
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct RelatedFile {
+ pub path: Arc<Path>,
+ pub max_row: u32,
+ pub excerpts: Vec<RelatedExcerpt>,
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct RelatedExcerpt {
+ pub row_range: Range<u32>,
+ pub text: String,
+}
+
+pub fn format_zeta_prompt(input: &ZetaPromptInput) -> String {
+ let mut prompt = String::new();
+ write_related_files(&mut prompt, &input.related_files);
+ write_edit_history_section(&mut prompt, input);
+ write_cursor_excerpt_section(&mut prompt, input);
+ prompt
+}
+
+pub fn write_related_files(prompt: &mut String, related_files: &[RelatedFile]) {
+ push_delimited(prompt, "related_files", &[], |prompt| {
+ for file in related_files {
+ let path_str = file.path.to_string_lossy();
+ push_delimited(prompt, "related_file", &[("path", &path_str)], |prompt| {
+ for excerpt in &file.excerpts {
+ push_delimited(
+ prompt,
+ "related_excerpt",
+ &[(
+ "lines",
+ &format!(
+ "{}-{}",
+ excerpt.row_range.start + 1,
+ excerpt.row_range.end + 1
+ ),
+ )],
+ |prompt| {
+ prompt.push_str(&excerpt.text);
+ prompt.push('\n');
+ },
+ );
+ }
+ });
+ }
+ });
+}
+
+fn write_edit_history_section(prompt: &mut String, input: &ZetaPromptInput) {
+ push_delimited(prompt, "edit_history", &[], |prompt| {
+ if input.events.is_empty() {
+ prompt.push_str("(No edit history)");
+ } else {
+ for event in &input.events {
+ write_event(prompt, event);
+ }
+ }
+ });
+}
+
+fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) {
+ push_delimited(prompt, "cursor_excerpt", &[], |prompt| {
+ let path_str = input.cursor_path.to_string_lossy();
+ push_delimited(prompt, "file", &[("path", &path_str)], |prompt| {
+ prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]);
+ push_delimited(prompt, "editable_region", &[], |prompt| {
+ prompt.push_str(
+ &input.cursor_excerpt
+ [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt],
+ );
+ prompt.push_str(CURSOR_MARKER);
+ prompt.push_str(
+ &input.cursor_excerpt
+ [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end],
+ );
+ });
+ prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]);
+ });
+ });
+}
+
+fn push_delimited(
+ prompt: &mut String,
+ tag: &'static str,
+ arguments: &[(&str, &str)],
+ cb: impl FnOnce(&mut String),
+) {
+ if !prompt.ends_with("\n") {
+ prompt.push('\n');
+ }
+ prompt.push('<');
+ prompt.push_str(tag);
+ for (arg_name, arg_value) in arguments {
+ write!(prompt, " {}=\"{}\"", arg_name, arg_value).ok();
+ }
+ prompt.push_str(">\n");
+
+ cb(prompt);
+
+ if !prompt.ends_with('\n') {
+ prompt.push('\n');
+ }
+ prompt.push_str("</");
+ prompt.push_str(tag);
+ prompt.push_str(">\n");
+}
@@ -109,18 +109,6 @@ git submodule init
git submodule update
```
-## Update Your Extension
-
-When developing/updating your extension, you will likely need to update its content from its submodule in the extensions repository.
-To quickly fetch the latest code for only specific extension (and avoid updating all others), use the specific path:
-
-```sh
-# From the root of the repository:
-git submodule update --remote extensions/your-extension-name
-```
-
-> Note: If you need to update all submodules (e.g., if multiple extensions have changed, or for a full clean build), you can run `git submodule update` without a path, but this will take longer.
-
## Extension License Requirements
As of October 1st, 2025, extension repositories must include a license.
@@ -177,7 +165,15 @@ To update an extension, open a PR to [the `zed-industries/extensions` repo](http
In your PR do the following:
-1. Update the extension's submodule to the commit of the new version.
+1. Update the extension's submodule to the commit of the new version. For this, you can run
+
+```sh
+# From the root of the repository:
+git submodule update --remote extensions/your-extension-name
+```
+
+to update your extension to the latest commit available in your remote repository.
+
2. Update the `version` field for the extension in `extensions.toml`
- Make sure the `version` matches the one set in `extension.toml` at the particular commit.
@@ -1,29 +1,24 @@
import os
-from datetime import datetime, timedelta
-from typing import Optional
+from datetime import date, datetime, timedelta
+from typing import Any, Optional
+import requests
import typer
-from github import Github
-from github.Issue import Issue
-from github.Repository import Repository
from pytz import timezone
from typer import Typer
app: Typer = typer.Typer()
-DATETIME_FORMAT: str = "%m/%d/%Y %I:%M %p"
-ISSUES_PER_LABEL: int = 50
+AMERICA_NEW_YORK_TIMEZONE = "America/New_York"
+DATETIME_FORMAT: str = "%B %d, %Y %I:%M %p"
+ISSUES_PER_SECTION: int = 50
+ISSUES_TO_FETCH: int = 100
+REPO_OWNER = "zed-industries"
+REPO_NAME = "zed"
+GITHUB_API_BASE_URL = "https://api.github.com"
-class IssueData:
- def __init__(self, issue: Issue) -> None:
- self.title = issue.title
- self.url: str = issue.html_url
- self.like_count: int = issue._rawData["reactions"]["+1"] # type: ignore [attr-defined]
- self.creation_datetime: str = issue.created_at.strftime(DATETIME_FORMAT)
- # TODO: Change script to support storing labels here, rather than directly in the script
- self.labels: set[str] = {label["name"] for label in issue._rawData["labels"]} # type: ignore [attr-defined]
- self._issue = issue
+EXCLUDE_LABEL = "ignore top-ranking issues"
@app.command()
@@ -32,181 +27,135 @@ def main(
issue_reference_number: Optional[int] = None,
query_day_interval: Optional[int] = None,
) -> None:
- start_time: datetime = datetime.now()
-
- start_date: datetime | None = None
+ script_start_time: datetime = datetime.now()
+ start_date: date | None = None
if query_day_interval:
- tz = timezone("america/new_york")
- current_time = datetime.now(tz).replace(
- hour=0, minute=0, second=0, microsecond=0
- )
- start_date = current_time - timedelta(days=query_day_interval)
+ tz = timezone(AMERICA_NEW_YORK_TIMEZONE)
+ today = datetime.now(tz).date()
+ start_date = today - timedelta(days=query_day_interval)
- # GitHub Workflow will pass in the token as an environment variable,
+ # GitHub Workflow will pass in the token as an argument,
# but we can place it in our env when running the script locally, for convenience
- github_token = github_token or os.getenv("GITHUB_ACCESS_TOKEN")
-
- with Github(github_token, per_page=100) as github:
- remaining_requests_before: int = github.rate_limiting[0]
- print(f"Remaining requests before: {remaining_requests_before}")
-
- repo_name: str = "zed-industries/zed"
- repository: Repository = github.get_repo(repo_name)
-
- label_to_issue_data: dict[str, list[IssueData]] = get_issue_maps(
- github, repository, start_date
+ token = github_token or os.getenv("GITHUB_ACCESS_TOKEN")
+ if not token:
+ raise typer.BadParameter(
+ "GitHub token is required. Pass --github-token or set GITHUB_ACCESS_TOKEN env var."
)
- issue_text: str = get_issue_text(label_to_issue_data)
-
- if issue_reference_number:
- top_ranking_issues_issue: Issue = repository.get_issue(issue_reference_number)
- top_ranking_issues_issue.edit(body=issue_text)
- else:
- print(issue_text)
-
- remaining_requests_after: int = github.rate_limiting[0]
- print(f"Remaining requests after: {remaining_requests_after}")
- print(f"Requests used: {remaining_requests_before - remaining_requests_after}")
-
- run_duration: timedelta = datetime.now() - start_time
- print(run_duration)
-
-
-def get_issue_maps(
- github: Github,
- repository: Repository,
- start_date: datetime | None = None,
-) -> dict[str, list[IssueData]]:
- label_to_issue_data: dict[str, list[IssueData]] = get_label_to_issue_data(
- github,
- repository,
- start_date,
- )
-
- # Create a new dictionary with labels ordered by the summation the of likes on the associated issues
- labels = list(label_to_issue_data.keys())
-
- labels.sort(
- key=lambda label: sum(
- issue_data.like_count for issue_data in label_to_issue_data[label]
- ),
- reverse=True,
- )
+ headers = {
+ "Authorization": f"token {token}",
+ "Accept": "application/vnd.github+json",
+ }
- label_to_issue_data = {label: label_to_issue_data[label] for label in labels}
+ section_to_issues = get_section_to_issues(headers, start_date)
+ issue_text: str = create_issue_text(section_to_issues)
- return label_to_issue_data
+ if issue_reference_number:
+ update_reference_issue(headers, issue_reference_number, issue_text)
+ else:
+ print(issue_text)
+ run_duration: timedelta = datetime.now() - script_start_time
+ print(f"Ran for {run_duration}")
-def get_label_to_issue_data(
- github: Github,
- repository: Repository,
- start_date: datetime | None = None,
-) -> dict[str, list[IssueData]]:
- common_queries = [
- f"repo:{repository.full_name}",
- "is:open",
- "is:issue",
- '-label:"ignore top-ranking issues"',
- "sort:reactions-+1-desc",
- ]
- date_query: str | None = (
- f"created:>={start_date.strftime('%Y-%m-%d')}" if start_date else None
- )
+def get_section_to_issues(
+ headers: dict[str, str], start_date: date | None = None
+) -> dict[str, list[dict[str, Any]]]:
+ """Fetch top-ranked issues for each section from GitHub."""
- if date_query:
- common_queries.append(date_query)
-
- common_query = " ".join(common_queries)
-
- # Because PyGithub doesn't seem to support logical operators `AND` and `OR`
- # that GitHub issue queries can use, we use lists as values, rather than
- # using `(label:bug OR type:Bug)`. This is not as efficient, as we might
- # query the same issue multiple times. Issues that are potentially queried
- # multiple times are deduplicated in the `label_to_issues` dictionary. If
- # PyGithub ever supports logical operators, we should definitely make the
- # switch.
- section_queries: dict[str, list[str]] = {
- "bug": ["label:bug", "type:Bug"],
- "crash": ["label:crash", "type:Crash"],
- "feature": ["label:feature", "type:Feature"],
- "meta": ["type:Meta"],
- "windows": ["label:windows"],
- "unlabeled": ["no:label no:type"],
+ section_filters = {
+ "Bugs": "type:Bug",
+ "Crashes": "type:Crash",
+ "Features": "type:Feature",
+ "Tracking issues": "type:Tracking",
+ "Meta issues": "type:Meta",
+ "Windows": 'label:"platform:windows"',
}
- label_to_issue_data: dict[str, list[IssueData]] = {}
-
- for section, queries in section_queries.items():
- unique_issues = set()
-
- for query in queries:
- query: str = f"{common_query} {query}"
- issues = github.search_issues(query)
-
- for issue in issues:
- unique_issues.add(issue)
-
- if len(unique_issues) <= 0:
+ section_to_issues: dict[str, list[dict[str, Any]]] = {}
+ for section, search_qualifier in section_filters.items():
+ query_parts = [
+ f"repo:{REPO_OWNER}/{REPO_NAME}",
+ "is:issue",
+ "is:open",
+ f'-label:"{EXCLUDE_LABEL}"',
+ search_qualifier,
+ ]
+
+ if start_date:
+ query_parts.append(f"created:>={start_date.strftime('%Y-%m-%d')}")
+
+ query = " ".join(query_parts)
+ url = f"{GITHUB_API_BASE_URL}/search/issues"
+ params = {
+ "q": query,
+ "sort": "reactions-+1",
+ "order": "desc",
+ "per_page": ISSUES_TO_FETCH, # this will work as long as it's ≤ 100
+ }
+
+ # we are only fetching one page on purpose
+ response = requests.get(url, headers=headers, params=params)
+ response.raise_for_status()
+ items = response.json()["items"]
+
+ issues: list[dict[str, Any]] = []
+ for item in items:
+ reactions = item["reactions"]
+ score = reactions["+1"] - reactions["-1"]
+ if score > 0:
+ issues.append({
+ "url": item["html_url"],
+ "score": score,
+ "created_at": item["created_at"],
+ })
+
+ if not issues:
continue
- issue_data: list[IssueData] = [IssueData(issue) for issue in unique_issues]
- issue_data.sort(
- key=lambda issue_data: (
- -issue_data.like_count,
- issue_data.creation_datetime,
- )
- )
-
- label_to_issue_data[section] = issue_data[0:ISSUES_PER_LABEL]
-
- return label_to_issue_data
-
+ issues.sort(key=lambda x: (-x["score"], x["created_at"]))
+ section_to_issues[section] = issues[:ISSUES_PER_SECTION]
-def get_issue_text(
- label_to_issue_data: dict[str, list[IssueData]],
-) -> str:
- tz = timezone("america/new_york")
- current_datetime: str = datetime.now(tz).strftime(f"{DATETIME_FORMAT} (%Z)")
-
- highest_ranking_issues_lines: list[str] = get_highest_ranking_issues_lines(
- label_to_issue_data
+ # Sort sections by total score (highest total first)
+ section_to_issues = dict(
+ sorted(
+ section_to_issues.items(),
+ key=lambda item: sum(issue["score"] for issue in item[1]),
+ reverse=True,
+ )
)
+ return section_to_issues
- issue_text_lines: list[str] = [
- f"*Updated on {current_datetime}*",
- *highest_ranking_issues_lines,
- "\n---\n",
- "*For details on how this issue is generated, [see the script](https://github.com/zed-industries/zed/blob/main/script/update_top_ranking_issues/main.py)*",
- ]
- return "\n".join(issue_text_lines)
+def update_reference_issue(
+ headers: dict[str, str], issue_number: int, body: str
+) -> None:
+ url = f"{GITHUB_API_BASE_URL}/repos/{REPO_OWNER}/{REPO_NAME}/issues/{issue_number}"
+ response = requests.patch(url, headers=headers, json={"body": body})
+ response.raise_for_status()
-def get_highest_ranking_issues_lines(
- label_to_issue_data: dict[str, list[IssueData]],
-) -> list[str]:
- highest_ranking_issues_lines: list[str] = []
+def create_issue_text(section_to_issues: dict[str, list[dict[str, Any]]]) -> str:
+ tz = timezone(AMERICA_NEW_YORK_TIMEZONE)
+ current_datetime: str = datetime.now(tz).strftime(f"{DATETIME_FORMAT} (%Z)")
- if label_to_issue_data:
- for label, issue_data in label_to_issue_data.items():
- highest_ranking_issues_lines.append(f"\n## {label}\n")
+ lines: list[str] = [f"*Updated on {current_datetime}*"]
- for i, issue_data in enumerate(issue_data):
- markdown_bullet_point: str = (
- f"{issue_data.url} ({issue_data.like_count} :thumbsup:)"
- )
+ for section, issues in section_to_issues.items():
+ lines.append(f"\n## {section}\n")
+ for i, issue in enumerate(issues):
+ lines.append(f"{i + 1}. {issue['url']} ({issue['score']} :thumbsup:)")
- markdown_bullet_point = f"{i + 1}. {markdown_bullet_point}"
- highest_ranking_issues_lines.append(markdown_bullet_point)
+ lines.append("\n---\n")
+ lines.append(
+ "*For details on how this issue is generated, "
+ "[see the script](https://github.com/zed-industries/zed/blob/main/script/update_top_ranking_issues/main.py)*"
+ )
- return highest_ranking_issues_lines
+ return "\n".join(lines)
if __name__ == "__main__":
app()
-
-# TODO: Sort label output into core and non core sections
@@ -5,9 +5,10 @@ readme = "README.md"
requires-python = ">=3.13"
dependencies = [
"mypy>=1.15.0",
- "pygithub>=2.6.1",
"pytz>=2025.1",
+ "requests>=2.32.0",
"ruff>=0.9.7",
"typer>=0.15.1",
"types-pytz>=2025.1.0.20250204",
+ "types-requests>=2.32.0",
]
@@ -1,60 +1,38 @@
version = 1
-revision = 1
+revision = 3
requires-python = ">=3.13"
[[package]]
name = "certifi"
version = "2024.8.30"
source = { registry = "https://pypi.org/simple" }
-sdist = { url = "https://files.pythonhosted.org/packages/b0/ee/9b19140fe824b367c04c5e1b369942dd754c4c5462d5674002f75c4dedc1/certifi-2024.8.30.tar.gz", hash = "sha256:bec941d2aa8195e248a60b31ff9f0558284cf01a52591ceda73ea9afffd69fd9", size = 168507 }
+sdist = { url = "https://files.pythonhosted.org/packages/b0/ee/9b19140fe824b367c04c5e1b369942dd754c4c5462d5674002f75c4dedc1/certifi-2024.8.30.tar.gz", hash = "sha256:bec941d2aa8195e248a60b31ff9f0558284cf01a52591ceda73ea9afffd69fd9", size = 168507, upload-time = "2024-08-30T01:55:04.365Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/12/90/3c9ff0512038035f59d279fddeb79f5f1eccd8859f06d6163c58798b9487/certifi-2024.8.30-py3-none-any.whl", hash = "sha256:922820b53db7a7257ffbda3f597266d435245903d80737e34f8a45ff3e3230d8", size = 167321 },
-]
-
-[[package]]
-name = "cffi"
-version = "1.17.1"
-source = { registry = "https://pypi.org/simple" }
-dependencies = [
- { name = "pycparser" },
-]
-sdist = { url = "https://files.pythonhosted.org/packages/fc/97/c783634659c2920c3fc70419e3af40972dbaf758daa229a7d6ea6135c90d/cffi-1.17.1.tar.gz", hash = "sha256:1c39c6016c32bc48dd54561950ebd6836e1670f2ae46128f67cf49e789c52824", size = 516621 }
-wheels = [
- { url = "https://files.pythonhosted.org/packages/8d/f8/dd6c246b148639254dad4d6803eb6a54e8c85c6e11ec9df2cffa87571dbe/cffi-1.17.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f3a2b4222ce6b60e2e8b337bb9596923045681d71e5a082783484d845390938e", size = 182989 },
- { url = "https://files.pythonhosted.org/packages/8b/f1/672d303ddf17c24fc83afd712316fda78dc6fce1cd53011b839483e1ecc8/cffi-1.17.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0984a4925a435b1da406122d4d7968dd861c1385afe3b45ba82b750f229811e2", size = 178802 },
- { url = "https://files.pythonhosted.org/packages/0e/2d/eab2e858a91fdff70533cab61dcff4a1f55ec60425832ddfdc9cd36bc8af/cffi-1.17.1-cp313-cp313-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d01b12eeeb4427d3110de311e1774046ad344f5b1a7403101878976ecd7a10f3", size = 454792 },
- { url = "https://files.pythonhosted.org/packages/75/b2/fbaec7c4455c604e29388d55599b99ebcc250a60050610fadde58932b7ee/cffi-1.17.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:706510fe141c86a69c8ddc029c7910003a17353970cff3b904ff0686a5927683", size = 478893 },
- { url = "https://files.pythonhosted.org/packages/4f/b7/6e4a2162178bf1935c336d4da8a9352cccab4d3a5d7914065490f08c0690/cffi-1.17.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de55b766c7aa2e2a3092c51e0483d700341182f08e67c63630d5b6f200bb28e5", size = 485810 },
- { url = "https://files.pythonhosted.org/packages/c7/8a/1d0e4a9c26e54746dc08c2c6c037889124d4f59dffd853a659fa545f1b40/cffi-1.17.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c59d6e989d07460165cc5ad3c61f9fd8f1b4796eacbd81cee78957842b834af4", size = 471200 },
- { url = "https://files.pythonhosted.org/packages/26/9f/1aab65a6c0db35f43c4d1b4f580e8df53914310afc10ae0397d29d697af4/cffi-1.17.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd398dbc6773384a17fe0d3e7eeb8d1a21c2200473ee6806bb5e6a8e62bb73dd", size = 479447 },
- { url = "https://files.pythonhosted.org/packages/5f/e4/fb8b3dd8dc0e98edf1135ff067ae070bb32ef9d509d6cb0f538cd6f7483f/cffi-1.17.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:3edc8d958eb099c634dace3c7e16560ae474aa3803a5df240542b305d14e14ed", size = 484358 },
- { url = "https://files.pythonhosted.org/packages/f1/47/d7145bf2dc04684935d57d67dff9d6d795b2ba2796806bb109864be3a151/cffi-1.17.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:72e72408cad3d5419375fc87d289076ee319835bdfa2caad331e377589aebba9", size = 488469 },
- { url = "https://files.pythonhosted.org/packages/bf/ee/f94057fa6426481d663b88637a9a10e859e492c73d0384514a17d78ee205/cffi-1.17.1-cp313-cp313-win32.whl", hash = "sha256:e03eab0a8677fa80d646b5ddece1cbeaf556c313dcfac435ba11f107ba117b5d", size = 172475 },
- { url = "https://files.pythonhosted.org/packages/7c/fc/6a8cb64e5f0324877d503c854da15d76c1e50eb722e320b15345c4d0c6de/cffi-1.17.1-cp313-cp313-win_amd64.whl", hash = "sha256:f6a16c31041f09ead72d69f583767292f750d24913dadacf5756b966aacb3f1a", size = 182009 },
+ { url = "https://files.pythonhosted.org/packages/12/90/3c9ff0512038035f59d279fddeb79f5f1eccd8859f06d6163c58798b9487/certifi-2024.8.30-py3-none-any.whl", hash = "sha256:922820b53db7a7257ffbda3f597266d435245903d80737e34f8a45ff3e3230d8", size = 167321, upload-time = "2024-08-30T01:55:02.591Z" },
]
[[package]]
name = "charset-normalizer"
version = "3.4.0"
source = { registry = "https://pypi.org/simple" }
-sdist = { url = "https://files.pythonhosted.org/packages/f2/4f/e1808dc01273379acc506d18f1504eb2d299bd4131743b9fc54d7be4df1e/charset_normalizer-3.4.0.tar.gz", hash = "sha256:223217c3d4f82c3ac5e29032b3f1c2eb0fb591b72161f86d93f5719079dae93e", size = 106620 }
+sdist = { url = "https://files.pythonhosted.org/packages/f2/4f/e1808dc01273379acc506d18f1504eb2d299bd4131743b9fc54d7be4df1e/charset_normalizer-3.4.0.tar.gz", hash = "sha256:223217c3d4f82c3ac5e29032b3f1c2eb0fb591b72161f86d93f5719079dae93e", size = 106620, upload-time = "2024-10-09T07:40:20.413Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/f3/89/68a4c86f1a0002810a27f12e9a7b22feb198c59b2f05231349fbce5c06f4/charset_normalizer-3.4.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:dd4eda173a9fcccb5f2e2bd2a9f423d180194b1bf17cf59e3269899235b2a114", size = 194617 },
- { url = "https://files.pythonhosted.org/packages/4f/cd/8947fe425e2ab0aa57aceb7807af13a0e4162cd21eee42ef5b053447edf5/charset_normalizer-3.4.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e9e3c4c9e1ed40ea53acf11e2a386383c3304212c965773704e4603d589343ed", size = 125310 },
- { url = "https://files.pythonhosted.org/packages/5b/f0/b5263e8668a4ee9becc2b451ed909e9c27058337fda5b8c49588183c267a/charset_normalizer-3.4.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:92a7e36b000bf022ef3dbb9c46bfe2d52c047d5e3f3343f43204263c5addc250", size = 119126 },
- { url = "https://files.pythonhosted.org/packages/ff/6e/e445afe4f7fda27a533f3234b627b3e515a1b9429bc981c9a5e2aa5d97b6/charset_normalizer-3.4.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:54b6a92d009cbe2fb11054ba694bc9e284dad30a26757b1e372a1fdddaf21920", size = 139342 },
- { url = "https://files.pythonhosted.org/packages/a1/b2/4af9993b532d93270538ad4926c8e37dc29f2111c36f9c629840c57cd9b3/charset_normalizer-3.4.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ffd9493de4c922f2a38c2bf62b831dcec90ac673ed1ca182fe11b4d8e9f2a64", size = 149383 },
- { url = "https://files.pythonhosted.org/packages/fb/6f/4e78c3b97686b871db9be6f31d64e9264e889f8c9d7ab33c771f847f79b7/charset_normalizer-3.4.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:35c404d74c2926d0287fbd63ed5d27eb911eb9e4a3bb2c6d294f3cfd4a9e0c23", size = 142214 },
- { url = "https://files.pythonhosted.org/packages/2b/c9/1c8fe3ce05d30c87eff498592c89015b19fade13df42850aafae09e94f35/charset_normalizer-3.4.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4796efc4faf6b53a18e3d46343535caed491776a22af773f366534056c4e1fbc", size = 144104 },
- { url = "https://files.pythonhosted.org/packages/ee/68/efad5dcb306bf37db7db338338e7bb8ebd8cf38ee5bbd5ceaaaa46f257e6/charset_normalizer-3.4.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e7fdd52961feb4c96507aa649550ec2a0d527c086d284749b2f582f2d40a2e0d", size = 146255 },
- { url = "https://files.pythonhosted.org/packages/0c/75/1ed813c3ffd200b1f3e71121c95da3f79e6d2a96120163443b3ad1057505/charset_normalizer-3.4.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:92db3c28b5b2a273346bebb24857fda45601aef6ae1c011c0a997106581e8a88", size = 140251 },
- { url = "https://files.pythonhosted.org/packages/7d/0d/6f32255c1979653b448d3c709583557a4d24ff97ac4f3a5be156b2e6a210/charset_normalizer-3.4.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:ab973df98fc99ab39080bfb0eb3a925181454d7c3ac8a1e695fddfae696d9e90", size = 148474 },
- { url = "https://files.pythonhosted.org/packages/ac/a0/c1b5298de4670d997101fef95b97ac440e8c8d8b4efa5a4d1ef44af82f0d/charset_normalizer-3.4.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:4b67fdab07fdd3c10bb21edab3cbfe8cf5696f453afce75d815d9d7223fbe88b", size = 151849 },
- { url = "https://files.pythonhosted.org/packages/04/4f/b3961ba0c664989ba63e30595a3ed0875d6790ff26671e2aae2fdc28a399/charset_normalizer-3.4.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:aa41e526a5d4a9dfcfbab0716c7e8a1b215abd3f3df5a45cf18a12721d31cb5d", size = 149781 },
- { url = "https://files.pythonhosted.org/packages/d8/90/6af4cd042066a4adad58ae25648a12c09c879efa4849c705719ba1b23d8c/charset_normalizer-3.4.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ffc519621dce0c767e96b9c53f09c5d215578e10b02c285809f76509a3931482", size = 144970 },
- { url = "https://files.pythonhosted.org/packages/cc/67/e5e7e0cbfefc4ca79025238b43cdf8a2037854195b37d6417f3d0895c4c2/charset_normalizer-3.4.0-cp313-cp313-win32.whl", hash = "sha256:f19c1585933c82098c2a520f8ec1227f20e339e33aca8fa6f956f6691b784e67", size = 94973 },
- { url = "https://files.pythonhosted.org/packages/65/97/fc9bbc54ee13d33dc54a7fcf17b26368b18505500fc01e228c27b5222d80/charset_normalizer-3.4.0-cp313-cp313-win_amd64.whl", hash = "sha256:707b82d19e65c9bd28b81dde95249b07bf9f5b90ebe1ef17d9b57473f8a64b7b", size = 102308 },
- { url = "https://files.pythonhosted.org/packages/bf/9b/08c0432272d77b04803958a4598a51e2a4b51c06640af8b8f0f908c18bf2/charset_normalizer-3.4.0-py3-none-any.whl", hash = "sha256:fe9f97feb71aa9896b81973a7bbada8c49501dc73e58a10fcef6663af95e5079", size = 49446 },
+ { url = "https://files.pythonhosted.org/packages/f3/89/68a4c86f1a0002810a27f12e9a7b22feb198c59b2f05231349fbce5c06f4/charset_normalizer-3.4.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:dd4eda173a9fcccb5f2e2bd2a9f423d180194b1bf17cf59e3269899235b2a114", size = 194617, upload-time = "2024-10-09T07:39:07.317Z" },
+ { url = "https://files.pythonhosted.org/packages/4f/cd/8947fe425e2ab0aa57aceb7807af13a0e4162cd21eee42ef5b053447edf5/charset_normalizer-3.4.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e9e3c4c9e1ed40ea53acf11e2a386383c3304212c965773704e4603d589343ed", size = 125310, upload-time = "2024-10-09T07:39:08.353Z" },
+ { url = "https://files.pythonhosted.org/packages/5b/f0/b5263e8668a4ee9becc2b451ed909e9c27058337fda5b8c49588183c267a/charset_normalizer-3.4.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:92a7e36b000bf022ef3dbb9c46bfe2d52c047d5e3f3343f43204263c5addc250", size = 119126, upload-time = "2024-10-09T07:39:09.327Z" },
+ { url = "https://files.pythonhosted.org/packages/ff/6e/e445afe4f7fda27a533f3234b627b3e515a1b9429bc981c9a5e2aa5d97b6/charset_normalizer-3.4.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:54b6a92d009cbe2fb11054ba694bc9e284dad30a26757b1e372a1fdddaf21920", size = 139342, upload-time = "2024-10-09T07:39:10.322Z" },
+ { url = "https://files.pythonhosted.org/packages/a1/b2/4af9993b532d93270538ad4926c8e37dc29f2111c36f9c629840c57cd9b3/charset_normalizer-3.4.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ffd9493de4c922f2a38c2bf62b831dcec90ac673ed1ca182fe11b4d8e9f2a64", size = 149383, upload-time = "2024-10-09T07:39:12.042Z" },
+ { url = "https://files.pythonhosted.org/packages/fb/6f/4e78c3b97686b871db9be6f31d64e9264e889f8c9d7ab33c771f847f79b7/charset_normalizer-3.4.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:35c404d74c2926d0287fbd63ed5d27eb911eb9e4a3bb2c6d294f3cfd4a9e0c23", size = 142214, upload-time = "2024-10-09T07:39:13.059Z" },
+ { url = "https://files.pythonhosted.org/packages/2b/c9/1c8fe3ce05d30c87eff498592c89015b19fade13df42850aafae09e94f35/charset_normalizer-3.4.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4796efc4faf6b53a18e3d46343535caed491776a22af773f366534056c4e1fbc", size = 144104, upload-time = "2024-10-09T07:39:14.815Z" },
+ { url = "https://files.pythonhosted.org/packages/ee/68/efad5dcb306bf37db7db338338e7bb8ebd8cf38ee5bbd5ceaaaa46f257e6/charset_normalizer-3.4.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e7fdd52961feb4c96507aa649550ec2a0d527c086d284749b2f582f2d40a2e0d", size = 146255, upload-time = "2024-10-09T07:39:15.868Z" },
+ { url = "https://files.pythonhosted.org/packages/0c/75/1ed813c3ffd200b1f3e71121c95da3f79e6d2a96120163443b3ad1057505/charset_normalizer-3.4.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:92db3c28b5b2a273346bebb24857fda45601aef6ae1c011c0a997106581e8a88", size = 140251, upload-time = "2024-10-09T07:39:16.995Z" },
+ { url = "https://files.pythonhosted.org/packages/7d/0d/6f32255c1979653b448d3c709583557a4d24ff97ac4f3a5be156b2e6a210/charset_normalizer-3.4.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:ab973df98fc99ab39080bfb0eb3a925181454d7c3ac8a1e695fddfae696d9e90", size = 148474, upload-time = "2024-10-09T07:39:18.021Z" },
+ { url = "https://files.pythonhosted.org/packages/ac/a0/c1b5298de4670d997101fef95b97ac440e8c8d8b4efa5a4d1ef44af82f0d/charset_normalizer-3.4.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:4b67fdab07fdd3c10bb21edab3cbfe8cf5696f453afce75d815d9d7223fbe88b", size = 151849, upload-time = "2024-10-09T07:39:19.243Z" },
+ { url = "https://files.pythonhosted.org/packages/04/4f/b3961ba0c664989ba63e30595a3ed0875d6790ff26671e2aae2fdc28a399/charset_normalizer-3.4.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:aa41e526a5d4a9dfcfbab0716c7e8a1b215abd3f3df5a45cf18a12721d31cb5d", size = 149781, upload-time = "2024-10-09T07:39:20.397Z" },
+ { url = "https://files.pythonhosted.org/packages/d8/90/6af4cd042066a4adad58ae25648a12c09c879efa4849c705719ba1b23d8c/charset_normalizer-3.4.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ffc519621dce0c767e96b9c53f09c5d215578e10b02c285809f76509a3931482", size = 144970, upload-time = "2024-10-09T07:39:21.452Z" },
+ { url = "https://files.pythonhosted.org/packages/cc/67/e5e7e0cbfefc4ca79025238b43cdf8a2037854195b37d6417f3d0895c4c2/charset_normalizer-3.4.0-cp313-cp313-win32.whl", hash = "sha256:f19c1585933c82098c2a520f8ec1227f20e339e33aca8fa6f956f6691b784e67", size = 94973, upload-time = "2024-10-09T07:39:22.509Z" },
+ { url = "https://files.pythonhosted.org/packages/65/97/fc9bbc54ee13d33dc54a7fcf17b26368b18505500fc01e228c27b5222d80/charset_normalizer-3.4.0-cp313-cp313-win_amd64.whl", hash = "sha256:707b82d19e65c9bd28b81dde95249b07bf9f5b90ebe1ef17d9b57473f8a64b7b", size = 102308, upload-time = "2024-10-09T07:39:23.524Z" },
+ { url = "https://files.pythonhosted.org/packages/bf/9b/08c0432272d77b04803958a4598a51e2a4b51c06640af8b8f0f908c18bf2/charset_normalizer-3.4.0-py3-none-any.whl", hash = "sha256:fe9f97feb71aa9896b81973a7bbada8c49501dc73e58a10fcef6663af95e5079", size = 49446, upload-time = "2024-10-09T07:40:19.383Z" },
]
[[package]]
@@ -64,68 +42,27 @@ source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "colorama", marker = "sys_platform == 'win32'" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/96/d3/f04c7bfcf5c1862a2a5b845c6b2b360488cf47af55dfa79c98f6a6bf98b5/click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de", size = 336121 }
+sdist = { url = "https://files.pythonhosted.org/packages/96/d3/f04c7bfcf5c1862a2a5b845c6b2b360488cf47af55dfa79c98f6a6bf98b5/click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de", size = 336121, upload-time = "2023-08-17T17:29:11.868Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/00/2e/d53fa4befbf2cfa713304affc7ca780ce4fc1fd8710527771b58311a3229/click-8.1.7-py3-none-any.whl", hash = "sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28", size = 97941 },
+ { url = "https://files.pythonhosted.org/packages/00/2e/d53fa4befbf2cfa713304affc7ca780ce4fc1fd8710527771b58311a3229/click-8.1.7-py3-none-any.whl", hash = "sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28", size = 97941, upload-time = "2023-08-17T17:29:10.08Z" },
]
[[package]]
name = "colorama"
version = "0.4.6"
source = { registry = "https://pypi.org/simple" }
-sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697 }
-wheels = [
- { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335 },
-]
-
-[[package]]
-name = "cryptography"
-version = "43.0.1"
-source = { registry = "https://pypi.org/simple" }
-dependencies = [
- { name = "cffi", marker = "platform_python_implementation != 'PyPy'" },
-]
-sdist = { url = "https://files.pythonhosted.org/packages/de/ba/0664727028b37e249e73879348cc46d45c5c1a2a2e81e8166462953c5755/cryptography-43.0.1.tar.gz", hash = "sha256:203e92a75716d8cfb491dc47c79e17d0d9207ccffcbcb35f598fbe463ae3444d", size = 686927 }
-wheels = [
- { url = "https://files.pythonhosted.org/packages/58/28/b92c98a04ba762f8cdeb54eba5c4c84e63cac037a7c5e70117d337b15ad6/cryptography-43.0.1-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:8385d98f6a3bf8bb2d65a73e17ed87a3ba84f6991c155691c51112075f9ffc5d", size = 6223222 },
- { url = "https://files.pythonhosted.org/packages/33/13/1193774705783ba364121aa2a60132fa31a668b8ababd5edfa1662354ccd/cryptography-43.0.1-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:27e613d7077ac613e399270253259d9d53872aaf657471473ebfc9a52935c062", size = 3794751 },
- { url = "https://files.pythonhosted.org/packages/5e/4b/39bb3c4c8cfb3e94e736b8d8859ce5c81536e91a1033b1d26770c4249000/cryptography-43.0.1-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:68aaecc4178e90719e95298515979814bda0cbada1256a4485414860bd7ab962", size = 3981827 },
- { url = "https://files.pythonhosted.org/packages/ce/dc/1471d4d56608e1013237af334b8a4c35d53895694fbb73882d1c4fd3f55e/cryptography-43.0.1-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:de41fd81a41e53267cb020bb3a7212861da53a7d39f863585d13ea11049cf277", size = 3780034 },
- { url = "https://files.pythonhosted.org/packages/ad/43/7a9920135b0d5437cc2f8f529fa757431eb6a7736ddfadfdee1cc5890800/cryptography-43.0.1-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:f98bf604c82c416bc829e490c700ca1553eafdf2912a91e23a79d97d9801372a", size = 3993407 },
- { url = "https://files.pythonhosted.org/packages/cc/42/9ab8467af6c0b76f3d9b8f01d1cf25b9c9f3f2151f4acfab888d21c55a72/cryptography-43.0.1-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:61ec41068b7b74268fa86e3e9e12b9f0c21fcf65434571dbb13d954bceb08042", size = 3886457 },
- { url = "https://files.pythonhosted.org/packages/a4/65/430509e31700286ec02868a2457d2111d03ccefc20349d24e58d171ae0a7/cryptography-43.0.1-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:014f58110f53237ace6a408b5beb6c427b64e084eb451ef25a28308270086494", size = 4081499 },
- { url = "https://files.pythonhosted.org/packages/bb/18/a04b6467e6e09df8c73b91dcee8878f4a438a43a3603dc3cd6f8003b92d8/cryptography-43.0.1-cp37-abi3-win32.whl", hash = "sha256:2bd51274dcd59f09dd952afb696bf9c61a7a49dfc764c04dd33ef7a6b502a1e2", size = 2616504 },
- { url = "https://files.pythonhosted.org/packages/cc/73/0eacbdc437202edcbdc07f3576ed8fb8b0ab79d27bf2c5d822d758a72faa/cryptography-43.0.1-cp37-abi3-win_amd64.whl", hash = "sha256:666ae11966643886c2987b3b721899d250855718d6d9ce41b521252a17985f4d", size = 3067456 },
- { url = "https://files.pythonhosted.org/packages/8a/b6/bc54b371f02cffd35ff8dc6baba88304d7cf8e83632566b4b42e00383e03/cryptography-43.0.1-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:ac119bb76b9faa00f48128b7f5679e1d8d437365c5d26f1c2c3f0da4ce1b553d", size = 6225263 },
- { url = "https://files.pythonhosted.org/packages/00/0e/8217e348a1fa417ec4c78cd3cdf24154f5e76fd7597343a35bd403650dfd/cryptography-43.0.1-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1bbcce1a551e262dfbafb6e6252f1ae36a248e615ca44ba302df077a846a8806", size = 3794368 },
- { url = "https://files.pythonhosted.org/packages/3d/ed/38b6be7254d8f7251fde8054af597ee8afa14f911da67a9410a45f602fc3/cryptography-43.0.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58d4e9129985185a06d849aa6df265bdd5a74ca6e1b736a77959b498e0505b85", size = 3981750 },
- { url = "https://files.pythonhosted.org/packages/64/f3/b7946c3887cf7436f002f4cbb1e6aec77b8d299b86be48eeadfefb937c4b/cryptography-43.0.1-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:d03a475165f3134f773d1388aeb19c2d25ba88b6a9733c5c590b9ff7bbfa2e0c", size = 3778925 },
- { url = "https://files.pythonhosted.org/packages/ac/7e/ebda4dd4ae098a0990753efbb4b50954f1d03003846b943ea85070782da7/cryptography-43.0.1-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:511f4273808ab590912a93ddb4e3914dfd8a388fed883361b02dea3791f292e1", size = 3993152 },
- { url = "https://files.pythonhosted.org/packages/43/f6/feebbd78a3e341e3913846a3bb2c29d0b09b1b3af1573c6baabc2533e147/cryptography-43.0.1-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:80eda8b3e173f0f247f711eef62be51b599b5d425c429b5d4ca6a05e9e856baa", size = 3886392 },
- { url = "https://files.pythonhosted.org/packages/bd/4c/ab0b9407d5247576290b4fd8abd06b7f51bd414f04eef0f2800675512d61/cryptography-43.0.1-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:38926c50cff6f533f8a2dae3d7f19541432610d114a70808f0926d5aaa7121e4", size = 4082606 },
- { url = "https://files.pythonhosted.org/packages/05/36/e532a671998d6fcfdb9122da16434347a58a6bae9465e527e450e0bc60a5/cryptography-43.0.1-cp39-abi3-win32.whl", hash = "sha256:a575913fb06e05e6b4b814d7f7468c2c660e8bb16d8d5a1faf9b33ccc569dd47", size = 2617948 },
- { url = "https://files.pythonhosted.org/packages/b3/c6/c09cee6968add5ff868525c3815e5dccc0e3c6e89eec58dc9135d3c40e88/cryptography-43.0.1-cp39-abi3-win_amd64.whl", hash = "sha256:d75601ad10b059ec832e78823b348bfa1a59f6b8d545db3a24fd44362a1564cb", size = 3070445 },
-]
-
-[[package]]
-name = "deprecated"
-version = "1.2.14"
-source = { registry = "https://pypi.org/simple" }
-dependencies = [
- { name = "wrapt" },
-]
-sdist = { url = "https://files.pythonhosted.org/packages/92/14/1e41f504a246fc224d2ac264c227975427a85caf37c3979979edb9b1b232/Deprecated-1.2.14.tar.gz", hash = "sha256:e5323eb936458dccc2582dc6f9c322c852a775a27065ff2b0c4970b9d53d01b3", size = 2974416 }
+sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697, upload-time = "2022-10-25T02:36:22.414Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/20/8d/778b7d51b981a96554f29136cd59ca7880bf58094338085bcf2a979a0e6a/Deprecated-1.2.14-py2.py3-none-any.whl", hash = "sha256:6fac8b097794a90302bdbb17b9b815e732d3c4720583ff1b198499d78470466c", size = 9561 },
+ { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" },
]
[[package]]
name = "idna"
version = "3.10"
source = { registry = "https://pypi.org/simple" }
-sdist = { url = "https://files.pythonhosted.org/packages/f1/70/7703c29685631f5a7590aa73f1f1d3fa9a380e654b86af429e0934a32f7d/idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9", size = 190490 }
+sdist = { url = "https://files.pythonhosted.org/packages/f1/70/7703c29685631f5a7590aa73f1f1d3fa9a380e654b86af429e0934a32f7d/idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9", size = 190490, upload-time = "2024-09-15T18:07:39.745Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442 },
+ { url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442, upload-time = "2024-09-15T18:07:37.964Z" },
]
[[package]]
@@ -135,18 +72,18 @@ source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "mdurl" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/38/71/3b932df36c1a044d397a1f92d1cf91ee0a503d91e470cbd670aa66b07ed0/markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb", size = 74596 }
+sdist = { url = "https://files.pythonhosted.org/packages/38/71/3b932df36c1a044d397a1f92d1cf91ee0a503d91e470cbd670aa66b07ed0/markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb", size = 74596, upload-time = "2023-06-03T06:41:14.443Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/42/d7/1ec15b46af6af88f19b8e5ffea08fa375d433c998b8a7639e76935c14f1f/markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1", size = 87528 },
+ { url = "https://files.pythonhosted.org/packages/42/d7/1ec15b46af6af88f19b8e5ffea08fa375d433c998b8a7639e76935c14f1f/markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1", size = 87528, upload-time = "2023-06-03T06:41:11.019Z" },
]
[[package]]
name = "mdurl"
version = "0.1.2"
source = { registry = "https://pypi.org/simple" }
-sdist = { url = "https://files.pythonhosted.org/packages/d6/54/cfe61301667036ec958cb99bd3efefba235e65cdeb9c84d24a8293ba1d90/mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba", size = 8729 }
+sdist = { url = "https://files.pythonhosted.org/packages/d6/54/cfe61301667036ec958cb99bd3efefba235e65cdeb9c84d24a8293ba1d90/mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba", size = 8729, upload-time = "2022-08-14T12:40:10.846Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979 },
+ { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979, upload-time = "2022-08-14T12:40:09.779Z" },
]
[[package]]
@@ -157,102 +94,42 @@ dependencies = [
{ name = "mypy-extensions" },
{ name = "typing-extensions" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/ce/43/d5e49a86afa64bd3839ea0d5b9c7103487007d728e1293f52525d6d5486a/mypy-1.15.0.tar.gz", hash = "sha256:404534629d51d3efea5c800ee7c42b72a6554d6c400e6a79eafe15d11341fd43", size = 3239717 }
+sdist = { url = "https://files.pythonhosted.org/packages/ce/43/d5e49a86afa64bd3839ea0d5b9c7103487007d728e1293f52525d6d5486a/mypy-1.15.0.tar.gz", hash = "sha256:404534629d51d3efea5c800ee7c42b72a6554d6c400e6a79eafe15d11341fd43", size = 3239717, upload-time = "2025-02-05T03:50:34.655Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/6a/9b/fd2e05d6ffff24d912f150b87db9e364fa8282045c875654ce7e32fffa66/mypy-1.15.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:93faf3fdb04768d44bf28693293f3904bbb555d076b781ad2530214ee53e3445", size = 10788592 },
- { url = "https://files.pythonhosted.org/packages/74/37/b246d711c28a03ead1fd906bbc7106659aed7c089d55fe40dd58db812628/mypy-1.15.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:811aeccadfb730024c5d3e326b2fbe9249bb7413553f15499a4050f7c30e801d", size = 9753611 },
- { url = "https://files.pythonhosted.org/packages/a6/ac/395808a92e10cfdac8003c3de9a2ab6dc7cde6c0d2a4df3df1b815ffd067/mypy-1.15.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:98b7b9b9aedb65fe628c62a6dc57f6d5088ef2dfca37903a7d9ee374d03acca5", size = 11438443 },
- { url = "https://files.pythonhosted.org/packages/d2/8b/801aa06445d2de3895f59e476f38f3f8d610ef5d6908245f07d002676cbf/mypy-1.15.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c43a7682e24b4f576d93072216bf56eeff70d9140241f9edec0c104d0c515036", size = 12402541 },
- { url = "https://files.pythonhosted.org/packages/c7/67/5a4268782eb77344cc613a4cf23540928e41f018a9a1ec4c6882baf20ab8/mypy-1.15.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:baefc32840a9f00babd83251560e0ae1573e2f9d1b067719479bfb0e987c6357", size = 12494348 },
- { url = "https://files.pythonhosted.org/packages/83/3e/57bb447f7bbbfaabf1712d96f9df142624a386d98fb026a761532526057e/mypy-1.15.0-cp313-cp313-win_amd64.whl", hash = "sha256:b9378e2c00146c44793c98b8d5a61039a048e31f429fb0eb546d93f4b000bedf", size = 9373648 },
- { url = "https://files.pythonhosted.org/packages/09/4e/a7d65c7322c510de2c409ff3828b03354a7c43f5a8ed458a7a131b41c7b9/mypy-1.15.0-py3-none-any.whl", hash = "sha256:5469affef548bd1895d86d3bf10ce2b44e33d86923c29e4d675b3e323437ea3e", size = 2221777 },
+ { url = "https://files.pythonhosted.org/packages/6a/9b/fd2e05d6ffff24d912f150b87db9e364fa8282045c875654ce7e32fffa66/mypy-1.15.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:93faf3fdb04768d44bf28693293f3904bbb555d076b781ad2530214ee53e3445", size = 10788592, upload-time = "2025-02-05T03:48:55.789Z" },
+ { url = "https://files.pythonhosted.org/packages/74/37/b246d711c28a03ead1fd906bbc7106659aed7c089d55fe40dd58db812628/mypy-1.15.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:811aeccadfb730024c5d3e326b2fbe9249bb7413553f15499a4050f7c30e801d", size = 9753611, upload-time = "2025-02-05T03:48:44.581Z" },
+ { url = "https://files.pythonhosted.org/packages/a6/ac/395808a92e10cfdac8003c3de9a2ab6dc7cde6c0d2a4df3df1b815ffd067/mypy-1.15.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:98b7b9b9aedb65fe628c62a6dc57f6d5088ef2dfca37903a7d9ee374d03acca5", size = 11438443, upload-time = "2025-02-05T03:49:25.514Z" },
+ { url = "https://files.pythonhosted.org/packages/d2/8b/801aa06445d2de3895f59e476f38f3f8d610ef5d6908245f07d002676cbf/mypy-1.15.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c43a7682e24b4f576d93072216bf56eeff70d9140241f9edec0c104d0c515036", size = 12402541, upload-time = "2025-02-05T03:49:57.623Z" },
+ { url = "https://files.pythonhosted.org/packages/c7/67/5a4268782eb77344cc613a4cf23540928e41f018a9a1ec4c6882baf20ab8/mypy-1.15.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:baefc32840a9f00babd83251560e0ae1573e2f9d1b067719479bfb0e987c6357", size = 12494348, upload-time = "2025-02-05T03:48:52.361Z" },
+ { url = "https://files.pythonhosted.org/packages/83/3e/57bb447f7bbbfaabf1712d96f9df142624a386d98fb026a761532526057e/mypy-1.15.0-cp313-cp313-win_amd64.whl", hash = "sha256:b9378e2c00146c44793c98b8d5a61039a048e31f429fb0eb546d93f4b000bedf", size = 9373648, upload-time = "2025-02-05T03:49:11.395Z" },
+ { url = "https://files.pythonhosted.org/packages/09/4e/a7d65c7322c510de2c409ff3828b03354a7c43f5a8ed458a7a131b41c7b9/mypy-1.15.0-py3-none-any.whl", hash = "sha256:5469affef548bd1895d86d3bf10ce2b44e33d86923c29e4d675b3e323437ea3e", size = 2221777, upload-time = "2025-02-05T03:50:08.348Z" },
]
[[package]]
name = "mypy-extensions"
version = "1.0.0"
source = { registry = "https://pypi.org/simple" }
-sdist = { url = "https://files.pythonhosted.org/packages/98/a4/1ab47638b92648243faf97a5aeb6ea83059cc3624972ab6b8d2316078d3f/mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782", size = 4433 }
-wheels = [
- { url = "https://files.pythonhosted.org/packages/2a/e2/5d3f6ada4297caebe1a2add3b126fe800c96f56dbe5d1988a2cbe0b267aa/mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d", size = 4695 },
-]
-
-[[package]]
-name = "pycparser"
-version = "2.22"
-source = { registry = "https://pypi.org/simple" }
-sdist = { url = "https://files.pythonhosted.org/packages/1d/b2/31537cf4b1ca988837256c910a668b553fceb8f069bedc4b1c826024b52c/pycparser-2.22.tar.gz", hash = "sha256:491c8be9c040f5390f5bf44a5b07752bd07f56edf992381b05c701439eec10f6", size = 172736 }
-wheels = [
- { url = "https://files.pythonhosted.org/packages/13/a3/a812df4e2dd5696d1f351d58b8fe16a405b234ad2886a0dab9183fb78109/pycparser-2.22-py3-none-any.whl", hash = "sha256:c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc", size = 117552 },
-]
-
-[[package]]
-name = "pygithub"
-version = "2.6.1"
-source = { registry = "https://pypi.org/simple" }
-dependencies = [
- { name = "deprecated" },
- { name = "pyjwt", extra = ["crypto"] },
- { name = "pynacl" },
- { name = "requests" },
- { name = "typing-extensions" },
- { name = "urllib3" },
-]
-sdist = { url = "https://files.pythonhosted.org/packages/c0/88/e08ab18dc74b2916f48703ed1a797d57cb64eca0e23b0a9254e13cfe3911/pygithub-2.6.1.tar.gz", hash = "sha256:b5c035392991cca63959e9453286b41b54d83bf2de2daa7d7ff7e4312cebf3bf", size = 3659473 }
+sdist = { url = "https://files.pythonhosted.org/packages/98/a4/1ab47638b92648243faf97a5aeb6ea83059cc3624972ab6b8d2316078d3f/mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782", size = 4433, upload-time = "2023-02-04T12:11:27.157Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/ac/fc/a444cd19ccc8c4946a512f3827ed0b3565c88488719d800d54a75d541c0b/PyGithub-2.6.1-py3-none-any.whl", hash = "sha256:6f2fa6d076ccae475f9fc392cc6cdbd54db985d4f69b8833a28397de75ed6ca3", size = 410451 },
+ { url = "https://files.pythonhosted.org/packages/2a/e2/5d3f6ada4297caebe1a2add3b126fe800c96f56dbe5d1988a2cbe0b267aa/mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d", size = 4695, upload-time = "2023-02-04T12:11:25.002Z" },
]
[[package]]
name = "pygments"
version = "2.18.0"
source = { registry = "https://pypi.org/simple" }
-sdist = { url = "https://files.pythonhosted.org/packages/8e/62/8336eff65bcbc8e4cb5d05b55faf041285951b6e80f33e2bff2024788f31/pygments-2.18.0.tar.gz", hash = "sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199", size = 4891905 }
-wheels = [
- { url = "https://files.pythonhosted.org/packages/f7/3f/01c8b82017c199075f8f788d0d906b9ffbbc5a47dc9918a945e13d5a2bda/pygments-2.18.0-py3-none-any.whl", hash = "sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a", size = 1205513 },
-]
-
-[[package]]
-name = "pyjwt"
-version = "2.9.0"
-source = { registry = "https://pypi.org/simple" }
-sdist = { url = "https://files.pythonhosted.org/packages/fb/68/ce067f09fca4abeca8771fe667d89cc347d1e99da3e093112ac329c6020e/pyjwt-2.9.0.tar.gz", hash = "sha256:7e1e5b56cc735432a7369cbfa0efe50fa113ebecdc04ae6922deba8b84582d0c", size = 78825 }
-wheels = [
- { url = "https://files.pythonhosted.org/packages/79/84/0fdf9b18ba31d69877bd39c9cd6052b47f3761e9910c15de788e519f079f/PyJWT-2.9.0-py3-none-any.whl", hash = "sha256:3b02fb0f44517787776cf48f2ae25d8e14f300e6d7545a4315cee571a415e850", size = 22344 },
-]
-
-[package.optional-dependencies]
-crypto = [
- { name = "cryptography" },
-]
-
-[[package]]
-name = "pynacl"
-version = "1.5.0"
-source = { registry = "https://pypi.org/simple" }
-dependencies = [
- { name = "cffi" },
-]
-sdist = { url = "https://files.pythonhosted.org/packages/a7/22/27582568be639dfe22ddb3902225f91f2f17ceff88ce80e4db396c8986da/PyNaCl-1.5.0.tar.gz", hash = "sha256:8ac7448f09ab85811607bdd21ec2464495ac8b7c66d146bf545b0f08fb9220ba", size = 3392854 }
+sdist = { url = "https://files.pythonhosted.org/packages/8e/62/8336eff65bcbc8e4cb5d05b55faf041285951b6e80f33e2bff2024788f31/pygments-2.18.0.tar.gz", hash = "sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199", size = 4891905, upload-time = "2024-05-04T13:42:02.013Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/ce/75/0b8ede18506041c0bf23ac4d8e2971b4161cd6ce630b177d0a08eb0d8857/PyNaCl-1.5.0-cp36-abi3-macosx_10_10_universal2.whl", hash = "sha256:401002a4aaa07c9414132aaed7f6836ff98f59277a234704ff66878c2ee4a0d1", size = 349920 },
- { url = "https://files.pythonhosted.org/packages/59/bb/fddf10acd09637327a97ef89d2a9d621328850a72f1fdc8c08bdf72e385f/PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:52cb72a79269189d4e0dc537556f4740f7f0a9ec41c1322598799b0bdad4ef92", size = 601722 },
- { url = "https://files.pythonhosted.org/packages/5d/70/87a065c37cca41a75f2ce113a5a2c2aa7533be648b184ade58971b5f7ccc/PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a36d4a9dda1f19ce6e03c9a784a2921a4b726b02e1c736600ca9c22029474394", size = 680087 },
- { url = "https://files.pythonhosted.org/packages/ee/87/f1bb6a595f14a327e8285b9eb54d41fef76c585a0edef0a45f6fc95de125/PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:0c84947a22519e013607c9be43706dd42513f9e6ae5d39d3613ca1e142fba44d", size = 856678 },
- { url = "https://files.pythonhosted.org/packages/66/28/ca86676b69bf9f90e710571b67450508484388bfce09acf8a46f0b8c785f/PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:06b8f6fa7f5de8d5d2f7573fe8c863c051225a27b61e6860fd047b1775807858", size = 1133660 },
- { url = "https://files.pythonhosted.org/packages/3d/85/c262db650e86812585e2bc59e497a8f59948a005325a11bbbc9ecd3fe26b/PyNaCl-1.5.0-cp36-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:a422368fc821589c228f4c49438a368831cb5bbc0eab5ebe1d7fac9dded6567b", size = 663824 },
- { url = "https://files.pythonhosted.org/packages/fd/1a/cc308a884bd299b651f1633acb978e8596c71c33ca85e9dc9fa33a5399b9/PyNaCl-1.5.0-cp36-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:61f642bf2378713e2c2e1de73444a3778e5f0a38be6fee0fe532fe30060282ff", size = 1117912 },
- { url = "https://files.pythonhosted.org/packages/25/2d/b7df6ddb0c2a33afdb358f8af6ea3b8c4d1196ca45497dd37a56f0c122be/PyNaCl-1.5.0-cp36-abi3-win32.whl", hash = "sha256:e46dae94e34b085175f8abb3b0aaa7da40767865ac82c928eeb9e57e1ea8a543", size = 204624 },
- { url = "https://files.pythonhosted.org/packages/5e/22/d3db169895faaf3e2eda892f005f433a62db2decbcfbc2f61e6517adfa87/PyNaCl-1.5.0-cp36-abi3-win_amd64.whl", hash = "sha256:20f42270d27e1b6a29f54032090b972d97f0a1b0948cc52392041ef7831fee93", size = 212141 },
+ { url = "https://files.pythonhosted.org/packages/f7/3f/01c8b82017c199075f8f788d0d906b9ffbbc5a47dc9918a945e13d5a2bda/pygments-2.18.0-py3-none-any.whl", hash = "sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a", size = 1205513, upload-time = "2024-05-04T13:41:57.345Z" },
]
[[package]]
name = "pytz"
version = "2025.1"
source = { registry = "https://pypi.org/simple" }
-sdist = { url = "https://files.pythonhosted.org/packages/5f/57/df1c9157c8d5a05117e455d66fd7cf6dbc46974f832b1058ed4856785d8a/pytz-2025.1.tar.gz", hash = "sha256:c2db42be2a2518b28e65f9207c4d05e6ff547d1efa4086469ef855e4ab70178e", size = 319617 }
+sdist = { url = "https://files.pythonhosted.org/packages/5f/57/df1c9157c8d5a05117e455d66fd7cf6dbc46974f832b1058ed4856785d8a/pytz-2025.1.tar.gz", hash = "sha256:c2db42be2a2518b28e65f9207c4d05e6ff547d1efa4086469ef855e4ab70178e", size = 319617, upload-time = "2025-01-31T01:54:48.615Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/eb/38/ac33370d784287baa1c3d538978b5e2ea064d4c1b93ffbd12826c190dd10/pytz-2025.1-py2.py3-none-any.whl", hash = "sha256:89dd22dca55b46eac6eda23b2d72721bf1bdfef212645d81513ef5d03038de57", size = 507930 },
+ { url = "https://files.pythonhosted.org/packages/eb/38/ac33370d784287baa1c3d538978b5e2ea064d4c1b93ffbd12826c190dd10/pytz-2025.1-py2.py3-none-any.whl", hash = "sha256:89dd22dca55b46eac6eda23b2d72721bf1bdfef212645d81513ef5d03038de57", size = 507930, upload-time = "2025-01-31T01:54:45.634Z" },
]
[[package]]
@@ -265,9 +142,9 @@ dependencies = [
{ name = "idna" },
{ name = "urllib3" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/63/70/2bf7780ad2d390a8d301ad0b550f1581eadbd9a20f896afe06353c2a2913/requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760", size = 131218 }
+sdist = { url = "https://files.pythonhosted.org/packages/63/70/2bf7780ad2d390a8d301ad0b550f1581eadbd9a20f896afe06353c2a2913/requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760", size = 131218, upload-time = "2024-05-29T15:37:49.536Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/f9/9b/335f9764261e915ed497fcdeb11df5dfd6f7bf257d4a6a2a686d80da4d54/requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6", size = 64928 },
+ { url = "https://files.pythonhosted.org/packages/f9/9b/335f9764261e915ed497fcdeb11df5dfd6f7bf257d4a6a2a686d80da4d54/requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6", size = 64928, upload-time = "2024-05-29T15:37:47.027Z" },
]
[[package]]
@@ -278,43 +155,43 @@ dependencies = [
{ name = "markdown-it-py" },
{ name = "pygments" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/aa/9e/1784d15b057b0075e5136445aaea92d23955aad2c93eaede673718a40d95/rich-13.9.2.tar.gz", hash = "sha256:51a2c62057461aaf7152b4d611168f93a9fc73068f8ded2790f29fe2b5366d0c", size = 222843 }
+sdist = { url = "https://files.pythonhosted.org/packages/aa/9e/1784d15b057b0075e5136445aaea92d23955aad2c93eaede673718a40d95/rich-13.9.2.tar.gz", hash = "sha256:51a2c62057461aaf7152b4d611168f93a9fc73068f8ded2790f29fe2b5366d0c", size = 222843, upload-time = "2024-10-04T11:50:31.453Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/67/91/5474b84e505a6ccc295b2d322d90ff6aa0746745717839ee0c5fb4fdcceb/rich-13.9.2-py3-none-any.whl", hash = "sha256:8c82a3d3f8dcfe9e734771313e606b39d8247bb6b826e196f4914b333b743cf1", size = 242117 },
+ { url = "https://files.pythonhosted.org/packages/67/91/5474b84e505a6ccc295b2d322d90ff6aa0746745717839ee0c5fb4fdcceb/rich-13.9.2-py3-none-any.whl", hash = "sha256:8c82a3d3f8dcfe9e734771313e606b39d8247bb6b826e196f4914b333b743cf1", size = 242117, upload-time = "2024-10-04T11:50:29.123Z" },
]
[[package]]
name = "ruff"
version = "0.9.7"
source = { registry = "https://pypi.org/simple" }
-sdist = { url = "https://files.pythonhosted.org/packages/39/8b/a86c300359861b186f18359adf4437ac8e4c52e42daa9eedc731ef9d5b53/ruff-0.9.7.tar.gz", hash = "sha256:643757633417907510157b206e490c3aa11cab0c087c912f60e07fbafa87a4c6", size = 3669813 }
+sdist = { url = "https://files.pythonhosted.org/packages/39/8b/a86c300359861b186f18359adf4437ac8e4c52e42daa9eedc731ef9d5b53/ruff-0.9.7.tar.gz", hash = "sha256:643757633417907510157b206e490c3aa11cab0c087c912f60e07fbafa87a4c6", size = 3669813, upload-time = "2025-02-20T13:26:52.111Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/b1/f3/3a1d22973291226df4b4e2ff70196b926b6f910c488479adb0eeb42a0d7f/ruff-0.9.7-py3-none-linux_armv6l.whl", hash = "sha256:99d50def47305fe6f233eb8dabfd60047578ca87c9dcb235c9723ab1175180f4", size = 11774588 },
- { url = "https://files.pythonhosted.org/packages/8e/c9/b881f4157b9b884f2994fd08ee92ae3663fb24e34b0372ac3af999aa7fc6/ruff-0.9.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:d59105ae9c44152c3d40a9c40d6331a7acd1cdf5ef404fbe31178a77b174ea66", size = 11746848 },
- { url = "https://files.pythonhosted.org/packages/14/89/2f546c133f73886ed50a3d449e6bf4af27d92d2f960a43a93d89353f0945/ruff-0.9.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:f313b5800483770bd540cddac7c90fc46f895f427b7820f18fe1822697f1fec9", size = 11177525 },
- { url = "https://files.pythonhosted.org/packages/d7/93/6b98f2c12bf28ab9def59c50c9c49508519c5b5cfecca6de871cf01237f6/ruff-0.9.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:042ae32b41343888f59c0a4148f103208bf6b21c90118d51dc93a68366f4e903", size = 11996580 },
- { url = "https://files.pythonhosted.org/packages/8e/3f/b3fcaf4f6d875e679ac2b71a72f6691a8128ea3cb7be07cbb249f477c061/ruff-0.9.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:87862589373b33cc484b10831004e5e5ec47dc10d2b41ba770e837d4f429d721", size = 11525674 },
- { url = "https://files.pythonhosted.org/packages/f0/48/33fbf18defb74d624535d5d22adcb09a64c9bbabfa755bc666189a6b2210/ruff-0.9.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a17e1e01bee0926d351a1ee9bc15c445beae888f90069a6192a07a84af544b6b", size = 12739151 },
- { url = "https://files.pythonhosted.org/packages/63/b5/7e161080c5e19fa69495cbab7c00975ef8a90f3679caa6164921d7f52f4a/ruff-0.9.7-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:7c1f880ac5b2cbebd58b8ebde57069a374865c73f3bf41f05fe7a179c1c8ef22", size = 13416128 },
- { url = "https://files.pythonhosted.org/packages/4e/c8/b5e7d61fb1c1b26f271ac301ff6d9de5e4d9a9a63f67d732fa8f200f0c88/ruff-0.9.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e63fc20143c291cab2841dbb8260e96bafbe1ba13fd3d60d28be2c71e312da49", size = 12870858 },
- { url = "https://files.pythonhosted.org/packages/da/cb/2a1a8e4e291a54d28259f8fc6a674cd5b8833e93852c7ef5de436d6ed729/ruff-0.9.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:91ff963baed3e9a6a4eba2a02f4ca8eaa6eba1cc0521aec0987da8d62f53cbef", size = 14786046 },
- { url = "https://files.pythonhosted.org/packages/ca/6c/c8f8a313be1943f333f376d79724260da5701426c0905762e3ddb389e3f4/ruff-0.9.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:88362e3227c82f63eaebf0b2eff5b88990280fb1ecf7105523883ba8c3aaf6fb", size = 12550834 },
- { url = "https://files.pythonhosted.org/packages/9d/ad/f70cf5e8e7c52a25e166bdc84c082163c9c6f82a073f654c321b4dff9660/ruff-0.9.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:0372c5a90349f00212270421fe91874b866fd3626eb3b397ede06cd385f6f7e0", size = 11961307 },
- { url = "https://files.pythonhosted.org/packages/52/d5/4f303ea94a5f4f454daf4d02671b1fbfe2a318b5fcd009f957466f936c50/ruff-0.9.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:d76b8ab60e99e6424cd9d3d923274a1324aefce04f8ea537136b8398bbae0a62", size = 11612039 },
- { url = "https://files.pythonhosted.org/packages/eb/c8/bd12a23a75603c704ce86723be0648ba3d4ecc2af07eecd2e9fa112f7e19/ruff-0.9.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:0c439bdfc8983e1336577f00e09a4e7a78944fe01e4ea7fe616d00c3ec69a3d0", size = 12168177 },
- { url = "https://files.pythonhosted.org/packages/cc/57/d648d4f73400fef047d62d464d1a14591f2e6b3d4a15e93e23a53c20705d/ruff-0.9.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:115d1f15e8fdd445a7b4dc9a30abae22de3f6bcabeb503964904471691ef7606", size = 12610122 },
- { url = "https://files.pythonhosted.org/packages/49/79/acbc1edd03ac0e2a04ae2593555dbc9990b34090a9729a0c4c0cf20fb595/ruff-0.9.7-py3-none-win32.whl", hash = "sha256:e9ece95b7de5923cbf38893f066ed2872be2f2f477ba94f826c8defdd6ec6b7d", size = 9988751 },
- { url = "https://files.pythonhosted.org/packages/6d/95/67153a838c6b6ba7a2401241fd8a00cd8c627a8e4a0491b8d853dedeffe0/ruff-0.9.7-py3-none-win_amd64.whl", hash = "sha256:3770fe52b9d691a15f0b87ada29c45324b2ace8f01200fb0c14845e499eb0c2c", size = 11002987 },
- { url = "https://files.pythonhosted.org/packages/63/6a/aca01554949f3a401991dc32fe22837baeaccb8a0d868256cbb26a029778/ruff-0.9.7-py3-none-win_arm64.whl", hash = "sha256:b075a700b2533feb7a01130ff656a4ec0d5f340bb540ad98759b8401c32c2037", size = 10177763 },
+ { url = "https://files.pythonhosted.org/packages/b1/f3/3a1d22973291226df4b4e2ff70196b926b6f910c488479adb0eeb42a0d7f/ruff-0.9.7-py3-none-linux_armv6l.whl", hash = "sha256:99d50def47305fe6f233eb8dabfd60047578ca87c9dcb235c9723ab1175180f4", size = 11774588, upload-time = "2025-02-20T13:25:52.253Z" },
+ { url = "https://files.pythonhosted.org/packages/8e/c9/b881f4157b9b884f2994fd08ee92ae3663fb24e34b0372ac3af999aa7fc6/ruff-0.9.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:d59105ae9c44152c3d40a9c40d6331a7acd1cdf5ef404fbe31178a77b174ea66", size = 11746848, upload-time = "2025-02-20T13:25:57.279Z" },
+ { url = "https://files.pythonhosted.org/packages/14/89/2f546c133f73886ed50a3d449e6bf4af27d92d2f960a43a93d89353f0945/ruff-0.9.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:f313b5800483770bd540cddac7c90fc46f895f427b7820f18fe1822697f1fec9", size = 11177525, upload-time = "2025-02-20T13:26:00.007Z" },
+ { url = "https://files.pythonhosted.org/packages/d7/93/6b98f2c12bf28ab9def59c50c9c49508519c5b5cfecca6de871cf01237f6/ruff-0.9.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:042ae32b41343888f59c0a4148f103208bf6b21c90118d51dc93a68366f4e903", size = 11996580, upload-time = "2025-02-20T13:26:03.274Z" },
+ { url = "https://files.pythonhosted.org/packages/8e/3f/b3fcaf4f6d875e679ac2b71a72f6691a8128ea3cb7be07cbb249f477c061/ruff-0.9.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:87862589373b33cc484b10831004e5e5ec47dc10d2b41ba770e837d4f429d721", size = 11525674, upload-time = "2025-02-20T13:26:06.073Z" },
+ { url = "https://files.pythonhosted.org/packages/f0/48/33fbf18defb74d624535d5d22adcb09a64c9bbabfa755bc666189a6b2210/ruff-0.9.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a17e1e01bee0926d351a1ee9bc15c445beae888f90069a6192a07a84af544b6b", size = 12739151, upload-time = "2025-02-20T13:26:08.964Z" },
+ { url = "https://files.pythonhosted.org/packages/63/b5/7e161080c5e19fa69495cbab7c00975ef8a90f3679caa6164921d7f52f4a/ruff-0.9.7-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:7c1f880ac5b2cbebd58b8ebde57069a374865c73f3bf41f05fe7a179c1c8ef22", size = 13416128, upload-time = "2025-02-20T13:26:12.54Z" },
+ { url = "https://files.pythonhosted.org/packages/4e/c8/b5e7d61fb1c1b26f271ac301ff6d9de5e4d9a9a63f67d732fa8f200f0c88/ruff-0.9.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e63fc20143c291cab2841dbb8260e96bafbe1ba13fd3d60d28be2c71e312da49", size = 12870858, upload-time = "2025-02-20T13:26:16.794Z" },
+ { url = "https://files.pythonhosted.org/packages/da/cb/2a1a8e4e291a54d28259f8fc6a674cd5b8833e93852c7ef5de436d6ed729/ruff-0.9.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:91ff963baed3e9a6a4eba2a02f4ca8eaa6eba1cc0521aec0987da8d62f53cbef", size = 14786046, upload-time = "2025-02-20T13:26:19.85Z" },
+ { url = "https://files.pythonhosted.org/packages/ca/6c/c8f8a313be1943f333f376d79724260da5701426c0905762e3ddb389e3f4/ruff-0.9.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:88362e3227c82f63eaebf0b2eff5b88990280fb1ecf7105523883ba8c3aaf6fb", size = 12550834, upload-time = "2025-02-20T13:26:23.082Z" },
+ { url = "https://files.pythonhosted.org/packages/9d/ad/f70cf5e8e7c52a25e166bdc84c082163c9c6f82a073f654c321b4dff9660/ruff-0.9.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:0372c5a90349f00212270421fe91874b866fd3626eb3b397ede06cd385f6f7e0", size = 11961307, upload-time = "2025-02-20T13:26:26.738Z" },
+ { url = "https://files.pythonhosted.org/packages/52/d5/4f303ea94a5f4f454daf4d02671b1fbfe2a318b5fcd009f957466f936c50/ruff-0.9.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:d76b8ab60e99e6424cd9d3d923274a1324aefce04f8ea537136b8398bbae0a62", size = 11612039, upload-time = "2025-02-20T13:26:30.26Z" },
+ { url = "https://files.pythonhosted.org/packages/eb/c8/bd12a23a75603c704ce86723be0648ba3d4ecc2af07eecd2e9fa112f7e19/ruff-0.9.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:0c439bdfc8983e1336577f00e09a4e7a78944fe01e4ea7fe616d00c3ec69a3d0", size = 12168177, upload-time = "2025-02-20T13:26:33.452Z" },
+ { url = "https://files.pythonhosted.org/packages/cc/57/d648d4f73400fef047d62d464d1a14591f2e6b3d4a15e93e23a53c20705d/ruff-0.9.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:115d1f15e8fdd445a7b4dc9a30abae22de3f6bcabeb503964904471691ef7606", size = 12610122, upload-time = "2025-02-20T13:26:37.365Z" },
+ { url = "https://files.pythonhosted.org/packages/49/79/acbc1edd03ac0e2a04ae2593555dbc9990b34090a9729a0c4c0cf20fb595/ruff-0.9.7-py3-none-win32.whl", hash = "sha256:e9ece95b7de5923cbf38893f066ed2872be2f2f477ba94f826c8defdd6ec6b7d", size = 9988751, upload-time = "2025-02-20T13:26:40.366Z" },
+ { url = "https://files.pythonhosted.org/packages/6d/95/67153a838c6b6ba7a2401241fd8a00cd8c627a8e4a0491b8d853dedeffe0/ruff-0.9.7-py3-none-win_amd64.whl", hash = "sha256:3770fe52b9d691a15f0b87ada29c45324b2ace8f01200fb0c14845e499eb0c2c", size = 11002987, upload-time = "2025-02-20T13:26:43.762Z" },
+ { url = "https://files.pythonhosted.org/packages/63/6a/aca01554949f3a401991dc32fe22837baeaccb8a0d868256cbb26a029778/ruff-0.9.7-py3-none-win_arm64.whl", hash = "sha256:b075a700b2533feb7a01130ff656a4ec0d5f340bb540ad98759b8401c32c2037", size = 10177763, upload-time = "2025-02-20T13:26:48.92Z" },
]
[[package]]
name = "shellingham"
version = "1.5.4"
source = { registry = "https://pypi.org/simple" }
-sdist = { url = "https://files.pythonhosted.org/packages/58/15/8b3609fd3830ef7b27b655beb4b4e9c62313a4e8da8c676e142cc210d58e/shellingham-1.5.4.tar.gz", hash = "sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de", size = 10310 }
+sdist = { url = "https://files.pythonhosted.org/packages/58/15/8b3609fd3830ef7b27b655beb4b4e9c62313a4e8da8c676e142cc210d58e/shellingham-1.5.4.tar.gz", hash = "sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de", size = 10310, upload-time = "2023-10-24T04:13:40.426Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/e0/f9/0595336914c5619e5f28a1fb793285925a8cd4b432c9da0a987836c7f822/shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686", size = 9755 },
+ { url = "https://files.pythonhosted.org/packages/e0/f9/0595336914c5619e5f28a1fb793285925a8cd4b432c9da0a987836c7f822/shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686", size = 9755, upload-time = "2023-10-24T04:13:38.866Z" },
]
[[package]]
@@ -327,27 +204,39 @@ dependencies = [
{ name = "shellingham" },
{ name = "typing-extensions" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/cb/ce/dca7b219718afd37a0068f4f2530a727c2b74a8b6e8e0c0080a4c0de4fcd/typer-0.15.1.tar.gz", hash = "sha256:a0588c0a7fa68a1978a069818657778f86abe6ff5ea6abf472f940a08bfe4f0a", size = 99789 }
+sdist = { url = "https://files.pythonhosted.org/packages/cb/ce/dca7b219718afd37a0068f4f2530a727c2b74a8b6e8e0c0080a4c0de4fcd/typer-0.15.1.tar.gz", hash = "sha256:a0588c0a7fa68a1978a069818657778f86abe6ff5ea6abf472f940a08bfe4f0a", size = 99789, upload-time = "2024-12-04T17:44:58.956Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/d0/cc/0a838ba5ca64dc832aa43f727bd586309846b0ffb2ce52422543e6075e8a/typer-0.15.1-py3-none-any.whl", hash = "sha256:7994fb7b8155b64d3402518560648446072864beefd44aa2dc36972a5972e847", size = 44908 },
+ { url = "https://files.pythonhosted.org/packages/d0/cc/0a838ba5ca64dc832aa43f727bd586309846b0ffb2ce52422543e6075e8a/typer-0.15.1-py3-none-any.whl", hash = "sha256:7994fb7b8155b64d3402518560648446072864beefd44aa2dc36972a5972e847", size = 44908, upload-time = "2024-12-04T17:44:57.291Z" },
]
[[package]]
name = "types-pytz"
version = "2025.1.0.20250204"
source = { registry = "https://pypi.org/simple" }
-sdist = { url = "https://files.pythonhosted.org/packages/b3/d2/2190c54d53c04491ad72a1df019c5dfa692e6ab6c2dba1be7b6c9d530e30/types_pytz-2025.1.0.20250204.tar.gz", hash = "sha256:00f750132769f1c65a4f7240bc84f13985b4da774bd17dfbe5d9cd442746bd49", size = 10352 }
+sdist = { url = "https://files.pythonhosted.org/packages/b3/d2/2190c54d53c04491ad72a1df019c5dfa692e6ab6c2dba1be7b6c9d530e30/types_pytz-2025.1.0.20250204.tar.gz", hash = "sha256:00f750132769f1c65a4f7240bc84f13985b4da774bd17dfbe5d9cd442746bd49", size = 10352, upload-time = "2025-02-04T02:39:05.553Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/be/50/65ffad73746f1d8b15992c030e0fd22965fd5ae2c0206dc28873343b3230/types_pytz-2025.1.0.20250204-py3-none-any.whl", hash = "sha256:32ca4a35430e8b94f6603b35beb7f56c32260ddddd4f4bb305fdf8f92358b87e", size = 10059, upload-time = "2025-02-04T02:39:03.899Z" },
+]
+
+[[package]]
+name = "types-requests"
+version = "2.32.4.20250913"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "urllib3" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/36/27/489922f4505975b11de2b5ad07b4fe1dca0bca9be81a703f26c5f3acfce5/types_requests-2.32.4.20250913.tar.gz", hash = "sha256:abd6d4f9ce3a9383f269775a9835a4c24e5cd6b9f647d64f88aa4613c33def5d", size = 23113, upload-time = "2025-09-13T02:40:02.309Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/be/50/65ffad73746f1d8b15992c030e0fd22965fd5ae2c0206dc28873343b3230/types_pytz-2025.1.0.20250204-py3-none-any.whl", hash = "sha256:32ca4a35430e8b94f6603b35beb7f56c32260ddddd4f4bb305fdf8f92358b87e", size = 10059 },
+ { url = "https://files.pythonhosted.org/packages/2a/20/9a227ea57c1285986c4cf78400d0a91615d25b24e257fd9e2969606bdfae/types_requests-2.32.4.20250913-py3-none-any.whl", hash = "sha256:78c9c1fffebbe0fa487a418e0fa5252017e9c60d1a2da394077f1780f655d7e1", size = 20658, upload-time = "2025-09-13T02:40:01.115Z" },
]
[[package]]
name = "typing-extensions"
version = "4.12.2"
source = { registry = "https://pypi.org/simple" }
-sdist = { url = "https://files.pythonhosted.org/packages/df/db/f35a00659bc03fec321ba8bce9420de607a1d37f8342eee1863174c69557/typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8", size = 85321 }
+sdist = { url = "https://files.pythonhosted.org/packages/df/db/f35a00659bc03fec321ba8bce9420de607a1d37f8342eee1863174c69557/typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8", size = 85321, upload-time = "2024-06-07T18:52:15.995Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/26/9f/ad63fc0248c5379346306f8668cda6e2e2e9c95e01216d2b8ffd9ff037d0/typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d", size = 37438 },
+ { url = "https://files.pythonhosted.org/packages/26/9f/ad63fc0248c5379346306f8668cda6e2e2e9c95e01216d2b8ffd9ff037d0/typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d", size = 37438, upload-time = "2024-06-07T18:52:13.582Z" },
]
[[package]]
@@ -356,37 +245,30 @@ version = "0.1.0"
source = { virtual = "." }
dependencies = [
{ name = "mypy" },
- { name = "pygithub" },
{ name = "pytz" },
+ { name = "requests" },
{ name = "ruff" },
{ name = "typer" },
{ name = "types-pytz" },
+ { name = "types-requests" },
]
[package.metadata]
requires-dist = [
{ name = "mypy", specifier = ">=1.15.0" },
- { name = "pygithub", specifier = ">=2.6.1" },
{ name = "pytz", specifier = ">=2025.1" },
+ { name = "requests", specifier = ">=2.32.0" },
{ name = "ruff", specifier = ">=0.9.7" },
{ name = "typer", specifier = ">=0.15.1" },
{ name = "types-pytz", specifier = ">=2025.1.0.20250204" },
+ { name = "types-requests", specifier = ">=2.32.0" },
]
[[package]]
name = "urllib3"
version = "2.2.3"
source = { registry = "https://pypi.org/simple" }
-sdist = { url = "https://files.pythonhosted.org/packages/ed/63/22ba4ebfe7430b76388e7cd448d5478814d3032121827c12a2cc287e2260/urllib3-2.2.3.tar.gz", hash = "sha256:e7d814a81dad81e6caf2ec9fdedb284ecc9c73076b62654547cc64ccdcae26e9", size = 300677 }
-wheels = [
- { url = "https://files.pythonhosted.org/packages/ce/d9/5f4c13cecde62396b0d3fe530a50ccea91e7dfc1ccf0e09c228841bb5ba8/urllib3-2.2.3-py3-none-any.whl", hash = "sha256:ca899ca043dcb1bafa3e262d73aa25c465bfb49e0bd9dd5d59f1d0acba2f8fac", size = 126338 },
-]
-
-[[package]]
-name = "wrapt"
-version = "1.16.0"
-source = { registry = "https://pypi.org/simple" }
-sdist = { url = "https://files.pythonhosted.org/packages/95/4c/063a912e20bcef7124e0df97282a8af3ff3e4b603ce84c481d6d7346be0a/wrapt-1.16.0.tar.gz", hash = "sha256:5f370f952971e7d17c7d1ead40e49f32345a7f7a5373571ef44d800d06b1899d", size = 53972 }
+sdist = { url = "https://files.pythonhosted.org/packages/ed/63/22ba4ebfe7430b76388e7cd448d5478814d3032121827c12a2cc287e2260/urllib3-2.2.3.tar.gz", hash = "sha256:e7d814a81dad81e6caf2ec9fdedb284ecc9c73076b62654547cc64ccdcae26e9", size = 300677, upload-time = "2024-09-12T10:52:18.401Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/ff/21/abdedb4cdf6ff41ebf01a74087740a709e2edb146490e4d9beea054b0b7a/wrapt-1.16.0-py3-none-any.whl", hash = "sha256:6906c4100a8fcbf2fa735f6059214bb13b97f75b1a61777fcf6432121ef12ef1", size = 23362 },
+ { url = "https://files.pythonhosted.org/packages/ce/d9/5f4c13cecde62396b0d3fe530a50ccea91e7dfc1ccf0e09c228841bb5ba8/urllib3-2.2.3-py3-none-any.whl", hash = "sha256:ca899ca043dcb1bafa3e262d73aa25c465bfb49e0bd9dd5d59f1d0acba2f8fac", size = 126338, upload-time = "2024-09-12T10:52:16.589Z" },
]
@@ -368,6 +368,8 @@ pub(crate) fn check_postgres_and_protobuf_migrations() -> NamedJob {
.runs_on(runners::LINUX_DEFAULT)
.add_env(("GIT_AUTHOR_NAME", "Protobuf Action"))
.add_env(("GIT_AUTHOR_EMAIL", "ci@zed.dev"))
+ .add_env(("GIT_COMMITTER_NAME", "Protobuf Action"))
+ .add_env(("GIT_COMMITTER_EMAIL", "ci@zed.dev"))
.add_step(steps::checkout_repo().with(("fetch-depth", 0))) // fetch full history
.add_step(remove_untracked_files())
.add_step(ensure_fresh_merge())