Detailed changes
@@ -5167,6 +5167,7 @@ dependencies = [
name = "edit_prediction_cli"
version = "0.1.0"
dependencies = [
+ "anthropic",
"anyhow",
"chrono",
"clap",
@@ -5182,6 +5183,7 @@ dependencies = [
"futures 0.3.31",
"gpui",
"gpui_tokio",
+ "http_client",
"indoc",
"language",
"language_extension",
@@ -5202,6 +5204,8 @@ dependencies = [
"settings",
"shellexpand 2.1.2",
"smol",
+ "sqlez",
+ "sqlez_macros",
"terminal_view",
"toml 0.8.23",
"util",
@@ -244,7 +244,6 @@ activity_indicator = { path = "crates/activity_indicator" }
agent_ui = { path = "crates/agent_ui" }
agent_settings = { path = "crates/agent_settings" }
agent_servers = { path = "crates/agent_servers" }
-ai = { path = "crates/ai" }
ai_onboarding = { path = "crates/ai_onboarding" }
anthropic = { path = "crates/anthropic" }
askpass = { path = "crates/askpass" }
@@ -254,7 +253,6 @@ assistant_slash_command = { path = "crates/assistant_slash_command" }
assistant_slash_commands = { path = "crates/assistant_slash_commands" }
audio = { path = "crates/audio" }
auto_update = { path = "crates/auto_update" }
-auto_update_helper = { path = "crates/auto_update_helper" }
auto_update_ui = { path = "crates/auto_update_ui" }
aws_http_client = { path = "crates/aws_http_client" }
bedrock = { path = "crates/bedrock" }
@@ -269,7 +267,6 @@ cloud_api_client = { path = "crates/cloud_api_client" }
cloud_api_types = { path = "crates/cloud_api_types" }
cloud_llm_client = { path = "crates/cloud_llm_client" }
cloud_zeta2_prompt = { path = "crates/cloud_zeta2_prompt" }
-collab = { path = "crates/collab" }
collab_ui = { path = "crates/collab_ui" }
collections = { path = "crates/collections", version = "0.1.0" }
command_palette = { path = "crates/command_palette" }
@@ -358,8 +355,6 @@ panel = { path = "crates/panel" }
paths = { path = "crates/paths" }
perf = { path = "tooling/perf" }
picker = { path = "crates/picker" }
-plugin = { path = "crates/plugin" }
-plugin_macros = { path = "crates/plugin_macros" }
prettier = { path = "crates/prettier" }
settings_profile_selector = { path = "crates/settings_profile_selector" }
project = { path = "crates/project" }
@@ -370,12 +365,10 @@ proto = { path = "crates/proto" }
recent_projects = { path = "crates/recent_projects" }
refineable = { path = "crates/refineable" }
release_channel = { path = "crates/release_channel" }
-scheduler = { path = "crates/scheduler" }
remote = { path = "crates/remote" }
remote_server = { path = "crates/remote_server" }
repl = { path = "crates/repl" }
reqwest_client = { path = "crates/reqwest_client" }
-rich_text = { path = "crates/rich_text" }
rodio = { git = "https://github.com/RustAudio/rodio", rev ="e2074c6c2acf07b57cf717e076bdda7a9ac6e70b", features = ["wav", "playback", "wav_output", "recording"] }
rope = { path = "crates/rope" }
rpc = { path = "crates/rpc" }
@@ -392,7 +385,6 @@ snippets_ui = { path = "crates/snippets_ui" }
sqlez = { path = "crates/sqlez" }
sqlez_macros = { path = "crates/sqlez_macros" }
story = { path = "crates/story" }
-storybook = { path = "crates/storybook" }
streaming_diff = { path = "crates/streaming_diff" }
sum_tree = { path = "crates/sum_tree" }
supermaven = { path = "crates/supermaven" }
@@ -409,7 +401,6 @@ terminal_view = { path = "crates/terminal_view" }
text = { path = "crates/text" }
theme = { path = "crates/theme" }
theme_extension = { path = "crates/theme_extension" }
-theme_importer = { path = "crates/theme_importer" }
theme_selector = { path = "crates/theme_selector" }
time_format = { path = "crates/time_format" }
title_bar = { path = "crates/title_bar" }
@@ -510,13 +501,11 @@ exec = "0.3.1"
fancy-regex = "0.16.0"
fork = "0.4.0"
futures = "0.3"
-futures-batch = "0.6.1"
futures-lite = "1.13"
gh-workflow = { git = "https://github.com/zed-industries/gh-workflow", rev = "09acfdf2bd5c1d6254abefd609c808ff73547b2c" }
git2 = { version = "0.20.1", default-features = false }
globset = "0.4"
handlebars = "4.3"
-hashbrown = "0.15.3"
heck = "0.5"
heed = { version = "0.21.0", features = ["read-txn-no-tls"] }
hex = "0.4.3"
@@ -552,7 +541,6 @@ nanoid = "0.4"
nbformat = "0.15.0"
nix = "0.29"
num-format = "0.4.4"
-num-traits = "0.2"
objc = "0.2"
objc2-foundation = { version = "=0.3.1", default-features = false, features = [
"NSArray",
@@ -591,7 +579,6 @@ pet = { git = "https://github.com/microsoft/python-environment-tools.git", rev =
pet-conda = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "1e86914c3ce2f3a08c0cedbcb0615a7f9fa7a5da" }
pet-core = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "1e86914c3ce2f3a08c0cedbcb0615a7f9fa7a5da" }
pet-fs = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "1e86914c3ce2f3a08c0cedbcb0615a7f9fa7a5da" }
-pet-pixi = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "1e86914c3ce2f3a08c0cedbcb0615a7f9fa7a5da" }
pet-poetry = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "1e86914c3ce2f3a08c0cedbcb0615a7f9fa7a5da" }
pet-reporter = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "1e86914c3ce2f3a08c0cedbcb0615a7f9fa7a5da" }
pet-virtualenv = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "1e86914c3ce2f3a08c0cedbcb0615a7f9fa7a5da" }
@@ -631,7 +618,6 @@ scap = { git = "https://github.com/zed-industries/scap", rev = "4afea48c3b002197
schemars = { version = "1.0", features = ["indexmap2"] }
semver = { version = "1.0", features = ["serde"] }
serde = { version = "1.0.221", features = ["derive", "rc"] }
-serde_derive = "1.0.221"
serde_json = { version = "1.0.144", features = ["preserve_order", "raw_value"] }
serde_json_lenient = { version = "0.2", features = [
"preserve_order",
@@ -643,7 +629,6 @@ serde_urlencoded = "0.7"
sha2 = "0.10"
shellexpand = "2.1.0"
shlex = "1.3.0"
-similar = "2.6"
simplelog = "0.12.2"
slotmap = "1.0.6"
smallvec = { version = "1.6", features = ["union"] }
@@ -722,7 +707,6 @@ wasmtime-wasi = "29"
wax = "0.6"
which = "6.0.0"
windows-core = "0.61"
-wit-component = "0.221"
yawc = "0.2.5"
zeroize = "1.8"
zstd = "0.11"
@@ -804,20 +788,13 @@ settings_macros = { opt-level = 3 }
sqlez_macros = { opt-level = 3, codegen-units = 1 }
ui_macros = { opt-level = 3 }
util_macros = { opt-level = 3 }
-serde_derive = { opt-level = 3 }
quote = { opt-level = 3 }
syn = { opt-level = 3 }
proc-macro2 = { opt-level = 3 }
# proc-macros end
taffy = { opt-level = 3 }
-cranelift-codegen = { opt-level = 3 }
-cranelift-codegen-meta = { opt-level = 3 }
-cranelift-codegen-shared = { opt-level = 3 }
resvg = { opt-level = 3 }
-rustybuzz = { opt-level = 3 }
-ttf-parser = { opt-level = 3 }
-wasmtime-cranelift = { opt-level = 3 }
wasmtime = { opt-level = 3 }
# Build single-source-file crates with cg=1 as it helps make `cargo build` of a whole workspace a bit faster
activity_indicator = { codegen-units = 1 }
@@ -826,7 +803,6 @@ breadcrumbs = { codegen-units = 1 }
collections = { codegen-units = 1 }
command_palette = { codegen-units = 1 }
command_palette_hooks = { codegen-units = 1 }
-extension_cli = { codegen-units = 1 }
feature_flags = { codegen-units = 1 }
file_icons = { codegen-units = 1 }
fsevent = { codegen-units = 1 }
@@ -846,7 +822,6 @@ project_symbols = { codegen-units = 1 }
refineable = { codegen-units = 1 }
release_channel = { codegen-units = 1 }
reqwest_client = { codegen-units = 1 }
-rich_text = { codegen-units = 1 }
session = { codegen-units = 1 }
snippet = { codegen-units = 1 }
snippets_ui = { codegen-units = 1 }
@@ -12,6 +12,8 @@ pub use settings::{AnthropicAvailableModel as AvailableModel, ModelMode};
use strum::{EnumIter, EnumString};
use thiserror::Error;
+pub mod batches;
+
pub const ANTHROPIC_API_URL: &str = "https://api.anthropic.com";
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
@@ -465,6 +467,7 @@ impl Model {
}
}
+/// Generate completion with streaming.
pub async fn stream_completion(
client: &dyn HttpClient,
api_url: &str,
@@ -477,6 +480,101 @@ pub async fn stream_completion(
.map(|output| output.0)
}
+/// Generate completion without streaming.
+pub async fn non_streaming_completion(
+ client: &dyn HttpClient,
+ api_url: &str,
+ api_key: &str,
+ request: Request,
+ beta_headers: Option<String>,
+) -> Result<Response, AnthropicError> {
+ let (mut response, rate_limits) =
+ send_request(client, api_url, api_key, &request, beta_headers).await?;
+
+ if response.status().is_success() {
+ let mut body = String::new();
+ response
+ .body_mut()
+ .read_to_string(&mut body)
+ .await
+ .map_err(AnthropicError::ReadResponse)?;
+
+ serde_json::from_str(&body).map_err(AnthropicError::DeserializeResponse)
+ } else {
+ Err(handle_error_response(response, rate_limits).await)
+ }
+}
+
+async fn send_request(
+ client: &dyn HttpClient,
+ api_url: &str,
+ api_key: &str,
+ request: impl Serialize,
+ beta_headers: Option<String>,
+) -> Result<(http::Response<AsyncBody>, RateLimitInfo), AnthropicError> {
+ let uri = format!("{api_url}/v1/messages");
+
+ let mut request_builder = HttpRequest::builder()
+ .method(Method::POST)
+ .uri(uri)
+ .header("Anthropic-Version", "2023-06-01")
+ .header("X-Api-Key", api_key.trim())
+ .header("Content-Type", "application/json");
+
+ if let Some(beta_headers) = beta_headers {
+ request_builder = request_builder.header("Anthropic-Beta", beta_headers);
+ }
+
+ let serialized_request =
+ serde_json::to_string(&request).map_err(AnthropicError::SerializeRequest)?;
+ let request = request_builder
+ .body(AsyncBody::from(serialized_request))
+ .map_err(AnthropicError::BuildRequestBody)?;
+
+ let response = client
+ .send(request)
+ .await
+ .map_err(AnthropicError::HttpSend)?;
+
+ let rate_limits = RateLimitInfo::from_headers(response.headers());
+
+ Ok((response, rate_limits))
+}
+
+async fn handle_error_response(
+ mut response: http::Response<AsyncBody>,
+ rate_limits: RateLimitInfo,
+) -> AnthropicError {
+ if response.status().as_u16() == 529 {
+ return AnthropicError::ServerOverloaded {
+ retry_after: rate_limits.retry_after,
+ };
+ }
+
+ if let Some(retry_after) = rate_limits.retry_after {
+ return AnthropicError::RateLimit { retry_after };
+ }
+
+ let mut body = String::new();
+ let read_result = response
+ .body_mut()
+ .read_to_string(&mut body)
+ .await
+ .map_err(AnthropicError::ReadResponse);
+
+ if let Err(err) = read_result {
+ return err;
+ }
+
+ match serde_json::from_str::<Event>(&body) {
+ Ok(Event::Error { error }) => AnthropicError::ApiError(error),
+ Ok(_) | Err(_) => AnthropicError::HttpResponseError {
+ status_code: response.status(),
+ message: body,
+ },
+ }
+}
+
/// An individual rate limit.
#[derive(Debug)]
pub struct RateLimit {
@@ -580,30 +678,10 @@ pub async fn stream_completion_with_rate_limit_info(
base: request,
stream: true,
};
- let uri = format!("{api_url}/v1/messages");
- let mut request_builder = HttpRequest::builder()
- .method(Method::POST)
- .uri(uri)
- .header("Anthropic-Version", "2023-06-01")
- .header("X-Api-Key", api_key.trim())
- .header("Content-Type", "application/json");
+ let (response, rate_limits) =
+ send_request(client, api_url, api_key, &request, beta_headers).await?;
- if let Some(beta_headers) = beta_headers {
- request_builder = request_builder.header("Anthropic-Beta", beta_headers);
- }
-
- let serialized_request =
- serde_json::to_string(&request).map_err(AnthropicError::SerializeRequest)?;
- let request = request_builder
- .body(AsyncBody::from(serialized_request))
- .map_err(AnthropicError::BuildRequestBody)?;
-
- let mut response = client
- .send(request)
- .await
- .map_err(AnthropicError::HttpSend)?;
- let rate_limits = RateLimitInfo::from_headers(response.headers());
if response.status().is_success() {
let reader = BufReader::new(response.into_body());
let stream = reader
@@ -622,27 +700,8 @@ pub async fn stream_completion_with_rate_limit_info(
})
.boxed();
Ok((stream, Some(rate_limits)))
- } else if response.status().as_u16() == 529 {
- Err(AnthropicError::ServerOverloaded {
- retry_after: rate_limits.retry_after,
- })
- } else if let Some(retry_after) = rate_limits.retry_after {
- Err(AnthropicError::RateLimit { retry_after })
} else {
- let mut body = String::new();
- response
- .body_mut()
- .read_to_string(&mut body)
- .await
- .map_err(AnthropicError::ReadResponse)?;
-
- match serde_json::from_str::<Event>(&body) {
- Ok(Event::Error { error }) => Err(AnthropicError::ApiError(error)),
- Ok(_) | Err(_) => Err(AnthropicError::HttpResponseError {
- status_code: response.status(),
- message: body,
- }),
- }
+ Err(handle_error_response(response, rate_limits).await)
}
}
@@ -0,0 +1,190 @@
+use anyhow::Result;
+use futures::AsyncReadExt;
+use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
+use serde::{Deserialize, Serialize};
+
+use crate::{AnthropicError, ApiError, RateLimitInfo, Request, Response};
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct BatchRequest {
+ pub custom_id: String,
+ pub params: Request,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct CreateBatchRequest {
+ pub requests: Vec<BatchRequest>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct MessageBatchRequestCounts {
+ pub processing: u64,
+ pub succeeded: u64,
+ pub errored: u64,
+ pub canceled: u64,
+ pub expired: u64,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct MessageBatch {
+ pub id: String,
+ #[serde(rename = "type")]
+ pub batch_type: String,
+ pub processing_status: String,
+ pub request_counts: MessageBatchRequestCounts,
+ pub ended_at: Option<String>,
+ pub created_at: String,
+ pub expires_at: String,
+ pub archived_at: Option<String>,
+ pub cancel_initiated_at: Option<String>,
+ pub results_url: Option<String>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(tag = "type")]
+pub enum BatchResult {
+ #[serde(rename = "succeeded")]
+ Succeeded { message: Response },
+ #[serde(rename = "errored")]
+ Errored { error: ApiError },
+ #[serde(rename = "canceled")]
+ Canceled,
+ #[serde(rename = "expired")]
+ Expired,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct BatchIndividualResponse {
+ pub custom_id: String,
+ pub result: BatchResult,
+}
+
+pub async fn create_batch(
+ client: &dyn HttpClient,
+ api_url: &str,
+ api_key: &str,
+ request: CreateBatchRequest,
+) -> Result<MessageBatch, AnthropicError> {
+ let uri = format!("{api_url}/v1/messages/batches");
+
+ let request_builder = HttpRequest::builder()
+ .method(Method::POST)
+ .uri(uri)
+ .header("Anthropic-Version", "2023-06-01")
+ .header("X-Api-Key", api_key.trim())
+ .header("Content-Type", "application/json");
+
+ let serialized_request =
+ serde_json::to_string(&request).map_err(AnthropicError::SerializeRequest)?;
+ let http_request = request_builder
+ .body(AsyncBody::from(serialized_request))
+ .map_err(AnthropicError::BuildRequestBody)?;
+
+ let mut response = client
+ .send(http_request)
+ .await
+ .map_err(AnthropicError::HttpSend)?;
+
+ let rate_limits = RateLimitInfo::from_headers(response.headers());
+
+ if response.status().is_success() {
+ let mut body = String::new();
+ response
+ .body_mut()
+ .read_to_string(&mut body)
+ .await
+ .map_err(AnthropicError::ReadResponse)?;
+
+ serde_json::from_str(&body).map_err(AnthropicError::DeserializeResponse)
+ } else {
+ Err(crate::handle_error_response(response, rate_limits).await)
+ }
+}
+
+pub async fn retrieve_batch(
+ client: &dyn HttpClient,
+ api_url: &str,
+ api_key: &str,
+ message_batch_id: &str,
+) -> Result<MessageBatch, AnthropicError> {
+ let uri = format!("{api_url}/v1/messages/batches/{message_batch_id}");
+
+ let request_builder = HttpRequest::builder()
+ .method(Method::GET)
+ .uri(uri)
+ .header("Anthropic-Version", "2023-06-01")
+ .header("X-Api-Key", api_key.trim());
+
+ let http_request = request_builder
+ .body(AsyncBody::default())
+ .map_err(AnthropicError::BuildRequestBody)?;
+
+ let mut response = client
+ .send(http_request)
+ .await
+ .map_err(AnthropicError::HttpSend)?;
+
+ let rate_limits = RateLimitInfo::from_headers(response.headers());
+
+ if response.status().is_success() {
+ let mut body = String::new();
+ response
+ .body_mut()
+ .read_to_string(&mut body)
+ .await
+ .map_err(AnthropicError::ReadResponse)?;
+
+ serde_json::from_str(&body).map_err(AnthropicError::DeserializeResponse)
+ } else {
+ Err(crate::handle_error_response(response, rate_limits).await)
+ }
+}
+
+pub async fn retrieve_batch_results(
+ client: &dyn HttpClient,
+ api_url: &str,
+ api_key: &str,
+ message_batch_id: &str,
+) -> Result<Vec<BatchIndividualResponse>, AnthropicError> {
+ let uri = format!("{api_url}/v1/messages/batches/{message_batch_id}/results");
+
+ let request_builder = HttpRequest::builder()
+ .method(Method::GET)
+ .uri(uri)
+ .header("Anthropic-Version", "2023-06-01")
+ .header("X-Api-Key", api_key.trim());
+
+ let http_request = request_builder
+ .body(AsyncBody::default())
+ .map_err(AnthropicError::BuildRequestBody)?;
+
+ let mut response = client
+ .send(http_request)
+ .await
+ .map_err(AnthropicError::HttpSend)?;
+
+ let rate_limits = RateLimitInfo::from_headers(response.headers());
+
+ if response.status().is_success() {
+ let mut body = String::new();
+ response
+ .body_mut()
+ .read_to_string(&mut body)
+ .await
+ .map_err(AnthropicError::ReadResponse)?;
+
+ let mut results = Vec::new();
+ for line in body.lines() {
+ if line.trim().is_empty() {
+ continue;
+ }
+ let result: BatchIndividualResponse =
+ serde_json::from_str(line).map_err(AnthropicError::DeserializeResponse)?;
+ results.push(result);
+ }
+
+ Ok(results)
+ } else {
+ Err(crate::handle_error_response(response, rate_limits).await)
+ }
+}
@@ -13,8 +13,9 @@ name = "ep_cli"
path = "src/main.rs"
[dependencies]
-
anyhow.workspace = true
+anthropic.workspace = true
+http_client.workspace = true
chrono.workspace = true
clap.workspace = true
client.workspace = true
@@ -28,6 +29,7 @@ fs.workspace = true
futures.workspace = true
gpui.workspace = true
gpui_tokio.workspace = true
+indoc.workspace = true
language.workspace = true
language_extension.workspace = true
language_model.workspace = true
@@ -46,6 +48,8 @@ serde_json.workspace = true
settings.workspace = true
shellexpand.workspace = true
smol.workspace = true
+sqlez.workspace = true
+sqlez_macros.workspace = true
terminal_view.workspace = true
toml.workspace = true
util.workspace = true
@@ -3,6 +3,8 @@ use std::{
cell::RefCell,
fmt::{self, Display},
fs,
+ hash::Hash,
+ hash::Hasher,
io::Write,
mem,
path::{Path, PathBuf},
@@ -43,7 +45,7 @@ pub struct NamedExample {
pub example: Example,
}
-#[derive(Clone, Debug, Serialize, Deserialize)]
+#[derive(Clone, Debug, Hash, Serialize, Deserialize)]
pub struct Example {
pub repository_url: String,
pub revision: String,
@@ -54,6 +56,134 @@ pub struct Example {
pub expected_patch: String,
}
+impl Example {
+ fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
+ // git@github.com:owner/repo.git
+ if self.repository_url.contains('@') {
+ let (owner, repo) = self
+ .repository_url
+ .split_once(':')
+ .context("expected : in git url")?
+ .1
+ .split_once('/')
+ .context("expected / in git url")?;
+ Ok((
+ Cow::Borrowed(owner),
+ Cow::Borrowed(repo.trim_end_matches(".git")),
+ ))
+ // http://github.com/owner/repo.git
+ } else {
+ let url = Url::parse(&self.repository_url)?;
+ let mut segments = url.path_segments().context("empty http url")?;
+ let owner = segments
+ .next()
+ .context("expected owner path segment")?
+ .to_string();
+ let repo = segments
+ .next()
+ .context("expected repo path segment")?
+ .trim_end_matches(".git")
+ .to_string();
+ assert!(segments.next().is_none());
+
+ Ok((owner.into(), repo.into()))
+ }
+ }
+
+ pub async fn setup_worktree(&self, file_name: String) -> Result<PathBuf> {
+ let (repo_owner, repo_name) = self.repo_name()?;
+
+ let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
+ let repo_lock = lock_repo(&repo_dir).await;
+
+ if !repo_dir.is_dir() {
+ fs::create_dir_all(&repo_dir)?;
+ run_git(&repo_dir, &["init"]).await?;
+ run_git(
+ &repo_dir,
+ &["remote", "add", "origin", &self.repository_url],
+ )
+ .await?;
+ }
+
+ // Resolve the example to a revision, fetching it if needed.
+ let revision = run_git(
+ &repo_dir,
+ &["rev-parse", &format!("{}^{{commit}}", self.revision)],
+ )
+ .await;
+ let revision = if let Ok(revision) = revision {
+ revision
+ } else {
+ if run_git(
+ &repo_dir,
+ &["fetch", "--depth", "1", "origin", &self.revision],
+ )
+ .await
+ .is_err()
+ {
+ run_git(&repo_dir, &["fetch", "origin"]).await?;
+ }
+ let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?;
+ if revision != self.revision {
+ run_git(&repo_dir, &["tag", &self.revision, &revision]).await?;
+ }
+ revision
+ };
+
+ // Create the worktree for this example if needed.
+ let worktree_path = WORKTREES_DIR.join(&file_name).join(repo_name.as_ref());
+ if worktree_path.is_dir() {
+ run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
+ run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
+ run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
+ } else {
+ let worktree_path_string = worktree_path.to_string_lossy();
+ run_git(&repo_dir, &["branch", "-f", &file_name, revision.as_str()]).await?;
+ run_git(
+ &repo_dir,
+ &["worktree", "add", "-f", &worktree_path_string, &file_name],
+ )
+ .await?;
+ }
+ drop(repo_lock);
+
+ // Apply the uncommitted diff for this example.
+ if !self.uncommitted_diff.is_empty() {
+ let mut apply_process = smol::process::Command::new("git")
+ .current_dir(&worktree_path)
+ .args(&["apply", "-"])
+ .stdin(std::process::Stdio::piped())
+ .spawn()?;
+
+ let mut stdin = apply_process.stdin.take().unwrap();
+ stdin.write_all(self.uncommitted_diff.as_bytes()).await?;
+ stdin.close().await?;
+ drop(stdin);
+
+ let apply_result = apply_process.output().await?;
+ if !apply_result.status.success() {
+ anyhow::bail!(
+ "Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
+ apply_result.status,
+ String::from_utf8_lossy(&apply_result.stderr),
+ String::from_utf8_lossy(&apply_result.stdout),
+ );
+ }
+ }
+
+ Ok(worktree_path)
+ }
+
+ pub fn unique_name(&self) -> String {
+ let mut hasher = std::hash::DefaultHasher::new();
+ self.hash(&mut hasher);
+ let disambiguator = hasher.finish();
+ let hash = format!("{:04x}", disambiguator);
+ format!("{}_{}", &self.revision[..8], &hash[..4])
+ }
+}
+
pub type ActualExcerpt = Excerpt;
#[derive(Clone, Debug, Serialize, Deserialize)]
@@ -292,90 +422,7 @@ impl NamedExample {
}
pub async fn setup_worktree(&self) -> Result<PathBuf> {
- let (repo_owner, repo_name) = self.repo_name()?;
- let file_name = self.file_name();
-
- let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
- let repo_lock = lock_repo(&repo_dir).await;
-
- if !repo_dir.is_dir() {
- fs::create_dir_all(&repo_dir)?;
- run_git(&repo_dir, &["init"]).await?;
- run_git(
- &repo_dir,
- &["remote", "add", "origin", &self.example.repository_url],
- )
- .await?;
- }
-
- // Resolve the example to a revision, fetching it if needed.
- let revision = run_git(
- &repo_dir,
- &[
- "rev-parse",
- &format!("{}^{{commit}}", self.example.revision),
- ],
- )
- .await;
- let revision = if let Ok(revision) = revision {
- revision
- } else {
- run_git(
- &repo_dir,
- &["fetch", "--depth", "1", "origin", &self.example.revision],
- )
- .await?;
- let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?;
- if revision != self.example.revision {
- run_git(&repo_dir, &["tag", &self.example.revision, &revision]).await?;
- }
- revision
- };
-
- // Create the worktree for this example if needed.
- let worktree_path = WORKTREES_DIR.join(&file_name).join(repo_name.as_ref());
- if worktree_path.is_dir() {
- run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
- run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
- run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
- } else {
- let worktree_path_string = worktree_path.to_string_lossy();
- run_git(&repo_dir, &["branch", "-f", &file_name, revision.as_str()]).await?;
- run_git(
- &repo_dir,
- &["worktree", "add", "-f", &worktree_path_string, &file_name],
- )
- .await?;
- }
- drop(repo_lock);
-
- // Apply the uncommitted diff for this example.
- if !self.example.uncommitted_diff.is_empty() {
- let mut apply_process = smol::process::Command::new("git")
- .current_dir(&worktree_path)
- .args(&["apply", "-"])
- .stdin(std::process::Stdio::piped())
- .spawn()?;
-
- let mut stdin = apply_process.stdin.take().unwrap();
- stdin
- .write_all(self.example.uncommitted_diff.as_bytes())
- .await?;
- stdin.close().await?;
- drop(stdin);
-
- let apply_result = apply_process.output().await?;
- if !apply_result.status.success() {
- anyhow::bail!(
- "Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
- apply_result.status,
- String::from_utf8_lossy(&apply_result.stderr),
- String::from_utf8_lossy(&apply_result.stdout),
- );
- }
- }
-
- Ok(worktree_path)
+ self.example.setup_worktree(self.file_name()).await
}
pub fn file_name(&self) -> String {
@@ -391,40 +438,6 @@ impl NamedExample {
.collect()
}
- fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
- // git@github.com:owner/repo.git
- if self.example.repository_url.contains('@') {
- let (owner, repo) = self
- .example
- .repository_url
- .split_once(':')
- .context("expected : in git url")?
- .1
- .split_once('/')
- .context("expected / in git url")?;
- Ok((
- Cow::Borrowed(owner),
- Cow::Borrowed(repo.trim_end_matches(".git")),
- ))
- // http://github.com/owner/repo.git
- } else {
- let url = Url::parse(&self.example.repository_url)?;
- let mut segments = url.path_segments().context("empty http url")?;
- let owner = segments
- .next()
- .context("expected owner path segment")?
- .to_string();
- let repo = segments
- .next()
- .context("expected repo path segment")?
- .trim_end_matches(".git")
- .to_string();
- assert!(segments.next().is_none());
-
- Ok((owner.into(), repo.into()))
- }
- }
-
pub async fn cursor_position(
&self,
project: &Entity<Project>,
@@ -5,6 +5,7 @@ mod metrics;
mod paths;
mod predict;
mod source_location;
+mod training;
mod util;
use crate::{
@@ -13,9 +14,10 @@ use crate::{
headless::ZetaCliAppState,
predict::run_predict,
source_location::SourceLocation,
+ training::{context::ContextType, distill::run_distill},
util::{open_buffer, open_buffer_with_language_server},
};
-use ::util::paths::PathStyle;
+use ::util::{ResultExt, paths::PathStyle};
use anyhow::{Result, anyhow};
use clap::{Args, Parser, Subcommand, ValueEnum};
use cloud_llm_client::predict_edits_v3;
@@ -43,6 +45,7 @@ enum Command {
Context(ContextArgs),
Predict(PredictArguments),
Eval(EvaluateArguments),
+ Distill(DistillArguments),
ConvertExample {
path: PathBuf,
#[arg(long, value_enum, default_value_t = ExampleFormat::Md)]
@@ -111,6 +114,15 @@ pub struct PredictArguments {
options: PredictionOptions,
}
+#[derive(Debug, Args)]
+pub struct DistillArguments {
+ split_commit_dataset: PathBuf,
+ #[clap(long, value_enum, default_value_t = ContextType::CurrentFile)]
+ context_type: ContextType,
+ #[clap(long)]
+ batch: Option<String>,
+}
+
#[derive(Clone, Debug, Args)]
pub struct PredictionOptions {
#[clap(flatten)]
@@ -468,6 +480,13 @@ fn main() {
Some(Command::Eval(arguments)) => {
run_evaluate(arguments, &app_state, cx).await;
}
+ Some(Command::Distill(arguments)) => {
+ let _guard = cx
+ .update(|cx| gpui_tokio::Tokio::handle(cx))
+ .unwrap()
+ .enter();
+ run_distill(arguments).await.log_err();
+ }
Some(Command::ConvertExample {
path,
output_format,
@@ -0,0 +1,89 @@
+use std::path::Path;
+
+use crate::{source_location::SourceLocation, training::teacher::TeacherModel};
+
+#[derive(Debug, Clone, Default, clap::ValueEnum)]
+pub enum ContextType {
+ #[default]
+ CurrentFile,
+}
+
+const MAX_CONTEXT_SIZE: usize = 32768;
+
+pub fn collect_context(
+ context_type: &ContextType,
+ worktree_dir: &Path,
+ cursor: SourceLocation,
+) -> String {
+ let context = match context_type {
+ ContextType::CurrentFile => {
+ let file_path = worktree_dir.join(cursor.path.as_std_path());
+ let context = std::fs::read_to_string(&file_path).unwrap_or_default();
+
+ let context = add_special_tags(&context, worktree_dir, cursor);
+ context
+ }
+ };
+
+ let region_end_offset = context.find(TeacherModel::REGION_END);
+
+ if context.len() <= MAX_CONTEXT_SIZE {
+ return context;
+ }
+
+ if let Some(region_end_offset) = region_end_offset
+ && region_end_offset + TeacherModel::REGION_END.len() > MAX_CONTEXT_SIZE
+ {
+ let to_truncate = context.len() - MAX_CONTEXT_SIZE;
+ format!(
+ "[...{} bytes truncated]\n{}\n",
+ to_truncate,
+ &context[to_truncate..]
+ )
+ } else {
+ format!(
+ "{}\n[...{} bytes truncated]\n",
+ &context[..MAX_CONTEXT_SIZE],
+ context.len() - MAX_CONTEXT_SIZE
+ )
+ }
+}
+
+/// Add <|editable_region_start/end|> tags
+fn add_special_tags(context: &str, worktree_dir: &Path, cursor: SourceLocation) -> String {
+ let path = worktree_dir.join(cursor.path.as_std_path());
+ let file = std::fs::read_to_string(&path).unwrap_or_default();
+ let lines = file.lines().collect::<Vec<_>>();
+ let cursor_row = cursor.point.row as usize;
+ let start_line = cursor_row.saturating_sub(TeacherModel::LEFT_CONTEXT_SIZE);
+ let end_line = (cursor_row + TeacherModel::RIGHT_CONTEXT_SIZE).min(lines.len());
+
+ let snippet = lines[start_line..end_line].join("\n");
+
+ if context.contains(&snippet) {
+ let mut cursor_line = lines[cursor_row].to_string();
+ cursor_line.insert_str(cursor.point.column as usize, TeacherModel::USER_CURSOR);
+
+ let mut snippet_with_tags_lines = vec![];
+ snippet_with_tags_lines.push(TeacherModel::REGION_START);
+ snippet_with_tags_lines.extend(&lines[start_line..cursor_row]);
+ snippet_with_tags_lines.push(&cursor_line);
+ snippet_with_tags_lines.extend(&lines[cursor_row + 1..end_line]);
+ snippet_with_tags_lines.push(TeacherModel::REGION_END);
+ let snippet_with_tags = snippet_with_tags_lines.join("\n");
+
+ context.replace(&snippet, &snippet_with_tags)
+ } else {
+ log::warn!(
+ "Can't find area around the cursor in the context; proceeding without special tags"
+ );
+ context.to_string()
+ }
+}
+
+pub fn strip_special_tags(context: &str) -> String {
+ context
+ .replace(TeacherModel::REGION_START, "")
+ .replace(TeacherModel::REGION_END, "")
+ .replace(TeacherModel::USER_CURSOR, "")
+}
@@ -0,0 +1,94 @@
+use serde::Deserialize;
+use std::sync::Arc;
+
+use crate::{
+ DistillArguments,
+ example::Example,
+ source_location::SourceLocation,
+ training::{
+ context::ContextType,
+ llm_client::LlmClient,
+ teacher::{TeacherModel, TeacherOutput},
+ },
+};
+use anyhow::Result;
+use reqwest_client::ReqwestClient;
+
+#[derive(Debug, Deserialize)]
+pub struct SplitCommit {
+ repo_url: String,
+ commit_sha: String,
+ edit_history: String,
+ expected_patch: String,
+ cursor_position: String,
+}
+
+pub async fn run_distill(arguments: DistillArguments) -> Result<()> {
+ let split_commits: Vec<SplitCommit> = std::fs::read_to_string(&arguments.split_commit_dataset)
+ .expect("Failed to read split commit dataset")
+ .lines()
+ .map(|line| serde_json::from_str(line).expect("Failed to parse JSON line"))
+ .collect();
+
+ let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
+
+ let llm_client = if let Some(cache_path) = arguments.batch {
+ LlmClient::batch(&cache_path, http_client)?
+ } else {
+ LlmClient::plain(http_client)?
+ };
+
+ let mut teacher = TeacherModel::new(
+ "claude-sonnet-4-5".to_string(),
+ ContextType::CurrentFile,
+ llm_client,
+ );
+
+ let mut num_marked_for_batching = 0;
+
+ for commit in split_commits {
+ if let Some(distilled) = distill_one(&mut teacher, commit).await? {
+ println!("{}", serde_json::to_string(&distilled)?);
+ } else {
+ if num_marked_for_batching == 0 {
+ log::warn!("Marked for batching");
+ }
+ num_marked_for_batching += 1;
+ }
+ }
+
+ eprintln!(
+ "{} requests are marked for batching",
+ num_marked_for_batching
+ );
+ let llm_client = teacher.client;
+ llm_client.sync_batches().await?;
+
+ Ok(())
+}
+
+pub async fn distill_one(
+ teacher: &mut TeacherModel,
+ commit: SplitCommit,
+) -> Result<Option<TeacherOutput>> {
+ let cursor: SourceLocation = commit
+ .cursor_position
+ .parse()
+ .expect("Failed to parse cursor position");
+
+ let path = cursor.path.to_rel_path_buf();
+
+ let example = Example {
+ repository_url: commit.repo_url,
+ revision: commit.commit_sha,
+ uncommitted_diff: commit.edit_history.clone(),
+ cursor_path: path.as_std_path().to_path_buf(),
+ cursor_position: commit.cursor_position,
+ edit_history: commit.edit_history, // todo: trim
+ expected_patch: commit.expected_patch,
+ };
+
+ let prediction = teacher.predict(example).await;
+
+ prediction
+}
@@ -0,0 +1,417 @@
+use anthropic::{
+ ANTHROPIC_API_URL, Message, Request as AnthropicRequest, RequestContent,
+ Response as AnthropicResponse, Role, non_streaming_completion,
+};
+use anyhow::Result;
+use http_client::HttpClient;
+use indoc::indoc;
+use sqlez::bindable::Bind;
+use sqlez::bindable::StaticColumnCount;
+use sqlez_macros::sql;
+use std::hash::Hash;
+use std::hash::Hasher;
+use std::sync::Arc;
+
+pub struct PlainLlmClient {
+ http_client: Arc<dyn HttpClient>,
+ api_key: String,
+}
+
+impl PlainLlmClient {
+ fn new(http_client: Arc<dyn HttpClient>) -> Result<Self> {
+ let api_key = std::env::var("ANTHROPIC_API_KEY")
+ .map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
+ Ok(Self {
+ http_client,
+ api_key,
+ })
+ }
+
+ async fn generate(
+ &self,
+ model: String,
+ max_tokens: u64,
+ messages: Vec<Message>,
+ ) -> Result<AnthropicResponse> {
+ let request = AnthropicRequest {
+ model,
+ max_tokens,
+ messages,
+ tools: Vec::new(),
+ thinking: None,
+ tool_choice: None,
+ system: None,
+ metadata: None,
+ stop_sequences: Vec::new(),
+ temperature: None,
+ top_k: None,
+ top_p: None,
+ };
+
+ let response = non_streaming_completion(
+ self.http_client.as_ref(),
+ ANTHROPIC_API_URL,
+ &self.api_key,
+ request,
+ None,
+ )
+ .await
+ .map_err(|e| anyhow::anyhow!("{:?}", e))?;
+
+ Ok(response)
+ }
+}
+
+pub struct BatchingLlmClient {
+ connection: sqlez::connection::Connection,
+ http_client: Arc<dyn HttpClient>,
+ api_key: String,
+}
+
+struct CacheRow {
+ request_hash: String,
+ request: Option<String>,
+ response: Option<String>,
+ batch_id: Option<String>,
+}
+
+impl StaticColumnCount for CacheRow {
+ fn column_count() -> usize {
+ 4
+ }
+}
+
+impl Bind for CacheRow {
+ fn bind(&self, statement: &sqlez::statement::Statement, start_index: i32) -> Result<i32> {
+ let next_index = statement.bind(&self.request_hash, start_index)?;
+ let next_index = statement.bind(&self.request, next_index)?;
+ let next_index = statement.bind(&self.response, next_index)?;
+ let next_index = statement.bind(&self.batch_id, next_index)?;
+ Ok(next_index)
+ }
+}
+
+#[derive(serde::Serialize, serde::Deserialize)]
+struct SerializableRequest {
+ model: String,
+ max_tokens: u64,
+ messages: Vec<SerializableMessage>,
+}
+
+#[derive(serde::Serialize, serde::Deserialize)]
+struct SerializableMessage {
+ role: String,
+ content: String,
+}
+
+impl BatchingLlmClient {
+ fn new(cache_path: &str, http_client: Arc<dyn HttpClient>) -> Result<Self> {
+ let api_key = std::env::var("ANTHROPIC_API_KEY")
+ .map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
+
+ let connection = sqlez::connection::Connection::open_file(&cache_path);
+ let mut statement = sqlez::statement::Statement::prepare(
+ &connection,
+ indoc! {"
+ CREATE TABLE IF NOT EXISTS cache (
+ request_hash TEXT PRIMARY KEY,
+ request TEXT,
+ response TEXT,
+ batch_id TEXT
+ );
+ "},
+ )?;
+ statement.exec()?;
+ drop(statement);
+
+ Ok(Self {
+ connection,
+ http_client,
+ api_key,
+ })
+ }
+
+ pub fn lookup(
+ &self,
+ model: &str,
+ max_tokens: u64,
+ messages: &[Message],
+ ) -> Result<Option<AnthropicResponse>> {
+ let request_hash_str = Self::request_hash(model, max_tokens, messages);
+ let response: Vec<String> = self.connection.select_bound(
+ &sql!(SELECT response FROM cache WHERE request_hash = ?1 AND response IS NOT NULL;),
+ )?(request_hash_str.as_str())?;
+ Ok(response
+ .into_iter()
+ .next()
+ .and_then(|text| serde_json::from_str(&text).ok()))
+ }
+
+ pub fn mark_for_batch(&self, model: &str, max_tokens: u64, messages: &[Message]) -> Result<()> {
+ let request_hash = Self::request_hash(model, max_tokens, messages);
+
+ let serializable_messages: Vec<SerializableMessage> = messages
+ .iter()
+ .map(|msg| SerializableMessage {
+ role: match msg.role {
+ Role::User => "user".to_string(),
+ Role::Assistant => "assistant".to_string(),
+ },
+ content: message_content_to_string(&msg.content),
+ })
+ .collect();
+
+ let serializable_request = SerializableRequest {
+ model: model.to_string(),
+ max_tokens,
+ messages: serializable_messages,
+ };
+
+ let request = Some(serde_json::to_string(&serializable_request)?);
+ let cache_row = CacheRow {
+ request_hash,
+ request,
+ response: None,
+ batch_id: None,
+ };
+ self.connection.exec_bound(sql!(
+ INSERT OR IGNORE INTO cache(request_hash, request, response, batch_id) VALUES (?, ?, ?, ?)))?(
+ cache_row,
+ )
+ }
+
+ async fn generate(
+ &self,
+ model: String,
+ max_tokens: u64,
+ messages: Vec<Message>,
+ ) -> Result<Option<AnthropicResponse>> {
+ let response = self.lookup(&model, max_tokens, &messages)?;
+ if let Some(response) = response {
+ return Ok(Some(response));
+ }
+
+ self.mark_for_batch(&model, max_tokens, &messages)?;
+
+ Ok(None)
+ }
+
+ /// Uploads pending requests as a new batch; downloads finished batches if any.
+ async fn sync_batches(&self) -> Result<()> {
+ self.upload_pending_requests().await?;
+ self.download_finished_batches().await
+ }
+
+ async fn download_finished_batches(&self) -> Result<()> {
+ let q = sql!(SELECT DISTINCT batch_id FROM cache WHERE batch_id IS NOT NULL AND response IS NULL);
+ let batch_ids: Vec<String> = self.connection.select(q)?()?;
+
+ for batch_id in batch_ids {
+ let batch_status = anthropic::batches::retrieve_batch(
+ self.http_client.as_ref(),
+ ANTHROPIC_API_URL,
+ &self.api_key,
+ &batch_id,
+ )
+ .await
+ .map_err(|e| anyhow::anyhow!("{:?}", e))?;
+
+ log::info!(
+ "Batch {} status: {}",
+ batch_id,
+ batch_status.processing_status
+ );
+
+ if batch_status.processing_status == "ended" {
+ let results = anthropic::batches::retrieve_batch_results(
+ self.http_client.as_ref(),
+ ANTHROPIC_API_URL,
+ &self.api_key,
+ &batch_id,
+ )
+ .await
+ .map_err(|e| anyhow::anyhow!("{:?}", e))?;
+
+ let mut success_count = 0;
+ for result in results {
+ let request_hash = result
+ .custom_id
+ .strip_prefix("req_hash_")
+ .unwrap_or(&result.custom_id)
+ .to_string();
+
+ match result.result {
+ anthropic::batches::BatchResult::Succeeded { message } => {
+ let response_json = serde_json::to_string(&message)?;
+ let q = sql!(UPDATE cache SET response = ? WHERE request_hash = ?);
+ self.connection.exec_bound(q)?((response_json, request_hash))?;
+ success_count += 1;
+ }
+ anthropic::batches::BatchResult::Errored { error } => {
+ log::error!("Batch request {} failed: {:?}", request_hash, error);
+ }
+ anthropic::batches::BatchResult::Canceled => {
+ log::warn!("Batch request {} was canceled", request_hash);
+ }
+ anthropic::batches::BatchResult::Expired => {
+ log::warn!("Batch request {} expired", request_hash);
+ }
+ }
+ }
+ log::info!("Uploaded {} successful requests", success_count);
+ }
+ }
+
+ Ok(())
+ }
+
+ async fn upload_pending_requests(&self) -> Result<String> {
+ let q = sql!(
+ SELECT request_hash, request FROM cache WHERE batch_id IS NULL AND response IS NULL
+ );
+
+ let rows: Vec<(String, String)> = self.connection.select(q)?()?;
+
+ if rows.is_empty() {
+ return Ok(String::new());
+ }
+
+ let batch_requests = rows
+ .iter()
+ .map(|(hash, request_str)| {
+ let serializable_request: SerializableRequest =
+ serde_json::from_str(&request_str).unwrap();
+
+ let messages: Vec<Message> = serializable_request
+ .messages
+ .into_iter()
+ .map(|msg| Message {
+ role: match msg.role.as_str() {
+ "user" => Role::User,
+ "assistant" => Role::Assistant,
+ _ => Role::User,
+ },
+ content: vec![RequestContent::Text {
+ text: msg.content,
+ cache_control: None,
+ }],
+ })
+ .collect();
+
+ let params = AnthropicRequest {
+ model: serializable_request.model,
+ max_tokens: serializable_request.max_tokens,
+ messages,
+ tools: Vec::new(),
+ thinking: None,
+ tool_choice: None,
+ system: None,
+ metadata: None,
+ stop_sequences: Vec::new(),
+ temperature: None,
+ top_k: None,
+ top_p: None,
+ };
+
+ let custom_id = format!("req_hash_{}", hash);
+ anthropic::batches::BatchRequest { custom_id, params }
+ })
+ .collect::<Vec<_>>();
+
+ let batch_len = batch_requests.len();
+ let batch = anthropic::batches::create_batch(
+ self.http_client.as_ref(),
+ ANTHROPIC_API_URL,
+ &self.api_key,
+ anthropic::batches::CreateBatchRequest {
+ requests: batch_requests,
+ },
+ )
+ .await
+ .map_err(|e| anyhow::anyhow!("{:?}", e))?;
+
+ let q = sql!(
+ UPDATE cache SET batch_id = ? WHERE batch_id is NULL
+ );
+ self.connection.exec_bound(q)?(batch.id.as_str())?;
+
+ log::info!("Uploaded batch with {} requests", batch_len);
+
+ Ok(batch.id)
+ }
+
+ fn request_hash(model: &str, max_tokens: u64, messages: &[Message]) -> String {
+ let mut hasher = std::hash::DefaultHasher::new();
+ model.hash(&mut hasher);
+ max_tokens.hash(&mut hasher);
+ for msg in messages {
+ message_content_to_string(&msg.content).hash(&mut hasher);
+ }
+ let request_hash = hasher.finish();
+ format!("{request_hash:016x}")
+ }
+}
+
+fn message_content_to_string(content: &[RequestContent]) -> String {
+ content
+ .iter()
+ .filter_map(|c| match c {
+ RequestContent::Text { text, .. } => Some(text.clone()),
+ _ => None,
+ })
+ .collect::<Vec<String>>()
+ .join("\n")
+}
+
+pub enum LlmClient {
+ // No batching
+ Plain(PlainLlmClient),
+ Batch(BatchingLlmClient),
+ Dummy,
+}
+
+impl LlmClient {
+ pub fn plain(http_client: Arc<dyn HttpClient>) -> Result<Self> {
+ Ok(Self::Plain(PlainLlmClient::new(http_client)?))
+ }
+
+ pub fn batch(cache_path: &str, http_client: Arc<dyn HttpClient>) -> Result<Self> {
+ Ok(Self::Batch(BatchingLlmClient::new(
+ cache_path,
+ http_client,
+ )?))
+ }
+
+ #[allow(dead_code)]
+ pub fn dummy() -> Self {
+ Self::Dummy
+ }
+
+ pub async fn generate(
+ &self,
+ model: String,
+ max_tokens: u64,
+ messages: Vec<Message>,
+ ) -> Result<Option<AnthropicResponse>> {
+ match self {
+ LlmClient::Plain(plain_llm_client) => plain_llm_client
+ .generate(model, max_tokens, messages)
+ .await
+ .map(Some),
+ LlmClient::Batch(batching_llm_client) => {
+ batching_llm_client
+ .generate(model, max_tokens, messages)
+ .await
+ }
+ LlmClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
+ }
+ }
+
+ pub async fn sync_batches(&self) -> Result<()> {
+ match self {
+ LlmClient::Plain(_) => Ok(()),
+ LlmClient::Batch(batching_llm_client) => batching_llm_client.sync_batches().await,
+ LlmClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
+ }
+ }
+}
@@ -0,0 +1,4 @@
+pub mod context;
+pub mod distill;
+pub mod llm_client;
+pub mod teacher;
@@ -0,0 +1,48 @@
+# Instructions
+
+You are a code completion assistant helping a programmer finish their work. Your task is to:
+
+1. Analyze the edit history to understand what the programmer is trying to achieve
+2. Identify any incomplete refactoring or changes that need to be finished
+3. Make the remaining edits that a human programmer would logically make next (by rewriting the corresponding code sections)
+4. Apply systematic changes consistently across the entire codebase - if you see a pattern starting, complete it everywhere.
+
+Focus on:
+- Understanding the intent behind the changes (e.g., improving error handling, refactoring APIs, fixing bugs)
+- Completing any partially-applied changes across the codebase
+- Ensuring consistency with the programming style and patterns already established
+- Making edits that maintain or improve code quality
+- If the programmer started refactoring one instance of a pattern, find and update ALL similar instances
+- Don't write a lot of code if you're not sure what to do
+
+Rules:
+- Do not just mechanically apply patterns - reason about what changes make sense given the context and the programmer's apparent goals.
+- Do not just fix syntax errors - look for the broader refactoring pattern and apply it systematically throughout the code.
+
+Input format:
+- You receive small code fragments called context (structs, field definitions, function signatures, etc.). They may or may not be relevant.
+- Never modify the context code.
+- You also receive a code snippet between <|editable_region_start|> and <|editable_region_end|>. This is the editable region.
+- The cursor position is marked with <|user_cursor|>.
+
+Output format:
+- Return the entire editable region, applying any edits you make.
+- Remove the <|user_cursor|> marker.
+- Wrap the edited code in a block of exactly five backticks.
+
+Output example:
+`````
+ // `zed --askpass` Makes zed operate in nc/netcat mode for use with askpass
+ if let Some(socket) = &args.askpass {{
+ askpass::main(socket);
+ return Ok(());
+ }}
+`````
+
+## User Edits History
+
+{{edit_history}}
+
+## Code Context
+
+{{context}}
@@ -0,0 +1,266 @@
+use crate::{
+ example::Example,
+ source_location::SourceLocation,
+ training::{
+ context::{ContextType, collect_context, strip_special_tags},
+ llm_client::LlmClient,
+ },
+};
+use anthropic::{Message, RequestContent, ResponseContent, Role};
+use anyhow::Result;
+
+pub struct TeacherModel {
+ pub llm_name: String,
+ pub context: ContextType,
+ pub client: LlmClient,
+}
+
+#[derive(Debug, serde::Serialize)]
+pub struct TeacherOutput {
+ parsed_output: String,
+ prompt: String,
+ raw_llm_response: String,
+ context: String,
+ diff: String,
+}
+
+impl TeacherModel {
+ const PROMPT: &str = include_str!("teacher.prompt.md");
+ pub(crate) const REGION_START: &str = "<|editable_region_start|>\n";
+ pub(crate) const REGION_END: &str = "<|editable_region_end|>";
+ pub(crate) const USER_CURSOR: &str = "<|user_cursor|>";
+
+ /// Number of lines to include before the cursor position
+ pub(crate) const LEFT_CONTEXT_SIZE: usize = 5;
+
+ /// Number of lines to include after the cursor position
+ pub(crate) const RIGHT_CONTEXT_SIZE: usize = 5;
+
+ /// Truncate edit history to this number of last lines
+ const MAX_HISTORY_LINES: usize = 128;
+
+ pub fn new(llm_name: String, context: ContextType, client: LlmClient) -> Self {
+ TeacherModel {
+ llm_name,
+ context,
+ client,
+ }
+ }
+
+ pub async fn predict(&self, input: Example) -> Result<Option<TeacherOutput>> {
+ let name = input.unique_name();
+ let worktree_dir = input.setup_worktree(name).await?;
+ let cursor: SourceLocation = input
+ .cursor_position
+ .parse()
+ .expect("Failed to parse cursor position");
+
+ let context = collect_context(&self.context, &worktree_dir, cursor.clone());
+ let edit_history = Self::format_edit_history(&input.edit_history);
+
+ let prompt = Self::PROMPT
+ .replace("{{context}}", &context)
+ .replace("{{edit_history}}", &edit_history);
+
+ let messages = vec![Message {
+ role: Role::User,
+ content: vec![RequestContent::Text {
+ text: prompt.clone(),
+ cache_control: None,
+ }],
+ }];
+
+ let Some(response) = self
+ .client
+ .generate(self.llm_name.clone(), 16384, messages)
+ .await?
+ else {
+ return Ok(None);
+ };
+
+ let response_text = response
+ .content
+ .into_iter()
+ .filter_map(|content| match content {
+ ResponseContent::Text { text } => Some(text),
+ _ => None,
+ })
+ .collect::<Vec<String>>()
+ .join("\n");
+
+ let parsed_output = self.parse_response(&response_text);
+
+ let original_editable_region = Self::extract_editable_region(&context);
+ let context_after_edit = context.replace(&original_editable_region, &parsed_output);
+ let context_after_edit = strip_special_tags(&context_after_edit);
+ let context_before_edit = strip_special_tags(&context);
+ let diff = language::unified_diff(&context_before_edit, &context_after_edit);
+
+ // zeta distill --batch batch_results.txt
+ // zeta distill
+ // 1. Run `zeta distill <2000 examples <- all examples>` for the first time
+ // - store LLM requests in a batch, don't actual send the request
+ // - send the batch (2000 requests) after all inputs are processed
+ // 2. `zeta send-batches`
+ // - upload the batch to Anthropic
+
+ // https://platform.claude.com/docs/en/build-with-claude/batch-processing
+ // https://crates.io/crates/anthropic-sdk-rust
+
+ // - poll for results
+ // - when ready, store results in cache (a database)
+ // 3. `zeta distill` again
+ // - use the cached results this time
+
+ Ok(Some(TeacherOutput {
+ parsed_output,
+ prompt,
+ raw_llm_response: response_text,
+ context,
+ diff,
+ }))
+ }
+
+ fn parse_response(&self, content: &str) -> String {
+ let codeblock = Self::extract_last_codeblock(content);
+ let editable_region = Self::extract_editable_region(&codeblock);
+
+ editable_region
+ }
+
+ /// Extract content from the last code-fenced block if any, or else return content as is
+ fn extract_last_codeblock(text: &str) -> String {
+ let mut last_block = None;
+ let mut search_start = 0;
+
+ while let Some(start) = text[search_start..].find("```") {
+ let start = start + search_start;
+ let bytes = text.as_bytes();
+ let mut backtick_end = start;
+
+ while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
+ backtick_end += 1;
+ }
+
+ let backtick_count = backtick_end - start;
+ let closing_backticks = "`".repeat(backtick_count);
+
+ if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) {
+ let code_block = &text[backtick_end + 1..backtick_end + end_pos - 1];
+ last_block = Some(code_block.to_string());
+ search_start = backtick_end + end_pos + backtick_count;
+ } else {
+ break;
+ }
+ }
+
+ last_block.unwrap_or_else(|| text.to_string())
+ }
+
+ fn extract_editable_region(text: &str) -> String {
+ let start = text
+ .find(Self::REGION_START)
+ .map_or(0, |pos| pos + Self::REGION_START.len());
+ let end = text.find(Self::REGION_END).unwrap_or(text.len());
+
+ text[start..end].to_string()
+ }
+
+ /// Truncates edit history to a maximum length and removes comments (unified diff garbage lines)
+ fn format_edit_history(edit_history: &str) -> String {
+ let lines = edit_history
+ .lines()
+ .filter(|&s| Self::is_content_line(s))
+ .collect::<Vec<_>>();
+
+ let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
+ &lines[lines.len() - Self::MAX_HISTORY_LINES..]
+ } else {
+ &lines
+ };
+ history_lines.join("\n")
+ }
+
+ fn is_content_line(s: &str) -> bool {
+ s.starts_with("-")
+ || s.starts_with("+")
+ || s.starts_with(" ")
+ || s.starts_with("---")
+ || s.starts_with("+++")
+ || s.starts_with("@@")
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_parse_response() {
+ let teacher = TeacherModel::new(
+ "test".to_string(),
+ ContextType::CurrentFile,
+ LlmClient::dummy(),
+ );
+ let response = "This is a test response.";
+ let parsed = teacher.parse_response(response);
+ assert_eq!(parsed, response.to_string());
+
+ let response = indoc::indoc! {"
+ Some thinking
+
+ `````
+ actual response
+ `````
+ "};
+ let parsed = teacher.parse_response(response);
+ assert_eq!(parsed, "actual response");
+ }
+
+ #[test]
+ fn test_extract_last_code_block() {
+ let text = indoc::indoc! {"
+ Some thinking
+
+ ```
+ first block
+ ```
+
+ `````
+ last block
+ `````
+ "};
+ let last_block = TeacherModel::extract_last_codeblock(text);
+ assert_eq!(last_block, "last block");
+ }
+
+ #[test]
+ fn test_extract_editable_region() {
+ let teacher = TeacherModel::new(
+ "test".to_string(),
+ ContextType::CurrentFile,
+ LlmClient::dummy(),
+ );
+ let response = indoc::indoc! {"
+ some lines
+ are
+ here
+ <|editable_region_start|>
+ one
+ two three
+
+ <|editable_region_end|>
+ more
+ lines here
+ "};
+ let parsed = teacher.parse_response(response);
+ assert_eq!(
+ parsed,
+ indoc::indoc! {"
+ one
+ two three
+
+ "}
+ );
+ }
+}