1mod headless;
2
3use anyhow::{Result, anyhow};
4use clap::{Args, Parser, Subcommand};
5use edit_prediction_context::EditPredictionExcerptOptions;
6use futures::channel::mpsc;
7use futures::{FutureExt as _, StreamExt as _};
8use gpui::{AppContext, Application, AsyncApp};
9use gpui::{Entity, Task};
10use language::Bias;
11use language::Buffer;
12use language::Point;
13use language_model::LlmApiToken;
14use project::{Project, ProjectPath, Worktree};
15use release_channel::AppVersion;
16use reqwest_client::ReqwestClient;
17use std::path::{Path, PathBuf};
18use std::process::exit;
19use std::str::FromStr;
20use std::sync::Arc;
21use std::time::Duration;
22use util::paths::PathStyle;
23use util::rel_path::RelPath;
24use zeta::{PerformPredictEditsParams, Zeta};
25
26use crate::headless::ZetaCliAppState;
27
28#[derive(Parser, Debug)]
29#[command(name = "zeta")]
30struct ZetaCliArgs {
31 #[command(subcommand)]
32 command: Commands,
33}
34
35#[derive(Subcommand, Debug)]
36enum Commands {
37 Context(ContextArgs),
38 Zeta2Context {
39 #[clap(flatten)]
40 zeta2_args: Zeta2Args,
41 #[clap(flatten)]
42 context_args: ContextArgs,
43 },
44 Predict {
45 #[arg(long)]
46 predict_edits_body: Option<FileOrStdin>,
47 #[clap(flatten)]
48 context_args: Option<ContextArgs>,
49 },
50}
51
52#[derive(Debug, Args)]
53#[group(requires = "worktree")]
54struct ContextArgs {
55 #[arg(long)]
56 worktree: PathBuf,
57 #[arg(long)]
58 cursor: CursorPosition,
59 #[arg(long)]
60 use_language_server: bool,
61 #[arg(long)]
62 events: Option<FileOrStdin>,
63}
64
65#[derive(Debug, Args)]
66struct Zeta2Args {
67 #[arg(long, default_value_t = 8192)]
68 max_prompt_bytes: usize,
69 #[arg(long, default_value_t = 2048)]
70 max_excerpt_bytes: usize,
71 #[arg(long, default_value_t = 1024)]
72 min_excerpt_bytes: usize,
73 #[arg(long, default_value_t = 0.66)]
74 target_before_cursor_over_total_bytes: f32,
75 #[arg(long, default_value_t = 1024)]
76 max_diagnostic_bytes: usize,
77}
78
79#[derive(Debug, Clone)]
80enum FileOrStdin {
81 File(PathBuf),
82 Stdin,
83}
84
85impl FileOrStdin {
86 async fn read_to_string(&self) -> Result<String, std::io::Error> {
87 match self {
88 FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
89 FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
90 }
91 }
92}
93
94impl FromStr for FileOrStdin {
95 type Err = <PathBuf as FromStr>::Err;
96
97 fn from_str(s: &str) -> Result<Self, Self::Err> {
98 match s {
99 "-" => Ok(Self::Stdin),
100 _ => Ok(Self::File(PathBuf::from_str(s)?)),
101 }
102 }
103}
104
105#[derive(Debug, Clone)]
106struct CursorPosition {
107 path: Arc<RelPath>,
108 point: Point,
109}
110
111impl FromStr for CursorPosition {
112 type Err = anyhow::Error;
113
114 fn from_str(s: &str) -> Result<Self> {
115 let parts: Vec<&str> = s.split(':').collect();
116 if parts.len() != 3 {
117 return Err(anyhow!(
118 "Invalid cursor format. Expected 'file.rs:line:column', got '{}'",
119 s
120 ));
121 }
122
123 let path = RelPath::from_std_path(Path::new(&parts[0]), PathStyle::local())?;
124 let line: u32 = parts[1]
125 .parse()
126 .map_err(|_| anyhow!("Invalid line number: '{}'", parts[1]))?;
127 let column: u32 = parts[2]
128 .parse()
129 .map_err(|_| anyhow!("Invalid column number: '{}'", parts[2]))?;
130
131 // Convert from 1-based to 0-based indexing
132 let point = Point::new(line.saturating_sub(1), column.saturating_sub(1));
133
134 Ok(CursorPosition { path, point })
135 }
136}
137
138enum GetContextOutput {
139 Zeta1(zeta::GatherContextOutput),
140 Zeta2(String),
141}
142
143async fn get_context(
144 zeta2_args: Option<Zeta2Args>,
145 args: ContextArgs,
146 app_state: &Arc<ZetaCliAppState>,
147 cx: &mut AsyncApp,
148) -> Result<GetContextOutput> {
149 let ContextArgs {
150 worktree: worktree_path,
151 cursor,
152 use_language_server,
153 events,
154 } = args;
155
156 let worktree_path = worktree_path.canonicalize()?;
157
158 let project = cx.update(|cx| {
159 Project::local(
160 app_state.client.clone(),
161 app_state.node_runtime.clone(),
162 app_state.user_store.clone(),
163 app_state.languages.clone(),
164 app_state.fs.clone(),
165 None,
166 cx,
167 )
168 })?;
169
170 let worktree = project
171 .update(cx, |project, cx| {
172 project.create_worktree(&worktree_path, true, cx)
173 })?
174 .await?;
175
176 let (_lsp_open_handle, buffer) = if use_language_server {
177 let (lsp_open_handle, buffer) =
178 open_buffer_with_language_server(&project, &worktree, &cursor.path, cx).await?;
179 (Some(lsp_open_handle), buffer)
180 } else {
181 let buffer = open_buffer(&project, &worktree, &cursor.path, cx).await?;
182 (None, buffer)
183 };
184
185 let full_path_str = worktree
186 .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
187 .display(PathStyle::local())
188 .to_string();
189
190 let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
191 let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
192 if clipped_cursor != cursor.point {
193 let max_row = snapshot.max_point().row;
194 if cursor.point.row < max_row {
195 return Err(anyhow!(
196 "Cursor position {:?} is out of bounds (line length is {})",
197 cursor.point,
198 snapshot.line_len(cursor.point.row)
199 ));
200 } else {
201 return Err(anyhow!(
202 "Cursor position {:?} is out of bounds (max row is {})",
203 cursor.point,
204 max_row
205 ));
206 }
207 }
208
209 let events = match events {
210 Some(events) => events.read_to_string().await?,
211 None => String::new(),
212 };
213
214 if let Some(zeta2_args) = zeta2_args {
215 Ok(GetContextOutput::Zeta2(
216 cx.update(|cx| {
217 let zeta = cx.new(|cx| {
218 zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
219 });
220 zeta.update(cx, |zeta, cx| {
221 zeta.register_buffer(&buffer, &project, cx);
222 zeta.set_options(zeta2::ZetaOptions {
223 excerpt: EditPredictionExcerptOptions {
224 max_bytes: zeta2_args.max_excerpt_bytes,
225 min_bytes: zeta2_args.min_excerpt_bytes,
226 target_before_cursor_over_total_bytes: zeta2_args
227 .target_before_cursor_over_total_bytes,
228 },
229 max_diagnostic_bytes: zeta2_args.max_diagnostic_bytes,
230 max_prompt_bytes: zeta2_args.max_prompt_bytes,
231 })
232 });
233 // TODO: Actually wait for indexing.
234 let timer = cx.background_executor().timer(Duration::from_secs(5));
235 cx.spawn(async move |cx| {
236 timer.await;
237 let request = zeta
238 .update(cx, |zeta, cx| {
239 let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
240 zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
241 })?
242 .await?;
243 let planned_prompt = cloud_zeta2_prompt::PlannedPrompt::populate(
244 &request,
245 &cloud_zeta2_prompt::PlanOptions {
246 max_bytes: zeta2_args.max_prompt_bytes,
247 },
248 )?;
249 anyhow::Ok(planned_prompt.to_prompt_string())
250 })
251 })?
252 .await?,
253 ))
254 } else {
255 let prompt_for_events = move || (events, 0);
256 Ok(GetContextOutput::Zeta1(
257 cx.update(|cx| {
258 zeta::gather_context(
259 full_path_str,
260 &snapshot,
261 clipped_cursor,
262 prompt_for_events,
263 cx,
264 )
265 })?
266 .await?,
267 ))
268 }
269}
270
271pub async fn open_buffer(
272 project: &Entity<Project>,
273 worktree: &Entity<Worktree>,
274 path: &RelPath,
275 cx: &mut AsyncApp,
276) -> Result<Entity<Buffer>> {
277 let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath {
278 worktree_id: worktree.id(),
279 path: path.into(),
280 })?;
281
282 project
283 .update(cx, |project, cx| project.open_buffer(project_path, cx))?
284 .await
285}
286
287pub async fn open_buffer_with_language_server(
288 project: &Entity<Project>,
289 worktree: &Entity<Worktree>,
290 path: &RelPath,
291 cx: &mut AsyncApp,
292) -> Result<(Entity<Entity<Buffer>>, Entity<Buffer>)> {
293 let buffer = open_buffer(project, worktree, path, cx).await?;
294
295 let (lsp_open_handle, path_style) = project.update(cx, |project, cx| {
296 (
297 project.register_buffer_with_language_servers(&buffer, cx),
298 project.path_style(cx),
299 )
300 })?;
301
302 let log_prefix = path.display(path_style);
303 wait_for_lang_server(&project, &buffer, log_prefix.into_owned(), cx).await?;
304
305 Ok((lsp_open_handle, buffer))
306}
307
308// TODO: Dedupe with similar function in crates/eval/src/instance.rs
309pub fn wait_for_lang_server(
310 project: &Entity<Project>,
311 buffer: &Entity<Buffer>,
312 log_prefix: String,
313 cx: &mut AsyncApp,
314) -> Task<Result<()>> {
315 println!("{}⏵ Waiting for language server", log_prefix);
316
317 let (mut tx, mut rx) = mpsc::channel(1);
318
319 let lsp_store = project
320 .read_with(cx, |project, _| project.lsp_store())
321 .unwrap();
322
323 let has_lang_server = buffer
324 .update(cx, |buffer, cx| {
325 lsp_store.update(cx, |lsp_store, cx| {
326 lsp_store
327 .language_servers_for_local_buffer(buffer, cx)
328 .next()
329 .is_some()
330 })
331 })
332 .unwrap_or(false);
333
334 if has_lang_server {
335 project
336 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
337 .unwrap()
338 .detach();
339 }
340
341 let subscriptions = [
342 cx.subscribe(&lsp_store, {
343 let log_prefix = log_prefix.clone();
344 move |_, event, _| {
345 if let project::LspStoreEvent::LanguageServerUpdate {
346 message:
347 client::proto::update_language_server::Variant::WorkProgress(
348 client::proto::LspWorkProgress {
349 message: Some(message),
350 ..
351 },
352 ),
353 ..
354 } = event
355 {
356 println!("{}⟲ {message}", log_prefix)
357 }
358 }
359 }),
360 cx.subscribe(project, {
361 let buffer = buffer.clone();
362 move |project, event, cx| match event {
363 project::Event::LanguageServerAdded(_, _, _) => {
364 let buffer = buffer.clone();
365 project
366 .update(cx, |project, cx| project.save_buffer(buffer, cx))
367 .detach();
368 }
369 project::Event::DiskBasedDiagnosticsFinished { .. } => {
370 tx.try_send(()).ok();
371 }
372 _ => {}
373 }
374 }),
375 ];
376
377 cx.spawn(async move |cx| {
378 let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0));
379 let result = futures::select! {
380 _ = rx.next() => {
381 println!("{}⚑ Language server idle", log_prefix);
382 anyhow::Ok(())
383 },
384 _ = timeout.fuse() => {
385 anyhow::bail!("LSP wait timed out after 5 minutes");
386 }
387 };
388 drop(subscriptions);
389 result
390 })
391}
392
393fn main() {
394 let args = ZetaCliArgs::parse();
395 let http_client = Arc::new(ReqwestClient::new());
396 let app = Application::headless().with_http_client(http_client);
397
398 app.run(move |cx| {
399 let app_state = Arc::new(headless::init(cx));
400 let is_zeta2_context_command = matches!(args.command, Commands::Zeta2Context { .. });
401 cx.spawn(async move |cx| {
402 let result = match args.command {
403 Commands::Zeta2Context {
404 zeta2_args,
405 context_args,
406 } => match get_context(Some(zeta2_args), context_args, &app_state, cx).await {
407 Ok(GetContextOutput::Zeta1 { .. }) => unreachable!(),
408 Ok(GetContextOutput::Zeta2(output)) => Ok(output),
409 Err(err) => Err(err),
410 },
411 Commands::Context(context_args) => {
412 match get_context(None, context_args, &app_state, cx).await {
413 Ok(GetContextOutput::Zeta1(output)) => {
414 Ok(serde_json::to_string_pretty(&output.body).unwrap())
415 }
416 Ok(GetContextOutput::Zeta2 { .. }) => unreachable!(),
417 Err(err) => Err(err),
418 }
419 }
420 Commands::Predict {
421 predict_edits_body,
422 context_args,
423 } => {
424 cx.spawn(async move |cx| {
425 let app_version = cx.update(|cx| AppVersion::global(cx))?;
426 app_state.client.sign_in(true, cx).await?;
427 let llm_token = LlmApiToken::default();
428 llm_token.refresh(&app_state.client).await?;
429
430 let predict_edits_body =
431 if let Some(predict_edits_body) = predict_edits_body {
432 serde_json::from_str(&predict_edits_body.read_to_string().await?)?
433 } else if let Some(context_args) = context_args {
434 match get_context(None, context_args, &app_state, cx).await? {
435 GetContextOutput::Zeta1(output) => output.body,
436 GetContextOutput::Zeta2 { .. } => unreachable!(),
437 }
438 } else {
439 return Err(anyhow!(
440 "Expected either --predict-edits-body-file \
441 or the required args of the `context` command."
442 ));
443 };
444
445 let (response, _usage) =
446 Zeta::perform_predict_edits(PerformPredictEditsParams {
447 client: app_state.client.clone(),
448 llm_token,
449 app_version,
450 body: predict_edits_body,
451 })
452 .await?;
453
454 Ok(response.output_excerpt)
455 })
456 .await
457 }
458 };
459 match result {
460 Ok(output) => {
461 println!("{}", output);
462 // TODO: Remove this once the 5 second delay is properly replaced.
463 if is_zeta2_context_command {
464 eprintln!("Note that zeta2-context doesn't yet wait for indexing, instead waits 5 seconds.");
465 }
466 let _ = cx.update(|cx| cx.quit());
467 }
468 Err(e) => {
469 eprintln!("Failed: {:?}", e);
470 exit(1);
471 }
472 }
473 })
474 .detach();
475 });
476}