1mod evaluate;
2mod example;
3mod headless;
4mod metrics;
5mod paths;
6mod predict;
7mod source_location;
8mod syntax_retrieval_stats;
9mod util;
10
11use crate::{
12 evaluate::run_evaluate,
13 example::{ExampleFormat, NamedExample},
14 headless::ZetaCliAppState,
15 predict::run_predict,
16 source_location::SourceLocation,
17 syntax_retrieval_stats::retrieval_stats,
18 util::{open_buffer, open_buffer_with_language_server},
19};
20use ::util::paths::PathStyle;
21use anyhow::{Result, anyhow};
22use clap::{Args, Parser, Subcommand, ValueEnum};
23use cloud_llm_client::predict_edits_v3;
24use edit_prediction_context::{
25 EditPredictionContextOptions, EditPredictionExcerptOptions, EditPredictionScoreOptions,
26};
27use gpui::{Application, AsyncApp, Entity, prelude::*};
28use language::{Bias, Buffer, BufferSnapshot, Point};
29use metrics::delta_chr_f;
30use project::{Project, Worktree};
31use reqwest_client::ReqwestClient;
32use serde_json::json;
33use std::io::{self};
34use std::time::Duration;
35use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc};
36use zeta::ContextMode;
37use zeta::udiff::DiffLine;
38
39#[derive(Parser, Debug)]
40#[command(name = "zeta")]
41struct ZetaCliArgs {
42 #[arg(long, default_value_t = false)]
43 printenv: bool,
44 #[command(subcommand)]
45 command: Option<Command>,
46}
47
48#[derive(Subcommand, Debug)]
49enum Command {
50 Context(ContextArgs),
51 ContextStats(ContextStatsArgs),
52 Predict(PredictArguments),
53 Eval(EvaluateArguments),
54 ConvertExample {
55 path: PathBuf,
56 #[arg(long, value_enum, default_value_t = ExampleFormat::Md)]
57 output_format: ExampleFormat,
58 },
59 Score {
60 golden_patch: PathBuf,
61 actual_patch: PathBuf,
62 },
63 Clean,
64}
65
66#[derive(Debug, Args)]
67struct ContextStatsArgs {
68 #[arg(long)]
69 worktree: PathBuf,
70 #[arg(long)]
71 extension: Option<String>,
72 #[arg(long)]
73 limit: Option<usize>,
74 #[arg(long)]
75 skip: Option<usize>,
76 #[clap(flatten)]
77 zeta2_args: Zeta2Args,
78}
79
80#[derive(Debug, Args)]
81struct ContextArgs {
82 #[arg(long)]
83 provider: ContextProvider,
84 #[arg(long)]
85 worktree: PathBuf,
86 #[arg(long)]
87 cursor: SourceLocation,
88 #[arg(long)]
89 use_language_server: bool,
90 #[arg(long)]
91 edit_history: Option<FileOrStdin>,
92 #[clap(flatten)]
93 zeta2_args: Zeta2Args,
94}
95
96#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
97enum ContextProvider {
98 Zeta1,
99 #[default]
100 Syntax,
101}
102
103#[derive(Clone, Debug, Args)]
104struct Zeta2Args {
105 #[arg(long, default_value_t = 8192)]
106 max_prompt_bytes: usize,
107 #[arg(long, default_value_t = 2048)]
108 max_excerpt_bytes: usize,
109 #[arg(long, default_value_t = 1024)]
110 min_excerpt_bytes: usize,
111 #[arg(long, default_value_t = 0.66)]
112 target_before_cursor_over_total_bytes: f32,
113 #[arg(long, default_value_t = 1024)]
114 max_diagnostic_bytes: usize,
115 #[arg(long, value_enum, default_value_t = PromptFormat::default())]
116 prompt_format: PromptFormat,
117 #[arg(long, value_enum, default_value_t = Default::default())]
118 output_format: OutputFormat,
119 #[arg(long, default_value_t = 42)]
120 file_indexing_parallelism: usize,
121 #[arg(long, default_value_t = false)]
122 disable_imports_gathering: bool,
123 #[arg(long, default_value_t = u8::MAX)]
124 max_retrieved_definitions: u8,
125}
126
127#[derive(Debug, Args)]
128pub struct PredictArguments {
129 #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
130 format: PredictionsOutputFormat,
131 example_path: PathBuf,
132 #[clap(flatten)]
133 options: PredictionOptions,
134}
135
136#[derive(Clone, Debug, Args)]
137pub struct PredictionOptions {
138 #[clap(flatten)]
139 zeta2: Zeta2Args,
140 #[clap(long)]
141 provider: PredictionProvider,
142 #[clap(long, value_enum, default_value_t = CacheMode::default())]
143 cache: CacheMode,
144}
145
146#[derive(Debug, ValueEnum, Default, Clone, Copy, PartialEq)]
147pub enum CacheMode {
148 /// Use cached LLM requests and responses, except when multiple repetitions are requested
149 #[default]
150 Auto,
151 /// Use cached LLM requests and responses, based on the hash of the prompt and the endpoint.
152 #[value(alias = "request")]
153 Requests,
154 /// Ignore existing cache entries for both LLM and search.
155 Skip,
156 /// Use cached LLM responses AND search results for full determinism. Fails if they haven't been cached yet.
157 /// Useful for reproducing results and fixing bugs outside of search queries
158 Force,
159}
160
161impl CacheMode {
162 fn use_cached_llm_responses(&self) -> bool {
163 self.assert_not_auto();
164 matches!(self, CacheMode::Requests | CacheMode::Force)
165 }
166
167 fn use_cached_search_results(&self) -> bool {
168 self.assert_not_auto();
169 matches!(self, CacheMode::Force)
170 }
171
172 fn assert_not_auto(&self) {
173 assert_ne!(
174 *self,
175 CacheMode::Auto,
176 "Cache mode should not be auto at this point!"
177 );
178 }
179}
180
181#[derive(clap::ValueEnum, Debug, Clone)]
182pub enum PredictionsOutputFormat {
183 Json,
184 Md,
185 Diff,
186}
187
188#[derive(Debug, Args)]
189pub struct EvaluateArguments {
190 example_paths: Vec<PathBuf>,
191 #[clap(flatten)]
192 options: PredictionOptions,
193 #[clap(short, long, default_value_t = 1, alias = "repeat")]
194 repetitions: u16,
195 #[arg(long)]
196 skip_prediction: bool,
197}
198
199#[derive(clap::ValueEnum, Default, Debug, Clone, Copy, PartialEq)]
200enum PredictionProvider {
201 Zeta1,
202 #[default]
203 Zeta2,
204 Sweep,
205}
206
207fn zeta2_args_to_options(args: &Zeta2Args, omit_excerpt_overlaps: bool) -> zeta::ZetaOptions {
208 zeta::ZetaOptions {
209 context: ContextMode::Syntax(EditPredictionContextOptions {
210 max_retrieved_declarations: args.max_retrieved_definitions,
211 use_imports: !args.disable_imports_gathering,
212 excerpt: EditPredictionExcerptOptions {
213 max_bytes: args.max_excerpt_bytes,
214 min_bytes: args.min_excerpt_bytes,
215 target_before_cursor_over_total_bytes: args.target_before_cursor_over_total_bytes,
216 },
217 score: EditPredictionScoreOptions {
218 omit_excerpt_overlaps,
219 },
220 }),
221 max_diagnostic_bytes: args.max_diagnostic_bytes,
222 max_prompt_bytes: args.max_prompt_bytes,
223 prompt_format: args.prompt_format.into(),
224 file_indexing_parallelism: args.file_indexing_parallelism,
225 buffer_change_grouping_interval: Duration::ZERO,
226 }
227}
228
229#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
230enum PromptFormat {
231 MarkedExcerpt,
232 LabeledSections,
233 OnlySnippets,
234 #[default]
235 NumberedLines,
236 OldTextNewText,
237 Minimal,
238 MinimalQwen,
239 SeedCoder1120,
240}
241
242impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
243 fn into(self) -> predict_edits_v3::PromptFormat {
244 match self {
245 Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
246 Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
247 Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
248 Self::NumberedLines => predict_edits_v3::PromptFormat::NumLinesUniDiff,
249 Self::OldTextNewText => predict_edits_v3::PromptFormat::OldTextNewText,
250 Self::Minimal => predict_edits_v3::PromptFormat::Minimal,
251 Self::MinimalQwen => predict_edits_v3::PromptFormat::MinimalQwen,
252 Self::SeedCoder1120 => predict_edits_v3::PromptFormat::SeedCoder1120,
253 }
254 }
255}
256
257#[derive(clap::ValueEnum, Default, Debug, Clone)]
258enum OutputFormat {
259 #[default]
260 Prompt,
261 Request,
262 Full,
263}
264
265#[derive(Debug, Clone)]
266enum FileOrStdin {
267 File(PathBuf),
268 Stdin,
269}
270
271impl FileOrStdin {
272 async fn read_to_string(&self) -> Result<String, std::io::Error> {
273 match self {
274 FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
275 FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
276 }
277 }
278}
279
280impl FromStr for FileOrStdin {
281 type Err = <PathBuf as FromStr>::Err;
282
283 fn from_str(s: &str) -> Result<Self, Self::Err> {
284 match s {
285 "-" => Ok(Self::Stdin),
286 _ => Ok(Self::File(PathBuf::from_str(s)?)),
287 }
288 }
289}
290
291struct LoadedContext {
292 full_path_str: String,
293 snapshot: BufferSnapshot,
294 clipped_cursor: Point,
295 worktree: Entity<Worktree>,
296 project: Entity<Project>,
297 buffer: Entity<Buffer>,
298}
299
300async fn load_context(
301 args: &ContextArgs,
302 app_state: &Arc<ZetaCliAppState>,
303 cx: &mut AsyncApp,
304) -> Result<LoadedContext> {
305 let ContextArgs {
306 worktree: worktree_path,
307 cursor,
308 use_language_server,
309 ..
310 } = args;
311
312 let worktree_path = worktree_path.canonicalize()?;
313
314 let project = cx.update(|cx| {
315 Project::local(
316 app_state.client.clone(),
317 app_state.node_runtime.clone(),
318 app_state.user_store.clone(),
319 app_state.languages.clone(),
320 app_state.fs.clone(),
321 None,
322 cx,
323 )
324 })?;
325
326 let worktree = project
327 .update(cx, |project, cx| {
328 project.create_worktree(&worktree_path, true, cx)
329 })?
330 .await?;
331
332 let mut ready_languages = HashSet::default();
333 let (_lsp_open_handle, buffer) = if *use_language_server {
334 let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
335 project.clone(),
336 worktree.clone(),
337 cursor.path.clone(),
338 &mut ready_languages,
339 cx,
340 )
341 .await?;
342 (Some(lsp_open_handle), buffer)
343 } else {
344 let buffer =
345 open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
346 (None, buffer)
347 };
348
349 let full_path_str = worktree
350 .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
351 .display(PathStyle::local())
352 .to_string();
353
354 let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
355 let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
356 if clipped_cursor != cursor.point {
357 let max_row = snapshot.max_point().row;
358 if cursor.point.row < max_row {
359 return Err(anyhow!(
360 "Cursor position {:?} is out of bounds (line length is {})",
361 cursor.point,
362 snapshot.line_len(cursor.point.row)
363 ));
364 } else {
365 return Err(anyhow!(
366 "Cursor position {:?} is out of bounds (max row is {})",
367 cursor.point,
368 max_row
369 ));
370 }
371 }
372
373 Ok(LoadedContext {
374 full_path_str,
375 snapshot,
376 clipped_cursor,
377 worktree,
378 project,
379 buffer,
380 })
381}
382
383async fn zeta2_syntax_context(
384 args: ContextArgs,
385 app_state: &Arc<ZetaCliAppState>,
386 cx: &mut AsyncApp,
387) -> Result<String> {
388 let LoadedContext {
389 worktree,
390 project,
391 buffer,
392 clipped_cursor,
393 ..
394 } = load_context(&args, app_state, cx).await?;
395
396 // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
397 // the whole worktree.
398 worktree
399 .read_with(cx, |worktree, _cx| {
400 worktree.as_local().unwrap().scan_complete()
401 })?
402 .await;
403 let output = cx
404 .update(|cx| {
405 let zeta = cx.new(|cx| {
406 zeta::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
407 });
408 let indexing_done_task = zeta.update(cx, |zeta, cx| {
409 zeta.set_options(zeta2_args_to_options(&args.zeta2_args, true));
410 zeta.register_buffer(&buffer, &project, cx);
411 zeta.wait_for_initial_indexing(&project, cx)
412 });
413 cx.spawn(async move |cx| {
414 indexing_done_task.await?;
415 let request = zeta
416 .update(cx, |zeta, cx| {
417 let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
418 zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
419 })?
420 .await?;
421
422 let (prompt_string, section_labels) = cloud_zeta2_prompt::build_prompt(&request)?;
423
424 match args.zeta2_args.output_format {
425 OutputFormat::Prompt => anyhow::Ok(prompt_string),
426 OutputFormat::Request => anyhow::Ok(serde_json::to_string_pretty(&request)?),
427 OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({
428 "request": request,
429 "prompt": prompt_string,
430 "section_labels": section_labels,
431 }))?),
432 }
433 })
434 })?
435 .await?;
436
437 Ok(output)
438}
439
440async fn zeta1_context(
441 args: ContextArgs,
442 app_state: &Arc<ZetaCliAppState>,
443 cx: &mut AsyncApp,
444) -> Result<zeta::zeta1::GatherContextOutput> {
445 let LoadedContext {
446 full_path_str,
447 snapshot,
448 clipped_cursor,
449 ..
450 } = load_context(&args, app_state, cx).await?;
451
452 let events = match args.edit_history {
453 Some(events) => events.read_to_string().await?,
454 None => String::new(),
455 };
456
457 let prompt_for_events = move || (events, 0);
458 cx.update(|cx| {
459 zeta::zeta1::gather_context(
460 full_path_str,
461 &snapshot,
462 clipped_cursor,
463 prompt_for_events,
464 cloud_llm_client::PredictEditsRequestTrigger::Cli,
465 cx,
466 )
467 })?
468 .await
469}
470
471fn main() {
472 zlog::init();
473 zlog::init_output_stderr();
474 let args = ZetaCliArgs::parse();
475 let http_client = Arc::new(ReqwestClient::new());
476 let app = Application::headless().with_http_client(http_client);
477
478 app.run(move |cx| {
479 let app_state = Arc::new(headless::init(cx));
480 cx.spawn(async move |cx| {
481 match args.command {
482 None => {
483 if args.printenv {
484 ::util::shell_env::print_env();
485 return;
486 } else {
487 panic!("Expected a command");
488 }
489 }
490 Some(Command::ContextStats(arguments)) => {
491 let result = retrieval_stats(
492 arguments.worktree,
493 app_state,
494 arguments.extension,
495 arguments.limit,
496 arguments.skip,
497 zeta2_args_to_options(&arguments.zeta2_args, false),
498 cx,
499 )
500 .await;
501 println!("{}", result.unwrap());
502 }
503 Some(Command::Context(context_args)) => {
504 let result = match context_args.provider {
505 ContextProvider::Zeta1 => {
506 let context =
507 zeta1_context(context_args, &app_state, cx).await.unwrap();
508 serde_json::to_string_pretty(&context.body).unwrap()
509 }
510 ContextProvider::Syntax => {
511 zeta2_syntax_context(context_args, &app_state, cx)
512 .await
513 .unwrap()
514 }
515 };
516 println!("{}", result);
517 }
518 Some(Command::Predict(arguments)) => {
519 run_predict(arguments, &app_state, cx).await;
520 }
521 Some(Command::Eval(arguments)) => {
522 run_evaluate(arguments, &app_state, cx).await;
523 }
524 Some(Command::ConvertExample {
525 path,
526 output_format,
527 }) => {
528 let example = NamedExample::load(path).unwrap();
529 example.write(output_format, io::stdout()).unwrap();
530 }
531 Some(Command::Score {
532 golden_patch,
533 actual_patch,
534 }) => {
535 let golden_content = std::fs::read_to_string(golden_patch).unwrap();
536 let actual_content = std::fs::read_to_string(actual_patch).unwrap();
537
538 let golden_diff: Vec<DiffLine> = golden_content
539 .lines()
540 .map(|line| DiffLine::parse(line))
541 .collect();
542
543 let actual_diff: Vec<DiffLine> = actual_content
544 .lines()
545 .map(|line| DiffLine::parse(line))
546 .collect();
547
548 let score = delta_chr_f(&golden_diff, &actual_diff);
549 println!("{:.2}", score);
550 }
551 Some(Command::Clean) => {
552 std::fs::remove_dir_all(&*crate::paths::TARGET_ZETA_DIR).unwrap()
553 }
554 };
555
556 let _ = cx.update(|cx| cx.quit());
557 })
558 .detach();
559 });
560}