diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index 0ade5e3f3fd30ef8139bc90e1c96e7b325395d5c..a2b640df0bd3ec028b86b91a8d6a6530f176a506 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -22,6 +22,7 @@ mod reversal_tracking; mod score; mod split_commit; mod split_dataset; +mod sync_deployments; mod synthesize; mod truncate_expected_patch; mod word_diff; @@ -209,6 +210,8 @@ enum Command { Repair(repair::RepairArgs), /// Print all valid zeta formats (lowercase, one per line) PrintZetaFormats, + /// Sync baseten deployment metadata to Snowflake for linking predictions to experiments + SyncDeployments(SyncDeploymentsArgs), } impl Display for Command { @@ -254,10 +257,20 @@ impl Display for Command { Command::PrintZetaFormats => { write!(f, "print-zeta-formats") } + Command::SyncDeployments(_) => { + write!(f, "sync-deployments") + } } } } +#[derive(Debug, Args, Clone)] +struct SyncDeploymentsArgs { + /// BaseTen model name (default: "zeta-2") + #[arg(long)] + model: Option, +} + #[derive(Debug, Args, Clone)] struct FormatPromptArgs { #[clap(long, short('p'), default_value_t = PredictionProvider::default())] @@ -731,6 +744,19 @@ fn main() { } return; } + Command::SyncDeployments(sync_args) => { + let http_client: Arc = Arc::new(ReqwestClient::new()); + smol::block_on(async { + if let Err(e) = + sync_deployments::run_sync_deployments(http_client, sync_args.model.clone()) + .await + { + eprintln!("Error: {:?}", e); + std::process::exit(1); + } + }); + return; + } Command::Synthesize(synth_args) => { let Some(output_dir) = args.output else { panic!("output dir is required"); @@ -966,7 +992,8 @@ fn main() { | Command::TruncatePatch(_) | Command::FilterLanguages(_) | Command::ImportBatch(_) - | Command::PrintZetaFormats => { + | Command::PrintZetaFormats + | Command::SyncDeployments(_) => { unreachable!() } } diff --git a/crates/edit_prediction_cli/src/pull_examples.rs b/crates/edit_prediction_cli/src/pull_examples.rs index 198848f1d6c088b1103b516f71249ac9f81516d7..96ca30ee5c0094f43b8477591d414df9a77fca64 100644 --- a/crates/edit_prediction_cli/src/pull_examples.rs +++ b/crates/edit_prediction_cli/src/pull_examples.rs @@ -20,16 +20,16 @@ use edit_prediction::example_spec::{ }; use std::fmt::Write as _; -const SNOWFLAKE_SUCCESS_CODE: &str = "090001"; -const SNOWFLAKE_ASYNC_IN_PROGRESS_CODE: &str = "333334"; +pub(crate) const SNOWFLAKE_SUCCESS_CODE: &str = "090001"; +pub(crate) const SNOWFLAKE_ASYNC_IN_PROGRESS_CODE: &str = "333334"; const EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT: &str = "Edit Prediction Example Captured"; const PREDICTIVE_EDIT_REQUESTED_EVENT: &str = "Predictive Edit Requested"; const PREDICTIVE_EDIT_REJECTED_EVENT: &str = "Predictive Edit Rejected"; const EDIT_PREDICTION_RATED_EVENT: &str = "Edit Prediction Rated"; const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120; -const POLL_INTERVAL: Duration = Duration::from_secs(2); -const MAX_POLL_ATTEMPTS: usize = 120; +pub(crate) const POLL_INTERVAL: Duration = Duration::from_secs(2); +pub(crate) const MAX_POLL_ATTEMPTS: usize = 120; /// Parse an input token of the form `captured-after:{timestamp}`. pub fn parse_captured_after_input(input: &str) -> Option<&str> { @@ -187,22 +187,22 @@ pub async fn fetch_captured_examples_after( #[derive(Debug, Clone, Deserialize)] #[serde(rename_all = "camelCase")] -struct SnowflakeStatementResponse { +pub(crate) struct SnowflakeStatementResponse { #[serde(default)] - data: Vec>, + pub(crate) data: Vec>, #[serde(default)] - result_set_meta_data: Option, + pub(crate) result_set_meta_data: Option, #[serde(default)] - code: Option, + pub(crate) code: Option, #[serde(default)] - message: Option, + pub(crate) message: Option, #[serde(default)] - statement_handle: Option, + pub(crate) statement_handle: Option, } #[derive(Debug, Clone, Deserialize)] #[serde(rename_all = "camelCase")] -struct SnowflakeResultSetMetaData { +pub(crate) struct SnowflakeResultSetMetaData { #[serde(default, rename = "rowType")] row_type: Vec, #[serde(default)] @@ -313,7 +313,7 @@ async fn run_sql_with_polling( Ok(response) } -async fn fetch_partition( +pub(crate) async fn fetch_partition( http_client: Arc, base_url: &str, token: &str, @@ -402,7 +402,7 @@ async fn fetch_partition( }) } -async fn run_sql( +pub(crate) async fn run_sql( http_client: Arc, base_url: &str, token: &str, @@ -1344,7 +1344,7 @@ fn build_output_patch( patch } -fn get_column_indices( +pub(crate) fn get_column_indices( meta: &Option, names: &[&str], ) -> std::collections::HashMap { diff --git a/crates/edit_prediction_cli/src/sync_deployments.rs b/crates/edit_prediction_cli/src/sync_deployments.rs new file mode 100644 index 0000000000000000000000000000000000000000..b3104c8462f5d7fbd6bca0c9cde7943d4664da62 --- /dev/null +++ b/crates/edit_prediction_cli/src/sync_deployments.rs @@ -0,0 +1,578 @@ +use anyhow::{Context as _, Result}; +use http_client::{AsyncBody, HttpClient, Method, Request}; +use serde::Deserialize; +use serde_json::{Value as JsonValue, json}; +use std::collections::HashMap; +use std::sync::Arc; + +use crate::pull_examples::{ + self, MAX_POLL_ATTEMPTS, POLL_INTERVAL, SNOWFLAKE_ASYNC_IN_PROGRESS_CODE, + SNOWFLAKE_SUCCESS_CODE, +}; + +const DEFAULT_BASETEN_MODEL_NAME: &str = "zeta-2"; +const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120; +const EDIT_PREDICTION_DEPLOYMENT_EVENT: &str = "Edit Prediction Deployment"; + +#[derive(Debug, Clone, Deserialize)] +struct BasetenModelsResponse { + models: Vec, +} + +#[derive(Debug, Clone, Deserialize)] +struct BasetenModel { + id: String, + name: String, +} + +#[derive(Debug, Clone, Deserialize)] +struct BasetenDeploymentsResponse { + deployments: Vec, +} + +#[derive(Debug, Clone, Deserialize)] +struct BasetenDeployment { + id: String, + name: String, + #[serde(default)] + status: Option, + #[serde(default)] + created_at: Option, + #[serde(default)] + environment: Option, +} + +#[derive(Debug, Clone)] +struct DeploymentRecord { + model_id: String, + model_version_id: String, + experiment_name: String, + environment: String, + status: String, + created_at: String, +} + +#[derive(Debug, Clone)] +struct ExistingDeployment { + experiment_name: String, + environment: String, +} + +async fn fetch_baseten_models( + http_client: &Arc, + api_key: &str, +) -> Result> { + let request = Request::builder() + .method(Method::GET) + .uri("https://api.baseten.co/v1/models") + .header("Authorization", format!("Api-Key {api_key}")) + .header("Accept", "application/json") + .body(AsyncBody::empty())?; + + let response = http_client + .send(request) + .await + .context("failed to fetch baseten models")?; + + let status = response.status(); + let body_bytes = { + use futures::AsyncReadExt as _; + let mut body = response.into_body(); + let mut bytes = Vec::new(); + body.read_to_end(&mut bytes) + .await + .context("failed to read baseten models response")?; + bytes + }; + + if !status.is_success() { + let body_text = String::from_utf8_lossy(&body_bytes); + anyhow::bail!("baseten models API http {}: {}", status.as_u16(), body_text); + } + + let parsed: BasetenModelsResponse = + serde_json::from_slice(&body_bytes).context("failed to parse baseten models response")?; + Ok(parsed.models) +} + +async fn fetch_baseten_deployments( + http_client: &Arc, + api_key: &str, + model_id: &str, +) -> Result> { + let url = format!("https://api.baseten.co/v1/models/{model_id}/deployments"); + let request = Request::builder() + .method(Method::GET) + .uri(url.as_str()) + .header("Authorization", format!("Api-Key {api_key}")) + .header("Accept", "application/json") + .body(AsyncBody::empty())?; + + let response = http_client + .send(request) + .await + .context("failed to fetch baseten deployments")?; + + let status = response.status(); + let body_bytes = { + use futures::AsyncReadExt as _; + let mut body = response.into_body(); + let mut bytes = Vec::new(); + body.read_to_end(&mut bytes) + .await + .context("failed to read baseten deployments response")?; + bytes + }; + + if !status.is_success() { + let body_text = String::from_utf8_lossy(&body_bytes); + anyhow::bail!( + "baseten deployments API http {}: {}", + status.as_u16(), + body_text + ); + } + + let parsed: BasetenDeploymentsResponse = + serde_json::from_slice(&body_bytes).context("failed to parse deployments response")?; + Ok(parsed.deployments) +} + +fn collect_deployment_records( + model_id: &str, + deployments: &[BasetenDeployment], +) -> Vec { + deployments + .iter() + .map(|deployment| DeploymentRecord { + model_id: model_id.to_string(), + model_version_id: deployment.id.clone(), + experiment_name: deployment.name.clone(), + environment: deployment + .environment + .clone() + .unwrap_or_else(|| "none".to_string()), + status: deployment + .status + .clone() + .unwrap_or_else(|| "unknown".to_string()), + created_at: deployment + .created_at + .clone() + .unwrap_or_else(|| "unknown".to_string()), + }) + .collect() +} + +async fn run_sql_with_polling( + http_client: Arc, + base_url: &str, + token: &str, + request: &serde_json::Value, +) -> Result { + let mut response = + pull_examples::run_sql(http_client.clone(), base_url, token, request).await?; + + if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) { + let statement_handle = response + .statement_handle + .as_ref() + .context("async query response missing statementHandle")? + .clone(); + + for _attempt in 1..=MAX_POLL_ATTEMPTS { + std::thread::sleep(POLL_INTERVAL); + + response = pull_examples::fetch_partition( + http_client.clone(), + base_url, + token, + &statement_handle, + 0, + ) + .await?; + + if response.code.as_deref() != Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) { + break; + } + } + + if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) { + anyhow::bail!( + "query still running after {} poll attempts ({} seconds)", + MAX_POLL_ATTEMPTS, + MAX_POLL_ATTEMPTS as u64 * POLL_INTERVAL.as_secs() + ); + } + } + + if let Some(code) = &response.code { + if code != SNOWFLAKE_SUCCESS_CODE { + anyhow::bail!( + "snowflake error: code={} message={}", + code, + response.message.as_deref().unwrap_or("") + ); + } + } + + Ok(response) +} + +async fn fetch_existing_deployments( + http_client: &Arc, + base_url: &str, + token: &str, + role: &Option, +) -> Result> { + let statement = format!( + r#" +SELECT + event_properties:model_version_id::string AS model_version_id, + event_properties:experiment_name::string AS experiment_name, + event_properties:environment::string AS environment +FROM events +WHERE event_type = '{EDIT_PREDICTION_DEPLOYMENT_EVENT}' +"# + ); + + let request = json!({ + "statement": statement, + "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS, + "database": "EVENTS", + "schema": "PUBLIC", + "warehouse": "DBT", + "role": role, + }); + + let response = run_sql_with_polling(http_client.clone(), base_url, token, &request).await?; + + let col_names = ["model_version_id", "experiment_name", "environment"]; + let column_indices = + pull_examples::get_column_indices(&response.result_set_meta_data, &col_names); + + let mut existing = HashMap::new(); + + for data_row in &response.data { + let get_string = |name: &str| -> Option { + let &index = column_indices.get(name)?; + match data_row.get(index) { + Some(JsonValue::String(s)) => Some(s.clone()), + _ => None, + } + }; + + let Some(model_version_id) = get_string("model_version_id") else { + continue; + }; + let experiment_name = get_string("experiment_name").unwrap_or_default(); + let environment = get_string("environment").unwrap_or_default(); + + existing.insert( + model_version_id, + ExistingDeployment { + experiment_name, + environment, + }, + ); + } + + Ok(existing) +} + +async fn insert_deployment( + http_client: &Arc, + base_url: &str, + token: &str, + role: &Option, + record: &DeploymentRecord, +) -> Result<()> { + let event_properties = json!({ + "model_id": record.model_id, + "model_version_id": record.model_version_id, + "experiment_name": record.experiment_name, + "environment": record.environment, + "status": record.status, + "created_at": record.created_at, + }); + + let event_properties_str = + serde_json::to_string(&event_properties).context("failed to serialize event_properties")?; + + let statement = r#" +INSERT INTO events (event_type, event_properties, device_id, time) +VALUES (?, PARSE_JSON(?), 'ep-cli', CURRENT_TIMESTAMP()) +"#; + + let bindings = json!({ + "1": { "type": "TEXT", "value": EDIT_PREDICTION_DEPLOYMENT_EVENT }, + "2": { "type": "TEXT", "value": event_properties_str } + }); + + let request = json!({ + "statement": statement, + "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS, + "database": "EVENTS", + "schema": "PUBLIC", + "warehouse": "DBT", + "role": role, + "bindings": bindings + }); + + run_sql_with_polling(http_client.clone(), base_url, token, &request).await?; + Ok(()) +} + +async fn update_deployment( + http_client: &Arc, + base_url: &str, + token: &str, + role: &Option, + record: &DeploymentRecord, +) -> Result<()> { + let statement = format!( + r#" +UPDATE events +SET + event_properties = OBJECT_INSERT( + OBJECT_INSERT(event_properties, 'environment', ?::VARIANT, true), + 'experiment_name', ?::VARIANT, true + ), + time = CURRENT_TIMESTAMP() +WHERE event_type = '{EDIT_PREDICTION_DEPLOYMENT_EVENT}' + AND event_properties:model_version_id::string = ? +"# + ); + + let bindings = json!({ + "1": { "type": "TEXT", "value": record.environment }, + "2": { "type": "TEXT", "value": record.experiment_name }, + "3": { "type": "TEXT", "value": record.model_version_id } + }); + + let request = json!({ + "statement": statement, + "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS, + "database": "EVENTS", + "schema": "PUBLIC", + "warehouse": "DBT", + "role": role, + "bindings": bindings + }); + + run_sql_with_polling(http_client.clone(), base_url, token, &request).await?; + Ok(()) +} + +fn display_deployments(existing: &HashMap) { + let col_names = ["version_id", "experiment", "environment"]; + + let mut col_widths: Vec = col_names.iter().map(|n| n.len()).collect(); + let mut rows: Vec<[String; 3]> = Vec::new(); + + for (version_id, deployment) in existing { + let row = [ + version_id.clone(), + deployment.experiment_name.clone(), + deployment.environment.clone(), + ]; + for (i, val) in row.iter().enumerate() { + col_widths[i] = col_widths[i].max(val.len()); + } + rows.push(row); + } + + rows.sort_by(|a, b| a[2].cmp(&b[2]).then_with(|| a[1].cmp(&b[1]))); + + let print_row = |values: &[&str]| { + for (i, val) in values.iter().enumerate() { + if i > 0 { + eprint!(" "); + } + eprint!("{:width$}", val, width = col_widths[i]); + } + eprintln!(); + }; + + eprintln!(); + print_row(&col_names); + + let separators: Vec = col_widths.iter().map(|w| "─".repeat(*w)).collect(); + let separator_refs: Vec<&str> = separators.iter().map(|s| s.as_str()).collect(); + print_row(&separator_refs); + + for row in &rows { + let refs: Vec<&str> = row.iter().map(|s| s.as_str()).collect(); + print_row(&refs); + } +} + +pub async fn run_sync_deployments( + http_client: Arc, + model_name: Option, +) -> Result<()> { + let baseten_api_key = std::env::var("BASETEN_API_KEY") + .context("missing required environment variable BASETEN_API_KEY")?; + let snowflake_token = std::env::var("EP_SNOWFLAKE_API_KEY") + .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?; + let snowflake_base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context( + "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://.snowflakecomputing.com)", + )?; + let snowflake_role = std::env::var("EP_SNOWFLAKE_ROLE").ok(); + + let model_name = model_name.unwrap_or_else(|| DEFAULT_BASETEN_MODEL_NAME.to_string()); + + let models = fetch_baseten_models(&http_client, &baseten_api_key).await?; + + let model = models + .iter() + .find(|m| m.name == model_name) + .with_context(|| { + let available: Vec<&str> = models.iter().map(|m| m.name.as_str()).collect(); + format!( + "model '{}' not found on baseten. Available: {:?}", + model_name, available + ) + })?; + + eprintln!("Fetching existing deployments from Snowflake..."); + let mut existing = fetch_existing_deployments( + &http_client, + &snowflake_base_url, + &snowflake_token, + &snowflake_role, + ) + .await + .context("failed to fetch existing deployments from Snowflake")?; + + eprintln!( + "Found {} existing deployment(s) in Snowflake.", + existing.len() + ); + + let baseten_deployments = fetch_baseten_deployments(&http_client, &baseten_api_key, &model.id) + .await + .with_context(|| format!("failed to fetch deployments for model '{}'", model.name))?; + + let records = collect_deployment_records(&model.id, &baseten_deployments); + + if records.is_empty() { + eprintln!("No deployments found on Baseten."); + return Ok(()); + } + + eprintln!( + "Found {} deployment(s) on Baseten for model '{}'.", + records.len(), + model.name + ); + + let mut inserts = Vec::new(); + let mut updates = Vec::new(); + let mut unchanged = 0; + + for record in &records { + match existing.get(&record.model_version_id) { + Some(existing_deployment) => { + let environment_changed = existing_deployment.environment != record.environment; + let experiment_changed = + existing_deployment.experiment_name != record.experiment_name; + + if environment_changed || experiment_changed { + updates.push(record); + } else { + unchanged += 1; + } + } + None => { + inserts.push(record); + } + } + } + + eprintln!( + "Diff: {} insert(s), {} update(s), {} unchanged", + inserts.len(), + updates.len(), + unchanged, + ); + + for (i, record) in inserts.iter().enumerate() { + eprintln!( + " INSERT [{}/{}] {} -> {} (version_id={})", + i + 1, + inserts.len(), + record.experiment_name, + record.environment, + record.model_version_id, + ); + insert_deployment( + &http_client, + &snowflake_base_url, + &snowflake_token, + &snowflake_role, + record, + ) + .await + .with_context(|| { + format!( + "failed to insert deployment '{}' (model_version_id={})", + record.experiment_name, record.model_version_id + ) + })?; + + existing.insert( + record.model_version_id.clone(), + ExistingDeployment { + experiment_name: record.experiment_name.clone(), + environment: record.environment.clone(), + }, + ); + } + + for (i, record) in updates.iter().enumerate() { + let existing_deployment = existing + .get(&record.model_version_id) + .context("update record missing from existing map")?; + eprintln!( + " UPDATE [{}/{}] version_id={}: environment '{}' -> '{}', experiment '{}' -> '{}'", + i + 1, + updates.len(), + record.model_version_id, + existing_deployment.environment, + record.environment, + existing_deployment.experiment_name, + record.experiment_name, + ); + update_deployment( + &http_client, + &snowflake_base_url, + &snowflake_token, + &snowflake_role, + record, + ) + .await + .with_context(|| { + format!( + "failed to update deployment '{}' (model_version_id={})", + record.experiment_name, record.model_version_id + ) + })?; + + existing.insert( + record.model_version_id.clone(), + ExistingDeployment { + experiment_name: record.experiment_name.clone(), + environment: record.environment.clone(), + }, + ); + } + + if inserts.is_empty() && updates.is_empty() { + eprintln!("All deployments up to date, no writes needed."); + } + + display_deployments(&existing); + + Ok(()) +}