main.rs

  1mod headless;
  2mod retrieval_stats;
  3mod source_location;
  4mod util;
  5
  6use crate::retrieval_stats::retrieval_stats;
  7use ::util::paths::PathStyle;
  8use anyhow::{Result, anyhow};
  9use clap::{Args, Parser, Subcommand};
 10use cloud_llm_client::predict_edits_v3::{self};
 11use edit_prediction_context::{
 12    EditPredictionContextOptions, EditPredictionExcerptOptions, EditPredictionScoreOptions,
 13};
 14use gpui::{Application, AsyncApp, prelude::*};
 15use language::Bias;
 16use language_model::LlmApiToken;
 17use project::Project;
 18use release_channel::AppVersion;
 19use reqwest_client::ReqwestClient;
 20use serde_json::json;
 21use std::{collections::HashSet, path::PathBuf, process::exit, str::FromStr, sync::Arc};
 22use zeta::{PerformPredictEditsParams, Zeta};
 23
 24use crate::headless::ZetaCliAppState;
 25use crate::source_location::SourceLocation;
 26use crate::util::{open_buffer, open_buffer_with_language_server};
 27
 28#[derive(Parser, Debug)]
 29#[command(name = "zeta")]
 30struct ZetaCliArgs {
 31    #[command(subcommand)]
 32    command: Commands,
 33}
 34
 35#[derive(Subcommand, Debug)]
 36enum Commands {
 37    Context(ContextArgs),
 38    Zeta2Context {
 39        #[clap(flatten)]
 40        zeta2_args: Zeta2Args,
 41        #[clap(flatten)]
 42        context_args: ContextArgs,
 43    },
 44    Predict {
 45        #[arg(long)]
 46        predict_edits_body: Option<FileOrStdin>,
 47        #[clap(flatten)]
 48        context_args: Option<ContextArgs>,
 49    },
 50    RetrievalStats {
 51        #[clap(flatten)]
 52        zeta2_args: Zeta2Args,
 53        #[arg(long)]
 54        worktree: PathBuf,
 55        #[arg(long)]
 56        extension: Option<String>,
 57        #[arg(long)]
 58        limit: Option<usize>,
 59        #[arg(long)]
 60        skip: Option<usize>,
 61    },
 62}
 63
 64#[derive(Debug, Args)]
 65#[group(requires = "worktree")]
 66struct ContextArgs {
 67    #[arg(long)]
 68    worktree: PathBuf,
 69    #[arg(long)]
 70    cursor: SourceLocation,
 71    #[arg(long)]
 72    use_language_server: bool,
 73    #[arg(long)]
 74    events: Option<FileOrStdin>,
 75}
 76
 77#[derive(Debug, Args)]
 78struct Zeta2Args {
 79    #[arg(long, default_value_t = 8192)]
 80    max_prompt_bytes: usize,
 81    #[arg(long, default_value_t = 2048)]
 82    max_excerpt_bytes: usize,
 83    #[arg(long, default_value_t = 1024)]
 84    min_excerpt_bytes: usize,
 85    #[arg(long, default_value_t = 0.66)]
 86    target_before_cursor_over_total_bytes: f32,
 87    #[arg(long, default_value_t = 1024)]
 88    max_diagnostic_bytes: usize,
 89    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
 90    prompt_format: PromptFormat,
 91    #[arg(long, value_enum, default_value_t = Default::default())]
 92    output_format: OutputFormat,
 93    #[arg(long, default_value_t = 42)]
 94    file_indexing_parallelism: usize,
 95    #[arg(long, default_value_t = false)]
 96    disable_imports_gathering: bool,
 97    #[arg(long, default_value_t = u8::MAX)]
 98    max_retrieved_definitions: u8,
 99}
