main.rs

  1mod headless;
  2
  3use anyhow::{Result, anyhow};
  4use clap::{Args, Parser, Subcommand};
  5use cloud_llm_client::predict_edits_v3::PromptFormat;
  6use edit_prediction_context::EditPredictionExcerptOptions;
  7use futures::channel::mpsc;
  8use futures::{FutureExt as _, StreamExt as _};
  9use gpui::{AppContext, Application, AsyncApp};
 10use gpui::{Entity, Task};
 11use language::Bias;
 12use language::Buffer;
 13use language::Point;
 14use language_model::LlmApiToken;
 15use project::{Project, ProjectPath, Worktree};
 16use release_channel::AppVersion;
 17use reqwest_client::ReqwestClient;
 18use std::path::{Path, PathBuf};
 19use std::process::exit;
 20use std::str::FromStr;
 21use std::sync::Arc;
 22use std::time::Duration;
 23use util::paths::PathStyle;
 24use util::rel_path::RelPath;
 25use zeta::{PerformPredictEditsParams, Zeta};
 26
 27use crate::headless::ZetaCliAppState;
 28
 29#[derive(Parser, Debug)]
 30#[command(name = "zeta")]
 31struct ZetaCliArgs {
 32    #[command(subcommand)]
 33    command: Commands,
 34}
 35
 36#[derive(Subcommand, Debug)]
 37enum Commands {
 38    Context(ContextArgs),
 39    Zeta2Context {
 40        #[clap(flatten)]
 41        zeta2_args: Zeta2Args,
 42        #[clap(flatten)]
 43        context_args: ContextArgs,
 44    },
 45    Predict {
 46        #[arg(long)]
 47        predict_edits_body: Option<FileOrStdin>,
 48        #[clap(flatten)]
 49        context_args: Option<ContextArgs>,
 50    },
 51}
 52
 53#[derive(Debug, Args)]
 54#[group(requires = "worktree")]
 55struct ContextArgs {
 56    #[arg(long)]
 57    worktree: PathBuf,
 58    #[arg(long)]
 59    cursor: CursorPosition,
 60    #[arg(long)]
 61    use_language_server: bool,
 62    #[arg(long)]
 63    events: Option<FileOrStdin>,
 64}
 65
 66#[derive(Debug, Args)]
 67struct Zeta2Args {
 68    #[arg(long, default_value_t = 8192)]
 69    max_prompt_bytes: usize,
 70    #[arg(long, default_value_t = 2048)]
 71    max_excerpt_bytes: usize,
 72    #[arg(long, default_value_t = 1024)]
 73    min_excerpt_bytes: usize,
 74    #[arg(long, default_value_t = 0.66)]
 75    target_before_cursor_over_total_bytes: f32,
 76    #[arg(long, default_value_t = 1024)]
 77    max_diagnostic_bytes: usize,
 78    #[arg(long, value_parser = parse_format)]
 79    format: PromptFormat,
 80}
 81
 82fn parse_format(s: &str) -> Result<PromptFormat> {
 83    match s {
 84        "marked_excerpt" => Ok(PromptFormat::MarkedExcerpt),
 85        "labeled_sections" => Ok(PromptFormat::LabeledSections),
 86        _ => Err(anyhow!("Invalid format: {}", s)),
 87    }
 88}
 89
 90#[derive(Debug, Clone)]
 91enum FileOrStdin {
 92    File(PathBuf),
 93    Stdin,
 94}
 95
 96impl FileOrStdin {
 97    async fn read_to_string(&self) -> Result<String, std::io::Error> {
 98        match self {
 99            FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
100            FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
101        }
102    }
103}
104
105impl FromStr for FileOrStdin {
106    type Err = <PathBuf as FromStr>::Err;
107
108    fn from_str(s: &str) -> Result<Self, Self::Err> {
109        match s {
110            "-" => Ok(Self::Stdin),
111            _ => Ok(Self::File(PathBuf::from_str(s)?)),
112        }
113    }
114}
115
116#[derive(Debug, Clone)]
117struct CursorPosition {
118    path: Arc<RelPath>,
119    point: Point,
120}
121
122impl FromStr for CursorPosition {
123    type Err = anyhow::Error;
124
125    fn from_str(s: &str) -> Result<Self> {
126        let parts: Vec<&str> = s.split(':').collect();
127        if parts.len() != 3 {
128            return Err(anyhow!(
129                "Invalid cursor format. Expected 'file.rs:line:column', got '{}'",
130                s
131            ));
132        }
133
134        let path = RelPath::from_std_path(Path::new(&parts[0]), PathStyle::local())?;
135        let line: u32 = parts[1]
136            .parse()
137            .map_err(|_| anyhow!("Invalid line number: '{}'", parts[1]))?;
138        let column: u32 = parts[2]
139            .parse()
140            .map_err(|_| anyhow!("Invalid column number: '{}'", parts[2]))?;
141
142        // Convert from 1-based to 0-based indexing
143        let point = Point::new(line.saturating_sub(1), column.saturating_sub(1));
144
145        Ok(CursorPosition { path, point })
146    }
147}
148
149enum GetContextOutput {
150    Zeta1(zeta::GatherContextOutput),
151    Zeta2(String),
152}
153
154async fn get_context(
155    zeta2_args: Option<Zeta2Args>,
156    args: ContextArgs,
157    app_state: &Arc<ZetaCliAppState>,
158    cx: &mut AsyncApp,
159) -> Result<GetContextOutput> {
160    let ContextArgs {
161        worktree: worktree_path,
162        cursor,
163        use_language_server,
164        events,
165    } = args;
166
167    let worktree_path = worktree_path.canonicalize()?;
168
169    let project = cx.update(|cx| {
170        Project::local(
171            app_state.client.clone(),
172            app_state.node_runtime.clone(),
173            app_state.user_store.clone(),
174            app_state.languages.clone(),
175            app_state.fs.clone(),
176            None,
177            cx,
178        )
179    })?;
180
181    let worktree = project
182        .update(cx, |project, cx| {
183            project.create_worktree(&worktree_path, true, cx)
184        })?
185        .await?;
186
187    let (_lsp_open_handle, buffer) = if use_language_server {
188        let (lsp_open_handle, buffer) =
189            open_buffer_with_language_server(&project, &worktree, &cursor.path, cx).await?;
190        (Some(lsp_open_handle), buffer)
191    } else {
192        let buffer = open_buffer(&project, &worktree, &cursor.path, cx).await?;
193        (None, buffer)
194    };
195
196    let full_path_str = worktree
197        .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
198        .display(PathStyle::local())
199        .to_string();
200
201    let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
202    let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
203    if clipped_cursor != cursor.point {
204        let max_row = snapshot.max_point().row;
205        if cursor.point.row < max_row {
206            return Err(anyhow!(
207                "Cursor position {:?} is out of bounds (line length is {})",
208                cursor.point,
209                snapshot.line_len(cursor.point.row)
210            ));
211        } else {
212            return Err(anyhow!(
213                "Cursor position {:?} is out of bounds (max row is {})",
214                cursor.point,
215                max_row
216            ));
217        }
218    }
219
220    let events = match events {
221        Some(events) => events.read_to_string().await?,
222        None => String::new(),
223    };
224
225    if let Some(zeta2_args) = zeta2_args {
226        Ok(GetContextOutput::Zeta2(
227            cx.update(|cx| {
228                let zeta = cx.new(|cx| {
229                    zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
230                });
231                zeta.update(cx, |zeta, cx| {
232                    zeta.register_buffer(&buffer, &project, cx);
233                    zeta.set_options(zeta2::ZetaOptions {
234                        excerpt: EditPredictionExcerptOptions {
235                            max_bytes: zeta2_args.max_excerpt_bytes,
236                            min_bytes: zeta2_args.min_excerpt_bytes,
237                            target_before_cursor_over_total_bytes: zeta2_args
238                                .target_before_cursor_over_total_bytes,
239                        },
240                        max_diagnostic_bytes: zeta2_args.max_diagnostic_bytes,
241                        max_prompt_bytes: zeta2_args.max_prompt_bytes,
242                        prompt_format: zeta2_args.format,
243                    })
244                });
245                // TODO: Actually wait for indexing.
246                let timer = cx.background_executor().timer(Duration::from_secs(5));
247                cx.spawn(async move |cx| {
248                    timer.await;
249                    let request = zeta
250                        .update(cx, |zeta, cx| {
251                            let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
252                            zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
253                        })?
254                        .await?;
255                    let planned_prompt = cloud_zeta2_prompt::PlannedPrompt::populate(&request)?;
256                    // TODO: Output the section label ranges
257                    anyhow::Ok(planned_prompt.to_prompt_string()?.0)
258                })
259            })?
260            .await?,
261        ))
262    } else {
263        let prompt_for_events = move || (events, 0);
264        Ok(GetContextOutput::Zeta1(
265            cx.update(|cx| {
266                zeta::gather_context(
267                    full_path_str,
268                    &snapshot,
269                    clipped_cursor,
270                    prompt_for_events,
271                    cx,
272                )
273            })?
274            .await?,
275        ))
276    }
277}
278
279pub async fn open_buffer(
280    project: &Entity<Project>,
281    worktree: &Entity<Worktree>,
282    path: &RelPath,
283    cx: &mut AsyncApp,
284) -> Result<Entity<Buffer>> {
285    let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath {
286        worktree_id: worktree.id(),
287        path: path.into(),
288    })?;
289
290    project
291        .update(cx, |project, cx| project.open_buffer(project_path, cx))?
292        .await
293}
294
295pub async fn open_buffer_with_language_server(
296    project: &Entity<Project>,
297    worktree: &Entity<Worktree>,
298    path: &RelPath,
299    cx: &mut AsyncApp,
300) -> Result<(Entity<Entity<Buffer>>, Entity<Buffer>)> {
301    let buffer = open_buffer(project, worktree, path, cx).await?;
302
303    let (lsp_open_handle, path_style) = project.update(cx, |project, cx| {
304        (
305            project.register_buffer_with_language_servers(&buffer, cx),
306            project.path_style(cx),
307        )
308    })?;
309
310    let log_prefix = path.display(path_style);
311    wait_for_lang_server(&project, &buffer, log_prefix.into_owned(), cx).await?;
312
313    Ok((lsp_open_handle, buffer))
314}
315
316// TODO: Dedupe with similar function in crates/eval/src/instance.rs
317pub fn wait_for_lang_server(
318    project: &Entity<Project>,
319    buffer: &Entity<Buffer>,
320    log_prefix: String,
321    cx: &mut AsyncApp,
322) -> Task<Result<()>> {
323    println!("{}⏵ Waiting for language server", log_prefix);
324
325    let (mut tx, mut rx) = mpsc::channel(1);
326
327    let lsp_store = project
328        .read_with(cx, |project, _| project.lsp_store())
329        .unwrap();
330
331    let has_lang_server = buffer
332        .update(cx, |buffer, cx| {
333            lsp_store.update(cx, |lsp_store, cx| {
334                lsp_store
335                    .language_servers_for_local_buffer(buffer, cx)
336                    .next()
337                    .is_some()
338            })
339        })
340        .unwrap_or(false);
341
342    if has_lang_server {
343        project
344            .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
345            .unwrap()
346            .detach();
347    }
348
349    let subscriptions = [
350        cx.subscribe(&lsp_store, {
351            let log_prefix = log_prefix.clone();
352            move |_, event, _| {
353                if let project::LspStoreEvent::LanguageServerUpdate {
354                    message:
355                        client::proto::update_language_server::Variant::WorkProgress(
356                            client::proto::LspWorkProgress {
357                                message: Some(message),
358                                ..
359                            },
360                        ),
361                    ..
362                } = event
363                {
364                    println!("{}{message}", log_prefix)
365                }
366            }
367        }),
368        cx.subscribe(project, {
369            let buffer = buffer.clone();
370            move |project, event, cx| match event {
371                project::Event::LanguageServerAdded(_, _, _) => {
372                    let buffer = buffer.clone();
373                    project
374                        .update(cx, |project, cx| project.save_buffer(buffer, cx))
375                        .detach();
376                }
377                project::Event::DiskBasedDiagnosticsFinished { .. } => {
378                    tx.try_send(()).ok();
379                }
380                _ => {}
381            }
382        }),
383    ];
384
385    cx.spawn(async move |cx| {
386        let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0));
387        let result = futures::select! {
388            _ = rx.next() => {
389                println!("{}⚑ Language server idle", log_prefix);
390                anyhow::Ok(())
391            },
392            _ = timeout.fuse() => {
393                anyhow::bail!("LSP wait timed out after 5 minutes");
394            }
395        };
396        drop(subscriptions);
397        result
398    })
399}
400
401fn main() {
402    let args = ZetaCliArgs::parse();
403    let http_client = Arc::new(ReqwestClient::new());
404    let app = Application::headless().with_http_client(http_client);
405
406    app.run(move |cx| {
407        let app_state = Arc::new(headless::init(cx));
408        let is_zeta2_context_command = matches!(args.command, Commands::Zeta2Context { .. });
409        cx.spawn(async move |cx| {
410            let result = match args.command {
411                Commands::Zeta2Context {
412                    zeta2_args,
413                    context_args,
414                } => match get_context(Some(zeta2_args), context_args, &app_state, cx).await {
415                    Ok(GetContextOutput::Zeta1 { .. }) => unreachable!(),
416                    Ok(GetContextOutput::Zeta2(output)) => Ok(output),
417                    Err(err) => Err(err),
418                },
419                Commands::Context(context_args) => {
420                    match get_context(None, context_args, &app_state, cx).await {
421                        Ok(GetContextOutput::Zeta1(output)) => {
422                            Ok(serde_json::to_string_pretty(&output.body).unwrap())
423                        }
424                        Ok(GetContextOutput::Zeta2 { .. }) => unreachable!(),
425                        Err(err) => Err(err),
426                    }
427                }
428                Commands::Predict {
429                    predict_edits_body,
430                    context_args,
431                } => {
432                    cx.spawn(async move |cx| {
433                        let app_version = cx.update(|cx| AppVersion::global(cx))?;
434                        app_state.client.sign_in(true, cx).await?;
435                        let llm_token = LlmApiToken::default();
436                        llm_token.refresh(&app_state.client).await?;
437
438                        let predict_edits_body =
439                            if let Some(predict_edits_body) = predict_edits_body {
440                                serde_json::from_str(&predict_edits_body.read_to_string().await?)?
441                            } else if let Some(context_args) = context_args {
442                                match get_context(None, context_args, &app_state, cx).await? {
443                                    GetContextOutput::Zeta1(output) => output.body,
444                                    GetContextOutput::Zeta2 { .. } => unreachable!(),
445                                }
446                            } else {
447                                return Err(anyhow!(
448                                    "Expected either --predict-edits-body-file \
449                                    or the required args of the `context` command."
450                                ));
451                            };
452
453                        let (response, _usage) =
454                            Zeta::perform_predict_edits(PerformPredictEditsParams {
455                                client: app_state.client.clone(),
456                                llm_token,
457                                app_version,
458                                body: predict_edits_body,
459                            })
460                            .await?;
461
462                        Ok(response.output_excerpt)
463                    })
464                    .await
465                }
466            };
467            match result {
468                Ok(output) => {
469                    println!("{}", output);
470                    // TODO: Remove this once the 5 second delay is properly replaced.
471                    if is_zeta2_context_command {
472                        eprintln!("Note that zeta2-context doesn't yet wait for indexing, instead waits 5 seconds.");
473                    }
474                    let _ = cx.update(|cx| cx.quit());
475                }
476                Err(e) => {
477                    eprintln!("Failed: {:?}", e);
478                    exit(1);
479                }
480            }
481        })
482        .detach();
483    });
484}