1mod headless;
2mod retrieval_stats;
3mod source_location;
4mod util;
5
6use crate::retrieval_stats::retrieval_stats;
7use ::util::paths::PathStyle;
8use anyhow::{Result, anyhow};
9use clap::{Args, Parser, Subcommand};
10use cloud_llm_client::predict_edits_v3::{self};
11use edit_prediction_context::{
12 EditPredictionContextOptions, EditPredictionExcerptOptions, EditPredictionScoreOptions,
13 SimilarSnippetOptions,
14};
15use gpui::{Application, AsyncApp, prelude::*};
16use language::Bias;
17use language_model::LlmApiToken;
18use project::Project;
19use release_channel::AppVersion;
20use reqwest_client::ReqwestClient;
21use serde_json::json;
22use std::{collections::HashSet, path::PathBuf, process::exit, str::FromStr, sync::Arc};
23use zeta::{PerformPredictEditsParams, Zeta};
24
25use crate::headless::ZetaCliAppState;
26use crate::source_location::SourceLocation;
27use crate::util::{open_buffer, open_buffer_with_language_server};
28
29#[derive(Parser, Debug)]
30#[command(name = "zeta")]
31struct ZetaCliArgs {
32 #[command(subcommand)]
33 command: Commands,
34}
35
36#[derive(Subcommand, Debug)]
37enum Commands {
38 Context(ContextArgs),
39 Zeta2Context {
40 #[clap(flatten)]
41 zeta2_args: Zeta2Args,
42 #[clap(flatten)]
43 context_args: ContextArgs,
44 },
45 Predict {
46 #[arg(long)]
47 predict_edits_body: Option<FileOrStdin>,
48 #[clap(flatten)]
49 context_args: Option<ContextArgs>,
50 },
51 RetrievalStats {
52 #[clap(flatten)]
53 zeta2_args: Zeta2Args,
54 #[arg(long)]
55 worktree: PathBuf,
56 #[arg(long)]
57 extension: Option<String>,
58 #[arg(long)]
59 limit: Option<usize>,
60 #[arg(long)]
61 skip: Option<usize>,
62 },
63}
64
65#[derive(Debug, Args)]
66#[group(requires = "worktree")]
67struct ContextArgs {
68 #[arg(long)]
69 worktree: PathBuf,
70 #[arg(long)]
71 cursor: SourceLocation,
72 #[arg(long)]
73 use_language_server: bool,
74 #[arg(long)]
75 events: Option<FileOrStdin>,
76}
77
78#[derive(Debug, Args)]
79struct Zeta2Args {
80 #[arg(long, default_value_t = 8192)]
81 max_prompt_bytes: usize,
82 #[arg(long, default_value_t = 2048)]
83 max_excerpt_bytes: usize,
84 #[arg(long, default_value_t = 1024)]
85 min_excerpt_bytes: usize,
86 #[arg(long, default_value_t = 0.66)]
87 target_before_cursor_over_total_bytes: f32,
88 #[arg(long, default_value_t = 1024)]
89 max_diagnostic_bytes: usize,
90 #[arg(long, value_enum, default_value_t = PromptFormat::default())]
91 prompt_format: PromptFormat,
92 #[arg(long, value_enum, default_value_t = Default::default())]
93 output_format: OutputFormat,
94 #[arg(long, default_value_t = 42)]
95 file_indexing_parallelism: usize,
96 #[arg(long, default_value_t = false)]
97 disable_imports_gathering: bool,
98 #[arg(long, default_value_t = u8::MAX)]
99 max_retrieved_definitions: u8,
100}
101
102#[derive(clap::ValueEnum, Default, Debug, Clone)]
103enum PromptFormat {
104 MarkedExcerpt,
105 LabeledSections,
106 OnlySnippets,
107 #[default]
108 NumberedLines,
109}
110
111impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
112 fn into(self) -> predict_edits_v3::PromptFormat {
113 match self {
114 Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
115 Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
116 Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
117 Self::NumberedLines => predict_edits_v3::PromptFormat::NumLinesUniDiff,
118 }
119 }
120}
121
122#[derive(clap::ValueEnum, Default, Debug, Clone)]
123enum OutputFormat {
124 #[default]
125 Prompt,
126 Request,
127 Full,
128}
129
130#[derive(Debug, Clone)]
131enum FileOrStdin {
132 File(PathBuf),
133 Stdin,
134}
135
136impl FileOrStdin {
137 async fn read_to_string(&self) -> Result<String, std::io::Error> {
138 match self {
139 FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
140 FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
141 }
142 }
143}
144
145impl FromStr for FileOrStdin {
146 type Err = <PathBuf as FromStr>::Err;
147
148 fn from_str(s: &str) -> Result<Self, Self::Err> {
149 match s {
150 "-" => Ok(Self::Stdin),
151 _ => Ok(Self::File(PathBuf::from_str(s)?)),
152 }
153 }
154}
155
156enum GetContextOutput {
157 Zeta1(zeta::GatherContextOutput),
158 Zeta2(String),
159}
160
161async fn get_context(
162 zeta2_args: Option<Zeta2Args>,
163 args: ContextArgs,
164 app_state: &Arc<ZetaCliAppState>,
165 cx: &mut AsyncApp,
166) -> Result<GetContextOutput> {
167 let ContextArgs {
168 worktree: worktree_path,
169 cursor,
170 use_language_server,
171 events,
172 } = args;
173
174 let worktree_path = worktree_path.canonicalize()?;
175
176 let project = cx.update(|cx| {
177 Project::local(
178 app_state.client.clone(),
179 app_state.node_runtime.clone(),
180 app_state.user_store.clone(),
181 app_state.languages.clone(),
182 app_state.fs.clone(),
183 None,
184 cx,
185 )
186 })?;
187
188 let worktree = project
189 .update(cx, |project, cx| {
190 project.create_worktree(&worktree_path, true, cx)
191 })?
192 .await?;
193
194 let mut ready_languages = HashSet::default();
195 let (_lsp_open_handle, buffer) = if use_language_server {
196 let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
197 project.clone(),
198 worktree.clone(),
199 cursor.path.clone(),
200 &mut ready_languages,
201 cx,
202 )
203 .await?;
204 (Some(lsp_open_handle), buffer)
205 } else {
206 let buffer =
207 open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
208 (None, buffer)
209 };
210
211 let full_path_str = worktree
212 .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
213 .display(PathStyle::local())
214 .to_string();
215
216 let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
217 let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
218 if clipped_cursor != cursor.point {
219 let max_row = snapshot.max_point().row;
220 if cursor.point.row < max_row {
221 return Err(anyhow!(
222 "Cursor position {:?} is out of bounds (line length is {})",
223 cursor.point,
224 snapshot.line_len(cursor.point.row)
225 ));
226 } else {
227 return Err(anyhow!(
228 "Cursor position {:?} is out of bounds (max row is {})",
229 cursor.point,
230 max_row
231 ));
232 }
233 }
234
235 let events = match events {
236 Some(events) => events.read_to_string().await?,
237 None => String::new(),
238 };
239
240 if let Some(zeta2_args) = zeta2_args {
241 // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
242 // the whole worktree.
243 worktree
244 .read_with(cx, |worktree, _cx| {
245 worktree.as_local().unwrap().scan_complete()
246 })?
247 .await;
248 let output = cx
249 .update(|cx| {
250 let zeta = cx.new(|cx| {
251 zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
252 });
253 let indexing_done_task = zeta.update(cx, |zeta, cx| {
254 zeta.set_options(zeta2_args.to_options(true));
255 zeta.register_buffer(&buffer, &project, cx);
256 zeta.wait_for_initial_indexing(&project, cx)
257 });
258 cx.spawn(async move |cx| {
259 indexing_done_task.await?;
260 let request = zeta
261 .update(cx, |zeta, cx| {
262 let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
263 zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
264 })?
265 .await?;
266
267 let planned_prompt = cloud_zeta2_prompt::PlannedPrompt::populate(&request)?;
268 let (prompt_string, section_labels) = planned_prompt.to_prompt_string()?;
269
270 match zeta2_args.output_format {
271 OutputFormat::Prompt => anyhow::Ok(prompt_string),
272 OutputFormat::Request => {
273 anyhow::Ok(serde_json::to_string_pretty(&request)?)
274 }
275 OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({
276 "request": request,
277 "prompt": prompt_string,
278 "section_labels": section_labels,
279 }))?),
280 }
281 })
282 })?
283 .await?;
284 Ok(GetContextOutput::Zeta2(output))
285 } else {
286 let prompt_for_events = move || (events, 0);
287 Ok(GetContextOutput::Zeta1(
288 cx.update(|cx| {
289 zeta::gather_context(
290 full_path_str,
291 &snapshot,
292 clipped_cursor,
293 prompt_for_events,
294 cx,
295 )
296 })?
297 .await?,
298 ))
299 }
300}
301
302impl Zeta2Args {
303 fn to_options(&self, omit_excerpt_overlaps: bool) -> zeta2::ZetaOptions {
304 zeta2::ZetaOptions {
305 context: EditPredictionContextOptions {
306 max_retrieved_declarations: self.max_retrieved_definitions,
307 use_imports: !self.disable_imports_gathering,
308 excerpt: EditPredictionExcerptOptions {
309 max_bytes: self.max_excerpt_bytes,
310 min_bytes: self.min_excerpt_bytes,
311 target_before_cursor_over_total_bytes: self
312 .target_before_cursor_over_total_bytes,
313 },
314 score: EditPredictionScoreOptions {
315 omit_excerpt_overlaps,
316 },
317 // todo! configuration
318 similar_snippets: SimilarSnippetOptions::default(),
319 },
320 max_diagnostic_bytes: self.max_diagnostic_bytes,
321 max_prompt_bytes: self.max_prompt_bytes,
322 prompt_format: self.prompt_format.clone().into(),
323 file_indexing_parallelism: self.file_indexing_parallelism,
324 }
325 }
326}
327
328fn main() {
329 zlog::init();
330 zlog::init_output_stderr();
331 let args = ZetaCliArgs::parse();
332 let http_client = Arc::new(ReqwestClient::new());
333 let app = Application::headless().with_http_client(http_client);
334
335 app.run(move |cx| {
336 let app_state = Arc::new(headless::init(cx));
337 cx.spawn(async move |cx| {
338 let result = match args.command {
339 Commands::Zeta2Context {
340 zeta2_args,
341 context_args,
342 } => match get_context(Some(zeta2_args), context_args, &app_state, cx).await {
343 Ok(GetContextOutput::Zeta1 { .. }) => unreachable!(),
344 Ok(GetContextOutput::Zeta2(output)) => Ok(output),
345 Err(err) => Err(err),
346 },
347 Commands::Context(context_args) => {
348 match get_context(None, context_args, &app_state, cx).await {
349 Ok(GetContextOutput::Zeta1(output)) => {
350 Ok(serde_json::to_string_pretty(&output.body).unwrap())
351 }
352 Ok(GetContextOutput::Zeta2 { .. }) => unreachable!(),
353 Err(err) => Err(err),
354 }
355 }
356 Commands::Predict {
357 predict_edits_body,
358 context_args,
359 } => {
360 cx.spawn(async move |cx| {
361 let app_version = cx.update(|cx| AppVersion::global(cx))?;
362 app_state.client.sign_in(true, cx).await?;
363 let llm_token = LlmApiToken::default();
364 llm_token.refresh(&app_state.client).await?;
365
366 let predict_edits_body =
367 if let Some(predict_edits_body) = predict_edits_body {
368 serde_json::from_str(&predict_edits_body.read_to_string().await?)?
369 } else if let Some(context_args) = context_args {
370 match get_context(None, context_args, &app_state, cx).await? {
371 GetContextOutput::Zeta1(output) => output.body,
372 GetContextOutput::Zeta2 { .. } => unreachable!(),
373 }
374 } else {
375 return Err(anyhow!(
376 "Expected either --predict-edits-body-file \
377 or the required args of the `context` command."
378 ));
379 };
380
381 let (response, _usage) =
382 Zeta::perform_predict_edits(PerformPredictEditsParams {
383 client: app_state.client.clone(),
384 llm_token,
385 app_version,
386 body: predict_edits_body,
387 })
388 .await?;
389
390 Ok(response.output_excerpt)
391 })
392 .await
393 }
394 Commands::RetrievalStats {
395 zeta2_args,
396 worktree,
397 extension,
398 limit,
399 skip,
400 } => {
401 retrieval_stats(
402 worktree,
403 app_state,
404 extension,
405 limit,
406 skip,
407 (&zeta2_args).to_options(false),
408 cx,
409 )
410 .await
411 }
412 };
413 match result {
414 Ok(output) => {
415 println!("{}", output);
416 let _ = cx.update(|cx| cx.quit());
417 }
418 Err(e) => {
419 eprintln!("Failed: {:?}", e);
420 exit(1);
421 }
422 }
423 })
424 .detach();
425 });
426}