100
101#[derive(clap::ValueEnum, Default, Debug, Clone)]
102enum PromptFormat {
103    MarkedExcerpt,
104    LabeledSections,
105    OnlySnippets,
106    #[default]
107    NumberedLines,
108}
109
110impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
111    fn into(self) -> predict_edits_v3::PromptFormat {
112        match self {
113            Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
114            Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
115            Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
116            Self::NumberedLines => predict_edits_v3::PromptFormat::NumLinesUniDiff,
117        }
118    }
119}
120
121#[derive(clap::ValueEnum, Default, Debug, Clone)]
122enum OutputFormat {
123    #[default]
124    Prompt,
125    Request,
126    Full,
127}
128
129#[derive(Debug, Clone)]
130enum FileOrStdin {
131    File(PathBuf),
132    Stdin,
133}
134
135impl FileOrStdin {
136    async fn read_to_string(&self) -> Result<String, std::io::Error> {
137        match self {
138            FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
139            FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
140        }
141    }
142}
143
144impl FromStr for FileOrStdin {
145    type Err = <PathBuf as FromStr>::Err;
146
147    fn from_str(s: &str) -> Result<Self, Self::Err> {
148        match s {
149            "-" => Ok(Self::Stdin),
150            _ => Ok(Self::File(PathBuf::from_str(s)?)),
151        }
152    }
153}
154
155enum GetContextOutput {
156    Zeta1(zeta::GatherContextOutput),
157    Zeta2(String),
158}
159
160async fn get_context(
161    zeta2_args: Option<Zeta2Args>,
162    args: ContextArgs,
163    app_state: &Arc<ZetaCliAppState>,
164    cx: &mut AsyncApp,
165) -> Result<GetContextOutput> {
166    let ContextArgs {
167        worktree: worktree_path,
168        cursor,
169        use_language_server,
170        events,
171    } = args;
172
173    let worktree_path = worktree_path.canonicalize()?;
174
175    let project = cx.update(|cx| {
176        Project::local(
177            app_state.client.clone(),
178            app_state.node_runtime.clone(),
179            app_state.user_store.clone(),
180            app_state.languages.clone(),
181            app_state.fs.clone(),
182            None,
183            cx,
184        )
185    })?;
186
187    let worktree = project
188        .update(cx, |project, cx| {
189            project.create_worktree(&worktree_path, true, cx)
190        })?
191        .await?;
192
193    let mut ready_languages = HashSet::default();
194    let (_lsp_open_handle, buffer) = if use_language_server {
195        let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
196            project.clone(),
197            worktree.clone(),
198            cursor.path.clone(),
199            &mut ready_languages,
200            cx,
201        )
202        .await?;
203        (Some(lsp_open_handle), buffer)
204    } else {
205        let buffer =
206            open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
207        (None, buffer)
208    };
209
210    let full_path_str = worktree
211        .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
212        .display(PathStyle::local())
213        .to_string();
214
215    let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
216    let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
217    if clipped_cursor != cursor.point {
218        let max_row = snapshot.max_point().row;
219        if cursor.point.row < max_row {
220            return Err(anyhow!(
221                "Cursor position {:?} is out of bounds (line length is {})",
222                cursor.point,
223                snapshot.line_len(cursor.point.row)
224            ));
225        } else {
226            return Err(anyhow!(
227                "Cursor position {:?} is out of bounds (max row is {})",
228                cursor.point,
229                max_row
230            ));
231        }
232    }
233
234    let events = match events {
235        Some(events) => events.read_to_string().await?,
236        None => String::new(),
237    };
238
239    if let Some(zeta2_args) = zeta2_args {
240        // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
241        // the whole worktree.
242        worktree
243            .read_with(cx, |worktree, _cx| {
244                worktree.as_local().unwrap().scan_complete()
245            })?
246            .await;
247        let output = cx
248            .update(|cx| {
249                let zeta = cx.new(|cx| {
250                    zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
251                });
252                let indexing_done_task = zeta.update(cx, |zeta, cx| {
253                    zeta.set_options(zeta2_args.to_options(true));
254                    zeta.register_buffer(&buffer, &project, cx);
255                    zeta.wait_for_initial_indexing(&project, cx)
256                });
257                cx.spawn(async move |cx| {
258                    indexing_done_task.await?;
259                    let request = zeta
260                        .update(cx, |zeta, cx| {
261                            let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
262                            zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
263                        })?
264                        .await?;
265
266                    let planned_prompt = cloud_zeta2_prompt::PlannedPrompt::populate(&request)?;
267                    let (prompt_string, section_labels) = planned_prompt.to_prompt_string()?;
268
269                    match zeta2_args.output_format {
270                        OutputFormat::Prompt => anyhow::Ok(prompt_string),
271                        OutputFormat::Request => {
272                            anyhow::Ok(serde_json::to_string_pretty(&request)?)
273                        }
274                        OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({
275                            "request": request,
276                            "prompt": prompt_string,
277                            "section_labels": section_labels,
278                        }))?),
279                    }
280                })
281            })?
282            .await?;
283        Ok(GetContextOutput::Zeta2(output))
284    } else {
285        let prompt_for_events = move || (events, 0);
286        Ok(GetContextOutput::Zeta1(
287            cx.update(|cx| {
288                zeta::gather_context(
289                    full_path_str,
290                    &snapshot,
291                    clipped_cursor,
292                    prompt_for_events,
293                    cx,
294                )
295            })?
296            .await?,
297        ))
298    }
299}
300
301impl Zeta2Args {
302    fn to_options(&self, omit_excerpt_overlaps: bool) -> zeta2::ZetaOptions {
303        zeta2::ZetaOptions {
304            context: EditPredictionContextOptions {
305                max_retrieved_declarations: self.max_retrieved_definitions,
306                use_imports: !self.disable_imports_gathering,
307                excerpt: EditPredictionExcerptOptions {
308                    max_bytes: self.max_excerpt_bytes,
309                    min_bytes: self.min_excerpt_bytes,
310                    target_before_cursor_over_total_bytes: self
311                        .target_before_cursor_over_total_bytes,
312                },
313                score: EditPredictionScoreOptions {
314                    omit_excerpt_overlaps,
315                },
316            },
317            max_diagnostic_bytes: self.max_diagnostic_bytes,
318            max_prompt_bytes: self.max_prompt_bytes,
319            prompt_format: self.prompt_format.clone().into(),
320            file_indexing_parallelism: self.file_indexing_parallelism,
321        }
322    }
323}
324
325fn main() {
326    zlog::init();
327    zlog::init_output_stderr();
328    let args = ZetaCliArgs::parse();
329    let http_client = Arc::new(ReqwestClient::new());
330    let app = Application::headless().with_http_client(http_client);
331
332    app.run(move |cx| {
333        let app_state = Arc::new(headless::init(cx));
334        cx.spawn(async move |cx| {
335            let result = match args.command {
336                Commands::Zeta2Context {
337                    zeta2_args,
338                    context_args,
339                } => match get_context(Some(zeta2_args), context_args, &app_state, cx).await {
340                    Ok(GetContextOutput::Zeta1 { .. }) => unreachable!(),
341                    Ok(GetContextOutput::Zeta2(output)) => Ok(output),
342                    Err(err) => Err(err),
343                },
344                Commands::Context(context_args) => {
345                    match get_context(None, context_args, &app_state, cx).await {
346                        Ok(GetContextOutput::Zeta1(output)) => {
347                            Ok(serde_json::to_string_pretty(&output.body).unwrap())
348                        }
349                        Ok(GetContextOutput::Zeta2 { .. }) => unreachable!(),
350                        Err(err) => Err(err),
351                    }
352                }
353                Commands::Predict {
354                    predict_edits_body,
355                    context_args,
356                } => {
357                    cx.spawn(async move |cx| {
358                        let app_version = cx.update(|cx| AppVersion::global(cx))?;
359                        app_state.client.sign_in(true, cx).await?;
360                        let llm_token = LlmApiToken::default();
361                        llm_token.refresh(&app_state.client).await?;
362
363                        let predict_edits_body =
364                            if let Some(predict_edits_body) = predict_edits_body {
365                                serde_json::from_str(&predict_edits_body.read_to_string().await?)?
366                            } else if let Some(context_args) = context_args {
367                                match get_context(None, context_args, &app_state, cx).await? {
368                                    GetContextOutput::Zeta1(output) => output.body,
369                                    GetContextOutput::Zeta2 { .. } => unreachable!(),
370                                }
371                            } else {
372                                return Err(anyhow!(
373                                    "Expected either --predict-edits-body-file \
374                                    or the required args of the `context` command."
375                                ));
376                            };
377
378                        let (response, _usage) =
379                            Zeta::perform_predict_edits(PerformPredictEditsParams {
380                                client: app_state.client.clone(),
381                                llm_token,
382                                app_version,
383                                body: predict_edits_body,
384                            })
385                            .await?;
386
387                        Ok(response.output_excerpt)
388                    })
389                    .await
390                }
391                Commands::RetrievalStats {
392                    zeta2_args,
393                    worktree,
394                    extension,
395                    limit,
396                    skip,
397                } => {
398                    retrieval_stats(
399                        worktree,
400                        app_state,
401                        extension,
402                        limit,
403                        skip,
404                        (&zeta2_args).to_options(false),
405                        cx,
406                    )
407                    .await
408                }
409            };
410            match result {
411                Ok(output) => {
412                    println!("{}", output);
413                    let _ = cx.update(|cx| cx.quit());
414                }
415                Err(e) => {
416                    eprintln!("Failed: {:?}", e);
417                    exit(1);
418                }
419            }
420        })
421        .detach();
422    });
423}