1mod evaluate;
2mod example;
3mod headless;
4mod paths;
5mod predict;
6mod source_location;
7mod syntax_retrieval_stats;
8mod util;
9
10use crate::evaluate::{EvaluateArguments, run_evaluate};
11use crate::example::{ExampleFormat, NamedExample};
12use crate::predict::{PredictArguments, run_zeta2_predict};
13use crate::syntax_retrieval_stats::retrieval_stats;
14use ::util::paths::PathStyle;
15use anyhow::{Result, anyhow};
16use clap::{Args, Parser, Subcommand};
17use cloud_llm_client::predict_edits_v3;
18use edit_prediction_context::{
19 EditPredictionContextOptions, EditPredictionExcerptOptions, EditPredictionScoreOptions,
20};
21use gpui::{Application, AsyncApp, Entity, prelude::*};
22use language::{Bias, Buffer, BufferSnapshot, Point};
23use project::{Project, Worktree};
24use reqwest_client::ReqwestClient;
25use serde_json::json;
26use std::io::{self};
27use std::time::Duration;
28use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc};
29use zeta2::ContextMode;
30
31use crate::headless::ZetaCliAppState;
32use crate::source_location::SourceLocation;
33use crate::util::{open_buffer, open_buffer_with_language_server};
34
35#[derive(Parser, Debug)]
36#[command(name = "zeta")]
37struct ZetaCliArgs {
38 #[arg(long, default_value_t = false)]
39 printenv: bool,
40 #[command(subcommand)]
41 command: Option<Command>,
42}
43
44#[derive(Subcommand, Debug)]
45enum Command {
46 Zeta1 {
47 #[command(subcommand)]
48 command: Zeta1Command,
49 },
50 Zeta2 {
51 #[command(subcommand)]
52 command: Zeta2Command,
53 },
54 ConvertExample {
55 path: PathBuf,
56 #[arg(long, value_enum, default_value_t = ExampleFormat::Md)]
57 output_format: ExampleFormat,
58 },
59 Clean,
60}
61
62#[derive(Subcommand, Debug)]
63enum Zeta1Command {
64 Context {
65 #[clap(flatten)]
66 context_args: ContextArgs,
67 },
68}
69
70#[derive(Subcommand, Debug)]
71enum Zeta2Command {
72 Syntax {
73 #[clap(flatten)]
74 args: Zeta2Args,
75 #[clap(flatten)]
76 syntax_args: Zeta2SyntaxArgs,
77 #[command(subcommand)]
78 command: Zeta2SyntaxCommand,
79 },
80 Predict(PredictArguments),
81 Eval(EvaluateArguments),
82}
83
84#[derive(Subcommand, Debug)]
85enum Zeta2SyntaxCommand {
86 Context {
87 #[clap(flatten)]
88 context_args: ContextArgs,
89 },
90 Stats {
91 #[arg(long)]
92 worktree: PathBuf,
93 #[arg(long)]
94 extension: Option<String>,
95 #[arg(long)]
96 limit: Option<usize>,
97 #[arg(long)]
98 skip: Option<usize>,
99 },
100}
101
102#[derive(Debug, Args)]
103#[group(requires = "worktree")]
104struct ContextArgs {
105 #[arg(long)]
106 worktree: PathBuf,
107 #[arg(long)]
108 cursor: SourceLocation,
109 #[arg(long)]
110 use_language_server: bool,
111 #[arg(long)]
112 edit_history: Option<FileOrStdin>,
113}
114
115#[derive(Debug, Args)]
116struct Zeta2Args {
117 #[arg(long, default_value_t = 8192)]
118 max_prompt_bytes: usize,
119 #[arg(long, default_value_t = 2048)]
120 max_excerpt_bytes: usize,
121 #[arg(long, default_value_t = 1024)]
122 min_excerpt_bytes: usize,
123 #[arg(long, default_value_t = 0.66)]
124 target_before_cursor_over_total_bytes: f32,
125 #[arg(long, default_value_t = 1024)]
126 max_diagnostic_bytes: usize,
127 #[arg(long, value_enum, default_value_t = PromptFormat::default())]
128 prompt_format: PromptFormat,
129 #[arg(long, value_enum, default_value_t = Default::default())]
130 output_format: OutputFormat,
131 #[arg(long, default_value_t = 42)]
132 file_indexing_parallelism: usize,
133}
134
135#[derive(Debug, Args)]
136struct Zeta2SyntaxArgs {
137 #[arg(long, default_value_t = false)]
138 disable_imports_gathering: bool,
139 #[arg(long, default_value_t = u8::MAX)]
140 max_retrieved_definitions: u8,
141}
142
143fn syntax_args_to_options(
144 zeta2_args: &Zeta2Args,
145 syntax_args: &Zeta2SyntaxArgs,
146 omit_excerpt_overlaps: bool,
147) -> zeta2::ZetaOptions {
148 zeta2::ZetaOptions {
149 context: ContextMode::Syntax(EditPredictionContextOptions {
150 max_retrieved_declarations: syntax_args.max_retrieved_definitions,
151 use_imports: !syntax_args.disable_imports_gathering,
152 excerpt: EditPredictionExcerptOptions {
153 max_bytes: zeta2_args.max_excerpt_bytes,
154 min_bytes: zeta2_args.min_excerpt_bytes,
155 target_before_cursor_over_total_bytes: zeta2_args
156 .target_before_cursor_over_total_bytes,
157 },
158 score: EditPredictionScoreOptions {
159 omit_excerpt_overlaps,
160 },
161 }),
162 max_diagnostic_bytes: zeta2_args.max_diagnostic_bytes,
163 max_prompt_bytes: zeta2_args.max_prompt_bytes,
164 prompt_format: zeta2_args.prompt_format.into(),
165 file_indexing_parallelism: zeta2_args.file_indexing_parallelism,
166 buffer_change_grouping_interval: Duration::ZERO,
167 }
168}
169
170#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
171enum PromptFormat {
172 MarkedExcerpt,
173 LabeledSections,
174 OnlySnippets,
175 #[default]
176 NumberedLines,
177 OldTextNewText,
178 Minimal,
179 MinimalQwen,
180}
181
182impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
183 fn into(self) -> predict_edits_v3::PromptFormat {
184 match self {
185 Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
186 Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
187 Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
188 Self::NumberedLines => predict_edits_v3::PromptFormat::NumLinesUniDiff,
189 Self::OldTextNewText => predict_edits_v3::PromptFormat::OldTextNewText,
190 Self::Minimal => predict_edits_v3::PromptFormat::Minimal,
191 Self::MinimalQwen => predict_edits_v3::PromptFormat::MinimalQwen,
192 }
193 }
194}
195
196#[derive(clap::ValueEnum, Default, Debug, Clone)]
197enum OutputFormat {
198 #[default]
199 Prompt,
200 Request,
201 Full,
202}
203
204#[derive(Debug, Clone)]
205enum FileOrStdin {
206 File(PathBuf),
207 Stdin,
208}
209
210impl FileOrStdin {
211 async fn read_to_string(&self) -> Result<String, std::io::Error> {
212 match self {
213 FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
214 FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
215 }
216 }
217}
218
219impl FromStr for FileOrStdin {
220 type Err = <PathBuf as FromStr>::Err;
221
222 fn from_str(s: &str) -> Result<Self, Self::Err> {
223 match s {
224 "-" => Ok(Self::Stdin),
225 _ => Ok(Self::File(PathBuf::from_str(s)?)),
226 }
227 }
228}
229
230struct LoadedContext {
231 full_path_str: String,
232 snapshot: BufferSnapshot,
233 clipped_cursor: Point,
234 worktree: Entity<Worktree>,
235 project: Entity<Project>,
236 buffer: Entity<Buffer>,
237}
238
239async fn load_context(
240 args: &ContextArgs,
241 app_state: &Arc<ZetaCliAppState>,
242 cx: &mut AsyncApp,
243) -> Result<LoadedContext> {
244 let ContextArgs {
245 worktree: worktree_path,
246 cursor,
247 use_language_server,
248 ..
249 } = args;
250
251 let worktree_path = worktree_path.canonicalize()?;
252
253 let project = cx.update(|cx| {
254 Project::local(
255 app_state.client.clone(),
256 app_state.node_runtime.clone(),
257 app_state.user_store.clone(),
258 app_state.languages.clone(),
259 app_state.fs.clone(),
260 None,
261 cx,
262 )
263 })?;
264
265 let worktree = project
266 .update(cx, |project, cx| {
267 project.create_worktree(&worktree_path, true, cx)
268 })?
269 .await?;
270
271 let mut ready_languages = HashSet::default();
272 let (_lsp_open_handle, buffer) = if *use_language_server {
273 let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
274 project.clone(),
275 worktree.clone(),
276 cursor.path.clone(),
277 &mut ready_languages,
278 cx,
279 )
280 .await?;
281 (Some(lsp_open_handle), buffer)
282 } else {
283 let buffer =
284 open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
285 (None, buffer)
286 };
287
288 let full_path_str = worktree
289 .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
290 .display(PathStyle::local())
291 .to_string();
292
293 let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
294 let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
295 if clipped_cursor != cursor.point {
296 let max_row = snapshot.max_point().row;
297 if cursor.point.row < max_row {
298 return Err(anyhow!(
299 "Cursor position {:?} is out of bounds (line length is {})",
300 cursor.point,
301 snapshot.line_len(cursor.point.row)
302 ));
303 } else {
304 return Err(anyhow!(
305 "Cursor position {:?} is out of bounds (max row is {})",
306 cursor.point,
307 max_row
308 ));
309 }
310 }
311
312 Ok(LoadedContext {
313 full_path_str,
314 snapshot,
315 clipped_cursor,
316 worktree,
317 project,
318 buffer,
319 })
320}
321
322async fn zeta2_syntax_context(
323 zeta2_args: Zeta2Args,
324 syntax_args: Zeta2SyntaxArgs,
325 args: ContextArgs,
326 app_state: &Arc<ZetaCliAppState>,
327 cx: &mut AsyncApp,
328) -> Result<String> {
329 let LoadedContext {
330 worktree,
331 project,
332 buffer,
333 clipped_cursor,
334 ..
335 } = load_context(&args, app_state, cx).await?;
336
337 // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
338 // the whole worktree.
339 worktree
340 .read_with(cx, |worktree, _cx| {
341 worktree.as_local().unwrap().scan_complete()
342 })?
343 .await;
344 let output = cx
345 .update(|cx| {
346 let zeta = cx.new(|cx| {
347 zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
348 });
349 let indexing_done_task = zeta.update(cx, |zeta, cx| {
350 zeta.set_options(syntax_args_to_options(&zeta2_args, &syntax_args, true));
351 zeta.register_buffer(&buffer, &project, cx);
352 zeta.wait_for_initial_indexing(&project, cx)
353 });
354 cx.spawn(async move |cx| {
355 indexing_done_task.await?;
356 let request = zeta
357 .update(cx, |zeta, cx| {
358 let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
359 zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
360 })?
361 .await?;
362
363 let (prompt_string, section_labels) = cloud_zeta2_prompt::build_prompt(&request)?;
364
365 match zeta2_args.output_format {
366 OutputFormat::Prompt => anyhow::Ok(prompt_string),
367 OutputFormat::Request => anyhow::Ok(serde_json::to_string_pretty(&request)?),
368 OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({
369 "request": request,
370 "prompt": prompt_string,
371 "section_labels": section_labels,
372 }))?),
373 }
374 })
375 })?
376 .await?;
377
378 Ok(output)
379}
380
381async fn zeta1_context(
382 args: ContextArgs,
383 app_state: &Arc<ZetaCliAppState>,
384 cx: &mut AsyncApp,
385) -> Result<zeta::GatherContextOutput> {
386 let LoadedContext {
387 full_path_str,
388 snapshot,
389 clipped_cursor,
390 ..
391 } = load_context(&args, app_state, cx).await?;
392
393 let events = match args.edit_history {
394 Some(events) => events.read_to_string().await?,
395 None => String::new(),
396 };
397
398 let prompt_for_events = move || (events, 0);
399 cx.update(|cx| {
400 zeta::gather_context(
401 full_path_str,
402 &snapshot,
403 clipped_cursor,
404 prompt_for_events,
405 cx,
406 )
407 })?
408 .await
409}
410
411fn main() {
412 zlog::init();
413 zlog::init_output_stderr();
414 let args = ZetaCliArgs::parse();
415 let http_client = Arc::new(ReqwestClient::new());
416 let app = Application::headless().with_http_client(http_client);
417
418 app.run(move |cx| {
419 let app_state = Arc::new(headless::init(cx));
420 cx.spawn(async move |cx| {
421 match args.command {
422 None => {
423 if args.printenv {
424 ::util::shell_env::print_env();
425 return;
426 } else {
427 panic!("Expected a command");
428 }
429 }
430 Some(Command::Zeta1 {
431 command: Zeta1Command::Context { context_args },
432 }) => {
433 let context = zeta1_context(context_args, &app_state, cx).await.unwrap();
434 let result = serde_json::to_string_pretty(&context.body).unwrap();
435 println!("{}", result);
436 }
437 Some(Command::Zeta2 { command }) => match command {
438 Zeta2Command::Predict(arguments) => {
439 run_zeta2_predict(arguments, &app_state, cx).await;
440 }
441 Zeta2Command::Eval(arguments) => {
442 run_evaluate(arguments, &app_state, cx).await;
443 }
444 Zeta2Command::Syntax {
445 args,
446 syntax_args,
447 command,
448 } => {
449 let result = match command {
450 Zeta2SyntaxCommand::Context { context_args } => {
451 zeta2_syntax_context(
452 args,
453 syntax_args,
454 context_args,
455 &app_state,
456 cx,
457 )
458 .await
459 }
460 Zeta2SyntaxCommand::Stats {
461 worktree,
462 extension,
463 limit,
464 skip,
465 } => {
466 retrieval_stats(
467 worktree,
468 app_state,
469 extension,
470 limit,
471 skip,
472 syntax_args_to_options(&args, &syntax_args, false),
473 cx,
474 )
475 .await
476 }
477 };
478 println!("{}", result.unwrap());
479 }
480 },
481 Some(Command::ConvertExample {
482 path,
483 output_format,
484 }) => {
485 let example = NamedExample::load(path).unwrap();
486 example.write(output_format, io::stdout()).unwrap();
487 }
488 Some(Command::Clean) => {
489 std::fs::remove_dir_all(&*crate::paths::TARGET_ZETA_DIR).unwrap()
490 }
491 };
492
493 let _ = cx.update(|cx| cx.quit());
494 })
495 .detach();
496 });
497}