From d312d59ace3feda4a0be5c63aa4f61f2d9ecca19 Mon Sep 17 00:00:00 2001 From: Oleksiy Syvokon Date: Mon, 8 Dec 2025 15:13:22 +0200 Subject: [PATCH] Add `zeta distill` command (#44369) This PR partially implements a knowledge distillation data pipeline. `zeta distill` gets a dataset of chronologically ordered commits and generates synthetic predictions with a teacher model (one-shot Claude Sonnet). `zeta distill --batches cache.db` will enable Message Batches API. Under the first run, this command will collect all LLM requests and upload a batch of them to Anthropic. On subsequent runs, it will check the batch status. If ready, it will download the result and put them into the local cache. Release Notes: - N/A --------- Co-authored-by: Piotr Osiewicz <24362066+osiewicz@users.noreply.github.com> Co-authored-by: Ben Kunkle --- Cargo.lock | 4 + Cargo.toml | 25 -- crates/anthropic/src/anthropic.rs | 143 ++++-- crates/anthropic/src/batches.rs | 190 ++++++++ crates/edit_prediction_cli/Cargo.toml | 6 +- crates/edit_prediction_cli/src/example.rs | 251 ++++++----- crates/edit_prediction_cli/src/main.rs | 21 +- .../src/training/context.rs | 89 ++++ .../src/training/distill.rs | 94 ++++ .../src/training/llm_client.rs | 417 ++++++++++++++++++ .../edit_prediction_cli/src/training/mod.rs | 4 + .../src/training/teacher.prompt.md | 48 ++ .../src/training/teacher.rs | 266 +++++++++++ 13 files changed, 1370 insertions(+), 188 deletions(-) create mode 100644 crates/anthropic/src/batches.rs create mode 100644 crates/edit_prediction_cli/src/training/context.rs create mode 100644 crates/edit_prediction_cli/src/training/distill.rs create mode 100644 crates/edit_prediction_cli/src/training/llm_client.rs create mode 100644 crates/edit_prediction_cli/src/training/mod.rs create mode 100644 crates/edit_prediction_cli/src/training/teacher.prompt.md create mode 100644 crates/edit_prediction_cli/src/training/teacher.rs diff --git a/Cargo.lock b/Cargo.lock index 2fa19c12df90ad026223020beb3772ec5215fc80..a671c1797f7b9abd9a7ec5262965a68132cef8ee 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5167,6 +5167,7 @@ dependencies = [ name = "edit_prediction_cli" version = "0.1.0" dependencies = [ + "anthropic", "anyhow", "chrono", "clap", @@ -5182,6 +5183,7 @@ dependencies = [ "futures 0.3.31", "gpui", "gpui_tokio", + "http_client", "indoc", "language", "language_extension", @@ -5202,6 +5204,8 @@ dependencies = [ "settings", "shellexpand 2.1.2", "smol", + "sqlez", + "sqlez_macros", "terminal_view", "toml 0.8.23", "util", diff --git a/Cargo.toml b/Cargo.toml index 2d203499c208d7ad3f366296b227807b496bf8ca..0ad4d2b14523988aa0dd6e3bfc935f84bcd0d8d9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -244,7 +244,6 @@ activity_indicator = { path = "crates/activity_indicator" } agent_ui = { path = "crates/agent_ui" } agent_settings = { path = "crates/agent_settings" } agent_servers = { path = "crates/agent_servers" } -ai = { path = "crates/ai" } ai_onboarding = { path = "crates/ai_onboarding" } anthropic = { path = "crates/anthropic" } askpass = { path = "crates/askpass" } @@ -254,7 +253,6 @@ assistant_slash_command = { path = "crates/assistant_slash_command" } assistant_slash_commands = { path = "crates/assistant_slash_commands" } audio = { path = "crates/audio" } auto_update = { path = "crates/auto_update" } -auto_update_helper = { path = "crates/auto_update_helper" } auto_update_ui = { path = "crates/auto_update_ui" } aws_http_client = { path = "crates/aws_http_client" } bedrock = { path = "crates/bedrock" } @@ -269,7 +267,6 @@ cloud_api_client = { path = "crates/cloud_api_client" } cloud_api_types = { path = "crates/cloud_api_types" } cloud_llm_client = { path = "crates/cloud_llm_client" } cloud_zeta2_prompt = { path = "crates/cloud_zeta2_prompt" } -collab = { path = "crates/collab" } collab_ui = { path = "crates/collab_ui" } collections = { path = "crates/collections", version = "0.1.0" } command_palette = { path = "crates/command_palette" } @@ -358,8 +355,6 @@ panel = { path = "crates/panel" } paths = { path = "crates/paths" } perf = { path = "tooling/perf" } picker = { path = "crates/picker" } -plugin = { path = "crates/plugin" } -plugin_macros = { path = "crates/plugin_macros" } prettier = { path = "crates/prettier" } settings_profile_selector = { path = "crates/settings_profile_selector" } project = { path = "crates/project" } @@ -370,12 +365,10 @@ proto = { path = "crates/proto" } recent_projects = { path = "crates/recent_projects" } refineable = { path = "crates/refineable" } release_channel = { path = "crates/release_channel" } -scheduler = { path = "crates/scheduler" } remote = { path = "crates/remote" } remote_server = { path = "crates/remote_server" } repl = { path = "crates/repl" } reqwest_client = { path = "crates/reqwest_client" } -rich_text = { path = "crates/rich_text" } rodio = { git = "https://github.com/RustAudio/rodio", rev ="e2074c6c2acf07b57cf717e076bdda7a9ac6e70b", features = ["wav", "playback", "wav_output", "recording"] } rope = { path = "crates/rope" } rpc = { path = "crates/rpc" } @@ -392,7 +385,6 @@ snippets_ui = { path = "crates/snippets_ui" } sqlez = { path = "crates/sqlez" } sqlez_macros = { path = "crates/sqlez_macros" } story = { path = "crates/story" } -storybook = { path = "crates/storybook" } streaming_diff = { path = "crates/streaming_diff" } sum_tree = { path = "crates/sum_tree" } supermaven = { path = "crates/supermaven" } @@ -409,7 +401,6 @@ terminal_view = { path = "crates/terminal_view" } text = { path = "crates/text" } theme = { path = "crates/theme" } theme_extension = { path = "crates/theme_extension" } -theme_importer = { path = "crates/theme_importer" } theme_selector = { path = "crates/theme_selector" } time_format = { path = "crates/time_format" } title_bar = { path = "crates/title_bar" } @@ -510,13 +501,11 @@ exec = "0.3.1" fancy-regex = "0.16.0" fork = "0.4.0" futures = "0.3" -futures-batch = "0.6.1" futures-lite = "1.13" gh-workflow = { git = "https://github.com/zed-industries/gh-workflow", rev = "09acfdf2bd5c1d6254abefd609c808ff73547b2c" } git2 = { version = "0.20.1", default-features = false } globset = "0.4" handlebars = "4.3" -hashbrown = "0.15.3" heck = "0.5" heed = { version = "0.21.0", features = ["read-txn-no-tls"] } hex = "0.4.3" @@ -552,7 +541,6 @@ nanoid = "0.4" nbformat = "0.15.0" nix = "0.29" num-format = "0.4.4" -num-traits = "0.2" objc = "0.2" objc2-foundation = { version = "=0.3.1", default-features = false, features = [ "NSArray", @@ -591,7 +579,6 @@ pet = { git = "https://github.com/microsoft/python-environment-tools.git", rev = pet-conda = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "1e86914c3ce2f3a08c0cedbcb0615a7f9fa7a5da" } pet-core = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "1e86914c3ce2f3a08c0cedbcb0615a7f9fa7a5da" } pet-fs = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "1e86914c3ce2f3a08c0cedbcb0615a7f9fa7a5da" } -pet-pixi = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "1e86914c3ce2f3a08c0cedbcb0615a7f9fa7a5da" } pet-poetry = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "1e86914c3ce2f3a08c0cedbcb0615a7f9fa7a5da" } pet-reporter = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "1e86914c3ce2f3a08c0cedbcb0615a7f9fa7a5da" } pet-virtualenv = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "1e86914c3ce2f3a08c0cedbcb0615a7f9fa7a5da" } @@ -631,7 +618,6 @@ scap = { git = "https://github.com/zed-industries/scap", rev = "4afea48c3b002197 schemars = { version = "1.0", features = ["indexmap2"] } semver = { version = "1.0", features = ["serde"] } serde = { version = "1.0.221", features = ["derive", "rc"] } -serde_derive = "1.0.221" serde_json = { version = "1.0.144", features = ["preserve_order", "raw_value"] } serde_json_lenient = { version = "0.2", features = [ "preserve_order", @@ -643,7 +629,6 @@ serde_urlencoded = "0.7" sha2 = "0.10" shellexpand = "2.1.0" shlex = "1.3.0" -similar = "2.6" simplelog = "0.12.2" slotmap = "1.0.6" smallvec = { version = "1.6", features = ["union"] } @@ -722,7 +707,6 @@ wasmtime-wasi = "29" wax = "0.6" which = "6.0.0" windows-core = "0.61" -wit-component = "0.221" yawc = "0.2.5" zeroize = "1.8" zstd = "0.11" @@ -804,20 +788,13 @@ settings_macros = { opt-level = 3 } sqlez_macros = { opt-level = 3, codegen-units = 1 } ui_macros = { opt-level = 3 } util_macros = { opt-level = 3 } -serde_derive = { opt-level = 3 } quote = { opt-level = 3 } syn = { opt-level = 3 } proc-macro2 = { opt-level = 3 } # proc-macros end taffy = { opt-level = 3 } -cranelift-codegen = { opt-level = 3 } -cranelift-codegen-meta = { opt-level = 3 } -cranelift-codegen-shared = { opt-level = 3 } resvg = { opt-level = 3 } -rustybuzz = { opt-level = 3 } -ttf-parser = { opt-level = 3 } -wasmtime-cranelift = { opt-level = 3 } wasmtime = { opt-level = 3 } # Build single-source-file crates with cg=1 as it helps make `cargo build` of a whole workspace a bit faster activity_indicator = { codegen-units = 1 } @@ -826,7 +803,6 @@ breadcrumbs = { codegen-units = 1 } collections = { codegen-units = 1 } command_palette = { codegen-units = 1 } command_palette_hooks = { codegen-units = 1 } -extension_cli = { codegen-units = 1 } feature_flags = { codegen-units = 1 } file_icons = { codegen-units = 1 } fsevent = { codegen-units = 1 } @@ -846,7 +822,6 @@ project_symbols = { codegen-units = 1 } refineable = { codegen-units = 1 } release_channel = { codegen-units = 1 } reqwest_client = { codegen-units = 1 } -rich_text = { codegen-units = 1 } session = { codegen-units = 1 } snippet = { codegen-units = 1 } snippets_ui = { codegen-units = 1 } diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index 041401418c427251a944fc39bb8ac83a0e22bc13..09b293b122624274b7484026f35d1bcc8e265ece 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -12,6 +12,8 @@ pub use settings::{AnthropicAvailableModel as AvailableModel, ModelMode}; use strum::{EnumIter, EnumString}; use thiserror::Error; +pub mod batches; + pub const ANTHROPIC_API_URL: &str = "https://api.anthropic.com"; #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] @@ -465,6 +467,7 @@ impl Model { } } +/// Generate completion with streaming. pub async fn stream_completion( client: &dyn HttpClient, api_url: &str, @@ -477,6 +480,101 @@ pub async fn stream_completion( .map(|output| output.0) } +/// Generate completion without streaming. +pub async fn non_streaming_completion( + client: &dyn HttpClient, + api_url: &str, + api_key: &str, + request: Request, + beta_headers: Option, +) -> Result { + let (mut response, rate_limits) = + send_request(client, api_url, api_key, &request, beta_headers).await?; + + if response.status().is_success() { + let mut body = String::new(); + response + .body_mut() + .read_to_string(&mut body) + .await + .map_err(AnthropicError::ReadResponse)?; + + serde_json::from_str(&body).map_err(AnthropicError::DeserializeResponse) + } else { + Err(handle_error_response(response, rate_limits).await) + } +} + +async fn send_request( + client: &dyn HttpClient, + api_url: &str, + api_key: &str, + request: impl Serialize, + beta_headers: Option, +) -> Result<(http::Response, RateLimitInfo), AnthropicError> { + let uri = format!("{api_url}/v1/messages"); + + let mut request_builder = HttpRequest::builder() + .method(Method::POST) + .uri(uri) + .header("Anthropic-Version", "2023-06-01") + .header("X-Api-Key", api_key.trim()) + .header("Content-Type", "application/json"); + + if let Some(beta_headers) = beta_headers { + request_builder = request_builder.header("Anthropic-Beta", beta_headers); + } + + let serialized_request = + serde_json::to_string(&request).map_err(AnthropicError::SerializeRequest)?; + let request = request_builder + .body(AsyncBody::from(serialized_request)) + .map_err(AnthropicError::BuildRequestBody)?; + + let response = client + .send(request) + .await + .map_err(AnthropicError::HttpSend)?; + + let rate_limits = RateLimitInfo::from_headers(response.headers()); + + Ok((response, rate_limits)) +} + +async fn handle_error_response( + mut response: http::Response, + rate_limits: RateLimitInfo, +) -> AnthropicError { + if response.status().as_u16() == 529 { + return AnthropicError::ServerOverloaded { + retry_after: rate_limits.retry_after, + }; + } + + if let Some(retry_after) = rate_limits.retry_after { + return AnthropicError::RateLimit { retry_after }; + } + + let mut body = String::new(); + let read_result = response + .body_mut() + .read_to_string(&mut body) + .await + .map_err(AnthropicError::ReadResponse); + + if let Err(err) = read_result { + return err; + } + + match serde_json::from_str::(&body) { + Ok(Event::Error { error }) => AnthropicError::ApiError(error), + Ok(_) | Err(_) => AnthropicError::HttpResponseError { + status_code: response.status(), + message: body, + }, + } +} + /// An individual rate limit. #[derive(Debug)] pub struct RateLimit { @@ -580,30 +678,10 @@ pub async fn stream_completion_with_rate_limit_info( base: request, stream: true, }; - let uri = format!("{api_url}/v1/messages"); - let mut request_builder = HttpRequest::builder() - .method(Method::POST) - .uri(uri) - .header("Anthropic-Version", "2023-06-01") - .header("X-Api-Key", api_key.trim()) - .header("Content-Type", "application/json"); + let (response, rate_limits) = + send_request(client, api_url, api_key, &request, beta_headers).await?; - if let Some(beta_headers) = beta_headers { - request_builder = request_builder.header("Anthropic-Beta", beta_headers); - } - - let serialized_request = - serde_json::to_string(&request).map_err(AnthropicError::SerializeRequest)?; - let request = request_builder - .body(AsyncBody::from(serialized_request)) - .map_err(AnthropicError::BuildRequestBody)?; - - let mut response = client - .send(request) - .await - .map_err(AnthropicError::HttpSend)?; - let rate_limits = RateLimitInfo::from_headers(response.headers()); if response.status().is_success() { let reader = BufReader::new(response.into_body()); let stream = reader @@ -622,27 +700,8 @@ pub async fn stream_completion_with_rate_limit_info( }) .boxed(); Ok((stream, Some(rate_limits))) - } else if response.status().as_u16() == 529 { - Err(AnthropicError::ServerOverloaded { - retry_after: rate_limits.retry_after, - }) - } else if let Some(retry_after) = rate_limits.retry_after { - Err(AnthropicError::RateLimit { retry_after }) } else { - let mut body = String::new(); - response - .body_mut() - .read_to_string(&mut body) - .await - .map_err(AnthropicError::ReadResponse)?; - - match serde_json::from_str::(&body) { - Ok(Event::Error { error }) => Err(AnthropicError::ApiError(error)), - Ok(_) | Err(_) => Err(AnthropicError::HttpResponseError { - status_code: response.status(), - message: body, - }), - } + Err(handle_error_response(response, rate_limits).await) } } diff --git a/crates/anthropic/src/batches.rs b/crates/anthropic/src/batches.rs new file mode 100644 index 0000000000000000000000000000000000000000..5fb594348d45c84e8c246c2611f7cde3aa77a18d --- /dev/null +++ b/crates/anthropic/src/batches.rs @@ -0,0 +1,190 @@ +use anyhow::Result; +use futures::AsyncReadExt; +use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; +use serde::{Deserialize, Serialize}; + +use crate::{AnthropicError, ApiError, RateLimitInfo, Request, Response}; + +#[derive(Debug, Serialize, Deserialize)] +pub struct BatchRequest { + pub custom_id: String, + pub params: Request, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct CreateBatchRequest { + pub requests: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct MessageBatchRequestCounts { + pub processing: u64, + pub succeeded: u64, + pub errored: u64, + pub canceled: u64, + pub expired: u64, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct MessageBatch { + pub id: String, + #[serde(rename = "type")] + pub batch_type: String, + pub processing_status: String, + pub request_counts: MessageBatchRequestCounts, + pub ended_at: Option, + pub created_at: String, + pub expires_at: String, + pub archived_at: Option, + pub cancel_initiated_at: Option, + pub results_url: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum BatchResult { + #[serde(rename = "succeeded")] + Succeeded { message: Response }, + #[serde(rename = "errored")] + Errored { error: ApiError }, + #[serde(rename = "canceled")] + Canceled, + #[serde(rename = "expired")] + Expired, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct BatchIndividualResponse { + pub custom_id: String, + pub result: BatchResult, +} + +pub async fn create_batch( + client: &dyn HttpClient, + api_url: &str, + api_key: &str, + request: CreateBatchRequest, +) -> Result { + let uri = format!("{api_url}/v1/messages/batches"); + + let request_builder = HttpRequest::builder() + .method(Method::POST) + .uri(uri) + .header("Anthropic-Version", "2023-06-01") + .header("X-Api-Key", api_key.trim()) + .header("Content-Type", "application/json"); + + let serialized_request = + serde_json::to_string(&request).map_err(AnthropicError::SerializeRequest)?; + let http_request = request_builder + .body(AsyncBody::from(serialized_request)) + .map_err(AnthropicError::BuildRequestBody)?; + + let mut response = client + .send(http_request) + .await + .map_err(AnthropicError::HttpSend)?; + + let rate_limits = RateLimitInfo::from_headers(response.headers()); + + if response.status().is_success() { + let mut body = String::new(); + response + .body_mut() + .read_to_string(&mut body) + .await + .map_err(AnthropicError::ReadResponse)?; + + serde_json::from_str(&body).map_err(AnthropicError::DeserializeResponse) + } else { + Err(crate::handle_error_response(response, rate_limits).await) + } +} + +pub async fn retrieve_batch( + client: &dyn HttpClient, + api_url: &str, + api_key: &str, + message_batch_id: &str, +) -> Result { + let uri = format!("{api_url}/v1/messages/batches/{message_batch_id}"); + + let request_builder = HttpRequest::builder() + .method(Method::GET) + .uri(uri) + .header("Anthropic-Version", "2023-06-01") + .header("X-Api-Key", api_key.trim()); + + let http_request = request_builder + .body(AsyncBody::default()) + .map_err(AnthropicError::BuildRequestBody)?; + + let mut response = client + .send(http_request) + .await + .map_err(AnthropicError::HttpSend)?; + + let rate_limits = RateLimitInfo::from_headers(response.headers()); + + if response.status().is_success() { + let mut body = String::new(); + response + .body_mut() + .read_to_string(&mut body) + .await + .map_err(AnthropicError::ReadResponse)?; + + serde_json::from_str(&body).map_err(AnthropicError::DeserializeResponse) + } else { + Err(crate::handle_error_response(response, rate_limits).await) + } +} + +pub async fn retrieve_batch_results( + client: &dyn HttpClient, + api_url: &str, + api_key: &str, + message_batch_id: &str, +) -> Result, AnthropicError> { + let uri = format!("{api_url}/v1/messages/batches/{message_batch_id}/results"); + + let request_builder = HttpRequest::builder() + .method(Method::GET) + .uri(uri) + .header("Anthropic-Version", "2023-06-01") + .header("X-Api-Key", api_key.trim()); + + let http_request = request_builder + .body(AsyncBody::default()) + .map_err(AnthropicError::BuildRequestBody)?; + + let mut response = client + .send(http_request) + .await + .map_err(AnthropicError::HttpSend)?; + + let rate_limits = RateLimitInfo::from_headers(response.headers()); + + if response.status().is_success() { + let mut body = String::new(); + response + .body_mut() + .read_to_string(&mut body) + .await + .map_err(AnthropicError::ReadResponse)?; + + let mut results = Vec::new(); + for line in body.lines() { + if line.trim().is_empty() { + continue; + } + let result: BatchIndividualResponse = + serde_json::from_str(line).map_err(AnthropicError::DeserializeResponse)?; + results.push(result); + } + + Ok(results) + } else { + Err(crate::handle_error_response(response, rate_limits).await) + } +} diff --git a/crates/edit_prediction_cli/Cargo.toml b/crates/edit_prediction_cli/Cargo.toml index d1b0b3f912ed2143b6c75ae39e94c2f7780ec4fe..26a060994d75a2c194cc159c33d88fbc296dfa47 100644 --- a/crates/edit_prediction_cli/Cargo.toml +++ b/crates/edit_prediction_cli/Cargo.toml @@ -13,8 +13,9 @@ name = "ep_cli" path = "src/main.rs" [dependencies] - anyhow.workspace = true +anthropic.workspace = true +http_client.workspace = true chrono.workspace = true clap.workspace = true client.workspace = true @@ -28,6 +29,7 @@ fs.workspace = true futures.workspace = true gpui.workspace = true gpui_tokio.workspace = true +indoc.workspace = true language.workspace = true language_extension.workspace = true language_model.workspace = true @@ -46,6 +48,8 @@ serde_json.workspace = true settings.workspace = true shellexpand.workspace = true smol.workspace = true +sqlez.workspace = true +sqlez_macros.workspace = true terminal_view.workspace = true toml.workspace = true util.workspace = true diff --git a/crates/edit_prediction_cli/src/example.rs b/crates/edit_prediction_cli/src/example.rs index 2f52b89c552b65072f753432eb63b656624fdf61..4f8c1867cd57d7fb5dbb9c2c08b63dccf2b97d30 100644 --- a/crates/edit_prediction_cli/src/example.rs +++ b/crates/edit_prediction_cli/src/example.rs @@ -3,6 +3,8 @@ use std::{ cell::RefCell, fmt::{self, Display}, fs, + hash::Hash, + hash::Hasher, io::Write, mem, path::{Path, PathBuf}, @@ -43,7 +45,7 @@ pub struct NamedExample { pub example: Example, } -#[derive(Clone, Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Hash, Serialize, Deserialize)] pub struct Example { pub repository_url: String, pub revision: String, @@ -54,6 +56,134 @@ pub struct Example { pub expected_patch: String, } +impl Example { + fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> { + // git@github.com:owner/repo.git + if self.repository_url.contains('@') { + let (owner, repo) = self + .repository_url + .split_once(':') + .context("expected : in git url")? + .1 + .split_once('/') + .context("expected / in git url")?; + Ok(( + Cow::Borrowed(owner), + Cow::Borrowed(repo.trim_end_matches(".git")), + )) + // http://github.com/owner/repo.git + } else { + let url = Url::parse(&self.repository_url)?; + let mut segments = url.path_segments().context("empty http url")?; + let owner = segments + .next() + .context("expected owner path segment")? + .to_string(); + let repo = segments + .next() + .context("expected repo path segment")? + .trim_end_matches(".git") + .to_string(); + assert!(segments.next().is_none()); + + Ok((owner.into(), repo.into())) + } + } + + pub async fn setup_worktree(&self, file_name: String) -> Result { + let (repo_owner, repo_name) = self.repo_name()?; + + let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref()); + let repo_lock = lock_repo(&repo_dir).await; + + if !repo_dir.is_dir() { + fs::create_dir_all(&repo_dir)?; + run_git(&repo_dir, &["init"]).await?; + run_git( + &repo_dir, + &["remote", "add", "origin", &self.repository_url], + ) + .await?; + } + + // Resolve the example to a revision, fetching it if needed. + let revision = run_git( + &repo_dir, + &["rev-parse", &format!("{}^{{commit}}", self.revision)], + ) + .await; + let revision = if let Ok(revision) = revision { + revision + } else { + if run_git( + &repo_dir, + &["fetch", "--depth", "1", "origin", &self.revision], + ) + .await + .is_err() + { + run_git(&repo_dir, &["fetch", "origin"]).await?; + } + let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?; + if revision != self.revision { + run_git(&repo_dir, &["tag", &self.revision, &revision]).await?; + } + revision + }; + + // Create the worktree for this example if needed. + let worktree_path = WORKTREES_DIR.join(&file_name).join(repo_name.as_ref()); + if worktree_path.is_dir() { + run_git(&worktree_path, &["clean", "--force", "-d"]).await?; + run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?; + run_git(&worktree_path, &["checkout", revision.as_str()]).await?; + } else { + let worktree_path_string = worktree_path.to_string_lossy(); + run_git(&repo_dir, &["branch", "-f", &file_name, revision.as_str()]).await?; + run_git( + &repo_dir, + &["worktree", "add", "-f", &worktree_path_string, &file_name], + ) + .await?; + } + drop(repo_lock); + + // Apply the uncommitted diff for this example. + if !self.uncommitted_diff.is_empty() { + let mut apply_process = smol::process::Command::new("git") + .current_dir(&worktree_path) + .args(&["apply", "-"]) + .stdin(std::process::Stdio::piped()) + .spawn()?; + + let mut stdin = apply_process.stdin.take().unwrap(); + stdin.write_all(self.uncommitted_diff.as_bytes()).await?; + stdin.close().await?; + drop(stdin); + + let apply_result = apply_process.output().await?; + if !apply_result.status.success() { + anyhow::bail!( + "Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}", + apply_result.status, + String::from_utf8_lossy(&apply_result.stderr), + String::from_utf8_lossy(&apply_result.stdout), + ); + } + } + + Ok(worktree_path) + } + + pub fn unique_name(&self) -> String { + let mut hasher = std::hash::DefaultHasher::new(); + self.hash(&mut hasher); + let disambiguator = hasher.finish(); + let hash = format!("{:04x}", disambiguator); + format!("{}_{}", &self.revision[..8], &hash[..4]) + } +} + pub type ActualExcerpt = Excerpt; #[derive(Clone, Debug, Serialize, Deserialize)] @@ -292,90 +422,7 @@ impl NamedExample { } pub async fn setup_worktree(&self) -> Result { - let (repo_owner, repo_name) = self.repo_name()?; - let file_name = self.file_name(); - - let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref()); - let repo_lock = lock_repo(&repo_dir).await; - - if !repo_dir.is_dir() { - fs::create_dir_all(&repo_dir)?; - run_git(&repo_dir, &["init"]).await?; - run_git( - &repo_dir, - &["remote", "add", "origin", &self.example.repository_url], - ) - .await?; - } - - // Resolve the example to a revision, fetching it if needed. - let revision = run_git( - &repo_dir, - &[ - "rev-parse", - &format!("{}^{{commit}}", self.example.revision), - ], - ) - .await; - let revision = if let Ok(revision) = revision { - revision - } else { - run_git( - &repo_dir, - &["fetch", "--depth", "1", "origin", &self.example.revision], - ) - .await?; - let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?; - if revision != self.example.revision { - run_git(&repo_dir, &["tag", &self.example.revision, &revision]).await?; - } - revision - }; - - // Create the worktree for this example if needed. - let worktree_path = WORKTREES_DIR.join(&file_name).join(repo_name.as_ref()); - if worktree_path.is_dir() { - run_git(&worktree_path, &["clean", "--force", "-d"]).await?; - run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?; - run_git(&worktree_path, &["checkout", revision.as_str()]).await?; - } else { - let worktree_path_string = worktree_path.to_string_lossy(); - run_git(&repo_dir, &["branch", "-f", &file_name, revision.as_str()]).await?; - run_git( - &repo_dir, - &["worktree", "add", "-f", &worktree_path_string, &file_name], - ) - .await?; - } - drop(repo_lock); - - // Apply the uncommitted diff for this example. - if !self.example.uncommitted_diff.is_empty() { - let mut apply_process = smol::process::Command::new("git") - .current_dir(&worktree_path) - .args(&["apply", "-"]) - .stdin(std::process::Stdio::piped()) - .spawn()?; - - let mut stdin = apply_process.stdin.take().unwrap(); - stdin - .write_all(self.example.uncommitted_diff.as_bytes()) - .await?; - stdin.close().await?; - drop(stdin); - - let apply_result = apply_process.output().await?; - if !apply_result.status.success() { - anyhow::bail!( - "Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}", - apply_result.status, - String::from_utf8_lossy(&apply_result.stderr), - String::from_utf8_lossy(&apply_result.stdout), - ); - } - } - - Ok(worktree_path) + self.example.setup_worktree(self.file_name()).await } pub fn file_name(&self) -> String { @@ -391,40 +438,6 @@ impl NamedExample { .collect() } - fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> { - // git@github.com:owner/repo.git - if self.example.repository_url.contains('@') { - let (owner, repo) = self - .example - .repository_url - .split_once(':') - .context("expected : in git url")? - .1 - .split_once('/') - .context("expected / in git url")?; - Ok(( - Cow::Borrowed(owner), - Cow::Borrowed(repo.trim_end_matches(".git")), - )) - // http://github.com/owner/repo.git - } else { - let url = Url::parse(&self.example.repository_url)?; - let mut segments = url.path_segments().context("empty http url")?; - let owner = segments - .next() - .context("expected owner path segment")? - .to_string(); - let repo = segments - .next() - .context("expected repo path segment")? - .trim_end_matches(".git") - .to_string(); - assert!(segments.next().is_none()); - - Ok((owner.into(), repo.into())) - } - } - pub async fn cursor_position( &self, project: &Entity, diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index f2887b98a0ce829a58374fdd10c3e346b6f5d16a..00086777f1f03112b92f11923ad2d025276699f5 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -5,6 +5,7 @@ mod metrics; mod paths; mod predict; mod source_location; +mod training; mod util; use crate::{ @@ -13,9 +14,10 @@ use crate::{ headless::ZetaCliAppState, predict::run_predict, source_location::SourceLocation, + training::{context::ContextType, distill::run_distill}, util::{open_buffer, open_buffer_with_language_server}, }; -use ::util::paths::PathStyle; +use ::util::{ResultExt, paths::PathStyle}; use anyhow::{Result, anyhow}; use clap::{Args, Parser, Subcommand, ValueEnum}; use cloud_llm_client::predict_edits_v3; @@ -43,6 +45,7 @@ enum Command { Context(ContextArgs), Predict(PredictArguments), Eval(EvaluateArguments), + Distill(DistillArguments), ConvertExample { path: PathBuf, #[arg(long, value_enum, default_value_t = ExampleFormat::Md)] @@ -111,6 +114,15 @@ pub struct PredictArguments { options: PredictionOptions, } +#[derive(Debug, Args)] +pub struct DistillArguments { + split_commit_dataset: PathBuf, + #[clap(long, value_enum, default_value_t = ContextType::CurrentFile)] + context_type: ContextType, + #[clap(long)] + batch: Option, +} + #[derive(Clone, Debug, Args)] pub struct PredictionOptions { #[clap(flatten)] @@ -468,6 +480,13 @@ fn main() { Some(Command::Eval(arguments)) => { run_evaluate(arguments, &app_state, cx).await; } + Some(Command::Distill(arguments)) => { + let _guard = cx + .update(|cx| gpui_tokio::Tokio::handle(cx)) + .unwrap() + .enter(); + run_distill(arguments).await.log_err(); + } Some(Command::ConvertExample { path, output_format, diff --git a/crates/edit_prediction_cli/src/training/context.rs b/crates/edit_prediction_cli/src/training/context.rs new file mode 100644 index 0000000000000000000000000000000000000000..7b6d9cc19c1c3750bbf03158ceec5c79a9df0340 --- /dev/null +++ b/crates/edit_prediction_cli/src/training/context.rs @@ -0,0 +1,89 @@ +use std::path::Path; + +use crate::{source_location::SourceLocation, training::teacher::TeacherModel}; + +#[derive(Debug, Clone, Default, clap::ValueEnum)] +pub enum ContextType { + #[default] + CurrentFile, +} + +const MAX_CONTEXT_SIZE: usize = 32768; + +pub fn collect_context( + context_type: &ContextType, + worktree_dir: &Path, + cursor: SourceLocation, +) -> String { + let context = match context_type { + ContextType::CurrentFile => { + let file_path = worktree_dir.join(cursor.path.as_std_path()); + let context = std::fs::read_to_string(&file_path).unwrap_or_default(); + + let context = add_special_tags(&context, worktree_dir, cursor); + context + } + }; + + let region_end_offset = context.find(TeacherModel::REGION_END); + + if context.len() <= MAX_CONTEXT_SIZE { + return context; + } + + if let Some(region_end_offset) = region_end_offset + && region_end_offset + TeacherModel::REGION_END.len() > MAX_CONTEXT_SIZE + { + let to_truncate = context.len() - MAX_CONTEXT_SIZE; + format!( + "[...{} bytes truncated]\n{}\n", + to_truncate, + &context[to_truncate..] + ) + } else { + format!( + "{}\n[...{} bytes truncated]\n", + &context[..MAX_CONTEXT_SIZE], + context.len() - MAX_CONTEXT_SIZE + ) + } +} + +/// Add <|editable_region_start/end|> tags +fn add_special_tags(context: &str, worktree_dir: &Path, cursor: SourceLocation) -> String { + let path = worktree_dir.join(cursor.path.as_std_path()); + let file = std::fs::read_to_string(&path).unwrap_or_default(); + let lines = file.lines().collect::>(); + let cursor_row = cursor.point.row as usize; + let start_line = cursor_row.saturating_sub(TeacherModel::LEFT_CONTEXT_SIZE); + let end_line = (cursor_row + TeacherModel::RIGHT_CONTEXT_SIZE).min(lines.len()); + + let snippet = lines[start_line..end_line].join("\n"); + + if context.contains(&snippet) { + let mut cursor_line = lines[cursor_row].to_string(); + cursor_line.insert_str(cursor.point.column as usize, TeacherModel::USER_CURSOR); + + let mut snippet_with_tags_lines = vec![]; + snippet_with_tags_lines.push(TeacherModel::REGION_START); + snippet_with_tags_lines.extend(&lines[start_line..cursor_row]); + snippet_with_tags_lines.push(&cursor_line); + snippet_with_tags_lines.extend(&lines[cursor_row + 1..end_line]); + snippet_with_tags_lines.push(TeacherModel::REGION_END); + let snippet_with_tags = snippet_with_tags_lines.join("\n"); + + context.replace(&snippet, &snippet_with_tags) + } else { + log::warn!( + "Can't find area around the cursor in the context; proceeding without special tags" + ); + context.to_string() + } +} + +pub fn strip_special_tags(context: &str) -> String { + context + .replace(TeacherModel::REGION_START, "") + .replace(TeacherModel::REGION_END, "") + .replace(TeacherModel::USER_CURSOR, "") +} diff --git a/crates/edit_prediction_cli/src/training/distill.rs b/crates/edit_prediction_cli/src/training/distill.rs new file mode 100644 index 0000000000000000000000000000000000000000..277e35551a9fbce43982de832de5ccecf8d6e92e --- /dev/null +++ b/crates/edit_prediction_cli/src/training/distill.rs @@ -0,0 +1,94 @@ +use serde::Deserialize; +use std::sync::Arc; + +use crate::{ + DistillArguments, + example::Example, + source_location::SourceLocation, + training::{ + context::ContextType, + llm_client::LlmClient, + teacher::{TeacherModel, TeacherOutput}, + }, +}; +use anyhow::Result; +use reqwest_client::ReqwestClient; + +#[derive(Debug, Deserialize)] +pub struct SplitCommit { + repo_url: String, + commit_sha: String, + edit_history: String, + expected_patch: String, + cursor_position: String, +} + +pub async fn run_distill(arguments: DistillArguments) -> Result<()> { + let split_commits: Vec = std::fs::read_to_string(&arguments.split_commit_dataset) + .expect("Failed to read split commit dataset") + .lines() + .map(|line| serde_json::from_str(line).expect("Failed to parse JSON line")) + .collect(); + + let http_client: Arc = Arc::new(ReqwestClient::new()); + + let llm_client = if let Some(cache_path) = arguments.batch { + LlmClient::batch(&cache_path, http_client)? + } else { + LlmClient::plain(http_client)? + }; + + let mut teacher = TeacherModel::new( + "claude-sonnet-4-5".to_string(), + ContextType::CurrentFile, + llm_client, + ); + + let mut num_marked_for_batching = 0; + + for commit in split_commits { + if let Some(distilled) = distill_one(&mut teacher, commit).await? { + println!("{}", serde_json::to_string(&distilled)?); + } else { + if num_marked_for_batching == 0 { + log::warn!("Marked for batching"); + } + num_marked_for_batching += 1; + } + } + + eprintln!( + "{} requests are marked for batching", + num_marked_for_batching + ); + let llm_client = teacher.client; + llm_client.sync_batches().await?; + + Ok(()) +} + +pub async fn distill_one( + teacher: &mut TeacherModel, + commit: SplitCommit, +) -> Result> { + let cursor: SourceLocation = commit + .cursor_position + .parse() + .expect("Failed to parse cursor position"); + + let path = cursor.path.to_rel_path_buf(); + + let example = Example { + repository_url: commit.repo_url, + revision: commit.commit_sha, + uncommitted_diff: commit.edit_history.clone(), + cursor_path: path.as_std_path().to_path_buf(), + cursor_position: commit.cursor_position, + edit_history: commit.edit_history, // todo: trim + expected_patch: commit.expected_patch, + }; + + let prediction = teacher.predict(example).await; + + prediction +} diff --git a/crates/edit_prediction_cli/src/training/llm_client.rs b/crates/edit_prediction_cli/src/training/llm_client.rs new file mode 100644 index 0000000000000000000000000000000000000000..ebecbe915d36a9a456296e818e559c654370f939 --- /dev/null +++ b/crates/edit_prediction_cli/src/training/llm_client.rs @@ -0,0 +1,417 @@ +use anthropic::{ + ANTHROPIC_API_URL, Message, Request as AnthropicRequest, RequestContent, + Response as AnthropicResponse, Role, non_streaming_completion, +}; +use anyhow::Result; +use http_client::HttpClient; +use indoc::indoc; +use sqlez::bindable::Bind; +use sqlez::bindable::StaticColumnCount; +use sqlez_macros::sql; +use std::hash::Hash; +use std::hash::Hasher; +use std::sync::Arc; + +pub struct PlainLlmClient { + http_client: Arc, + api_key: String, +} + +impl PlainLlmClient { + fn new(http_client: Arc) -> Result { + let api_key = std::env::var("ANTHROPIC_API_KEY") + .map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?; + Ok(Self { + http_client, + api_key, + }) + } + + async fn generate( + &self, + model: String, + max_tokens: u64, + messages: Vec, + ) -> Result { + let request = AnthropicRequest { + model, + max_tokens, + messages, + tools: Vec::new(), + thinking: None, + tool_choice: None, + system: None, + metadata: None, + stop_sequences: Vec::new(), + temperature: None, + top_k: None, + top_p: None, + }; + + let response = non_streaming_completion( + self.http_client.as_ref(), + ANTHROPIC_API_URL, + &self.api_key, + request, + None, + ) + .await + .map_err(|e| anyhow::anyhow!("{:?}", e))?; + + Ok(response) + } +} + +pub struct BatchingLlmClient { + connection: sqlez::connection::Connection, + http_client: Arc, + api_key: String, +} + +struct CacheRow { + request_hash: String, + request: Option, + response: Option, + batch_id: Option, +} + +impl StaticColumnCount for CacheRow { + fn column_count() -> usize { + 4 + } +} + +impl Bind for CacheRow { + fn bind(&self, statement: &sqlez::statement::Statement, start_index: i32) -> Result { + let next_index = statement.bind(&self.request_hash, start_index)?; + let next_index = statement.bind(&self.request, next_index)?; + let next_index = statement.bind(&self.response, next_index)?; + let next_index = statement.bind(&self.batch_id, next_index)?; + Ok(next_index) + } +} + +#[derive(serde::Serialize, serde::Deserialize)] +struct SerializableRequest { + model: String, + max_tokens: u64, + messages: Vec, +} + +#[derive(serde::Serialize, serde::Deserialize)] +struct SerializableMessage { + role: String, + content: String, +} + +impl BatchingLlmClient { + fn new(cache_path: &str, http_client: Arc) -> Result { + let api_key = std::env::var("ANTHROPIC_API_KEY") + .map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?; + + let connection = sqlez::connection::Connection::open_file(&cache_path); + let mut statement = sqlez::statement::Statement::prepare( + &connection, + indoc! {" + CREATE TABLE IF NOT EXISTS cache ( + request_hash TEXT PRIMARY KEY, + request TEXT, + response TEXT, + batch_id TEXT + ); + "}, + )?; + statement.exec()?; + drop(statement); + + Ok(Self { + connection, + http_client, + api_key, + }) + } + + pub fn lookup( + &self, + model: &str, + max_tokens: u64, + messages: &[Message], + ) -> Result> { + let request_hash_str = Self::request_hash(model, max_tokens, messages); + let response: Vec = self.connection.select_bound( + &sql!(SELECT response FROM cache WHERE request_hash = ?1 AND response IS NOT NULL;), + )?(request_hash_str.as_str())?; + Ok(response + .into_iter() + .next() + .and_then(|text| serde_json::from_str(&text).ok())) + } + + pub fn mark_for_batch(&self, model: &str, max_tokens: u64, messages: &[Message]) -> Result<()> { + let request_hash = Self::request_hash(model, max_tokens, messages); + + let serializable_messages: Vec = messages + .iter() + .map(|msg| SerializableMessage { + role: match msg.role { + Role::User => "user".to_string(), + Role::Assistant => "assistant".to_string(), + }, + content: message_content_to_string(&msg.content), + }) + .collect(); + + let serializable_request = SerializableRequest { + model: model.to_string(), + max_tokens, + messages: serializable_messages, + }; + + let request = Some(serde_json::to_string(&serializable_request)?); + let cache_row = CacheRow { + request_hash, + request, + response: None, + batch_id: None, + }; + self.connection.exec_bound(sql!( + INSERT OR IGNORE INTO cache(request_hash, request, response, batch_id) VALUES (?, ?, ?, ?)))?( + cache_row, + ) + } + + async fn generate( + &self, + model: String, + max_tokens: u64, + messages: Vec, + ) -> Result> { + let response = self.lookup(&model, max_tokens, &messages)?; + if let Some(response) = response { + return Ok(Some(response)); + } + + self.mark_for_batch(&model, max_tokens, &messages)?; + + Ok(None) + } + + /// Uploads pending requests as a new batch; downloads finished batches if any. + async fn sync_batches(&self) -> Result<()> { + self.upload_pending_requests().await?; + self.download_finished_batches().await + } + + async fn download_finished_batches(&self) -> Result<()> { + let q = sql!(SELECT DISTINCT batch_id FROM cache WHERE batch_id IS NOT NULL AND response IS NULL); + let batch_ids: Vec = self.connection.select(q)?()?; + + for batch_id in batch_ids { + let batch_status = anthropic::batches::retrieve_batch( + self.http_client.as_ref(), + ANTHROPIC_API_URL, + &self.api_key, + &batch_id, + ) + .await + .map_err(|e| anyhow::anyhow!("{:?}", e))?; + + log::info!( + "Batch {} status: {}", + batch_id, + batch_status.processing_status + ); + + if batch_status.processing_status == "ended" { + let results = anthropic::batches::retrieve_batch_results( + self.http_client.as_ref(), + ANTHROPIC_API_URL, + &self.api_key, + &batch_id, + ) + .await + .map_err(|e| anyhow::anyhow!("{:?}", e))?; + + let mut success_count = 0; + for result in results { + let request_hash = result + .custom_id + .strip_prefix("req_hash_") + .unwrap_or(&result.custom_id) + .to_string(); + + match result.result { + anthropic::batches::BatchResult::Succeeded { message } => { + let response_json = serde_json::to_string(&message)?; + let q = sql!(UPDATE cache SET response = ? WHERE request_hash = ?); + self.connection.exec_bound(q)?((response_json, request_hash))?; + success_count += 1; + } + anthropic::batches::BatchResult::Errored { error } => { + log::error!("Batch request {} failed: {:?}", request_hash, error); + } + anthropic::batches::BatchResult::Canceled => { + log::warn!("Batch request {} was canceled", request_hash); + } + anthropic::batches::BatchResult::Expired => { + log::warn!("Batch request {} expired", request_hash); + } + } + } + log::info!("Uploaded {} successful requests", success_count); + } + } + + Ok(()) + } + + async fn upload_pending_requests(&self) -> Result { + let q = sql!( + SELECT request_hash, request FROM cache WHERE batch_id IS NULL AND response IS NULL + ); + + let rows: Vec<(String, String)> = self.connection.select(q)?()?; + + if rows.is_empty() { + return Ok(String::new()); + } + + let batch_requests = rows + .iter() + .map(|(hash, request_str)| { + let serializable_request: SerializableRequest = + serde_json::from_str(&request_str).unwrap(); + + let messages: Vec = serializable_request + .messages + .into_iter() + .map(|msg| Message { + role: match msg.role.as_str() { + "user" => Role::User, + "assistant" => Role::Assistant, + _ => Role::User, + }, + content: vec![RequestContent::Text { + text: msg.content, + cache_control: None, + }], + }) + .collect(); + + let params = AnthropicRequest { + model: serializable_request.model, + max_tokens: serializable_request.max_tokens, + messages, + tools: Vec::new(), + thinking: None, + tool_choice: None, + system: None, + metadata: None, + stop_sequences: Vec::new(), + temperature: None, + top_k: None, + top_p: None, + }; + + let custom_id = format!("req_hash_{}", hash); + anthropic::batches::BatchRequest { custom_id, params } + }) + .collect::>(); + + let batch_len = batch_requests.len(); + let batch = anthropic::batches::create_batch( + self.http_client.as_ref(), + ANTHROPIC_API_URL, + &self.api_key, + anthropic::batches::CreateBatchRequest { + requests: batch_requests, + }, + ) + .await + .map_err(|e| anyhow::anyhow!("{:?}", e))?; + + let q = sql!( + UPDATE cache SET batch_id = ? WHERE batch_id is NULL + ); + self.connection.exec_bound(q)?(batch.id.as_str())?; + + log::info!("Uploaded batch with {} requests", batch_len); + + Ok(batch.id) + } + + fn request_hash(model: &str, max_tokens: u64, messages: &[Message]) -> String { + let mut hasher = std::hash::DefaultHasher::new(); + model.hash(&mut hasher); + max_tokens.hash(&mut hasher); + for msg in messages { + message_content_to_string(&msg.content).hash(&mut hasher); + } + let request_hash = hasher.finish(); + format!("{request_hash:016x}") + } +} + +fn message_content_to_string(content: &[RequestContent]) -> String { + content + .iter() + .filter_map(|c| match c { + RequestContent::Text { text, .. } => Some(text.clone()), + _ => None, + }) + .collect::>() + .join("\n") +} + +pub enum LlmClient { + // No batching + Plain(PlainLlmClient), + Batch(BatchingLlmClient), + Dummy, +} + +impl LlmClient { + pub fn plain(http_client: Arc) -> Result { + Ok(Self::Plain(PlainLlmClient::new(http_client)?)) + } + + pub fn batch(cache_path: &str, http_client: Arc) -> Result { + Ok(Self::Batch(BatchingLlmClient::new( + cache_path, + http_client, + )?)) + } + + #[allow(dead_code)] + pub fn dummy() -> Self { + Self::Dummy + } + + pub async fn generate( + &self, + model: String, + max_tokens: u64, + messages: Vec, + ) -> Result> { + match self { + LlmClient::Plain(plain_llm_client) => plain_llm_client + .generate(model, max_tokens, messages) + .await + .map(Some), + LlmClient::Batch(batching_llm_client) => { + batching_llm_client + .generate(model, max_tokens, messages) + .await + } + LlmClient::Dummy => panic!("Dummy LLM client is not expected to be used"), + } + } + + pub async fn sync_batches(&self) -> Result<()> { + match self { + LlmClient::Plain(_) => Ok(()), + LlmClient::Batch(batching_llm_client) => batching_llm_client.sync_batches().await, + LlmClient::Dummy => panic!("Dummy LLM client is not expected to be used"), + } + } +} diff --git a/crates/edit_prediction_cli/src/training/mod.rs b/crates/edit_prediction_cli/src/training/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..dc564c4dc86c8e095e8e93ccbdfb29d3313e922a --- /dev/null +++ b/crates/edit_prediction_cli/src/training/mod.rs @@ -0,0 +1,4 @@ +pub mod context; +pub mod distill; +pub mod llm_client; +pub mod teacher; diff --git a/crates/edit_prediction_cli/src/training/teacher.prompt.md b/crates/edit_prediction_cli/src/training/teacher.prompt.md new file mode 100644 index 0000000000000000000000000000000000000000..af67c871ef31a21a8744bf71375a50128d9699b6 --- /dev/null +++ b/crates/edit_prediction_cli/src/training/teacher.prompt.md @@ -0,0 +1,48 @@ +# Instructions + +You are a code completion assistant helping a programmer finish their work. Your task is to: + +1. Analyze the edit history to understand what the programmer is trying to achieve +2. Identify any incomplete refactoring or changes that need to be finished +3. Make the remaining edits that a human programmer would logically make next (by rewriting the corresponding code sections) +4. Apply systematic changes consistently across the entire codebase - if you see a pattern starting, complete it everywhere. + +Focus on: +- Understanding the intent behind the changes (e.g., improving error handling, refactoring APIs, fixing bugs) +- Completing any partially-applied changes across the codebase +- Ensuring consistency with the programming style and patterns already established +- Making edits that maintain or improve code quality +- If the programmer started refactoring one instance of a pattern, find and update ALL similar instances +- Don't write a lot of code if you're not sure what to do + +Rules: +- Do not just mechanically apply patterns - reason about what changes make sense given the context and the programmer's apparent goals. +- Do not just fix syntax errors - look for the broader refactoring pattern and apply it systematically throughout the code. + +Input format: +- You receive small code fragments called context (structs, field definitions, function signatures, etc.). They may or may not be relevant. +- Never modify the context code. +- You also receive a code snippet between <|editable_region_start|> and <|editable_region_end|>. This is the editable region. +- The cursor position is marked with <|user_cursor|>. + +Output format: +- Return the entire editable region, applying any edits you make. +- Remove the <|user_cursor|> marker. +- Wrap the edited code in a block of exactly five backticks. + +Output example: +````` + // `zed --askpass` Makes zed operate in nc/netcat mode for use with askpass + if let Some(socket) = &args.askpass {{ + askpass::main(socket); + return Ok(()); + }} +````` + +## User Edits History + +{{edit_history}} + +## Code Context + +{{context}} diff --git a/crates/edit_prediction_cli/src/training/teacher.rs b/crates/edit_prediction_cli/src/training/teacher.rs new file mode 100644 index 0000000000000000000000000000000000000000..99672db8f99a87b99a43c8876db2fd0c2f307b21 --- /dev/null +++ b/crates/edit_prediction_cli/src/training/teacher.rs @@ -0,0 +1,266 @@ +use crate::{ + example::Example, + source_location::SourceLocation, + training::{ + context::{ContextType, collect_context, strip_special_tags}, + llm_client::LlmClient, + }, +}; +use anthropic::{Message, RequestContent, ResponseContent, Role}; +use anyhow::Result; + +pub struct TeacherModel { + pub llm_name: String, + pub context: ContextType, + pub client: LlmClient, +} + +#[derive(Debug, serde::Serialize)] +pub struct TeacherOutput { + parsed_output: String, + prompt: String, + raw_llm_response: String, + context: String, + diff: String, +} + +impl TeacherModel { + const PROMPT: &str = include_str!("teacher.prompt.md"); + pub(crate) const REGION_START: &str = "<|editable_region_start|>\n"; + pub(crate) const REGION_END: &str = "<|editable_region_end|>"; + pub(crate) const USER_CURSOR: &str = "<|user_cursor|>"; + + /// Number of lines to include before the cursor position + pub(crate) const LEFT_CONTEXT_SIZE: usize = 5; + + /// Number of lines to include after the cursor position + pub(crate) const RIGHT_CONTEXT_SIZE: usize = 5; + + /// Truncate edit history to this number of last lines + const MAX_HISTORY_LINES: usize = 128; + + pub fn new(llm_name: String, context: ContextType, client: LlmClient) -> Self { + TeacherModel { + llm_name, + context, + client, + } + } + + pub async fn predict(&self, input: Example) -> Result> { + let name = input.unique_name(); + let worktree_dir = input.setup_worktree(name).await?; + let cursor: SourceLocation = input + .cursor_position + .parse() + .expect("Failed to parse cursor position"); + + let context = collect_context(&self.context, &worktree_dir, cursor.clone()); + let edit_history = Self::format_edit_history(&input.edit_history); + + let prompt = Self::PROMPT + .replace("{{context}}", &context) + .replace("{{edit_history}}", &edit_history); + + let messages = vec![Message { + role: Role::User, + content: vec![RequestContent::Text { + text: prompt.clone(), + cache_control: None, + }], + }]; + + let Some(response) = self + .client + .generate(self.llm_name.clone(), 16384, messages) + .await? + else { + return Ok(None); + }; + + let response_text = response + .content + .into_iter() + .filter_map(|content| match content { + ResponseContent::Text { text } => Some(text), + _ => None, + }) + .collect::>() + .join("\n"); + + let parsed_output = self.parse_response(&response_text); + + let original_editable_region = Self::extract_editable_region(&context); + let context_after_edit = context.replace(&original_editable_region, &parsed_output); + let context_after_edit = strip_special_tags(&context_after_edit); + let context_before_edit = strip_special_tags(&context); + let diff = language::unified_diff(&context_before_edit, &context_after_edit); + + // zeta distill --batch batch_results.txt + // zeta distill + // 1. Run `zeta distill <2000 examples <- all examples>` for the first time + // - store LLM requests in a batch, don't actual send the request + // - send the batch (2000 requests) after all inputs are processed + // 2. `zeta send-batches` + // - upload the batch to Anthropic + + // https://platform.claude.com/docs/en/build-with-claude/batch-processing + // https://crates.io/crates/anthropic-sdk-rust + + // - poll for results + // - when ready, store results in cache (a database) + // 3. `zeta distill` again + // - use the cached results this time + + Ok(Some(TeacherOutput { + parsed_output, + prompt, + raw_llm_response: response_text, + context, + diff, + })) + } + + fn parse_response(&self, content: &str) -> String { + let codeblock = Self::extract_last_codeblock(content); + let editable_region = Self::extract_editable_region(&codeblock); + + editable_region + } + + /// Extract content from the last code-fenced block if any, or else return content as is + fn extract_last_codeblock(text: &str) -> String { + let mut last_block = None; + let mut search_start = 0; + + while let Some(start) = text[search_start..].find("```") { + let start = start + search_start; + let bytes = text.as_bytes(); + let mut backtick_end = start; + + while backtick_end < bytes.len() && bytes[backtick_end] == b'`' { + backtick_end += 1; + } + + let backtick_count = backtick_end - start; + let closing_backticks = "`".repeat(backtick_count); + + if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) { + let code_block = &text[backtick_end + 1..backtick_end + end_pos - 1]; + last_block = Some(code_block.to_string()); + search_start = backtick_end + end_pos + backtick_count; + } else { + break; + } + } + + last_block.unwrap_or_else(|| text.to_string()) + } + + fn extract_editable_region(text: &str) -> String { + let start = text + .find(Self::REGION_START) + .map_or(0, |pos| pos + Self::REGION_START.len()); + let end = text.find(Self::REGION_END).unwrap_or(text.len()); + + text[start..end].to_string() + } + + /// Truncates edit history to a maximum length and removes comments (unified diff garbage lines) + fn format_edit_history(edit_history: &str) -> String { + let lines = edit_history + .lines() + .filter(|&s| Self::is_content_line(s)) + .collect::>(); + + let history_lines = if lines.len() > Self::MAX_HISTORY_LINES { + &lines[lines.len() - Self::MAX_HISTORY_LINES..] + } else { + &lines + }; + history_lines.join("\n") + } + + fn is_content_line(s: &str) -> bool { + s.starts_with("-") + || s.starts_with("+") + || s.starts_with(" ") + || s.starts_with("---") + || s.starts_with("+++") + || s.starts_with("@@") + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_response() { + let teacher = TeacherModel::new( + "test".to_string(), + ContextType::CurrentFile, + LlmClient::dummy(), + ); + let response = "This is a test response."; + let parsed = teacher.parse_response(response); + assert_eq!(parsed, response.to_string()); + + let response = indoc::indoc! {" + Some thinking + + ````` + actual response + ````` + "}; + let parsed = teacher.parse_response(response); + assert_eq!(parsed, "actual response"); + } + + #[test] + fn test_extract_last_code_block() { + let text = indoc::indoc! {" + Some thinking + + ``` + first block + ``` + + ````` + last block + ````` + "}; + let last_block = TeacherModel::extract_last_codeblock(text); + assert_eq!(last_block, "last block"); + } + + #[test] + fn test_extract_editable_region() { + let teacher = TeacherModel::new( + "test".to_string(), + ContextType::CurrentFile, + LlmClient::dummy(), + ); + let response = indoc::indoc! {" + some lines + are + here + <|editable_region_start|> + one + two three + + <|editable_region_end|> + more + lines here + "}; + let parsed = teacher.parse_response(response); + assert_eq!( + parsed, + indoc::indoc! {" + one + two three + + "} + ); + } +}