main.rs

  1mod headless;
  2
  3use anyhow::{Result, anyhow};
  4use clap::{Args, Parser, Subcommand};
  5use edit_prediction_context::EditPredictionExcerptOptions;
  6use futures::channel::mpsc;
  7use futures::{FutureExt as _, StreamExt as _};
  8use gpui::{AppContext, Application, AsyncApp};
  9use gpui::{Entity, Task};
 10use language::Bias;
 11use language::Buffer;
 12use language::Point;
 13use language_model::LlmApiToken;
 14use project::{Project, ProjectPath, Worktree};
 15use release_channel::AppVersion;
 16use reqwest_client::ReqwestClient;
 17use std::path::{Path, PathBuf};
 18use std::process::exit;
 19use std::str::FromStr;
 20use std::sync::Arc;
 21use std::time::Duration;
 22use zeta::{PerformPredictEditsParams, Zeta};
 23
 24use crate::headless::ZetaCliAppState;
 25
 26#[derive(Parser, Debug)]
 27#[command(name = "zeta")]
 28struct ZetaCliArgs {
 29    #[command(subcommand)]
 30    command: Commands,
 31}
 32
 33#[derive(Subcommand, Debug)]
 34enum Commands {
 35    Context(ContextArgs),
 36    Zeta2Context {
 37        #[clap(flatten)]
 38        zeta2_args: Zeta2Args,
 39        #[clap(flatten)]
 40        context_args: ContextArgs,
 41    },
 42    Predict {
 43        #[arg(long)]
 44        predict_edits_body: Option<FileOrStdin>,
 45        #[clap(flatten)]
 46        context_args: Option<ContextArgs>,
 47    },
 48}
 49
 50#[derive(Debug, Args)]
 51#[group(requires = "worktree")]
 52struct ContextArgs {
 53    #[arg(long)]
 54    worktree: PathBuf,
 55    #[arg(long)]
 56    cursor: CursorPosition,
 57    #[arg(long)]
 58    use_language_server: bool,
 59    #[arg(long)]
 60    events: Option<FileOrStdin>,
 61}
 62
 63#[derive(Debug, Args)]
 64struct Zeta2Args {
 65    #[arg(long, default_value_t = 8192)]
 66    prompt_max_bytes: usize,
 67    #[arg(long, default_value_t = 2048)]
 68    excerpt_max_bytes: usize,
 69    #[arg(long, default_value_t = 1024)]
 70    excerpt_min_bytes: usize,
 71    #[arg(long, default_value_t = 0.66)]
 72    target_before_cursor_over_total_bytes: f32,
 73}
 74
 75#[derive(Debug, Clone)]
 76enum FileOrStdin {
 77    File(PathBuf),
 78    Stdin,
 79}
 80
 81impl FileOrStdin {
 82    async fn read_to_string(&self) -> Result<String, std::io::Error> {
 83        match self {
 84            FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
 85            FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
 86        }
 87    }
 88}
 89
 90impl FromStr for FileOrStdin {
 91    type Err = <PathBuf as FromStr>::Err;
 92
 93    fn from_str(s: &str) -> Result<Self, Self::Err> {
 94        match s {
 95            "-" => Ok(Self::Stdin),
 96            _ => Ok(Self::File(PathBuf::from_str(s)?)),
 97        }
 98    }
 99}
