Add `zeta distill` command (#44369)

Oleksiy Syvokon , Piotr Osiewicz , and Ben Kunkle created

This PR partially implements a knowledge distillation data pipeline.

`zeta distill` gets a dataset of chronologically ordered commits and
generates synthetic predictions with a teacher model (one-shot Claude
Sonnet).

`zeta distill --batches cache.db` will enable Message Batches API. Under
the first run, this command will collect all LLM requests and upload a
batch of them to Anthropic. On subsequent runs, it will check the batch
status. If ready, it will download the result and put them into the
local cache.


Release Notes:

- N/A

---------

Co-authored-by: Piotr Osiewicz <24362066+osiewicz@users.noreply.github.com>
Co-authored-by: Ben Kunkle <ben@zed.dev>

Change summary

Cargo.lock                                                |   4 
Cargo.toml                                                |  25 
crates/anthropic/src/anthropic.rs                         | 143 ++
crates/anthropic/src/batches.rs                           | 190 ++++
crates/edit_prediction_cli/Cargo.toml                     |   6 
crates/edit_prediction_cli/src/example.rs                 | 251 ++--
crates/edit_prediction_cli/src/main.rs                    |  21 
crates/edit_prediction_cli/src/training/context.rs        |  89 +
crates/edit_prediction_cli/src/training/distill.rs        |  94 ++
crates/edit_prediction_cli/src/training/llm_client.rs     | 417 +++++++++
crates/edit_prediction_cli/src/training/mod.rs            |   4 
crates/edit_prediction_cli/src/training/teacher.prompt.md |  48 +
crates/edit_prediction_cli/src/training/teacher.rs        | 266 +++++
13 files changed, 1,370 insertions(+), 188 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -5167,6 +5167,7 @@ dependencies = [
 name = "edit_prediction_cli"
 version = "0.1.0"
 dependencies = [
+ "anthropic",
  "anyhow",
  "chrono",
  "clap",
@@ -5182,6 +5183,7 @@ dependencies = [
  "futures 0.3.31",
  "gpui",
  "gpui_tokio",
+ "http_client",
  "indoc",
  "language",
  "language_extension",
@@ -5202,6 +5204,8 @@ dependencies = [
  "settings",
  "shellexpand 2.1.2",
  "smol",
+ "sqlez",
+ "sqlez_macros",
  "terminal_view",
  "toml 0.8.23",
  "util",

Cargo.toml 🔗

@@ -244,7 +244,6 @@ activity_indicator = { path = "crates/activity_indicator" }
 agent_ui = { path = "crates/agent_ui" }
 agent_settings = { path = "crates/agent_settings" }
 agent_servers = { path = "crates/agent_servers" }
-ai = { path = "crates/ai" }
 ai_onboarding = { path = "crates/ai_onboarding" }
 anthropic = { path = "crates/anthropic" }
 askpass = { path = "crates/askpass" }
@@ -254,7 +253,6 @@ assistant_slash_command = { path = "crates/assistant_slash_command" }
 assistant_slash_commands = { path = "crates/assistant_slash_commands" }
 audio = { path = "crates/audio" }
 auto_update = { path = "crates/auto_update" }
-auto_update_helper = { path = "crates/auto_update_helper" }
 auto_update_ui = { path = "crates/auto_update_ui" }
 aws_http_client = { path = "crates/aws_http_client" }
 bedrock = { path = "crates/bedrock" }
@@ -269,7 +267,6 @@ cloud_api_client = { path = "crates/cloud_api_client" }
 cloud_api_types = { path = "crates/cloud_api_types" }
 cloud_llm_client = { path = "crates/cloud_llm_client" }
 cloud_zeta2_prompt = { path = "crates/cloud_zeta2_prompt" }
-collab = { path = "crates/collab" }
 collab_ui = { path = "crates/collab_ui" }
 collections = { path = "crates/collections", version = "0.1.0" }
 command_palette = { path = "crates/command_palette" }
@@ -358,8 +355,6 @@ panel = { path = "crates/panel" }
 paths = { path = "crates/paths" }
 perf = { path = "tooling/perf" }
 picker = { path = "crates/picker" }
-plugin = { path = "crates/plugin" }
-plugin_macros = { path = "crates/plugin_macros" }
 prettier = { path = "crates/prettier" }
 settings_profile_selector = { path = "crates/settings_profile_selector" }
 project = { path = "crates/project" }
@@ -370,12 +365,10 @@ proto = { path = "crates/proto" }
 recent_projects = { path = "crates/recent_projects" }
 refineable = { path = "crates/refineable" }
 release_channel = { path = "crates/release_channel" }
-scheduler = { path = "crates/scheduler" }
 remote = { path = "crates/remote" }
 remote_server = { path = "crates/remote_server" }
 repl = { path = "crates/repl" }
 reqwest_client = { path = "crates/reqwest_client" }
-rich_text = { path = "crates/rich_text" }
 rodio = { git = "https://github.com/RustAudio/rodio", rev ="e2074c6c2acf07b57cf717e076bdda7a9ac6e70b", features = ["wav", "playback", "wav_output", "recording"] }
 rope = { path = "crates/rope" }
 rpc = { path = "crates/rpc" }
@@ -392,7 +385,6 @@ snippets_ui = { path = "crates/snippets_ui" }
 sqlez = { path = "crates/sqlez" }
 sqlez_macros = { path = "crates/sqlez_macros" }
 story = { path = "crates/story" }
-storybook = { path = "crates/storybook" }
 streaming_diff = { path = "crates/streaming_diff" }
 sum_tree = { path = "crates/sum_tree" }
 supermaven = { path = "crates/supermaven" }
@@ -409,7 +401,6 @@ terminal_view = { path = "crates/terminal_view" }
 text = { path = "crates/text" }
 theme = { path = "crates/theme" }
 theme_extension = { path = "crates/theme_extension" }
-theme_importer = { path = "crates/theme_importer" }
 theme_selector = { path = "crates/theme_selector" }
 time_format = { path = "crates/time_format" }
 title_bar = { path = "crates/title_bar" }
@@ -510,13 +501,11 @@ exec = "0.3.1"
 fancy-regex = "0.16.0"
 fork = "0.4.0"
 futures = "0.3"
-futures-batch = "0.6.1"
 futures-lite = "1.13"
 gh-workflow = { git = "https://github.com/zed-industries/gh-workflow", rev = "09acfdf2bd5c1d6254abefd609c808ff73547b2c" }
 git2 = { version = "0.20.1", default-features = false }
 globset = "0.4"
 handlebars = "4.3"
-hashbrown = "0.15.3"
 heck = "0.5"
 heed = { version = "0.21.0", features = ["read-txn-no-tls"] }
 hex = "0.4.3"
@@ -552,7 +541,6 @@ nanoid = "0.4"
 nbformat = "0.15.0"
 nix = "0.29"
 num-format = "0.4.4"
-num-traits = "0.2"
 objc = "0.2"
 objc2-foundation = { version = "=0.3.1", default-features = false, features = [
     "NSArray",
@@ -591,7 +579,6 @@ pet = { git = "https://github.com/microsoft/python-environment-tools.git", rev =
 pet-conda = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "1e86914c3ce2f3a08c0cedbcb0615a7f9fa7a5da" }
 pet-core = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "1e86914c3ce2f3a08c0cedbcb0615a7f9fa7a5da" }
 pet-fs = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "1e86914c3ce2f3a08c0cedbcb0615a7f9fa7a5da" }
-pet-pixi = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "1e86914c3ce2f3a08c0cedbcb0615a7f9fa7a5da" }
 pet-poetry = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "1e86914c3ce2f3a08c0cedbcb0615a7f9fa7a5da" }
 pet-reporter = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "1e86914c3ce2f3a08c0cedbcb0615a7f9fa7a5da" }
 pet-virtualenv = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "1e86914c3ce2f3a08c0cedbcb0615a7f9fa7a5da" }
@@ -631,7 +618,6 @@ scap = { git = "https://github.com/zed-industries/scap", rev = "4afea48c3b002197
 schemars = { version = "1.0", features = ["indexmap2"] }
 semver = { version = "1.0", features = ["serde"] }
 serde = { version = "1.0.221", features = ["derive", "rc"] }
-serde_derive = "1.0.221"
 serde_json = { version = "1.0.144", features = ["preserve_order", "raw_value"] }
 serde_json_lenient = { version = "0.2", features = [
     "preserve_order",
@@ -643,7 +629,6 @@ serde_urlencoded = "0.7"
 sha2 = "0.10"
 shellexpand = "2.1.0"
 shlex = "1.3.0"
-similar = "2.6"
 simplelog = "0.12.2"
 slotmap = "1.0.6"
 smallvec = { version = "1.6", features = ["union"] }
@@ -722,7 +707,6 @@ wasmtime-wasi = "29"
 wax = "0.6"
 which = "6.0.0"
 windows-core = "0.61"
-wit-component = "0.221"
 yawc = "0.2.5"
 zeroize = "1.8"
 zstd = "0.11"
@@ -804,20 +788,13 @@ settings_macros = { opt-level = 3 }
 sqlez_macros = { opt-level = 3, codegen-units = 1 }
 ui_macros = { opt-level = 3 }
 util_macros = { opt-level = 3 }
-serde_derive = { opt-level = 3 }
 quote = { opt-level = 3 }
 syn = { opt-level = 3 }
 proc-macro2 = { opt-level = 3 }
 # proc-macros end
 
 taffy = { opt-level = 3 }
-cranelift-codegen = { opt-level = 3 }
-cranelift-codegen-meta = { opt-level = 3 }
-cranelift-codegen-shared = { opt-level = 3 }
 resvg = { opt-level = 3 }
-rustybuzz = { opt-level = 3 }
-ttf-parser = { opt-level = 3 }
-wasmtime-cranelift = { opt-level = 3 }
 wasmtime = { opt-level = 3 }
 # Build single-source-file crates with cg=1 as it helps make `cargo build` of a whole workspace a bit faster
 activity_indicator = { codegen-units = 1 }
@@ -826,7 +803,6 @@ breadcrumbs = { codegen-units = 1 }
 collections = { codegen-units = 1 }
 command_palette = { codegen-units = 1 }
 command_palette_hooks = { codegen-units = 1 }
-extension_cli = { codegen-units = 1 }
 feature_flags = { codegen-units = 1 }
 file_icons = { codegen-units = 1 }
 fsevent = { codegen-units = 1 }
@@ -846,7 +822,6 @@ project_symbols = { codegen-units = 1 }
 refineable = { codegen-units = 1 }
 release_channel = { codegen-units = 1 }
 reqwest_client = { codegen-units = 1 }
-rich_text = { codegen-units = 1 }
 session = { codegen-units = 1 }
 snippet = { codegen-units = 1 }
 snippets_ui = { codegen-units = 1 }

crates/anthropic/src/anthropic.rs 🔗

@@ -12,6 +12,8 @@ pub use settings::{AnthropicAvailableModel as AvailableModel, ModelMode};
 use strum::{EnumIter, EnumString};
 use thiserror::Error;
 
+pub mod batches;
+
 pub const ANTHROPIC_API_URL: &str = "https://api.anthropic.com";
 
 #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
@@ -465,6 +467,7 @@ impl Model {
     }
 }
 
+/// Generate completion with streaming.
 pub async fn stream_completion(
     client: &dyn HttpClient,
     api_url: &str,
@@ -477,6 +480,101 @@ pub async fn stream_completion(
         .map(|output| output.0)
 }
 
+/// Generate completion without streaming.
+pub async fn non_streaming_completion(
+    client: &dyn HttpClient,
+    api_url: &str,
+    api_key: &str,
+    request: Request,
+    beta_headers: Option<String>,
+) -> Result<Response, AnthropicError> {
+    let (mut response, rate_limits) =
+        send_request(client, api_url, api_key, &request, beta_headers).await?;
+
+    if response.status().is_success() {
+        let mut body = String::new();
+        response
+            .body_mut()
+            .read_to_string(&mut body)
+            .await
+            .map_err(AnthropicError::ReadResponse)?;
+
+        serde_json::from_str(&body).map_err(AnthropicError::DeserializeResponse)
+    } else {
+        Err(handle_error_response(response, rate_limits).await)
+    }
+}
+
+async fn send_request(
+    client: &dyn HttpClient,
+    api_url: &str,
+    api_key: &str,
+    request: impl Serialize,
+    beta_headers: Option<String>,
+) -> Result<(http::Response<AsyncBody>, RateLimitInfo), AnthropicError> {
+    let uri = format!("{api_url}/v1/messages");
+
+    let mut request_builder = HttpRequest::builder()
+        .method(Method::POST)
+        .uri(uri)
+        .header("Anthropic-Version", "2023-06-01")
+        .header("X-Api-Key", api_key.trim())
+        .header("Content-Type", "application/json");
+
+    if let Some(beta_headers) = beta_headers {
+        request_builder = request_builder.header("Anthropic-Beta", beta_headers);
+    }
+
+    let serialized_request =
+        serde_json::to_string(&request).map_err(AnthropicError::SerializeRequest)?;
+    let request = request_builder
+        .body(AsyncBody::from(serialized_request))
+        .map_err(AnthropicError::BuildRequestBody)?;
+
+    let response = client
+        .send(request)
+        .await
+        .map_err(AnthropicError::HttpSend)?;
+
+    let rate_limits = RateLimitInfo::from_headers(response.headers());
+
+    Ok((response, rate_limits))
+}
+
+async fn handle_error_response(
+    mut response: http::Response<AsyncBody>,
+    rate_limits: RateLimitInfo,
+) -> AnthropicError {
+    if response.status().as_u16() == 529 {
+        return AnthropicError::ServerOverloaded {
+            retry_after: rate_limits.retry_after,
+        };
+    }
+
+    if let Some(retry_after) = rate_limits.retry_after {
+        return AnthropicError::RateLimit { retry_after };
+    }
+
+    let mut body = String::new();
+    let read_result = response
+        .body_mut()
+        .read_to_string(&mut body)
+        .await
+        .map_err(AnthropicError::ReadResponse);
+
+    if let Err(err) = read_result {
+        return err;
+    }
+
+    match serde_json::from_str::<Event>(&body) {
+        Ok(Event::Error { error }) => AnthropicError::ApiError(error),
+        Ok(_) | Err(_) => AnthropicError::HttpResponseError {
+            status_code: response.status(),
+            message: body,
+        },
+    }
+}
+
 /// An individual rate limit.
 #[derive(Debug)]
 pub struct RateLimit {
@@ -580,30 +678,10 @@ pub async fn stream_completion_with_rate_limit_info(
         base: request,
         stream: true,
     };
-    let uri = format!("{api_url}/v1/messages");
 
-    let mut request_builder = HttpRequest::builder()
-        .method(Method::POST)
-        .uri(uri)
-        .header("Anthropic-Version", "2023-06-01")
-        .header("X-Api-Key", api_key.trim())
-        .header("Content-Type", "application/json");
+    let (response, rate_limits) =
+        send_request(client, api_url, api_key, &request, beta_headers).await?;
 
-    if let Some(beta_headers) = beta_headers {
-        request_builder = request_builder.header("Anthropic-Beta", beta_headers);
-    }
-
-    let serialized_request =
-        serde_json::to_string(&request).map_err(AnthropicError::SerializeRequest)?;
-    let request = request_builder
-        .body(AsyncBody::from(serialized_request))
-        .map_err(AnthropicError::BuildRequestBody)?;
-
-    let mut response = client
-        .send(request)
-        .await
-        .map_err(AnthropicError::HttpSend)?;
-    let rate_limits = RateLimitInfo::from_headers(response.headers());
     if response.status().is_success() {
         let reader = BufReader::new(response.into_body());
         let stream = reader
@@ -622,27 +700,8 @@ pub async fn stream_completion_with_rate_limit_info(
             })
             .boxed();
         Ok((stream, Some(rate_limits)))
-    } else if response.status().as_u16() == 529 {
-        Err(AnthropicError::ServerOverloaded {
-            retry_after: rate_limits.retry_after,
-        })
-    } else if let Some(retry_after) = rate_limits.retry_after {
-        Err(AnthropicError::RateLimit { retry_after })
     } else {
-        let mut body = String::new();
-        response
-            .body_mut()
-            .read_to_string(&mut body)
-            .await
-            .map_err(AnthropicError::ReadResponse)?;
-
-        match serde_json::from_str::<Event>(&body) {
-            Ok(Event::Error { error }) => Err(AnthropicError::ApiError(error)),
-            Ok(_) | Err(_) => Err(AnthropicError::HttpResponseError {
-                status_code: response.status(),
-                message: body,
-            }),
-        }
+        Err(handle_error_response(response, rate_limits).await)
     }
 }
 

crates/anthropic/src/batches.rs 🔗

@@ -0,0 +1,190 @@
+use anyhow::Result;
+use futures::AsyncReadExt;
+use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
+use serde::{Deserialize, Serialize};
+
+use crate::{AnthropicError, ApiError, RateLimitInfo, Request, Response};
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct BatchRequest {
+    pub custom_id: String,
+    pub params: Request,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct CreateBatchRequest {
+    pub requests: Vec<BatchRequest>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct MessageBatchRequestCounts {
+    pub processing: u64,
+    pub succeeded: u64,
+    pub errored: u64,
+    pub canceled: u64,
+    pub expired: u64,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct MessageBatch {
+    pub id: String,
+    #[serde(rename = "type")]
+    pub batch_type: String,
+    pub processing_status: String,
+    pub request_counts: MessageBatchRequestCounts,
+    pub ended_at: Option<String>,
+    pub created_at: String,
+    pub expires_at: String,
+    pub archived_at: Option<String>,
+    pub cancel_initiated_at: Option<String>,
+    pub results_url: Option<String>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(tag = "type")]
+pub enum BatchResult {
+    #[serde(rename = "succeeded")]
+    Succeeded { message: Response },
+    #[serde(rename = "errored")]
+    Errored { error: ApiError },
+    #[serde(rename = "canceled")]
+    Canceled,
+    #[serde(rename = "expired")]
+    Expired,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct BatchIndividualResponse {
+    pub custom_id: String,
+    pub result: BatchResult,
+}
+
+pub async fn create_batch(
+    client: &dyn HttpClient,
+    api_url: &str,
+    api_key: &str,
+    request: CreateBatchRequest,
+) -> Result<MessageBatch, AnthropicError> {
+    let uri = format!("{api_url}/v1/messages/batches");
+
+    let request_builder = HttpRequest::builder()
+        .method(Method::POST)
+        .uri(uri)
+        .header("Anthropic-Version", "2023-06-01")
+        .header("X-Api-Key", api_key.trim())
+        .header("Content-Type", "application/json");
+
+    let serialized_request =
+        serde_json::to_string(&request).map_err(AnthropicError::SerializeRequest)?;
+    let http_request = request_builder
+        .body(AsyncBody::from(serialized_request))
+        .map_err(AnthropicError::BuildRequestBody)?;
+
+    let mut response = client
+        .send(http_request)
+        .await
+        .map_err(AnthropicError::HttpSend)?;
+
+    let rate_limits = RateLimitInfo::from_headers(response.headers());
+
+    if response.status().is_success() {
+        let mut body = String::new();
+        response
+            .body_mut()
+            .read_to_string(&mut body)
+            .await
+            .map_err(AnthropicError::ReadResponse)?;
+
+        serde_json::from_str(&body).map_err(AnthropicError::DeserializeResponse)
+    } else {
+        Err(crate::handle_error_response(response, rate_limits).await)
+    }
+}
+
+pub async fn retrieve_batch(
+    client: &dyn HttpClient,
+    api_url: &str,
+    api_key: &str,
+    message_batch_id: &str,
+) -> Result<MessageBatch, AnthropicError> {
+    let uri = format!("{api_url}/v1/messages/batches/{message_batch_id}");
+
+    let request_builder = HttpRequest::builder()
+        .method(Method::GET)
+        .uri(uri)
+        .header("Anthropic-Version", "2023-06-01")
+        .header("X-Api-Key", api_key.trim());
+
+    let http_request = request_builder
+        .body(AsyncBody::default())
+        .map_err(AnthropicError::BuildRequestBody)?;
+
+    let mut response = client
+        .send(http_request)
+        .await
+        .map_err(AnthropicError::HttpSend)?;
+
+    let rate_limits = RateLimitInfo::from_headers(response.headers());
+
+    if response.status().is_success() {
+        let mut body = String::new();
+        response
+            .body_mut()
+            .read_to_string(&mut body)
+            .await
+            .map_err(AnthropicError::ReadResponse)?;
+
+        serde_json::from_str(&body).map_err(AnthropicError::DeserializeResponse)
+    } else {
+        Err(crate::handle_error_response(response, rate_limits).await)
+    }
+}
+
+pub async fn retrieve_batch_results(
+    client: &dyn HttpClient,
+    api_url: &str,
+    api_key: &str,
+    message_batch_id: &str,
+) -> Result<Vec<BatchIndividualResponse>, AnthropicError> {
+    let uri = format!("{api_url}/v1/messages/batches/{message_batch_id}/results");
+
+    let request_builder = HttpRequest::builder()
+        .method(Method::GET)
+        .uri(uri)
+        .header("Anthropic-Version", "2023-06-01")
+        .header("X-Api-Key", api_key.trim());
+
+    let http_request = request_builder
+        .body(AsyncBody::default())
+        .map_err(AnthropicError::BuildRequestBody)?;
+
+    let mut response = client
+        .send(http_request)
+        .await
+        .map_err(AnthropicError::HttpSend)?;
+
+    let rate_limits = RateLimitInfo::from_headers(response.headers());
+
+    if response.status().is_success() {
+        let mut body = String::new();
+        response
+            .body_mut()
+            .read_to_string(&mut body)
+            .await
+            .map_err(AnthropicError::ReadResponse)?;
+
+        let mut results = Vec::new();
+        for line in body.lines() {
+            if line.trim().is_empty() {
+                continue;
+            }
+            let result: BatchIndividualResponse =
+                serde_json::from_str(line).map_err(AnthropicError::DeserializeResponse)?;
+            results.push(result);
+        }
+
+        Ok(results)
+    } else {
+        Err(crate::handle_error_response(response, rate_limits).await)
+    }
+}

crates/edit_prediction_cli/Cargo.toml 🔗

@@ -13,8 +13,9 @@ name = "ep_cli"
 path = "src/main.rs"
 
 [dependencies]
-
 anyhow.workspace = true
+anthropic.workspace = true
+http_client.workspace = true
 chrono.workspace = true
 clap.workspace = true
 client.workspace = true
@@ -28,6 +29,7 @@ fs.workspace = true
 futures.workspace = true
 gpui.workspace = true
 gpui_tokio.workspace = true
+indoc.workspace = true
 language.workspace = true
 language_extension.workspace = true
 language_model.workspace = true
@@ -46,6 +48,8 @@ serde_json.workspace = true
 settings.workspace = true
 shellexpand.workspace = true
 smol.workspace = true
+sqlez.workspace = true
+sqlez_macros.workspace = true
 terminal_view.workspace = true
 toml.workspace = true
 util.workspace = true

crates/edit_prediction_cli/src/example.rs 🔗

@@ -3,6 +3,8 @@ use std::{
     cell::RefCell,
     fmt::{self, Display},
     fs,
+    hash::Hash,
+    hash::Hasher,
     io::Write,
     mem,
     path::{Path, PathBuf},
@@ -43,7 +45,7 @@ pub struct NamedExample {
     pub example: Example,
 }
 
-#[derive(Clone, Debug, Serialize, Deserialize)]
+#[derive(Clone, Debug, Hash, Serialize, Deserialize)]
 pub struct Example {
     pub repository_url: String,
     pub revision: String,
@@ -54,6 +56,134 @@ pub struct Example {
     pub expected_patch: String,
 }
 
+impl Example {
+    fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
+        // git@github.com:owner/repo.git
+        if self.repository_url.contains('@') {
+            let (owner, repo) = self
+                .repository_url
+                .split_once(':')
+                .context("expected : in git url")?
+                .1
+                .split_once('/')
+                .context("expected / in git url")?;
+            Ok((
+                Cow::Borrowed(owner),
+                Cow::Borrowed(repo.trim_end_matches(".git")),
+            ))
+        // http://github.com/owner/repo.git
+        } else {
+            let url = Url::parse(&self.repository_url)?;
+            let mut segments = url.path_segments().context("empty http url")?;
+            let owner = segments
+                .next()
+                .context("expected owner path segment")?
+                .to_string();
+            let repo = segments
+                .next()
+                .context("expected repo path segment")?
+                .trim_end_matches(".git")
+                .to_string();
+            assert!(segments.next().is_none());
+
+            Ok((owner.into(), repo.into()))
+        }
+    }
+
+    pub async fn setup_worktree(&self, file_name: String) -> Result<PathBuf> {
+        let (repo_owner, repo_name) = self.repo_name()?;
+
+        let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
+        let repo_lock = lock_repo(&repo_dir).await;
+
+        if !repo_dir.is_dir() {
+            fs::create_dir_all(&repo_dir)?;
+            run_git(&repo_dir, &["init"]).await?;
+            run_git(
+                &repo_dir,
+                &["remote", "add", "origin", &self.repository_url],
+            )
+            .await?;
+        }
+
+        // Resolve the example to a revision, fetching it if needed.
+        let revision = run_git(
+            &repo_dir,
+            &["rev-parse", &format!("{}^{{commit}}", self.revision)],
+        )
+        .await;
+        let revision = if let Ok(revision) = revision {
+            revision
+        } else {
+            if run_git(
+                &repo_dir,
+                &["fetch", "--depth", "1", "origin", &self.revision],
+            )
+            .await
+            .is_err()
+            {
+                run_git(&repo_dir, &["fetch", "origin"]).await?;
+            }
+            let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?;
+            if revision != self.revision {
+                run_git(&repo_dir, &["tag", &self.revision, &revision]).await?;
+            }
+            revision
+        };
+
+        // Create the worktree for this example if needed.
+        let worktree_path = WORKTREES_DIR.join(&file_name).join(repo_name.as_ref());
+        if worktree_path.is_dir() {
+            run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
+            run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
+            run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
+        } else {
+            let worktree_path_string = worktree_path.to_string_lossy();
+            run_git(&repo_dir, &["branch", "-f", &file_name, revision.as_str()]).await?;
+            run_git(
+                &repo_dir,
+                &["worktree", "add", "-f", &worktree_path_string, &file_name],
+            )
+            .await?;
+        }
+        drop(repo_lock);
+
+        // Apply the uncommitted diff for this example.
+        if !self.uncommitted_diff.is_empty() {
+            let mut apply_process = smol::process::Command::new("git")
+                .current_dir(&worktree_path)
+                .args(&["apply", "-"])
+                .stdin(std::process::Stdio::piped())
+                .spawn()?;
+
+            let mut stdin = apply_process.stdin.take().unwrap();
+            stdin.write_all(self.uncommitted_diff.as_bytes()).await?;
+            stdin.close().await?;
+            drop(stdin);
+
+            let apply_result = apply_process.output().await?;
+            if !apply_result.status.success() {
+                anyhow::bail!(
+                    "Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
+                    apply_result.status,
+                    String::from_utf8_lossy(&apply_result.stderr),
+                    String::from_utf8_lossy(&apply_result.stdout),
+                );
+            }
+        }
+
+        Ok(worktree_path)
+    }
+
+    pub fn unique_name(&self) -> String {
+        let mut hasher = std::hash::DefaultHasher::new();
+        self.hash(&mut hasher);
+        let disambiguator = hasher.finish();
+        let hash = format!("{:04x}", disambiguator);
+        format!("{}_{}", &self.revision[..8], &hash[..4])
+    }
+}
+
 pub type ActualExcerpt = Excerpt;
 
 #[derive(Clone, Debug, Serialize, Deserialize)]
@@ -292,90 +422,7 @@ impl NamedExample {
     }
 
     pub async fn setup_worktree(&self) -> Result<PathBuf> {
-        let (repo_owner, repo_name) = self.repo_name()?;
-        let file_name = self.file_name();
-
-        let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
-        let repo_lock = lock_repo(&repo_dir).await;
-
-        if !repo_dir.is_dir() {
-            fs::create_dir_all(&repo_dir)?;
-            run_git(&repo_dir, &["init"]).await?;
-            run_git(
-                &repo_dir,
-                &["remote", "add", "origin", &self.example.repository_url],
-            )
-            .await?;
-        }
-
-        // Resolve the example to a revision, fetching it if needed.
-        let revision = run_git(
-            &repo_dir,
-            &[
-                "rev-parse",
-                &format!("{}^{{commit}}", self.example.revision),
-            ],
-        )
-        .await;
-        let revision = if let Ok(revision) = revision {
-            revision
-        } else {
-            run_git(
-                &repo_dir,
-                &["fetch", "--depth", "1", "origin", &self.example.revision],
-            )
-            .await?;
-            let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?;
-            if revision != self.example.revision {
-                run_git(&repo_dir, &["tag", &self.example.revision, &revision]).await?;
-            }
-            revision
-        };
-
-        // Create the worktree for this example if needed.
-        let worktree_path = WORKTREES_DIR.join(&file_name).join(repo_name.as_ref());
-        if worktree_path.is_dir() {
-            run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
-            run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
-            run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
-        } else {
-            let worktree_path_string = worktree_path.to_string_lossy();
-            run_git(&repo_dir, &["branch", "-f", &file_name, revision.as_str()]).await?;
-            run_git(
-                &repo_dir,
-                &["worktree", "add", "-f", &worktree_path_string, &file_name],
-            )
-            .await?;
-        }
-        drop(repo_lock);
-
-        // Apply the uncommitted diff for this example.
-        if !self.example.uncommitted_diff.is_empty() {
-            let mut apply_process = smol::process::Command::new("git")
-                .current_dir(&worktree_path)
-                .args(&["apply", "-"])
-                .stdin(std::process::Stdio::piped())
-                .spawn()?;
-
-            let mut stdin = apply_process.stdin.take().unwrap();
-            stdin
-                .write_all(self.example.uncommitted_diff.as_bytes())
-                .await?;
-            stdin.close().await?;
-            drop(stdin);
-
-            let apply_result = apply_process.output().await?;
-            if !apply_result.status.success() {
-                anyhow::bail!(
-                    "Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
-                    apply_result.status,
-                    String::from_utf8_lossy(&apply_result.stderr),
-                    String::from_utf8_lossy(&apply_result.stdout),
-                );
-            }
-        }
-
-        Ok(worktree_path)
+        self.example.setup_worktree(self.file_name()).await
     }
 
     pub fn file_name(&self) -> String {
@@ -391,40 +438,6 @@ impl NamedExample {
             .collect()
     }
 
-    fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
-        // git@github.com:owner/repo.git
-        if self.example.repository_url.contains('@') {
-            let (owner, repo) = self
-                .example
-                .repository_url
-                .split_once(':')
-                .context("expected : in git url")?
-                .1
-                .split_once('/')
-                .context("expected / in git url")?;
-            Ok((
-                Cow::Borrowed(owner),
-                Cow::Borrowed(repo.trim_end_matches(".git")),
-            ))
-        // http://github.com/owner/repo.git
-        } else {
-            let url = Url::parse(&self.example.repository_url)?;
-            let mut segments = url.path_segments().context("empty http url")?;
-            let owner = segments
-                .next()
-                .context("expected owner path segment")?
-                .to_string();
-            let repo = segments
-                .next()
-                .context("expected repo path segment")?
-                .trim_end_matches(".git")
-                .to_string();
-            assert!(segments.next().is_none());
-
-            Ok((owner.into(), repo.into()))
-        }
-    }
-
     pub async fn cursor_position(
         &self,
         project: &Entity<Project>,

crates/edit_prediction_cli/src/main.rs 🔗

@@ -5,6 +5,7 @@ mod metrics;
 mod paths;
 mod predict;
 mod source_location;
+mod training;
 mod util;
 
 use crate::{
@@ -13,9 +14,10 @@ use crate::{
     headless::ZetaCliAppState,
     predict::run_predict,
     source_location::SourceLocation,
+    training::{context::ContextType, distill::run_distill},
     util::{open_buffer, open_buffer_with_language_server},
 };
-use ::util::paths::PathStyle;
+use ::util::{ResultExt, paths::PathStyle};
 use anyhow::{Result, anyhow};
 use clap::{Args, Parser, Subcommand, ValueEnum};
 use cloud_llm_client::predict_edits_v3;
@@ -43,6 +45,7 @@ enum Command {
     Context(ContextArgs),
     Predict(PredictArguments),
     Eval(EvaluateArguments),
+    Distill(DistillArguments),
     ConvertExample {
         path: PathBuf,
         #[arg(long, value_enum, default_value_t = ExampleFormat::Md)]
@@ -111,6 +114,15 @@ pub struct PredictArguments {
     options: PredictionOptions,
 }
 
+#[derive(Debug, Args)]
+pub struct DistillArguments {
+    split_commit_dataset: PathBuf,
+    #[clap(long, value_enum, default_value_t = ContextType::CurrentFile)]
+    context_type: ContextType,
+    #[clap(long)]
+    batch: Option<String>,
+}
+
 #[derive(Clone, Debug, Args)]
 pub struct PredictionOptions {
     #[clap(flatten)]
@@ -468,6 +480,13 @@ fn main() {
                 Some(Command::Eval(arguments)) => {
                     run_evaluate(arguments, &app_state, cx).await;
                 }
+                Some(Command::Distill(arguments)) => {
+                    let _guard = cx
+                        .update(|cx| gpui_tokio::Tokio::handle(cx))
+                        .unwrap()
+                        .enter();
+                    run_distill(arguments).await.log_err();
+                }
                 Some(Command::ConvertExample {
                     path,
                     output_format,

crates/edit_prediction_cli/src/training/context.rs 🔗

@@ -0,0 +1,89 @@
+use std::path::Path;
+
+use crate::{source_location::SourceLocation, training::teacher::TeacherModel};
+
+#[derive(Debug, Clone, Default, clap::ValueEnum)]
+pub enum ContextType {
+    #[default]
+    CurrentFile,
+}
+
+const MAX_CONTEXT_SIZE: usize = 32768;
+
+pub fn collect_context(
+    context_type: &ContextType,
+    worktree_dir: &Path,
+    cursor: SourceLocation,
+) -> String {
+    let context = match context_type {
+        ContextType::CurrentFile => {
+            let file_path = worktree_dir.join(cursor.path.as_std_path());
+            let context = std::fs::read_to_string(&file_path).unwrap_or_default();
+
+            let context = add_special_tags(&context, worktree_dir, cursor);
+            context
+        }
+    };
+
+    let region_end_offset = context.find(TeacherModel::REGION_END);
+
+    if context.len() <= MAX_CONTEXT_SIZE {
+        return context;
+    }
+
+    if let Some(region_end_offset) = region_end_offset
+        && region_end_offset + TeacherModel::REGION_END.len() > MAX_CONTEXT_SIZE
+    {
+        let to_truncate = context.len() - MAX_CONTEXT_SIZE;
+        format!(
+            "[...{} bytes truncated]\n{}\n",
+            to_truncate,
+            &context[to_truncate..]
+        )
+    } else {
+        format!(
+            "{}\n[...{} bytes truncated]\n",
+            &context[..MAX_CONTEXT_SIZE],
+            context.len() - MAX_CONTEXT_SIZE
+        )
+    }
+}
+
+/// Add <|editable_region_start/end|> tags
+fn add_special_tags(context: &str, worktree_dir: &Path, cursor: SourceLocation) -> String {
+    let path = worktree_dir.join(cursor.path.as_std_path());
+    let file = std::fs::read_to_string(&path).unwrap_or_default();
+    let lines = file.lines().collect::<Vec<_>>();
+    let cursor_row = cursor.point.row as usize;
+    let start_line = cursor_row.saturating_sub(TeacherModel::LEFT_CONTEXT_SIZE);
+    let end_line = (cursor_row + TeacherModel::RIGHT_CONTEXT_SIZE).min(lines.len());
+
+    let snippet = lines[start_line..end_line].join("\n");
+
+    if context.contains(&snippet) {
+        let mut cursor_line = lines[cursor_row].to_string();
+        cursor_line.insert_str(cursor.point.column as usize, TeacherModel::USER_CURSOR);
+
+        let mut snippet_with_tags_lines = vec![];
+        snippet_with_tags_lines.push(TeacherModel::REGION_START);
+        snippet_with_tags_lines.extend(&lines[start_line..cursor_row]);
+        snippet_with_tags_lines.push(&cursor_line);
+        snippet_with_tags_lines.extend(&lines[cursor_row + 1..end_line]);
+        snippet_with_tags_lines.push(TeacherModel::REGION_END);
+        let snippet_with_tags = snippet_with_tags_lines.join("\n");
+
+        context.replace(&snippet, &snippet_with_tags)
+    } else {
+        log::warn!(
+            "Can't find area around the cursor in the context; proceeding without special tags"
+        );
+        context.to_string()
+    }
+}
+
+pub fn strip_special_tags(context: &str) -> String {
+    context
+        .replace(TeacherModel::REGION_START, "")
+        .replace(TeacherModel::REGION_END, "")
+        .replace(TeacherModel::USER_CURSOR, "")
+}

crates/edit_prediction_cli/src/training/distill.rs 🔗

@@ -0,0 +1,94 @@
+use serde::Deserialize;
+use std::sync::Arc;
+
+use crate::{
+    DistillArguments,
+    example::Example,
+    source_location::SourceLocation,
+    training::{
+        context::ContextType,
+        llm_client::LlmClient,
+        teacher::{TeacherModel, TeacherOutput},
+    },
+};
+use anyhow::Result;
+use reqwest_client::ReqwestClient;
+
+#[derive(Debug, Deserialize)]
+pub struct SplitCommit {
+    repo_url: String,
+    commit_sha: String,
+    edit_history: String,
+    expected_patch: String,
+    cursor_position: String,
+}
+
+pub async fn run_distill(arguments: DistillArguments) -> Result<()> {
+    let split_commits: Vec<SplitCommit> = std::fs::read_to_string(&arguments.split_commit_dataset)
+        .expect("Failed to read split commit dataset")
+        .lines()
+        .map(|line| serde_json::from_str(line).expect("Failed to parse JSON line"))
+        .collect();
+
+    let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
+
+    let llm_client = if let Some(cache_path) = arguments.batch {
+        LlmClient::batch(&cache_path, http_client)?
+    } else {
+        LlmClient::plain(http_client)?
+    };
+
+    let mut teacher = TeacherModel::new(
+        "claude-sonnet-4-5".to_string(),
+        ContextType::CurrentFile,
+        llm_client,
+    );
+
+    let mut num_marked_for_batching = 0;
+
+    for commit in split_commits {
+        if let Some(distilled) = distill_one(&mut teacher, commit).await? {
+            println!("{}", serde_json::to_string(&distilled)?);
+        } else {
+            if num_marked_for_batching == 0 {
+                log::warn!("Marked for batching");
+            }
+            num_marked_for_batching += 1;
+        }
+    }
+
+    eprintln!(
+        "{} requests are marked for batching",
+        num_marked_for_batching
+    );
+    let llm_client = teacher.client;
+    llm_client.sync_batches().await?;
+
+    Ok(())
+}
+
+pub async fn distill_one(
+    teacher: &mut TeacherModel,
+    commit: SplitCommit,
+) -> Result<Option<TeacherOutput>> {
+    let cursor: SourceLocation = commit
+        .cursor_position
+        .parse()
+        .expect("Failed to parse cursor position");
+
+    let path = cursor.path.to_rel_path_buf();
+
+    let example = Example {
+        repository_url: commit.repo_url,
+        revision: commit.commit_sha,
+        uncommitted_diff: commit.edit_history.clone(),
+        cursor_path: path.as_std_path().to_path_buf(),
+        cursor_position: commit.cursor_position,
+        edit_history: commit.edit_history, // todo: trim
+        expected_patch: commit.expected_patch,
+    };
+
+    let prediction = teacher.predict(example).await;
+
+    prediction
+}

crates/edit_prediction_cli/src/training/llm_client.rs 🔗

@@ -0,0 +1,417 @@
+use anthropic::{
+    ANTHROPIC_API_URL, Message, Request as AnthropicRequest, RequestContent,
+    Response as AnthropicResponse, Role, non_streaming_completion,
+};
+use anyhow::Result;
+use http_client::HttpClient;
+use indoc::indoc;
+use sqlez::bindable::Bind;
+use sqlez::bindable::StaticColumnCount;
+use sqlez_macros::sql;
+use std::hash::Hash;
+use std::hash::Hasher;
+use std::sync::Arc;
+
+pub struct PlainLlmClient {
+    http_client: Arc<dyn HttpClient>,
+    api_key: String,
+}
+
+impl PlainLlmClient {
+    fn new(http_client: Arc<dyn HttpClient>) -> Result<Self> {
+        let api_key = std::env::var("ANTHROPIC_API_KEY")
+            .map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
+        Ok(Self {
+            http_client,
+            api_key,
+        })
+    }
+
+    async fn generate(
+        &self,
+        model: String,
+        max_tokens: u64,
+        messages: Vec<Message>,
+    ) -> Result<AnthropicResponse> {
+        let request = AnthropicRequest {
+            model,
+            max_tokens,
+            messages,
+            tools: Vec::new(),
+            thinking: None,
+            tool_choice: None,
+            system: None,
+            metadata: None,
+            stop_sequences: Vec::new(),
+            temperature: None,
+            top_k: None,
+            top_p: None,
+        };
+
+        let response = non_streaming_completion(
+            self.http_client.as_ref(),
+            ANTHROPIC_API_URL,
+            &self.api_key,
+            request,
+            None,
+        )
+        .await
+        .map_err(|e| anyhow::anyhow!("{:?}", e))?;
+
+        Ok(response)
+    }
+}
+
+pub struct BatchingLlmClient {
+    connection: sqlez::connection::Connection,
+    http_client: Arc<dyn HttpClient>,
+    api_key: String,
+}
+
+struct CacheRow {
+    request_hash: String,
+    request: Option<String>,
+    response: Option<String>,
+    batch_id: Option<String>,
+}
+
+impl StaticColumnCount for CacheRow {
+    fn column_count() -> usize {
+        4
+    }
+}
+
+impl Bind for CacheRow {
+    fn bind(&self, statement: &sqlez::statement::Statement, start_index: i32) -> Result<i32> {
+        let next_index = statement.bind(&self.request_hash, start_index)?;
+        let next_index = statement.bind(&self.request, next_index)?;
+        let next_index = statement.bind(&self.response, next_index)?;
+        let next_index = statement.bind(&self.batch_id, next_index)?;
+        Ok(next_index)
+    }
+}
+
+#[derive(serde::Serialize, serde::Deserialize)]
+struct SerializableRequest {
+    model: String,
+    max_tokens: u64,
+    messages: Vec<SerializableMessage>,
+}
+
+#[derive(serde::Serialize, serde::Deserialize)]
+struct SerializableMessage {
+    role: String,
+    content: String,
+}
+
+impl BatchingLlmClient {
+    fn new(cache_path: &str, http_client: Arc<dyn HttpClient>) -> Result<Self> {
+        let api_key = std::env::var("ANTHROPIC_API_KEY")
+            .map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
+
+        let connection = sqlez::connection::Connection::open_file(&cache_path);
+        let mut statement = sqlez::statement::Statement::prepare(
+            &connection,
+            indoc! {"
+                CREATE TABLE IF NOT EXISTS cache (
+                    request_hash TEXT PRIMARY KEY,
+                    request TEXT,
+                    response TEXT,
+                    batch_id TEXT
+                );
+                "},
+        )?;
+        statement.exec()?;
+        drop(statement);
+
+        Ok(Self {
+            connection,
+            http_client,
+            api_key,
+        })
+    }
+
+    pub fn lookup(
+        &self,
+        model: &str,
+        max_tokens: u64,
+        messages: &[Message],
+    ) -> Result<Option<AnthropicResponse>> {
+        let request_hash_str = Self::request_hash(model, max_tokens, messages);
+        let response: Vec<String> = self.connection.select_bound(
+            &sql!(SELECT response FROM cache WHERE request_hash = ?1 AND response IS NOT NULL;),
+        )?(request_hash_str.as_str())?;
+        Ok(response
+            .into_iter()
+            .next()
+            .and_then(|text| serde_json::from_str(&text).ok()))
+    }
+
+    pub fn mark_for_batch(&self, model: &str, max_tokens: u64, messages: &[Message]) -> Result<()> {
+        let request_hash = Self::request_hash(model, max_tokens, messages);
+
+        let serializable_messages: Vec<SerializableMessage> = messages
+            .iter()
+            .map(|msg| SerializableMessage {
+                role: match msg.role {
+                    Role::User => "user".to_string(),
+                    Role::Assistant => "assistant".to_string(),
+                },
+                content: message_content_to_string(&msg.content),
+            })
+            .collect();
+
+        let serializable_request = SerializableRequest {
+            model: model.to_string(),
+            max_tokens,
+            messages: serializable_messages,
+        };
+
+        let request = Some(serde_json::to_string(&serializable_request)?);
+        let cache_row = CacheRow {
+            request_hash,
+            request,
+            response: None,
+            batch_id: None,
+        };
+        self.connection.exec_bound(sql!(
+            INSERT OR IGNORE INTO cache(request_hash, request, response, batch_id) VALUES (?, ?, ?, ?)))?(
+            cache_row,
+        )
+    }
+
+    async fn generate(
+        &self,
+        model: String,
+        max_tokens: u64,
+        messages: Vec<Message>,
+    ) -> Result<Option<AnthropicResponse>> {
+        let response = self.lookup(&model, max_tokens, &messages)?;
+        if let Some(response) = response {
+            return Ok(Some(response));
+        }
+
+        self.mark_for_batch(&model, max_tokens, &messages)?;
+
+        Ok(None)
+    }
+
+    /// Uploads pending requests as a new batch; downloads finished batches if any.
+    async fn sync_batches(&self) -> Result<()> {
+        self.upload_pending_requests().await?;
+        self.download_finished_batches().await
+    }
+
+    async fn download_finished_batches(&self) -> Result<()> {
+        let q = sql!(SELECT DISTINCT batch_id FROM cache WHERE batch_id IS NOT NULL AND response IS NULL);
+        let batch_ids: Vec<String> = self.connection.select(q)?()?;
+
+        for batch_id in batch_ids {
+            let batch_status = anthropic::batches::retrieve_batch(
+                self.http_client.as_ref(),
+                ANTHROPIC_API_URL,
+                &self.api_key,
+                &batch_id,
+            )
+            .await
+            .map_err(|e| anyhow::anyhow!("{:?}", e))?;
+
+            log::info!(
+                "Batch {} status: {}",
+                batch_id,
+                batch_status.processing_status
+            );
+
+            if batch_status.processing_status == "ended" {
+                let results = anthropic::batches::retrieve_batch_results(
+                    self.http_client.as_ref(),
+                    ANTHROPIC_API_URL,
+                    &self.api_key,
+                    &batch_id,
+                )
+                .await
+                .map_err(|e| anyhow::anyhow!("{:?}", e))?;
+
+                let mut success_count = 0;
+                for result in results {
+                    let request_hash = result
+                        .custom_id
+                        .strip_prefix("req_hash_")
+                        .unwrap_or(&result.custom_id)
+                        .to_string();
+
+                    match result.result {
+                        anthropic::batches::BatchResult::Succeeded { message } => {
+                            let response_json = serde_json::to_string(&message)?;
+                            let q = sql!(UPDATE cache SET response = ? WHERE request_hash = ?);
+                            self.connection.exec_bound(q)?((response_json, request_hash))?;
+                            success_count += 1;
+                        }
+                        anthropic::batches::BatchResult::Errored { error } => {
+                            log::error!("Batch request {} failed: {:?}", request_hash, error);
+                        }
+                        anthropic::batches::BatchResult::Canceled => {
+                            log::warn!("Batch request {} was canceled", request_hash);
+                        }
+                        anthropic::batches::BatchResult::Expired => {
+                            log::warn!("Batch request {} expired", request_hash);
+                        }
+                    }
+                }
+                log::info!("Uploaded {} successful requests", success_count);
+            }
+        }
+
+        Ok(())
+    }
+
+    async fn upload_pending_requests(&self) -> Result<String> {
+        let q = sql!(
+        SELECT request_hash, request FROM cache WHERE batch_id IS NULL AND response IS NULL
+        );
+
+        let rows: Vec<(String, String)> = self.connection.select(q)?()?;
+
+        if rows.is_empty() {
+            return Ok(String::new());
+        }
+
+        let batch_requests = rows
+            .iter()
+            .map(|(hash, request_str)| {
+                let serializable_request: SerializableRequest =
+                    serde_json::from_str(&request_str).unwrap();
+
+                let messages: Vec<Message> = serializable_request
+                    .messages
+                    .into_iter()
+                    .map(|msg| Message {
+                        role: match msg.role.as_str() {
+                            "user" => Role::User,
+                            "assistant" => Role::Assistant,
+                            _ => Role::User,
+                        },
+                        content: vec![RequestContent::Text {
+                            text: msg.content,
+                            cache_control: None,
+                        }],
+                    })
+                    .collect();
+
+                let params = AnthropicRequest {
+                    model: serializable_request.model,
+                    max_tokens: serializable_request.max_tokens,
+                    messages,
+                    tools: Vec::new(),
+                    thinking: None,
+                    tool_choice: None,
+                    system: None,
+                    metadata: None,
+                    stop_sequences: Vec::new(),
+                    temperature: None,
+                    top_k: None,
+                    top_p: None,
+                };
+
+                let custom_id = format!("req_hash_{}", hash);
+                anthropic::batches::BatchRequest { custom_id, params }
+            })
+            .collect::<Vec<_>>();
+
+        let batch_len = batch_requests.len();
+        let batch = anthropic::batches::create_batch(
+            self.http_client.as_ref(),
+            ANTHROPIC_API_URL,
+            &self.api_key,
+            anthropic::batches::CreateBatchRequest {
+                requests: batch_requests,
+            },
+        )
+        .await
+        .map_err(|e| anyhow::anyhow!("{:?}", e))?;
+
+        let q = sql!(
+            UPDATE cache SET batch_id = ? WHERE batch_id is NULL
+        );
+        self.connection.exec_bound(q)?(batch.id.as_str())?;
+
+        log::info!("Uploaded batch with {} requests", batch_len);
+
+        Ok(batch.id)
+    }
+
+    fn request_hash(model: &str, max_tokens: u64, messages: &[Message]) -> String {
+        let mut hasher = std::hash::DefaultHasher::new();
+        model.hash(&mut hasher);
+        max_tokens.hash(&mut hasher);
+        for msg in messages {
+            message_content_to_string(&msg.content).hash(&mut hasher);
+        }
+        let request_hash = hasher.finish();
+        format!("{request_hash:016x}")
+    }
+}
+
+fn message_content_to_string(content: &[RequestContent]) -> String {
+    content
+        .iter()
+        .filter_map(|c| match c {
+            RequestContent::Text { text, .. } => Some(text.clone()),
+            _ => None,
+        })
+        .collect::<Vec<String>>()
+        .join("\n")
+}
+
+pub enum LlmClient {
+    // No batching
+    Plain(PlainLlmClient),
+    Batch(BatchingLlmClient),
+    Dummy,
+}
+
+impl LlmClient {
+    pub fn plain(http_client: Arc<dyn HttpClient>) -> Result<Self> {
+        Ok(Self::Plain(PlainLlmClient::new(http_client)?))
+    }
+
+    pub fn batch(cache_path: &str, http_client: Arc<dyn HttpClient>) -> Result<Self> {
+        Ok(Self::Batch(BatchingLlmClient::new(
+            cache_path,
+            http_client,
+        )?))
+    }
+
+    #[allow(dead_code)]
+    pub fn dummy() -> Self {
+        Self::Dummy
+    }
+
+    pub async fn generate(
+        &self,
+        model: String,
+        max_tokens: u64,
+        messages: Vec<Message>,
+    ) -> Result<Option<AnthropicResponse>> {
+        match self {
+            LlmClient::Plain(plain_llm_client) => plain_llm_client
+                .generate(model, max_tokens, messages)
+                .await
+                .map(Some),
+            LlmClient::Batch(batching_llm_client) => {
+                batching_llm_client
+                    .generate(model, max_tokens, messages)
+                    .await
+            }
+            LlmClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
+        }
+    }
+
+    pub async fn sync_batches(&self) -> Result<()> {
+        match self {
+            LlmClient::Plain(_) => Ok(()),
+            LlmClient::Batch(batching_llm_client) => batching_llm_client.sync_batches().await,
+            LlmClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
+        }
+    }
+}

crates/edit_prediction_cli/src/training/teacher.prompt.md 🔗

@@ -0,0 +1,48 @@
+# Instructions
+
+You are a code completion assistant helping a programmer finish their work. Your task is to:
+
+1. Analyze the edit history to understand what the programmer is trying to achieve
+2. Identify any incomplete refactoring or changes that need to be finished
+3. Make the remaining edits that a human programmer would logically make next (by rewriting the corresponding code sections)
+4. Apply systematic changes consistently across the entire codebase - if you see a pattern starting, complete it everywhere.
+
+Focus on:
+- Understanding the intent behind the changes (e.g., improving error handling, refactoring APIs, fixing bugs)
+- Completing any partially-applied changes across the codebase
+- Ensuring consistency with the programming style and patterns already established
+- Making edits that maintain or improve code quality
+- If the programmer started refactoring one instance of a pattern, find and update ALL similar instances
+- Don't write a lot of code if you're not sure what to do
+
+Rules:
+- Do not just mechanically apply patterns - reason about what changes make sense given the context and the programmer's apparent goals.
+- Do not just fix syntax errors - look for the broader refactoring pattern and apply it systematically throughout the code.
+
+Input format:
+- You receive small code fragments called context (structs, field definitions, function signatures, etc.). They may or may not be relevant.
+- Never modify the context code.
+- You also receive a code snippet between <|editable_region_start|> and <|editable_region_end|>. This is the editable region.
+- The cursor position is marked with <|user_cursor|>.
+
+Output format:
+- Return the entire editable region, applying any edits you make.
+- Remove the <|user_cursor|> marker.
+- Wrap the edited code in a block of exactly five backticks.
+
+Output example:
+`````
+    // `zed --askpass` Makes zed operate in nc/netcat mode for use with askpass
+    if let Some(socket) = &args.askpass {{
+        askpass::main(socket);
+        return Ok(());
+    }}
+`````
+
+## User Edits History
+
+{{edit_history}}
+
+## Code Context
+
+{{context}}

crates/edit_prediction_cli/src/training/teacher.rs 🔗

@@ -0,0 +1,266 @@
+use crate::{
+    example::Example,
+    source_location::SourceLocation,
+    training::{
+        context::{ContextType, collect_context, strip_special_tags},
+        llm_client::LlmClient,
+    },
+};
+use anthropic::{Message, RequestContent, ResponseContent, Role};
+use anyhow::Result;
+
+pub struct TeacherModel {
+    pub llm_name: String,
+    pub context: ContextType,
+    pub client: LlmClient,
+}
+
+#[derive(Debug, serde::Serialize)]
+pub struct TeacherOutput {
+    parsed_output: String,
+    prompt: String,
+    raw_llm_response: String,
+    context: String,
+    diff: String,
+}
+
+impl TeacherModel {
+    const PROMPT: &str = include_str!("teacher.prompt.md");
+    pub(crate) const REGION_START: &str = "<|editable_region_start|>\n";
+    pub(crate) const REGION_END: &str = "<|editable_region_end|>";
+    pub(crate) const USER_CURSOR: &str = "<|user_cursor|>";
+
+    /// Number of lines to include before the cursor position
+    pub(crate) const LEFT_CONTEXT_SIZE: usize = 5;
+
+    /// Number of lines to include after the cursor position
+    pub(crate) const RIGHT_CONTEXT_SIZE: usize = 5;
+
+    /// Truncate edit history to this number of last lines
+    const MAX_HISTORY_LINES: usize = 128;
+
+    pub fn new(llm_name: String, context: ContextType, client: LlmClient) -> Self {
+        TeacherModel {
+            llm_name,
+            context,
+            client,
+        }
+    }
+
+    pub async fn predict(&self, input: Example) -> Result<Option<TeacherOutput>> {
+        let name = input.unique_name();
+        let worktree_dir = input.setup_worktree(name).await?;
+        let cursor: SourceLocation = input
+            .cursor_position
+            .parse()
+            .expect("Failed to parse cursor position");
+
+        let context = collect_context(&self.context, &worktree_dir, cursor.clone());
+        let edit_history = Self::format_edit_history(&input.edit_history);
+
+        let prompt = Self::PROMPT
+            .replace("{{context}}", &context)
+            .replace("{{edit_history}}", &edit_history);
+
+        let messages = vec![Message {
+            role: Role::User,
+            content: vec![RequestContent::Text {
+                text: prompt.clone(),
+                cache_control: None,
+            }],
+        }];
+
+        let Some(response) = self
+            .client
+            .generate(self.llm_name.clone(), 16384, messages)
+            .await?
+        else {
+            return Ok(None);
+        };
+
+        let response_text = response
+            .content
+            .into_iter()
+            .filter_map(|content| match content {
+                ResponseContent::Text { text } => Some(text),
+                _ => None,
+            })
+            .collect::<Vec<String>>()
+            .join("\n");
+
+        let parsed_output = self.parse_response(&response_text);
+
+        let original_editable_region = Self::extract_editable_region(&context);
+        let context_after_edit = context.replace(&original_editable_region, &parsed_output);
+        let context_after_edit = strip_special_tags(&context_after_edit);
+        let context_before_edit = strip_special_tags(&context);
+        let diff = language::unified_diff(&context_before_edit, &context_after_edit);
+
+        // zeta distill --batch batch_results.txt
+        // zeta distill
+        // 1. Run `zeta distill <2000 examples <- all examples>` for the first time
+        //  - store LLM requests in a batch, don't actual send the request
+        //  - send the batch (2000 requests) after all inputs are processed
+        // 2. `zeta send-batches`
+        //   - upload the batch to Anthropic
+
+        // https://platform.claude.com/docs/en/build-with-claude/batch-processing
+        // https://crates.io/crates/anthropic-sdk-rust
+
+        //   - poll for results
+        //   - when ready, store results in cache (a database)
+        // 3. `zeta distill` again
+        //    - use the cached results this time
+
+        Ok(Some(TeacherOutput {
+            parsed_output,
+            prompt,
+            raw_llm_response: response_text,
+            context,
+            diff,
+        }))
+    }
+
+    fn parse_response(&self, content: &str) -> String {
+        let codeblock = Self::extract_last_codeblock(content);
+        let editable_region = Self::extract_editable_region(&codeblock);
+
+        editable_region
+    }
+
+    /// Extract content from the last code-fenced block if any, or else return content as is
+    fn extract_last_codeblock(text: &str) -> String {
+        let mut last_block = None;
+        let mut search_start = 0;
+
+        while let Some(start) = text[search_start..].find("```") {
+            let start = start + search_start;
+            let bytes = text.as_bytes();
+            let mut backtick_end = start;
+
+            while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
+                backtick_end += 1;
+            }
+
+            let backtick_count = backtick_end - start;
+            let closing_backticks = "`".repeat(backtick_count);
+
+            if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) {
+                let code_block = &text[backtick_end + 1..backtick_end + end_pos - 1];
+                last_block = Some(code_block.to_string());
+                search_start = backtick_end + end_pos + backtick_count;
+            } else {
+                break;
+            }
+        }
+
+        last_block.unwrap_or_else(|| text.to_string())
+    }
+
+    fn extract_editable_region(text: &str) -> String {
+        let start = text
+            .find(Self::REGION_START)
+            .map_or(0, |pos| pos + Self::REGION_START.len());
+        let end = text.find(Self::REGION_END).unwrap_or(text.len());
+
+        text[start..end].to_string()
+    }
+
+    /// Truncates edit history to a maximum length and removes comments (unified diff garbage lines)
+    fn format_edit_history(edit_history: &str) -> String {
+        let lines = edit_history
+            .lines()
+            .filter(|&s| Self::is_content_line(s))
+            .collect::<Vec<_>>();
+
+        let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
+            &lines[lines.len() - Self::MAX_HISTORY_LINES..]
+        } else {
+            &lines
+        };
+        history_lines.join("\n")
+    }
+
+    fn is_content_line(s: &str) -> bool {
+        s.starts_with("-")
+            || s.starts_with("+")
+            || s.starts_with(" ")
+            || s.starts_with("---")
+            || s.starts_with("+++")
+            || s.starts_with("@@")
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn test_parse_response() {
+        let teacher = TeacherModel::new(
+            "test".to_string(),
+            ContextType::CurrentFile,
+            LlmClient::dummy(),
+        );
+        let response = "This is a test response.";
+        let parsed = teacher.parse_response(response);
+        assert_eq!(parsed, response.to_string());
+
+        let response = indoc::indoc! {"
+            Some thinking
+
+            `````
+            actual response
+            `````
+            "};
+        let parsed = teacher.parse_response(response);
+        assert_eq!(parsed, "actual response");
+    }
+
+    #[test]
+    fn test_extract_last_code_block() {
+        let text = indoc::indoc! {"
+            Some thinking
+
+            ```
+            first block
+            ```
+
+            `````
+            last block
+            `````
+            "};
+        let last_block = TeacherModel::extract_last_codeblock(text);
+        assert_eq!(last_block, "last block");
+    }
+
+    #[test]
+    fn test_extract_editable_region() {
+        let teacher = TeacherModel::new(
+            "test".to_string(),
+            ContextType::CurrentFile,
+            LlmClient::dummy(),
+        );
+        let response = indoc::indoc! {"
+            some lines
+            are
+            here
+            <|editable_region_start|>
+            one
+            two three
+
+            <|editable_region_end|>
+            more
+            lines here
+            "};
+        let parsed = teacher.parse_response(response);
+        assert_eq!(
+            parsed,
+            indoc::indoc! {"
+            one
+            two three
+
+            "}
+        );
+    }
+}