1mod headless;
2
3use anyhow::{Result, anyhow};
4use clap::{Args, Parser, Subcommand};
5use cloud_llm_client::predict_edits_v3::PromptFormat;
6use edit_prediction_context::EditPredictionExcerptOptions;
7use futures::channel::mpsc;
8use futures::{FutureExt as _, StreamExt as _};
9use gpui::{AppContext, Application, AsyncApp};
10use gpui::{Entity, Task};
11use language::Bias;
12use language::Buffer;
13use language::Point;
14use language_model::LlmApiToken;
15use project::{Project, ProjectPath, Worktree};
16use release_channel::AppVersion;
17use reqwest_client::ReqwestClient;
18use std::path::{Path, PathBuf};
19use std::process::exit;
20use std::str::FromStr;
21use std::sync::Arc;
22use std::time::Duration;
23use util::paths::PathStyle;
24use util::rel_path::RelPath;
25use zeta::{PerformPredictEditsParams, Zeta};
26
27use crate::headless::ZetaCliAppState;
28
29#[derive(Parser, Debug)]
30#[command(name = "zeta")]
31struct ZetaCliArgs {
32 #[command(subcommand)]
33 command: Commands,
34}
35
36#[derive(Subcommand, Debug)]
37enum Commands {
38 Context(ContextArgs),
39 Zeta2Context {
40 #[clap(flatten)]
41 zeta2_args: Zeta2Args,
42 #[clap(flatten)]
43 context_args: ContextArgs,
44 },
45 Predict {
46 #[arg(long)]
47 predict_edits_body: Option<FileOrStdin>,
48 #[clap(flatten)]
49 context_args: Option<ContextArgs>,
50 },
51}
52
53#[derive(Debug, Args)]
54#[group(requires = "worktree")]
55struct ContextArgs {
56 #[arg(long)]
57 worktree: PathBuf,
58 #[arg(long)]
59 cursor: CursorPosition,
60 #[arg(long)]
61 use_language_server: bool,
62 #[arg(long)]
63 events: Option<FileOrStdin>,
64}
65
66#[derive(Debug, Args)]
67struct Zeta2Args {
68 #[arg(long, default_value_t = 8192)]
69 max_prompt_bytes: usize,
70 #[arg(long, default_value_t = 2048)]
71 max_excerpt_bytes: usize,
72 #[arg(long, default_value_t = 1024)]
73 min_excerpt_bytes: usize,
74 #[arg(long, default_value_t = 0.66)]
75 target_before_cursor_over_total_bytes: f32,
76 #[arg(long, default_value_t = 1024)]
77 max_diagnostic_bytes: usize,
78 #[arg(long, value_parser = parse_format)]
79 format: PromptFormat,
80}
81
82fn parse_format(s: &str) -> Result<PromptFormat> {
83 match s {
84 "marked_excerpt" => Ok(PromptFormat::MarkedExcerpt),
85 "labeled_sections" => Ok(PromptFormat::LabeledSections),
86 _ => Err(anyhow!("Invalid format: {}", s)),
87 }
88}
89
90#[derive(Debug, Clone)]
91enum FileOrStdin {
92 File(PathBuf),
93 Stdin,
94}
95
96impl FileOrStdin {
97 async fn read_to_string(&self) -> Result<String, std::io::Error> {
98 match self {
99 FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
100 FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
101 }
102 }
103}
104
105impl FromStr for FileOrStdin {
106 type Err = <PathBuf as FromStr>::Err;
107
108 fn from_str(s: &str) -> Result<Self, Self::Err> {
109 match s {
110 "-" => Ok(Self::Stdin),
111 _ => Ok(Self::File(PathBuf::from_str(s)?)),
112 }
113 }
114}
115
116#[derive(Debug, Clone)]
117struct CursorPosition {
118 path: Arc<RelPath>,
119 point: Point,
120}
121
122impl FromStr for CursorPosition {
123 type Err = anyhow::Error;
124
125 fn from_str(s: &str) -> Result<Self> {
126 let parts: Vec<&str> = s.split(':').collect();
127 if parts.len() != 3 {
128 return Err(anyhow!(
129 "Invalid cursor format. Expected 'file.rs:line:column', got '{}'",
130 s
131 ));
132 }
133
134 let path = RelPath::from_std_path(Path::new(&parts[0]), PathStyle::local())?;
135 let line: u32 = parts[1]
136 .parse()
137 .map_err(|_| anyhow!("Invalid line number: '{}'", parts[1]))?;
138 let column: u32 = parts[2]
139 .parse()
140 .map_err(|_| anyhow!("Invalid column number: '{}'", parts[2]))?;
141
142 // Convert from 1-based to 0-based indexing
143 let point = Point::new(line.saturating_sub(1), column.saturating_sub(1));
144
145 Ok(CursorPosition { path, point })
146 }
147}
148
149enum GetContextOutput {
150 Zeta1(zeta::GatherContextOutput),
151 Zeta2(String),
152}
153
154async fn get_context(
155 zeta2_args: Option<Zeta2Args>,
156 args: ContextArgs,
157 app_state: &Arc<ZetaCliAppState>,
158 cx: &mut AsyncApp,
159) -> Result<GetContextOutput> {
160 let ContextArgs {
161 worktree: worktree_path,
162 cursor,
163 use_language_server,
164 events,
165 } = args;
166
167 let worktree_path = worktree_path.canonicalize()?;
168
169 let project = cx.update(|cx| {
170 Project::local(
171 app_state.client.clone(),
172 app_state.node_runtime.clone(),
173 app_state.user_store.clone(),
174 app_state.languages.clone(),
175 app_state.fs.clone(),
176 None,
177 cx,
178 )
179 })?;
180
181 let worktree = project
182 .update(cx, |project, cx| {
183 project.create_worktree(&worktree_path, true, cx)
184 })?
185 .await?;
186
187 let (_lsp_open_handle, buffer) = if use_language_server {
188 let (lsp_open_handle, buffer) =
189 open_buffer_with_language_server(&project, &worktree, &cursor.path, cx).await?;
190 (Some(lsp_open_handle), buffer)
191 } else {
192 let buffer = open_buffer(&project, &worktree, &cursor.path, cx).await?;
193 (None, buffer)
194 };
195
196 let full_path_str = worktree
197 .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
198 .display(PathStyle::local())
199 .to_string();
200
201 let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
202 let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
203 if clipped_cursor != cursor.point {
204 let max_row = snapshot.max_point().row;
205 if cursor.point.row < max_row {
206 return Err(anyhow!(
207 "Cursor position {:?} is out of bounds (line length is {})",
208 cursor.point,
209 snapshot.line_len(cursor.point.row)
210 ));
211 } else {
212 return Err(anyhow!(
213 "Cursor position {:?} is out of bounds (max row is {})",
214 cursor.point,
215 max_row
216 ));
217 }
218 }
219
220 let events = match events {
221 Some(events) => events.read_to_string().await?,
222 None => String::new(),
223 };
224
225 if let Some(zeta2_args) = zeta2_args {
226 Ok(GetContextOutput::Zeta2(
227 cx.update(|cx| {
228 let zeta = cx.new(|cx| {
229 zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
230 });
231 zeta.update(cx, |zeta, cx| {
232 zeta.register_buffer(&buffer, &project, cx);
233 zeta.set_options(zeta2::ZetaOptions {
234 excerpt: EditPredictionExcerptOptions {
235 max_bytes: zeta2_args.max_excerpt_bytes,
236 min_bytes: zeta2_args.min_excerpt_bytes,
237 target_before_cursor_over_total_bytes: zeta2_args
238 .target_before_cursor_over_total_bytes,
239 },
240 max_diagnostic_bytes: zeta2_args.max_diagnostic_bytes,
241 max_prompt_bytes: zeta2_args.max_prompt_bytes,
242 prompt_format: zeta2_args.format,
243 })
244 });
245 // TODO: Actually wait for indexing.
246 let timer = cx.background_executor().timer(Duration::from_secs(5));
247 cx.spawn(async move |cx| {
248 timer.await;
249 let request = zeta
250 .update(cx, |zeta, cx| {
251 let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
252 zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
253 })?
254 .await?;
255 let planned_prompt = cloud_zeta2_prompt::PlannedPrompt::populate(&request)?;
256 // TODO: Output the section label ranges
257 anyhow::Ok(planned_prompt.to_prompt_string()?.0)
258 })
259 })?
260 .await?,
261 ))
262 } else {
263 let prompt_for_events = move || (events, 0);
264 Ok(GetContextOutput::Zeta1(
265 cx.update(|cx| {
266 zeta::gather_context(
267 full_path_str,
268 &snapshot,
269 clipped_cursor,
270 prompt_for_events,
271 cx,
272 )
273 })?
274 .await?,
275 ))
276 }
277}
278
279pub async fn open_buffer(
280 project: &Entity<Project>,
281 worktree: &Entity<Worktree>,
282 path: &RelPath,
283 cx: &mut AsyncApp,
284) -> Result<Entity<Buffer>> {
285 let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath {
286 worktree_id: worktree.id(),
287 path: path.into(),
288 })?;
289
290 project
291 .update(cx, |project, cx| project.open_buffer(project_path, cx))?
292 .await
293}
294
295pub async fn open_buffer_with_language_server(
296 project: &Entity<Project>,
297 worktree: &Entity<Worktree>,
298 path: &RelPath,
299 cx: &mut AsyncApp,
300) -> Result<(Entity<Entity<Buffer>>, Entity<Buffer>)> {
301 let buffer = open_buffer(project, worktree, path, cx).await?;
302
303 let (lsp_open_handle, path_style) = project.update(cx, |project, cx| {
304 (
305 project.register_buffer_with_language_servers(&buffer, cx),
306 project.path_style(cx),
307 )
308 })?;
309
310 let log_prefix = path.display(path_style);
311 wait_for_lang_server(&project, &buffer, log_prefix.into_owned(), cx).await?;
312
313 Ok((lsp_open_handle, buffer))
314}
315
316// TODO: Dedupe with similar function in crates/eval/src/instance.rs
317pub fn wait_for_lang_server(
318 project: &Entity<Project>,
319 buffer: &Entity<Buffer>,
320 log_prefix: String,
321 cx: &mut AsyncApp,
322) -> Task<Result<()>> {
323 println!("{}⏵ Waiting for language server", log_prefix);
324
325 let (mut tx, mut rx) = mpsc::channel(1);
326
327 let lsp_store = project
328 .read_with(cx, |project, _| project.lsp_store())
329 .unwrap();
330
331 let has_lang_server = buffer
332 .update(cx, |buffer, cx| {
333 lsp_store.update(cx, |lsp_store, cx| {
334 lsp_store
335 .language_servers_for_local_buffer(buffer, cx)
336 .next()
337 .is_some()
338 })
339 })
340 .unwrap_or(false);
341
342 if has_lang_server {
343 project
344 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
345 .unwrap()
346 .detach();
347 }
348
349 let subscriptions = [
350 cx.subscribe(&lsp_store, {
351 let log_prefix = log_prefix.clone();
352 move |_, event, _| {
353 if let project::LspStoreEvent::LanguageServerUpdate {
354 message:
355 client::proto::update_language_server::Variant::WorkProgress(
356 client::proto::LspWorkProgress {
357 message: Some(message),
358 ..
359 },
360 ),
361 ..
362 } = event
363 {
364 println!("{}⟲ {message}", log_prefix)
365 }
366 }
367 }),
368 cx.subscribe(project, {
369 let buffer = buffer.clone();
370 move |project, event, cx| match event {
371 project::Event::LanguageServerAdded(_, _, _) => {
372 let buffer = buffer.clone();
373 project
374 .update(cx, |project, cx| project.save_buffer(buffer, cx))
375 .detach();
376 }
377 project::Event::DiskBasedDiagnosticsFinished { .. } => {
378 tx.try_send(()).ok();
379 }
380 _ => {}
381 }
382 }),
383 ];
384
385 cx.spawn(async move |cx| {
386 let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0));
387 let result = futures::select! {
388 _ = rx.next() => {
389 println!("{}⚑ Language server idle", log_prefix);
390 anyhow::Ok(())
391 },
392 _ = timeout.fuse() => {
393 anyhow::bail!("LSP wait timed out after 5 minutes");
394 }
395 };
396 drop(subscriptions);
397 result
398 })
399}
400
401fn main() {
402 let args = ZetaCliArgs::parse();
403 let http_client = Arc::new(ReqwestClient::new());
404 let app = Application::headless().with_http_client(http_client);
405
406 app.run(move |cx| {
407 let app_state = Arc::new(headless::init(cx));
408 let is_zeta2_context_command = matches!(args.command, Commands::Zeta2Context { .. });
409 cx.spawn(async move |cx| {
410 let result = match args.command {
411 Commands::Zeta2Context {
412 zeta2_args,
413 context_args,
414 } => match get_context(Some(zeta2_args), context_args, &app_state, cx).await {
415 Ok(GetContextOutput::Zeta1 { .. }) => unreachable!(),
416 Ok(GetContextOutput::Zeta2(output)) => Ok(output),
417 Err(err) => Err(err),
418 },
419 Commands::Context(context_args) => {
420 match get_context(None, context_args, &app_state, cx).await {
421 Ok(GetContextOutput::Zeta1(output)) => {
422 Ok(serde_json::to_string_pretty(&output.body).unwrap())
423 }
424 Ok(GetContextOutput::Zeta2 { .. }) => unreachable!(),
425 Err(err) => Err(err),
426 }
427 }
428 Commands::Predict {
429 predict_edits_body,
430 context_args,
431 } => {
432 cx.spawn(async move |cx| {
433 let app_version = cx.update(|cx| AppVersion::global(cx))?;
434 app_state.client.sign_in(true, cx).await?;
435 let llm_token = LlmApiToken::default();
436 llm_token.refresh(&app_state.client).await?;
437
438 let predict_edits_body =
439 if let Some(predict_edits_body) = predict_edits_body {
440 serde_json::from_str(&predict_edits_body.read_to_string().await?)?
441 } else if let Some(context_args) = context_args {
442 match get_context(None, context_args, &app_state, cx).await? {
443 GetContextOutput::Zeta1(output) => output.body,
444 GetContextOutput::Zeta2 { .. } => unreachable!(),
445 }
446 } else {
447 return Err(anyhow!(
448 "Expected either --predict-edits-body-file \
449 or the required args of the `context` command."
450 ));
451 };
452
453 let (response, _usage) =
454 Zeta::perform_predict_edits(PerformPredictEditsParams {
455 client: app_state.client.clone(),
456 llm_token,
457 app_version,
458 body: predict_edits_body,
459 })
460 .await?;
461
462 Ok(response.output_excerpt)
463 })
464 .await
465 }
466 };
467 match result {
468 Ok(output) => {
469 println!("{}", output);
470 // TODO: Remove this once the 5 second delay is properly replaced.
471 if is_zeta2_context_command {
472 eprintln!("Note that zeta2-context doesn't yet wait for indexing, instead waits 5 seconds.");
473 }
474 let _ = cx.update(|cx| cx.quit());
475 }
476 Err(e) => {
477 eprintln!("Failed: {:?}", e);
478 exit(1);
479 }
480 }
481 })
482 .detach();
483 });
484}