1mod headless;
2mod retrieval_stats;
3mod source_location;
4mod util;
5
6use crate::retrieval_stats::retrieval_stats;
7use ::util::paths::PathStyle;
8use anyhow::{Result, anyhow};
9use clap::{Args, Parser, Subcommand};
10use cloud_llm_client::predict_edits_v3::{self};
11use edit_prediction_context::{
12 EditPredictionContextOptions, EditPredictionExcerptOptions, EditPredictionScoreOptions,
13};
14use gpui::{Application, AsyncApp, prelude::*};
15use language::Bias;
16use language_model::LlmApiToken;
17use project::Project;
18use release_channel::AppVersion;
19use reqwest_client::ReqwestClient;
20use serde_json::json;
21use std::{collections::HashSet, path::PathBuf, process::exit, str::FromStr, sync::Arc};
22use zeta::{PerformPredictEditsParams, Zeta};
23
24use crate::headless::ZetaCliAppState;
25use crate::source_location::SourceLocation;
26use crate::util::{open_buffer, open_buffer_with_language_server};
27
28#[derive(Parser, Debug)]
29#[command(name = "zeta")]
30struct ZetaCliArgs {
31 #[command(subcommand)]
32 command: Commands,
33}
34
35#[derive(Subcommand, Debug)]
36enum Commands {
37 Context(ContextArgs),
38 Zeta2Context {
39 #[clap(flatten)]
40 zeta2_args: Zeta2Args,
41 #[clap(flatten)]
42 context_args: ContextArgs,
43 },
44 Predict {
45 #[arg(long)]
46 predict_edits_body: Option<FileOrStdin>,
47 #[clap(flatten)]
48 context_args: Option<ContextArgs>,
49 },
50 RetrievalStats {
51 #[clap(flatten)]
52 zeta2_args: Zeta2Args,
53 #[arg(long)]
54 worktree: PathBuf,
55 #[arg(long)]
56 extension: Option<String>,
57 #[arg(long)]
58 limit: Option<usize>,
59 #[arg(long)]
60 skip: Option<usize>,
61 },
62}
63
64#[derive(Debug, Args)]
65#[group(requires = "worktree")]
66struct ContextArgs {
67 #[arg(long)]
68 worktree: PathBuf,
69 #[arg(long)]
70 cursor: SourceLocation,
71 #[arg(long)]
72 use_language_server: bool,
73 #[arg(long)]
74 events: Option<FileOrStdin>,
75}
76
77#[derive(Debug, Args)]
78struct Zeta2Args {
79 #[arg(long, default_value_t = 8192)]
80 max_prompt_bytes: usize,
81 #[arg(long, default_value_t = 2048)]
82 max_excerpt_bytes: usize,
83 #[arg(long, default_value_t = 1024)]
84 min_excerpt_bytes: usize,
85 #[arg(long, default_value_t = 0.66)]
86 target_before_cursor_over_total_bytes: f32,
87 #[arg(long, default_value_t = 1024)]
88 max_diagnostic_bytes: usize,
89 #[arg(long, value_enum, default_value_t = PromptFormat::default())]
90 prompt_format: PromptFormat,
91 #[arg(long, value_enum, default_value_t = Default::default())]
92 output_format: OutputFormat,
93 #[arg(long, default_value_t = 42)]
94 file_indexing_parallelism: usize,
95 #[arg(long, default_value_t = false)]
96 disable_imports_gathering: bool,
97 #[arg(long, default_value_t = false)]
98 disable_reference_retrieval: bool,
99}
100
101#[derive(clap::ValueEnum, Default, Debug, Clone)]
102enum PromptFormat {
103 MarkedExcerpt,
104 LabeledSections,
105 OnlySnippets,
106 #[default]
107 NumberedLines,
108}
109
110impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
111 fn into(self) -> predict_edits_v3::PromptFormat {
112 match self {
113 Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
114 Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
115 Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
116 Self::NumberedLines => predict_edits_v3::PromptFormat::NumLinesUniDiff,
117 }
118 }
119}
120
121#[derive(clap::ValueEnum, Default, Debug, Clone)]
122enum OutputFormat {
123 #[default]
124 Prompt,
125 Request,
126 Full,
127}
128
129#[derive(Debug, Clone)]
130enum FileOrStdin {
131 File(PathBuf),
132 Stdin,
133}
134
135impl FileOrStdin {
136 async fn read_to_string(&self) -> Result<String, std::io::Error> {
137 match self {
138 FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
139 FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
140 }
141 }
142}
143
144impl FromStr for FileOrStdin {
145 type Err = <PathBuf as FromStr>::Err;
146
147 fn from_str(s: &str) -> Result<Self, Self::Err> {
148 match s {
149 "-" => Ok(Self::Stdin),
150 _ => Ok(Self::File(PathBuf::from_str(s)?)),
151 }
152 }
153}
154
155enum GetContextOutput {
156 Zeta1(zeta::GatherContextOutput),
157 Zeta2(String),
158}
159
160async fn get_context(
161 zeta2_args: Option<Zeta2Args>,
162 args: ContextArgs,
163 app_state: &Arc<ZetaCliAppState>,
164 cx: &mut AsyncApp,
165) -> Result<GetContextOutput> {
166 let ContextArgs {
167 worktree: worktree_path,
168 cursor,
169 use_language_server,
170 events,
171 } = args;
172
173 let worktree_path = worktree_path.canonicalize()?;
174
175 let project = cx.update(|cx| {
176 Project::local(
177 app_state.client.clone(),
178 app_state.node_runtime.clone(),
179 app_state.user_store.clone(),
180 app_state.languages.clone(),
181 app_state.fs.clone(),
182 None,
183 cx,
184 )
185 })?;
186
187 let worktree = project
188 .update(cx, |project, cx| {
189 project.create_worktree(&worktree_path, true, cx)
190 })?
191 .await?;
192
193 let mut ready_languages = HashSet::default();
194 let (_lsp_open_handle, buffer) = if use_language_server {
195 let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
196 project.clone(),
197 worktree.clone(),
198 cursor.path.clone(),
199 &mut ready_languages,
200 cx,
201 )
202 .await?;
203 (Some(lsp_open_handle), buffer)
204 } else {
205 let buffer =
206 open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
207 (None, buffer)
208 };
209
210 let full_path_str = worktree
211 .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
212 .display(PathStyle::local())
213 .to_string();
214
215 let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
216 let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
217 if clipped_cursor != cursor.point {
218 let max_row = snapshot.max_point().row;
219 if cursor.point.row < max_row {
220 return Err(anyhow!(
221 "Cursor position {:?} is out of bounds (line length is {})",
222 cursor.point,
223 snapshot.line_len(cursor.point.row)
224 ));
225 } else {
226 return Err(anyhow!(
227 "Cursor position {:?} is out of bounds (max row is {})",
228 cursor.point,
229 max_row
230 ));
231 }
232 }
233
234 let events = match events {
235 Some(events) => events.read_to_string().await?,
236 None => String::new(),
237 };
238
239 if let Some(zeta2_args) = zeta2_args {
240 // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
241 // the whole worktree.
242 worktree
243 .read_with(cx, |worktree, _cx| {
244 worktree.as_local().unwrap().scan_complete()
245 })?
246 .await;
247 let output = cx
248 .update(|cx| {
249 let zeta = cx.new(|cx| {
250 zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
251 });
252 let indexing_done_task = zeta.update(cx, |zeta, cx| {
253 zeta.set_options(zeta2_args.to_options(true));
254 zeta.register_buffer(&buffer, &project, cx);
255 zeta.wait_for_initial_indexing(&project, cx)
256 });
257 cx.spawn(async move |cx| {
258 indexing_done_task.await?;
259 let request = zeta
260 .update(cx, |zeta, cx| {
261 let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
262 zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
263 })?
264 .await?;
265
266 let planned_prompt = cloud_zeta2_prompt::PlannedPrompt::populate(&request)?;
267 let (prompt_string, section_labels) = planned_prompt.to_prompt_string()?;
268
269 match zeta2_args.output_format {
270 OutputFormat::Prompt => anyhow::Ok(prompt_string),
271 OutputFormat::Request => {
272 anyhow::Ok(serde_json::to_string_pretty(&request)?)
273 }
274 OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({
275 "request": request,
276 "prompt": prompt_string,
277 "section_labels": section_labels,
278 }))?),
279 }
280 })
281 })?
282 .await?;
283 Ok(GetContextOutput::Zeta2(output))
284 } else {
285 let prompt_for_events = move || (events, 0);
286 Ok(GetContextOutput::Zeta1(
287 cx.update(|cx| {
288 zeta::gather_context(
289 full_path_str,
290 &snapshot,
291 clipped_cursor,
292 prompt_for_events,
293 cx,
294 )
295 })?
296 .await?,
297 ))
298 }
299}
300
301impl Zeta2Args {
302 fn to_options(&self, omit_excerpt_overlaps: bool) -> zeta2::ZetaOptions {
303 zeta2::ZetaOptions {
304 context: EditPredictionContextOptions {
305 use_references: !self.disable_reference_retrieval,
306 use_imports: !self.disable_imports_gathering,
307 excerpt: EditPredictionExcerptOptions {
308 max_bytes: self.max_excerpt_bytes,
309 min_bytes: self.min_excerpt_bytes,
310 target_before_cursor_over_total_bytes: self
311 .target_before_cursor_over_total_bytes,
312 },
313 score: EditPredictionScoreOptions {
314 omit_excerpt_overlaps,
315 },
316 },
317 max_diagnostic_bytes: self.max_diagnostic_bytes,
318 max_prompt_bytes: self.max_prompt_bytes,
319 prompt_format: self.prompt_format.clone().into(),
320 file_indexing_parallelism: self.file_indexing_parallelism,
321 }
322 }
323}
324
325fn main() {
326 zlog::init();
327 zlog::init_output_stderr();
328 let args = ZetaCliArgs::parse();
329 let http_client = Arc::new(ReqwestClient::new());
330 let app = Application::headless().with_http_client(http_client);
331
332 app.run(move |cx| {
333 let app_state = Arc::new(headless::init(cx));
334 cx.spawn(async move |cx| {
335 let result = match args.command {
336 Commands::Zeta2Context {
337 zeta2_args,
338 context_args,
339 } => match get_context(Some(zeta2_args), context_args, &app_state, cx).await {
340 Ok(GetContextOutput::Zeta1 { .. }) => unreachable!(),
341 Ok(GetContextOutput::Zeta2(output)) => Ok(output),
342 Err(err) => Err(err),
343 },
344 Commands::Context(context_args) => {
345 match get_context(None, context_args, &app_state, cx).await {
346 Ok(GetContextOutput::Zeta1(output)) => {
347 Ok(serde_json::to_string_pretty(&output.body).unwrap())
348 }
349 Ok(GetContextOutput::Zeta2 { .. }) => unreachable!(),
350 Err(err) => Err(err),
351 }
352 }
353 Commands::Predict {
354 predict_edits_body,
355 context_args,
356 } => {
357 cx.spawn(async move |cx| {
358 let app_version = cx.update(|cx| AppVersion::global(cx))?;
359 app_state.client.sign_in(true, cx).await?;
360 let llm_token = LlmApiToken::default();
361 llm_token.refresh(&app_state.client).await?;
362
363 let predict_edits_body =
364 if let Some(predict_edits_body) = predict_edits_body {
365 serde_json::from_str(&predict_edits_body.read_to_string().await?)?
366 } else if let Some(context_args) = context_args {
367 match get_context(None, context_args, &app_state, cx).await? {
368 GetContextOutput::Zeta1(output) => output.body,
369 GetContextOutput::Zeta2 { .. } => unreachable!(),
370 }
371 } else {
372 return Err(anyhow!(
373 "Expected either --predict-edits-body-file \
374 or the required args of the `context` command."
375 ));
376 };
377
378 let (response, _usage) =
379 Zeta::perform_predict_edits(PerformPredictEditsParams {
380 client: app_state.client.clone(),
381 llm_token,
382 app_version,
383 body: predict_edits_body,
384 })
385 .await?;
386
387 Ok(response.output_excerpt)
388 })
389 .await
390 }
391 Commands::RetrievalStats {
392 zeta2_args,
393 worktree,
394 extension,
395 limit,
396 skip,
397 } => {
398 retrieval_stats(
399 worktree,
400 app_state,
401 extension,
402 limit,
403 skip,
404 (&zeta2_args).to_options(false),
405 cx,
406 )
407 .await
408 }
409 };
410 match result {
411 Ok(output) => {
412 println!("{}", output);
413 let _ = cx.update(|cx| cx.quit());
414 }
415 Err(e) => {
416 eprintln!("Failed: {:?}", e);
417 exit(1);
418 }
419 }
420 })
421 .detach();
422 });
423}