1mod headless;
2mod retrieval_stats;
3mod source_location;
4mod util;
5
6use crate::retrieval_stats::retrieval_stats;
7use ::util::paths::PathStyle;
8use anyhow::{Result, anyhow};
9use clap::{Args, Parser, Subcommand};
10use cloud_llm_client::predict_edits_v3::{self};
11use edit_prediction_context::{
12 EditPredictionContextOptions, EditPredictionExcerptOptions, EditPredictionScoreOptions,
13};
14use gpui::{Application, AsyncApp, prelude::*};
15use language::Bias;
16use language_model::LlmApiToken;
17use project::Project;
18use release_channel::AppVersion;
19use reqwest_client::ReqwestClient;
20use serde_json::json;
21use std::{collections::HashSet, path::PathBuf, process::exit, str::FromStr, sync::Arc};
22use zeta::{PerformPredictEditsParams, Zeta};
23
24use crate::headless::ZetaCliAppState;
25use crate::source_location::SourceLocation;
26use crate::util::{open_buffer, open_buffer_with_language_server};
27
28#[derive(Parser, Debug)]
29#[command(name = "zeta")]
30struct ZetaCliArgs {
31 #[command(subcommand)]
32 command: Commands,
33}
34
35#[derive(Subcommand, Debug)]
36enum Commands {
37 Context(ContextArgs),
38 Zeta2Context {
39 #[clap(flatten)]
40 zeta2_args: Zeta2Args,
41 #[clap(flatten)]
42 context_args: ContextArgs,
43 },
44 Predict {
45 #[arg(long)]
46 predict_edits_body: Option<FileOrStdin>,
47 #[clap(flatten)]
48 context_args: Option<ContextArgs>,
49 },
50 RetrievalStats {
51 #[clap(flatten)]
52 zeta2_args: Zeta2Args,
53 #[arg(long)]
54 worktree: PathBuf,
55 #[arg(long)]
56 extension: Option<String>,
57 #[arg(long)]
58 limit: Option<usize>,
59 #[arg(long)]
60 skip: Option<usize>,
61 },
62}
63
64#[derive(Debug, Args)]
65#[group(requires = "worktree")]
66struct ContextArgs {
67 #[arg(long)]
68 worktree: PathBuf,
69 #[arg(long)]
70 cursor: SourceLocation,
71 #[arg(long)]
72 use_language_server: bool,
73 #[arg(long)]
74 events: Option<FileOrStdin>,
75}
76
77#[derive(Debug, Args)]
78struct Zeta2Args {
79 #[arg(long, default_value_t = 8192)]
80 max_prompt_bytes: usize,
81 #[arg(long, default_value_t = 2048)]
82 max_excerpt_bytes: usize,
83 #[arg(long, default_value_t = 1024)]
84 min_excerpt_bytes: usize,
85 #[arg(long, default_value_t = 0.66)]
86 target_before_cursor_over_total_bytes: f32,
87 #[arg(long, default_value_t = 1024)]
88 max_diagnostic_bytes: usize,
89 #[arg(long, value_enum, default_value_t = PromptFormat::default())]
90 prompt_format: PromptFormat,
91 #[arg(long, value_enum, default_value_t = Default::default())]
92 output_format: OutputFormat,
93 #[arg(long, default_value_t = 42)]
94 file_indexing_parallelism: usize,
95 #[arg(long, default_value_t = false)]
96 disable_imports_gathering: bool,
97}
98
99#[derive(clap::ValueEnum, Default, Debug, Clone)]
100enum PromptFormat {
101 MarkedExcerpt,
102 LabeledSections,
103 OnlySnippets,
104 #[default]
105 NumberedLines,
106}
107
108impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
109 fn into(self) -> predict_edits_v3::PromptFormat {
110 match self {
111 Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
112 Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
113 Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
114 Self::NumberedLines => predict_edits_v3::PromptFormat::NumberedLines,
115 }
116 }
117}
118
119#[derive(clap::ValueEnum, Default, Debug, Clone)]
120enum OutputFormat {
121 #[default]
122 Prompt,
123 Request,
124 Full,
125}
126
127#[derive(Debug, Clone)]
128enum FileOrStdin {
129 File(PathBuf),
130 Stdin,
131}
132
133impl FileOrStdin {
134 async fn read_to_string(&self) -> Result<String, std::io::Error> {
135 match self {
136 FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
137 FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
138 }
139 }
140}
141
142impl FromStr for FileOrStdin {
143 type Err = <PathBuf as FromStr>::Err;
144
145 fn from_str(s: &str) -> Result<Self, Self::Err> {
146 match s {
147 "-" => Ok(Self::Stdin),
148 _ => Ok(Self::File(PathBuf::from_str(s)?)),
149 }
150 }
151}
152
153enum GetContextOutput {
154 Zeta1(zeta::GatherContextOutput),
155 Zeta2(String),
156}
157
158async fn get_context(
159 zeta2_args: Option<Zeta2Args>,
160 args: ContextArgs,
161 app_state: &Arc<ZetaCliAppState>,
162 cx: &mut AsyncApp,
163) -> Result<GetContextOutput> {
164 let ContextArgs {
165 worktree: worktree_path,
166 cursor,
167 use_language_server,
168 events,
169 } = args;
170
171 let worktree_path = worktree_path.canonicalize()?;
172
173 let project = cx.update(|cx| {
174 Project::local(
175 app_state.client.clone(),
176 app_state.node_runtime.clone(),
177 app_state.user_store.clone(),
178 app_state.languages.clone(),
179 app_state.fs.clone(),
180 None,
181 cx,
182 )
183 })?;
184
185 let worktree = project
186 .update(cx, |project, cx| {
187 project.create_worktree(&worktree_path, true, cx)
188 })?
189 .await?;
190
191 let mut ready_languages = HashSet::default();
192 let (_lsp_open_handle, buffer) = if use_language_server {
193 let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
194 project.clone(),
195 worktree.clone(),
196 cursor.path.clone(),
197 &mut ready_languages,
198 cx,
199 )
200 .await?;
201 (Some(lsp_open_handle), buffer)
202 } else {
203 let buffer =
204 open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
205 (None, buffer)
206 };
207
208 let full_path_str = worktree
209 .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
210 .display(PathStyle::local())
211 .to_string();
212
213 let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
214 let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
215 if clipped_cursor != cursor.point {
216 let max_row = snapshot.max_point().row;
217 if cursor.point.row < max_row {
218 return Err(anyhow!(
219 "Cursor position {:?} is out of bounds (line length is {})",
220 cursor.point,
221 snapshot.line_len(cursor.point.row)
222 ));
223 } else {
224 return Err(anyhow!(
225 "Cursor position {:?} is out of bounds (max row is {})",
226 cursor.point,
227 max_row
228 ));
229 }
230 }
231
232 let events = match events {
233 Some(events) => events.read_to_string().await?,
234 None => String::new(),
235 };
236
237 if let Some(zeta2_args) = zeta2_args {
238 // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
239 // the whole worktree.
240 worktree
241 .read_with(cx, |worktree, _cx| {
242 worktree.as_local().unwrap().scan_complete()
243 })?
244 .await;
245 let output = cx
246 .update(|cx| {
247 let zeta = cx.new(|cx| {
248 zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
249 });
250 let indexing_done_task = zeta.update(cx, |zeta, cx| {
251 zeta.set_options(zeta2_args.to_options(true));
252 zeta.register_buffer(&buffer, &project, cx);
253 zeta.wait_for_initial_indexing(&project, cx)
254 });
255 cx.spawn(async move |cx| {
256 indexing_done_task.await?;
257 let request = zeta
258 .update(cx, |zeta, cx| {
259 let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
260 zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
261 })?
262 .await?;
263
264 let planned_prompt = cloud_zeta2_prompt::PlannedPrompt::populate(&request)?;
265 let (prompt_string, section_labels) = planned_prompt.to_prompt_string()?;
266
267 match zeta2_args.output_format {
268 OutputFormat::Prompt => anyhow::Ok(prompt_string),
269 OutputFormat::Request => {
270 anyhow::Ok(serde_json::to_string_pretty(&request)?)
271 }
272 OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({
273 "request": request,
274 "prompt": prompt_string,
275 "section_labels": section_labels,
276 }))?),
277 }
278 })
279 })?
280 .await?;
281 Ok(GetContextOutput::Zeta2(output))
282 } else {
283 let prompt_for_events = move || (events, 0);
284 Ok(GetContextOutput::Zeta1(
285 cx.update(|cx| {
286 zeta::gather_context(
287 full_path_str,
288 &snapshot,
289 clipped_cursor,
290 prompt_for_events,
291 cx,
292 )
293 })?
294 .await?,
295 ))
296 }
297}
298
299impl Zeta2Args {
300 fn to_options(&self, omit_excerpt_overlaps: bool) -> zeta2::ZetaOptions {
301 zeta2::ZetaOptions {
302 context: EditPredictionContextOptions {
303 use_imports: !self.disable_imports_gathering,
304 excerpt: EditPredictionExcerptOptions {
305 max_bytes: self.max_excerpt_bytes,
306 min_bytes: self.min_excerpt_bytes,
307 target_before_cursor_over_total_bytes: self
308 .target_before_cursor_over_total_bytes,
309 },
310 score: EditPredictionScoreOptions {
311 omit_excerpt_overlaps,
312 },
313 },
314 max_diagnostic_bytes: self.max_diagnostic_bytes,
315 max_prompt_bytes: self.max_prompt_bytes,
316 prompt_format: self.prompt_format.clone().into(),
317 file_indexing_parallelism: self.file_indexing_parallelism,
318 }
319 }
320}
321
322fn main() {
323 zlog::init();
324 zlog::init_output_stderr();
325 let args = ZetaCliArgs::parse();
326 let http_client = Arc::new(ReqwestClient::new());
327 let app = Application::headless().with_http_client(http_client);
328
329 app.run(move |cx| {
330 let app_state = Arc::new(headless::init(cx));
331 cx.spawn(async move |cx| {
332 let result = match args.command {
333 Commands::Zeta2Context {
334 zeta2_args,
335 context_args,
336 } => match get_context(Some(zeta2_args), context_args, &app_state, cx).await {
337 Ok(GetContextOutput::Zeta1 { .. }) => unreachable!(),
338 Ok(GetContextOutput::Zeta2(output)) => Ok(output),
339 Err(err) => Err(err),
340 },
341 Commands::Context(context_args) => {
342 match get_context(None, context_args, &app_state, cx).await {
343 Ok(GetContextOutput::Zeta1(output)) => {
344 Ok(serde_json::to_string_pretty(&output.body).unwrap())
345 }
346 Ok(GetContextOutput::Zeta2 { .. }) => unreachable!(),
347 Err(err) => Err(err),
348 }
349 }
350 Commands::Predict {
351 predict_edits_body,
352 context_args,
353 } => {
354 cx.spawn(async move |cx| {
355 let app_version = cx.update(|cx| AppVersion::global(cx))?;
356 app_state.client.sign_in(true, cx).await?;
357 let llm_token = LlmApiToken::default();
358 llm_token.refresh(&app_state.client).await?;
359
360 let predict_edits_body =
361 if let Some(predict_edits_body) = predict_edits_body {
362 serde_json::from_str(&predict_edits_body.read_to_string().await?)?
363 } else if let Some(context_args) = context_args {
364 match get_context(None, context_args, &app_state, cx).await? {
365 GetContextOutput::Zeta1(output) => output.body,
366 GetContextOutput::Zeta2 { .. } => unreachable!(),
367 }
368 } else {
369 return Err(anyhow!(
370 "Expected either --predict-edits-body-file \
371 or the required args of the `context` command."
372 ));
373 };
374
375 let (response, _usage) =
376 Zeta::perform_predict_edits(PerformPredictEditsParams {
377 client: app_state.client.clone(),
378 llm_token,
379 app_version,
380 body: predict_edits_body,
381 })
382 .await?;
383
384 Ok(response.output_excerpt)
385 })
386 .await
387 }
388 Commands::RetrievalStats {
389 zeta2_args,
390 worktree,
391 extension,
392 limit,
393 skip,
394 } => {
395 retrieval_stats(
396 worktree,
397 app_state,
398 extension,
399 limit,
400 skip,
401 (&zeta2_args).to_options(false),
402 cx,
403 )
404 .await
405 }
406 };
407 match result {
408 Ok(output) => {
409 println!("{}", output);
410 let _ = cx.update(|cx| cx.quit());
411 }
412 Err(e) => {
413 eprintln!("Failed: {:?}", e);
414 exit(1);
415 }
416 }
417 })
418 .detach();
419 });
420}