main.rs

  1mod headless;
  2
  3use anyhow::{Result, anyhow};
  4use clap::{Args, Parser, Subcommand};
  5use cloud_llm_client::predict_edits_v3;
  6use edit_prediction_context::EditPredictionExcerptOptions;
  7use futures::channel::mpsc;
  8use futures::{FutureExt as _, StreamExt as _};
  9use gpui::{AppContext, Application, AsyncApp};
 10use gpui::{Entity, Task};
 11use language::Bias;
 12use language::Buffer;
 13use language::Point;
 14use language_model::LlmApiToken;
 15use project::{Project, ProjectPath, Worktree};
 16use release_channel::AppVersion;
 17use reqwest_client::ReqwestClient;
 18use serde_json::json;
 19use std::path::{Path, PathBuf};
 20use std::process::exit;
 21use std::str::FromStr;
 22use std::sync::Arc;
 23use std::time::Duration;
 24use util::paths::PathStyle;
 25use util::rel_path::RelPath;
 26use zeta::{PerformPredictEditsParams, Zeta};
 27
 28use crate::headless::ZetaCliAppState;
 29
 30#[derive(Parser, Debug)]
 31#[command(name = "zeta")]
 32struct ZetaCliArgs {
 33    #[command(subcommand)]
 34    command: Commands,
 35}
 36
 37#[derive(Subcommand, Debug)]
 38enum Commands {
 39    Context(ContextArgs),
 40    Zeta2Context {
 41        #[clap(flatten)]
 42        zeta2_args: Zeta2Args,
 43        #[clap(flatten)]
 44        context_args: ContextArgs,
 45    },
 46    Predict {
 47        #[arg(long)]
 48        predict_edits_body: Option<FileOrStdin>,
 49        #[clap(flatten)]
 50        context_args: Option<ContextArgs>,
 51    },
 52}
 53
 54#[derive(Debug, Args)]
 55#[group(requires = "worktree")]
 56struct ContextArgs {
 57    #[arg(long)]
 58    worktree: PathBuf,
 59    #[arg(long)]
 60    cursor: CursorPosition,
 61    #[arg(long)]
 62    use_language_server: bool,
 63    #[arg(long)]
 64    events: Option<FileOrStdin>,
 65}
 66
 67#[derive(Debug, Args)]
 68struct Zeta2Args {
 69    #[arg(long, default_value_t = 8192)]
 70    max_prompt_bytes: usize,
 71    #[arg(long, default_value_t = 2048)]
 72    max_excerpt_bytes: usize,
 73    #[arg(long, default_value_t = 1024)]
 74    min_excerpt_bytes: usize,
 75    #[arg(long, default_value_t = 0.66)]
 76    target_before_cursor_over_total_bytes: f32,
 77    #[arg(long, default_value_t = 1024)]
 78    max_diagnostic_bytes: usize,
 79    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
 80    prompt_format: PromptFormat,
 81    #[arg(long, value_enum, default_value_t = Default::default())]
 82    output_format: OutputFormat,
 83    #[arg(long, default_value_t = 42)]
 84    file_indexing_parallelism: usize,
 85}
 86
 87#[derive(clap::ValueEnum, Default, Debug, Clone)]
 88enum PromptFormat {
 89    #[default]
 90    MarkedExcerpt,
 91    LabeledSections,
 92    OnlySnippets,
 93}
 94
 95impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
 96    fn into(self) -> predict_edits_v3::PromptFormat {
 97        match self {
 98            Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
 99            Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
100            Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
101        }
102    }
103}
104
105#[derive(clap::ValueEnum, Default, Debug, Clone)]
106enum OutputFormat {
107    #[default]
108    Prompt,
109    Request,
110    Both,
111}
112
113#[derive(Debug, Clone)]
114enum FileOrStdin {
115    File(PathBuf),
116    Stdin,
117}
118
119impl FileOrStdin {
120    async fn read_to_string(&self) -> Result<String, std::io::Error> {
121        match self {
122            FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
123            FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
124        }
125    }
126}
127
128impl FromStr for FileOrStdin {
129    type Err = <PathBuf as FromStr>::Err;
130
131    fn from_str(s: &str) -> Result<Self, Self::Err> {
132        match s {
133            "-" => Ok(Self::Stdin),
134            _ => Ok(Self::File(PathBuf::from_str(s)?)),
135        }
136    }
137}
138
139#[derive(Debug, Clone)]
140struct CursorPosition {
141    path: Arc<RelPath>,
142    point: Point,
143}
144
145impl FromStr for CursorPosition {
146    type Err = anyhow::Error;
147
148    fn from_str(s: &str) -> Result<Self> {
149        let parts: Vec<&str> = s.split(':').collect();
150        if parts.len() != 3 {
151            return Err(anyhow!(
152                "Invalid cursor format. Expected 'file.rs:line:column', got '{}'",
153                s
154            ));
155        }
156
157        let path = RelPath::new(Path::new(&parts[0]), PathStyle::local())?.into_arc();
158        let line: u32 = parts[1]
159            .parse()
160            .map_err(|_| anyhow!("Invalid line number: '{}'", parts[1]))?;
161        let column: u32 = parts[2]
162            .parse()
163            .map_err(|_| anyhow!("Invalid column number: '{}'", parts[2]))?;
164
165        // Convert from 1-based to 0-based indexing
166        let point = Point::new(line.saturating_sub(1), column.saturating_sub(1));
167
168        Ok(CursorPosition { path, point })
169    }
170}
171
172enum GetContextOutput {
173    Zeta1(zeta::GatherContextOutput),
174    Zeta2(String),
175}
176
177async fn get_context(
178    zeta2_args: Option<Zeta2Args>,
179    args: ContextArgs,
180    app_state: &Arc<ZetaCliAppState>,
181    cx: &mut AsyncApp,
182) -> Result<GetContextOutput> {
183    let ContextArgs {
184        worktree: worktree_path,
185        cursor,
186        use_language_server,
187        events,
188    } = args;
189
190    let worktree_path = worktree_path.canonicalize()?;
191
192    let project = cx.update(|cx| {
193        Project::local(
194            app_state.client.clone(),
195            app_state.node_runtime.clone(),
196            app_state.user_store.clone(),
197            app_state.languages.clone(),
198            app_state.fs.clone(),
199            None,
200            cx,
201        )
202    })?;
203
204    let worktree = project
205        .update(cx, |project, cx| {
206            project.create_worktree(&worktree_path, true, cx)
207        })?
208        .await?;
209
210    let (_lsp_open_handle, buffer) = if use_language_server {
211        let (lsp_open_handle, buffer) =
212            open_buffer_with_language_server(&project, &worktree, &cursor.path, cx).await?;
213        (Some(lsp_open_handle), buffer)
214    } else {
215        let buffer = open_buffer(&project, &worktree, &cursor.path, cx).await?;
216        (None, buffer)
217    };
218
219    let full_path_str = worktree
220        .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
221        .display(PathStyle::local())
222        .to_string();
223
224    let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
225    let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
226    if clipped_cursor != cursor.point {
227        let max_row = snapshot.max_point().row;
228        if cursor.point.row < max_row {
229            return Err(anyhow!(
230                "Cursor position {:?} is out of bounds (line length is {})",
231                cursor.point,
232                snapshot.line_len(cursor.point.row)
233            ));
234        } else {
235            return Err(anyhow!(
236                "Cursor position {:?} is out of bounds (max row is {})",
237                cursor.point,
238                max_row
239            ));
240        }
241    }
242
243    let events = match events {
244        Some(events) => events.read_to_string().await?,
245        None => String::new(),
246    };
247
248    if let Some(zeta2_args) = zeta2_args {
249        // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
250        // the whole worktree.
251        worktree
252            .read_with(cx, |worktree, _cx| {
253                worktree.as_local().unwrap().scan_complete()
254            })?
255            .await;
256        let output = cx
257            .update(|cx| {
258                let zeta = cx.new(|cx| {
259                    zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
260                });
261                let indexing_done_task = zeta.update(cx, |zeta, cx| {
262                    zeta.set_options(zeta2::ZetaOptions {
263                        excerpt: EditPredictionExcerptOptions {
264                            max_bytes: zeta2_args.max_excerpt_bytes,
265                            min_bytes: zeta2_args.min_excerpt_bytes,
266                            target_before_cursor_over_total_bytes: zeta2_args
267                                .target_before_cursor_over_total_bytes,
268                        },
269                        max_diagnostic_bytes: zeta2_args.max_diagnostic_bytes,
270                        max_prompt_bytes: zeta2_args.max_prompt_bytes,
271                        prompt_format: zeta2_args.prompt_format.into(),
272                        file_indexing_parallelism: zeta2_args.file_indexing_parallelism,
273                    });
274                    zeta.register_buffer(&buffer, &project, cx);
275                    zeta.wait_for_initial_indexing(&project, cx)
276                });
277                cx.spawn(async move |cx| {
278                    indexing_done_task.await?;
279                    let request = zeta
280                        .update(cx, |zeta, cx| {
281                            let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
282                            zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
283                        })?
284                        .await?;
285
286                    let planned_prompt = cloud_zeta2_prompt::PlannedPrompt::populate(&request)?;
287                    let prompt_string = planned_prompt.to_prompt_string()?.0;
288                    match zeta2_args.output_format {
289                        OutputFormat::Prompt => anyhow::Ok(prompt_string),
290                        OutputFormat::Request => {
291                            anyhow::Ok(serde_json::to_string_pretty(&request)?)
292                        }
293                        OutputFormat::Both => anyhow::Ok(serde_json::to_string_pretty(&json!({
294                            "request": request,
295                            "prompt": prompt_string,
296                        }))?),
297                    }
298                })
299            })?
300            .await?;
301        Ok(GetContextOutput::Zeta2(output))
302    } else {
303        let prompt_for_events = move || (events, 0);
304        Ok(GetContextOutput::Zeta1(
305            cx.update(|cx| {
306                zeta::gather_context(
307                    full_path_str,
308                    &snapshot,
309                    clipped_cursor,
310                    prompt_for_events,
311                    cx,
312                )
313            })?
314            .await?,
315        ))
316    }
317}
318
319pub async fn open_buffer(
320    project: &Entity<Project>,
321    worktree: &Entity<Worktree>,
322    path: &RelPath,
323    cx: &mut AsyncApp,
324) -> Result<Entity<Buffer>> {
325    let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath {
326        worktree_id: worktree.id(),
327        path: path.into(),
328    })?;
329
330    project
331        .update(cx, |project, cx| project.open_buffer(project_path, cx))?
332        .await
333}
334
335pub async fn open_buffer_with_language_server(
336    project: &Entity<Project>,
337    worktree: &Entity<Worktree>,
338    path: &RelPath,
339    cx: &mut AsyncApp,
340) -> Result<(Entity<Entity<Buffer>>, Entity<Buffer>)> {
341    let buffer = open_buffer(project, worktree, path, cx).await?;
342
343    let (lsp_open_handle, path_style) = project.update(cx, |project, cx| {
344        (
345            project.register_buffer_with_language_servers(&buffer, cx),
346            project.path_style(cx),
347        )
348    })?;
349
350    let log_prefix = path.display(path_style);
351    wait_for_lang_server(&project, &buffer, log_prefix.into_owned(), cx).await?;
352
353    Ok((lsp_open_handle, buffer))
354}
355
356// TODO: Dedupe with similar function in crates/eval/src/instance.rs
357pub fn wait_for_lang_server(
358    project: &Entity<Project>,
359    buffer: &Entity<Buffer>,
360    log_prefix: String,
361    cx: &mut AsyncApp,
362) -> Task<Result<()>> {
363    println!("{}⏵ Waiting for language server", log_prefix);
364
365    let (mut tx, mut rx) = mpsc::channel(1);
366
367    let lsp_store = project
368        .read_with(cx, |project, _| project.lsp_store())
369        .unwrap();
370
371    let has_lang_server = buffer
372        .update(cx, |buffer, cx| {
373            lsp_store.update(cx, |lsp_store, cx| {
374                lsp_store
375                    .language_servers_for_local_buffer(buffer, cx)
376                    .next()
377                    .is_some()
378            })
379        })
380        .unwrap_or(false);
381
382    if has_lang_server {
383        project
384            .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
385            .unwrap()
386            .detach();
387    }
388
389    let subscriptions = [
390        cx.subscribe(&lsp_store, {
391            let log_prefix = log_prefix.clone();
392            move |_, event, _| {
393                if let project::LspStoreEvent::LanguageServerUpdate {
394                    message:
395                        client::proto::update_language_server::Variant::WorkProgress(
396                            client::proto::LspWorkProgress {
397                                message: Some(message),
398                                ..
399                            },
400                        ),
401                    ..
402                } = event
403                {
404                    println!("{}{message}", log_prefix)
405                }
406            }
407        }),
408        cx.subscribe(project, {
409            let buffer = buffer.clone();
410            move |project, event, cx| match event {
411                project::Event::LanguageServerAdded(_, _, _) => {
412                    let buffer = buffer.clone();
413                    project
414                        .update(cx, |project, cx| project.save_buffer(buffer, cx))
415                        .detach();
416                }
417                project::Event::DiskBasedDiagnosticsFinished { .. } => {
418                    tx.try_send(()).ok();
419                }
420                _ => {}
421            }
422        }),
423    ];
424
425    cx.spawn(async move |cx| {
426        let timeout = cx.background_executor().timer(Duration::from_secs(60 * 5));
427        let result = futures::select! {
428            _ = rx.next() => {
429                println!("{}⚑ Language server idle", log_prefix);
430                anyhow::Ok(())
431            },
432            _ = timeout.fuse() => {
433                anyhow::bail!("LSP wait timed out after 5 minutes");
434            }
435        };
436        drop(subscriptions);
437        result
438    })
439}
440
441fn main() {
442    zlog::init();
443    zlog::init_output_stderr();
444    let args = ZetaCliArgs::parse();
445    let http_client = Arc::new(ReqwestClient::new());
446    let app = Application::headless().with_http_client(http_client);
447
448    app.run(move |cx| {
449        let app_state = Arc::new(headless::init(cx));
450        cx.spawn(async move |cx| {
451            let result = match args.command {
452                Commands::Zeta2Context {
453                    zeta2_args,
454                    context_args,
455                } => match get_context(Some(zeta2_args), context_args, &app_state, cx).await {
456                    Ok(GetContextOutput::Zeta1 { .. }) => unreachable!(),
457                    Ok(GetContextOutput::Zeta2(output)) => Ok(output),
458                    Err(err) => Err(err),
459                },
460                Commands::Context(context_args) => {
461                    match get_context(None, context_args, &app_state, cx).await {
462                        Ok(GetContextOutput::Zeta1(output)) => {
463                            Ok(serde_json::to_string_pretty(&output.body).unwrap())
464                        }
465                        Ok(GetContextOutput::Zeta2 { .. }) => unreachable!(),
466                        Err(err) => Err(err),
467                    }
468                }
469                Commands::Predict {
470                    predict_edits_body,
471                    context_args,
472                } => {
473                    cx.spawn(async move |cx| {
474                        let app_version = cx.update(|cx| AppVersion::global(cx))?;
475                        app_state.client.sign_in(true, cx).await?;
476                        let llm_token = LlmApiToken::default();
477                        llm_token.refresh(&app_state.client).await?;
478
479                        let predict_edits_body =
480                            if let Some(predict_edits_body) = predict_edits_body {
481                                serde_json::from_str(&predict_edits_body.read_to_string().await?)?
482                            } else if let Some(context_args) = context_args {
483                                match get_context(None, context_args, &app_state, cx).await? {
484                                    GetContextOutput::Zeta1(output) => output.body,
485                                    GetContextOutput::Zeta2 { .. } => unreachable!(),
486                                }
487                            } else {
488                                return Err(anyhow!(
489                                    "Expected either --predict-edits-body-file \
490                                    or the required args of the `context` command."
491                                ));
492                            };
493
494                        let (response, _usage) =
495                            Zeta::perform_predict_edits(PerformPredictEditsParams {
496                                client: app_state.client.clone(),
497                                llm_token,
498                                app_version,
499                                body: predict_edits_body,
500                            })
501                            .await?;
502
503                        Ok(response.output_excerpt)
504                    })
505                    .await
506                }
507            };
508            match result {
509                Ok(output) => {
510                    println!("{}", output);
511                    let _ = cx.update(|cx| cx.quit());
512                }
513                Err(e) => {
514                    eprintln!("Failed: {:?}", e);
515                    exit(1);
516                }
517            }
518        })
519        .detach();
520    });
521}