1mod evaluate;
2mod example;
3mod headless;
4mod metrics;
5mod paths;
6mod predict;
7mod source_location;
8mod syntax_retrieval_stats;
9mod util;
10
11use crate::{
12 evaluate::run_evaluate,
13 example::{ExampleFormat, NamedExample},
14 headless::ZetaCliAppState,
15 predict::run_predict,
16 source_location::SourceLocation,
17 syntax_retrieval_stats::retrieval_stats,
18 util::{open_buffer, open_buffer_with_language_server},
19};
20use ::util::paths::PathStyle;
21use anyhow::{Result, anyhow};
22use clap::{Args, Parser, Subcommand, ValueEnum};
23use cloud_llm_client::predict_edits_v3;
24use edit_prediction_context::EditPredictionExcerptOptions;
25use gpui::{Application, AsyncApp, Entity, prelude::*};
26use language::{Bias, Buffer, BufferSnapshot, Point};
27use metrics::delta_chr_f;
28use project::{Project, Worktree, lsp_store::OpenLspBufferHandle};
29use reqwest_client::ReqwestClient;
30use std::io::{self};
31use std::time::Duration;
32use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc};
33use zeta::ContextMode;
34use zeta::udiff::DiffLine;
35
36#[derive(Parser, Debug)]
37#[command(name = "zeta")]
38struct ZetaCliArgs {
39 #[arg(long, default_value_t = false)]
40 printenv: bool,
41 #[command(subcommand)]
42 command: Option<Command>,
43}
44
45#[derive(Subcommand, Debug)]
46enum Command {
47 Context(ContextArgs),
48 ContextStats(ContextStatsArgs),
49 Predict(PredictArguments),
50 Eval(EvaluateArguments),
51 ConvertExample {
52 path: PathBuf,
53 #[arg(long, value_enum, default_value_t = ExampleFormat::Md)]
54 output_format: ExampleFormat,
55 },
56 Score {
57 golden_patch: PathBuf,
58 actual_patch: PathBuf,
59 },
60 Clean,
61}
62
63#[derive(Debug, Args)]
64struct ContextStatsArgs {
65 #[arg(long)]
66 worktree: PathBuf,
67 #[arg(long)]
68 extension: Option<String>,
69 #[arg(long)]
70 limit: Option<usize>,
71 #[arg(long)]
72 skip: Option<usize>,
73 #[clap(flatten)]
74 zeta2_args: Zeta2Args,
75}
76
77#[derive(Debug, Args)]
78struct ContextArgs {
79 #[arg(long)]
80 provider: ContextProvider,
81 #[arg(long)]
82 worktree: PathBuf,
83 #[arg(long)]
84 cursor: SourceLocation,
85 #[arg(long)]
86 use_language_server: bool,
87 #[arg(long)]
88 edit_history: Option<FileOrStdin>,
89 #[clap(flatten)]
90 zeta2_args: Zeta2Args,
91}
92
93#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
94enum ContextProvider {
95 Zeta1,
96 #[default]
97 Zeta2,
98}
99
100#[derive(Clone, Debug, Args)]
101struct Zeta2Args {
102 #[arg(long, default_value_t = 8192)]
103 max_prompt_bytes: usize,
104 #[arg(long, default_value_t = 2048)]
105 max_excerpt_bytes: usize,
106 #[arg(long, default_value_t = 1024)]
107 min_excerpt_bytes: usize,
108 #[arg(long, default_value_t = 0.66)]
109 target_before_cursor_over_total_bytes: f32,
110 #[arg(long, default_value_t = 1024)]
111 max_diagnostic_bytes: usize,
112 #[arg(long, value_enum, default_value_t = PromptFormat::default())]
113 prompt_format: PromptFormat,
114 #[arg(long, value_enum, default_value_t = Default::default())]
115 output_format: OutputFormat,
116 #[arg(long, default_value_t = 42)]
117 file_indexing_parallelism: usize,
118 #[arg(long, default_value_t = false)]
119 disable_imports_gathering: bool,
120 #[arg(long, default_value_t = u8::MAX)]
121 max_retrieved_definitions: u8,
122}
123
124#[derive(Debug, Args)]
125pub struct PredictArguments {
126 #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
127 format: PredictionsOutputFormat,
128 example_path: PathBuf,
129 #[clap(flatten)]
130 options: PredictionOptions,
131}
132
133#[derive(Clone, Debug, Args)]
134pub struct PredictionOptions {
135 #[clap(flatten)]
136 zeta2: Zeta2Args,
137 #[clap(long)]
138 provider: PredictionProvider,
139 #[clap(long, value_enum, default_value_t = CacheMode::default())]
140 cache: CacheMode,
141}
142
143#[derive(Debug, ValueEnum, Default, Clone, Copy, PartialEq)]
144pub enum CacheMode {
145 /// Use cached LLM requests and responses, except when multiple repetitions are requested
146 #[default]
147 Auto,
148 /// Use cached LLM requests and responses, based on the hash of the prompt and the endpoint.
149 #[value(alias = "request")]
150 Requests,
151 /// Ignore existing cache entries for both LLM and search.
152 Skip,
153 /// Use cached LLM responses AND search results for full determinism. Fails if they haven't been cached yet.
154 /// Useful for reproducing results and fixing bugs outside of search queries
155 Force,
156}
157
158impl CacheMode {
159 fn use_cached_llm_responses(&self) -> bool {
160 self.assert_not_auto();
161 matches!(self, CacheMode::Requests | CacheMode::Force)
162 }
163
164 fn use_cached_search_results(&self) -> bool {
165 self.assert_not_auto();
166 matches!(self, CacheMode::Force)
167 }
168
169 fn assert_not_auto(&self) {
170 assert_ne!(
171 *self,
172 CacheMode::Auto,
173 "Cache mode should not be auto at this point!"
174 );
175 }
176}
177
178#[derive(clap::ValueEnum, Debug, Clone)]
179pub enum PredictionsOutputFormat {
180 Json,
181 Md,
182 Diff,
183}
184
185#[derive(Debug, Args)]
186pub struct EvaluateArguments {
187 example_paths: Vec<PathBuf>,
188 #[clap(flatten)]
189 options: PredictionOptions,
190 #[clap(short, long, default_value_t = 1, alias = "repeat")]
191 repetitions: u16,
192 #[arg(long)]
193 skip_prediction: bool,
194}
195
196#[derive(clap::ValueEnum, Default, Debug, Clone, Copy, PartialEq)]
197enum PredictionProvider {
198 Zeta1,
199 #[default]
200 Zeta2,
201 Sweep,
202}
203
204fn zeta2_args_to_options(args: &Zeta2Args) -> zeta::ZetaOptions {
205 zeta::ZetaOptions {
206 context: ContextMode::Lsp(EditPredictionExcerptOptions {
207 max_bytes: args.max_excerpt_bytes,
208 min_bytes: args.min_excerpt_bytes,
209 target_before_cursor_over_total_bytes: args.target_before_cursor_over_total_bytes,
210 }),
211 max_diagnostic_bytes: args.max_diagnostic_bytes,
212 max_prompt_bytes: args.max_prompt_bytes,
213 prompt_format: args.prompt_format.into(),
214 file_indexing_parallelism: args.file_indexing_parallelism,
215 buffer_change_grouping_interval: Duration::ZERO,
216 }
217}
218
219#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
220enum PromptFormat {
221 MarkedExcerpt,
222 LabeledSections,
223 OnlySnippets,
224 #[default]
225 NumberedLines,
226 OldTextNewText,
227 Minimal,
228 MinimalQwen,
229 SeedCoder1120,
230}
231
232impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
233 fn into(self) -> predict_edits_v3::PromptFormat {
234 match self {
235 Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
236 Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
237 Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
238 Self::NumberedLines => predict_edits_v3::PromptFormat::NumLinesUniDiff,
239 Self::OldTextNewText => predict_edits_v3::PromptFormat::OldTextNewText,
240 Self::Minimal => predict_edits_v3::PromptFormat::Minimal,
241 Self::MinimalQwen => predict_edits_v3::PromptFormat::MinimalQwen,
242 Self::SeedCoder1120 => predict_edits_v3::PromptFormat::SeedCoder1120,
243 }
244 }
245}
246
247#[derive(clap::ValueEnum, Default, Debug, Clone)]
248enum OutputFormat {
249 #[default]
250 Prompt,
251 Request,
252 Full,
253}
254
255#[derive(Debug, Clone)]
256enum FileOrStdin {
257 File(PathBuf),
258 Stdin,
259}
260
261impl FileOrStdin {
262 async fn read_to_string(&self) -> Result<String, std::io::Error> {
263 match self {
264 FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
265 FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
266 }
267 }
268}
269
270impl FromStr for FileOrStdin {
271 type Err = <PathBuf as FromStr>::Err;
272
273 fn from_str(s: &str) -> Result<Self, Self::Err> {
274 match s {
275 "-" => Ok(Self::Stdin),
276 _ => Ok(Self::File(PathBuf::from_str(s)?)),
277 }
278 }
279}
280
281struct LoadedContext {
282 full_path_str: String,
283 snapshot: BufferSnapshot,
284 clipped_cursor: Point,
285 worktree: Entity<Worktree>,
286 project: Entity<Project>,
287 buffer: Entity<Buffer>,
288 lsp_open_handle: Option<OpenLspBufferHandle>,
289}
290
291async fn load_context(
292 args: &ContextArgs,
293 app_state: &Arc<ZetaCliAppState>,
294 cx: &mut AsyncApp,
295) -> Result<LoadedContext> {
296 let ContextArgs {
297 worktree: worktree_path,
298 cursor,
299 use_language_server,
300 ..
301 } = args;
302
303 let worktree_path = worktree_path.canonicalize()?;
304
305 let project = cx.update(|cx| {
306 Project::local(
307 app_state.client.clone(),
308 app_state.node_runtime.clone(),
309 app_state.user_store.clone(),
310 app_state.languages.clone(),
311 app_state.fs.clone(),
312 None,
313 cx,
314 )
315 })?;
316
317 let worktree = project
318 .update(cx, |project, cx| {
319 project.create_worktree(&worktree_path, true, cx)
320 })?
321 .await?;
322
323 let mut ready_languages = HashSet::default();
324 let (lsp_open_handle, buffer) = if *use_language_server {
325 let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
326 project.clone(),
327 worktree.clone(),
328 cursor.path.clone(),
329 &mut ready_languages,
330 cx,
331 )
332 .await?;
333 (Some(lsp_open_handle), buffer)
334 } else {
335 let buffer =
336 open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
337 (None, buffer)
338 };
339
340 let full_path_str = worktree
341 .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
342 .display(PathStyle::local())
343 .to_string();
344
345 let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
346 let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
347 if clipped_cursor != cursor.point {
348 let max_row = snapshot.max_point().row;
349 if cursor.point.row < max_row {
350 return Err(anyhow!(
351 "Cursor position {:?} is out of bounds (line length is {})",
352 cursor.point,
353 snapshot.line_len(cursor.point.row)
354 ));
355 } else {
356 return Err(anyhow!(
357 "Cursor position {:?} is out of bounds (max row is {})",
358 cursor.point,
359 max_row
360 ));
361 }
362 }
363
364 Ok(LoadedContext {
365 full_path_str,
366 snapshot,
367 clipped_cursor,
368 worktree,
369 project,
370 buffer,
371 lsp_open_handle,
372 })
373}
374
375async fn zeta2_context(
376 args: ContextArgs,
377 app_state: &Arc<ZetaCliAppState>,
378 cx: &mut AsyncApp,
379) -> Result<String> {
380 let LoadedContext {
381 worktree,
382 project,
383 buffer,
384 clipped_cursor,
385 lsp_open_handle: _handle,
386 ..
387 } = load_context(&args, app_state, cx).await?;
388
389 // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
390 // the whole worktree.
391 worktree
392 .read_with(cx, |worktree, _cx| {
393 worktree.as_local().unwrap().scan_complete()
394 })?
395 .await;
396 let output = cx
397 .update(|cx| {
398 let zeta = cx.new(|cx| {
399 zeta::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
400 });
401 let indexing_done_task = zeta.update(cx, |zeta, cx| {
402 zeta.set_options(zeta2_args_to_options(&args.zeta2_args));
403 zeta.register_buffer(&buffer, &project, cx);
404 zeta.wait_for_initial_indexing(&project, cx)
405 });
406 cx.spawn(async move |cx| {
407 indexing_done_task.await?;
408 let updates_rx = zeta.update(cx, |zeta, cx| {
409 let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
410 zeta.set_use_context(true);
411 zeta.refresh_context_if_needed(&project, &buffer, cursor, cx);
412 zeta.project_context_updates(&project).unwrap()
413 })?;
414
415 updates_rx.recv().await.ok();
416
417 let context = zeta.update(cx, |zeta, cx| {
418 zeta.context_for_project(&project, cx).to_vec()
419 })?;
420
421 anyhow::Ok(serde_json::to_string_pretty(&context).unwrap())
422 })
423 })?
424 .await?;
425
426 Ok(output)
427}
428
429async fn zeta1_context(
430 args: ContextArgs,
431 app_state: &Arc<ZetaCliAppState>,
432 cx: &mut AsyncApp,
433) -> Result<zeta::zeta1::GatherContextOutput> {
434 let LoadedContext {
435 full_path_str,
436 snapshot,
437 clipped_cursor,
438 ..
439 } = load_context(&args, app_state, cx).await?;
440
441 let events = match args.edit_history {
442 Some(events) => events.read_to_string().await?,
443 None => String::new(),
444 };
445
446 let prompt_for_events = move || (events, 0);
447 cx.update(|cx| {
448 zeta::zeta1::gather_context(
449 full_path_str,
450 &snapshot,
451 clipped_cursor,
452 prompt_for_events,
453 cloud_llm_client::PredictEditsRequestTrigger::Cli,
454 cx,
455 )
456 })?
457 .await
458}
459
460fn main() {
461 zlog::init();
462 zlog::init_output_stderr();
463 let args = ZetaCliArgs::parse();
464 let http_client = Arc::new(ReqwestClient::new());
465 let app = Application::headless().with_http_client(http_client);
466
467 app.run(move |cx| {
468 let app_state = Arc::new(headless::init(cx));
469 cx.spawn(async move |cx| {
470 match args.command {
471 None => {
472 if args.printenv {
473 ::util::shell_env::print_env();
474 } else {
475 panic!("Expected a command");
476 }
477 }
478 Some(Command::ContextStats(arguments)) => {
479 let result = retrieval_stats(
480 arguments.worktree,
481 app_state,
482 arguments.extension,
483 arguments.limit,
484 arguments.skip,
485 zeta2_args_to_options(&arguments.zeta2_args),
486 cx,
487 )
488 .await;
489 println!("{}", result.unwrap());
490 }
491 Some(Command::Context(context_args)) => {
492 let result = match context_args.provider {
493 ContextProvider::Zeta1 => {
494 let context =
495 zeta1_context(context_args, &app_state, cx).await.unwrap();
496 serde_json::to_string_pretty(&context.body).unwrap()
497 }
498 ContextProvider::Zeta2 => {
499 zeta2_context(context_args, &app_state, cx).await.unwrap()
500 }
501 };
502 println!("{}", result);
503 }
504 Some(Command::Predict(arguments)) => {
505 run_predict(arguments, &app_state, cx).await;
506 }
507 Some(Command::Eval(arguments)) => {
508 run_evaluate(arguments, &app_state, cx).await;
509 }
510 Some(Command::ConvertExample {
511 path,
512 output_format,
513 }) => {
514 let example = NamedExample::load(path).unwrap();
515 example.write(output_format, io::stdout()).unwrap();
516 }
517 Some(Command::Score {
518 golden_patch,
519 actual_patch,
520 }) => {
521 let golden_content = std::fs::read_to_string(golden_patch).unwrap();
522 let actual_content = std::fs::read_to_string(actual_patch).unwrap();
523
524 let golden_diff: Vec<DiffLine> = golden_content
525 .lines()
526 .map(|line| DiffLine::parse(line))
527 .collect();
528
529 let actual_diff: Vec<DiffLine> = actual_content
530 .lines()
531 .map(|line| DiffLine::parse(line))
532 .collect();
533
534 let score = delta_chr_f(&golden_diff, &actual_diff);
535 println!("{:.2}", score);
536 }
537 Some(Command::Clean) => {
538 std::fs::remove_dir_all(&*crate::paths::TARGET_ZETA_DIR).unwrap()
539 }
540 };
541
542 let _ = cx.update(|cx| cx.quit());
543 })
544 .detach();
545 });
546}