1mod evaluate;
2mod example;
3mod headless;
4mod metrics;
5mod paths;
6mod predict;
7mod source_location;
8mod util;
9
10use crate::{
11 evaluate::run_evaluate,
12 example::{ExampleFormat, NamedExample},
13 headless::ZetaCliAppState,
14 predict::run_predict,
15 source_location::SourceLocation,
16 util::{open_buffer, open_buffer_with_language_server},
17};
18use ::util::paths::PathStyle;
19use anyhow::{Result, anyhow};
20use clap::{Args, Parser, Subcommand, ValueEnum};
21use cloud_llm_client::predict_edits_v3;
22use edit_prediction::udiff::DiffLine;
23use edit_prediction_context::EditPredictionExcerptOptions;
24use gpui::{Application, AsyncApp, Entity, prelude::*};
25use language::{Bias, Buffer, BufferSnapshot, Point};
26use metrics::delta_chr_f;
27use project::{Project, Worktree, lsp_store::OpenLspBufferHandle};
28use reqwest_client::ReqwestClient;
29use std::io::{self};
30use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc};
31
32#[derive(Parser, Debug)]
33#[command(name = "zeta")]
34struct ZetaCliArgs {
35 #[arg(long, default_value_t = false)]
36 printenv: bool,
37 #[command(subcommand)]
38 command: Option<Command>,
39}
40
41#[derive(Subcommand, Debug)]
42enum Command {
43 Context(ContextArgs),
44 Predict(PredictArguments),
45 Eval(EvaluateArguments),
46 ConvertExample {
47 path: PathBuf,
48 #[arg(long, value_enum, default_value_t = ExampleFormat::Md)]
49 output_format: ExampleFormat,
50 },
51 Score {
52 golden_patch: PathBuf,
53 actual_patch: PathBuf,
54 },
55 Clean,
56}
57
58#[derive(Debug, Args)]
59struct ContextArgs {
60 #[arg(long)]
61 provider: ContextProvider,
62 #[arg(long)]
63 worktree: PathBuf,
64 #[arg(long)]
65 cursor: SourceLocation,
66 #[arg(long)]
67 use_language_server: bool,
68 #[arg(long)]
69 edit_history: Option<FileOrStdin>,
70 #[clap(flatten)]
71 zeta2_args: Zeta2Args,
72}
73
74#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
75enum ContextProvider {
76 Zeta1,
77 #[default]
78 Zeta2,
79}
80
81#[derive(Clone, Debug, Args)]
82struct Zeta2Args {
83 #[arg(long, default_value_t = 8192)]
84 max_prompt_bytes: usize,
85 #[arg(long, default_value_t = 2048)]
86 max_excerpt_bytes: usize,
87 #[arg(long, default_value_t = 1024)]
88 min_excerpt_bytes: usize,
89 #[arg(long, default_value_t = 0.66)]
90 target_before_cursor_over_total_bytes: f32,
91 #[arg(long, default_value_t = 1024)]
92 max_diagnostic_bytes: usize,
93 #[arg(long, value_enum, default_value_t = PromptFormat::default())]
94 prompt_format: PromptFormat,
95 #[arg(long, value_enum, default_value_t = Default::default())]
96 output_format: OutputFormat,
97 #[arg(long, default_value_t = 42)]
98 file_indexing_parallelism: usize,
99 #[arg(long, default_value_t = false)]
100 disable_imports_gathering: bool,
101 #[arg(long, default_value_t = u8::MAX)]
102 max_retrieved_definitions: u8,
103}
104
105#[derive(Debug, Args)]
106pub struct PredictArguments {
107 #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
108 format: PredictionsOutputFormat,
109 example_path: PathBuf,
110 #[clap(flatten)]
111 options: PredictionOptions,
112}
113
114#[derive(Clone, Debug, Args)]
115pub struct PredictionOptions {
116 #[clap(flatten)]
117 zeta2: Zeta2Args,
118 #[clap(long)]
119 provider: PredictionProvider,
120 #[clap(long, value_enum, default_value_t = CacheMode::default())]
121 cache: CacheMode,
122}
123
124#[derive(Debug, ValueEnum, Default, Clone, Copy, PartialEq)]
125pub enum CacheMode {
126 /// Use cached LLM requests and responses, except when multiple repetitions are requested
127 #[default]
128 Auto,
129 /// Use cached LLM requests and responses, based on the hash of the prompt and the endpoint.
130 #[value(alias = "request")]
131 Requests,
132 /// Ignore existing cache entries for both LLM and search.
133 Skip,
134 /// Use cached LLM responses AND search results for full determinism. Fails if they haven't been cached yet.
135 /// Useful for reproducing results and fixing bugs outside of search queries
136 Force,
137}
138
139impl CacheMode {
140 fn use_cached_llm_responses(&self) -> bool {
141 self.assert_not_auto();
142 matches!(self, CacheMode::Requests | CacheMode::Force)
143 }
144
145 fn use_cached_search_results(&self) -> bool {
146 self.assert_not_auto();
147 matches!(self, CacheMode::Force)
148 }
149
150 fn assert_not_auto(&self) {
151 assert_ne!(
152 *self,
153 CacheMode::Auto,
154 "Cache mode should not be auto at this point!"
155 );
156 }
157}
158
159#[derive(clap::ValueEnum, Debug, Clone)]
160pub enum PredictionsOutputFormat {
161 Json,
162 Md,
163 Diff,
164}
165
166#[derive(Debug, Args)]
167pub struct EvaluateArguments {
168 example_paths: Vec<PathBuf>,
169 #[clap(flatten)]
170 options: PredictionOptions,
171 #[clap(short, long, default_value_t = 1, alias = "repeat")]
172 repetitions: u16,
173 #[arg(long)]
174 skip_prediction: bool,
175}
176
177#[derive(clap::ValueEnum, Default, Debug, Clone, Copy, PartialEq)]
178enum PredictionProvider {
179 Zeta1,
180 #[default]
181 Zeta2,
182 Sweep,
183}
184
185fn zeta2_args_to_options(args: &Zeta2Args) -> edit_prediction::ZetaOptions {
186 edit_prediction::ZetaOptions {
187 context: EditPredictionExcerptOptions {
188 max_bytes: args.max_excerpt_bytes,
189 min_bytes: args.min_excerpt_bytes,
190 target_before_cursor_over_total_bytes: args.target_before_cursor_over_total_bytes,
191 },
192 max_prompt_bytes: args.max_prompt_bytes,
193 prompt_format: args.prompt_format.into(),
194 }
195}
196
197#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
198enum PromptFormat {
199 OnlySnippets,
200 #[default]
201 OldTextNewText,
202 Minimal,
203 MinimalQwen,
204 SeedCoder1120,
205}
206
207impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
208 fn into(self) -> predict_edits_v3::PromptFormat {
209 match self {
210 Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
211 Self::OldTextNewText => predict_edits_v3::PromptFormat::OldTextNewText,
212 Self::Minimal => predict_edits_v3::PromptFormat::Minimal,
213 Self::MinimalQwen => predict_edits_v3::PromptFormat::MinimalQwen,
214 Self::SeedCoder1120 => predict_edits_v3::PromptFormat::SeedCoder1120,
215 }
216 }
217}
218
219#[derive(clap::ValueEnum, Default, Debug, Clone)]
220enum OutputFormat {
221 #[default]
222 Prompt,
223 Request,
224 Full,
225}
226
227#[derive(Debug, Clone)]
228enum FileOrStdin {
229 File(PathBuf),
230 Stdin,
231}
232
233impl FileOrStdin {
234 async fn read_to_string(&self) -> Result<String, std::io::Error> {
235 match self {
236 FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
237 FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
238 }
239 }
240}
241
242impl FromStr for FileOrStdin {
243 type Err = <PathBuf as FromStr>::Err;
244
245 fn from_str(s: &str) -> Result<Self, Self::Err> {
246 match s {
247 "-" => Ok(Self::Stdin),
248 _ => Ok(Self::File(PathBuf::from_str(s)?)),
249 }
250 }
251}
252
253struct LoadedContext {
254 full_path_str: String,
255 snapshot: BufferSnapshot,
256 clipped_cursor: Point,
257 worktree: Entity<Worktree>,
258 project: Entity<Project>,
259 buffer: Entity<Buffer>,
260 lsp_open_handle: Option<OpenLspBufferHandle>,
261}
262
263async fn load_context(
264 args: &ContextArgs,
265 app_state: &Arc<ZetaCliAppState>,
266 cx: &mut AsyncApp,
267) -> Result<LoadedContext> {
268 let ContextArgs {
269 worktree: worktree_path,
270 cursor,
271 use_language_server,
272 ..
273 } = args;
274
275 let worktree_path = worktree_path.canonicalize()?;
276
277 let project = cx.update(|cx| {
278 Project::local(
279 app_state.client.clone(),
280 app_state.node_runtime.clone(),
281 app_state.user_store.clone(),
282 app_state.languages.clone(),
283 app_state.fs.clone(),
284 None,
285 cx,
286 )
287 })?;
288
289 let worktree = project
290 .update(cx, |project, cx| {
291 project.create_worktree(&worktree_path, true, cx)
292 })?
293 .await?;
294
295 let mut ready_languages = HashSet::default();
296 let (lsp_open_handle, buffer) = if *use_language_server {
297 let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
298 project.clone(),
299 worktree.clone(),
300 cursor.path.clone(),
301 &mut ready_languages,
302 cx,
303 )
304 .await?;
305 (Some(lsp_open_handle), buffer)
306 } else {
307 let buffer =
308 open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
309 (None, buffer)
310 };
311
312 let full_path_str = worktree
313 .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
314 .display(PathStyle::local())
315 .to_string();
316
317 let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
318 let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
319 if clipped_cursor != cursor.point {
320 let max_row = snapshot.max_point().row;
321 if cursor.point.row < max_row {
322 return Err(anyhow!(
323 "Cursor position {:?} is out of bounds (line length is {})",
324 cursor.point,
325 snapshot.line_len(cursor.point.row)
326 ));
327 } else {
328 return Err(anyhow!(
329 "Cursor position {:?} is out of bounds (max row is {})",
330 cursor.point,
331 max_row
332 ));
333 }
334 }
335
336 Ok(LoadedContext {
337 full_path_str,
338 snapshot,
339 clipped_cursor,
340 worktree,
341 project,
342 buffer,
343 lsp_open_handle,
344 })
345}
346
347async fn zeta2_context(
348 args: ContextArgs,
349 app_state: &Arc<ZetaCliAppState>,
350 cx: &mut AsyncApp,
351) -> Result<String> {
352 let LoadedContext {
353 worktree,
354 project,
355 buffer,
356 clipped_cursor,
357 lsp_open_handle: _handle,
358 ..
359 } = load_context(&args, app_state, cx).await?;
360
361 // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
362 // the whole worktree.
363 worktree
364 .read_with(cx, |worktree, _cx| {
365 worktree.as_local().unwrap().scan_complete()
366 })?
367 .await;
368 let output = cx
369 .update(|cx| {
370 let store = cx.new(|cx| {
371 edit_prediction::EditPredictionStore::new(
372 app_state.client.clone(),
373 app_state.user_store.clone(),
374 cx,
375 )
376 });
377 store.update(cx, |store, cx| {
378 store.set_options(zeta2_args_to_options(&args.zeta2_args));
379 store.register_buffer(&buffer, &project, cx);
380 });
381 cx.spawn(async move |cx| {
382 let updates_rx = store.update(cx, |store, cx| {
383 let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
384 store.set_use_context(true);
385 store.refresh_context(&project, &buffer, cursor, cx);
386 store.project_context_updates(&project).unwrap()
387 })?;
388
389 updates_rx.recv().await.ok();
390
391 let context = store.update(cx, |store, cx| {
392 store.context_for_project(&project, cx).to_vec()
393 })?;
394
395 anyhow::Ok(serde_json::to_string_pretty(&context).unwrap())
396 })
397 })?
398 .await?;
399
400 Ok(output)
401}
402
403async fn zeta1_context(
404 args: ContextArgs,
405 app_state: &Arc<ZetaCliAppState>,
406 cx: &mut AsyncApp,
407) -> Result<edit_prediction::zeta1::GatherContextOutput> {
408 let LoadedContext {
409 full_path_str,
410 snapshot,
411 clipped_cursor,
412 ..
413 } = load_context(&args, app_state, cx).await?;
414
415 let events = match args.edit_history {
416 Some(events) => events.read_to_string().await?,
417 None => String::new(),
418 };
419
420 let prompt_for_events = move || (events, 0);
421 cx.update(|cx| {
422 edit_prediction::zeta1::gather_context(
423 full_path_str,
424 &snapshot,
425 clipped_cursor,
426 prompt_for_events,
427 cloud_llm_client::PredictEditsRequestTrigger::Cli,
428 cx,
429 )
430 })?
431 .await
432}
433
434fn main() {
435 zlog::init();
436 zlog::init_output_stderr();
437 let args = ZetaCliArgs::parse();
438 let http_client = Arc::new(ReqwestClient::new());
439 let app = Application::headless().with_http_client(http_client);
440
441 app.run(move |cx| {
442 let app_state = Arc::new(headless::init(cx));
443 cx.spawn(async move |cx| {
444 match args.command {
445 None => {
446 if args.printenv {
447 ::util::shell_env::print_env();
448 } else {
449 panic!("Expected a command");
450 }
451 }
452 Some(Command::Context(context_args)) => {
453 let result = match context_args.provider {
454 ContextProvider::Zeta1 => {
455 let context =
456 zeta1_context(context_args, &app_state, cx).await.unwrap();
457 serde_json::to_string_pretty(&context.body).unwrap()
458 }
459 ContextProvider::Zeta2 => {
460 zeta2_context(context_args, &app_state, cx).await.unwrap()
461 }
462 };
463 println!("{}", result);
464 }
465 Some(Command::Predict(arguments)) => {
466 run_predict(arguments, &app_state, cx).await;
467 }
468 Some(Command::Eval(arguments)) => {
469 run_evaluate(arguments, &app_state, cx).await;
470 }
471 Some(Command::ConvertExample {
472 path,
473 output_format,
474 }) => {
475 let example = NamedExample::load(path).unwrap();
476 example.write(output_format, io::stdout()).unwrap();
477 }
478 Some(Command::Score {
479 golden_patch,
480 actual_patch,
481 }) => {
482 let golden_content = std::fs::read_to_string(golden_patch).unwrap();
483 let actual_content = std::fs::read_to_string(actual_patch).unwrap();
484
485 let golden_diff: Vec<DiffLine> = golden_content
486 .lines()
487 .map(|line| DiffLine::parse(line))
488 .collect();
489
490 let actual_diff: Vec<DiffLine> = actual_content
491 .lines()
492 .map(|line| DiffLine::parse(line))
493 .collect();
494
495 let score = delta_chr_f(&golden_diff, &actual_diff);
496 println!("{:.2}", score);
497 }
498 Some(Command::Clean) => {
499 std::fs::remove_dir_all(&*crate::paths::TARGET_ZETA_DIR).unwrap()
500 }
501 };
502
503 let _ = cx.update(|cx| cx.quit());
504 })
505 .detach();
506 });
507}