diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index e62350b6da7ff55833664c17fce32cb32cf4dd36..01845b13c5621973c148988a168c89efb7a46210 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -22,7 +22,7 @@ mod reversal_tracking; mod score; mod split_commit; mod split_dataset; -mod sync_deployments; + mod synthesize; mod truncate_expected_patch; mod word_diff; @@ -210,8 +210,6 @@ enum Command { Repair(repair::RepairArgs), /// Print all valid zeta formats (lowercase, one per line) PrintZetaFormats, - /// Sync baseten deployment metadata to Snowflake for linking predictions to experiments - SyncDeployments(SyncDeploymentsArgs), } impl Display for Command { @@ -257,20 +255,10 @@ impl Display for Command { Command::PrintZetaFormats => { write!(f, "print-zeta-formats") } - Command::SyncDeployments(_) => { - write!(f, "sync-deployments") - } } } } -#[derive(Debug, Args, Clone)] -struct SyncDeploymentsArgs { - /// BaseTen model name (default: "zeta-2") - #[arg(long)] - model: Option, -} - #[derive(Debug, Args, Clone)] struct FormatPromptArgs { #[clap(long, short('p'), default_value_t = PredictionProvider::default())] @@ -777,19 +765,7 @@ fn main() { } return; } - Command::SyncDeployments(sync_args) => { - let http_client: Arc = Arc::new(ReqwestClient::new()); - smol::block_on(async { - if let Err(e) = - sync_deployments::run_sync_deployments(http_client, sync_args.model.clone()) - .await - { - eprintln!("Error: {:?}", e); - std::process::exit(1); - } - }); - return; - } + Command::Synthesize(synth_args) => { let Some(output_dir) = args.output else { panic!("output dir is required"); @@ -1025,8 +1001,7 @@ fn main() { | Command::TruncatePatch(_) | Command::FilterLanguages(_) | Command::ImportBatch(_) - | Command::PrintZetaFormats - | Command::SyncDeployments(_) => { + | Command::PrintZetaFormats => { unreachable!() } } diff --git a/crates/edit_prediction_cli/src/pull_examples.rs b/crates/edit_prediction_cli/src/pull_examples.rs index 91202fda3e2a9fdb0858397f2c000d34803e8b94..91ea85e019973ec04cc747606e105c5e9ef50988 100644 --- a/crates/edit_prediction_cli/src/pull_examples.rs +++ b/crates/edit_prediction_cli/src/pull_examples.rs @@ -14,7 +14,7 @@ use zeta_prompt::ZetaPromptInput; use crate::example::Example; use crate::progress::{InfoStyle, Progress, Step}; -use crate::sync_deployments::EDIT_PREDICTION_DEPLOYMENT_EVENT; +const EDIT_PREDICTION_DEPLOYMENT_EVENT: &str = "Edit Prediction Deployment"; use edit_prediction::example_spec::{ CapturedEvent, CapturedPromptInput, CapturedRelatedExcerpt, CapturedRelatedFile, ExampleSpec, TelemetrySource, diff --git a/crates/edit_prediction_cli/src/sync_deployments.rs b/crates/edit_prediction_cli/src/sync_deployments.rs deleted file mode 100644 index b55923eda60fdf1c966f418d0f9d08762617987c..0000000000000000000000000000000000000000 --- a/crates/edit_prediction_cli/src/sync_deployments.rs +++ /dev/null @@ -1,578 +0,0 @@ -use anyhow::{Context as _, Result}; -use http_client::{AsyncBody, HttpClient, Method, Request}; -use serde::Deserialize; -use serde_json::{Value as JsonValue, json}; -use std::collections::HashMap; -use std::sync::Arc; - -use crate::pull_examples::{ - self, MAX_POLL_ATTEMPTS, POLL_INTERVAL, SNOWFLAKE_ASYNC_IN_PROGRESS_CODE, - SNOWFLAKE_SUCCESS_CODE, -}; - -const DEFAULT_BASETEN_MODEL_NAME: &str = "zeta-2"; -const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120; -pub(crate) const EDIT_PREDICTION_DEPLOYMENT_EVENT: &str = "Edit Prediction Deployment"; - -#[derive(Debug, Clone, Deserialize)] -struct BasetenModelsResponse { - models: Vec, -} - -#[derive(Debug, Clone, Deserialize)] -struct BasetenModel { - id: String, - name: String, -} - -#[derive(Debug, Clone, Deserialize)] -struct BasetenDeploymentsResponse { - deployments: Vec, -} - -#[derive(Debug, Clone, Deserialize)] -struct BasetenDeployment { - id: String, - name: String, - #[serde(default)] - status: Option, - #[serde(default)] - created_at: Option, - #[serde(default)] - environment: Option, -} - -#[derive(Debug, Clone)] -struct DeploymentRecord { - model_id: String, - model_version_id: String, - experiment_name: String, - environment: String, - status: String, - created_at: String, -} - -#[derive(Debug, Clone)] -struct ExistingDeployment { - experiment_name: String, - environment: String, -} - -async fn fetch_baseten_models( - http_client: &Arc, - api_key: &str, -) -> Result> { - let request = Request::builder() - .method(Method::GET) - .uri("https://api.baseten.co/v1/models") - .header("Authorization", format!("Api-Key {api_key}")) - .header("Accept", "application/json") - .body(AsyncBody::empty())?; - - let response = http_client - .send(request) - .await - .context("failed to fetch baseten models")?; - - let status = response.status(); - let body_bytes = { - use futures::AsyncReadExt as _; - let mut body = response.into_body(); - let mut bytes = Vec::new(); - body.read_to_end(&mut bytes) - .await - .context("failed to read baseten models response")?; - bytes - }; - - if !status.is_success() { - let body_text = String::from_utf8_lossy(&body_bytes); - anyhow::bail!("baseten models API http {}: {}", status.as_u16(), body_text); - } - - let parsed: BasetenModelsResponse = - serde_json::from_slice(&body_bytes).context("failed to parse baseten models response")?; - Ok(parsed.models) -} - -async fn fetch_baseten_deployments( - http_client: &Arc, - api_key: &str, - model_id: &str, -) -> Result> { - let url = format!("https://api.baseten.co/v1/models/{model_id}/deployments"); - let request = Request::builder() - .method(Method::GET) - .uri(url.as_str()) - .header("Authorization", format!("Api-Key {api_key}")) - .header("Accept", "application/json") - .body(AsyncBody::empty())?; - - let response = http_client - .send(request) - .await - .context("failed to fetch baseten deployments")?; - - let status = response.status(); - let body_bytes = { - use futures::AsyncReadExt as _; - let mut body = response.into_body(); - let mut bytes = Vec::new(); - body.read_to_end(&mut bytes) - .await - .context("failed to read baseten deployments response")?; - bytes - }; - - if !status.is_success() { - let body_text = String::from_utf8_lossy(&body_bytes); - anyhow::bail!( - "baseten deployments API http {}: {}", - status.as_u16(), - body_text - ); - } - - let parsed: BasetenDeploymentsResponse = - serde_json::from_slice(&body_bytes).context("failed to parse deployments response")?; - Ok(parsed.deployments) -} - -fn collect_deployment_records( - model_id: &str, - deployments: &[BasetenDeployment], -) -> Vec { - deployments - .iter() - .map(|deployment| DeploymentRecord { - model_id: model_id.to_string(), - model_version_id: deployment.id.clone(), - experiment_name: deployment.name.clone(), - environment: deployment - .environment - .clone() - .unwrap_or_else(|| "none".to_string()), - status: deployment - .status - .clone() - .unwrap_or_else(|| "unknown".to_string()), - created_at: deployment - .created_at - .clone() - .unwrap_or_else(|| "unknown".to_string()), - }) - .collect() -} - -async fn run_sql_with_polling( - http_client: Arc, - base_url: &str, - token: &str, - request: &serde_json::Value, -) -> Result { - let mut response = - pull_examples::run_sql(http_client.clone(), base_url, token, request).await?; - - if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) { - let statement_handle = response - .statement_handle - .as_ref() - .context("async query response missing statementHandle")? - .clone(); - - for _attempt in 1..=MAX_POLL_ATTEMPTS { - std::thread::sleep(POLL_INTERVAL); - - response = pull_examples::fetch_partition( - http_client.clone(), - base_url, - token, - &statement_handle, - 0, - ) - .await?; - - if response.code.as_deref() != Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) { - break; - } - } - - if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) { - anyhow::bail!( - "query still running after {} poll attempts ({} seconds)", - MAX_POLL_ATTEMPTS, - MAX_POLL_ATTEMPTS as u64 * POLL_INTERVAL.as_secs() - ); - } - } - - if let Some(code) = &response.code { - if code != SNOWFLAKE_SUCCESS_CODE { - anyhow::bail!( - "snowflake error: code={} message={}", - code, - response.message.as_deref().unwrap_or("") - ); - } - } - - Ok(response) -} - -async fn fetch_existing_deployments( - http_client: &Arc, - base_url: &str, - token: &str, - role: &Option, -) -> Result> { - let statement = format!( - r#" -SELECT - event_properties:model_version_id::string AS model_version_id, - event_properties:experiment_name::string AS experiment_name, - event_properties:environment::string AS environment -FROM events -WHERE event_type = '{EDIT_PREDICTION_DEPLOYMENT_EVENT}' -"# - ); - - let request = json!({ - "statement": statement, - "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS, - "database": "EVENTS", - "schema": "PUBLIC", - "warehouse": "DBT", - "role": role, - }); - - let response = run_sql_with_polling(http_client.clone(), base_url, token, &request).await?; - - let col_names = ["model_version_id", "experiment_name", "environment"]; - let column_indices = - pull_examples::get_column_indices(&response.result_set_meta_data, &col_names); - - let mut existing = HashMap::new(); - - for data_row in &response.data { - let get_string = |name: &str| -> Option { - let &index = column_indices.get(name)?; - match data_row.get(index) { - Some(JsonValue::String(s)) => Some(s.clone()), - _ => None, - } - }; - - let Some(model_version_id) = get_string("model_version_id") else { - continue; - }; - let experiment_name = get_string("experiment_name").unwrap_or_default(); - let environment = get_string("environment").unwrap_or_default(); - - existing.insert( - model_version_id, - ExistingDeployment { - experiment_name, - environment, - }, - ); - } - - Ok(existing) -} - -async fn insert_deployment( - http_client: &Arc, - base_url: &str, - token: &str, - role: &Option, - record: &DeploymentRecord, -) -> Result<()> { - let event_properties = json!({ - "model_id": record.model_id, - "model_version_id": record.model_version_id, - "experiment_name": record.experiment_name, - "environment": record.environment, - "status": record.status, - "created_at": record.created_at, - }); - - let event_properties_str = - serde_json::to_string(&event_properties).context("failed to serialize event_properties")?; - - let statement = r#" -INSERT INTO events (event_type, event_properties, device_id, time) -VALUES (?, PARSE_JSON(?), 'ep-cli', CURRENT_TIMESTAMP()) -"#; - - let bindings = json!({ - "1": { "type": "TEXT", "value": EDIT_PREDICTION_DEPLOYMENT_EVENT }, - "2": { "type": "TEXT", "value": event_properties_str } - }); - - let request = json!({ - "statement": statement, - "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS, - "database": "EVENTS", - "schema": "PUBLIC", - "warehouse": "DBT", - "role": role, - "bindings": bindings - }); - - run_sql_with_polling(http_client.clone(), base_url, token, &request).await?; - Ok(()) -} - -async fn update_deployment( - http_client: &Arc, - base_url: &str, - token: &str, - role: &Option, - record: &DeploymentRecord, -) -> Result<()> { - let statement = format!( - r#" -UPDATE events -SET - event_properties = OBJECT_INSERT( - OBJECT_INSERT(event_properties, 'environment', ?::VARIANT, true), - 'experiment_name', ?::VARIANT, true - ), - time = CURRENT_TIMESTAMP() -WHERE event_type = '{EDIT_PREDICTION_DEPLOYMENT_EVENT}' - AND event_properties:model_version_id::string = ? -"# - ); - - let bindings = json!({ - "1": { "type": "TEXT", "value": record.environment }, - "2": { "type": "TEXT", "value": record.experiment_name }, - "3": { "type": "TEXT", "value": record.model_version_id } - }); - - let request = json!({ - "statement": statement, - "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS, - "database": "EVENTS", - "schema": "PUBLIC", - "warehouse": "DBT", - "role": role, - "bindings": bindings - }); - - run_sql_with_polling(http_client.clone(), base_url, token, &request).await?; - Ok(()) -} - -fn display_deployments(existing: &HashMap) { - let col_names = ["version_id", "experiment", "environment"]; - - let mut col_widths: Vec = col_names.iter().map(|n| n.len()).collect(); - let mut rows: Vec<[String; 3]> = Vec::new(); - - for (version_id, deployment) in existing { - let row = [ - version_id.clone(), - deployment.experiment_name.clone(), - deployment.environment.clone(), - ]; - for (i, val) in row.iter().enumerate() { - col_widths[i] = col_widths[i].max(val.len()); - } - rows.push(row); - } - - rows.sort_by(|a, b| a[2].cmp(&b[2]).then_with(|| a[1].cmp(&b[1]))); - - let print_row = |values: &[&str]| { - for (i, val) in values.iter().enumerate() { - if i > 0 { - eprint!(" "); - } - eprint!("{:width$}", val, width = col_widths[i]); - } - eprintln!(); - }; - - eprintln!(); - print_row(&col_names); - - let separators: Vec = col_widths.iter().map(|w| "─".repeat(*w)).collect(); - let separator_refs: Vec<&str> = separators.iter().map(|s| s.as_str()).collect(); - print_row(&separator_refs); - - for row in &rows { - let refs: Vec<&str> = row.iter().map(|s| s.as_str()).collect(); - print_row(&refs); - } -} - -pub async fn run_sync_deployments( - http_client: Arc, - model_name: Option, -) -> Result<()> { - let baseten_api_key = std::env::var("BASETEN_API_KEY") - .context("missing required environment variable BASETEN_API_KEY")?; - let snowflake_token = std::env::var("EP_SNOWFLAKE_API_KEY") - .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?; - let snowflake_base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context( - "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://.snowflakecomputing.com)", - )?; - let snowflake_role = std::env::var("EP_SNOWFLAKE_ROLE").ok(); - - let model_name = model_name.unwrap_or_else(|| DEFAULT_BASETEN_MODEL_NAME.to_string()); - - let models = fetch_baseten_models(&http_client, &baseten_api_key).await?; - - let model = models - .iter() - .find(|m| m.name == model_name) - .with_context(|| { - let available: Vec<&str> = models.iter().map(|m| m.name.as_str()).collect(); - format!( - "model '{}' not found on baseten. Available: {:?}", - model_name, available - ) - })?; - - eprintln!("Fetching existing deployments from Snowflake..."); - let mut existing = fetch_existing_deployments( - &http_client, - &snowflake_base_url, - &snowflake_token, - &snowflake_role, - ) - .await - .context("failed to fetch existing deployments from Snowflake")?; - - eprintln!( - "Found {} existing deployment(s) in Snowflake.", - existing.len() - ); - - let baseten_deployments = fetch_baseten_deployments(&http_client, &baseten_api_key, &model.id) - .await - .with_context(|| format!("failed to fetch deployments for model '{}'", model.name))?; - - let records = collect_deployment_records(&model.id, &baseten_deployments); - - if records.is_empty() { - eprintln!("No deployments found on Baseten."); - return Ok(()); - } - - eprintln!( - "Found {} deployment(s) on Baseten for model '{}'.", - records.len(), - model.name - ); - - let mut inserts = Vec::new(); - let mut updates = Vec::new(); - let mut unchanged = 0; - - for record in &records { - match existing.get(&record.model_version_id) { - Some(existing_deployment) => { - let environment_changed = existing_deployment.environment != record.environment; - let experiment_changed = - existing_deployment.experiment_name != record.experiment_name; - - if environment_changed || experiment_changed { - updates.push(record); - } else { - unchanged += 1; - } - } - None => { - inserts.push(record); - } - } - } - - eprintln!( - "Diff: {} insert(s), {} update(s), {} unchanged", - inserts.len(), - updates.len(), - unchanged, - ); - - for (i, record) in inserts.iter().enumerate() { - eprintln!( - " INSERT [{}/{}] {} -> {} (version_id={})", - i + 1, - inserts.len(), - record.experiment_name, - record.environment, - record.model_version_id, - ); - insert_deployment( - &http_client, - &snowflake_base_url, - &snowflake_token, - &snowflake_role, - record, - ) - .await - .with_context(|| { - format!( - "failed to insert deployment '{}' (model_version_id={})", - record.experiment_name, record.model_version_id - ) - })?; - - existing.insert( - record.model_version_id.clone(), - ExistingDeployment { - experiment_name: record.experiment_name.clone(), - environment: record.environment.clone(), - }, - ); - } - - for (i, record) in updates.iter().enumerate() { - let existing_deployment = existing - .get(&record.model_version_id) - .context("update record missing from existing map")?; - eprintln!( - " UPDATE [{}/{}] version_id={}: environment '{}' -> '{}', experiment '{}' -> '{}'", - i + 1, - updates.len(), - record.model_version_id, - existing_deployment.environment, - record.environment, - existing_deployment.experiment_name, - record.experiment_name, - ); - update_deployment( - &http_client, - &snowflake_base_url, - &snowflake_token, - &snowflake_role, - record, - ) - .await - .with_context(|| { - format!( - "failed to update deployment '{}' (model_version_id={})", - record.experiment_name, record.model_version_id - ) - })?; - - existing.insert( - record.model_version_id.clone(), - ExistingDeployment { - experiment_name: record.experiment_name.clone(), - environment: record.environment.clone(), - }, - ); - } - - if inserts.is_empty() && updates.is_empty() { - eprintln!("All deployments up to date, no writes needed."); - } - - display_deployments(&existing); - - Ok(()) -}