1mod headless;
  2mod retrieval_stats;
  3mod source_location;
  4mod util;
  5
  6use crate::retrieval_stats::retrieval_stats;
  7use ::util::paths::PathStyle;
  8use anyhow::{Result, anyhow};
  9use clap::{Args, Parser, Subcommand};
 10use cloud_llm_client::predict_edits_v3::{self};
 11use edit_prediction_context::{
 12    EditPredictionContextOptions, EditPredictionExcerptOptions, EditPredictionScoreOptions,
 13};
 14use gpui::{Application, AsyncApp, prelude::*};
 15use language::Bias;
 16use language_model::LlmApiToken;
 17use project::Project;
 18use release_channel::AppVersion;
 19use reqwest_client::ReqwestClient;
 20use serde_json::json;
 21use std::{collections::HashSet, path::PathBuf, process::exit, str::FromStr, sync::Arc};
 22use zeta::{PerformPredictEditsParams, Zeta};
 23use zeta2::ContextMode;
 24
 25use crate::headless::ZetaCliAppState;
 26use crate::source_location::SourceLocation;
 27use crate::util::{open_buffer, open_buffer_with_language_server};
 28
 29#[derive(Parser, Debug)]
 30#[command(name = "zeta")]
 31struct ZetaCliArgs {
 32    #[command(subcommand)]
 33    command: Commands,
 34}
 35
 36#[derive(Subcommand, Debug)]
 37enum Commands {
 38    Context(ContextArgs),
 39    Zeta2Context {
 40        #[clap(flatten)]
 41        zeta2_args: Zeta2Args,
 42        #[clap(flatten)]
 43        context_args: ContextArgs,
 44    },
 45    Predict {
 46        #[arg(long)]
 47        predict_edits_body: Option<FileOrStdin>,
 48        #[clap(flatten)]
 49        context_args: Option<ContextArgs>,
 50    },
 51    RetrievalStats {
 52        #[clap(flatten)]
 53        zeta2_args: Zeta2Args,
 54        #[arg(long)]
 55        worktree: PathBuf,
 56        #[arg(long)]
 57        extension: Option<String>,
 58        #[arg(long)]
 59        limit: Option<usize>,
 60        #[arg(long)]
 61        skip: Option<usize>,
 62    },
 63}
 64
 65#[derive(Debug, Args)]
 66#[group(requires = "worktree")]
 67struct ContextArgs {
 68    #[arg(long)]
 69    worktree: PathBuf,
 70    #[arg(long)]
 71    cursor: SourceLocation,
 72    #[arg(long)]
 73    use_language_server: bool,
 74    #[arg(long)]
 75    events: Option<FileOrStdin>,
 76}
 77
 78#[derive(Debug, Args)]
 79struct Zeta2Args {
 80    #[arg(long, default_value_t = 8192)]
 81    max_prompt_bytes: usize,
 82    #[arg(long, default_value_t = 2048)]
 83    max_excerpt_bytes: usize,
 84    #[arg(long, default_value_t = 1024)]
 85    min_excerpt_bytes: usize,
 86    #[arg(long, default_value_t = 0.66)]
 87    target_before_cursor_over_total_bytes: f32,
 88    #[arg(long, default_value_t = 1024)]
 89    max_diagnostic_bytes: usize,
 90    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
 91    prompt_format: PromptFormat,
 92    #[arg(long, value_enum, default_value_t = Default::default())]
 93    output_format: OutputFormat,
 94    #[arg(long, default_value_t = 42)]
 95    file_indexing_parallelism: usize,
 96    #[arg(long, default_value_t = false)]
 97    disable_imports_gathering: bool,
 98    #[arg(long, default_value_t = u8::MAX)]
 99    max_retrieved_definitions: u8,
100}
101
102#[derive(clap::ValueEnum, Default, Debug, Clone)]
103enum PromptFormat {
104    MarkedExcerpt,
105    LabeledSections,
106    OnlySnippets,
107    #[default]
108    NumberedLines,
109}
110
111impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
112    fn into(self) -> predict_edits_v3::PromptFormat {
113        match self {
114            Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
115            Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
116            Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
117            Self::NumberedLines => predict_edits_v3::PromptFormat::NumLinesUniDiff,
118        }
119    }
120}
121
122#[derive(clap::ValueEnum, Default, Debug, Clone)]
123enum OutputFormat {
124    #[default]
125    Prompt,
126    Request,
127    Full,
128}
129
130#[derive(Debug, Clone)]
131enum FileOrStdin {
132    File(PathBuf),
133    Stdin,
134}
135
136impl FileOrStdin {
137    async fn read_to_string(&self) -> Result<String, std::io::Error> {
138        match self {
139            FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
140            FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
141        }
142    }
143}
144
145impl FromStr for FileOrStdin {
146    type Err = <PathBuf as FromStr>::Err;
147
148    fn from_str(s: &str) -> Result<Self, Self::Err> {
149        match s {
150            "-" => Ok(Self::Stdin),
151            _ => Ok(Self::File(PathBuf::from_str(s)?)),
152        }
153    }
154}
155
156enum GetContextOutput {
157    Zeta1(zeta::GatherContextOutput),
158    Zeta2(String),
159}
160
161async fn get_context(
162    zeta2_args: Option<Zeta2Args>,
163    args: ContextArgs,
164    app_state: &Arc<ZetaCliAppState>,
165    cx: &mut AsyncApp,
166) -> Result<GetContextOutput> {
167    let ContextArgs {
168        worktree: worktree_path,
169        cursor,
170        use_language_server,
171        events,
172    } = args;
173
174    let worktree_path = worktree_path.canonicalize()?;
175
176    let project = cx.update(|cx| {
177        Project::local(
178            app_state.client.clone(),
179            app_state.node_runtime.clone(),
180            app_state.user_store.clone(),
181            app_state.languages.clone(),
182            app_state.fs.clone(),
183            None,
184            cx,
185        )
186    })?;
187
188    let worktree = project
189        .update(cx, |project, cx| {
190            project.create_worktree(&worktree_path, true, cx)
191        })?
192        .await?;
193
194    let mut ready_languages = HashSet::default();
195    let (_lsp_open_handle, buffer) = if use_language_server {
196        let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
197            project.clone(),
198            worktree.clone(),
199            cursor.path.clone(),
200            &mut ready_languages,
201            cx,
202        )
203        .await?;
204        (Some(lsp_open_handle), buffer)
205    } else {
206        let buffer =
207            open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
208        (None, buffer)
209    };
210
211    let full_path_str = worktree
212        .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
213        .display(PathStyle::local())
214        .to_string();
215
216    let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
217    let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
218    if clipped_cursor != cursor.point {
219        let max_row = snapshot.max_point().row;
220        if cursor.point.row < max_row {
221            return Err(anyhow!(
222                "Cursor position {:?} is out of bounds (line length is {})",
223                cursor.point,
224                snapshot.line_len(cursor.point.row)
225            ));
226        } else {
227            return Err(anyhow!(
228                "Cursor position {:?} is out of bounds (max row is {})",
229                cursor.point,
230                max_row
231            ));
232        }
233    }
234
235    let events = match events {
236        Some(events) => events.read_to_string().await?,
237        None => String::new(),
238    };
239
240    if let Some(zeta2_args) = zeta2_args {
241        // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
242        // the whole worktree.
243        worktree
244            .read_with(cx, |worktree, _cx| {
245                worktree.as_local().unwrap().scan_complete()
246            })?
247            .await;
248        let output = cx
249            .update(|cx| {
250                let zeta = cx.new(|cx| {
251                    zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
252                });
253                let indexing_done_task = zeta.update(cx, |zeta, cx| {
254                    zeta.set_options(zeta2_args.to_options(true));
255                    zeta.register_buffer(&buffer, &project, cx);
256                    zeta.wait_for_initial_indexing(&project, cx)
257                });
258                cx.spawn(async move |cx| {
259                    indexing_done_task.await?;
260                    let request = zeta
261                        .update(cx, |zeta, cx| {
262                            let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
263                            zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
264                        })?
265                        .await?;
266
267                    let (prompt_string, section_labels) =
268                        cloud_zeta2_prompt::build_prompt(&request)?;
269
270                    match zeta2_args.output_format {
271                        OutputFormat::Prompt => anyhow::Ok(prompt_string),
272                        OutputFormat::Request => {
273                            anyhow::Ok(serde_json::to_string_pretty(&request)?)
274                        }
275                        OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({
276                            "request": request,
277                            "prompt": prompt_string,
278                            "section_labels": section_labels,
279                        }))?),
280                    }
281                })
282            })?
283            .await?;
284        Ok(GetContextOutput::Zeta2(output))
285    } else {
286        let prompt_for_events = move || (events, 0);
287        Ok(GetContextOutput::Zeta1(
288            cx.update(|cx| {
289                zeta::gather_context(
290                    full_path_str,
291                    &snapshot,
292                    clipped_cursor,
293                    prompt_for_events,
294                    cx,
295                )
296            })?
297            .await?,
298        ))
299    }
300}
301
302impl Zeta2Args {
303    fn to_options(&self, omit_excerpt_overlaps: bool) -> zeta2::ZetaOptions {
304        zeta2::ZetaOptions {
305            context: ContextMode::Syntax(EditPredictionContextOptions {
306                max_retrieved_declarations: self.max_retrieved_definitions,
307                use_imports: !self.disable_imports_gathering,
308                excerpt: EditPredictionExcerptOptions {
309                    max_bytes: self.max_excerpt_bytes,
310                    min_bytes: self.min_excerpt_bytes,
311                    target_before_cursor_over_total_bytes: self
312                        .target_before_cursor_over_total_bytes,
313                },
314                score: EditPredictionScoreOptions {
315                    omit_excerpt_overlaps,
316                },
317            }),
318            max_diagnostic_bytes: self.max_diagnostic_bytes,
319            max_prompt_bytes: self.max_prompt_bytes,
320            prompt_format: self.prompt_format.clone().into(),
321            file_indexing_parallelism: self.file_indexing_parallelism,
322        }
323    }
324}
325
326fn main() {
327    zlog::init();
328    zlog::init_output_stderr();
329    let args = ZetaCliArgs::parse();
330    let http_client = Arc::new(ReqwestClient::new());
331    let app = Application::headless().with_http_client(http_client);
332
333    app.run(move |cx| {
334        let app_state = Arc::new(headless::init(cx));
335        cx.spawn(async move |cx| {
336            let result = match args.command {
337                Commands::Zeta2Context {
338                    zeta2_args,
339                    context_args,
340                } => match get_context(Some(zeta2_args), context_args, &app_state, cx).await {
341                    Ok(GetContextOutput::Zeta1 { .. }) => unreachable!(),
342                    Ok(GetContextOutput::Zeta2(output)) => Ok(output),
343                    Err(err) => Err(err),
344                },
345                Commands::Context(context_args) => {
346                    match get_context(None, context_args, &app_state, cx).await {
347                        Ok(GetContextOutput::Zeta1(output)) => {
348                            Ok(serde_json::to_string_pretty(&output.body).unwrap())
349                        }
350                        Ok(GetContextOutput::Zeta2 { .. }) => unreachable!(),
351                        Err(err) => Err(err),
352                    }
353                }
354                Commands::Predict {
355                    predict_edits_body,
356                    context_args,
357                } => {
358                    cx.spawn(async move |cx| {
359                        let app_version = cx.update(|cx| AppVersion::global(cx))?;
360                        app_state.client.sign_in(true, cx).await?;
361                        let llm_token = LlmApiToken::default();
362                        llm_token.refresh(&app_state.client).await?;
363
364                        let predict_edits_body =
365                            if let Some(predict_edits_body) = predict_edits_body {
366                                serde_json::from_str(&predict_edits_body.read_to_string().await?)?
367                            } else if let Some(context_args) = context_args {
368                                match get_context(None, context_args, &app_state, cx).await? {
369                                    GetContextOutput::Zeta1(output) => output.body,
370                                    GetContextOutput::Zeta2 { .. } => unreachable!(),
371                                }
372                            } else {
373                                return Err(anyhow!(
374                                    "Expected either --predict-edits-body-file \
375                                    or the required args of the `context` command."
376                                ));
377                            };
378
379                        let (response, _usage) =
380                            Zeta::perform_predict_edits(PerformPredictEditsParams {
381                                client: app_state.client.clone(),
382                                llm_token,
383                                app_version,
384                                body: predict_edits_body,
385                            })
386                            .await?;
387
388                        Ok(response.output_excerpt)
389                    })
390                    .await
391                }
392                Commands::RetrievalStats {
393                    zeta2_args,
394                    worktree,
395                    extension,
396                    limit,
397                    skip,
398                } => {
399                    retrieval_stats(
400                        worktree,
401                        app_state,
402                        extension,
403                        limit,
404                        skip,
405                        (&zeta2_args).to_options(false),
406                        cx,
407                    )
408                    .await
409                }
410            };
411            match result {
412                Ok(output) => {
413                    println!("{}", output);
414                    let _ = cx.update(|cx| cx.quit());
415                }
416                Err(e) => {
417                    eprintln!("Failed: {:?}", e);
418                    exit(1);
419                }
420            }
421        })
422        .detach();
423    });
424}