Detailed changes
@@ -21741,6 +21741,7 @@ dependencies = [
"futures 0.3.31",
"gpui",
"gpui_tokio",
+ "indoc",
"language",
"language_extension",
"language_model",
@@ -21751,6 +21752,7 @@ dependencies = [
"ordered-float 2.10.1",
"paths",
"polars",
+ "pretty_assertions",
"project",
"prompt_store",
"pulldown-cmark 0.12.2",
@@ -1,4 +1,5 @@
pub mod predict_edits_v3;
+pub mod udiff;
use std::str::FromStr;
use std::sync::Arc;
@@ -0,0 +1,270 @@
+use std::borrow::Cow;
+
+#[derive(Debug, PartialEq)]
+pub enum DiffLine<'a> {
+ OldPath { path: Cow<'a, str> },
+ NewPath { path: Cow<'a, str> },
+ HunkHeader(Option<HunkLocation>),
+ Context(&'a str),
+ Deletion(&'a str),
+ Addition(&'a str),
+ Garbage,
+}
+
+#[derive(Debug, PartialEq)]
+pub struct HunkLocation {
+ start_line_old: u32,
+ count_old: u32,
+ start_line_new: u32,
+ count_new: u32,
+}
+
+impl<'a> DiffLine<'a> {
+ pub fn parse(line: &'a str) -> Self {
+ Self::try_parse(line).unwrap_or(Self::Garbage)
+ }
+
+ fn try_parse(line: &'a str) -> Option<Self> {
+ if let Some(header) = line.strip_prefix("---").and_then(eat_required_whitespace) {
+ let path = parse_header_path("a/", header);
+ Some(Self::OldPath { path })
+ } else if let Some(header) = line.strip_prefix("+++").and_then(eat_required_whitespace) {
+ Some(Self::NewPath {
+ path: parse_header_path("b/", header),
+ })
+ } else if let Some(header) = line.strip_prefix("@@").and_then(eat_required_whitespace) {
+ if header.starts_with("...") {
+ return Some(Self::HunkHeader(None));
+ }
+
+ let (start_line_old, header) = header.strip_prefix('-')?.split_once(',')?;
+ let mut parts = header.split_ascii_whitespace();
+ let count_old = parts.next()?;
+ let (start_line_new, count_new) = parts.next()?.strip_prefix('+')?.split_once(',')?;
+
+ Some(Self::HunkHeader(Some(HunkLocation {
+ start_line_old: start_line_old.parse::<u32>().ok()?.saturating_sub(1),
+ count_old: count_old.parse().ok()?,
+ start_line_new: start_line_new.parse::<u32>().ok()?.saturating_sub(1),
+ count_new: count_new.parse().ok()?,
+ })))
+ } else if let Some(deleted_header) = line.strip_prefix("-") {
+ Some(Self::Deletion(deleted_header))
+ } else if line.is_empty() {
+ Some(Self::Context(""))
+ } else if let Some(context) = line.strip_prefix(" ") {
+ Some(Self::Context(context))
+ } else {
+ Some(Self::Addition(line.strip_prefix("+")?))
+ }
+ }
+}
+
+fn parse_header_path<'a>(strip_prefix: &'static str, header: &'a str) -> Cow<'a, str> {
+ if !header.contains(['"', '\\']) {
+ let path = header.split_ascii_whitespace().next().unwrap_or(header);
+ return Cow::Borrowed(path.strip_prefix(strip_prefix).unwrap_or(path));
+ }
+
+ let mut path = String::with_capacity(header.len());
+ let mut in_quote = false;
+ let mut chars = header.chars().peekable();
+ let mut strip_prefix = Some(strip_prefix);
+
+ while let Some(char) = chars.next() {
+ if char == '"' {
+ in_quote = !in_quote;
+ } else if char == '\\' {
+ let Some(&next_char) = chars.peek() else {
+ break;
+ };
+ chars.next();
+ path.push(next_char);
+ } else if char.is_ascii_whitespace() && !in_quote {
+ break;
+ } else {
+ path.push(char);
+ }
+
+ if let Some(prefix) = strip_prefix
+ && path == prefix
+ {
+ strip_prefix.take();
+ path.clear();
+ }
+ }
+
+ Cow::Owned(path)
+}
+
+fn eat_required_whitespace(header: &str) -> Option<&str> {
+ let trimmed = header.trim_ascii_start();
+
+ if trimmed.len() == header.len() {
+ None
+ } else {
+ Some(trimmed)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use indoc::indoc;
+
+ #[test]
+ fn parse_lines_simple() {
+ let input = indoc! {"
+ diff --git a/text.txt b/text.txt
+ index 86c770d..a1fd855 100644
+ --- a/file.txt
+ +++ b/file.txt
+ @@ -1,2 +1,3 @@
+ context
+ -deleted
+ +inserted
+ garbage
+
+ --- b/file.txt
+ +++ a/file.txt
+ "};
+
+ let lines = input.lines().map(DiffLine::parse).collect::<Vec<_>>();
+
+ pretty_assertions::assert_eq!(
+ lines,
+ &[
+ DiffLine::Garbage,
+ DiffLine::Garbage,
+ DiffLine::OldPath {
+ path: "file.txt".into()
+ },
+ DiffLine::NewPath {
+ path: "file.txt".into()
+ },
+ DiffLine::HunkHeader(Some(HunkLocation {
+ start_line_old: 0,
+ count_old: 2,
+ start_line_new: 0,
+ count_new: 3
+ })),
+ DiffLine::Context("context"),
+ DiffLine::Deletion("deleted"),
+ DiffLine::Addition("inserted"),
+ DiffLine::Garbage,
+ DiffLine::Context(""),
+ DiffLine::OldPath {
+ path: "b/file.txt".into()
+ },
+ DiffLine::NewPath {
+ path: "a/file.txt".into()
+ },
+ ]
+ );
+ }
+
+ #[test]
+ fn file_header_extra_space() {
+ let options = ["--- file", "--- file", "---\tfile"];
+
+ for option in options {
+ pretty_assertions::assert_eq!(
+ DiffLine::parse(option),
+ DiffLine::OldPath {
+ path: "file".into()
+ },
+ "{option}",
+ );
+ }
+ }
+
+ #[test]
+ fn hunk_header_extra_space() {
+ let options = [
+ "@@ -1,2 +1,3 @@",
+ "@@ -1,2 +1,3 @@",
+ "@@\t-1,2\t+1,3\t@@",
+ "@@ -1,2 +1,3 @@",
+ "@@ -1,2 +1,3 @@",
+ "@@ -1,2 +1,3 @@",
+ "@@ -1,2 +1,3 @@ garbage",
+ ];
+
+ for option in options {
+ pretty_assertions::assert_eq!(
+ DiffLine::parse(option),
+ DiffLine::HunkHeader(Some(HunkLocation {
+ start_line_old: 0,
+ count_old: 2,
+ start_line_new: 0,
+ count_new: 3
+ })),
+ "{option}",
+ );
+ }
+ }
+
+ #[test]
+ fn hunk_header_without_location() {
+ pretty_assertions::assert_eq!(DiffLine::parse("@@ ... @@"), DiffLine::HunkHeader(None));
+ }
+
+ #[test]
+ fn test_parse_path() {
+ assert_eq!(parse_header_path("a/", "foo.txt"), "foo.txt");
+ assert_eq!(
+ parse_header_path("a/", "foo/bar/baz.txt"),
+ "foo/bar/baz.txt"
+ );
+ assert_eq!(parse_header_path("a/", "a/foo.txt"), "foo.txt");
+ assert_eq!(
+ parse_header_path("a/", "a/foo/bar/baz.txt"),
+ "foo/bar/baz.txt"
+ );
+
+ // Extra
+ assert_eq!(
+ parse_header_path("a/", "a/foo/bar/baz.txt 2025"),
+ "foo/bar/baz.txt"
+ );
+ assert_eq!(
+ parse_header_path("a/", "a/foo/bar/baz.txt\t2025"),
+ "foo/bar/baz.txt"
+ );
+ assert_eq!(
+ parse_header_path("a/", "a/foo/bar/baz.txt \""),
+ "foo/bar/baz.txt"
+ );
+
+ // Quoted
+ assert_eq!(
+ parse_header_path("a/", "a/foo/bar/\"baz quox.txt\""),
+ "foo/bar/baz quox.txt"
+ );
+ assert_eq!(
+ parse_header_path("a/", "\"a/foo/bar/baz quox.txt\""),
+ "foo/bar/baz quox.txt"
+ );
+ assert_eq!(
+ parse_header_path("a/", "\"foo/bar/baz quox.txt\""),
+ "foo/bar/baz quox.txt"
+ );
+ assert_eq!(parse_header_path("a/", "\"whatever 🤷\""), "whatever 🤷");
+ assert_eq!(
+ parse_header_path("a/", "\"foo/bar/baz quox.txt\" 2025"),
+ "foo/bar/baz quox.txt"
+ );
+ // unescaped quotes are dropped
+ assert_eq!(parse_header_path("a/", "foo/\"bar\""), "foo/bar");
+
+ // Escaped
+ assert_eq!(
+ parse_header_path("a/", "\"foo/\\\"bar\\\"/baz.txt\""),
+ "foo/\"bar\"/baz.txt"
+ );
+ assert_eq!(
+ parse_header_path("a/", "\"C:\\\\Projects\\\\My App\\\\old file.txt\""),
+ "C:\\Projects\\My App\\old file.txt"
+ );
+ }
+}
@@ -149,6 +149,9 @@ pub fn find_related_excerpts(
.find(|model| {
model.provider_id() == MODEL_PROVIDER_ID
&& model.id() == LanguageModelId("claude-haiku-4-5-latest".into())
+ // model.provider_id() == LanguageModelProviderId::new("zeta-ctx-qwen-30b")
+ // model.provider_id() == LanguageModelProviderId::new("ollama")
+ // && model.id() == LanguageModelId("gpt-oss:20b".into())
})
else {
return Task::ready(Err(anyhow!("could not find context model")));
@@ -35,8 +35,8 @@ use std::str::FromStr as _;
use std::sync::Arc;
use std::time::{Duration, Instant};
use thiserror::Error;
-use util::ResultExt as _;
use util::rel_path::RelPathBuf;
+use util::{LogErrorFuture, TryFutureExt};
use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
pub mod merge_excerpts;
@@ -50,8 +50,6 @@ use crate::related_excerpts::find_related_excerpts;
pub use crate::related_excerpts::{LlmContextOptions, SearchToolQuery};
pub use provider::ZetaEditPredictionProvider;
-const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
-
/// Maximum number of events to track.
const MAX_EVENT_COUNT: usize = 16;
@@ -83,6 +81,7 @@ pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
max_diagnostic_bytes: 2048,
prompt_format: PromptFormat::DEFAULT,
file_indexing_parallelism: 1,
+ buffer_change_grouping_interval: Duration::from_secs(1),
};
pub struct Zeta2FeatureFlag;
@@ -118,6 +117,7 @@ pub struct ZetaOptions {
pub max_diagnostic_bytes: usize,
pub prompt_format: predict_edits_v3::PromptFormat,
pub file_indexing_parallelism: usize,
+ pub buffer_change_grouping_interval: Duration,
}
#[derive(Debug, Clone, PartialEq)]
@@ -135,6 +135,7 @@ impl ContextMode {
}
}
+#[derive(Debug)]
pub enum ZetaDebugInfo {
ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
@@ -144,17 +145,20 @@ pub enum ZetaDebugInfo {
EditPredicted(ZetaEditPredictionDebugInfo),
}
+#[derive(Debug)]
pub struct ZetaContextRetrievalStartedDebugInfo {
pub project: Entity<Project>,
pub timestamp: Instant,
pub search_prompt: String,
}
+#[derive(Debug)]
pub struct ZetaContextRetrievalDebugInfo {
pub project: Entity<Project>,
pub timestamp: Instant,
}
+#[derive(Debug)]
pub struct ZetaEditPredictionDebugInfo {
pub request: predict_edits_v3::PredictEditsRequest,
pub retrieval_time: TimeDelta,
@@ -164,6 +168,7 @@ pub struct ZetaEditPredictionDebugInfo {
pub response_rx: oneshot::Receiver<Result<predict_edits_v3::PredictEditsResponse, String>>,
}
+#[derive(Debug)]
pub struct ZetaSearchQueryDebugInfo {
pub project: Entity<Project>,
pub timestamp: Instant,
@@ -178,7 +183,7 @@ struct ZetaProject {
registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
current_prediction: Option<CurrentEditPrediction>,
context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
- refresh_context_task: Option<Task<Option<()>>>,
+ refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
refresh_context_debounce_task: Option<Task<Option<()>>>,
refresh_context_timestamp: Option<Instant>,
}
@@ -460,6 +465,7 @@ impl Zeta {
project: &Entity<Project>,
cx: &mut Context<Self>,
) -> BufferSnapshot {
+ let buffer_change_grouping_interval = self.options.buffer_change_grouping_interval;
let zeta_project = self.get_or_init_zeta_project(project, cx);
let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
@@ -469,6 +475,7 @@ impl Zeta {
std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
Self::push_event(
zeta_project,
+ buffer_change_grouping_interval,
Event::BufferChange {
old_snapshot,
new_snapshot: new_snapshot.clone(),
@@ -480,14 +487,19 @@ impl Zeta {
new_snapshot
}
- fn push_event(zeta_project: &mut ZetaProject, event: Event) {
+ fn push_event(
+ zeta_project: &mut ZetaProject,
+ buffer_change_grouping_interval: Duration,
+ event: Event,
+ ) {
let events = &mut zeta_project.events;
- if let Some(Event::BufferChange {
- new_snapshot: last_new_snapshot,
- timestamp: last_timestamp,
- ..
- }) = events.back_mut()
+ if buffer_change_grouping_interval > Duration::ZERO
+ && let Some(Event::BufferChange {
+ new_snapshot: last_new_snapshot,
+ timestamp: last_timestamp,
+ ..
+ }) = events.back_mut()
{
// Coalesce edits for the same buffer when they happen one after the other.
let Event::BufferChange {
@@ -496,7 +508,7 @@ impl Zeta {
timestamp,
} = &event;
- if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
+ if timestamp.duration_since(*last_timestamp) <= buffer_change_grouping_interval
&& old_snapshot.remote_id() == last_new_snapshot.remote_id()
&& old_snapshot.version == last_new_snapshot.version
{
@@ -624,7 +636,7 @@ impl Zeta {
})
}
- fn request_prediction(
+ pub fn request_prediction(
&mut self,
project: &Entity<Project>,
buffer: &Entity<Buffer>,
@@ -1068,7 +1080,11 @@ impl Zeta {
log::debug!("refetching edit prediction context after pause");
}
this.update(cx, |this, cx| {
- this.refresh_context(project, buffer, cursor_position, cx);
+ let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
+
+ if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
+ zeta_project.refresh_context_task = Some(task.log_err());
+ };
})
.ok()
}
@@ -1077,73 +1093,68 @@ impl Zeta {
// Refresh the related excerpts asynchronously. Ensure the task runs to completion,
// and avoid spawning more than one concurrent task.
- fn refresh_context(
+ pub fn refresh_context(
&mut self,
project: Entity<Project>,
buffer: Entity<language::Buffer>,
cursor_position: language::Anchor,
cx: &mut Context<Self>,
- ) {
- let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
- return;
- };
-
- let debug_tx = self.debug_tx.clone();
-
- zeta_project
- .refresh_context_task
- .get_or_insert(cx.spawn(async move |this, cx| {
- let related_excerpts = this
- .update(cx, |this, cx| {
- let Some(zeta_project) = this.projects.get(&project.entity_id()) else {
- return Task::ready(anyhow::Ok(HashMap::default()));
- };
+ ) -> Task<Result<()>> {
+ cx.spawn(async move |this, cx| {
+ let related_excerpts_result = this
+ .update(cx, |this, cx| {
+ let Some(zeta_project) = this.projects.get(&project.entity_id()) else {
+ return Task::ready(anyhow::Ok(HashMap::default()));
+ };
- let ContextMode::Llm(options) = &this.options().context else {
- return Task::ready(anyhow::Ok(HashMap::default()));
- };
+ let ContextMode::Llm(options) = &this.options().context else {
+ return Task::ready(anyhow::Ok(HashMap::default()));
+ };
- let mut edit_history_unified_diff = String::new();
+ let mut edit_history_unified_diff = String::new();
- for event in zeta_project.events.iter() {
- if let Some(event) = event.to_request_event(cx) {
- writeln!(&mut edit_history_unified_diff, "{event}").ok();
- }
+ for event in zeta_project.events.iter() {
+ if let Some(event) = event.to_request_event(cx) {
+ writeln!(&mut edit_history_unified_diff, "{event}").ok();
}
+ }
- find_related_excerpts(
- buffer.clone(),
- cursor_position,
- &project,
- edit_history_unified_diff,
- options,
- debug_tx,
- cx,
- )
- })
- .ok()?
- .await
- .log_err()
- .unwrap_or_default();
- this.update(cx, |this, _cx| {
- let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
- return;
- };
- zeta_project.context = Some(related_excerpts);
- zeta_project.refresh_context_task.take();
- if let Some(debug_tx) = &this.debug_tx {
- debug_tx
- .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
- ZetaContextRetrievalDebugInfo {
- project,
- timestamp: Instant::now(),
- },
- ))
- .ok();
+ find_related_excerpts(
+ buffer.clone(),
+ cursor_position,
+ &project,
+ edit_history_unified_diff,
+ options,
+ this.debug_tx.clone(),
+ cx,
+ )
+ })?
+ .await;
+
+ this.update(cx, |this, _cx| {
+ let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
+ return Ok(());
+ };
+ zeta_project.refresh_context_task.take();
+ if let Some(debug_tx) = &this.debug_tx {
+ debug_tx
+ .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
+ ZetaContextRetrievalDebugInfo {
+ project,
+ timestamp: Instant::now(),
+ },
+ ))
+ .ok();
+ }
+ match related_excerpts_result {
+ Ok(excerpts) => {
+ zeta_project.context = Some(excerpts);
+ Ok(())
}
- })
- .ok()
- }));
+ Err(error) => Err(error),
+ }
+ })?
+ })
}
fn gather_nearby_diagnostics(
@@ -335,6 +335,8 @@ impl Zeta2Inspector {
max_diagnostic_bytes: zeta_options.max_diagnostic_bytes,
prompt_format: zeta_options.prompt_format,
file_indexing_parallelism: zeta_options.file_indexing_parallelism,
+ buffer_change_grouping_interval: zeta_options
+ .buffer_change_grouping_interval,
},
cx,
);
@@ -13,6 +13,7 @@ name = "zeta"
path = "src/main.rs"
[dependencies]
+
anyhow.workspace = true
chrono.workspace = true
clap.workspace = true
@@ -42,7 +43,6 @@ prompt_store.workspace = true
pulldown-cmark.workspace = true
release_channel.workspace = true
reqwest_client.workspace = true
-toml.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
@@ -50,8 +50,15 @@ shellexpand.workspace = true
smol.workspace = true
soa-rs = "0.8.1"
terminal_view.workspace = true
+toml.workspace = true
util.workspace = true
watch.workspace = true
zeta.workspace = true
zeta2.workspace = true
zlog.workspace = true
+
+[dev-dependencies]
+indoc.workspace = true
+gpui = { workspace = true, features = ["test-support"] }
+project = { workspace = true, features = ["test-support"] }
+pretty_assertions.workspace = true
@@ -5,17 +5,23 @@ use std::{
fs,
io::Write,
mem,
+ ops::Range,
path::{Path, PathBuf},
};
use anyhow::{Context as _, Result};
use clap::ValueEnum;
-use gpui::http_client::Url;
+use collections::HashSet;
+use futures::AsyncWriteExt as _;
+use gpui::{AsyncApp, Entity, http_client::Url};
+use language::Buffer;
+use project::{Project, ProjectPath};
use pulldown_cmark::CowStr;
use serde::{Deserialize, Serialize};
-const CURSOR_POSITION_HEADING: &str = "Cursor Position";
+const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
const EDIT_HISTORY_HEADING: &str = "Edit History";
+const CURSOR_POSITION_HEADING: &str = "Cursor Position";
const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
const EXPECTED_EXCERPTS_HEADING: &str = "Expected Excerpts";
const REPOSITORY_URL_FIELD: &str = "repository_url";
@@ -31,9 +37,10 @@ pub struct NamedExample {
pub struct Example {
pub repository_url: String,
pub revision: String,
+ pub uncommitted_diff: String,
pub cursor_path: PathBuf,
pub cursor_position: String,
- pub edit_history: Vec<String>,
+ pub edit_history: String,
pub expected_patch: String,
pub expected_excerpts: Vec<ExpectedExcerpt>,
}
@@ -59,11 +66,11 @@ impl NamedExample {
match ext.and_then(|s| s.to_str()) {
Some("json") => Ok(Self {
- name: path.file_name().unwrap_or_default().display().to_string(),
+ name: path.file_stem().unwrap_or_default().display().to_string(),
example: serde_json::from_str(&content)?,
}),
Some("toml") => Ok(Self {
- name: path.file_name().unwrap_or_default().display().to_string(),
+ name: path.file_stem().unwrap_or_default().display().to_string(),
example: toml::from_str(&content)?,
}),
Some("md") => Self::parse_md(&content),
@@ -88,9 +95,10 @@ impl NamedExample {
example: Example {
repository_url: String::new(),
revision: String::new(),
+ uncommitted_diff: String::new(),
cursor_path: PathBuf::new(),
cursor_position: String::new(),
- edit_history: Vec::new(),
+ edit_history: String::new(),
expected_patch: String::new(),
expected_excerpts: Vec::new(),
},
@@ -152,18 +160,19 @@ impl NamedExample {
block_info = "".into();
}
Event::End(TagEnd::CodeBlock) => {
- if current_section.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
- named.example.edit_history.push(mem::take(&mut text));
+ let block_info = block_info.trim();
+ if current_section.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
+ named.example.uncommitted_diff = mem::take(&mut text);
+ } else if current_section.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
+ named.example.edit_history.push_str(&mem::take(&mut text));
} else if current_section.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
- let path = PathBuf::from(block_info.trim());
- named.example.cursor_path = path;
+ named.example.cursor_path = block_info.into();
named.example.cursor_position = mem::take(&mut text);
} else if current_section.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
named.example.expected_patch = mem::take(&mut text);
} else if current_section.eq_ignore_ascii_case(EXPECTED_EXCERPTS_HEADING) {
- let path = PathBuf::from(block_info.trim());
named.example.expected_excerpts.push(ExpectedExcerpt {
- path,
+ path: block_info.into(),
text: mem::take(&mut text),
});
} else {
@@ -195,13 +204,14 @@ impl NamedExample {
#[allow(unused)]
pub async fn setup_worktree(&self) -> Result<PathBuf> {
+ let (repo_owner, repo_name) = self.repo_name()?;
+ let file_name = self.file_name();
+
let worktrees_dir = env::current_dir()?.join("target").join("zeta-worktrees");
let repos_dir = env::current_dir()?.join("target").join("zeta-repos");
fs::create_dir_all(&repos_dir)?;
fs::create_dir_all(&worktrees_dir)?;
- let (repo_owner, repo_name) = self.repo_name()?;
-
let repo_dir = repos_dir.join(repo_owner.as_ref()).join(repo_name.as_ref());
if !repo_dir.is_dir() {
fs::create_dir_all(&repo_dir)?;
@@ -213,36 +223,81 @@ impl NamedExample {
.await?;
}
- run_git(
- &repo_dir,
- &["fetch", "--depth", "1", "origin", &self.example.revision],
- )
- .await?;
-
- let worktree_path = worktrees_dir.join(&self.name);
+ // Resolve the example to a revision, fetching it if needed.
+ let revision = run_git(&repo_dir, &["rev-parse", &self.example.revision]).await;
+ let revision = if let Ok(revision) = revision {
+ revision
+ } else {
+ run_git(
+ &repo_dir,
+ &["fetch", "--depth", "1", "origin", &self.example.revision],
+ )
+ .await?;
+ let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?;
+ if revision != self.example.revision {
+ run_git(&repo_dir, &["tag", &self.example.revision, &revision]).await?;
+ }
+ revision
+ };
+ // Create the worktree for this example if needed.
+ let worktree_path = worktrees_dir.join(&file_name);
if worktree_path.is_dir() {
run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
- run_git(&worktree_path, &["checkout", &self.example.revision]).await?;
+ run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
} else {
let worktree_path_string = worktree_path.to_string_lossy();
+ run_git(&repo_dir, &["branch", "-f", &file_name, revision.as_str()]).await?;
run_git(
&repo_dir,
- &[
- "worktree",
- "add",
- "-f",
- &worktree_path_string,
- &self.example.revision,
- ],
+ &["worktree", "add", "-f", &worktree_path_string, &file_name],
)
.await?;
}
+ // Apply the uncommitted diff for this example.
+ if !self.example.uncommitted_diff.is_empty() {
+ let mut apply_process = smol::process::Command::new("git")
+ .current_dir(&worktree_path)
+ .args(&["apply", "-"])
+ .stdin(std::process::Stdio::piped())
+ .spawn()?;
+
+ let mut stdin = apply_process.stdin.take().unwrap();
+ stdin
+ .write_all(self.example.uncommitted_diff.as_bytes())
+ .await?;
+ stdin.close().await?;
+ drop(stdin);
+
+ let apply_result = apply_process.output().await?;
+ if !apply_result.status.success() {
+ anyhow::bail!(
+ "Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
+ apply_result.status,
+ String::from_utf8_lossy(&apply_result.stderr),
+ String::from_utf8_lossy(&apply_result.stdout),
+ );
+ }
+ }
+
Ok(worktree_path)
}
+ fn file_name(&self) -> String {
+ self.name
+ .chars()
+ .map(|c| {
+ if c.is_whitespace() {
+ '-'
+ } else {
+ c.to_ascii_lowercase()
+ }
+ })
+ .collect()
+ }
+
#[allow(unused)]
fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
// git@github.com:owner/repo.git
@@ -277,6 +332,15 @@ impl NamedExample {
Ok((owner.into(), repo.into()))
}
}
+
+ #[must_use]
+ pub async fn apply_edit_history(
+ &self,
+ project: &Entity<Project>,
+ cx: &mut AsyncApp,
+ ) -> Result<HashSet<Entity<Buffer>>> {
+ apply_diff(&self.example.edit_history, project, cx).await
+ }
}
async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
@@ -308,6 +372,15 @@ impl Display for NamedExample {
)?;
write!(f, "{REVISION_FIELD} = {}\n\n", self.example.revision)?;
+ write!(f, "## {UNCOMMITTED_DIFF_HEADING}\n\n")?;
+ write!(f, "`````diff\n")?;
+ write!(f, "{}", self.example.uncommitted_diff)?;
+ write!(f, "`````\n")?;
+
+ if !self.example.edit_history.is_empty() {
+ write!(f, "`````diff\n{}`````\n", self.example.edit_history)?;
+ }
+
write!(
f,
"## {CURSOR_POSITION_HEADING}\n\n`````{}\n{}`````\n",
@@ -316,14 +389,6 @@ impl Display for NamedExample {
)?;
write!(f, "## {EDIT_HISTORY_HEADING}\n\n")?;
- if !self.example.edit_history.is_empty() {
- write!(f, "`````diff\n")?;
- for item in &self.example.edit_history {
- write!(f, "{item}")?;
- }
- write!(f, "`````\n")?;
- }
-
if !self.example.expected_patch.is_empty() {
write!(
f,
@@ -353,3 +418,404 @@ impl Display for NamedExample {
Ok(())
}
}
+
+#[must_use]
+pub async fn apply_diff(
+ diff: &str,
+ project: &Entity<Project>,
+ cx: &mut AsyncApp,
+) -> Result<HashSet<Entity<Buffer>>> {
+ use cloud_llm_client::udiff::DiffLine;
+ use std::fmt::Write;
+
+ #[derive(Debug, Default)]
+ struct HunkState {
+ context: String,
+ edits: Vec<Edit>,
+ }
+
+ #[derive(Debug)]
+ struct Edit {
+ range: Range<usize>,
+ text: String,
+ }
+
+ let mut old_path = None;
+ let mut new_path = None;
+ let mut hunk = HunkState::default();
+ let mut diff_lines = diff.lines().map(DiffLine::parse).peekable();
+ let mut open_buffers = HashSet::default();
+
+ while let Some(diff_line) = diff_lines.next() {
+ match diff_line {
+ DiffLine::OldPath { path } => old_path = Some(path),
+ DiffLine::NewPath { path } => {
+ if old_path.is_none() {
+ anyhow::bail!(
+ "Found a new path header (`+++`) before an (`---`) old path header"
+ );
+ }
+ new_path = Some(path)
+ }
+ DiffLine::Context(ctx) => {
+ writeln!(&mut hunk.context, "{ctx}")?;
+ }
+ DiffLine::Deletion(del) => {
+ let range = hunk.context.len()..hunk.context.len() + del.len() + '\n'.len_utf8();
+ if let Some(last_edit) = hunk.edits.last_mut()
+ && last_edit.range.end == range.start
+ {
+ last_edit.range.end = range.end;
+ } else {
+ hunk.edits.push(Edit {
+ range,
+ text: String::new(),
+ });
+ }
+ writeln!(&mut hunk.context, "{del}")?;
+ }
+ DiffLine::Addition(add) => {
+ let range = hunk.context.len()..hunk.context.len();
+ if let Some(last_edit) = hunk.edits.last_mut()
+ && last_edit.range.end == range.start
+ {
+ writeln!(&mut last_edit.text, "{add}").unwrap();
+ } else {
+ hunk.edits.push(Edit {
+ range,
+ text: format!("{add}\n"),
+ });
+ }
+ }
+ DiffLine::HunkHeader(_) | DiffLine::Garbage => {}
+ }
+
+ let at_hunk_end = match diff_lines.peek() {
+ Some(DiffLine::OldPath { .. }) | Some(DiffLine::HunkHeader(_)) | None => true,
+ _ => false,
+ };
+
+ if at_hunk_end {
+ let hunk = mem::take(&mut hunk);
+
+ let Some(old_path) = old_path.as_deref() else {
+ anyhow::bail!("Missing old path (`---`) header")
+ };
+
+ let Some(new_path) = new_path.as_deref() else {
+ anyhow::bail!("Missing new path (`+++`) header")
+ };
+
+ let buffer = project
+ .update(cx, |project, cx| {
+ let project_path = project
+ .find_project_path(old_path, cx)
+ .context("Failed to find old_path in project")?;
+
+ anyhow::Ok(project.open_buffer(project_path, cx))
+ })??
+ .await?;
+ open_buffers.insert(buffer.clone());
+
+ if old_path != new_path {
+ project
+ .update(cx, |project, cx| {
+ let project_file = project::File::from_dyn(buffer.read(cx).file()).unwrap();
+ let new_path = ProjectPath {
+ worktree_id: project_file.worktree_id(cx),
+ path: project_file.path.clone(),
+ };
+ project.rename_entry(project_file.entry_id.unwrap(), new_path, cx)
+ })?
+ .await?;
+ }
+
+ // TODO is it worth using project search?
+ buffer.update(cx, |buffer, cx| {
+ let context_offset = if hunk.context.is_empty() {
+ 0
+ } else {
+ let text = buffer.text();
+ if let Some(offset) = text.find(&hunk.context) {
+ if text[offset + 1..].contains(&hunk.context) {
+ anyhow::bail!("Context is not unique enough:\n{}", hunk.context);
+ }
+ offset
+ } else {
+ anyhow::bail!(
+ "Failed to match context:\n{}\n\nBuffer:\n{}",
+ hunk.context,
+ text
+ );
+ }
+ };
+
+ buffer.edit(
+ hunk.edits.into_iter().map(|edit| {
+ (
+ context_offset + edit.range.start..context_offset + edit.range.end,
+ edit.text,
+ )
+ }),
+ None,
+ cx,
+ );
+
+ anyhow::Ok(())
+ })??;
+ }
+ }
+
+ anyhow::Ok(open_buffers)
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use ::fs::FakeFs;
+ use gpui::TestAppContext;
+ use indoc::indoc;
+ use pretty_assertions::assert_eq;
+ use project::Project;
+ use serde_json::json;
+ use settings::SettingsStore;
+ use util::path;
+
+ #[gpui::test]
+ async fn test_apply_diff_successful(cx: &mut TestAppContext) {
+ let buffer_1_text = indoc! {r#"
+ one
+ two
+ three
+ four
+ five
+ "# };
+
+ let buffer_1_text_final = indoc! {r#"
+ 3
+ 4
+ 5
+ "# };
+
+ let buffer_2_text = indoc! {r#"
+ six
+ seven
+ eight
+ nine
+ ten
+ "# };
+
+ let buffer_2_text_final = indoc! {r#"
+ 5
+ six
+ seven
+ 7.5
+ eight
+ nine
+ ten
+ 11
+ "# };
+
+ cx.update(|cx| {
+ let settings_store = SettingsStore::test(cx);
+ cx.set_global(settings_store);
+ Project::init_settings(cx);
+ language::init(cx);
+ });
+
+ let fs = FakeFs::new(cx.background_executor.clone());
+ fs.insert_tree(
+ path!("/root"),
+ json!({
+ "file1": buffer_1_text,
+ "file2": buffer_2_text,
+ }),
+ )
+ .await;
+
+ let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
+
+ let diff = indoc! {r#"
+ --- a/root/file1
+ +++ b/root/file1
+ one
+ two
+ -three
+ +3
+ four
+ five
+ --- a/root/file1
+ +++ b/root/file1
+ 3
+ -four
+ -five
+ +4
+ +5
+ --- a/root/file1
+ +++ b/root/file1
+ -one
+ -two
+ 3
+ 4
+ --- a/root/file2
+ +++ b/root/file2
+ +5
+ six
+ --- a/root/file2
+ +++ b/root/file2
+ seven
+ +7.5
+ eight
+ --- a/root/file2
+ +++ b/root/file2
+ ten
+ +11
+ "#};
+
+ let _buffers = apply_diff(diff, &project, &mut cx.to_async())
+ .await
+ .unwrap();
+ let buffer_1 = project
+ .update(cx, |project, cx| {
+ let project_path = project.find_project_path(path!("/root/file1"), cx).unwrap();
+ project.open_buffer(project_path, cx)
+ })
+ .await
+ .unwrap();
+
+ buffer_1.read_with(cx, |buffer, _cx| {
+ assert_eq!(buffer.text(), buffer_1_text_final);
+ });
+ let buffer_2 = project
+ .update(cx, |project, cx| {
+ let project_path = project.find_project_path(path!("/root/file2"), cx).unwrap();
+ project.open_buffer(project_path, cx)
+ })
+ .await
+ .unwrap();
+
+ buffer_2.read_with(cx, |buffer, _cx| {
+ assert_eq!(buffer.text(), buffer_2_text_final);
+ });
+ }
+
+ #[gpui::test]
+ async fn test_apply_diff_non_unique(cx: &mut TestAppContext) {
+ let buffer_1_text = indoc! {r#"
+ one
+ two
+ three
+ four
+ five
+ one
+ two
+ three
+ four
+ five
+ "# };
+
+ cx.update(|cx| {
+ let settings_store = SettingsStore::test(cx);
+ cx.set_global(settings_store);
+ Project::init_settings(cx);
+ language::init(cx);
+ });
+
+ let fs = FakeFs::new(cx.background_executor.clone());
+ fs.insert_tree(
+ path!("/root"),
+ json!({
+ "file1": buffer_1_text,
+ }),
+ )
+ .await;
+
+ let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
+
+ let diff = indoc! {r#"
+ --- a/root/file1
+ +++ b/root/file1
+ one
+ two
+ -three
+ +3
+ four
+ five
+ "#};
+
+ apply_diff(diff, &project, &mut cx.to_async())
+ .await
+ .expect_err("Non-unique edits should fail");
+ }
+
+ #[gpui::test]
+ async fn test_apply_diff_unique_via_previous_context(cx: &mut TestAppContext) {
+ let start = indoc! {r#"
+ one
+ two
+ three
+ four
+ five
+
+ four
+ five
+ "# };
+
+ let end = indoc! {r#"
+ one
+ two
+ 3
+ four
+ 5
+
+ four
+ five
+ "# };
+
+ cx.update(|cx| {
+ let settings_store = SettingsStore::test(cx);
+ cx.set_global(settings_store);
+ Project::init_settings(cx);
+ language::init(cx);
+ });
+
+ let fs = FakeFs::new(cx.background_executor.clone());
+ fs.insert_tree(
+ path!("/root"),
+ json!({
+ "file1": start,
+ }),
+ )
+ .await;
+
+ let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
+
+ let diff = indoc! {r#"
+ --- a/root/file1
+ +++ b/root/file1
+ one
+ two
+ -three
+ +3
+ four
+ -five
+ +5
+ "#};
+
+ let _buffers = apply_diff(diff, &project, &mut cx.to_async())
+ .await
+ .unwrap();
+
+ let buffer_1 = project
+ .update(cx, |project, cx| {
+ let project_path = project.find_project_path(path!("/root/file1"), cx).unwrap();
+ project.open_buffer(project_path, cx)
+ })
+ .await
+ .unwrap();
+
+ buffer_1.read_with(cx, |buffer, _cx| {
+ assert_eq!(buffer.text(), end);
+ });
+ }
+}
@@ -8,6 +8,7 @@ use crate::example::{ExampleFormat, NamedExample};
use crate::syntax_retrieval_stats::retrieval_stats;
use ::serde::Serialize;
use ::util::paths::PathStyle;
+use ::util::rel_path::RelPath;
use anyhow::{Context as _, Result, anyhow};
use clap::{Args, Parser, Subcommand};
use cloud_llm_client::predict_edits_v3::{self, Excerpt};
@@ -21,10 +22,11 @@ use futures::channel::mpsc;
use gpui::{Application, AsyncApp, Entity, prelude::*};
use language::{Bias, Buffer, BufferSnapshot, OffsetRangeExt, Point};
use language_model::LanguageModelRegistry;
-use project::{Project, Worktree};
+use project::{Project, ProjectPath, Worktree};
use reqwest_client::ReqwestClient;
use serde_json::json;
use std::io;
+use std::time::{Duration, Instant};
use std::{collections::HashSet, path::PathBuf, process::exit, str::FromStr, sync::Arc};
use zeta2::{ContextMode, LlmContextOptions, SearchToolQuery};
@@ -46,8 +48,6 @@ enum Command {
command: Zeta1Command,
},
Zeta2 {
- #[clap(flatten)]
- args: Zeta2Args,
#[command(subcommand)]
command: Zeta2Command,
},
@@ -69,15 +69,22 @@ enum Zeta1Command {
#[derive(Subcommand, Debug)]
enum Zeta2Command {
Syntax {
+ #[clap(flatten)]
+ args: Zeta2Args,
#[clap(flatten)]
syntax_args: Zeta2SyntaxArgs,
#[command(subcommand)]
command: Zeta2SyntaxCommand,
},
Llm {
+ #[clap(flatten)]
+ args: Zeta2Args,
#[command(subcommand)]
command: Zeta2LlmCommand,
},
+ Predict {
+ example_path: PathBuf,
+ },
}
#[derive(Subcommand, Debug)]
@@ -170,6 +177,7 @@ fn syntax_args_to_options(
max_prompt_bytes: zeta2_args.max_prompt_bytes,
prompt_format: zeta2_args.prompt_format.clone().into(),
file_indexing_parallelism: zeta2_args.file_indexing_parallelism,
+ buffer_change_grouping_interval: Duration::ZERO,
}
}
@@ -319,6 +327,208 @@ async fn load_context(
})
}
+async fn zeta2_predict(
+ example: NamedExample,
+ app_state: &Arc<ZetaCliAppState>,
+ cx: &mut AsyncApp,
+) -> Result<()> {
+ let worktree_path = example.setup_worktree().await?;
+
+ cx.update(|cx| {
+ LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
+ registry
+ .provider(&zeta2::related_excerpts::MODEL_PROVIDER_ID)
+ .unwrap()
+ .authenticate(cx)
+ })
+ })?
+ .await?;
+
+ app_state
+ .client
+ .sign_in_with_optional_connect(true, cx)
+ .await?;
+
+ let project = cx.update(|cx| {
+ Project::local(
+ app_state.client.clone(),
+ app_state.node_runtime.clone(),
+ app_state.user_store.clone(),
+ app_state.languages.clone(),
+ app_state.fs.clone(),
+ None,
+ cx,
+ )
+ })?;
+
+ let worktree = project
+ .update(cx, |project, cx| {
+ project.create_worktree(&worktree_path, true, cx)
+ })?
+ .await?;
+ worktree
+ .read_with(cx, |worktree, _cx| {
+ worktree.as_local().unwrap().scan_complete()
+ })?
+ .await;
+
+ let _edited_buffers = example.apply_edit_history(&project, cx).await?;
+
+ let cursor_path = RelPath::new(&example.example.cursor_path, PathStyle::Posix)?.into_arc();
+
+ let cursor_buffer = project
+ .update(cx, |project, cx| {
+ project.open_buffer(
+ ProjectPath {
+ worktree_id: worktree.read(cx).id(),
+ path: cursor_path,
+ },
+ cx,
+ )
+ })?
+ .await?;
+
+ let cursor_offset_within_excerpt = example
+ .example
+ .cursor_position
+ .find(CURSOR_MARKER)
+ .ok_or_else(|| anyhow!("missing cursor marker"))?;
+ let mut cursor_excerpt = example.example.cursor_position.clone();
+ cursor_excerpt.replace_range(
+ cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
+ "",
+ );
+ let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
+ let text = buffer.text();
+
+ let mut matches = text.match_indices(&cursor_excerpt);
+ let Some((excerpt_offset, _)) = matches.next() else {
+ anyhow::bail!(
+ "Cursor excerpt did not exist in buffer.\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n"
+ );
+ };
+ assert!(matches.next().is_none());
+
+ Ok(excerpt_offset)
+ })??;
+
+ let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
+ let cursor_anchor =
+ cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?;
+
+ let zeta = cx.update(|cx| zeta2::Zeta::global(&app_state.client, &app_state.user_store, cx))?;
+
+ let refresh_task = zeta.update(cx, |zeta, cx| {
+ zeta.register_buffer(&cursor_buffer, &project, cx);
+ zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
+ })?;
+
+ let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
+ let mut context_retrieval_started_at = None;
+ let mut context_retrieval_finished_at = None;
+ let mut search_queries_generated_at = None;
+ let mut search_queries_executed_at = None;
+ let mut prediction_started_at = None;
+ let mut prediction_finished_at = None;
+ let mut excerpts_text = String::new();
+ let mut prediction_task = None;
+ while let Some(event) = debug_rx.next().await {
+ match event {
+ zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
+ context_retrieval_started_at = Some(info.timestamp);
+ }
+ zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
+ search_queries_generated_at = Some(info.timestamp);
+ }
+ zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
+ search_queries_executed_at = Some(info.timestamp);
+ }
+ zeta2::ZetaDebugInfo::ContextRetrievalFinished(info) => {
+ context_retrieval_finished_at = Some(info.timestamp);
+
+ prediction_task = Some(zeta.update(cx, |zeta, cx| {
+ zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
+ })?);
+ }
+ zeta2::ZetaDebugInfo::EditPredicted(request) => {
+ prediction_started_at = Some(Instant::now());
+ request.response_rx.await?.map_err(|err| anyhow!(err))?;
+ prediction_finished_at = Some(Instant::now());
+
+ for included_file in request.request.included_files {
+ let insertions = vec![(request.request.cursor_point, CURSOR_MARKER)];
+ write_codeblock(
+ &included_file.path,
+ included_file.excerpts.iter(),
+ if included_file.path == request.request.excerpt_path {
+ &insertions
+ } else {
+ &[]
+ },
+ included_file.max_row,
+ false,
+ &mut excerpts_text,
+ );
+ }
+ break;
+ }
+ _ => {}
+ }
+ }
+
+ refresh_task.await.context("context retrieval failed")?;
+ let prediction = prediction_task.unwrap().await?.context("No prediction")?;
+
+ println!("## Excerpts\n");
+ println!("{excerpts_text}");
+
+ let old_text = prediction.snapshot.text();
+ let new_text = prediction.buffer.update(cx, |buffer, cx| {
+ buffer.edit(prediction.edits.iter().cloned(), None, cx);
+ buffer.text()
+ })?;
+ let diff = language::unified_diff(&old_text, &new_text);
+
+ println!("## Prediction\n");
+ println!("{diff}");
+
+ println!("## Time\n");
+
+ let planning_search_time =
+ search_queries_generated_at.unwrap() - context_retrieval_started_at.unwrap();
+
+ println!("Planning searches: {}ms", planning_search_time.as_millis());
+ println!(
+ "Running searches: {}ms",
+ (search_queries_executed_at.unwrap() - search_queries_generated_at.unwrap()).as_millis()
+ );
+
+ let filtering_search_time =
+ context_retrieval_finished_at.unwrap() - search_queries_executed_at.unwrap();
+ println!(
+ "Filtering context results: {}ms",
+ filtering_search_time.as_millis()
+ );
+
+ let prediction_time = prediction_finished_at.unwrap() - prediction_started_at.unwrap();
+ println!("Making Prediction: {}ms", prediction_time.as_millis());
+
+ println!("-------------------");
+ let total_time =
+ (prediction_finished_at.unwrap() - context_retrieval_started_at.unwrap()).as_millis();
+ println!("Total: {}ms", total_time);
+
+ let inference_time =
+ (planning_search_time + filtering_search_time + prediction_time).as_millis();
+ println!(
+ "Inference: {}ms ({:.2}%)",
+ inference_time,
+ (inference_time as f64 / total_time as f64) * 100.
+ );
+
+ anyhow::Ok(())
+}
+
async fn zeta2_syntax_context(
zeta2_args: Zeta2Args,
syntax_args: Zeta2SyntaxArgs,
@@ -616,8 +826,15 @@ fn main() {
let context = zeta1_context(context_args, &app_state, cx).await.unwrap();
serde_json::to_string_pretty(&context.body).map_err(|err| anyhow::anyhow!(err))
}
- Command::Zeta2 { args, command } => match command {
+ Command::Zeta2 { command } => match command {
+ Zeta2Command::Predict { example_path } => {
+ let example = NamedExample::load(example_path).unwrap();
+ zeta2_predict(example, &app_state, cx).await.unwrap();
+ let _ = cx.update(|cx| cx.quit());
+ return;
+ }
Zeta2Command::Syntax {
+ args,
syntax_args,
command,
} => match command {
@@ -643,7 +860,7 @@ fn main() {
.await
}
},
- Zeta2Command::Llm { command } => match command {
+ Zeta2Command::Llm { args, command } => match command {
Zeta2LlmCommand::Context { context_args } => {
zeta2_llm_context(args, context_args, &app_state, cx).await
}