1mod evaluate;
2mod example;
3mod headless;
4mod paths;
5mod predict;
6mod source_location;
7mod syntax_retrieval_stats;
8mod util;
9
10use crate::evaluate::{EvaluateArguments, run_evaluate};
11use crate::example::{ExampleFormat, NamedExample};
12use crate::predict::{PredictArguments, run_zeta2_predict};
13use crate::syntax_retrieval_stats::retrieval_stats;
14use ::util::paths::PathStyle;
15use anyhow::{Result, anyhow};
16use clap::{Args, Parser, Subcommand};
17use cloud_llm_client::predict_edits_v3;
18use edit_prediction_context::{
19 EditPredictionContextOptions, EditPredictionExcerptOptions, EditPredictionScoreOptions,
20};
21use gpui::{Application, AsyncApp, Entity, prelude::*};
22use language::{Bias, Buffer, BufferSnapshot, Point};
23use project::{Project, Worktree};
24use reqwest_client::ReqwestClient;
25use serde_json::json;
26use std::io::{self};
27use std::time::Duration;
28use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc};
29use zeta2::ContextMode;
30
31use crate::headless::ZetaCliAppState;
32use crate::source_location::SourceLocation;
33use crate::util::{open_buffer, open_buffer_with_language_server};
34
35#[derive(Parser, Debug)]
36#[command(name = "zeta")]
37struct ZetaCliArgs {
38 #[command(subcommand)]
39 command: Command,
40}
41
42#[derive(Subcommand, Debug)]
43enum Command {
44 Zeta1 {
45 #[command(subcommand)]
46 command: Zeta1Command,
47 },
48 Zeta2 {
49 #[command(subcommand)]
50 command: Zeta2Command,
51 },
52 ConvertExample {
53 path: PathBuf,
54 #[arg(long, value_enum, default_value_t = ExampleFormat::Md)]
55 output_format: ExampleFormat,
56 },
57 Clean,
58}
59
60#[derive(Subcommand, Debug)]
61enum Zeta1Command {
62 Context {
63 #[clap(flatten)]
64 context_args: ContextArgs,
65 },
66}
67
68#[derive(Subcommand, Debug)]
69enum Zeta2Command {
70 Syntax {
71 #[clap(flatten)]
72 args: Zeta2Args,
73 #[clap(flatten)]
74 syntax_args: Zeta2SyntaxArgs,
75 #[command(subcommand)]
76 command: Zeta2SyntaxCommand,
77 },
78 Predict(PredictArguments),
79 Eval(EvaluateArguments),
80}
81
82#[derive(Subcommand, Debug)]
83enum Zeta2SyntaxCommand {
84 Context {
85 #[clap(flatten)]
86 context_args: ContextArgs,
87 },
88 Stats {
89 #[arg(long)]
90 worktree: PathBuf,
91 #[arg(long)]
92 extension: Option<String>,
93 #[arg(long)]
94 limit: Option<usize>,
95 #[arg(long)]
96 skip: Option<usize>,
97 },
98}
99
100#[derive(Debug, Args)]
101#[group(requires = "worktree")]
102struct ContextArgs {
103 #[arg(long)]
104 worktree: PathBuf,
105 #[arg(long)]
106 cursor: SourceLocation,
107 #[arg(long)]
108 use_language_server: bool,
109 #[arg(long)]
110 edit_history: Option<FileOrStdin>,
111}
112
113#[derive(Debug, Args)]
114struct Zeta2Args {
115 #[arg(long, default_value_t = 8192)]
116 max_prompt_bytes: usize,
117 #[arg(long, default_value_t = 2048)]
118 max_excerpt_bytes: usize,
119 #[arg(long, default_value_t = 1024)]
120 min_excerpt_bytes: usize,
121 #[arg(long, default_value_t = 0.66)]
122 target_before_cursor_over_total_bytes: f32,
123 #[arg(long, default_value_t = 1024)]
124 max_diagnostic_bytes: usize,
125 #[arg(long, value_enum, default_value_t = PromptFormat::default())]
126 prompt_format: PromptFormat,
127 #[arg(long, value_enum, default_value_t = Default::default())]
128 output_format: OutputFormat,
129 #[arg(long, default_value_t = 42)]
130 file_indexing_parallelism: usize,
131}
132
133#[derive(Debug, Args)]
134struct Zeta2SyntaxArgs {
135 #[arg(long, default_value_t = false)]
136 disable_imports_gathering: bool,
137 #[arg(long, default_value_t = u8::MAX)]
138 max_retrieved_definitions: u8,
139}
140
141fn syntax_args_to_options(
142 zeta2_args: &Zeta2Args,
143 syntax_args: &Zeta2SyntaxArgs,
144 omit_excerpt_overlaps: bool,
145) -> zeta2::ZetaOptions {
146 zeta2::ZetaOptions {
147 context: ContextMode::Syntax(EditPredictionContextOptions {
148 max_retrieved_declarations: syntax_args.max_retrieved_definitions,
149 use_imports: !syntax_args.disable_imports_gathering,
150 excerpt: EditPredictionExcerptOptions {
151 max_bytes: zeta2_args.max_excerpt_bytes,
152 min_bytes: zeta2_args.min_excerpt_bytes,
153 target_before_cursor_over_total_bytes: zeta2_args
154 .target_before_cursor_over_total_bytes,
155 },
156 score: EditPredictionScoreOptions {
157 omit_excerpt_overlaps,
158 },
159 }),
160 max_diagnostic_bytes: zeta2_args.max_diagnostic_bytes,
161 max_prompt_bytes: zeta2_args.max_prompt_bytes,
162 prompt_format: zeta2_args.prompt_format.into(),
163 file_indexing_parallelism: zeta2_args.file_indexing_parallelism,
164 buffer_change_grouping_interval: Duration::ZERO,
165 }
166}
167
168#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
169enum PromptFormat {
170 MarkedExcerpt,
171 LabeledSections,
172 OnlySnippets,
173 #[default]
174 NumberedLines,
175 OldTextNewText,
176}
177
178impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
179 fn into(self) -> predict_edits_v3::PromptFormat {
180 match self {
181 Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
182 Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
183 Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
184 Self::NumberedLines => predict_edits_v3::PromptFormat::NumLinesUniDiff,
185 Self::OldTextNewText => predict_edits_v3::PromptFormat::OldTextNewText,
186 }
187 }
188}
189
190#[derive(clap::ValueEnum, Default, Debug, Clone)]
191enum OutputFormat {
192 #[default]
193 Prompt,
194 Request,
195 Full,
196}
197
198#[derive(Debug, Clone)]
199enum FileOrStdin {
200 File(PathBuf),
201 Stdin,
202}
203
204impl FileOrStdin {
205 async fn read_to_string(&self) -> Result<String, std::io::Error> {
206 match self {
207 FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
208 FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
209 }
210 }
211}
212
213impl FromStr for FileOrStdin {
214 type Err = <PathBuf as FromStr>::Err;
215
216 fn from_str(s: &str) -> Result<Self, Self::Err> {
217 match s {
218 "-" => Ok(Self::Stdin),
219 _ => Ok(Self::File(PathBuf::from_str(s)?)),
220 }
221 }
222}
223
224struct LoadedContext {
225 full_path_str: String,
226 snapshot: BufferSnapshot,
227 clipped_cursor: Point,
228 worktree: Entity<Worktree>,
229 project: Entity<Project>,
230 buffer: Entity<Buffer>,
231}
232
233async fn load_context(
234 args: &ContextArgs,
235 app_state: &Arc<ZetaCliAppState>,
236 cx: &mut AsyncApp,
237) -> Result<LoadedContext> {
238 let ContextArgs {
239 worktree: worktree_path,
240 cursor,
241 use_language_server,
242 ..
243 } = args;
244
245 let worktree_path = worktree_path.canonicalize()?;
246
247 let project = cx.update(|cx| {
248 Project::local(
249 app_state.client.clone(),
250 app_state.node_runtime.clone(),
251 app_state.user_store.clone(),
252 app_state.languages.clone(),
253 app_state.fs.clone(),
254 None,
255 cx,
256 )
257 })?;
258
259 let worktree = project
260 .update(cx, |project, cx| {
261 project.create_worktree(&worktree_path, true, cx)
262 })?
263 .await?;
264
265 let mut ready_languages = HashSet::default();
266 let (_lsp_open_handle, buffer) = if *use_language_server {
267 let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
268 project.clone(),
269 worktree.clone(),
270 cursor.path.clone(),
271 &mut ready_languages,
272 cx,
273 )
274 .await?;
275 (Some(lsp_open_handle), buffer)
276 } else {
277 let buffer =
278 open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
279 (None, buffer)
280 };
281
282 let full_path_str = worktree
283 .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
284 .display(PathStyle::local())
285 .to_string();
286
287 let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
288 let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
289 if clipped_cursor != cursor.point {
290 let max_row = snapshot.max_point().row;
291 if cursor.point.row < max_row {
292 return Err(anyhow!(
293 "Cursor position {:?} is out of bounds (line length is {})",
294 cursor.point,
295 snapshot.line_len(cursor.point.row)
296 ));
297 } else {
298 return Err(anyhow!(
299 "Cursor position {:?} is out of bounds (max row is {})",
300 cursor.point,
301 max_row
302 ));
303 }
304 }
305
306 Ok(LoadedContext {
307 full_path_str,
308 snapshot,
309 clipped_cursor,
310 worktree,
311 project,
312 buffer,
313 })
314}
315
316async fn zeta2_syntax_context(
317 zeta2_args: Zeta2Args,
318 syntax_args: Zeta2SyntaxArgs,
319 args: ContextArgs,
320 app_state: &Arc<ZetaCliAppState>,
321 cx: &mut AsyncApp,
322) -> Result<String> {
323 let LoadedContext {
324 worktree,
325 project,
326 buffer,
327 clipped_cursor,
328 ..
329 } = load_context(&args, app_state, cx).await?;
330
331 // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
332 // the whole worktree.
333 worktree
334 .read_with(cx, |worktree, _cx| {
335 worktree.as_local().unwrap().scan_complete()
336 })?
337 .await;
338 let output = cx
339 .update(|cx| {
340 let zeta = cx.new(|cx| {
341 zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
342 });
343 let indexing_done_task = zeta.update(cx, |zeta, cx| {
344 zeta.set_options(syntax_args_to_options(&zeta2_args, &syntax_args, true));
345 zeta.register_buffer(&buffer, &project, cx);
346 zeta.wait_for_initial_indexing(&project, cx)
347 });
348 cx.spawn(async move |cx| {
349 indexing_done_task.await?;
350 let request = zeta
351 .update(cx, |zeta, cx| {
352 let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
353 zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
354 })?
355 .await?;
356
357 let (prompt_string, section_labels) = cloud_zeta2_prompt::build_prompt(&request)?;
358
359 match zeta2_args.output_format {
360 OutputFormat::Prompt => anyhow::Ok(prompt_string),
361 OutputFormat::Request => anyhow::Ok(serde_json::to_string_pretty(&request)?),
362 OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({
363 "request": request,
364 "prompt": prompt_string,
365 "section_labels": section_labels,
366 }))?),
367 }
368 })
369 })?
370 .await?;
371
372 Ok(output)
373}
374
375async fn zeta1_context(
376 args: ContextArgs,
377 app_state: &Arc<ZetaCliAppState>,
378 cx: &mut AsyncApp,
379) -> Result<zeta::GatherContextOutput> {
380 let LoadedContext {
381 full_path_str,
382 snapshot,
383 clipped_cursor,
384 ..
385 } = load_context(&args, app_state, cx).await?;
386
387 let events = match args.edit_history {
388 Some(events) => events.read_to_string().await?,
389 None => String::new(),
390 };
391
392 let prompt_for_events = move || (events, 0);
393 cx.update(|cx| {
394 zeta::gather_context(
395 full_path_str,
396 &snapshot,
397 clipped_cursor,
398 prompt_for_events,
399 cx,
400 )
401 })?
402 .await
403}
404
405fn main() {
406 zlog::init();
407 zlog::init_output_stderr();
408 let args = ZetaCliArgs::parse();
409 let http_client = Arc::new(ReqwestClient::new());
410 let app = Application::headless().with_http_client(http_client);
411
412 app.run(move |cx| {
413 let app_state = Arc::new(headless::init(cx));
414 cx.spawn(async move |cx| {
415 match args.command {
416 Command::Zeta1 {
417 command: Zeta1Command::Context { context_args },
418 } => {
419 let context = zeta1_context(context_args, &app_state, cx).await.unwrap();
420 let result = serde_json::to_string_pretty(&context.body).unwrap();
421 println!("{}", result);
422 }
423 Command::Zeta2 { command } => match command {
424 Zeta2Command::Predict(arguments) => {
425 run_zeta2_predict(arguments, &app_state, cx).await;
426 }
427 Zeta2Command::Eval(arguments) => {
428 run_evaluate(arguments, &app_state, cx).await;
429 }
430 Zeta2Command::Syntax {
431 args,
432 syntax_args,
433 command,
434 } => {
435 let result = match command {
436 Zeta2SyntaxCommand::Context { context_args } => {
437 zeta2_syntax_context(
438 args,
439 syntax_args,
440 context_args,
441 &app_state,
442 cx,
443 )
444 .await
445 }
446 Zeta2SyntaxCommand::Stats {
447 worktree,
448 extension,
449 limit,
450 skip,
451 } => {
452 retrieval_stats(
453 worktree,
454 app_state,
455 extension,
456 limit,
457 skip,
458 syntax_args_to_options(&args, &syntax_args, false),
459 cx,
460 )
461 .await
462 }
463 };
464 println!("{}", result.unwrap());
465 }
466 },
467 Command::ConvertExample {
468 path,
469 output_format,
470 } => {
471 let example = NamedExample::load(path).unwrap();
472 example.write(output_format, io::stdout()).unwrap();
473 }
474 Command::Clean => std::fs::remove_dir_all(&*crate::paths::TARGET_ZETA_DIR).unwrap(),
475 };
476
477 let _ = cx.update(|cx| cx.quit());
478 })
479 .detach();
480 });
481}