main.rs

  1mod headless;
  2
  3use anyhow::{Result, anyhow};
  4use clap::{Args, Parser, Subcommand};
  5use edit_prediction_context::EditPredictionExcerptOptions;
  6use futures::channel::mpsc;
  7use futures::{FutureExt as _, StreamExt as _};
  8use gpui::{AppContext, Application, AsyncApp};
  9use gpui::{Entity, Task};
 10use language::Bias;
 11use language::Buffer;
 12use language::Point;
 13use language_model::LlmApiToken;
 14use project::{Project, ProjectPath, Worktree};
 15use release_channel::AppVersion;
 16use reqwest_client::ReqwestClient;
 17use std::path::{Path, PathBuf};
 18use std::process::exit;
 19use std::str::FromStr;
 20use std::sync::Arc;
 21use std::time::Duration;
 22use zeta::{PerformPredictEditsParams, Zeta};
 23
 24use crate::headless::ZetaCliAppState;
 25
 26#[derive(Parser, Debug)]
 27#[command(name = "zeta")]
 28struct ZetaCliArgs {
 29    #[command(subcommand)]
 30    command: Commands,
 31}
 32
 33#[derive(Subcommand, Debug)]
 34enum Commands {
 35    Context(ContextArgs),
 36    Zeta2Context {
 37        #[clap(flatten)]
 38        zeta2_args: Zeta2Args,
 39        #[clap(flatten)]
 40        context_args: ContextArgs,
 41    },
 42    Predict {
 43        #[arg(long)]
 44        predict_edits_body: Option<FileOrStdin>,
 45        #[clap(flatten)]
 46        context_args: Option<ContextArgs>,
 47    },
 48}
 49
 50#[derive(Debug, Args)]
 51#[group(requires = "worktree")]
 52struct ContextArgs {
 53    #[arg(long)]
 54    worktree: PathBuf,
 55    #[arg(long)]
 56    cursor: CursorPosition,
 57    #[arg(long)]
 58    use_language_server: bool,
 59    #[arg(long)]
 60    events: Option<FileOrStdin>,
 61}
 62
 63#[derive(Debug, Args)]
 64struct Zeta2Args {
 65    #[arg(long, default_value_t = 8192)]
 66    max_prompt_bytes: usize,
 67    #[arg(long, default_value_t = 2048)]
 68    max_excerpt_bytes: usize,
 69    #[arg(long, default_value_t = 1024)]
 70    min_excerpt_bytes: usize,
 71    #[arg(long, default_value_t = 0.66)]
 72    target_before_cursor_over_total_bytes: f32,
 73    #[arg(long, default_value_t = 1024)]
 74    max_diagnostic_bytes: usize,
 75}
 76
 77#[derive(Debug, Clone)]
 78enum FileOrStdin {
 79    File(PathBuf),
 80    Stdin,
 81}
 82
 83impl FileOrStdin {
 84    async fn read_to_string(&self) -> Result<String, std::io::Error> {
 85        match self {
 86            FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
 87            FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
 88        }
 89    }
 90}
 91
 92impl FromStr for FileOrStdin {
 93    type Err = <PathBuf as FromStr>::Err;
 94
 95    fn from_str(s: &str) -> Result<Self, Self::Err> {
 96        match s {
 97            "-" => Ok(Self::Stdin),
 98            _ => Ok(Self::File(PathBuf::from_str(s)?)),
 99        }
100    }
101}
102
103#[derive(Debug, Clone)]
104struct CursorPosition {
105    path: PathBuf,
106    point: Point,
107}
108
109impl FromStr for CursorPosition {
110    type Err = anyhow::Error;
111
112    fn from_str(s: &str) -> Result<Self> {
113        let parts: Vec<&str> = s.split(':').collect();
114        if parts.len() != 3 {
115            return Err(anyhow!(
116                "Invalid cursor format. Expected 'file.rs:line:column', got '{}'",
117                s
118            ));
119        }
120
121        let path = PathBuf::from(parts[0]);
122        let line: u32 = parts[1]
123            .parse()
124            .map_err(|_| anyhow!("Invalid line number: '{}'", parts[1]))?;
125        let column: u32 = parts[2]
126            .parse()
127            .map_err(|_| anyhow!("Invalid column number: '{}'", parts[2]))?;
128
129        // Convert from 1-based to 0-based indexing
130        let point = Point::new(line.saturating_sub(1), column.saturating_sub(1));
131
132        Ok(CursorPosition { path, point })
133    }
134}
135
136enum GetContextOutput {
137    Zeta1(zeta::GatherContextOutput),
138    Zeta2(String),
139}
140
141async fn get_context(
142    zeta2_args: Option<Zeta2Args>,
143    args: ContextArgs,
144    app_state: &Arc<ZetaCliAppState>,
145    cx: &mut AsyncApp,
146) -> Result<GetContextOutput> {
147    let ContextArgs {
148        worktree: worktree_path,
149        cursor,
150        use_language_server,
151        events,
152    } = args;
153
154    let worktree_path = worktree_path.canonicalize()?;
155    if cursor.path.is_absolute() {
156        return Err(anyhow!("Absolute paths are not supported in --cursor"));
157    }
158
159    let project = cx.update(|cx| {
160        Project::local(
161            app_state.client.clone(),
162            app_state.node_runtime.clone(),
163            app_state.user_store.clone(),
164            app_state.languages.clone(),
165            app_state.fs.clone(),
166            None,
167            cx,
168        )
169    })?;
170
171    let worktree = project
172        .update(cx, |project, cx| {
173            project.create_worktree(&worktree_path, true, cx)
174        })?
175        .await?;
176
177    let (_lsp_open_handle, buffer) = if use_language_server {
178        let (lsp_open_handle, buffer) =
179            open_buffer_with_language_server(&project, &worktree, &cursor.path, cx).await?;
180        (Some(lsp_open_handle), buffer)
181    } else {
182        let buffer = open_buffer(&project, &worktree, &cursor.path, cx).await?;
183        (None, buffer)
184    };
185
186    let worktree_name = worktree_path
187        .file_name()
188        .ok_or_else(|| anyhow!("--worktree path must end with a folder name"))?;
189    let full_path_str = PathBuf::from(worktree_name)
190        .join(&cursor.path)
191        .to_string_lossy()
192        .to_string();
193
194    let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
195    let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
196    if clipped_cursor != cursor.point {
197        let max_row = snapshot.max_point().row;
198        if cursor.point.row < max_row {
199            return Err(anyhow!(
200                "Cursor position {:?} is out of bounds (line length is {})",
201                cursor.point,
202                snapshot.line_len(cursor.point.row)
203            ));
204        } else {
205            return Err(anyhow!(
206                "Cursor position {:?} is out of bounds (max row is {})",
207                cursor.point,
208                max_row
209            ));
210        }
211    }
212
213    let events = match events {
214        Some(events) => events.read_to_string().await?,
215        None => String::new(),
216    };
217
218    if let Some(zeta2_args) = zeta2_args {
219        Ok(GetContextOutput::Zeta2(
220            cx.update(|cx| {
221                let zeta = cx.new(|cx| {
222                    zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
223                });
224                zeta.update(cx, |zeta, cx| {
225                    zeta.register_buffer(&buffer, &project, cx);
226                    zeta.set_options(zeta2::ZetaOptions {
227                        excerpt: EditPredictionExcerptOptions {
228                            max_bytes: zeta2_args.max_excerpt_bytes,
229                            min_bytes: zeta2_args.min_excerpt_bytes,
230                            target_before_cursor_over_total_bytes: zeta2_args
231                                .target_before_cursor_over_total_bytes,
232                        },
233                        max_diagnostic_bytes: zeta2_args.max_diagnostic_bytes,
234                        max_prompt_bytes: zeta2_args.max_prompt_bytes,
235                    })
236                });
237                // TODO: Actually wait for indexing.
238                let timer = cx.background_executor().timer(Duration::from_secs(5));
239                cx.spawn(async move |cx| {
240                    timer.await;
241                    let request = zeta
242                        .update(cx, |zeta, cx| {
243                            let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
244                            zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
245                        })?
246                        .await?;
247                    let planned_prompt = cloud_zeta2_prompt::PlannedPrompt::populate(
248                        &request,
249                        &cloud_zeta2_prompt::PlanOptions {
250                            max_bytes: zeta2_args.max_prompt_bytes,
251                        },
252                    )?;
253                    anyhow::Ok(planned_prompt.to_prompt_string())
254                })
255            })?
256            .await?,
257        ))
258    } else {
259        let prompt_for_events = move || (events, 0);
260        Ok(GetContextOutput::Zeta1(
261            cx.update(|cx| {
262                zeta::gather_context(
263                    full_path_str,
264                    &snapshot,
265                    clipped_cursor,
266                    prompt_for_events,
267                    cx,
268                )
269            })?
270            .await?,
271        ))
272    }
273}
274
275pub async fn open_buffer(
276    project: &Entity<Project>,
277    worktree: &Entity<Worktree>,
278    path: &Path,
279    cx: &mut AsyncApp,
280) -> Result<Entity<Buffer>> {
281    let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath {
282        worktree_id: worktree.id(),
283        path: path.to_path_buf().into(),
284    })?;
285
286    project
287        .update(cx, |project, cx| project.open_buffer(project_path, cx))?
288        .await
289}
290
291pub async fn open_buffer_with_language_server(
292    project: &Entity<Project>,
293    worktree: &Entity<Worktree>,
294    path: &Path,
295    cx: &mut AsyncApp,
296) -> Result<(Entity<Entity<Buffer>>, Entity<Buffer>)> {
297    let buffer = open_buffer(project, worktree, path, cx).await?;
298
299    let lsp_open_handle = project.update(cx, |project, cx| {
300        project.register_buffer_with_language_servers(&buffer, cx)
301    })?;
302
303    let log_prefix = path.to_string_lossy().to_string();
304    wait_for_lang_server(&project, &buffer, log_prefix, cx).await?;
305
306    Ok((lsp_open_handle, buffer))
307}
308
309// TODO: Dedupe with similar function in crates/eval/src/instance.rs
310pub fn wait_for_lang_server(
311    project: &Entity<Project>,
312    buffer: &Entity<Buffer>,
313    log_prefix: String,
314    cx: &mut AsyncApp,
315) -> Task<Result<()>> {
316    println!("{}⏵ Waiting for language server", log_prefix);
317
318    let (mut tx, mut rx) = mpsc::channel(1);
319
320    let lsp_store = project
321        .read_with(cx, |project, _| project.lsp_store())
322        .unwrap();
323
324    let has_lang_server = buffer
325        .update(cx, |buffer, cx| {
326            lsp_store.update(cx, |lsp_store, cx| {
327                lsp_store
328                    .language_servers_for_local_buffer(buffer, cx)
329                    .next()
330                    .is_some()
331            })
332        })
333        .unwrap_or(false);
334
335    if has_lang_server {
336        project
337            .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
338            .unwrap()
339            .detach();
340    }
341
342    let subscriptions = [
343        cx.subscribe(&lsp_store, {
344            let log_prefix = log_prefix.clone();
345            move |_, event, _| {
346                if let project::LspStoreEvent::LanguageServerUpdate {
347                    message:
348                        client::proto::update_language_server::Variant::WorkProgress(
349                            client::proto::LspWorkProgress {
350                                message: Some(message),
351                                ..
352                            },
353                        ),
354                    ..
355                } = event
356                {
357                    println!("{}{message}", log_prefix)
358                }
359            }
360        }),
361        cx.subscribe(project, {
362            let buffer = buffer.clone();
363            move |project, event, cx| match event {
364                project::Event::LanguageServerAdded(_, _, _) => {
365                    let buffer = buffer.clone();
366                    project
367                        .update(cx, |project, cx| project.save_buffer(buffer, cx))
368                        .detach();
369                }
370                project::Event::DiskBasedDiagnosticsFinished { .. } => {
371                    tx.try_send(()).ok();
372                }
373                _ => {}
374            }
375        }),
376    ];
377
378    cx.spawn(async move |cx| {
379        let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0));
380        let result = futures::select! {
381            _ = rx.next() => {
382                println!("{}⚑ Language server idle", log_prefix);
383                anyhow::Ok(())
384            },
385            _ = timeout.fuse() => {
386                anyhow::bail!("LSP wait timed out after 5 minutes");
387            }
388        };
389        drop(subscriptions);
390        result
391    })
392}
393
394fn main() {
395    let args = ZetaCliArgs::parse();
396    let http_client = Arc::new(ReqwestClient::new());
397    let app = Application::headless().with_http_client(http_client);
398
399    app.run(move |cx| {
400        let app_state = Arc::new(headless::init(cx));
401        let is_zeta2_context_command = matches!(args.command, Commands::Zeta2Context { .. });
402        cx.spawn(async move |cx| {
403            let result = match args.command {
404                Commands::Zeta2Context {
405                    zeta2_args,
406                    context_args,
407                } => match get_context(Some(zeta2_args), context_args, &app_state, cx).await {
408                    Ok(GetContextOutput::Zeta1 { .. }) => unreachable!(),
409                    Ok(GetContextOutput::Zeta2(output)) => Ok(output),
410                    Err(err) => Err(err),
411                },
412                Commands::Context(context_args) => {
413                    match get_context(None, context_args, &app_state, cx).await {
414                        Ok(GetContextOutput::Zeta1(output)) => {
415                            Ok(serde_json::to_string_pretty(&output.body).unwrap())
416                        }
417                        Ok(GetContextOutput::Zeta2 { .. }) => unreachable!(),
418                        Err(err) => Err(err),
419                    }
420                }
421                Commands::Predict {
422                    predict_edits_body,
423                    context_args,
424                } => {
425                    cx.spawn(async move |cx| {
426                        let app_version = cx.update(|cx| AppVersion::global(cx))?;
427                        app_state.client.sign_in(true, cx).await?;
428                        let llm_token = LlmApiToken::default();
429                        llm_token.refresh(&app_state.client).await?;
430
431                        let predict_edits_body =
432                            if let Some(predict_edits_body) = predict_edits_body {
433                                serde_json::from_str(&predict_edits_body.read_to_string().await?)?
434                            } else if let Some(context_args) = context_args {
435                                match get_context(None, context_args, &app_state, cx).await? {
436                                    GetContextOutput::Zeta1(output) => output.body,
437                                    GetContextOutput::Zeta2 { .. } => unreachable!(),
438                                }
439                            } else {
440                                return Err(anyhow!(
441                                    "Expected either --predict-edits-body-file \
442                                    or the required args of the `context` command."
443                                ));
444                            };
445
446                        let (response, _usage) =
447                            Zeta::perform_predict_edits(PerformPredictEditsParams {
448                                client: app_state.client.clone(),
449                                llm_token,
450                                app_version,
451                                body: predict_edits_body,
452                            })
453                            .await?;
454
455                        Ok(response.output_excerpt)
456                    })
457                    .await
458                }
459            };
460            match result {
461                Ok(output) => {
462                    println!("{}", output);
463                    // TODO: Remove this once the 5 second delay is properly replaced.
464                    if is_zeta2_context_command {
465                        eprintln!("Note that zeta2-context doesn't yet wait for indexing, instead waits 5 seconds.");
466                    }
467                    let _ = cx.update(|cx| cx.quit());
468                }
469                Err(e) => {
470                    eprintln!("Failed: {:?}", e);
471                    exit(1);
472                }
473            }
474        })
475        .detach();
476    });
477}