1mod evaluate;
2mod example;
3mod headless;
4mod paths;
5mod predict;
6mod source_location;
7mod syntax_retrieval_stats;
8mod util;
9
10use crate::{
11 evaluate::run_evaluate,
12 example::{ExampleFormat, NamedExample},
13 headless::ZetaCliAppState,
14 predict::run_predict,
15 source_location::SourceLocation,
16 syntax_retrieval_stats::retrieval_stats,
17 util::{open_buffer, open_buffer_with_language_server},
18};
19use ::util::paths::PathStyle;
20use anyhow::{Result, anyhow};
21use clap::{Args, Parser, Subcommand, ValueEnum};
22use cloud_llm_client::predict_edits_v3;
23use edit_prediction_context::{
24 EditPredictionContextOptions, EditPredictionExcerptOptions, EditPredictionScoreOptions,
25};
26use gpui::{Application, AsyncApp, Entity, prelude::*};
27use language::{Bias, Buffer, BufferSnapshot, Point};
28use project::{Project, Worktree};
29use reqwest_client::ReqwestClient;
30use serde_json::json;
31use std::io::{self};
32use std::time::Duration;
33use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc};
34use zeta::ContextMode;
35
36#[derive(Parser, Debug)]
37#[command(name = "zeta")]
38struct ZetaCliArgs {
39 #[arg(long, default_value_t = false)]
40 printenv: bool,
41 #[command(subcommand)]
42 command: Option<Command>,
43}
44
45#[derive(Subcommand, Debug)]
46enum Command {
47 Context(ContextArgs),
48 ContextStats(ContextStatsArgs),
49 Predict(PredictArguments),
50 Eval(EvaluateArguments),
51 ConvertExample {
52 path: PathBuf,
53 #[arg(long, value_enum, default_value_t = ExampleFormat::Md)]
54 output_format: ExampleFormat,
55 },
56 Clean,
57}
58
59#[derive(Debug, Args)]
60struct ContextStatsArgs {
61 #[arg(long)]
62 worktree: PathBuf,
63 #[arg(long)]
64 extension: Option<String>,
65 #[arg(long)]
66 limit: Option<usize>,
67 #[arg(long)]
68 skip: Option<usize>,
69 #[clap(flatten)]
70 zeta2_args: Zeta2Args,
71}
72
73#[derive(Debug, Args)]
74struct ContextArgs {
75 #[arg(long)]
76 provider: ContextProvider,
77 #[arg(long)]
78 worktree: PathBuf,
79 #[arg(long)]
80 cursor: SourceLocation,
81 #[arg(long)]
82 use_language_server: bool,
83 #[arg(long)]
84 edit_history: Option<FileOrStdin>,
85 #[clap(flatten)]
86 zeta2_args: Zeta2Args,
87}
88
89#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
90enum ContextProvider {
91 Zeta1,
92 #[default]
93 Syntax,
94}
95
96#[derive(Clone, Debug, Args)]
97struct Zeta2Args {
98 #[arg(long, default_value_t = 8192)]
99 max_prompt_bytes: usize,
100 #[arg(long, default_value_t = 2048)]
101 max_excerpt_bytes: usize,
102 #[arg(long, default_value_t = 1024)]
103 min_excerpt_bytes: usize,
104 #[arg(long, default_value_t = 0.66)]
105 target_before_cursor_over_total_bytes: f32,
106 #[arg(long, default_value_t = 1024)]
107 max_diagnostic_bytes: usize,
108 #[arg(long, value_enum, default_value_t = PromptFormat::default())]
109 prompt_format: PromptFormat,
110 #[arg(long, value_enum, default_value_t = Default::default())]
111 output_format: OutputFormat,
112 #[arg(long, default_value_t = 42)]
113 file_indexing_parallelism: usize,
114 #[arg(long, default_value_t = false)]
115 disable_imports_gathering: bool,
116 #[arg(long, default_value_t = u8::MAX)]
117 max_retrieved_definitions: u8,
118}
119
120#[derive(Debug, Args)]
121pub struct PredictArguments {
122 #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
123 format: PredictionsOutputFormat,
124 example_path: PathBuf,
125 #[clap(flatten)]
126 options: PredictionOptions,
127}
128
129#[derive(Clone, Debug, Args)]
130pub struct PredictionOptions {
131 #[clap(flatten)]
132 zeta2: Zeta2Args,
133 #[clap(long)]
134 provider: PredictionProvider,
135 #[clap(long, value_enum, default_value_t = CacheMode::default())]
136 cache: CacheMode,
137}
138
139#[derive(Debug, ValueEnum, Default, Clone, Copy, PartialEq)]
140pub enum CacheMode {
141 /// Use cached LLM requests and responses, except when multiple repetitions are requested
142 #[default]
143 Auto,
144 /// Use cached LLM requests and responses, based on the hash of the prompt and the endpoint.
145 #[value(alias = "request")]
146 Requests,
147 /// Ignore existing cache entries for both LLM and search.
148 Skip,
149 /// Use cached LLM responses AND search results for full determinism. Fails if they haven't been cached yet.
150 /// Useful for reproducing results and fixing bugs outside of search queries
151 Force,
152}
153
154impl CacheMode {
155 fn use_cached_llm_responses(&self) -> bool {
156 self.assert_not_auto();
157 matches!(self, CacheMode::Requests | CacheMode::Force)
158 }
159
160 fn use_cached_search_results(&self) -> bool {
161 self.assert_not_auto();
162 matches!(self, CacheMode::Force)
163 }
164
165 fn assert_not_auto(&self) {
166 assert_ne!(
167 *self,
168 CacheMode::Auto,
169 "Cache mode should not be auto at this point!"
170 );
171 }
172}
173
174#[derive(clap::ValueEnum, Debug, Clone)]
175pub enum PredictionsOutputFormat {
176 Json,
177 Md,
178 Diff,
179}
180
181#[derive(Debug, Args)]
182pub struct EvaluateArguments {
183 example_paths: Vec<PathBuf>,
184 #[clap(flatten)]
185 options: PredictionOptions,
186 #[clap(short, long, default_value_t = 1, alias = "repeat")]
187 repetitions: u16,
188 #[arg(long)]
189 skip_prediction: bool,
190}
191
192#[derive(clap::ValueEnum, Default, Debug, Clone, Copy, PartialEq)]
193enum PredictionProvider {
194 Zeta1,
195 #[default]
196 Zeta2,
197 Sweep,
198}
199
200fn zeta2_args_to_options(args: &Zeta2Args, omit_excerpt_overlaps: bool) -> zeta::ZetaOptions {
201 zeta::ZetaOptions {
202 context: ContextMode::Syntax(EditPredictionContextOptions {
203 max_retrieved_declarations: args.max_retrieved_definitions,
204 use_imports: !args.disable_imports_gathering,
205 excerpt: EditPredictionExcerptOptions {
206 max_bytes: args.max_excerpt_bytes,
207 min_bytes: args.min_excerpt_bytes,
208 target_before_cursor_over_total_bytes: args.target_before_cursor_over_total_bytes,
209 },
210 score: EditPredictionScoreOptions {
211 omit_excerpt_overlaps,
212 },
213 }),
214 max_diagnostic_bytes: args.max_diagnostic_bytes,
215 max_prompt_bytes: args.max_prompt_bytes,
216 prompt_format: args.prompt_format.into(),
217 file_indexing_parallelism: args.file_indexing_parallelism,
218 buffer_change_grouping_interval: Duration::ZERO,
219 }
220}
221
222#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
223enum PromptFormat {
224 MarkedExcerpt,
225 LabeledSections,
226 OnlySnippets,
227 #[default]
228 NumberedLines,
229 OldTextNewText,
230 Minimal,
231 MinimalQwen,
232 SeedCoder1120,
233}
234
235impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
236 fn into(self) -> predict_edits_v3::PromptFormat {
237 match self {
238 Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
239 Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
240 Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
241 Self::NumberedLines => predict_edits_v3::PromptFormat::NumLinesUniDiff,
242 Self::OldTextNewText => predict_edits_v3::PromptFormat::OldTextNewText,
243 Self::Minimal => predict_edits_v3::PromptFormat::Minimal,
244 Self::MinimalQwen => predict_edits_v3::PromptFormat::MinimalQwen,
245 Self::SeedCoder1120 => predict_edits_v3::PromptFormat::SeedCoder1120,
246 }
247 }
248}
249
250#[derive(clap::ValueEnum, Default, Debug, Clone)]
251enum OutputFormat {
252 #[default]
253 Prompt,
254 Request,
255 Full,
256}
257
258#[derive(Debug, Clone)]
259enum FileOrStdin {
260 File(PathBuf),
261 Stdin,
262}
263
264impl FileOrStdin {
265 async fn read_to_string(&self) -> Result<String, std::io::Error> {
266 match self {
267 FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
268 FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
269 }
270 }
271}
272
273impl FromStr for FileOrStdin {
274 type Err = <PathBuf as FromStr>::Err;
275
276 fn from_str(s: &str) -> Result<Self, Self::Err> {
277 match s {
278 "-" => Ok(Self::Stdin),
279 _ => Ok(Self::File(PathBuf::from_str(s)?)),
280 }
281 }
282}
283
284struct LoadedContext {
285 full_path_str: String,
286 snapshot: BufferSnapshot,
287 clipped_cursor: Point,
288 worktree: Entity<Worktree>,
289 project: Entity<Project>,
290 buffer: Entity<Buffer>,
291}
292
293async fn load_context(
294 args: &ContextArgs,
295 app_state: &Arc<ZetaCliAppState>,
296 cx: &mut AsyncApp,
297) -> Result<LoadedContext> {
298 let ContextArgs {
299 worktree: worktree_path,
300 cursor,
301 use_language_server,
302 ..
303 } = args;
304
305 let worktree_path = worktree_path.canonicalize()?;
306
307 let project = cx.update(|cx| {
308 Project::local(
309 app_state.client.clone(),
310 app_state.node_runtime.clone(),
311 app_state.user_store.clone(),
312 app_state.languages.clone(),
313 app_state.fs.clone(),
314 None,
315 cx,
316 )
317 })?;
318
319 let worktree = project
320 .update(cx, |project, cx| {
321 project.create_worktree(&worktree_path, true, cx)
322 })?
323 .await?;
324
325 let mut ready_languages = HashSet::default();
326 let (_lsp_open_handle, buffer) = if *use_language_server {
327 let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
328 project.clone(),
329 worktree.clone(),
330 cursor.path.clone(),
331 &mut ready_languages,
332 cx,
333 )
334 .await?;
335 (Some(lsp_open_handle), buffer)
336 } else {
337 let buffer =
338 open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
339 (None, buffer)
340 };
341
342 let full_path_str = worktree
343 .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
344 .display(PathStyle::local())
345 .to_string();
346
347 let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
348 let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
349 if clipped_cursor != cursor.point {
350 let max_row = snapshot.max_point().row;
351 if cursor.point.row < max_row {
352 return Err(anyhow!(
353 "Cursor position {:?} is out of bounds (line length is {})",
354 cursor.point,
355 snapshot.line_len(cursor.point.row)
356 ));
357 } else {
358 return Err(anyhow!(
359 "Cursor position {:?} is out of bounds (max row is {})",
360 cursor.point,
361 max_row
362 ));
363 }
364 }
365
366 Ok(LoadedContext {
367 full_path_str,
368 snapshot,
369 clipped_cursor,
370 worktree,
371 project,
372 buffer,
373 })
374}
375
376async fn zeta2_syntax_context(
377 args: ContextArgs,
378 app_state: &Arc<ZetaCliAppState>,
379 cx: &mut AsyncApp,
380) -> Result<String> {
381 let LoadedContext {
382 worktree,
383 project,
384 buffer,
385 clipped_cursor,
386 ..
387 } = load_context(&args, app_state, cx).await?;
388
389 // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
390 // the whole worktree.
391 worktree
392 .read_with(cx, |worktree, _cx| {
393 worktree.as_local().unwrap().scan_complete()
394 })?
395 .await;
396 let output = cx
397 .update(|cx| {
398 let zeta = cx.new(|cx| {
399 zeta::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
400 });
401 let indexing_done_task = zeta.update(cx, |zeta, cx| {
402 zeta.set_options(zeta2_args_to_options(&args.zeta2_args, true));
403 zeta.register_buffer(&buffer, &project, cx);
404 zeta.wait_for_initial_indexing(&project, cx)
405 });
406 cx.spawn(async move |cx| {
407 indexing_done_task.await?;
408 let request = zeta
409 .update(cx, |zeta, cx| {
410 let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
411 zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
412 })?
413 .await?;
414
415 let (prompt_string, section_labels) = cloud_zeta2_prompt::build_prompt(&request)?;
416
417 match args.zeta2_args.output_format {
418 OutputFormat::Prompt => anyhow::Ok(prompt_string),
419 OutputFormat::Request => anyhow::Ok(serde_json::to_string_pretty(&request)?),
420 OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({
421 "request": request,
422 "prompt": prompt_string,
423 "section_labels": section_labels,
424 }))?),
425 }
426 })
427 })?
428 .await?;
429
430 Ok(output)
431}
432
433async fn zeta1_context(
434 args: ContextArgs,
435 app_state: &Arc<ZetaCliAppState>,
436 cx: &mut AsyncApp,
437) -> Result<zeta::zeta1::GatherContextOutput> {
438 let LoadedContext {
439 full_path_str,
440 snapshot,
441 clipped_cursor,
442 ..
443 } = load_context(&args, app_state, cx).await?;
444
445 let events = match args.edit_history {
446 Some(events) => events.read_to_string().await?,
447 None => String::new(),
448 };
449
450 let prompt_for_events = move || (events, 0);
451 cx.update(|cx| {
452 zeta::zeta1::gather_context(
453 full_path_str,
454 &snapshot,
455 clipped_cursor,
456 prompt_for_events,
457 cloud_llm_client::PredictEditsRequestTrigger::Cli,
458 cx,
459 )
460 })?
461 .await
462}
463
464fn main() {
465 zlog::init();
466 zlog::init_output_stderr();
467 let args = ZetaCliArgs::parse();
468 let http_client = Arc::new(ReqwestClient::new());
469 let app = Application::headless().with_http_client(http_client);
470
471 app.run(move |cx| {
472 let app_state = Arc::new(headless::init(cx));
473 cx.spawn(async move |cx| {
474 match args.command {
475 None => {
476 if args.printenv {
477 ::util::shell_env::print_env();
478 return;
479 } else {
480 panic!("Expected a command");
481 }
482 }
483 Some(Command::ContextStats(arguments)) => {
484 let result = retrieval_stats(
485 arguments.worktree,
486 app_state,
487 arguments.extension,
488 arguments.limit,
489 arguments.skip,
490 zeta2_args_to_options(&arguments.zeta2_args, false),
491 cx,
492 )
493 .await;
494 println!("{}", result.unwrap());
495 }
496 Some(Command::Context(context_args)) => {
497 let result = match context_args.provider {
498 ContextProvider::Zeta1 => {
499 let context =
500 zeta1_context(context_args, &app_state, cx).await.unwrap();
501 serde_json::to_string_pretty(&context.body).unwrap()
502 }
503 ContextProvider::Syntax => {
504 zeta2_syntax_context(context_args, &app_state, cx)
505 .await
506 .unwrap()
507 }
508 };
509 println!("{}", result);
510 }
511 Some(Command::Predict(arguments)) => {
512 run_predict(arguments, &app_state, cx).await;
513 }
514 Some(Command::Eval(arguments)) => {
515 run_evaluate(arguments, &app_state, cx).await;
516 }
517 Some(Command::ConvertExample {
518 path,
519 output_format,
520 }) => {
521 let example = NamedExample::load(path).unwrap();
522 example.write(output_format, io::stdout()).unwrap();
523 }
524 Some(Command::Clean) => {
525 std::fs::remove_dir_all(&*crate::paths::TARGET_ZETA_DIR).unwrap()
526 }
527 };
528
529 let _ = cx.update(|cx| cx.quit());
530 })
531 .detach();
532 });
533}