1mod evaluate;
2mod example;
3mod headless;
4mod metrics;
5mod paths;
6mod predict;
7mod source_location;
8mod syntax_retrieval_stats;
9mod util;
10
11use crate::{
12 evaluate::run_evaluate,
13 example::{ExampleFormat, NamedExample},
14 headless::ZetaCliAppState,
15 predict::run_predict,
16 source_location::SourceLocation,
17 syntax_retrieval_stats::retrieval_stats,
18 util::{open_buffer, open_buffer_with_language_server},
19};
20use ::util::paths::PathStyle;
21use anyhow::{Result, anyhow};
22use clap::{Args, Parser, Subcommand, ValueEnum};
23use cloud_llm_client::predict_edits_v3;
24use edit_prediction_context::{
25 EditPredictionContextOptions, EditPredictionExcerptOptions, EditPredictionScoreOptions,
26};
27use gpui::{Application, AsyncApp, Entity, prelude::*};
28use language::{Bias, Buffer, BufferSnapshot, Point};
29use project::{Project, Worktree};
30use reqwest_client::ReqwestClient;
31use serde_json::json;
32use std::io::{self};
33use std::time::Duration;
34use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc};
35use zeta::ContextMode;
36
37#[derive(Parser, Debug)]
38#[command(name = "zeta")]
39struct ZetaCliArgs {
40 #[arg(long, default_value_t = false)]
41 printenv: bool,
42 #[command(subcommand)]
43 command: Option<Command>,
44}
45
46#[derive(Subcommand, Debug)]
47enum Command {
48 Context(ContextArgs),
49 ContextStats(ContextStatsArgs),
50 Predict(PredictArguments),
51 Eval(EvaluateArguments),
52 ConvertExample {
53 path: PathBuf,
54 #[arg(long, value_enum, default_value_t = ExampleFormat::Md)]
55 output_format: ExampleFormat,
56 },
57 Clean,
58}
59
60#[derive(Debug, Args)]
61struct ContextStatsArgs {
62 #[arg(long)]
63 worktree: PathBuf,
64 #[arg(long)]
65 extension: Option<String>,
66 #[arg(long)]
67 limit: Option<usize>,
68 #[arg(long)]
69 skip: Option<usize>,
70 #[clap(flatten)]
71 zeta2_args: Zeta2Args,
72}
73
74#[derive(Debug, Args)]
75struct ContextArgs {
76 #[arg(long)]
77 provider: ContextProvider,
78 #[arg(long)]
79 worktree: PathBuf,
80 #[arg(long)]
81 cursor: SourceLocation,
82 #[arg(long)]
83 use_language_server: bool,
84 #[arg(long)]
85 edit_history: Option<FileOrStdin>,
86 #[clap(flatten)]
87 zeta2_args: Zeta2Args,
88}
89
90#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
91enum ContextProvider {
92 Zeta1,
93 #[default]
94 Syntax,
95}
96
97#[derive(Clone, Debug, Args)]
98struct Zeta2Args {
99 #[arg(long, default_value_t = 8192)]
100 max_prompt_bytes: usize,
101 #[arg(long, default_value_t = 2048)]
102 max_excerpt_bytes: usize,
103 #[arg(long, default_value_t = 1024)]
104 min_excerpt_bytes: usize,
105 #[arg(long, default_value_t = 0.66)]
106 target_before_cursor_over_total_bytes: f32,
107 #[arg(long, default_value_t = 1024)]
108 max_diagnostic_bytes: usize,
109 #[arg(long, value_enum, default_value_t = PromptFormat::default())]
110 prompt_format: PromptFormat,
111 #[arg(long, value_enum, default_value_t = Default::default())]
112 output_format: OutputFormat,
113 #[arg(long, default_value_t = 42)]
114 file_indexing_parallelism: usize,
115 #[arg(long, default_value_t = false)]
116 disable_imports_gathering: bool,
117 #[arg(long, default_value_t = u8::MAX)]
118 max_retrieved_definitions: u8,
119}
120
121#[derive(Debug, Args)]
122pub struct PredictArguments {
123 #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
124 format: PredictionsOutputFormat,
125 example_path: PathBuf,
126 #[clap(flatten)]
127 options: PredictionOptions,
128}
129
130#[derive(Clone, Debug, Args)]
131pub struct PredictionOptions {
132 #[clap(flatten)]
133 zeta2: Zeta2Args,
134 #[clap(long)]
135 provider: PredictionProvider,
136 #[clap(long, value_enum, default_value_t = CacheMode::default())]
137 cache: CacheMode,
138}
139
140#[derive(Debug, ValueEnum, Default, Clone, Copy, PartialEq)]
141pub enum CacheMode {
142 /// Use cached LLM requests and responses, except when multiple repetitions are requested
143 #[default]
144 Auto,
145 /// Use cached LLM requests and responses, based on the hash of the prompt and the endpoint.
146 #[value(alias = "request")]
147 Requests,
148 /// Ignore existing cache entries for both LLM and search.
149 Skip,
150 /// Use cached LLM responses AND search results for full determinism. Fails if they haven't been cached yet.
151 /// Useful for reproducing results and fixing bugs outside of search queries
152 Force,
153}
154
155impl CacheMode {
156 fn use_cached_llm_responses(&self) -> bool {
157 self.assert_not_auto();
158 matches!(self, CacheMode::Requests | CacheMode::Force)
159 }
160
161 fn use_cached_search_results(&self) -> bool {
162 self.assert_not_auto();
163 matches!(self, CacheMode::Force)
164 }
165
166 fn assert_not_auto(&self) {
167 assert_ne!(
168 *self,
169 CacheMode::Auto,
170 "Cache mode should not be auto at this point!"
171 );
172 }
173}
174
175#[derive(clap::ValueEnum, Debug, Clone)]
176pub enum PredictionsOutputFormat {
177 Json,
178 Md,
179 Diff,
180}
181
182#[derive(Debug, Args)]
183pub struct EvaluateArguments {
184 example_paths: Vec<PathBuf>,
185 #[clap(flatten)]
186 options: PredictionOptions,
187 #[clap(short, long, default_value_t = 1, alias = "repeat")]
188 repetitions: u16,
189 #[arg(long)]
190 skip_prediction: bool,
191}
192
193#[derive(clap::ValueEnum, Default, Debug, Clone, Copy, PartialEq)]
194enum PredictionProvider {
195 Zeta1,
196 #[default]
197 Zeta2,
198 Sweep,
199}
200
201fn zeta2_args_to_options(args: &Zeta2Args, omit_excerpt_overlaps: bool) -> zeta::ZetaOptions {
202 zeta::ZetaOptions {
203 context: ContextMode::Syntax(EditPredictionContextOptions {
204 max_retrieved_declarations: args.max_retrieved_definitions,
205 use_imports: !args.disable_imports_gathering,
206 excerpt: EditPredictionExcerptOptions {
207 max_bytes: args.max_excerpt_bytes,
208 min_bytes: args.min_excerpt_bytes,
209 target_before_cursor_over_total_bytes: args.target_before_cursor_over_total_bytes,
210 },
211 score: EditPredictionScoreOptions {
212 omit_excerpt_overlaps,
213 },
214 }),
215 max_diagnostic_bytes: args.max_diagnostic_bytes,
216 max_prompt_bytes: args.max_prompt_bytes,
217 prompt_format: args.prompt_format.into(),
218 file_indexing_parallelism: args.file_indexing_parallelism,
219 buffer_change_grouping_interval: Duration::ZERO,
220 }
221}
222
223#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
224enum PromptFormat {
225 MarkedExcerpt,
226 LabeledSections,
227 OnlySnippets,
228 #[default]
229 NumberedLines,
230 OldTextNewText,
231 Minimal,
232 MinimalQwen,
233 SeedCoder1120,
234}
235
236impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
237 fn into(self) -> predict_edits_v3::PromptFormat {
238 match self {
239 Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
240 Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
241 Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
242 Self::NumberedLines => predict_edits_v3::PromptFormat::NumLinesUniDiff,
243 Self::OldTextNewText => predict_edits_v3::PromptFormat::OldTextNewText,
244 Self::Minimal => predict_edits_v3::PromptFormat::Minimal,
245 Self::MinimalQwen => predict_edits_v3::PromptFormat::MinimalQwen,
246 Self::SeedCoder1120 => predict_edits_v3::PromptFormat::SeedCoder1120,
247 }
248 }
249}
250
251#[derive(clap::ValueEnum, Default, Debug, Clone)]
252enum OutputFormat {
253 #[default]
254 Prompt,
255 Request,
256 Full,
257}
258
259#[derive(Debug, Clone)]
260enum FileOrStdin {
261 File(PathBuf),
262 Stdin,
263}
264
265impl FileOrStdin {
266 async fn read_to_string(&self) -> Result<String, std::io::Error> {
267 match self {
268 FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
269 FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
270 }
271 }
272}
273
274impl FromStr for FileOrStdin {
275 type Err = <PathBuf as FromStr>::Err;
276
277 fn from_str(s: &str) -> Result<Self, Self::Err> {
278 match s {
279 "-" => Ok(Self::Stdin),
280 _ => Ok(Self::File(PathBuf::from_str(s)?)),
281 }
282 }
283}
284
285struct LoadedContext {
286 full_path_str: String,
287 snapshot: BufferSnapshot,
288 clipped_cursor: Point,
289 worktree: Entity<Worktree>,
290 project: Entity<Project>,
291 buffer: Entity<Buffer>,
292}
293
294async fn load_context(
295 args: &ContextArgs,
296 app_state: &Arc<ZetaCliAppState>,
297 cx: &mut AsyncApp,
298) -> Result<LoadedContext> {
299 let ContextArgs {
300 worktree: worktree_path,
301 cursor,
302 use_language_server,
303 ..
304 } = args;
305
306 let worktree_path = worktree_path.canonicalize()?;
307
308 let project = cx.update(|cx| {
309 Project::local(
310 app_state.client.clone(),
311 app_state.node_runtime.clone(),
312 app_state.user_store.clone(),
313 app_state.languages.clone(),
314 app_state.fs.clone(),
315 None,
316 cx,
317 )
318 })?;
319
320 let worktree = project
321 .update(cx, |project, cx| {
322 project.create_worktree(&worktree_path, true, cx)
323 })?
324 .await?;
325
326 let mut ready_languages = HashSet::default();
327 let (_lsp_open_handle, buffer) = if *use_language_server {
328 let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
329 project.clone(),
330 worktree.clone(),
331 cursor.path.clone(),
332 &mut ready_languages,
333 cx,
334 )
335 .await?;
336 (Some(lsp_open_handle), buffer)
337 } else {
338 let buffer =
339 open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
340 (None, buffer)
341 };
342
343 let full_path_str = worktree
344 .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
345 .display(PathStyle::local())
346 .to_string();
347
348 let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
349 let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
350 if clipped_cursor != cursor.point {
351 let max_row = snapshot.max_point().row;
352 if cursor.point.row < max_row {
353 return Err(anyhow!(
354 "Cursor position {:?} is out of bounds (line length is {})",
355 cursor.point,
356 snapshot.line_len(cursor.point.row)
357 ));
358 } else {
359 return Err(anyhow!(
360 "Cursor position {:?} is out of bounds (max row is {})",
361 cursor.point,
362 max_row
363 ));
364 }
365 }
366
367 Ok(LoadedContext {
368 full_path_str,
369 snapshot,
370 clipped_cursor,
371 worktree,
372 project,
373 buffer,
374 })
375}
376
377async fn zeta2_syntax_context(
378 args: ContextArgs,
379 app_state: &Arc<ZetaCliAppState>,
380 cx: &mut AsyncApp,
381) -> Result<String> {
382 let LoadedContext {
383 worktree,
384 project,
385 buffer,
386 clipped_cursor,
387 ..
388 } = load_context(&args, app_state, cx).await?;
389
390 // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
391 // the whole worktree.
392 worktree
393 .read_with(cx, |worktree, _cx| {
394 worktree.as_local().unwrap().scan_complete()
395 })?
396 .await;
397 let output = cx
398 .update(|cx| {
399 let zeta = cx.new(|cx| {
400 zeta::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
401 });
402 let indexing_done_task = zeta.update(cx, |zeta, cx| {
403 zeta.set_options(zeta2_args_to_options(&args.zeta2_args, true));
404 zeta.register_buffer(&buffer, &project, cx);
405 zeta.wait_for_initial_indexing(&project, cx)
406 });
407 cx.spawn(async move |cx| {
408 indexing_done_task.await?;
409 let request = zeta
410 .update(cx, |zeta, cx| {
411 let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
412 zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
413 })?
414 .await?;
415
416 let (prompt_string, section_labels) = cloud_zeta2_prompt::build_prompt(&request)?;
417
418 match args.zeta2_args.output_format {
419 OutputFormat::Prompt => anyhow::Ok(prompt_string),
420 OutputFormat::Request => anyhow::Ok(serde_json::to_string_pretty(&request)?),
421 OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({
422 "request": request,
423 "prompt": prompt_string,
424 "section_labels": section_labels,
425 }))?),
426 }
427 })
428 })?
429 .await?;
430
431 Ok(output)
432}
433
434async fn zeta1_context(
435 args: ContextArgs,
436 app_state: &Arc<ZetaCliAppState>,
437 cx: &mut AsyncApp,
438) -> Result<zeta::zeta1::GatherContextOutput> {
439 let LoadedContext {
440 full_path_str,
441 snapshot,
442 clipped_cursor,
443 ..
444 } = load_context(&args, app_state, cx).await?;
445
446 let events = match args.edit_history {
447 Some(events) => events.read_to_string().await?,
448 None => String::new(),
449 };
450
451 let prompt_for_events = move || (events, 0);
452 cx.update(|cx| {
453 zeta::zeta1::gather_context(
454 full_path_str,
455 &snapshot,
456 clipped_cursor,
457 prompt_for_events,
458 cloud_llm_client::PredictEditsRequestTrigger::Cli,
459 cx,
460 )
461 })?
462 .await
463}
464
465fn main() {
466 zlog::init();
467 zlog::init_output_stderr();
468 let args = ZetaCliArgs::parse();
469 let http_client = Arc::new(ReqwestClient::new());
470 let app = Application::headless().with_http_client(http_client);
471
472 app.run(move |cx| {
473 let app_state = Arc::new(headless::init(cx));
474 cx.spawn(async move |cx| {
475 match args.command {
476 None => {
477 if args.printenv {
478 ::util::shell_env::print_env();
479 return;
480 } else {
481 panic!("Expected a command");
482 }
483 }
484 Some(Command::ContextStats(arguments)) => {
485 let result = retrieval_stats(
486 arguments.worktree,
487 app_state,
488 arguments.extension,
489 arguments.limit,
490 arguments.skip,
491 zeta2_args_to_options(&arguments.zeta2_args, false),
492 cx,
493 )
494 .await;
495 println!("{}", result.unwrap());
496 }
497 Some(Command::Context(context_args)) => {
498 let result = match context_args.provider {
499 ContextProvider::Zeta1 => {
500 let context =
501 zeta1_context(context_args, &app_state, cx).await.unwrap();
502 serde_json::to_string_pretty(&context.body).unwrap()
503 }
504 ContextProvider::Syntax => {
505 zeta2_syntax_context(context_args, &app_state, cx)
506 .await
507 .unwrap()
508 }
509 };
510 println!("{}", result);
511 }
512 Some(Command::Predict(arguments)) => {
513 run_predict(arguments, &app_state, cx).await;
514 }
515 Some(Command::Eval(arguments)) => {
516 run_evaluate(arguments, &app_state, cx).await;
517 }
518 Some(Command::ConvertExample {
519 path,
520 output_format,
521 }) => {
522 let example = NamedExample::load(path).unwrap();
523 example.write(output_format, io::stdout()).unwrap();
524 }
525 Some(Command::Clean) => {
526 std::fs::remove_dir_all(&*crate::paths::TARGET_ZETA_DIR).unwrap()
527 }
528 };
529
530 let _ = cx.update(|cx| cx.quit());
531 })
532 .detach();
533 });
534}