main.rs

  1mod headless;
  2
  3use anyhow::{Result, anyhow};
  4use clap::{Args, Parser, Subcommand};
  5use cloud_llm_client::predict_edits_v3;
  6use edit_prediction_context::{
  7    Declaration, EditPredictionContext, EditPredictionExcerptOptions, Identifier, ReferenceRegion,
  8    SyntaxIndex, references_in_range,
  9};
 10use futures::channel::mpsc;
 11use futures::{FutureExt as _, StreamExt as _};
 12use gpui::{AppContext, Application, AsyncApp};
 13use gpui::{Entity, Task};
 14use language::Bias;
 15use language::Point;
 16use language::{Buffer, OffsetRangeExt};
 17use language_model::LlmApiToken;
 18use ordered_float::OrderedFloat;
 19use project::{Project, ProjectPath, Worktree};
 20use release_channel::AppVersion;
 21use reqwest_client::ReqwestClient;
 22use serde_json::json;
 23use std::cmp::Reverse;
 24use std::collections::HashMap;
 25use std::io::Write as _;
 26use std::ops::Range;
 27use std::path::{Path, PathBuf};
 28use std::process::exit;
 29use std::str::FromStr;
 30use std::sync::Arc;
 31use std::time::Duration;
 32use util::paths::PathStyle;
 33use util::rel_path::RelPath;
 34use util::{RangeExt, ResultExt as _};
 35use zeta::{PerformPredictEditsParams, Zeta};
 36
 37use crate::headless::ZetaCliAppState;
 38
 39#[derive(Parser, Debug)]
 40#[command(name = "zeta")]
 41struct ZetaCliArgs {
 42    #[command(subcommand)]
 43    command: Commands,
 44}
 45
 46#[derive(Subcommand, Debug)]
 47enum Commands {
 48    Context(ContextArgs),
 49    Zeta2Context {
 50        #[clap(flatten)]
 51        zeta2_args: Zeta2Args,
 52        #[clap(flatten)]
 53        context_args: ContextArgs,
 54    },
 55    Predict {
 56        #[arg(long)]
 57        predict_edits_body: Option<FileOrStdin>,
 58        #[clap(flatten)]
 59        context_args: Option<ContextArgs>,
 60    },
 61    RetrievalStats {
 62        #[arg(long)]
 63        worktree: PathBuf,
 64        #[arg(long, default_value_t = 42)]
 65        file_indexing_parallelism: usize,
 66    },
 67}
 68
 69#[derive(Debug, Args)]
 70#[group(requires = "worktree")]
 71struct ContextArgs {
 72    #[arg(long)]
 73    worktree: PathBuf,
 74    #[arg(long)]
 75    cursor: CursorPosition,
 76    #[arg(long)]
 77    use_language_server: bool,
 78    #[arg(long)]
 79    events: Option<FileOrStdin>,
 80}
 81
 82#[derive(Debug, Args)]
 83struct Zeta2Args {
 84    #[arg(long, default_value_t = 8192)]
 85    max_prompt_bytes: usize,
 86    #[arg(long, default_value_t = 2048)]
 87    max_excerpt_bytes: usize,
 88    #[arg(long, default_value_t = 1024)]
 89    min_excerpt_bytes: usize,
 90    #[arg(long, default_value_t = 0.66)]
 91    target_before_cursor_over_total_bytes: f32,
 92    #[arg(long, default_value_t = 1024)]
 93    max_diagnostic_bytes: usize,
 94    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
 95    prompt_format: PromptFormat,
 96    #[arg(long, value_enum, default_value_t = Default::default())]
 97    output_format: OutputFormat,
 98    #[arg(long, default_value_t = 42)]
 99    file_indexing_parallelism: usize,
100}
101
102#[derive(clap::ValueEnum, Default, Debug, Clone)]
103enum PromptFormat {
104    #[default]
105    MarkedExcerpt,
106    LabeledSections,
107    OnlySnippets,
108}
109
110impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
111    fn into(self) -> predict_edits_v3::PromptFormat {
112        match self {
113            Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
114            Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
115            Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
116        }
117    }
118}
119
120#[derive(clap::ValueEnum, Default, Debug, Clone)]
121enum OutputFormat {
122    #[default]
123    Prompt,
124    Request,
125    Full,
126}
127
128#[derive(Debug, Clone)]
129enum FileOrStdin {
130    File(PathBuf),
131    Stdin,
132}
133
134impl FileOrStdin {
135    async fn read_to_string(&self) -> Result<String, std::io::Error> {
136        match self {
137            FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
138            FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
139        }
140    }
141}
142
143impl FromStr for FileOrStdin {
144    type Err = <PathBuf as FromStr>::Err;
145
146    fn from_str(s: &str) -> Result<Self, Self::Err> {
147        match s {
148            "-" => Ok(Self::Stdin),
149            _ => Ok(Self::File(PathBuf::from_str(s)?)),
150        }
151    }
152}
153
154#[derive(Debug, Clone)]
155struct CursorPosition {
156    path: Arc<RelPath>,
157    point: Point,
158}
159
160impl FromStr for CursorPosition {
161    type Err = anyhow::Error;
162
163    fn from_str(s: &str) -> Result<Self> {
164        let parts: Vec<&str> = s.split(':').collect();
165        if parts.len() != 3 {
166            return Err(anyhow!(
167                "Invalid cursor format. Expected 'file.rs:line:column', got '{}'",
168                s
169            ));
170        }
171
172        let path = RelPath::new(Path::new(&parts[0]), PathStyle::local())?.into_arc();
173        let line: u32 = parts[1]
174            .parse()
175            .map_err(|_| anyhow!("Invalid line number: '{}'", parts[1]))?;
176        let column: u32 = parts[2]
177            .parse()
178            .map_err(|_| anyhow!("Invalid column number: '{}'", parts[2]))?;
179
180        // Convert from 1-based to 0-based indexing
181        let point = Point::new(line.saturating_sub(1), column.saturating_sub(1));
182
183        Ok(CursorPosition { path, point })
184    }
185}
186
187enum GetContextOutput {
188    Zeta1(zeta::GatherContextOutput),
189    Zeta2(String),
190}
191
192async fn get_context(
193    zeta2_args: Option<Zeta2Args>,
194    args: ContextArgs,
195    app_state: &Arc<ZetaCliAppState>,
196    cx: &mut AsyncApp,
197) -> Result<GetContextOutput> {
198    let ContextArgs {
199        worktree: worktree_path,
200        cursor,
201        use_language_server,
202        events,
203    } = args;
204
205    let worktree_path = worktree_path.canonicalize()?;
206
207    let project = cx.update(|cx| {
208        Project::local(
209            app_state.client.clone(),
210            app_state.node_runtime.clone(),
211            app_state.user_store.clone(),
212            app_state.languages.clone(),
213            app_state.fs.clone(),
214            None,
215            cx,
216        )
217    })?;
218
219    let worktree = project
220        .update(cx, |project, cx| {
221            project.create_worktree(&worktree_path, true, cx)
222        })?
223        .await?;
224
225    let (_lsp_open_handle, buffer) = if use_language_server {
226        let (lsp_open_handle, buffer) =
227            open_buffer_with_language_server(&project, &worktree, &cursor.path, cx).await?;
228        (Some(lsp_open_handle), buffer)
229    } else {
230        let buffer = open_buffer(&project, &worktree, &cursor.path, cx).await?;
231        (None, buffer)
232    };
233
234    let full_path_str = worktree
235        .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
236        .display(PathStyle::local())
237        .to_string();
238
239    let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
240    let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
241    if clipped_cursor != cursor.point {
242        let max_row = snapshot.max_point().row;
243        if cursor.point.row < max_row {
244            return Err(anyhow!(
245                "Cursor position {:?} is out of bounds (line length is {})",
246                cursor.point,
247                snapshot.line_len(cursor.point.row)
248            ));
249        } else {
250            return Err(anyhow!(
251                "Cursor position {:?} is out of bounds (max row is {})",
252                cursor.point,
253                max_row
254            ));
255        }
256    }
257
258    let events = match events {
259        Some(events) => events.read_to_string().await?,
260        None => String::new(),
261    };
262
263    if let Some(zeta2_args) = zeta2_args {
264        // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
265        // the whole worktree.
266        worktree
267            .read_with(cx, |worktree, _cx| {
268                worktree.as_local().unwrap().scan_complete()
269            })?
270            .await;
271        let output = cx
272            .update(|cx| {
273                let zeta = cx.new(|cx| {
274                    zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
275                });
276                let indexing_done_task = zeta.update(cx, |zeta, cx| {
277                    zeta.set_options(zeta2::ZetaOptions {
278                        excerpt: EditPredictionExcerptOptions {
279                            max_bytes: zeta2_args.max_excerpt_bytes,
280                            min_bytes: zeta2_args.min_excerpt_bytes,
281                            target_before_cursor_over_total_bytes: zeta2_args
282                                .target_before_cursor_over_total_bytes,
283                        },
284                        max_diagnostic_bytes: zeta2_args.max_diagnostic_bytes,
285                        max_prompt_bytes: zeta2_args.max_prompt_bytes,
286                        prompt_format: zeta2_args.prompt_format.into(),
287                        file_indexing_parallelism: zeta2_args.file_indexing_parallelism,
288                    });
289                    zeta.register_buffer(&buffer, &project, cx);
290                    zeta.wait_for_initial_indexing(&project, cx)
291                });
292                cx.spawn(async move |cx| {
293                    indexing_done_task.await?;
294                    let request = zeta
295                        .update(cx, |zeta, cx| {
296                            let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
297                            zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
298                        })?
299                        .await?;
300
301                    let planned_prompt = cloud_zeta2_prompt::PlannedPrompt::populate(&request)?;
302                    let (prompt_string, section_labels) = planned_prompt.to_prompt_string()?;
303
304                    match zeta2_args.output_format {
305                        OutputFormat::Prompt => anyhow::Ok(prompt_string),
306                        OutputFormat::Request => {
307                            anyhow::Ok(serde_json::to_string_pretty(&request)?)
308                        }
309                        OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({
310                            "request": request,
311                            "prompt": prompt_string,
312                            "section_labels": section_labels,
313                        }))?),
314                    }
315                })
316            })?
317            .await?;
318        Ok(GetContextOutput::Zeta2(output))
319    } else {
320        let prompt_for_events = move || (events, 0);
321        Ok(GetContextOutput::Zeta1(
322            cx.update(|cx| {
323                zeta::gather_context(
324                    full_path_str,
325                    &snapshot,
326                    clipped_cursor,
327                    prompt_for_events,
328                    cx,
329                )
330            })?
331            .await?,
332        ))
333    }
334}
335
336pub async fn retrieval_stats(
337    worktree: PathBuf,
338    file_indexing_parallelism: usize,
339    app_state: Arc<ZetaCliAppState>,
340    cx: &mut AsyncApp,
341) -> Result<String> {
342    let worktree_path = worktree.canonicalize()?;
343
344    let project = cx.update(|cx| {
345        Project::local(
346            app_state.client.clone(),
347            app_state.node_runtime.clone(),
348            app_state.user_store.clone(),
349            app_state.languages.clone(),
350            app_state.fs.clone(),
351            None,
352            cx,
353        )
354    })?;
355
356    let worktree = project
357        .update(cx, |project, cx| {
358            project.create_worktree(&worktree_path, true, cx)
359        })?
360        .await?;
361    let worktree_id = worktree.read_with(cx, |worktree, _cx| worktree.id())?;
362
363    // wait for worktree scan so that wait_for_initial_file_indexing waits for the whole worktree.
364    worktree
365        .read_with(cx, |worktree, _cx| {
366            worktree.as_local().unwrap().scan_complete()
367        })?
368        .await;
369
370    let index = cx.new(|cx| SyntaxIndex::new(&project, file_indexing_parallelism, cx))?;
371    index
372        .read_with(cx, |index, cx| index.wait_for_initial_file_indexing(cx))?
373        .await?;
374    let files = index
375        .read_with(cx, |index, cx| index.indexed_file_paths(cx))?
376        .await;
377
378    let mut lsp_open_handles = Vec::new();
379    let mut output = std::fs::File::create("retrieval-stats.txt")?;
380    let mut results = Vec::new();
381    for (file_index, project_path) in files.iter().enumerate() {
382        println!(
383            "Processing file {} of {}: {}",
384            file_index + 1,
385            files.len(),
386            project_path.path.display(PathStyle::Posix)
387        );
388        let Some((lsp_open_handle, buffer)) =
389            open_buffer_with_language_server(&project, &worktree, &project_path.path, cx)
390                .await
391                .log_err()
392        else {
393            continue;
394        };
395        lsp_open_handles.push(lsp_open_handle);
396
397        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
398        let full_range = 0..snapshot.len();
399        let references = references_in_range(
400            full_range,
401            &snapshot.text(),
402            ReferenceRegion::Nearby,
403            &snapshot,
404        );
405
406        let index = index.read_with(cx, |index, _cx| index.state().clone())?;
407        let index = index.lock().await;
408        for reference in references {
409            let query_point = snapshot.offset_to_point(reference.range.start);
410            let mut single_reference_map = HashMap::default();
411            single_reference_map.insert(reference.identifier.clone(), vec![reference.clone()]);
412            let edit_prediction_context = EditPredictionContext::gather_context_with_references_fn(
413                query_point,
414                &snapshot,
415                &zeta2::DEFAULT_EXCERPT_OPTIONS,
416                Some(&index),
417                |_, _, _| single_reference_map,
418            );
419
420            let Some(edit_prediction_context) = edit_prediction_context else {
421                let result = RetrievalStatsResult {
422                    identifier: reference.identifier,
423                    point: query_point,
424                    outcome: RetrievalStatsOutcome::NoExcerpt,
425                };
426                write!(output, "{:?}\n\n", result)?;
427                results.push(result);
428                continue;
429            };
430
431            let mut retrieved_definitions = Vec::new();
432            for scored_declaration in edit_prediction_context.declarations {
433                match &scored_declaration.declaration {
434                    Declaration::File {
435                        project_entry_id,
436                        declaration,
437                    } => {
438                        let Some(path) = worktree.read_with(cx, |worktree, _cx| {
439                            worktree
440                                .entry_for_id(*project_entry_id)
441                                .map(|entry| entry.path.clone())
442                        })?
443                        else {
444                            log::error!("bug: file project entry not found");
445                            continue;
446                        };
447                        let project_path = ProjectPath {
448                            worktree_id,
449                            path: path.clone(),
450                        };
451                        let buffer = project
452                            .update(cx, |project, cx| project.open_buffer(project_path, cx))?
453                            .await?;
454                        let rope = buffer.read_with(cx, |buffer, _cx| buffer.as_rope().clone())?;
455                        retrieved_definitions.push((
456                            path,
457                            rope.offset_to_point(declaration.item_range.start)
458                                ..rope.offset_to_point(declaration.item_range.end),
459                            scored_declaration.scores.declaration,
460                            scored_declaration.scores.retrieval,
461                        ));
462                    }
463                    Declaration::Buffer {
464                        project_entry_id,
465                        rope,
466                        declaration,
467                        ..
468                    } => {
469                        let Some(path) = worktree.read_with(cx, |worktree, _cx| {
470                            worktree
471                                .entry_for_id(*project_entry_id)
472                                .map(|entry| entry.path.clone())
473                        })?
474                        else {
475                            log::error!("bug: buffer project entry not found");
476                            continue;
477                        };
478                        retrieved_definitions.push((
479                            path,
480                            rope.offset_to_point(declaration.item_range.start)
481                                ..rope.offset_to_point(declaration.item_range.end),
482                            scored_declaration.scores.declaration,
483                            scored_declaration.scores.retrieval,
484                        ));
485                    }
486                }
487            }
488            retrieved_definitions
489                .sort_by_key(|(_, _, _, retrieval_score)| Reverse(OrderedFloat(*retrieval_score)));
490
491            // TODO: Consider still checking language server in this case, or having a mode for
492            // this. For now assuming that the purpose of this is to refine the ranking rather than
493            // refining whether the definition is present at all.
494            if retrieved_definitions.is_empty() {
495                continue;
496            }
497
498            // TODO: Rename declaration to definition in edit_prediction_context?
499            let lsp_result = project
500                .update(cx, |project, cx| {
501                    project.definitions(&buffer, reference.range.start, cx)
502                })?
503                .await;
504            match lsp_result {
505                Ok(lsp_definitions) => {
506                    let lsp_definitions = lsp_definitions
507                        .unwrap_or_default()
508                        .into_iter()
509                        .filter_map(|definition| {
510                            definition
511                                .target
512                                .buffer
513                                .read_with(cx, |buffer, _cx| {
514                                    Some((
515                                        buffer.file()?.path().clone(),
516                                        definition.target.range.to_point(&buffer),
517                                    ))
518                                })
519                                .ok()?
520                        })
521                        .collect::<Vec<_>>();
522
523                    let result = RetrievalStatsResult {
524                        identifier: reference.identifier,
525                        point: query_point,
526                        outcome: RetrievalStatsOutcome::Success {
527                            matches: lsp_definitions
528                                .iter()
529                                .map(|(path, range)| {
530                                    retrieved_definitions.iter().position(
531                                        |(retrieved_path, retrieved_range, _, _)| {
532                                            path == retrieved_path
533                                                && retrieved_range.contains_inclusive(&range)
534                                        },
535                                    )
536                                })
537                                .collect(),
538                            lsp_definitions,
539                            retrieved_definitions,
540                        },
541                    };
542                    write!(output, "{:?}\n\n", result)?;
543                    results.push(result);
544                }
545                Err(err) => {
546                    let result = RetrievalStatsResult {
547                        identifier: reference.identifier,
548                        point: query_point,
549                        outcome: RetrievalStatsOutcome::LanguageServerError {
550                            message: err.to_string(),
551                        },
552                    };
553                    write!(output, "{:?}\n\n", result)?;
554                    results.push(result);
555                }
556            }
557        }
558    }
559
560    let mut no_excerpt_count = 0;
561    let mut error_count = 0;
562    let mut definitions_count = 0;
563    let mut top_match_count = 0;
564    let mut non_top_match_count = 0;
565    let mut ranking_involved_count = 0;
566    let mut ranking_involved_top_match_count = 0;
567    let mut ranking_involved_non_top_match_count = 0;
568    for result in &results {
569        match &result.outcome {
570            RetrievalStatsOutcome::NoExcerpt => no_excerpt_count += 1,
571            RetrievalStatsOutcome::LanguageServerError { .. } => error_count += 1,
572            RetrievalStatsOutcome::Success {
573                matches,
574                retrieved_definitions,
575                ..
576            } => {
577                definitions_count += 1;
578                let top_matches = matches.contains(&Some(0));
579                if top_matches {
580                    top_match_count += 1;
581                }
582                let non_top_matches = !top_matches && matches.iter().any(|index| *index != Some(0));
583                if non_top_matches {
584                    non_top_match_count += 1;
585                }
586                if retrieved_definitions.len() > 1 {
587                    ranking_involved_count += 1;
588                    if top_matches {
589                        ranking_involved_top_match_count += 1;
590                    }
591                    if non_top_matches {
592                        ranking_involved_non_top_match_count += 1;
593                    }
594                }
595            }
596        }
597    }
598
599    println!("\nStats:\n");
600    println!("No Excerpt: {}", no_excerpt_count);
601    println!("Language Server Error: {}", error_count);
602    println!("Definitions: {}", definitions_count);
603    println!("Top Match: {}", top_match_count);
604    println!("Non-Top Match: {}", non_top_match_count);
605    println!("Ranking Involved: {}", ranking_involved_count);
606    println!(
607        "Ranking Involved Top Match: {}",
608        ranking_involved_top_match_count
609    );
610    println!(
611        "Ranking Involved Non-Top Match: {}",
612        ranking_involved_non_top_match_count
613    );
614
615    Ok("".to_string())
616}
617
618#[derive(Debug)]
619struct RetrievalStatsResult {
620    #[allow(dead_code)]
621    identifier: Identifier,
622    #[allow(dead_code)]
623    point: Point,
624    outcome: RetrievalStatsOutcome,
625}
626
627#[derive(Debug)]
628enum RetrievalStatsOutcome {
629    NoExcerpt,
630    LanguageServerError {
631        #[allow(dead_code)]
632        message: String,
633    },
634    Success {
635        matches: Vec<Option<usize>>,
636        #[allow(dead_code)]
637        lsp_definitions: Vec<(Arc<RelPath>, Range<Point>)>,
638        retrieved_definitions: Vec<(Arc<RelPath>, Range<Point>, f32, f32)>,
639    },
640}
641
642pub async fn open_buffer(
643    project: &Entity<Project>,
644    worktree: &Entity<Worktree>,
645    path: &RelPath,
646    cx: &mut AsyncApp,
647) -> Result<Entity<Buffer>> {
648    let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath {
649        worktree_id: worktree.id(),
650        path: path.into(),
651    })?;
652
653    project
654        .update(cx, |project, cx| project.open_buffer(project_path, cx))?
655        .await
656}
657
658pub async fn open_buffer_with_language_server(
659    project: &Entity<Project>,
660    worktree: &Entity<Worktree>,
661    path: &RelPath,
662    cx: &mut AsyncApp,
663) -> Result<(Entity<Entity<Buffer>>, Entity<Buffer>)> {
664    let buffer = open_buffer(project, worktree, path, cx).await?;
665
666    let (lsp_open_handle, path_style) = project.update(cx, |project, cx| {
667        (
668            project.register_buffer_with_language_servers(&buffer, cx),
669            project.path_style(cx),
670        )
671    })?;
672
673    let log_prefix = path.display(path_style);
674    wait_for_lang_server(&project, &buffer, log_prefix.into_owned(), cx).await?;
675
676    Ok((lsp_open_handle, buffer))
677}
678
679// TODO: Dedupe with similar function in crates/eval/src/instance.rs
680pub fn wait_for_lang_server(
681    project: &Entity<Project>,
682    buffer: &Entity<Buffer>,
683    log_prefix: String,
684    cx: &mut AsyncApp,
685) -> Task<Result<()>> {
686    println!("{}⏵ Waiting for language server", log_prefix);
687
688    let (mut tx, mut rx) = mpsc::channel(1);
689
690    let lsp_store = project
691        .read_with(cx, |project, _| project.lsp_store())
692        .unwrap();
693
694    let has_lang_server = buffer
695        .update(cx, |buffer, cx| {
696            lsp_store.update(cx, |lsp_store, cx| {
697                lsp_store
698                    .language_servers_for_local_buffer(buffer, cx)
699                    .next()
700                    .is_some()
701            })
702        })
703        .unwrap_or(false);
704
705    if has_lang_server {
706        project
707            .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
708            .unwrap()
709            .detach();
710    }
711    let (mut added_tx, mut added_rx) = mpsc::channel(1);
712
713    let subscriptions = [
714        cx.subscribe(&lsp_store, {
715            let log_prefix = log_prefix.clone();
716            move |_, event, _| {
717                if let project::LspStoreEvent::LanguageServerUpdate {
718                    message:
719                        client::proto::update_language_server::Variant::WorkProgress(
720                            client::proto::LspWorkProgress {
721                                message: Some(message),
722                                ..
723                            },
724                        ),
725                    ..
726                } = event
727                {
728                    println!("{}{message}", log_prefix)
729                }
730            }
731        }),
732        cx.subscribe(project, {
733            let buffer = buffer.clone();
734            move |project, event, cx| match event {
735                project::Event::LanguageServerAdded(_, _, _) => {
736                    let buffer = buffer.clone();
737                    project
738                        .update(cx, |project, cx| project.save_buffer(buffer, cx))
739                        .detach();
740                    added_tx.try_send(()).ok();
741                }
742                project::Event::DiskBasedDiagnosticsFinished { .. } => {
743                    tx.try_send(()).ok();
744                }
745                _ => {}
746            }
747        }),
748    ];
749
750    cx.spawn(async move |cx| {
751        if !has_lang_server {
752            // some buffers never have a language server, so this aborts quickly in that case.
753            let timeout = cx.background_executor().timer(Duration::from_secs(1));
754            futures::select! {
755                _ = added_rx.next() => {},
756                _ = timeout.fuse() => {
757                    anyhow::bail!("Waiting for language server add timed out after 1 second");
758                }
759            };
760        }
761        let timeout = cx.background_executor().timer(Duration::from_secs(60 * 5));
762        let result = futures::select! {
763            _ = rx.next() => {
764                println!("{}⚑ Language server idle", log_prefix);
765                anyhow::Ok(())
766            },
767            _ = timeout.fuse() => {
768                anyhow::bail!("LSP wait timed out after 5 minutes");
769            }
770        };
771        drop(subscriptions);
772        result
773    })
774}
775
776fn main() {
777    zlog::init();
778    zlog::init_output_stderr();
779    let args = ZetaCliArgs::parse();
780    let http_client = Arc::new(ReqwestClient::new());
781    let app = Application::headless().with_http_client(http_client);
782
783    app.run(move |cx| {
784        let app_state = Arc::new(headless::init(cx));
785        cx.spawn(async move |cx| {
786            let result = match args.command {
787                Commands::Zeta2Context {
788                    zeta2_args,
789                    context_args,
790                } => match get_context(Some(zeta2_args), context_args, &app_state, cx).await {
791                    Ok(GetContextOutput::Zeta1 { .. }) => unreachable!(),
792                    Ok(GetContextOutput::Zeta2(output)) => Ok(output),
793                    Err(err) => Err(err),
794                },
795                Commands::Context(context_args) => {
796                    match get_context(None, context_args, &app_state, cx).await {
797                        Ok(GetContextOutput::Zeta1(output)) => {
798                            Ok(serde_json::to_string_pretty(&output.body).unwrap())
799                        }
800                        Ok(GetContextOutput::Zeta2 { .. }) => unreachable!(),
801                        Err(err) => Err(err),
802                    }
803                }
804                Commands::Predict {
805                    predict_edits_body,
806                    context_args,
807                } => {
808                    cx.spawn(async move |cx| {
809                        let app_version = cx.update(|cx| AppVersion::global(cx))?;
810                        app_state.client.sign_in(true, cx).await?;
811                        let llm_token = LlmApiToken::default();
812                        llm_token.refresh(&app_state.client).await?;
813
814                        let predict_edits_body =
815                            if let Some(predict_edits_body) = predict_edits_body {
816                                serde_json::from_str(&predict_edits_body.read_to_string().await?)?
817                            } else if let Some(context_args) = context_args {
818                                match get_context(None, context_args, &app_state, cx).await? {
819                                    GetContextOutput::Zeta1(output) => output.body,
820                                    GetContextOutput::Zeta2 { .. } => unreachable!(),
821                                }
822                            } else {
823                                return Err(anyhow!(
824                                    "Expected either --predict-edits-body-file \
825                                    or the required args of the `context` command."
826                                ));
827                            };
828
829                        let (response, _usage) =
830                            Zeta::perform_predict_edits(PerformPredictEditsParams {
831                                client: app_state.client.clone(),
832                                llm_token,
833                                app_version,
834                                body: predict_edits_body,
835                            })
836                            .await?;
837
838                        Ok(response.output_excerpt)
839                    })
840                    .await
841                }
842                Commands::RetrievalStats {
843                    worktree,
844                    file_indexing_parallelism,
845                } => retrieval_stats(worktree, file_indexing_parallelism, app_state, cx).await,
846            };
847            match result {
848                Ok(output) => {
849                    println!("{}", output);
850                    let _ = cx.update(|cx| cx.quit());
851                }
852                Err(e) => {
853                    eprintln!("Failed: {:?}", e);
854                    exit(1);
855                }
856            }
857        })
858        .detach();
859    });
860}