main.rs

  1mod headless;
  2
  3use anyhow::{Result, anyhow};
  4use clap::{Args, Parser, Subcommand};
  5use edit_prediction_context::EditPredictionExcerptOptions;
  6use futures::channel::mpsc;
  7use futures::{FutureExt as _, StreamExt as _};
  8use gpui::{AppContext, Application, AsyncApp};
  9use gpui::{Entity, Task};
 10use language::Bias;
 11use language::Buffer;
 12use language::Point;
 13use language_model::LlmApiToken;
 14use project::{Project, ProjectPath, Worktree};
 15use release_channel::AppVersion;
 16use reqwest_client::ReqwestClient;
 17use std::path::{Path, PathBuf};
 18use std::process::exit;
 19use std::str::FromStr;
 20use std::sync::Arc;
 21use std::time::Duration;
 22use util::paths::PathStyle;
 23use util::rel_path::RelPath;
 24use zeta::{PerformPredictEditsParams, Zeta};
 25
 26use crate::headless::ZetaCliAppState;
 27
 28#[derive(Parser, Debug)]
 29#[command(name = "zeta")]
 30struct ZetaCliArgs {
 31    #[command(subcommand)]
 32    command: Commands,
 33}
 34
 35#[derive(Subcommand, Debug)]
 36enum Commands {
 37    Context(ContextArgs),
 38    Zeta2Context {
 39        #[clap(flatten)]
 40        zeta2_args: Zeta2Args,
 41        #[clap(flatten)]
 42        context_args: ContextArgs,
 43    },
 44    Predict {
 45        #[arg(long)]
 46        predict_edits_body: Option<FileOrStdin>,
 47        #[clap(flatten)]
 48        context_args: Option<ContextArgs>,
 49    },
 50}
 51
 52#[derive(Debug, Args)]
 53#[group(requires = "worktree")]
 54struct ContextArgs {
 55    #[arg(long)]
 56    worktree: PathBuf,
 57    #[arg(long)]
 58    cursor: CursorPosition,
 59    #[arg(long)]
 60    use_language_server: bool,
 61    #[arg(long)]
 62    events: Option<FileOrStdin>,
 63}
 64
 65#[derive(Debug, Args)]
 66struct Zeta2Args {
 67    #[arg(long, default_value_t = 8192)]
 68    max_prompt_bytes: usize,
 69    #[arg(long, default_value_t = 2048)]
 70    max_excerpt_bytes: usize,
 71    #[arg(long, default_value_t = 1024)]
 72    min_excerpt_bytes: usize,
 73    #[arg(long, default_value_t = 0.66)]
 74    target_before_cursor_over_total_bytes: f32,
 75    #[arg(long, default_value_t = 1024)]
 76    max_diagnostic_bytes: usize,
 77}
 78
 79#[derive(Debug, Clone)]
 80enum FileOrStdin {
 81    File(PathBuf),
 82    Stdin,
 83}
 84
 85impl FileOrStdin {
 86    async fn read_to_string(&self) -> Result<String, std::io::Error> {
 87        match self {
 88            FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
 89            FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
 90        }
 91    }
 92}
 93
 94impl FromStr for FileOrStdin {
 95    type Err = <PathBuf as FromStr>::Err;
 96
 97    fn from_str(s: &str) -> Result<Self, Self::Err> {
 98        match s {
 99            "-" => Ok(Self::Stdin),
100            _ => Ok(Self::File(PathBuf::from_str(s)?)),
101        }
102    }
103}
104
105#[derive(Debug, Clone)]
106struct CursorPosition {
107    path: Arc<RelPath>,
108    point: Point,
109}
110
111impl FromStr for CursorPosition {
112    type Err = anyhow::Error;
113
114    fn from_str(s: &str) -> Result<Self> {
115        let parts: Vec<&str> = s.split(':').collect();
116        if parts.len() != 3 {
117            return Err(anyhow!(
118                "Invalid cursor format. Expected 'file.rs:line:column', got '{}'",
119                s
120            ));
121        }
122
123        let path = RelPath::from_std_path(Path::new(&parts[0]), PathStyle::local())?;
124        let line: u32 = parts[1]
125            .parse()
126            .map_err(|_| anyhow!("Invalid line number: '{}'", parts[1]))?;
127        let column: u32 = parts[2]
128            .parse()
129            .map_err(|_| anyhow!("Invalid column number: '{}'", parts[2]))?;
130
131        // Convert from 1-based to 0-based indexing
132        let point = Point::new(line.saturating_sub(1), column.saturating_sub(1));
133
134        Ok(CursorPosition { path, point })
135    }
136}
137
138enum GetContextOutput {
139    Zeta1(zeta::GatherContextOutput),
140    Zeta2(String),
141}
142
143async fn get_context(
144    zeta2_args: Option<Zeta2Args>,
145    args: ContextArgs,
146    app_state: &Arc<ZetaCliAppState>,
147    cx: &mut AsyncApp,
148) -> Result<GetContextOutput> {
149    let ContextArgs {
150        worktree: worktree_path,
151        cursor,
152        use_language_server,
153        events,
154    } = args;
155
156    let worktree_path = worktree_path.canonicalize()?;
157
158    let project = cx.update(|cx| {
159        Project::local(
160            app_state.client.clone(),
161            app_state.node_runtime.clone(),
162            app_state.user_store.clone(),
163            app_state.languages.clone(),
164            app_state.fs.clone(),
165            None,
166            cx,
167        )
168    })?;
169
170    let worktree = project
171        .update(cx, |project, cx| {
172            project.create_worktree(&worktree_path, true, cx)
173        })?
174        .await?;
175
176    let (_lsp_open_handle, buffer) = if use_language_server {
177        let (lsp_open_handle, buffer) =
178            open_buffer_with_language_server(&project, &worktree, &cursor.path, cx).await?;
179        (Some(lsp_open_handle), buffer)
180    } else {
181        let buffer = open_buffer(&project, &worktree, &cursor.path, cx).await?;
182        (None, buffer)
183    };
184
185    let full_path_str = worktree
186        .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
187        .display(PathStyle::local())
188        .to_string();
189
190    let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
191    let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
192    if clipped_cursor != cursor.point {
193        let max_row = snapshot.max_point().row;
194        if cursor.point.row < max_row {
195            return Err(anyhow!(
196                "Cursor position {:?} is out of bounds (line length is {})",
197                cursor.point,
198                snapshot.line_len(cursor.point.row)
199            ));
200        } else {
201            return Err(anyhow!(
202                "Cursor position {:?} is out of bounds (max row is {})",
203                cursor.point,
204                max_row
205            ));
206        }
207    }
208
209    let events = match events {
210        Some(events) => events.read_to_string().await?,
211        None => String::new(),
212    };
213
214    if let Some(zeta2_args) = zeta2_args {
215        Ok(GetContextOutput::Zeta2(
216            cx.update(|cx| {
217                let zeta = cx.new(|cx| {
218                    zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
219                });
220                zeta.update(cx, |zeta, cx| {
221                    zeta.register_buffer(&buffer, &project, cx);
222                    zeta.set_options(zeta2::ZetaOptions {
223                        excerpt: EditPredictionExcerptOptions {
224                            max_bytes: zeta2_args.max_excerpt_bytes,
225                            min_bytes: zeta2_args.min_excerpt_bytes,
226                            target_before_cursor_over_total_bytes: zeta2_args
227                                .target_before_cursor_over_total_bytes,
228                        },
229                        max_diagnostic_bytes: zeta2_args.max_diagnostic_bytes,
230                        max_prompt_bytes: zeta2_args.max_prompt_bytes,
231                    })
232                });
233                // TODO: Actually wait for indexing.
234                let timer = cx.background_executor().timer(Duration::from_secs(5));
235                cx.spawn(async move |cx| {
236                    timer.await;
237                    let request = zeta
238                        .update(cx, |zeta, cx| {
239                            let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
240                            zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
241                        })?
242                        .await?;
243                    let planned_prompt = cloud_zeta2_prompt::PlannedPrompt::populate(
244                        &request,
245                        &cloud_zeta2_prompt::PlanOptions {
246                            max_bytes: zeta2_args.max_prompt_bytes,
247                        },
248                    )?;
249                    anyhow::Ok(planned_prompt.to_prompt_string())
250                })
251            })?
252            .await?,
253        ))
254    } else {
255        let prompt_for_events = move || (events, 0);
256        Ok(GetContextOutput::Zeta1(
257            cx.update(|cx| {
258                zeta::gather_context(
259                    full_path_str,
260                    &snapshot,
261                    clipped_cursor,
262                    prompt_for_events,
263                    cx,
264                )
265            })?
266            .await?,
267        ))
268    }
269}
270
271pub async fn open_buffer(
272    project: &Entity<Project>,
273    worktree: &Entity<Worktree>,
274    path: &RelPath,
275    cx: &mut AsyncApp,
276) -> Result<Entity<Buffer>> {
277    let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath {
278        worktree_id: worktree.id(),
279        path: path.into(),
280    })?;
281
282    project
283        .update(cx, |project, cx| project.open_buffer(project_path, cx))?
284        .await
285}
286
287pub async fn open_buffer_with_language_server(
288    project: &Entity<Project>,
289    worktree: &Entity<Worktree>,
290    path: &RelPath,
291    cx: &mut AsyncApp,
292) -> Result<(Entity<Entity<Buffer>>, Entity<Buffer>)> {
293    let buffer = open_buffer(project, worktree, path, cx).await?;
294
295    let (lsp_open_handle, path_style) = project.update(cx, |project, cx| {
296        (
297            project.register_buffer_with_language_servers(&buffer, cx),
298            project.path_style(cx),
299        )
300    })?;
301
302    let log_prefix = path.display(path_style);
303    wait_for_lang_server(&project, &buffer, log_prefix.into_owned(), cx).await?;
304
305    Ok((lsp_open_handle, buffer))
306}
307
308// TODO: Dedupe with similar function in crates/eval/src/instance.rs
309pub fn wait_for_lang_server(
310    project: &Entity<Project>,
311    buffer: &Entity<Buffer>,
312    log_prefix: String,
313    cx: &mut AsyncApp,
314) -> Task<Result<()>> {
315    println!("{}⏵ Waiting for language server", log_prefix);
316
317    let (mut tx, mut rx) = mpsc::channel(1);
318
319    let lsp_store = project
320        .read_with(cx, |project, _| project.lsp_store())
321        .unwrap();
322
323    let has_lang_server = buffer
324        .update(cx, |buffer, cx| {
325            lsp_store.update(cx, |lsp_store, cx| {
326                lsp_store
327                    .language_servers_for_local_buffer(buffer, cx)
328                    .next()
329                    .is_some()
330            })
331        })
332        .unwrap_or(false);
333
334    if has_lang_server {
335        project
336            .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
337            .unwrap()
338            .detach();
339    }
340
341    let subscriptions = [
342        cx.subscribe(&lsp_store, {
343            let log_prefix = log_prefix.clone();
344            move |_, event, _| {
345                if let project::LspStoreEvent::LanguageServerUpdate {
346                    message:
347                        client::proto::update_language_server::Variant::WorkProgress(
348                            client::proto::LspWorkProgress {
349                                message: Some(message),
350                                ..
351                            },
352                        ),
353                    ..
354                } = event
355                {
356                    println!("{}{message}", log_prefix)
357                }
358            }
359        }),
360        cx.subscribe(project, {
361            let buffer = buffer.clone();
362            move |project, event, cx| match event {
363                project::Event::LanguageServerAdded(_, _, _) => {
364                    let buffer = buffer.clone();
365                    project
366                        .update(cx, |project, cx| project.save_buffer(buffer, cx))
367                        .detach();
368                }
369                project::Event::DiskBasedDiagnosticsFinished { .. } => {
370                    tx.try_send(()).ok();
371                }
372                _ => {}
373            }
374        }),
375    ];
376
377    cx.spawn(async move |cx| {
378        let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0));
379        let result = futures::select! {
380            _ = rx.next() => {
381                println!("{}⚑ Language server idle", log_prefix);
382                anyhow::Ok(())
383            },
384            _ = timeout.fuse() => {
385                anyhow::bail!("LSP wait timed out after 5 minutes");
386            }
387        };
388        drop(subscriptions);
389        result
390    })
391}
392
393fn main() {
394    let args = ZetaCliArgs::parse();
395    let http_client = Arc::new(ReqwestClient::new());
396    let app = Application::headless().with_http_client(http_client);
397
398    app.run(move |cx| {
399        let app_state = Arc::new(headless::init(cx));
400        let is_zeta2_context_command = matches!(args.command, Commands::Zeta2Context { .. });
401        cx.spawn(async move |cx| {
402            let result = match args.command {
403                Commands::Zeta2Context {
404                    zeta2_args,
405                    context_args,
406                } => match get_context(Some(zeta2_args), context_args, &app_state, cx).await {
407                    Ok(GetContextOutput::Zeta1 { .. }) => unreachable!(),
408                    Ok(GetContextOutput::Zeta2(output)) => Ok(output),
409                    Err(err) => Err(err),
410                },
411                Commands::Context(context_args) => {
412                    match get_context(None, context_args, &app_state, cx).await {
413                        Ok(GetContextOutput::Zeta1(output)) => {
414                            Ok(serde_json::to_string_pretty(&output.body).unwrap())
415                        }
416                        Ok(GetContextOutput::Zeta2 { .. }) => unreachable!(),
417                        Err(err) => Err(err),
418                    }
419                }
420                Commands::Predict {
421                    predict_edits_body,
422                    context_args,
423                } => {
424                    cx.spawn(async move |cx| {
425                        let app_version = cx.update(|cx| AppVersion::global(cx))?;
426                        app_state.client.sign_in(true, cx).await?;
427                        let llm_token = LlmApiToken::default();
428                        llm_token.refresh(&app_state.client).await?;
429
430                        let predict_edits_body =
431                            if let Some(predict_edits_body) = predict_edits_body {
432                                serde_json::from_str(&predict_edits_body.read_to_string().await?)?
433                            } else if let Some(context_args) = context_args {
434                                match get_context(None, context_args, &app_state, cx).await? {
435                                    GetContextOutput::Zeta1(output) => output.body,
436                                    GetContextOutput::Zeta2 { .. } => unreachable!(),
437                                }
438                            } else {
439                                return Err(anyhow!(
440                                    "Expected either --predict-edits-body-file \
441                                    or the required args of the `context` command."
442                                ));
443                            };
444
445                        let (response, _usage) =
446                            Zeta::perform_predict_edits(PerformPredictEditsParams {
447                                client: app_state.client.clone(),
448                                llm_token,
449                                app_version,
450                                body: predict_edits_body,
451                            })
452                            .await?;
453
454                        Ok(response.output_excerpt)
455                    })
456                    .await
457                }
458            };
459            match result {
460                Ok(output) => {
461                    println!("{}", output);
462                    // TODO: Remove this once the 5 second delay is properly replaced.
463                    if is_zeta2_context_command {
464                        eprintln!("Note that zeta2-context doesn't yet wait for indexing, instead waits 5 seconds.");
465                    }
466                    let _ = cx.update(|cx| cx.quit());
467                }
468                Err(e) => {
469                    eprintln!("Failed: {:?}", e);
470                    exit(1);
471                }
472            }
473        })
474        .detach();
475    });
476}