100
101#[derive(Debug, Clone)]
102struct CursorPosition {
103    path: PathBuf,
104    point: Point,
105}
106
107impl FromStr for CursorPosition {
108    type Err = anyhow::Error;
109
110    fn from_str(s: &str) -> Result<Self> {
111        let parts: Vec<&str> = s.split(':').collect();
112        if parts.len() != 3 {
113            return Err(anyhow!(
114                "Invalid cursor format. Expected 'file.rs:line:column', got '{}'",
115                s
116            ));
117        }
118
119        let path = PathBuf::from(parts[0]);
120        let line: u32 = parts[1]
121            .parse()
122            .map_err(|_| anyhow!("Invalid line number: '{}'", parts[1]))?;
123        let column: u32 = parts[2]
124            .parse()
125            .map_err(|_| anyhow!("Invalid column number: '{}'", parts[2]))?;
126
127        // Convert from 1-based to 0-based indexing
128        let point = Point::new(line.saturating_sub(1), column.saturating_sub(1));
129
130        Ok(CursorPosition { path, point })
131    }
132}
133
134enum GetContextOutput {
135    Zeta1(zeta::GatherContextOutput),
136    Zeta2(String),
137}
138
139async fn get_context(
140    zeta2_args: Option<Zeta2Args>,
141    args: ContextArgs,
142    app_state: &Arc<ZetaCliAppState>,
143    cx: &mut AsyncApp,
144) -> Result<GetContextOutput> {
145    let ContextArgs {
146        worktree: worktree_path,
147        cursor,
148        use_language_server,
149        events,
150    } = args;
151
152    let worktree_path = worktree_path.canonicalize()?;
153    if cursor.path.is_absolute() {
154        return Err(anyhow!("Absolute paths are not supported in --cursor"));
155    }
156
157    let project = cx.update(|cx| {
158        Project::local(
159            app_state.client.clone(),
160            app_state.node_runtime.clone(),
161            app_state.user_store.clone(),
162            app_state.languages.clone(),
163            app_state.fs.clone(),
164            None,
165            cx,
166        )
167    })?;
168
169    let worktree = project
170        .update(cx, |project, cx| {
171            project.create_worktree(&worktree_path, true, cx)
172        })?
173        .await?;
174
175    let (_lsp_open_handle, buffer) = if use_language_server {
176        let (lsp_open_handle, buffer) =
177            open_buffer_with_language_server(&project, &worktree, &cursor.path, cx).await?;
178        (Some(lsp_open_handle), buffer)
179    } else {
180        let buffer = open_buffer(&project, &worktree, &cursor.path, cx).await?;
181        (None, buffer)
182    };
183
184    let worktree_name = worktree_path
185        .file_name()
186        .ok_or_else(|| anyhow!("--worktree path must end with a folder name"))?;
187    let full_path_str = PathBuf::from(worktree_name)
188        .join(&cursor.path)
189        .to_string_lossy()
190        .to_string();
191
192    let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
193    let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
194    if clipped_cursor != cursor.point {
195        let max_row = snapshot.max_point().row;
196        if cursor.point.row < max_row {
197            return Err(anyhow!(
198                "Cursor position {:?} is out of bounds (line length is {})",
199                cursor.point,
200                snapshot.line_len(cursor.point.row)
201            ));
202        } else {
203            return Err(anyhow!(
204                "Cursor position {:?} is out of bounds (max row is {})",
205                cursor.point,
206                max_row
207            ));
208        }
209    }
210
211    let events = match events {
212        Some(events) => events.read_to_string().await?,
213        None => String::new(),
214    };
215
216    if let Some(zeta2_args) = zeta2_args {
217        Ok(GetContextOutput::Zeta2(
218            cx.update(|cx| {
219                let zeta = cx.new(|cx| {
220                    zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
221                });
222                zeta.update(cx, |zeta, cx| {
223                    zeta.register_buffer(&buffer, &project, cx);
224                    zeta.excerpt_options = EditPredictionExcerptOptions {
225                        max_bytes: zeta2_args.excerpt_max_bytes,
226                        min_bytes: zeta2_args.excerpt_min_bytes,
227                        target_before_cursor_over_total_bytes: zeta2_args
228                            .target_before_cursor_over_total_bytes,
229                    }
230                });
231                // TODO: Actually wait for indexing.
232                let timer = cx.background_executor().timer(Duration::from_secs(5));
233                cx.spawn(async move |cx| {
234                    timer.await;
235                    let request = zeta
236                        .update(cx, |zeta, cx| {
237                            let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
238                            zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
239                        })?
240                        .await?;
241                    let planned_prompt = cloud_zeta2_prompt::PlannedPrompt::populate(
242                        &request,
243                        &cloud_zeta2_prompt::PlanOptions {
244                            max_bytes: zeta2_args.prompt_max_bytes,
245                        },
246                    )?;
247                    anyhow::Ok(planned_prompt.to_prompt_string())
248                })
249            })?
250            .await?,
251        ))
252    } else {
253        let prompt_for_events = move || (events, 0);
254        Ok(GetContextOutput::Zeta1(
255            cx.update(|cx| {
256                zeta::gather_context(
257                    full_path_str,
258                    &snapshot,
259                    clipped_cursor,
260                    prompt_for_events,
261                    cx,
262                )
263            })?
264            .await?,
265        ))
266    }
267}
268
269pub async fn open_buffer(
270    project: &Entity<Project>,
271    worktree: &Entity<Worktree>,
272    path: &Path,
273    cx: &mut AsyncApp,
274) -> Result<Entity<Buffer>> {
275    let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath {
276        worktree_id: worktree.id(),
277        path: path.to_path_buf().into(),
278    })?;
279
280    project
281        .update(cx, |project, cx| project.open_buffer(project_path, cx))?
282        .await
283}
284
285pub async fn open_buffer_with_language_server(
286    project: &Entity<Project>,
287    worktree: &Entity<Worktree>,
288    path: &Path,
289    cx: &mut AsyncApp,
290) -> Result<(Entity<Entity<Buffer>>, Entity<Buffer>)> {
291    let buffer = open_buffer(project, worktree, path, cx).await?;
292
293    let lsp_open_handle = project.update(cx, |project, cx| {
294        project.register_buffer_with_language_servers(&buffer, cx)
295    })?;
296
297    let log_prefix = path.to_string_lossy().to_string();
298    wait_for_lang_server(&project, &buffer, log_prefix, cx).await?;
299
300    Ok((lsp_open_handle, buffer))
301}
302
303// TODO: Dedupe with similar function in crates/eval/src/instance.rs
304pub fn wait_for_lang_server(
305    project: &Entity<Project>,
306    buffer: &Entity<Buffer>,
307    log_prefix: String,
308    cx: &mut AsyncApp,
309) -> Task<Result<()>> {
310    println!("{}⏵ Waiting for language server", log_prefix);
311
312    let (mut tx, mut rx) = mpsc::channel(1);
313
314    let lsp_store = project
315        .read_with(cx, |project, _| project.lsp_store())
316        .unwrap();
317
318    let has_lang_server = buffer
319        .update(cx, |buffer, cx| {
320            lsp_store.update(cx, |lsp_store, cx| {
321                lsp_store
322                    .language_servers_for_local_buffer(buffer, cx)
323                    .next()
324                    .is_some()
325            })
326        })
327        .unwrap_or(false);
328
329    if has_lang_server {
330        project
331            .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
332            .unwrap()
333            .detach();
334    }
335
336    let subscriptions = [
337        cx.subscribe(&lsp_store, {
338            let log_prefix = log_prefix.clone();
339            move |_, event, _| {
340                if let project::LspStoreEvent::LanguageServerUpdate {
341                    message:
342                        client::proto::update_language_server::Variant::WorkProgress(
343                            client::proto::LspWorkProgress {
344                                message: Some(message),
345                                ..
346                            },
347                        ),
348                    ..
349                } = event
350                {
351                    println!("{}{message}", log_prefix)
352                }
353            }
354        }),
355        cx.subscribe(project, {
356            let buffer = buffer.clone();
357            move |project, event, cx| match event {
358                project::Event::LanguageServerAdded(_, _, _) => {
359                    let buffer = buffer.clone();
360                    project
361                        .update(cx, |project, cx| project.save_buffer(buffer, cx))
362                        .detach();
363                }
364                project::Event::DiskBasedDiagnosticsFinished { .. } => {
365                    tx.try_send(()).ok();
366                }
367                _ => {}
368            }
369        }),
370    ];
371
372    cx.spawn(async move |cx| {
373        let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0));
374        let result = futures::select! {
375            _ = rx.next() => {
376                println!("{}⚑ Language server idle", log_prefix);
377                anyhow::Ok(())
378            },
379            _ = timeout.fuse() => {
380                anyhow::bail!("LSP wait timed out after 5 minutes");
381            }
382        };
383        drop(subscriptions);
384        result
385    })
386}
387
388fn main() {
389    let args = ZetaCliArgs::parse();
390    let http_client = Arc::new(ReqwestClient::new());
391    let app = Application::headless().with_http_client(http_client);
392
393    app.run(move |cx| {
394        let app_state = Arc::new(headless::init(cx));
395        let is_zeta2_context_command = matches!(args.command, Commands::Zeta2Context { .. });
396        cx.spawn(async move |cx| {
397            let result = match args.command {
398                Commands::Zeta2Context {
399                    zeta2_args,
400                    context_args,
401                } => match get_context(Some(zeta2_args), context_args, &app_state, cx).await {
402                    Ok(GetContextOutput::Zeta1 { .. }) => unreachable!(),
403                    Ok(GetContextOutput::Zeta2(output)) => Ok(output),
404                    Err(err) => Err(err),
405                },
406                Commands::Context(context_args) => {
407                    match get_context(None, context_args, &app_state, cx).await {
408                        Ok(GetContextOutput::Zeta1(output)) => {
409                            Ok(serde_json::to_string_pretty(&output.body).unwrap())
410                        }
411                        Ok(GetContextOutput::Zeta2 { .. }) => unreachable!(),
412                        Err(err) => Err(err),
413                    }
414                }
415                Commands::Predict {
416                    predict_edits_body,
417                    context_args,
418                } => {
419                    cx.spawn(async move |cx| {
420                        let app_version = cx.update(|cx| AppVersion::global(cx))?;
421                        app_state.client.sign_in(true, cx).await?;
422                        let llm_token = LlmApiToken::default();
423                        llm_token.refresh(&app_state.client).await?;
424
425                        let predict_edits_body =
426                            if let Some(predict_edits_body) = predict_edits_body {
427                                serde_json::from_str(&predict_edits_body.read_to_string().await?)?
428                            } else if let Some(context_args) = context_args {
429                                match get_context(None, context_args, &app_state, cx).await? {
430                                    GetContextOutput::Zeta1(output) => output.body,
431                                    GetContextOutput::Zeta2 { .. } => unreachable!(),
432                                }
433                            } else {
434                                return Err(anyhow!(
435                                    "Expected either --predict-edits-body-file \
436                                    or the required args of the `context` command."
437                                ));
438                            };
439
440                        let (response, _usage) =
441                            Zeta::perform_predict_edits(PerformPredictEditsParams {
442                                client: app_state.client.clone(),
443                                llm_token,
444                                app_version,
445                                body: predict_edits_body,
446                            })
447                            .await?;
448
449                        Ok(response.output_excerpt)
450                    })
451                    .await
452                }
453            };
454            match result {
455                Ok(output) => {
456                    println!("{}", output);
457                    // TODO: Remove this once the 5 second delay is properly replaced.
458                    if is_zeta2_context_command {
459                        eprintln!("Note that zeta2-context doesn't yet wait for indexing, instead waits 5 seconds.");
460                    }
461                    let _ = cx.update(|cx| cx.quit());
462                }
463                Err(e) => {
464                    eprintln!("Failed: {:?}", e);
465                    exit(1);
466                }
467            }
468        })
469        .detach();
470    });
471}