1mod evaluate;
2mod example;
3mod headless;
4mod paths;
5mod predict;
6mod source_location;
7mod syntax_retrieval_stats;
8mod util;
9
10use crate::evaluate::{EvaluateArguments, run_evaluate};
11use crate::example::{ExampleFormat, NamedExample};
12use crate::predict::{PredictArguments, run_zeta2_predict};
13use crate::syntax_retrieval_stats::retrieval_stats;
14use ::util::paths::PathStyle;
15use anyhow::{Result, anyhow};
16use clap::{Args, Parser, Subcommand};
17use cloud_llm_client::predict_edits_v3;
18use edit_prediction_context::{
19 EditPredictionContextOptions, EditPredictionExcerptOptions, EditPredictionScoreOptions,
20};
21use gpui::{Application, AsyncApp, Entity, prelude::*};
22use language::{Bias, Buffer, BufferSnapshot, Point};
23use project::{Project, Worktree};
24use reqwest_client::ReqwestClient;
25use serde_json::json;
26use std::io::{self};
27use std::time::Duration;
28use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc};
29use zeta2::ContextMode;
30
31use crate::headless::ZetaCliAppState;
32use crate::source_location::SourceLocation;
33use crate::util::{open_buffer, open_buffer_with_language_server};
34
35#[derive(Parser, Debug)]
36#[command(name = "zeta")]
37struct ZetaCliArgs {
38 #[command(subcommand)]
39 command: Command,
40}
41
42#[derive(Subcommand, Debug)]
43enum Command {
44 Zeta1 {
45 #[command(subcommand)]
46 command: Zeta1Command,
47 },
48 Zeta2 {
49 #[command(subcommand)]
50 command: Zeta2Command,
51 },
52 ConvertExample {
53 path: PathBuf,
54 #[arg(long, value_enum, default_value_t = ExampleFormat::Md)]
55 output_format: ExampleFormat,
56 },
57}
58
59#[derive(Subcommand, Debug)]
60enum Zeta1Command {
61 Context {
62 #[clap(flatten)]
63 context_args: ContextArgs,
64 },
65}
66
67#[derive(Subcommand, Debug)]
68enum Zeta2Command {
69 Syntax {
70 #[clap(flatten)]
71 args: Zeta2Args,
72 #[clap(flatten)]
73 syntax_args: Zeta2SyntaxArgs,
74 #[command(subcommand)]
75 command: Zeta2SyntaxCommand,
76 },
77 Predict(PredictArguments),
78 Eval(EvaluateArguments),
79}
80
81#[derive(Subcommand, Debug)]
82enum Zeta2SyntaxCommand {
83 Context {
84 #[clap(flatten)]
85 context_args: ContextArgs,
86 },
87 Stats {
88 #[arg(long)]
89 worktree: PathBuf,
90 #[arg(long)]
91 extension: Option<String>,
92 #[arg(long)]
93 limit: Option<usize>,
94 #[arg(long)]
95 skip: Option<usize>,
96 },
97}
98
99#[derive(Debug, Args)]
100#[group(requires = "worktree")]
101struct ContextArgs {
102 #[arg(long)]
103 worktree: PathBuf,
104 #[arg(long)]
105 cursor: SourceLocation,
106 #[arg(long)]
107 use_language_server: bool,
108 #[arg(long)]
109 edit_history: Option<FileOrStdin>,
110}
111
112#[derive(Debug, Args)]
113struct Zeta2Args {
114 #[arg(long, default_value_t = 8192)]
115 max_prompt_bytes: usize,
116 #[arg(long, default_value_t = 2048)]
117 max_excerpt_bytes: usize,
118 #[arg(long, default_value_t = 1024)]
119 min_excerpt_bytes: usize,
120 #[arg(long, default_value_t = 0.66)]
121 target_before_cursor_over_total_bytes: f32,
122 #[arg(long, default_value_t = 1024)]
123 max_diagnostic_bytes: usize,
124 #[arg(long, value_enum, default_value_t = PromptFormat::default())]
125 prompt_format: PromptFormat,
126 #[arg(long, value_enum, default_value_t = Default::default())]
127 output_format: OutputFormat,
128 #[arg(long, default_value_t = 42)]
129 file_indexing_parallelism: usize,
130}
131
132#[derive(Debug, Args)]
133struct Zeta2SyntaxArgs {
134 #[arg(long, default_value_t = false)]
135 disable_imports_gathering: bool,
136 #[arg(long, default_value_t = u8::MAX)]
137 max_retrieved_definitions: u8,
138}
139
140fn syntax_args_to_options(
141 zeta2_args: &Zeta2Args,
142 syntax_args: &Zeta2SyntaxArgs,
143 omit_excerpt_overlaps: bool,
144) -> zeta2::ZetaOptions {
145 zeta2::ZetaOptions {
146 context: ContextMode::Syntax(EditPredictionContextOptions {
147 max_retrieved_declarations: syntax_args.max_retrieved_definitions,
148 use_imports: !syntax_args.disable_imports_gathering,
149 excerpt: EditPredictionExcerptOptions {
150 max_bytes: zeta2_args.max_excerpt_bytes,
151 min_bytes: zeta2_args.min_excerpt_bytes,
152 target_before_cursor_over_total_bytes: zeta2_args
153 .target_before_cursor_over_total_bytes,
154 },
155 score: EditPredictionScoreOptions {
156 omit_excerpt_overlaps,
157 },
158 }),
159 max_diagnostic_bytes: zeta2_args.max_diagnostic_bytes,
160 max_prompt_bytes: zeta2_args.max_prompt_bytes,
161 prompt_format: zeta2_args.prompt_format.clone().into(),
162 file_indexing_parallelism: zeta2_args.file_indexing_parallelism,
163 buffer_change_grouping_interval: Duration::ZERO,
164 }
165}
166
167#[derive(clap::ValueEnum, Default, Debug, Clone)]
168enum PromptFormat {
169 MarkedExcerpt,
170 LabeledSections,
171 OnlySnippets,
172 #[default]
173 NumberedLines,
174}
175
176impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
177 fn into(self) -> predict_edits_v3::PromptFormat {
178 match self {
179 Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
180 Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
181 Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
182 Self::NumberedLines => predict_edits_v3::PromptFormat::NumLinesUniDiff,
183 }
184 }
185}
186
187#[derive(clap::ValueEnum, Default, Debug, Clone)]
188enum OutputFormat {
189 #[default]
190 Prompt,
191 Request,
192 Full,
193}
194
195#[derive(Debug, Clone)]
196enum FileOrStdin {
197 File(PathBuf),
198 Stdin,
199}
200
201impl FileOrStdin {
202 async fn read_to_string(&self) -> Result<String, std::io::Error> {
203 match self {
204 FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
205 FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
206 }
207 }
208}
209
210impl FromStr for FileOrStdin {
211 type Err = <PathBuf as FromStr>::Err;
212
213 fn from_str(s: &str) -> Result<Self, Self::Err> {
214 match s {
215 "-" => Ok(Self::Stdin),
216 _ => Ok(Self::File(PathBuf::from_str(s)?)),
217 }
218 }
219}
220
221struct LoadedContext {
222 full_path_str: String,
223 snapshot: BufferSnapshot,
224 clipped_cursor: Point,
225 worktree: Entity<Worktree>,
226 project: Entity<Project>,
227 buffer: Entity<Buffer>,
228}
229
230async fn load_context(
231 args: &ContextArgs,
232 app_state: &Arc<ZetaCliAppState>,
233 cx: &mut AsyncApp,
234) -> Result<LoadedContext> {
235 let ContextArgs {
236 worktree: worktree_path,
237 cursor,
238 use_language_server,
239 ..
240 } = args;
241
242 let worktree_path = worktree_path.canonicalize()?;
243
244 let project = cx.update(|cx| {
245 Project::local(
246 app_state.client.clone(),
247 app_state.node_runtime.clone(),
248 app_state.user_store.clone(),
249 app_state.languages.clone(),
250 app_state.fs.clone(),
251 None,
252 cx,
253 )
254 })?;
255
256 let worktree = project
257 .update(cx, |project, cx| {
258 project.create_worktree(&worktree_path, true, cx)
259 })?
260 .await?;
261
262 let mut ready_languages = HashSet::default();
263 let (_lsp_open_handle, buffer) = if *use_language_server {
264 let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
265 project.clone(),
266 worktree.clone(),
267 cursor.path.clone(),
268 &mut ready_languages,
269 cx,
270 )
271 .await?;
272 (Some(lsp_open_handle), buffer)
273 } else {
274 let buffer =
275 open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
276 (None, buffer)
277 };
278
279 let full_path_str = worktree
280 .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
281 .display(PathStyle::local())
282 .to_string();
283
284 let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
285 let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
286 if clipped_cursor != cursor.point {
287 let max_row = snapshot.max_point().row;
288 if cursor.point.row < max_row {
289 return Err(anyhow!(
290 "Cursor position {:?} is out of bounds (line length is {})",
291 cursor.point,
292 snapshot.line_len(cursor.point.row)
293 ));
294 } else {
295 return Err(anyhow!(
296 "Cursor position {:?} is out of bounds (max row is {})",
297 cursor.point,
298 max_row
299 ));
300 }
301 }
302
303 Ok(LoadedContext {
304 full_path_str,
305 snapshot,
306 clipped_cursor,
307 worktree,
308 project,
309 buffer,
310 })
311}
312
313async fn zeta2_syntax_context(
314 zeta2_args: Zeta2Args,
315 syntax_args: Zeta2SyntaxArgs,
316 args: ContextArgs,
317 app_state: &Arc<ZetaCliAppState>,
318 cx: &mut AsyncApp,
319) -> Result<String> {
320 let LoadedContext {
321 worktree,
322 project,
323 buffer,
324 clipped_cursor,
325 ..
326 } = load_context(&args, app_state, cx).await?;
327
328 // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
329 // the whole worktree.
330 worktree
331 .read_with(cx, |worktree, _cx| {
332 worktree.as_local().unwrap().scan_complete()
333 })?
334 .await;
335 let output = cx
336 .update(|cx| {
337 let zeta = cx.new(|cx| {
338 zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
339 });
340 let indexing_done_task = zeta.update(cx, |zeta, cx| {
341 zeta.set_options(syntax_args_to_options(&zeta2_args, &syntax_args, true));
342 zeta.register_buffer(&buffer, &project, cx);
343 zeta.wait_for_initial_indexing(&project, cx)
344 });
345 cx.spawn(async move |cx| {
346 indexing_done_task.await?;
347 let request = zeta
348 .update(cx, |zeta, cx| {
349 let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
350 zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
351 })?
352 .await?;
353
354 let (prompt_string, section_labels) = cloud_zeta2_prompt::build_prompt(&request)?;
355
356 match zeta2_args.output_format {
357 OutputFormat::Prompt => anyhow::Ok(prompt_string),
358 OutputFormat::Request => anyhow::Ok(serde_json::to_string_pretty(&request)?),
359 OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({
360 "request": request,
361 "prompt": prompt_string,
362 "section_labels": section_labels,
363 }))?),
364 }
365 })
366 })?
367 .await?;
368
369 Ok(output)
370}
371
372async fn zeta1_context(
373 args: ContextArgs,
374 app_state: &Arc<ZetaCliAppState>,
375 cx: &mut AsyncApp,
376) -> Result<zeta::GatherContextOutput> {
377 let LoadedContext {
378 full_path_str,
379 snapshot,
380 clipped_cursor,
381 ..
382 } = load_context(&args, app_state, cx).await?;
383
384 let events = match args.edit_history {
385 Some(events) => events.read_to_string().await?,
386 None => String::new(),
387 };
388
389 let prompt_for_events = move || (events, 0);
390 cx.update(|cx| {
391 zeta::gather_context(
392 full_path_str,
393 &snapshot,
394 clipped_cursor,
395 prompt_for_events,
396 cx,
397 )
398 })?
399 .await
400}
401
402fn main() {
403 zlog::init();
404 zlog::init_output_stderr();
405 let args = ZetaCliArgs::parse();
406 let http_client = Arc::new(ReqwestClient::new());
407 let app = Application::headless().with_http_client(http_client);
408
409 app.run(move |cx| {
410 let app_state = Arc::new(headless::init(cx));
411 cx.spawn(async move |cx| {
412 match args.command {
413 Command::Zeta1 {
414 command: Zeta1Command::Context { context_args },
415 } => {
416 let context = zeta1_context(context_args, &app_state, cx).await.unwrap();
417 let result = serde_json::to_string_pretty(&context.body).unwrap();
418 println!("{}", result);
419 }
420 Command::Zeta2 { command } => match command {
421 Zeta2Command::Predict(arguments) => {
422 run_zeta2_predict(arguments, &app_state, cx).await;
423 }
424 Zeta2Command::Eval(arguments) => {
425 run_evaluate(arguments, &app_state, cx).await;
426 }
427 Zeta2Command::Syntax {
428 args,
429 syntax_args,
430 command,
431 } => {
432 let result = match command {
433 Zeta2SyntaxCommand::Context { context_args } => {
434 zeta2_syntax_context(
435 args,
436 syntax_args,
437 context_args,
438 &app_state,
439 cx,
440 )
441 .await
442 }
443 Zeta2SyntaxCommand::Stats {
444 worktree,
445 extension,
446 limit,
447 skip,
448 } => {
449 retrieval_stats(
450 worktree,
451 app_state,
452 extension,
453 limit,
454 skip,
455 syntax_args_to_options(&args, &syntax_args, false),
456 cx,
457 )
458 .await
459 }
460 };
461 println!("{}", result.unwrap());
462 }
463 },
464 Command::ConvertExample {
465 path,
466 output_format,
467 } => {
468 let example = NamedExample::load(path).unwrap();
469 example.write(output_format, io::stdout()).unwrap();
470 }
471 };
472
473 let _ = cx.update(|cx| cx.quit());
474 })
475 .detach();
476 });
477}