main.rs

  1mod headless;
  2
  3use anyhow::{Result, anyhow};
  4use clap::{Args, Parser, Subcommand};
  5use futures::channel::mpsc;
  6use futures::{FutureExt as _, StreamExt as _};
  7use gpui::{AppContext, Application, AsyncApp};
  8use gpui::{Entity, Task};
  9use language::Bias;
 10use language::Buffer;
 11use language::Point;
 12use language_model::LlmApiToken;
 13use project::{Project, ProjectPath};
 14use release_channel::AppVersion;
 15use reqwest_client::ReqwestClient;
 16use std::path::{Path, PathBuf};
 17use std::process::exit;
 18use std::str::FromStr;
 19use std::sync::Arc;
 20use std::time::Duration;
 21use zeta::{GatherContextOutput, PerformPredictEditsParams, Zeta, gather_context};
 22
 23use crate::headless::ZetaCliAppState;
 24
 25#[derive(Parser, Debug)]
 26#[command(name = "zeta")]
 27struct ZetaCliArgs {
 28    #[command(subcommand)]
 29    command: Commands,
 30}
 31
 32#[derive(Subcommand, Debug)]
 33enum Commands {
 34    Context(ContextArgs),
 35    Predict {
 36        #[arg(long)]
 37        predict_edits_body: Option<FileOrStdin>,
 38        #[clap(flatten)]
 39        context_args: Option<ContextArgs>,
 40    },
 41}
 42
 43#[derive(Debug, Args)]
 44#[group(requires = "worktree")]
 45struct ContextArgs {
 46    #[arg(long)]
 47    worktree: PathBuf,
 48    #[arg(long)]
 49    cursor: CursorPosition,
 50    #[arg(long)]
 51    use_language_server: bool,
 52    #[arg(long)]
 53    events: Option<FileOrStdin>,
 54}
 55
 56#[derive(Debug, Clone)]
 57enum FileOrStdin {
 58    File(PathBuf),
 59    Stdin,
 60}
 61
 62impl FileOrStdin {
 63    async fn read_to_string(&self) -> Result<String, std::io::Error> {
 64        match self {
 65            FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
 66            FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
 67        }
 68    }
 69}
 70
 71impl FromStr for FileOrStdin {
 72    type Err = <PathBuf as FromStr>::Err;
 73
 74    fn from_str(s: &str) -> Result<Self, Self::Err> {
 75        match s {
 76            "-" => Ok(Self::Stdin),
 77            _ => Ok(Self::File(PathBuf::from_str(s)?)),
 78        }
 79    }
 80}
 81
 82#[derive(Debug, Clone)]
 83struct CursorPosition {
 84    path: PathBuf,
 85    point: Point,
 86}
 87
 88impl FromStr for CursorPosition {
 89    type Err = anyhow::Error;
 90
 91    fn from_str(s: &str) -> Result<Self> {
 92        let parts: Vec<&str> = s.split(':').collect();
 93        if parts.len() != 3 {
 94            return Err(anyhow!(
 95                "Invalid cursor format. Expected 'file.rs:line:column', got '{}'",
 96                s
 97            ));
 98        }
 99
100        let path = PathBuf::from(parts[0]);
101        let line: u32 = parts[1]
102            .parse()
103            .map_err(|_| anyhow!("Invalid line number: '{}'", parts[1]))?;
104        let column: u32 = parts[2]
105            .parse()
106            .map_err(|_| anyhow!("Invalid column number: '{}'", parts[2]))?;
107
108        // Convert from 1-based to 0-based indexing
109        let point = Point::new(line.saturating_sub(1), column.saturating_sub(1));
110
111        Ok(CursorPosition { path, point })
112    }
113}
114
115async fn get_context(
116    args: ContextArgs,
117    app_state: &Arc<ZetaCliAppState>,
118    cx: &mut AsyncApp,
119) -> Result<GatherContextOutput> {
120    let ContextArgs {
121        worktree: worktree_path,
122        cursor,
123        use_language_server,
124        events,
125    } = args;
126
127    let worktree_path = worktree_path.canonicalize()?;
128    if cursor.path.is_absolute() {
129        return Err(anyhow!("Absolute paths are not supported in --cursor"));
130    }
131
132    let (project, _lsp_open_handle, buffer) = if use_language_server {
133        let (project, lsp_open_handle, buffer) =
134            open_buffer_with_language_server(&worktree_path, &cursor.path, app_state, cx).await?;
135        (Some(project), Some(lsp_open_handle), buffer)
136    } else {
137        let abs_path = worktree_path.join(&cursor.path);
138        let content = smol::fs::read_to_string(&abs_path).await?;
139        let buffer = cx.new(|cx| Buffer::local(content, cx))?;
140        (None, None, buffer)
141    };
142
143    let worktree_name = worktree_path
144        .file_name()
145        .ok_or_else(|| anyhow!("--worktree path must end with a folder name"))?;
146    let full_path_str = PathBuf::from(worktree_name)
147        .join(&cursor.path)
148        .to_string_lossy()
149        .to_string();
150
151    let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
152    let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
153    if clipped_cursor != cursor.point {
154        let max_row = snapshot.max_point().row;
155        if cursor.point.row < max_row {
156            return Err(anyhow!(
157                "Cursor position {:?} is out of bounds (line length is {})",
158                cursor.point,
159                snapshot.line_len(cursor.point.row)
160            ));
161        } else {
162            return Err(anyhow!(
163                "Cursor position {:?} is out of bounds (max row is {})",
164                cursor.point,
165                max_row
166            ));
167        }
168    }
169
170    let events = match events {
171        Some(events) => events.read_to_string().await?,
172        None => String::new(),
173    };
174    // Enable gathering extra data not currently needed for edit predictions
175    let can_collect_data = true;
176    let git_info = None;
177    let recent_files = None;
178    let mut gather_context_output = cx
179        .update(|cx| {
180            gather_context(
181                project.as_ref(),
182                full_path_str,
183                &snapshot,
184                clipped_cursor,
185                move || events,
186                can_collect_data,
187                git_info,
188                recent_files,
189                cx,
190            )
191        })?
192        .await;
193
194    // Disable data collection for these requests, as this is currently just used for evals
195    if let Ok(gather_context_output) = gather_context_output.as_mut() {
196        gather_context_output.body.can_collect_data = false
197    }
198
199    gather_context_output
200}
201
202pub async fn open_buffer_with_language_server(
203    worktree_path: &Path,
204    path: &Path,
205    app_state: &Arc<ZetaCliAppState>,
206    cx: &mut AsyncApp,
207) -> Result<(Entity<Project>, Entity<Entity<Buffer>>, Entity<Buffer>)> {
208    let project = cx.update(|cx| {
209        Project::local(
210            app_state.client.clone(),
211            app_state.node_runtime.clone(),
212            app_state.user_store.clone(),
213            app_state.languages.clone(),
214            app_state.fs.clone(),
215            None,
216            cx,
217        )
218    })?;
219
220    let worktree = project
221        .update(cx, |project, cx| {
222            project.create_worktree(worktree_path, true, cx)
223        })?
224        .await?;
225
226    let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath {
227        worktree_id: worktree.id(),
228        path: path.to_path_buf().into(),
229    })?;
230
231    let buffer = project
232        .update(cx, |project, cx| project.open_buffer(project_path, cx))?
233        .await?;
234
235    let lsp_open_handle = project.update(cx, |project, cx| {
236        project.register_buffer_with_language_servers(&buffer, cx)
237    })?;
238
239    let log_prefix = path.to_string_lossy().to_string();
240    wait_for_lang_server(&project, &buffer, log_prefix, cx).await?;
241
242    Ok((project, lsp_open_handle, buffer))
243}
244
245// TODO: Dedupe with similar function in crates/eval/src/instance.rs
246pub fn wait_for_lang_server(
247    project: &Entity<Project>,
248    buffer: &Entity<Buffer>,
249    log_prefix: String,
250    cx: &mut AsyncApp,
251) -> Task<Result<()>> {
252    println!("{}⏵ Waiting for language server", log_prefix);
253
254    let (mut tx, mut rx) = mpsc::channel(1);
255
256    let lsp_store = project
257        .read_with(cx, |project, _| project.lsp_store())
258        .unwrap();
259
260    let has_lang_server = buffer
261        .update(cx, |buffer, cx| {
262            lsp_store.update(cx, |lsp_store, cx| {
263                lsp_store
264                    .language_servers_for_local_buffer(buffer, cx)
265                    .next()
266                    .is_some()
267            })
268        })
269        .unwrap_or(false);
270
271    if has_lang_server {
272        project
273            .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
274            .unwrap()
275            .detach();
276    }
277
278    let subscriptions = [
279        cx.subscribe(&lsp_store, {
280            let log_prefix = log_prefix.clone();
281            move |_, event, _| {
282                if let project::LspStoreEvent::LanguageServerUpdate {
283                    message:
284                        client::proto::update_language_server::Variant::WorkProgress(
285                            client::proto::LspWorkProgress {
286                                message: Some(message),
287                                ..
288                            },
289                        ),
290                    ..
291                } = event
292                {
293                    println!("{}{message}", log_prefix)
294                }
295            }
296        }),
297        cx.subscribe(project, {
298            let buffer = buffer.clone();
299            move |project, event, cx| match event {
300                project::Event::LanguageServerAdded(_, _, _) => {
301                    let buffer = buffer.clone();
302                    project
303                        .update(cx, |project, cx| project.save_buffer(buffer, cx))
304                        .detach();
305                }
306                project::Event::DiskBasedDiagnosticsFinished { .. } => {
307                    tx.try_send(()).ok();
308                }
309                _ => {}
310            }
311        }),
312    ];
313
314    cx.spawn(async move |cx| {
315        let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0));
316        let result = futures::select! {
317            _ = rx.next() => {
318                println!("{}⚑ Language server idle", log_prefix);
319                anyhow::Ok(())
320            },
321            _ = timeout.fuse() => {
322                anyhow::bail!("LSP wait timed out after 5 minutes");
323            }
324        };
325        drop(subscriptions);
326        result
327    })
328}
329
330fn main() {
331    let args = ZetaCliArgs::parse();
332    let http_client = Arc::new(ReqwestClient::new());
333    let app = Application::headless().with_http_client(http_client);
334
335    app.run(move |cx| {
336        let app_state = Arc::new(headless::init(cx));
337        cx.spawn(async move |cx| {
338            let result = match args.command {
339                Commands::Context(context_args) => get_context(context_args, &app_state, cx)
340                    .await
341                    .map(|output| serde_json::to_string_pretty(&output.body).unwrap()),
342                Commands::Predict {
343                    predict_edits_body,
344                    context_args,
345                } => {
346                    cx.spawn(async move |cx| {
347                        let app_version = cx.update(|cx| AppVersion::global(cx))?;
348                        app_state.client.sign_in(true, cx).await?;
349                        let llm_token = LlmApiToken::default();
350                        llm_token.refresh(&app_state.client).await?;
351
352                        let predict_edits_body =
353                            if let Some(predict_edits_body) = predict_edits_body {
354                                serde_json::from_str(&predict_edits_body.read_to_string().await?)?
355                            } else if let Some(context_args) = context_args {
356                                get_context(context_args, &app_state, cx).await?.body
357                            } else {
358                                return Err(anyhow!(
359                                    "Expected either --predict-edits-body-file \
360                                    or the required args of the `context` command."
361                                ));
362                            };
363
364                        let (response, _usage) =
365                            Zeta::perform_predict_edits(PerformPredictEditsParams {
366                                client: app_state.client.clone(),
367                                llm_token,
368                                app_version,
369                                body: predict_edits_body,
370                            })
371                            .await?;
372
373                        Ok(response.output_excerpt)
374                    })
375                    .await
376                }
377            };
378            match result {
379                Ok(output) => {
380                    println!("{}", output);
381                    let _ = cx.update(|cx| cx.quit());
382                }
383                Err(e) => {
384                    eprintln!("Failed: {:?}", e);
385                    exit(1);
386                }
387            }
388        })
389        .detach();
390    });
391}