1mod evaluate;
2mod example;
3mod headless;
4mod paths;
5mod predict;
6mod source_location;
7mod syntax_retrieval_stats;
8mod util;
9
10use crate::evaluate::{EvaluateArguments, run_evaluate};
11use crate::example::{ExampleFormat, NamedExample};
12use crate::predict::{PredictArguments, run_zeta2_predict};
13use crate::syntax_retrieval_stats::retrieval_stats;
14use ::util::paths::PathStyle;
15use anyhow::{Result, anyhow};
16use clap::{Args, Parser, Subcommand};
17use cloud_llm_client::predict_edits_v3;
18use edit_prediction_context::{
19 EditPredictionContextOptions, EditPredictionExcerptOptions, EditPredictionScoreOptions,
20};
21use gpui::{Application, AsyncApp, Entity, prelude::*};
22use language::{Bias, Buffer, BufferSnapshot, Point};
23use project::{Project, Worktree};
24use reqwest_client::ReqwestClient;
25use serde_json::json;
26use std::io::{self};
27use std::time::Duration;
28use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc};
29use zeta2::ContextMode;
30
31use crate::headless::ZetaCliAppState;
32use crate::source_location::SourceLocation;
33use crate::util::{open_buffer, open_buffer_with_language_server};
34
35#[derive(Parser, Debug)]
36#[command(name = "zeta")]
37struct ZetaCliArgs {
38 #[command(subcommand)]
39 command: Command,
40}
41
42#[derive(Subcommand, Debug)]
43enum Command {
44 Zeta1 {
45 #[command(subcommand)]
46 command: Zeta1Command,
47 },
48 Zeta2 {
49 #[command(subcommand)]
50 command: Zeta2Command,
51 },
52 ConvertExample {
53 path: PathBuf,
54 #[arg(long, value_enum, default_value_t = ExampleFormat::Md)]
55 output_format: ExampleFormat,
56 },
57}
58
59#[derive(Subcommand, Debug)]
60enum Zeta1Command {
61 Context {
62 #[clap(flatten)]
63 context_args: ContextArgs,
64 },
65}
66
67#[derive(Subcommand, Debug)]
68enum Zeta2Command {
69 Syntax {
70 #[clap(flatten)]
71 args: Zeta2Args,
72 #[clap(flatten)]
73 syntax_args: Zeta2SyntaxArgs,
74 #[command(subcommand)]
75 command: Zeta2SyntaxCommand,
76 },
77 Predict(PredictArguments),
78 Eval(EvaluateArguments),
79}
80
81#[derive(Subcommand, Debug)]
82enum Zeta2SyntaxCommand {
83 Context {
84 #[clap(flatten)]
85 context_args: ContextArgs,
86 },
87 Stats {
88 #[arg(long)]
89 worktree: PathBuf,
90 #[arg(long)]
91 extension: Option<String>,
92 #[arg(long)]
93 limit: Option<usize>,
94 #[arg(long)]
95 skip: Option<usize>,
96 },
97}
98
99#[derive(Debug, Args)]
100#[group(requires = "worktree")]
101struct ContextArgs {
102 #[arg(long)]
103 worktree: PathBuf,
104 #[arg(long)]
105 cursor: SourceLocation,
106 #[arg(long)]
107 use_language_server: bool,
108 #[arg(long)]
109 edit_history: Option<FileOrStdin>,
110}
111
112#[derive(Debug, Args)]
113struct Zeta2Args {
114 #[arg(long, default_value_t = 8192)]
115 max_prompt_bytes: usize,
116 #[arg(long, default_value_t = 2048)]
117 max_excerpt_bytes: usize,
118 #[arg(long, default_value_t = 1024)]
119 min_excerpt_bytes: usize,
120 #[arg(long, default_value_t = 0.66)]
121 target_before_cursor_over_total_bytes: f32,
122 #[arg(long, default_value_t = 1024)]
123 max_diagnostic_bytes: usize,
124 #[arg(long, value_enum, default_value_t = PromptFormat::default())]
125 prompt_format: PromptFormat,
126 #[arg(long, value_enum, default_value_t = Default::default())]
127 output_format: OutputFormat,
128 #[arg(long, default_value_t = 42)]
129 file_indexing_parallelism: usize,
130}
131
132#[derive(Debug, Args)]
133struct Zeta2SyntaxArgs {
134 #[arg(long, default_value_t = false)]
135 disable_imports_gathering: bool,
136 #[arg(long, default_value_t = u8::MAX)]
137 max_retrieved_definitions: u8,
138}
139
140fn syntax_args_to_options(
141 zeta2_args: &Zeta2Args,
142 syntax_args: &Zeta2SyntaxArgs,
143 omit_excerpt_overlaps: bool,
144) -> zeta2::ZetaOptions {
145 zeta2::ZetaOptions {
146 context: ContextMode::Syntax(EditPredictionContextOptions {
147 max_retrieved_declarations: syntax_args.max_retrieved_definitions,
148 use_imports: !syntax_args.disable_imports_gathering,
149 excerpt: EditPredictionExcerptOptions {
150 max_bytes: zeta2_args.max_excerpt_bytes,
151 min_bytes: zeta2_args.min_excerpt_bytes,
152 target_before_cursor_over_total_bytes: zeta2_args
153 .target_before_cursor_over_total_bytes,
154 },
155 score: EditPredictionScoreOptions {
156 omit_excerpt_overlaps,
157 },
158 }),
159 max_diagnostic_bytes: zeta2_args.max_diagnostic_bytes,
160 max_prompt_bytes: zeta2_args.max_prompt_bytes,
161 prompt_format: zeta2_args.prompt_format.into(),
162 file_indexing_parallelism: zeta2_args.file_indexing_parallelism,
163 buffer_change_grouping_interval: Duration::ZERO,
164 }
165}
166
167#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
168enum PromptFormat {
169 MarkedExcerpt,
170 LabeledSections,
171 OnlySnippets,
172 #[default]
173 NumberedLines,
174 OldTextNewText,
175}
176
177impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
178 fn into(self) -> predict_edits_v3::PromptFormat {
179 match self {
180 Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
181 Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
182 Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
183 Self::NumberedLines => predict_edits_v3::PromptFormat::NumLinesUniDiff,
184 Self::OldTextNewText => predict_edits_v3::PromptFormat::OldTextNewText,
185 }
186 }
187}
188
189#[derive(clap::ValueEnum, Default, Debug, Clone)]
190enum OutputFormat {
191 #[default]
192 Prompt,
193 Request,
194 Full,
195}
196
197#[derive(Debug, Clone)]
198enum FileOrStdin {
199 File(PathBuf),
200 Stdin,
201}
202
203impl FileOrStdin {
204 async fn read_to_string(&self) -> Result<String, std::io::Error> {
205 match self {
206 FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
207 FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
208 }
209 }
210}
211
212impl FromStr for FileOrStdin {
213 type Err = <PathBuf as FromStr>::Err;
214
215 fn from_str(s: &str) -> Result<Self, Self::Err> {
216 match s {
217 "-" => Ok(Self::Stdin),
218 _ => Ok(Self::File(PathBuf::from_str(s)?)),
219 }
220 }
221}
222
223struct LoadedContext {
224 full_path_str: String,
225 snapshot: BufferSnapshot,
226 clipped_cursor: Point,
227 worktree: Entity<Worktree>,
228 project: Entity<Project>,
229 buffer: Entity<Buffer>,
230}
231
232async fn load_context(
233 args: &ContextArgs,
234 app_state: &Arc<ZetaCliAppState>,
235 cx: &mut AsyncApp,
236) -> Result<LoadedContext> {
237 let ContextArgs {
238 worktree: worktree_path,
239 cursor,
240 use_language_server,
241 ..
242 } = args;
243
244 let worktree_path = worktree_path.canonicalize()?;
245
246 let project = cx.update(|cx| {
247 Project::local(
248 app_state.client.clone(),
249 app_state.node_runtime.clone(),
250 app_state.user_store.clone(),
251 app_state.languages.clone(),
252 app_state.fs.clone(),
253 None,
254 cx,
255 )
256 })?;
257
258 let worktree = project
259 .update(cx, |project, cx| {
260 project.create_worktree(&worktree_path, true, cx)
261 })?
262 .await?;
263
264 let mut ready_languages = HashSet::default();
265 let (_lsp_open_handle, buffer) = if *use_language_server {
266 let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
267 project.clone(),
268 worktree.clone(),
269 cursor.path.clone(),
270 &mut ready_languages,
271 cx,
272 )
273 .await?;
274 (Some(lsp_open_handle), buffer)
275 } else {
276 let buffer =
277 open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
278 (None, buffer)
279 };
280
281 let full_path_str = worktree
282 .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
283 .display(PathStyle::local())
284 .to_string();
285
286 let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
287 let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
288 if clipped_cursor != cursor.point {
289 let max_row = snapshot.max_point().row;
290 if cursor.point.row < max_row {
291 return Err(anyhow!(
292 "Cursor position {:?} is out of bounds (line length is {})",
293 cursor.point,
294 snapshot.line_len(cursor.point.row)
295 ));
296 } else {
297 return Err(anyhow!(
298 "Cursor position {:?} is out of bounds (max row is {})",
299 cursor.point,
300 max_row
301 ));
302 }
303 }
304
305 Ok(LoadedContext {
306 full_path_str,
307 snapshot,
308 clipped_cursor,
309 worktree,
310 project,
311 buffer,
312 })
313}
314
315async fn zeta2_syntax_context(
316 zeta2_args: Zeta2Args,
317 syntax_args: Zeta2SyntaxArgs,
318 args: ContextArgs,
319 app_state: &Arc<ZetaCliAppState>,
320 cx: &mut AsyncApp,
321) -> Result<String> {
322 let LoadedContext {
323 worktree,
324 project,
325 buffer,
326 clipped_cursor,
327 ..
328 } = load_context(&args, app_state, cx).await?;
329
330 // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
331 // the whole worktree.
332 worktree
333 .read_with(cx, |worktree, _cx| {
334 worktree.as_local().unwrap().scan_complete()
335 })?
336 .await;
337 let output = cx
338 .update(|cx| {
339 let zeta = cx.new(|cx| {
340 zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
341 });
342 let indexing_done_task = zeta.update(cx, |zeta, cx| {
343 zeta.set_options(syntax_args_to_options(&zeta2_args, &syntax_args, true));
344 zeta.register_buffer(&buffer, &project, cx);
345 zeta.wait_for_initial_indexing(&project, cx)
346 });
347 cx.spawn(async move |cx| {
348 indexing_done_task.await?;
349 let request = zeta
350 .update(cx, |zeta, cx| {
351 let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
352 zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
353 })?
354 .await?;
355
356 let (prompt_string, section_labels) = cloud_zeta2_prompt::build_prompt(&request)?;
357
358 match zeta2_args.output_format {
359 OutputFormat::Prompt => anyhow::Ok(prompt_string),
360 OutputFormat::Request => anyhow::Ok(serde_json::to_string_pretty(&request)?),
361 OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({
362 "request": request,
363 "prompt": prompt_string,
364 "section_labels": section_labels,
365 }))?),
366 }
367 })
368 })?
369 .await?;
370
371 Ok(output)
372}
373
374async fn zeta1_context(
375 args: ContextArgs,
376 app_state: &Arc<ZetaCliAppState>,
377 cx: &mut AsyncApp,
378) -> Result<zeta::GatherContextOutput> {
379 let LoadedContext {
380 full_path_str,
381 snapshot,
382 clipped_cursor,
383 ..
384 } = load_context(&args, app_state, cx).await?;
385
386 let events = match args.edit_history {
387 Some(events) => events.read_to_string().await?,
388 None => String::new(),
389 };
390
391 let prompt_for_events = move || (events, 0);
392 cx.update(|cx| {
393 zeta::gather_context(
394 full_path_str,
395 &snapshot,
396 clipped_cursor,
397 prompt_for_events,
398 cx,
399 )
400 })?
401 .await
402}
403
404fn main() {
405 zlog::init();
406 zlog::init_output_stderr();
407 let args = ZetaCliArgs::parse();
408 let http_client = Arc::new(ReqwestClient::new());
409 let app = Application::headless().with_http_client(http_client);
410
411 app.run(move |cx| {
412 let app_state = Arc::new(headless::init(cx));
413 cx.spawn(async move |cx| {
414 match args.command {
415 Command::Zeta1 {
416 command: Zeta1Command::Context { context_args },
417 } => {
418 let context = zeta1_context(context_args, &app_state, cx).await.unwrap();
419 let result = serde_json::to_string_pretty(&context.body).unwrap();
420 println!("{}", result);
421 }
422 Command::Zeta2 { command } => match command {
423 Zeta2Command::Predict(arguments) => {
424 run_zeta2_predict(arguments, &app_state, cx).await;
425 }
426 Zeta2Command::Eval(arguments) => {
427 run_evaluate(arguments, &app_state, cx).await;
428 }
429 Zeta2Command::Syntax {
430 args,
431 syntax_args,
432 command,
433 } => {
434 let result = match command {
435 Zeta2SyntaxCommand::Context { context_args } => {
436 zeta2_syntax_context(
437 args,
438 syntax_args,
439 context_args,
440 &app_state,
441 cx,
442 )
443 .await
444 }
445 Zeta2SyntaxCommand::Stats {
446 worktree,
447 extension,
448 limit,
449 skip,
450 } => {
451 retrieval_stats(
452 worktree,
453 app_state,
454 extension,
455 limit,
456 skip,
457 syntax_args_to_options(&args, &syntax_args, false),
458 cx,
459 )
460 .await
461 }
462 };
463 println!("{}", result.unwrap());
464 }
465 },
466 Command::ConvertExample {
467 path,
468 output_format,
469 } => {
470 let example = NamedExample::load(path).unwrap();
471 example.write(output_format, io::stdout()).unwrap();
472 }
473 };
474
475 let _ = cx.update(|cx| cx.quit());
476 })
477 .detach();
478 });
479}