Detailed changes
@@ -20606,6 +20606,42 @@ dependencies = [
"zlog",
]
+[[package]]
+name = "zeta_cli"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "clap",
+ "client",
+ "debug_adapter_extension",
+ "extension",
+ "fs",
+ "futures 0.3.31",
+ "gpui",
+ "gpui_tokio",
+ "language",
+ "language_extension",
+ "language_model",
+ "language_models",
+ "languages",
+ "node_runtime",
+ "paths",
+ "project",
+ "prompt_store",
+ "release_channel",
+ "reqwest_client",
+ "serde",
+ "serde_json",
+ "settings",
+ "shellexpand 2.1.2",
+ "smol",
+ "terminal_view",
+ "util",
+ "watch",
+ "workspace-hack",
+ "zeta",
+]
+
[[package]]
name = "zip"
version = "0.6.6"
@@ -189,6 +189,7 @@ members = [
"crates/zed",
"crates/zed_actions",
"crates/zeta",
+ "crates/zeta_cli",
"crates/zlog",
"crates/zlog_settings",
@@ -18,7 +18,7 @@ use collections::{HashMap, HashSet};
use extension::ExtensionHostProxy;
use futures::future;
use gpui::http_client::read_proxy_from_env;
-use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, UpdateGlobal};
+use gpui::{App, AppContext, Application, AsyncApp, Entity, UpdateGlobal};
use gpui_tokio::Tokio;
use language::LanguageRegistry;
use language_model::{ConfiguredModel, LanguageModel, LanguageModelRegistry, SelectedModel};
@@ -337,7 +337,8 @@ pub struct AgentAppState {
}
pub fn init(cx: &mut App) -> Arc<AgentAppState> {
- release_channel::init(SemanticVersion::default(), cx);
+ let app_version = AppVersion::global(cx);
+ release_channel::init(app_version, cx);
gpui_tokio::init(cx);
let mut settings_store = SettingsStore::new(cx);
@@ -350,7 +351,7 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
// Set User-Agent so we can download language servers from GitHub
let user_agent = format!(
"Zed/{} ({}; {})",
- AppVersion::global(cx),
+ app_version,
std::env::consts::OS,
std::env::consts::ARCH
);
@@ -146,14 +146,14 @@ pub struct InlineCompletion {
input_events: Arc<str>,
input_excerpt: Arc<str>,
output_excerpt: Arc<str>,
- request_sent_at: Instant,
+ buffer_snapshotted_at: Instant,
response_received_at: Instant,
}
impl InlineCompletion {
fn latency(&self) -> Duration {
self.response_received_at
- .duration_since(self.request_sent_at)
+ .duration_since(self.buffer_snapshotted_at)
}
fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
@@ -391,104 +391,48 @@ impl Zeta {
+ Send
+ 'static,
{
+ let buffer = buffer.clone();
+ let buffer_snapshotted_at = Instant::now();
let snapshot = self.report_changes_for_buffer(&buffer, cx);
- let diagnostic_groups = snapshot.diagnostic_groups(None);
- let cursor_point = cursor.to_point(&snapshot);
- let cursor_offset = cursor_point.to_offset(&snapshot);
- let events = self.events.clone();
- let path: Arc<Path> = snapshot
- .file()
- .map(|f| Arc::from(f.full_path(cx).as_path()))
- .unwrap_or_else(|| Arc::from(Path::new("untitled")));
-
let zeta = cx.entity();
+ let events = self.events.clone();
let client = self.client.clone();
let llm_token = self.llm_token.clone();
let app_version = AppVersion::global(cx);
- let buffer = buffer.clone();
-
- let local_lsp_store =
- project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local());
- let diagnostic_groups = if let Some(local_lsp_store) = local_lsp_store {
- Some(
- diagnostic_groups
- .into_iter()
- .filter_map(|(language_server_id, diagnostic_group)| {
- let language_server =
- local_lsp_store.running_language_server_for_id(language_server_id)?;
-
- Some((
- language_server.name(),
- diagnostic_group.resolve::<usize>(&snapshot),
- ))
- })
- .collect::<Vec<_>>(),
- )
- } else {
- None
- };
+ let full_path: Arc<Path> = snapshot
+ .file()
+ .map(|f| Arc::from(f.full_path(cx).as_path()))
+ .unwrap_or_else(|| Arc::from(Path::new("untitled")));
+ let full_path_str = full_path.to_string_lossy().to_string();
+ let cursor_point = cursor.to_point(&snapshot);
+ let cursor_offset = cursor_point.to_offset(&snapshot);
+ let make_events_prompt = move || prompt_for_events(&events, MAX_EVENT_TOKENS);
+ let gather_task = gather_context(
+ project,
+ full_path_str,
+ &snapshot,
+ cursor_point,
+ make_events_prompt,
+ can_collect_data,
+ cx,
+ );
cx.spawn(async move |this, cx| {
- let request_sent_at = Instant::now();
-
- struct BackgroundValues {
- input_events: String,
- input_excerpt: String,
- speculated_output: String,
- editable_range: Range<usize>,
- input_outline: String,
- }
-
- let values = cx
- .background_spawn({
- let snapshot = snapshot.clone();
- let path = path.clone();
- async move {
- let path = path.to_string_lossy();
- let input_excerpt = excerpt_for_cursor_position(
- cursor_point,
- &path,
- &snapshot,
- MAX_REWRITE_TOKENS,
- MAX_CONTEXT_TOKENS,
- );
- let input_events = prompt_for_events(&events, MAX_EVENT_TOKENS);
- let input_outline = prompt_for_outline(&snapshot);
-
- anyhow::Ok(BackgroundValues {
- input_events,
- input_excerpt: input_excerpt.prompt,
- speculated_output: input_excerpt.speculated_output,
- editable_range: input_excerpt.editable_range.to_offset(&snapshot),
- input_outline,
- })
- }
- })
- .await?;
+ let GatherContextOutput {
+ body,
+ editable_range,
+ } = gather_task.await?;
log::debug!(
"Events:\n{}\nExcerpt:\n{:?}",
- values.input_events,
- values.input_excerpt
+ body.input_events,
+ body.input_excerpt
);
- let body = PredictEditsBody {
- input_events: values.input_events.clone(),
- input_excerpt: values.input_excerpt.clone(),
- speculated_output: Some(values.speculated_output),
- outline: Some(values.input_outline.clone()),
- can_collect_data,
- diagnostic_groups: diagnostic_groups.and_then(|diagnostic_groups| {
- diagnostic_groups
- .into_iter()
- .map(|(name, diagnostic_group)| {
- Ok((name.to_string(), serde_json::to_value(diagnostic_group)?))
- })
- .collect::<Result<Vec<_>>>()
- .log_err()
- }),
- };
+ let input_outline = body.outline.clone().unwrap_or_default();
+ let input_events = body.input_events.clone();
+ let input_excerpt = body.input_excerpt.clone();
let response = perform_predict_edits(PerformPredictEditsParams {
client,
@@ -546,13 +490,13 @@ impl Zeta {
response,
buffer,
&snapshot,
- values.editable_range,
+ editable_range,
cursor_offset,
- path,
- values.input_outline,
- values.input_events,
- values.input_excerpt,
- request_sent_at,
+ full_path,
+ input_outline,
+ input_events,
+ input_excerpt,
+ buffer_snapshotted_at,
&cx,
)
.await
@@ -751,7 +695,7 @@ and then another
)
}
- fn perform_predict_edits(
+ pub fn perform_predict_edits(
params: PerformPredictEditsParams,
) -> impl Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>> {
async move {
@@ -906,7 +850,7 @@ and then another
input_outline: String,
input_events: String,
input_excerpt: String,
- request_sent_at: Instant,
+ buffer_snapshotted_at: Instant,
cx: &AsyncApp,
) -> Task<Result<Option<InlineCompletion>>> {
let snapshot = snapshot.clone();
@@ -952,7 +896,7 @@ and then another
input_events: input_events.into(),
input_excerpt: input_excerpt.into(),
output_excerpt,
- request_sent_at,
+ buffer_snapshotted_at,
response_received_at: Instant::now(),
}))
})
@@ -1136,7 +1080,7 @@ and then another
}
}
-struct PerformPredictEditsParams {
+pub struct PerformPredictEditsParams {
pub client: Arc<Client>,
pub llm_token: LlmApiToken,
pub app_version: SemanticVersion,
@@ -1211,6 +1155,77 @@ fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b:
.sum()
}
+pub struct GatherContextOutput {
+ pub body: PredictEditsBody,
+ pub editable_range: Range<usize>,
+}
+
+pub fn gather_context(
+ project: Option<&Entity<Project>>,
+ full_path_str: String,
+ snapshot: &BufferSnapshot,
+ cursor_point: language::Point,
+ make_events_prompt: impl FnOnce() -> String + Send + 'static,
+ can_collect_data: bool,
+ cx: &App,
+) -> Task<Result<GatherContextOutput>> {
+ let local_lsp_store =
+ project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local());
+ let diagnostic_groups: Vec<(String, serde_json::Value)> =
+ if let Some(local_lsp_store) = local_lsp_store {
+ snapshot
+ .diagnostic_groups(None)
+ .into_iter()
+ .filter_map(|(language_server_id, diagnostic_group)| {
+ let language_server =
+ local_lsp_store.running_language_server_for_id(language_server_id)?;
+ let diagnostic_group = diagnostic_group.resolve::<usize>(&snapshot);
+ let language_server_name = language_server.name().to_string();
+ let serialized = serde_json::to_value(diagnostic_group).unwrap();
+ Some((language_server_name, serialized))
+ })
+ .collect::<Vec<_>>()
+ } else {
+ Vec::new()
+ };
+
+ cx.background_spawn({
+ let snapshot = snapshot.clone();
+ async move {
+ let diagnostic_groups = if diagnostic_groups.is_empty() {
+ None
+ } else {
+ Some(diagnostic_groups)
+ };
+
+ let input_excerpt = excerpt_for_cursor_position(
+ cursor_point,
+ &full_path_str,
+ &snapshot,
+ MAX_REWRITE_TOKENS,
+ MAX_CONTEXT_TOKENS,
+ );
+ let input_events = make_events_prompt();
+ let input_outline = prompt_for_outline(&snapshot);
+ let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
+
+ let body = PredictEditsBody {
+ input_events,
+ input_excerpt: input_excerpt.prompt,
+ speculated_output: Some(input_excerpt.speculated_output),
+ outline: Some(input_outline),
+ can_collect_data,
+ diagnostic_groups,
+ };
+
+ Ok(GatherContextOutput {
+ body,
+ editable_range,
+ })
+ }
+ })
+}
+
fn prompt_for_outline(snapshot: &BufferSnapshot) -> String {
let mut input_outline = String::new();
@@ -1261,7 +1276,7 @@ struct RegisteredBuffer {
}
#[derive(Clone)]
-enum Event {
+pub enum Event {
BufferChange {
old_snapshot: BufferSnapshot,
new_snapshot: BufferSnapshot,
@@ -1845,7 +1860,7 @@ mod tests {
input_events: "".into(),
input_excerpt: "".into(),
output_excerpt: "".into(),
- request_sent_at: Instant::now(),
+ buffer_snapshotted_at: Instant::now(),
response_received_at: Instant::now(),
};
@@ -0,0 +1,45 @@
+[package]
+name = "zeta_cli"
+version = "0.1.0"
+edition.workspace = true
+publish.workspace = true
+license = "GPL-3.0-or-later"
+
+[lints]
+workspace = true
+
+[[bin]]
+name = "zeta"
+path = "src/main.rs"
+
+[dependencies]
+anyhow.workspace = true
+clap.workspace = true
+client.workspace = true
+debug_adapter_extension.workspace = true
+extension.workspace = true
+fs.workspace = true
+futures.workspace = true
+gpui.workspace = true
+gpui_tokio.workspace = true
+language.workspace = true
+language_extension.workspace = true
+language_model.workspace = true
+language_models.workspace = true
+languages = { workspace = true, features = ["load-grammars"] }
+node_runtime.workspace = true
+paths.workspace = true
+project.workspace = true
+prompt_store.workspace = true
+release_channel.workspace = true
+reqwest_client.workspace = true
+serde.workspace = true
+serde_json.workspace = true
+settings.workspace = true
+shellexpand.workspace = true
+terminal_view.workspace = true
+util.workspace = true
+watch.workspace = true
+workspace-hack.workspace = true
+zeta.workspace = true
+smol.workspace = true
@@ -0,0 +1 @@
+../../LICENSE-GPL
@@ -0,0 +1,14 @@
+fn main() {
+ let cargo_toml =
+ std::fs::read_to_string("../zed/Cargo.toml").expect("Failed to read Cargo.toml");
+ let version = cargo_toml
+ .lines()
+ .find(|line| line.starts_with("version = "))
+ .expect("Version not found in crates/zed/Cargo.toml")
+ .split('=')
+ .nth(1)
+ .expect("Invalid version format")
+ .trim()
+ .trim_matches('"');
+ println!("cargo:rustc-env=ZED_PKG_VERSION={}", version);
+}
@@ -0,0 +1,128 @@
+use client::{Client, ProxySettings, UserStore};
+use extension::ExtensionHostProxy;
+use fs::RealFs;
+use gpui::http_client::read_proxy_from_env;
+use gpui::{App, AppContext, Entity};
+use gpui_tokio::Tokio;
+use language::LanguageRegistry;
+use language_extension::LspAccess;
+use node_runtime::{NodeBinaryOptions, NodeRuntime};
+use project::Project;
+use project::project_settings::ProjectSettings;
+use release_channel::AppVersion;
+use reqwest_client::ReqwestClient;
+use settings::{Settings, SettingsStore};
+use std::path::PathBuf;
+use std::sync::Arc;
+use util::ResultExt as _;
+
+/// Headless subset of `workspace::AppState`.
+pub struct ZetaCliAppState {
+ pub languages: Arc<LanguageRegistry>,
+ pub client: Arc<Client>,
+ pub user_store: Entity<UserStore>,
+ pub fs: Arc<dyn fs::Fs>,
+ pub node_runtime: NodeRuntime,
+}
+
+// TODO: dedupe with crates/eval/src/eval.rs
+pub fn init(cx: &mut App) -> ZetaCliAppState {
+ let app_version = AppVersion::load(env!("ZED_PKG_VERSION"));
+ release_channel::init(app_version, cx);
+ gpui_tokio::init(cx);
+
+ let mut settings_store = SettingsStore::new(cx);
+ settings_store
+ .set_default_settings(settings::default_settings().as_ref(), cx)
+ .unwrap();
+ cx.set_global(settings_store);
+ client::init_settings(cx);
+
+ // Set User-Agent so we can download language servers from GitHub
+ let user_agent = format!(
+ "Zed/{} ({}; {})",
+ app_version,
+ std::env::consts::OS,
+ std::env::consts::ARCH
+ );
+ let proxy_str = ProxySettings::get_global(cx).proxy.to_owned();
+ let proxy_url = proxy_str
+ .as_ref()
+ .and_then(|input| input.parse().ok())
+ .or_else(read_proxy_from_env);
+ let http = {
+ let _guard = Tokio::handle(cx).enter();
+
+ ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent)
+ .expect("could not start HTTP client")
+ };
+ cx.set_http_client(Arc::new(http));
+
+ Project::init_settings(cx);
+
+ let client = Client::production(cx);
+ cx.set_http_client(client.http_client());
+
+ let git_binary_path = None;
+ let fs = Arc::new(RealFs::new(
+ git_binary_path,
+ cx.background_executor().clone(),
+ ));
+
+ let mut languages = LanguageRegistry::new(cx.background_executor().clone());
+ languages.set_language_server_download_dir(paths::languages_dir().clone());
+ let languages = Arc::new(languages);
+
+ let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
+
+ extension::init(cx);
+
+ let (mut tx, rx) = watch::channel(None);
+ cx.observe_global::<SettingsStore>(move |cx| {
+ let settings = &ProjectSettings::get_global(cx).node;
+ let options = NodeBinaryOptions {
+ allow_path_lookup: !settings.ignore_system_version,
+ allow_binary_download: true,
+ use_paths: settings.path.as_ref().map(|node_path| {
+ let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref());
+ let npm_path = settings
+ .npm_path
+ .as_ref()
+ .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref()));
+ (
+ node_path.clone(),
+ npm_path.unwrap_or_else(|| {
+ let base_path = PathBuf::new();
+ node_path.parent().unwrap_or(&base_path).join("npm")
+ }),
+ )
+ }),
+ };
+ tx.send(Some(options)).log_err();
+ })
+ .detach();
+ let node_runtime = NodeRuntime::new(client.http_client(), None, rx);
+
+ let extension_host_proxy = ExtensionHostProxy::global(cx);
+
+ language::init(cx);
+ debug_adapter_extension::init(extension_host_proxy.clone(), cx);
+ language_extension::init(
+ LspAccess::Noop,
+ extension_host_proxy.clone(),
+ languages.clone(),
+ );
+ language_model::init(client.clone(), cx);
+ language_models::init(user_store.clone(), client.clone(), cx);
+ languages::init(languages.clone(), node_runtime.clone(), cx);
+ prompt_store::init(cx);
+ terminal_view::init(cx);
+
+ ZetaCliAppState {
+ languages,
+ client,
+ user_store,
+ fs,
+ node_runtime,
+ }
+}
@@ -0,0 +1,376 @@
+mod headless;
+
+use anyhow::{Result, anyhow};
+use clap::{Args, Parser, Subcommand};
+use futures::channel::mpsc;
+use futures::{FutureExt as _, StreamExt as _};
+use gpui::{AppContext, Application, AsyncApp};
+use gpui::{Entity, Task};
+use language::Bias;
+use language::Buffer;
+use language::Point;
+use language_model::LlmApiToken;
+use project::{Project, ProjectPath};
+use release_channel::AppVersion;
+use reqwest_client::ReqwestClient;
+use std::path::{Path, PathBuf};
+use std::process::exit;
+use std::str::FromStr;
+use std::sync::Arc;
+use std::time::Duration;
+use zeta::{GatherContextOutput, PerformPredictEditsParams, Zeta, gather_context};
+
+use crate::headless::ZetaCliAppState;
+
+#[derive(Parser, Debug)]
+#[command(name = "zeta")]
+struct ZetaCliArgs {
+ #[command(subcommand)]
+ command: Commands,
+}
+
+#[derive(Subcommand, Debug)]
+enum Commands {
+ Context(ContextArgs),
+ Predict {
+ #[arg(long)]
+ predict_edits_body: Option<FileOrStdin>,
+ #[clap(flatten)]
+ context_args: Option<ContextArgs>,
+ },
+}
+
+#[derive(Debug, Args)]
+#[group(requires = "worktree")]
+struct ContextArgs {
+ #[arg(long)]
+ worktree: PathBuf,
+ #[arg(long)]
+ cursor: CursorPosition,
+ #[arg(long)]
+ use_language_server: bool,
+ #[arg(long)]
+ events: Option<FileOrStdin>,
+}
+
+#[derive(Debug, Clone)]
+enum FileOrStdin {
+ File(PathBuf),
+ Stdin,
+}
+
+impl FileOrStdin {
+ async fn read_to_string(&self) -> Result<String, std::io::Error> {
+ match self {
+ FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
+ FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
+ }
+ }
+}
+
+impl FromStr for FileOrStdin {
+ type Err = <PathBuf as FromStr>::Err;
+
+ fn from_str(s: &str) -> Result<Self, Self::Err> {
+ match s {
+ "-" => Ok(Self::Stdin),
+ _ => Ok(Self::File(PathBuf::from_str(s)?)),
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+struct CursorPosition {
+ path: PathBuf,
+ point: Point,
+}
+
+impl FromStr for CursorPosition {
+ type Err = anyhow::Error;
+
+ fn from_str(s: &str) -> Result<Self> {
+ let parts: Vec<&str> = s.split(':').collect();
+ if parts.len() != 3 {
+ return Err(anyhow!(
+ "Invalid cursor format. Expected 'file.rs:line:column', got '{}'",
+ s
+ ));
+ }
+
+ let path = PathBuf::from(parts[0]);
+ let line: u32 = parts[1]
+ .parse()
+ .map_err(|_| anyhow!("Invalid line number: '{}'", parts[1]))?;
+ let column: u32 = parts[2]
+ .parse()
+ .map_err(|_| anyhow!("Invalid column number: '{}'", parts[2]))?;
+
+ // Convert from 1-based to 0-based indexing
+ let point = Point::new(line.saturating_sub(1), column.saturating_sub(1));
+
+ Ok(CursorPosition { path, point })
+ }
+}
+
+async fn get_context(
+ args: ContextArgs,
+ app_state: &Arc<ZetaCliAppState>,
+ cx: &mut AsyncApp,
+) -> Result<GatherContextOutput> {
+ let ContextArgs {
+ worktree: worktree_path,
+ cursor,
+ use_language_server,
+ events,
+ } = args;
+
+ let worktree_path = worktree_path.canonicalize()?;
+ if cursor.path.is_absolute() {
+ return Err(anyhow!("Absolute paths are not supported in --cursor"));
+ }
+
+ let (project, _lsp_open_handle, buffer) = if use_language_server {
+ let (project, lsp_open_handle, buffer) =
+ open_buffer_with_language_server(&worktree_path, &cursor.path, &app_state, cx).await?;
+ (Some(project), Some(lsp_open_handle), buffer)
+ } else {
+ let abs_path = worktree_path.join(&cursor.path);
+ let content = smol::fs::read_to_string(&abs_path).await?;
+ let buffer = cx.new(|cx| Buffer::local(content, cx))?;
+ (None, None, buffer)
+ };
+
+ let worktree_name = worktree_path
+ .file_name()
+ .ok_or_else(|| anyhow!("--worktree path must end with a folder name"))?;
+ let full_path_str = PathBuf::from(worktree_name)
+ .join(&cursor.path)
+ .to_string_lossy()
+ .to_string();
+
+ let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
+ let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
+ if clipped_cursor != cursor.point {
+ let max_row = snapshot.max_point().row;
+ if cursor.point.row < max_row {
+ return Err(anyhow!(
+ "Cursor position {:?} is out of bounds (line length is {})",
+ cursor.point,
+ snapshot.line_len(cursor.point.row)
+ ));
+ } else {
+ return Err(anyhow!(
+ "Cursor position {:?} is out of bounds (max row is {})",
+ cursor.point,
+ max_row
+ ));
+ }
+ }
+
+ let events = match events {
+ Some(events) => events.read_to_string().await?,
+ None => String::new(),
+ };
+ let can_collect_data = false;
+ cx.update(|cx| {
+ gather_context(
+ project.as_ref(),
+ full_path_str,
+ &snapshot,
+ clipped_cursor,
+ move || events,
+ can_collect_data,
+ cx,
+ )
+ })?
+ .await
+}
+
+pub async fn open_buffer_with_language_server(
+ worktree_path: &Path,
+ path: &Path,
+ app_state: &Arc<ZetaCliAppState>,
+ cx: &mut AsyncApp,
+) -> Result<(Entity<Project>, Entity<Entity<Buffer>>, Entity<Buffer>)> {
+ let project = cx.update(|cx| {
+ Project::local(
+ app_state.client.clone(),
+ app_state.node_runtime.clone(),
+ app_state.user_store.clone(),
+ app_state.languages.clone(),
+ app_state.fs.clone(),
+ None,
+ cx,
+ )
+ })?;
+
+ let worktree = project
+ .update(cx, |project, cx| {
+ project.create_worktree(worktree_path, true, cx)
+ })?
+ .await?;
+
+ let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath {
+ worktree_id: worktree.id(),
+ path: path.to_path_buf().into(),
+ })?;
+
+ let buffer = project
+ .update(cx, |project, cx| project.open_buffer(project_path, cx))?
+ .await?;
+
+ let lsp_open_handle = project.update(cx, |project, cx| {
+ project.register_buffer_with_language_servers(&buffer, cx)
+ })?;
+
+ let log_prefix = path.to_string_lossy().to_string();
+ wait_for_lang_server(&project, &buffer, log_prefix, cx).await?;
+
+ Ok((project, lsp_open_handle, buffer))
+}
+
+// TODO: Dedupe with similar function in crates/eval/src/instance.rs
+pub fn wait_for_lang_server(
+ project: &Entity<Project>,
+ buffer: &Entity<Buffer>,
+ log_prefix: String,
+ cx: &mut AsyncApp,
+) -> Task<Result<()>> {
+ println!("{}⏵ Waiting for language server", log_prefix);
+
+ let (mut tx, mut rx) = mpsc::channel(1);
+
+ let lsp_store = project
+ .read_with(cx, |project, _| project.lsp_store())
+ .unwrap();
+
+ let has_lang_server = buffer
+ .update(cx, |buffer, cx| {
+ lsp_store.update(cx, |lsp_store, cx| {
+ lsp_store
+ .language_servers_for_local_buffer(&buffer, cx)
+ .next()
+ .is_some()
+ })
+ })
+ .unwrap_or(false);
+
+ if has_lang_server {
+ project
+ .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
+ .unwrap()
+ .detach();
+ }
+
+ let subscriptions = [
+ cx.subscribe(&lsp_store, {
+ let log_prefix = log_prefix.clone();
+ move |_, event, _| match event {
+ project::LspStoreEvent::LanguageServerUpdate {
+ message:
+ client::proto::update_language_server::Variant::WorkProgress(
+ client::proto::LspWorkProgress {
+ message: Some(message),
+ ..
+ },
+ ),
+ ..
+ } => println!("{}⟲ {message}", log_prefix),
+ _ => {}
+ }
+ }),
+ cx.subscribe(&project, {
+ let buffer = buffer.clone();
+ move |project, event, cx| match event {
+ project::Event::LanguageServerAdded(_, _, _) => {
+ let buffer = buffer.clone();
+ project
+ .update(cx, |project, cx| project.save_buffer(buffer, cx))
+ .detach();
+ }
+ project::Event::DiskBasedDiagnosticsFinished { .. } => {
+ tx.try_send(()).ok();
+ }
+ _ => {}
+ }
+ }),
+ ];
+
+ cx.spawn(async move |cx| {
+ let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0));
+ let result = futures::select! {
+ _ = rx.next() => {
+ println!("{}⚑ Language server idle", log_prefix);
+ anyhow::Ok(())
+ },
+ _ = timeout.fuse() => {
+ anyhow::bail!("LSP wait timed out after 5 minutes");
+ }
+ };
+ drop(subscriptions);
+ result
+ })
+}
+
+fn main() {
+ let args = ZetaCliArgs::parse();
+ let http_client = Arc::new(ReqwestClient::new());
+ let app = Application::headless().with_http_client(http_client);
+
+ app.run(move |cx| {
+ let app_state = Arc::new(headless::init(cx));
+ cx.spawn(async move |cx| {
+ let result = match args.command {
+ Commands::Context(context_args) => get_context(context_args, &app_state, cx)
+ .await
+ .map(|output| serde_json::to_string_pretty(&output.body).unwrap()),
+ Commands::Predict {
+ predict_edits_body,
+ context_args,
+ } => {
+ cx.spawn(async move |cx| {
+ let app_version = cx.update(|cx| AppVersion::global(cx))?;
+ app_state.client.sign_in(true, cx).await?;
+ let llm_token = LlmApiToken::default();
+ llm_token.refresh(&app_state.client).await?;
+
+ let predict_edits_body =
+ if let Some(predict_edits_body) = predict_edits_body {
+ serde_json::from_str(&predict_edits_body.read_to_string().await?)?
+ } else if let Some(context_args) = context_args {
+ get_context(context_args, &app_state, cx).await?.body
+ } else {
+ return Err(anyhow!(
+ "Expected either --predict-edits-body-file \
+ or the required args of the `context` command."
+ ));
+ };
+
+ let (response, _usage) =
+ Zeta::perform_predict_edits(PerformPredictEditsParams {
+ client: app_state.client.clone(),
+ llm_token,
+ app_version,
+ body: predict_edits_body,
+ })
+ .await?;
+
+ Ok(response.output_excerpt)
+ })
+ .await
+ }
+ };
+ match result {
+ Ok(output) => {
+ println!("{}", output);
+ let _ = cx.update(|cx| cx.quit());
+ }
+ Err(e) => {
+ eprintln!("Failed: {:?}", e);
+ exit(1);
+ }
+ }
+ })
+ .detach();
+ });
+}