main.rs

  1mod headless;
  2
  3use anyhow::{Result, anyhow};
  4use clap::{Args, Parser, Subcommand};
  5use cloud_llm_client::predict_edits_v3;
  6use edit_prediction_context::{
  7    Declaration, EditPredictionContext, EditPredictionExcerptOptions, Identifier, ReferenceRegion,
  8    SyntaxIndex, references_in_range,
  9};
 10use futures::channel::mpsc;
 11use futures::{FutureExt as _, StreamExt as _};
 12use gpui::{AppContext, Application, AsyncApp};
 13use gpui::{Entity, Task};
 14use language::{Bias, LanguageServerId};
 15use language::{Buffer, OffsetRangeExt};
 16use language::{LanguageId, Point};
 17use language_model::LlmApiToken;
 18use ordered_float::OrderedFloat;
 19use project::{Project, ProjectPath, Worktree};
 20use release_channel::AppVersion;
 21use reqwest_client::ReqwestClient;
 22use serde_json::json;
 23use std::cmp::Reverse;
 24use std::collections::{HashMap, HashSet};
 25use std::io::Write as _;
 26use std::ops::Range;
 27use std::path::{Path, PathBuf};
 28use std::process::exit;
 29use std::str::FromStr;
 30use std::sync::Arc;
 31use std::time::Duration;
 32use util::paths::PathStyle;
 33use util::rel_path::RelPath;
 34use util::{RangeExt, ResultExt as _};
 35use zeta::{PerformPredictEditsParams, Zeta};
 36
 37use crate::headless::ZetaCliAppState;
 38
 39#[derive(Parser, Debug)]
 40#[command(name = "zeta")]
 41struct ZetaCliArgs {
 42    #[command(subcommand)]
 43    command: Commands,
 44}
 45
 46#[derive(Subcommand, Debug)]
 47enum Commands {
 48    Context(ContextArgs),
 49    Zeta2Context {
 50        #[clap(flatten)]
 51        zeta2_args: Zeta2Args,
 52        #[clap(flatten)]
 53        context_args: ContextArgs,
 54    },
 55    Predict {
 56        #[arg(long)]
 57        predict_edits_body: Option<FileOrStdin>,
 58        #[clap(flatten)]
 59        context_args: Option<ContextArgs>,
 60    },
 61    RetrievalStats {
 62        #[arg(long)]
 63        worktree: PathBuf,
 64        #[arg(long, default_value_t = 42)]
 65        file_indexing_parallelism: usize,
 66    },
 67}
 68
 69#[derive(Debug, Args)]
 70#[group(requires = "worktree")]
 71struct ContextArgs {
 72    #[arg(long)]
 73    worktree: PathBuf,
 74    #[arg(long)]
 75    cursor: CursorPosition,
 76    #[arg(long)]
 77    use_language_server: bool,
 78    #[arg(long)]
 79    events: Option<FileOrStdin>,
 80}
 81
 82#[derive(Debug, Args)]
 83struct Zeta2Args {
 84    #[arg(long, default_value_t = 8192)]
 85    max_prompt_bytes: usize,
 86    #[arg(long, default_value_t = 2048)]
 87    max_excerpt_bytes: usize,
 88    #[arg(long, default_value_t = 1024)]
 89    min_excerpt_bytes: usize,
 90    #[arg(long, default_value_t = 0.66)]
 91    target_before_cursor_over_total_bytes: f32,
 92    #[arg(long, default_value_t = 1024)]
 93    max_diagnostic_bytes: usize,
 94    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
 95    prompt_format: PromptFormat,
 96    #[arg(long, value_enum, default_value_t = Default::default())]
 97    output_format: OutputFormat,
 98    #[arg(long, default_value_t = 42)]
 99    file_indexing_parallelism: usize,
100}
101
102#[derive(clap::ValueEnum, Default, Debug, Clone)]
103enum PromptFormat {
104    #[default]
105    MarkedExcerpt,
106    LabeledSections,
107    OnlySnippets,
108}
109
110impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
111    fn into(self) -> predict_edits_v3::PromptFormat {
112        match self {
113            Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
114            Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
115            Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
116        }
117    }
118}
119
120#[derive(clap::ValueEnum, Default, Debug, Clone)]
121enum OutputFormat {
122    #[default]
123    Prompt,
124    Request,
125    Full,
126}
127
128#[derive(Debug, Clone)]
129enum FileOrStdin {
130    File(PathBuf),
131    Stdin,
132}
133
134impl FileOrStdin {
135    async fn read_to_string(&self) -> Result<String, std::io::Error> {
136        match self {
137            FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
138            FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
139        }
140    }
141}
142
143impl FromStr for FileOrStdin {
144    type Err = <PathBuf as FromStr>::Err;
145
146    fn from_str(s: &str) -> Result<Self, Self::Err> {
147        match s {
148            "-" => Ok(Self::Stdin),
149            _ => Ok(Self::File(PathBuf::from_str(s)?)),
150        }
151    }
152}
153
154#[derive(Debug, Clone)]
155struct CursorPosition {
156    path: Arc<RelPath>,
157    point: Point,
158}
159
160impl FromStr for CursorPosition {
161    type Err = anyhow::Error;
162
163    fn from_str(s: &str) -> Result<Self> {
164        let parts: Vec<&str> = s.split(':').collect();
165        if parts.len() != 3 {
166            return Err(anyhow!(
167                "Invalid cursor format. Expected 'file.rs:line:column', got '{}'",
168                s
169            ));
170        }
171
172        let path = RelPath::new(Path::new(&parts[0]), PathStyle::local())?.into_arc();
173        let line: u32 = parts[1]
174            .parse()
175            .map_err(|_| anyhow!("Invalid line number: '{}'", parts[1]))?;
176        let column: u32 = parts[2]
177            .parse()
178            .map_err(|_| anyhow!("Invalid column number: '{}'", parts[2]))?;
179
180        // Convert from 1-based to 0-based indexing
181        let point = Point::new(line.saturating_sub(1), column.saturating_sub(1));
182
183        Ok(CursorPosition { path, point })
184    }
185}
186
187enum GetContextOutput {
188    Zeta1(zeta::GatherContextOutput),
189    Zeta2(String),
190}
191
192async fn get_context(
193    zeta2_args: Option<Zeta2Args>,
194    args: ContextArgs,
195    app_state: &Arc<ZetaCliAppState>,
196    cx: &mut AsyncApp,
197) -> Result<GetContextOutput> {
198    let ContextArgs {
199        worktree: worktree_path,
200        cursor,
201        use_language_server,
202        events,
203    } = args;
204
205    let worktree_path = worktree_path.canonicalize()?;
206
207    let project = cx.update(|cx| {
208        Project::local(
209            app_state.client.clone(),
210            app_state.node_runtime.clone(),
211            app_state.user_store.clone(),
212            app_state.languages.clone(),
213            app_state.fs.clone(),
214            None,
215            cx,
216        )
217    })?;
218
219    let worktree = project
220        .update(cx, |project, cx| {
221            project.create_worktree(&worktree_path, true, cx)
222        })?
223        .await?;
224
225    let mut ready_languages = HashSet::default();
226    let (_lsp_open_handle, buffer) = if use_language_server {
227        let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
228            &project,
229            &worktree,
230            &cursor.path,
231            &mut ready_languages,
232            cx,
233        )
234        .await?;
235        (Some(lsp_open_handle), buffer)
236    } else {
237        let buffer = open_buffer(&project, &worktree, &cursor.path, cx).await?;
238        (None, buffer)
239    };
240
241    let full_path_str = worktree
242        .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
243        .display(PathStyle::local())
244        .to_string();
245
246    let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
247    let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
248    if clipped_cursor != cursor.point {
249        let max_row = snapshot.max_point().row;
250        if cursor.point.row < max_row {
251            return Err(anyhow!(
252                "Cursor position {:?} is out of bounds (line length is {})",
253                cursor.point,
254                snapshot.line_len(cursor.point.row)
255            ));
256        } else {
257            return Err(anyhow!(
258                "Cursor position {:?} is out of bounds (max row is {})",
259                cursor.point,
260                max_row
261            ));
262        }
263    }
264
265    let events = match events {
266        Some(events) => events.read_to_string().await?,
267        None => String::new(),
268    };
269
270    if let Some(zeta2_args) = zeta2_args {
271        // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
272        // the whole worktree.
273        worktree
274            .read_with(cx, |worktree, _cx| {
275                worktree.as_local().unwrap().scan_complete()
276            })?
277            .await;
278        let output = cx
279            .update(|cx| {
280                let zeta = cx.new(|cx| {
281                    zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
282                });
283                let indexing_done_task = zeta.update(cx, |zeta, cx| {
284                    zeta.set_options(zeta2::ZetaOptions {
285                        excerpt: EditPredictionExcerptOptions {
286                            max_bytes: zeta2_args.max_excerpt_bytes,
287                            min_bytes: zeta2_args.min_excerpt_bytes,
288                            target_before_cursor_over_total_bytes: zeta2_args
289                                .target_before_cursor_over_total_bytes,
290                        },
291                        max_diagnostic_bytes: zeta2_args.max_diagnostic_bytes,
292                        max_prompt_bytes: zeta2_args.max_prompt_bytes,
293                        prompt_format: zeta2_args.prompt_format.into(),
294                        file_indexing_parallelism: zeta2_args.file_indexing_parallelism,
295                    });
296                    zeta.register_buffer(&buffer, &project, cx);
297                    zeta.wait_for_initial_indexing(&project, cx)
298                });
299                cx.spawn(async move |cx| {
300                    indexing_done_task.await?;
301                    let request = zeta
302                        .update(cx, |zeta, cx| {
303                            let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
304                            zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
305                        })?
306                        .await?;
307
308                    let planned_prompt = cloud_zeta2_prompt::PlannedPrompt::populate(&request)?;
309                    let (prompt_string, section_labels) = planned_prompt.to_prompt_string()?;
310
311                    match zeta2_args.output_format {
312                        OutputFormat::Prompt => anyhow::Ok(prompt_string),
313                        OutputFormat::Request => {
314                            anyhow::Ok(serde_json::to_string_pretty(&request)?)
315                        }
316                        OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({
317                            "request": request,
318                            "prompt": prompt_string,
319                            "section_labels": section_labels,
320                        }))?),
321                    }
322                })
323            })?
324            .await?;
325        Ok(GetContextOutput::Zeta2(output))
326    } else {
327        let prompt_for_events = move || (events, 0);
328        Ok(GetContextOutput::Zeta1(
329            cx.update(|cx| {
330                zeta::gather_context(
331                    full_path_str,
332                    &snapshot,
333                    clipped_cursor,
334                    prompt_for_events,
335                    cx,
336                )
337            })?
338            .await?,
339        ))
340    }
341}
342
343pub async fn retrieval_stats(
344    worktree: PathBuf,
345    file_indexing_parallelism: usize,
346    app_state: Arc<ZetaCliAppState>,
347    cx: &mut AsyncApp,
348) -> Result<String> {
349    let worktree_path = worktree.canonicalize()?;
350
351    let project = cx.update(|cx| {
352        Project::local(
353            app_state.client.clone(),
354            app_state.node_runtime.clone(),
355            app_state.user_store.clone(),
356            app_state.languages.clone(),
357            app_state.fs.clone(),
358            None,
359            cx,
360        )
361    })?;
362
363    let worktree = project
364        .update(cx, |project, cx| {
365            project.create_worktree(&worktree_path, true, cx)
366        })?
367        .await?;
368    let worktree_id = worktree.read_with(cx, |worktree, _cx| worktree.id())?;
369
370    // wait for worktree scan so that wait_for_initial_file_indexing waits for the whole worktree.
371    worktree
372        .read_with(cx, |worktree, _cx| {
373            worktree.as_local().unwrap().scan_complete()
374        })?
375        .await;
376
377    let index = cx.new(|cx| SyntaxIndex::new(&project, file_indexing_parallelism, cx))?;
378    index
379        .read_with(cx, |index, cx| index.wait_for_initial_file_indexing(cx))?
380        .await?;
381    let files = index
382        .read_with(cx, |index, cx| index.indexed_file_paths(cx))?
383        .await
384        .into_iter()
385        .filter(|project_path| {
386            project_path
387                .path
388                .extension()
389                .is_some_and(|extension| !["md", "json", "sh", "diff"].contains(&extension))
390        })
391        .collect::<Vec<_>>();
392
393    let lsp_store = project.read_with(cx, |project, _cx| project.lsp_store())?;
394    cx.subscribe(&lsp_store, {
395        move |_, event, _| {
396            if let project::LspStoreEvent::LanguageServerUpdate {
397                message:
398                    client::proto::update_language_server::Variant::WorkProgress(
399                        client::proto::LspWorkProgress {
400                            message: Some(message),
401                            ..
402                        },
403                    ),
404                ..
405            } = event
406            {
407                println!("{message}")
408            }
409        }
410    })?
411    .detach();
412
413    let mut lsp_open_handles = Vec::new();
414    let mut output = std::fs::File::create("retrieval-stats.txt")?;
415    let mut results = Vec::new();
416    let mut ready_languages = HashSet::default();
417    for (file_index, project_path) in files.iter().enumerate() {
418        let processing_file_message = format!(
419            "Processing file {} of {}: {}",
420            file_index + 1,
421            files.len(),
422            project_path.path.display(PathStyle::Posix)
423        );
424        println!("{}", processing_file_message);
425        write!(output, "{processing_file_message}\n\n").ok();
426
427        let Some((lsp_open_handle, language_server_id, buffer)) = open_buffer_with_language_server(
428            &project,
429            &worktree,
430            &project_path.path,
431            &mut ready_languages,
432            cx,
433        )
434        .await
435        .log_err() else {
436            continue;
437        };
438        lsp_open_handles.push(lsp_open_handle);
439
440        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
441        let full_range = 0..snapshot.len();
442        let references = references_in_range(
443            full_range,
444            &snapshot.text(),
445            ReferenceRegion::Nearby,
446            &snapshot,
447        );
448
449        loop {
450            let is_ready = lsp_store
451                .read_with(cx, |lsp_store, _cx| {
452                    lsp_store
453                        .language_server_statuses
454                        .get(&language_server_id)
455                        .is_some_and(|status| status.pending_work.is_empty())
456                })
457                .unwrap();
458            if is_ready {
459                break;
460            }
461            cx.background_executor()
462                .timer(Duration::from_millis(10))
463                .await;
464        }
465
466        let index = index.read_with(cx, |index, _cx| index.state().clone())?;
467        let index = index.lock().await;
468        for reference in references {
469            let query_point = snapshot.offset_to_point(reference.range.start);
470            let mut single_reference_map = HashMap::default();
471            single_reference_map.insert(reference.identifier.clone(), vec![reference.clone()]);
472            let edit_prediction_context = EditPredictionContext::gather_context_with_references_fn(
473                query_point,
474                &snapshot,
475                &zeta2::DEFAULT_EXCERPT_OPTIONS,
476                Some(&index),
477                |_, _, _| single_reference_map,
478            );
479
480            let Some(edit_prediction_context) = edit_prediction_context else {
481                let result = RetrievalStatsResult {
482                    identifier: reference.identifier,
483                    point: query_point,
484                    outcome: RetrievalStatsOutcome::NoExcerpt,
485                };
486                write!(output, "{:?}\n\n", result)?;
487                results.push(result);
488                continue;
489            };
490
491            let mut retrieved_definitions = Vec::new();
492            for scored_declaration in edit_prediction_context.declarations {
493                match &scored_declaration.declaration {
494                    Declaration::File {
495                        project_entry_id,
496                        declaration,
497                    } => {
498                        let Some(path) = worktree.read_with(cx, |worktree, _cx| {
499                            worktree
500                                .entry_for_id(*project_entry_id)
501                                .map(|entry| entry.path.clone())
502                        })?
503                        else {
504                            log::error!("bug: file project entry not found");
505                            continue;
506                        };
507                        let project_path = ProjectPath {
508                            worktree_id,
509                            path: path.clone(),
510                        };
511                        let buffer = project
512                            .update(cx, |project, cx| project.open_buffer(project_path, cx))?
513                            .await?;
514                        let rope = buffer.read_with(cx, |buffer, _cx| buffer.as_rope().clone())?;
515                        retrieved_definitions.push((
516                            path,
517                            rope.offset_to_point(declaration.item_range.start)
518                                ..rope.offset_to_point(declaration.item_range.end),
519                            scored_declaration.scores.declaration,
520                            scored_declaration.scores.retrieval,
521                        ));
522                    }
523                    Declaration::Buffer {
524                        project_entry_id,
525                        rope,
526                        declaration,
527                        ..
528                    } => {
529                        let Some(path) = worktree.read_with(cx, |worktree, _cx| {
530                            worktree
531                                .entry_for_id(*project_entry_id)
532                                .map(|entry| entry.path.clone())
533                        })?
534                        else {
535                            // This case happens when dependency buffers have been opened by
536                            // go-to-definition, resulting in single-file worktrees.
537                            continue;
538                        };
539                        retrieved_definitions.push((
540                            path,
541                            rope.offset_to_point(declaration.item_range.start)
542                                ..rope.offset_to_point(declaration.item_range.end),
543                            scored_declaration.scores.declaration,
544                            scored_declaration.scores.retrieval,
545                        ));
546                    }
547                }
548            }
549            retrieved_definitions
550                .sort_by_key(|(_, _, _, retrieval_score)| Reverse(OrderedFloat(*retrieval_score)));
551
552            // TODO: Consider still checking language server in this case, or having a mode for
553            // this. For now assuming that the purpose of this is to refine the ranking rather than
554            // refining whether the definition is present at all.
555            if retrieved_definitions.is_empty() {
556                continue;
557            }
558
559            // TODO: Rename declaration to definition in edit_prediction_context?
560            let lsp_result = project
561                .update(cx, |project, cx| {
562                    project.definitions(&buffer, reference.range.start, cx)
563                })?
564                .await;
565            match lsp_result {
566                Ok(lsp_definitions) => {
567                    let lsp_definitions = lsp_definitions
568                        .unwrap_or_default()
569                        .into_iter()
570                        .filter_map(|definition| {
571                            definition
572                                .target
573                                .buffer
574                                .read_with(cx, |buffer, _cx| {
575                                    let path = buffer.file()?.path();
576                                    // filter out definitions from single-file worktrees
577                                    if path.is_empty() {
578                                        None
579                                    } else {
580                                        Some((
581                                            path.clone(),
582                                            definition.target.range.to_point(&buffer),
583                                        ))
584                                    }
585                                })
586                                .ok()?
587                        })
588                        .collect::<Vec<_>>();
589
590                    let result = RetrievalStatsResult {
591                        identifier: reference.identifier,
592                        point: query_point,
593                        outcome: RetrievalStatsOutcome::Success {
594                            matches: lsp_definitions
595                                .iter()
596                                .map(|(path, range)| {
597                                    retrieved_definitions.iter().position(
598                                        |(retrieved_path, retrieved_range, _, _)| {
599                                            path == retrieved_path
600                                                && retrieved_range.contains_inclusive(&range)
601                                        },
602                                    )
603                                })
604                                .collect(),
605                            lsp_definitions,
606                            retrieved_definitions,
607                        },
608                    };
609                    write!(output, "{:?}\n\n", result)?;
610                    results.push(result);
611                }
612                Err(err) => {
613                    let result = RetrievalStatsResult {
614                        identifier: reference.identifier,
615                        point: query_point,
616                        outcome: RetrievalStatsOutcome::LanguageServerError {
617                            message: err.to_string(),
618                        },
619                    };
620                    write!(output, "{:?}\n\n", result)?;
621                    results.push(result);
622                }
623            }
624        }
625    }
626
627    let mut no_excerpt_count = 0;
628    let mut error_count = 0;
629    let mut definitions_count = 0;
630    let mut top_match_count = 0;
631    let mut non_top_match_count = 0;
632    let mut ranking_involved_count = 0;
633    let mut ranking_involved_top_match_count = 0;
634    let mut ranking_involved_non_top_match_count = 0;
635    for result in &results {
636        match &result.outcome {
637            RetrievalStatsOutcome::NoExcerpt => no_excerpt_count += 1,
638            RetrievalStatsOutcome::LanguageServerError { .. } => error_count += 1,
639            RetrievalStatsOutcome::Success {
640                matches,
641                retrieved_definitions,
642                ..
643            } => {
644                definitions_count += 1;
645                let top_matches = matches.contains(&Some(0));
646                if top_matches {
647                    top_match_count += 1;
648                }
649                let non_top_matches = !top_matches && matches.iter().any(|index| *index != Some(0));
650                if non_top_matches {
651                    non_top_match_count += 1;
652                }
653                if retrieved_definitions.len() > 1 {
654                    ranking_involved_count += 1;
655                    if top_matches {
656                        ranking_involved_top_match_count += 1;
657                    }
658                    if non_top_matches {
659                        ranking_involved_non_top_match_count += 1;
660                    }
661                }
662            }
663        }
664    }
665
666    println!("\nStats:\n");
667    println!("No Excerpt: {}", no_excerpt_count);
668    println!("Language Server Error: {}", error_count);
669    println!("Definitions: {}", definitions_count);
670    println!("Top Match: {}", top_match_count);
671    println!("Non-Top Match: {}", non_top_match_count);
672    println!("Ranking Involved: {}", ranking_involved_count);
673    println!(
674        "Ranking Involved Top Match: {}",
675        ranking_involved_top_match_count
676    );
677    println!(
678        "Ranking Involved Non-Top Match: {}",
679        ranking_involved_non_top_match_count
680    );
681
682    Ok("".to_string())
683}
684
685#[derive(Debug)]
686struct RetrievalStatsResult {
687    #[allow(dead_code)]
688    identifier: Identifier,
689    #[allow(dead_code)]
690    point: Point,
691    outcome: RetrievalStatsOutcome,
692}
693
694#[derive(Debug)]
695enum RetrievalStatsOutcome {
696    NoExcerpt,
697    LanguageServerError {
698        #[allow(dead_code)]
699        message: String,
700    },
701    Success {
702        matches: Vec<Option<usize>>,
703        #[allow(dead_code)]
704        lsp_definitions: Vec<(Arc<RelPath>, Range<Point>)>,
705        retrieved_definitions: Vec<(Arc<RelPath>, Range<Point>, f32, f32)>,
706    },
707}
708
709pub async fn open_buffer(
710    project: &Entity<Project>,
711    worktree: &Entity<Worktree>,
712    path: &RelPath,
713    cx: &mut AsyncApp,
714) -> Result<Entity<Buffer>> {
715    let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath {
716        worktree_id: worktree.id(),
717        path: path.into(),
718    })?;
719
720    project
721        .update(cx, |project, cx| project.open_buffer(project_path, cx))?
722        .await
723}
724
725pub async fn open_buffer_with_language_server(
726    project: &Entity<Project>,
727    worktree: &Entity<Worktree>,
728    path: &RelPath,
729    ready_languages: &mut HashSet<LanguageId>,
730    cx: &mut AsyncApp,
731) -> Result<(Entity<Entity<Buffer>>, LanguageServerId, Entity<Buffer>)> {
732    let buffer = open_buffer(project, worktree, path, cx).await?;
733
734    let (lsp_open_handle, path_style) = project.update(cx, |project, cx| {
735        (
736            project.register_buffer_with_language_servers(&buffer, cx),
737            project.path_style(cx),
738        )
739    })?;
740
741    let Some(language_id) = buffer.read_with(cx, |buffer, _cx| {
742        buffer.language().map(|language| language.id())
743    })?
744    else {
745        return Err(anyhow!("No language for {}", path.display(path_style)));
746    };
747
748    let log_prefix = path.display(path_style);
749    if !ready_languages.contains(&language_id) {
750        wait_for_lang_server(&project, &buffer, log_prefix.into_owned(), cx).await?;
751        ready_languages.insert(language_id);
752    }
753
754    let lsp_store = project.read_with(cx, |project, _cx| project.lsp_store())?;
755
756    // hacky wait for buffer to be registered with the language server
757    for _ in 0..100 {
758        let Some(language_server_id) = lsp_store.update(cx, |lsp_store, cx| {
759            buffer.update(cx, |buffer, cx| {
760                lsp_store
761                    .language_servers_for_local_buffer(&buffer, cx)
762                    .next()
763                    .map(|(_, language_server)| language_server.server_id())
764            })
765        })?
766        else {
767            cx.background_executor()
768                .timer(Duration::from_millis(10))
769                .await;
770            continue;
771        };
772
773        return Ok((lsp_open_handle, language_server_id, buffer));
774    }
775
776    return Err(anyhow!("No language server found for buffer"));
777}
778
779// TODO: Dedupe with similar function in crates/eval/src/instance.rs
780pub fn wait_for_lang_server(
781    project: &Entity<Project>,
782    buffer: &Entity<Buffer>,
783    log_prefix: String,
784    cx: &mut AsyncApp,
785) -> Task<Result<()>> {
786    println!("{}⏵ Waiting for language server", log_prefix);
787
788    let (mut tx, mut rx) = mpsc::channel(1);
789
790    let lsp_store = project
791        .read_with(cx, |project, _| project.lsp_store())
792        .unwrap();
793
794    let has_lang_server = buffer
795        .update(cx, |buffer, cx| {
796            lsp_store.update(cx, |lsp_store, cx| {
797                lsp_store
798                    .language_servers_for_local_buffer(buffer, cx)
799                    .next()
800                    .is_some()
801            })
802        })
803        .unwrap_or(false);
804
805    if has_lang_server {
806        project
807            .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
808            .unwrap()
809            .detach();
810    }
811    let (mut added_tx, mut added_rx) = mpsc::channel(1);
812
813    let subscriptions = [
814        cx.subscribe(&lsp_store, {
815            let log_prefix = log_prefix.clone();
816            move |_, event, _| {
817                if let project::LspStoreEvent::LanguageServerUpdate {
818                    message:
819                        client::proto::update_language_server::Variant::WorkProgress(
820                            client::proto::LspWorkProgress {
821                                message: Some(message),
822                                ..
823                            },
824                        ),
825                    ..
826                } = event
827                {
828                    println!("{}{message}", log_prefix)
829                }
830            }
831        }),
832        cx.subscribe(project, {
833            let buffer = buffer.clone();
834            move |project, event, cx| match event {
835                project::Event::LanguageServerAdded(_, _, _) => {
836                    let buffer = buffer.clone();
837                    project
838                        .update(cx, |project, cx| project.save_buffer(buffer, cx))
839                        .detach();
840                    added_tx.try_send(()).ok();
841                }
842                project::Event::DiskBasedDiagnosticsFinished { .. } => {
843                    tx.try_send(()).ok();
844                }
845                _ => {}
846            }
847        }),
848    ];
849
850    cx.spawn(async move |cx| {
851        if !has_lang_server {
852            // some buffers never have a language server, so this aborts quickly in that case.
853            let timeout = cx.background_executor().timer(Duration::from_secs(5));
854            futures::select! {
855                _ = added_rx.next() => {},
856                _ = timeout.fuse() => {
857                    anyhow::bail!("Waiting for language server add timed out after 5 seconds");
858                }
859            };
860        }
861        let timeout = cx.background_executor().timer(Duration::from_secs(60 * 5));
862        let result = futures::select! {
863            _ = rx.next() => {
864                println!("{}⚑ Language server idle", log_prefix);
865                anyhow::Ok(())
866            },
867            _ = timeout.fuse() => {
868                anyhow::bail!("LSP wait timed out after 5 minutes");
869            }
870        };
871        drop(subscriptions);
872        result
873    })
874}
875
876fn main() {
877    zlog::init();
878    zlog::init_output_stderr();
879    let args = ZetaCliArgs::parse();
880    let http_client = Arc::new(ReqwestClient::new());
881    let app = Application::headless().with_http_client(http_client);
882
883    app.run(move |cx| {
884        let app_state = Arc::new(headless::init(cx));
885        cx.spawn(async move |cx| {
886            let result = match args.command {
887                Commands::Zeta2Context {
888                    zeta2_args,
889                    context_args,
890                } => match get_context(Some(zeta2_args), context_args, &app_state, cx).await {
891                    Ok(GetContextOutput::Zeta1 { .. }) => unreachable!(),
892                    Ok(GetContextOutput::Zeta2(output)) => Ok(output),
893                    Err(err) => Err(err),
894                },
895                Commands::Context(context_args) => {
896                    match get_context(None, context_args, &app_state, cx).await {
897                        Ok(GetContextOutput::Zeta1(output)) => {
898                            Ok(serde_json::to_string_pretty(&output.body).unwrap())
899                        }
900                        Ok(GetContextOutput::Zeta2 { .. }) => unreachable!(),
901                        Err(err) => Err(err),
902                    }
903                }
904                Commands::Predict {
905                    predict_edits_body,
906                    context_args,
907                } => {
908                    cx.spawn(async move |cx| {
909                        let app_version = cx.update(|cx| AppVersion::global(cx))?;
910                        app_state.client.sign_in(true, cx).await?;
911                        let llm_token = LlmApiToken::default();
912                        llm_token.refresh(&app_state.client).await?;
913
914                        let predict_edits_body =
915                            if let Some(predict_edits_body) = predict_edits_body {
916                                serde_json::from_str(&predict_edits_body.read_to_string().await?)?
917                            } else if let Some(context_args) = context_args {
918                                match get_context(None, context_args, &app_state, cx).await? {
919                                    GetContextOutput::Zeta1(output) => output.body,
920                                    GetContextOutput::Zeta2 { .. } => unreachable!(),
921                                }
922                            } else {
923                                return Err(anyhow!(
924                                    "Expected either --predict-edits-body-file \
925                                    or the required args of the `context` command."
926                                ));
927                            };
928
929                        let (response, _usage) =
930                            Zeta::perform_predict_edits(PerformPredictEditsParams {
931                                client: app_state.client.clone(),
932                                llm_token,
933                                app_version,
934                                body: predict_edits_body,
935                            })
936                            .await?;
937
938                        Ok(response.output_excerpt)
939                    })
940                    .await
941                }
942                Commands::RetrievalStats {
943                    worktree,
944                    file_indexing_parallelism,
945                } => retrieval_stats(worktree, file_indexing_parallelism, app_state, cx).await,
946            };
947            match result {
948                Ok(output) => {
949                    println!("{}", output);
950                    let _ = cx.update(|cx| cx.quit());
951                }
952                Err(e) => {
953                    eprintln!("Failed: {:?}", e);
954                    exit(1);
955                }
956            }
957        })
958        .detach();
959    });
960}