1mod evaluate;
2mod example;
3mod headless;
4mod metrics;
5mod paths;
6mod predict;
7mod source_location;
8mod training;
9mod util;
10
11use crate::{
12 evaluate::run_evaluate,
13 example::{ExampleFormat, NamedExample},
14 headless::ZetaCliAppState,
15 predict::run_predict,
16 source_location::SourceLocation,
17 training::{context::ContextType, distill::run_distill},
18 util::{open_buffer, open_buffer_with_language_server},
19};
20use ::util::{ResultExt, paths::PathStyle};
21use anyhow::{Result, anyhow};
22use clap::{Args, Parser, Subcommand, ValueEnum};
23use cloud_llm_client::predict_edits_v3;
24use edit_prediction::udiff::DiffLine;
25use edit_prediction_context::EditPredictionExcerptOptions;
26use gpui::{Application, AsyncApp, Entity, prelude::*};
27use language::{Bias, Buffer, BufferSnapshot, Point};
28use metrics::delta_chr_f;
29use project::{Project, Worktree, lsp_store::OpenLspBufferHandle};
30use reqwest_client::ReqwestClient;
31use std::io::{self};
32use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc};
33
34#[derive(Parser, Debug)]
35#[command(name = "zeta")]
36struct ZetaCliArgs {
37 #[arg(long, default_value_t = false)]
38 printenv: bool,
39 #[command(subcommand)]
40 command: Option<Command>,
41}
42
43#[derive(Subcommand, Debug)]
44enum Command {
45 Context(ContextArgs),
46 Predict(PredictArguments),
47 Eval(EvaluateArguments),
48 Distill(DistillArguments),
49 ConvertExample {
50 path: PathBuf,
51 #[arg(long, value_enum, default_value_t = ExampleFormat::Md)]
52 output_format: ExampleFormat,
53 },
54 Score {
55 golden_patch: PathBuf,
56 actual_patch: PathBuf,
57 },
58 Clean,
59}
60
61#[derive(Debug, Args)]
62struct ContextArgs {
63 #[arg(long)]
64 provider: ContextProvider,
65 #[arg(long)]
66 worktree: PathBuf,
67 #[arg(long)]
68 cursor: SourceLocation,
69 #[arg(long)]
70 use_language_server: bool,
71 #[arg(long)]
72 edit_history: Option<FileOrStdin>,
73 #[clap(flatten)]
74 zeta2_args: Zeta2Args,
75}
76
77#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
78enum ContextProvider {
79 Zeta1,
80 #[default]
81 Zeta2,
82}
83
84#[derive(Clone, Debug, Args)]
85struct Zeta2Args {
86 #[arg(long, default_value_t = 8192)]
87 max_prompt_bytes: usize,
88 #[arg(long, default_value_t = 2048)]
89 max_excerpt_bytes: usize,
90 #[arg(long, default_value_t = 1024)]
91 min_excerpt_bytes: usize,
92 #[arg(long, default_value_t = 0.66)]
93 target_before_cursor_over_total_bytes: f32,
94 #[arg(long, default_value_t = 1024)]
95 max_diagnostic_bytes: usize,
96 #[arg(long, value_enum, default_value_t = PromptFormat::default())]
97 prompt_format: PromptFormat,
98 #[arg(long, value_enum, default_value_t = Default::default())]
99 output_format: OutputFormat,
100 #[arg(long, default_value_t = 42)]
101 file_indexing_parallelism: usize,
102 #[arg(long, default_value_t = false)]
103 disable_imports_gathering: bool,
104 #[arg(long, default_value_t = u8::MAX)]
105 max_retrieved_definitions: u8,
106}
107
108#[derive(Debug, Args)]
109pub struct PredictArguments {
110 #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
111 format: PredictionsOutputFormat,
112 example_path: PathBuf,
113 #[clap(flatten)]
114 options: PredictionOptions,
115}
116
117#[derive(Debug, Args)]
118pub struct DistillArguments {
119 split_commit_dataset: PathBuf,
120 #[clap(long, value_enum, default_value_t = ContextType::CurrentFile)]
121 context_type: ContextType,
122 #[clap(long)]
123 batch: Option<String>,
124}
125
126#[derive(Clone, Debug, Args)]
127pub struct PredictionOptions {
128 #[clap(flatten)]
129 zeta2: Zeta2Args,
130 #[clap(long)]
131 provider: PredictionProvider,
132 #[clap(long, value_enum, default_value_t = CacheMode::default())]
133 cache: CacheMode,
134}
135
136#[derive(Debug, ValueEnum, Default, Clone, Copy, PartialEq)]
137pub enum CacheMode {
138 /// Use cached LLM requests and responses, except when multiple repetitions are requested
139 #[default]
140 Auto,
141 /// Use cached LLM requests and responses, based on the hash of the prompt and the endpoint.
142 #[value(alias = "request")]
143 Requests,
144 /// Ignore existing cache entries for both LLM and search.
145 Skip,
146 /// Use cached LLM responses AND search results for full determinism. Fails if they haven't been cached yet.
147 /// Useful for reproducing results and fixing bugs outside of search queries
148 Force,
149}
150
151impl CacheMode {
152 fn use_cached_llm_responses(&self) -> bool {
153 self.assert_not_auto();
154 matches!(self, CacheMode::Requests | CacheMode::Force)
155 }
156
157 fn use_cached_search_results(&self) -> bool {
158 self.assert_not_auto();
159 matches!(self, CacheMode::Force)
160 }
161
162 fn assert_not_auto(&self) {
163 assert_ne!(
164 *self,
165 CacheMode::Auto,
166 "Cache mode should not be auto at this point!"
167 );
168 }
169}
170
171#[derive(clap::ValueEnum, Debug, Clone)]
172pub enum PredictionsOutputFormat {
173 Json,
174 Md,
175 Diff,
176}
177
178#[derive(Debug, Args)]
179pub struct EvaluateArguments {
180 example_paths: Vec<PathBuf>,
181 #[clap(flatten)]
182 options: PredictionOptions,
183 #[clap(short, long, default_value_t = 1, alias = "repeat")]
184 repetitions: u16,
185 #[arg(long)]
186 skip_prediction: bool,
187}
188
189#[derive(clap::ValueEnum, Default, Debug, Clone, Copy, PartialEq)]
190enum PredictionProvider {
191 Zeta1,
192 #[default]
193 Zeta2,
194 Sweep,
195}
196
197fn zeta2_args_to_options(args: &Zeta2Args) -> edit_prediction::ZetaOptions {
198 edit_prediction::ZetaOptions {
199 context: EditPredictionExcerptOptions {
200 max_bytes: args.max_excerpt_bytes,
201 min_bytes: args.min_excerpt_bytes,
202 target_before_cursor_over_total_bytes: args.target_before_cursor_over_total_bytes,
203 },
204 max_prompt_bytes: args.max_prompt_bytes,
205 prompt_format: args.prompt_format.into(),
206 }
207}
208
209#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
210enum PromptFormat {
211 OnlySnippets,
212 #[default]
213 OldTextNewText,
214 Minimal,
215 MinimalQwen,
216 SeedCoder1120,
217}
218
219impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
220 fn into(self) -> predict_edits_v3::PromptFormat {
221 match self {
222 Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
223 Self::OldTextNewText => predict_edits_v3::PromptFormat::OldTextNewText,
224 Self::Minimal => predict_edits_v3::PromptFormat::Minimal,
225 Self::MinimalQwen => predict_edits_v3::PromptFormat::MinimalQwen,
226 Self::SeedCoder1120 => predict_edits_v3::PromptFormat::SeedCoder1120,
227 }
228 }
229}
230
231#[derive(clap::ValueEnum, Default, Debug, Clone)]
232enum OutputFormat {
233 #[default]
234 Prompt,
235 Request,
236 Full,
237}
238
239#[derive(Debug, Clone)]
240enum FileOrStdin {
241 File(PathBuf),
242 Stdin,
243}
244
245impl FileOrStdin {
246 async fn read_to_string(&self) -> Result<String, std::io::Error> {
247 match self {
248 FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
249 FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
250 }
251 }
252}
253
254impl FromStr for FileOrStdin {
255 type Err = <PathBuf as FromStr>::Err;
256
257 fn from_str(s: &str) -> Result<Self, Self::Err> {
258 match s {
259 "-" => Ok(Self::Stdin),
260 _ => Ok(Self::File(PathBuf::from_str(s)?)),
261 }
262 }
263}
264
265struct LoadedContext {
266 full_path_str: String,
267 snapshot: BufferSnapshot,
268 clipped_cursor: Point,
269 worktree: Entity<Worktree>,
270 project: Entity<Project>,
271 buffer: Entity<Buffer>,
272 lsp_open_handle: Option<OpenLspBufferHandle>,
273}
274
275async fn load_context(
276 args: &ContextArgs,
277 app_state: &Arc<ZetaCliAppState>,
278 cx: &mut AsyncApp,
279) -> Result<LoadedContext> {
280 let ContextArgs {
281 worktree: worktree_path,
282 cursor,
283 use_language_server,
284 ..
285 } = args;
286
287 let worktree_path = worktree_path.canonicalize()?;
288
289 let project = cx.update(|cx| {
290 Project::local(
291 app_state.client.clone(),
292 app_state.node_runtime.clone(),
293 app_state.user_store.clone(),
294 app_state.languages.clone(),
295 app_state.fs.clone(),
296 None,
297 cx,
298 )
299 })?;
300
301 let worktree = project
302 .update(cx, |project, cx| {
303 project.create_worktree(&worktree_path, true, cx)
304 })?
305 .await?;
306
307 let mut ready_languages = HashSet::default();
308 let (lsp_open_handle, buffer) = if *use_language_server {
309 let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
310 project.clone(),
311 worktree.clone(),
312 cursor.path.clone(),
313 &mut ready_languages,
314 cx,
315 )
316 .await?;
317 (Some(lsp_open_handle), buffer)
318 } else {
319 let buffer =
320 open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
321 (None, buffer)
322 };
323
324 let full_path_str = worktree
325 .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
326 .display(PathStyle::local())
327 .to_string();
328
329 let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
330 let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
331 if clipped_cursor != cursor.point {
332 let max_row = snapshot.max_point().row;
333 if cursor.point.row < max_row {
334 return Err(anyhow!(
335 "Cursor position {:?} is out of bounds (line length is {})",
336 cursor.point,
337 snapshot.line_len(cursor.point.row)
338 ));
339 } else {
340 return Err(anyhow!(
341 "Cursor position {:?} is out of bounds (max row is {})",
342 cursor.point,
343 max_row
344 ));
345 }
346 }
347
348 Ok(LoadedContext {
349 full_path_str,
350 snapshot,
351 clipped_cursor,
352 worktree,
353 project,
354 buffer,
355 lsp_open_handle,
356 })
357}
358
359async fn zeta2_context(
360 args: ContextArgs,
361 app_state: &Arc<ZetaCliAppState>,
362 cx: &mut AsyncApp,
363) -> Result<String> {
364 let LoadedContext {
365 worktree,
366 project,
367 buffer,
368 clipped_cursor,
369 lsp_open_handle: _handle,
370 ..
371 } = load_context(&args, app_state, cx).await?;
372
373 // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
374 // the whole worktree.
375 worktree
376 .read_with(cx, |worktree, _cx| {
377 worktree.as_local().unwrap().scan_complete()
378 })?
379 .await;
380 let output = cx
381 .update(|cx| {
382 let store = cx.new(|cx| {
383 edit_prediction::EditPredictionStore::new(
384 app_state.client.clone(),
385 app_state.user_store.clone(),
386 cx,
387 )
388 });
389 store.update(cx, |store, cx| {
390 store.set_options(zeta2_args_to_options(&args.zeta2_args));
391 store.register_buffer(&buffer, &project, cx);
392 });
393 cx.spawn(async move |cx| {
394 let updates_rx = store.update(cx, |store, cx| {
395 let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
396 store.set_use_context(true);
397 store.refresh_context(&project, &buffer, cursor, cx);
398 store.project_context_updates(&project).unwrap()
399 })?;
400
401 updates_rx.recv().await.ok();
402
403 let context = store.update(cx, |store, cx| {
404 store.context_for_project(&project, cx).to_vec()
405 })?;
406
407 anyhow::Ok(serde_json::to_string_pretty(&context).unwrap())
408 })
409 })?
410 .await?;
411
412 Ok(output)
413}
414
415async fn zeta1_context(
416 args: ContextArgs,
417 app_state: &Arc<ZetaCliAppState>,
418 cx: &mut AsyncApp,
419) -> Result<edit_prediction::zeta1::GatherContextOutput> {
420 let LoadedContext {
421 full_path_str,
422 snapshot,
423 clipped_cursor,
424 ..
425 } = load_context(&args, app_state, cx).await?;
426
427 let events = match args.edit_history {
428 Some(events) => events.read_to_string().await?,
429 None => String::new(),
430 };
431
432 let prompt_for_events = move || (events, 0);
433 cx.update(|cx| {
434 edit_prediction::zeta1::gather_context(
435 full_path_str,
436 &snapshot,
437 clipped_cursor,
438 prompt_for_events,
439 cloud_llm_client::PredictEditsRequestTrigger::Cli,
440 cx,
441 )
442 })?
443 .await
444}
445
446fn main() {
447 zlog::init();
448 zlog::init_output_stderr();
449 let args = ZetaCliArgs::parse();
450 let http_client = Arc::new(ReqwestClient::new());
451 let app = Application::headless().with_http_client(http_client);
452
453 app.run(move |cx| {
454 let app_state = Arc::new(headless::init(cx));
455 cx.spawn(async move |cx| {
456 match args.command {
457 None => {
458 if args.printenv {
459 ::util::shell_env::print_env();
460 } else {
461 panic!("Expected a command");
462 }
463 }
464 Some(Command::Context(context_args)) => {
465 let result = match context_args.provider {
466 ContextProvider::Zeta1 => {
467 let context =
468 zeta1_context(context_args, &app_state, cx).await.unwrap();
469 serde_json::to_string_pretty(&context.body).unwrap()
470 }
471 ContextProvider::Zeta2 => {
472 zeta2_context(context_args, &app_state, cx).await.unwrap()
473 }
474 };
475 println!("{}", result);
476 }
477 Some(Command::Predict(arguments)) => {
478 run_predict(arguments, &app_state, cx).await;
479 }
480 Some(Command::Eval(arguments)) => {
481 run_evaluate(arguments, &app_state, cx).await;
482 }
483 Some(Command::Distill(arguments)) => {
484 let _guard = cx
485 .update(|cx| gpui_tokio::Tokio::handle(cx))
486 .unwrap()
487 .enter();
488 run_distill(arguments).await.log_err();
489 }
490 Some(Command::ConvertExample {
491 path,
492 output_format,
493 }) => {
494 let example = NamedExample::load(path).unwrap();
495 example.write(output_format, io::stdout()).unwrap();
496 }
497 Some(Command::Score {
498 golden_patch,
499 actual_patch,
500 }) => {
501 let golden_content = std::fs::read_to_string(golden_patch).unwrap();
502 let actual_content = std::fs::read_to_string(actual_patch).unwrap();
503
504 let golden_diff: Vec<DiffLine> = golden_content
505 .lines()
506 .map(|line| DiffLine::parse(line))
507 .collect();
508
509 let actual_diff: Vec<DiffLine> = actual_content
510 .lines()
511 .map(|line| DiffLine::parse(line))
512 .collect();
513
514 let score = delta_chr_f(&golden_diff, &actual_diff);
515 println!("{:.2}", score);
516 }
517 Some(Command::Clean) => {
518 std::fs::remove_dir_all(&*crate::paths::TARGET_ZETA_DIR).unwrap()
519 }
520 };
521
522 let _ = cx.update(|cx| cx.quit());
523 })
524 .detach();
525 });
526}