ep_cli: Add sync deployments command (#48880)

Ben Kunkle and Tom created

Closes #ISSUE

- [ ] Tests or screenshots needed?
- [ ] Code Reviewed
- [ ] Manual QA

Release Notes:

- N/A *or* Added/Fixed/Improved ...

---------

Co-authored-by: Tom <tom@zed.dev>

Change summary

crates/edit_prediction_cli/src/main.rs             |  29 
crates/edit_prediction_cli/src/pull_examples.rs    |  28 
crates/edit_prediction_cli/src/sync_deployments.rs | 578 ++++++++++++++++
3 files changed, 620 insertions(+), 15 deletions(-)

Detailed changes

crates/edit_prediction_cli/src/main.rs 🔗

@@ -22,6 +22,7 @@ mod reversal_tracking;
 mod score;
 mod split_commit;
 mod split_dataset;
+mod sync_deployments;
 mod synthesize;
 mod truncate_expected_patch;
 mod word_diff;
@@ -209,6 +210,8 @@ enum Command {
     Repair(repair::RepairArgs),
     /// Print all valid zeta formats (lowercase, one per line)
     PrintZetaFormats,
+    /// Sync baseten deployment metadata to Snowflake for linking predictions to experiments
+    SyncDeployments(SyncDeploymentsArgs),
 }
 
 impl Display for Command {
@@ -254,10 +257,20 @@ impl Display for Command {
             Command::PrintZetaFormats => {
                 write!(f, "print-zeta-formats")
             }
+            Command::SyncDeployments(_) => {
+                write!(f, "sync-deployments")
+            }
         }
     }
 }
 
+#[derive(Debug, Args, Clone)]
+struct SyncDeploymentsArgs {
+    /// BaseTen model name (default: "zeta-2")
+    #[arg(long)]
+    model: Option<String>,
+}
+
 #[derive(Debug, Args, Clone)]
 struct FormatPromptArgs {
     #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
@@ -731,6 +744,19 @@ fn main() {
             }
             return;
         }
+        Command::SyncDeployments(sync_args) => {
+            let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
+            smol::block_on(async {
+                if let Err(e) =
+                    sync_deployments::run_sync_deployments(http_client, sync_args.model.clone())
+                        .await
+                {
+                    eprintln!("Error: {:?}", e);
+                    std::process::exit(1);
+                }
+            });
+            return;
+        }
         Command::Synthesize(synth_args) => {
             let Some(output_dir) = args.output else {
                 panic!("output dir is required");
@@ -966,7 +992,8 @@ fn main() {
                                         | Command::TruncatePatch(_)
                                         | Command::FilterLanguages(_)
                                         | Command::ImportBatch(_)
-                                        | Command::PrintZetaFormats => {
+                                        | Command::PrintZetaFormats
+                                        | Command::SyncDeployments(_) => {
                                             unreachable!()
                                         }
                                     }

crates/edit_prediction_cli/src/pull_examples.rs 🔗

@@ -20,16 +20,16 @@ use edit_prediction::example_spec::{
 };
 use std::fmt::Write as _;
 
-const SNOWFLAKE_SUCCESS_CODE: &str = "090001";
-const SNOWFLAKE_ASYNC_IN_PROGRESS_CODE: &str = "333334";
+pub(crate) const SNOWFLAKE_SUCCESS_CODE: &str = "090001";
+pub(crate) const SNOWFLAKE_ASYNC_IN_PROGRESS_CODE: &str = "333334";
 const EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT: &str = "Edit Prediction Example Captured";
 const PREDICTIVE_EDIT_REQUESTED_EVENT: &str = "Predictive Edit Requested";
 const PREDICTIVE_EDIT_REJECTED_EVENT: &str = "Predictive Edit Rejected";
 const EDIT_PREDICTION_RATED_EVENT: &str = "Edit Prediction Rated";
 
 const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120;
-const POLL_INTERVAL: Duration = Duration::from_secs(2);
-const MAX_POLL_ATTEMPTS: usize = 120;
+pub(crate) const POLL_INTERVAL: Duration = Duration::from_secs(2);
+pub(crate) const MAX_POLL_ATTEMPTS: usize = 120;
 
 /// Parse an input token of the form `captured-after:{timestamp}`.
 pub fn parse_captured_after_input(input: &str) -> Option<&str> {
@@ -187,22 +187,22 @@ pub async fn fetch_captured_examples_after(
 
 #[derive(Debug, Clone, Deserialize)]
 #[serde(rename_all = "camelCase")]
-struct SnowflakeStatementResponse {
+pub(crate) struct SnowflakeStatementResponse {
     #[serde(default)]
-    data: Vec<Vec<JsonValue>>,
+    pub(crate) data: Vec<Vec<JsonValue>>,
     #[serde(default)]
-    result_set_meta_data: Option<SnowflakeResultSetMetaData>,
+    pub(crate) result_set_meta_data: Option<SnowflakeResultSetMetaData>,
     #[serde(default)]
-    code: Option<String>,
+    pub(crate) code: Option<String>,
     #[serde(default)]
-    message: Option<String>,
+    pub(crate) message: Option<String>,
     #[serde(default)]
-    statement_handle: Option<String>,
+    pub(crate) statement_handle: Option<String>,
 }
 
 #[derive(Debug, Clone, Deserialize)]
 #[serde(rename_all = "camelCase")]
-struct SnowflakeResultSetMetaData {
+pub(crate) struct SnowflakeResultSetMetaData {
     #[serde(default, rename = "rowType")]
     row_type: Vec<SnowflakeColumnMeta>,
     #[serde(default)]
@@ -313,7 +313,7 @@ async fn run_sql_with_polling(
     Ok(response)
 }
 
-async fn fetch_partition(
+pub(crate) async fn fetch_partition(
     http_client: Arc<dyn HttpClient>,
     base_url: &str,
     token: &str,
@@ -402,7 +402,7 @@ async fn fetch_partition(
     })
 }
 
-async fn run_sql(
+pub(crate) async fn run_sql(
     http_client: Arc<dyn HttpClient>,
     base_url: &str,
     token: &str,
@@ -1344,7 +1344,7 @@ fn build_output_patch(
     patch
 }
 
-fn get_column_indices(
+pub(crate) fn get_column_indices(
     meta: &Option<SnowflakeResultSetMetaData>,
     names: &[&str],
 ) -> std::collections::HashMap<String, usize> {

crates/edit_prediction_cli/src/sync_deployments.rs 🔗

@@ -0,0 +1,578 @@
+use anyhow::{Context as _, Result};
+use http_client::{AsyncBody, HttpClient, Method, Request};
+use serde::Deserialize;
+use serde_json::{Value as JsonValue, json};
+use std::collections::HashMap;
+use std::sync::Arc;
+
+use crate::pull_examples::{
+    self, MAX_POLL_ATTEMPTS, POLL_INTERVAL, SNOWFLAKE_ASYNC_IN_PROGRESS_CODE,
+    SNOWFLAKE_SUCCESS_CODE,
+};
+
+const DEFAULT_BASETEN_MODEL_NAME: &str = "zeta-2";
+const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120;
+const EDIT_PREDICTION_DEPLOYMENT_EVENT: &str = "Edit Prediction Deployment";
+
+#[derive(Debug, Clone, Deserialize)]
+struct BasetenModelsResponse {
+    models: Vec<BasetenModel>,
+}
+
+#[derive(Debug, Clone, Deserialize)]
+struct BasetenModel {
+    id: String,
+    name: String,
+}
+
+#[derive(Debug, Clone, Deserialize)]
+struct BasetenDeploymentsResponse {
+    deployments: Vec<BasetenDeployment>,
+}
+
+#[derive(Debug, Clone, Deserialize)]
+struct BasetenDeployment {
+    id: String,
+    name: String,
+    #[serde(default)]
+    status: Option<String>,
+    #[serde(default)]
+    created_at: Option<String>,
+    #[serde(default)]
+    environment: Option<String>,
+}
+
+#[derive(Debug, Clone)]
+struct DeploymentRecord {
+    model_id: String,
+    model_version_id: String,
+    experiment_name: String,
+    environment: String,
+    status: String,
+    created_at: String,
+}
+
+#[derive(Debug, Clone)]
+struct ExistingDeployment {
+    experiment_name: String,
+    environment: String,
+}
+
+async fn fetch_baseten_models(
+    http_client: &Arc<dyn HttpClient>,
+    api_key: &str,
+) -> Result<Vec<BasetenModel>> {
+    let request = Request::builder()
+        .method(Method::GET)
+        .uri("https://api.baseten.co/v1/models")
+        .header("Authorization", format!("Api-Key {api_key}"))
+        .header("Accept", "application/json")
+        .body(AsyncBody::empty())?;
+
+    let response = http_client
+        .send(request)
+        .await
+        .context("failed to fetch baseten models")?;
+
+    let status = response.status();
+    let body_bytes = {
+        use futures::AsyncReadExt as _;
+        let mut body = response.into_body();
+        let mut bytes = Vec::new();
+        body.read_to_end(&mut bytes)
+            .await
+            .context("failed to read baseten models response")?;
+        bytes
+    };
+
+    if !status.is_success() {
+        let body_text = String::from_utf8_lossy(&body_bytes);
+        anyhow::bail!("baseten models API http {}: {}", status.as_u16(), body_text);
+    }
+
+    let parsed: BasetenModelsResponse =
+        serde_json::from_slice(&body_bytes).context("failed to parse baseten models response")?;
+    Ok(parsed.models)
+}
+
+async fn fetch_baseten_deployments(
+    http_client: &Arc<dyn HttpClient>,
+    api_key: &str,
+    model_id: &str,
+) -> Result<Vec<BasetenDeployment>> {
+    let url = format!("https://api.baseten.co/v1/models/{model_id}/deployments");
+    let request = Request::builder()
+        .method(Method::GET)
+        .uri(url.as_str())
+        .header("Authorization", format!("Api-Key {api_key}"))
+        .header("Accept", "application/json")
+        .body(AsyncBody::empty())?;
+
+    let response = http_client
+        .send(request)
+        .await
+        .context("failed to fetch baseten deployments")?;
+
+    let status = response.status();
+    let body_bytes = {
+        use futures::AsyncReadExt as _;
+        let mut body = response.into_body();
+        let mut bytes = Vec::new();
+        body.read_to_end(&mut bytes)
+            .await
+            .context("failed to read baseten deployments response")?;
+        bytes
+    };
+
+    if !status.is_success() {
+        let body_text = String::from_utf8_lossy(&body_bytes);
+        anyhow::bail!(
+            "baseten deployments API http {}: {}",
+            status.as_u16(),
+            body_text
+        );
+    }
+
+    let parsed: BasetenDeploymentsResponse =
+        serde_json::from_slice(&body_bytes).context("failed to parse deployments response")?;
+    Ok(parsed.deployments)
+}
+
+fn collect_deployment_records(
+    model_id: &str,
+    deployments: &[BasetenDeployment],
+) -> Vec<DeploymentRecord> {
+    deployments
+        .iter()
+        .map(|deployment| DeploymentRecord {
+            model_id: model_id.to_string(),
+            model_version_id: deployment.id.clone(),
+            experiment_name: deployment.name.clone(),
+            environment: deployment
+                .environment
+                .clone()
+                .unwrap_or_else(|| "none".to_string()),
+            status: deployment
+                .status
+                .clone()
+                .unwrap_or_else(|| "unknown".to_string()),
+            created_at: deployment
+                .created_at
+                .clone()
+                .unwrap_or_else(|| "unknown".to_string()),
+        })
+        .collect()
+}
+
+async fn run_sql_with_polling(
+    http_client: Arc<dyn HttpClient>,
+    base_url: &str,
+    token: &str,
+    request: &serde_json::Value,
+) -> Result<pull_examples::SnowflakeStatementResponse> {
+    let mut response =
+        pull_examples::run_sql(http_client.clone(), base_url, token, request).await?;
+
+    if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
+        let statement_handle = response
+            .statement_handle
+            .as_ref()
+            .context("async query response missing statementHandle")?
+            .clone();
+
+        for _attempt in 1..=MAX_POLL_ATTEMPTS {
+            std::thread::sleep(POLL_INTERVAL);
+
+            response = pull_examples::fetch_partition(
+                http_client.clone(),
+                base_url,
+                token,
+                &statement_handle,
+                0,
+            )
+            .await?;
+
+            if response.code.as_deref() != Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
+                break;
+            }
+        }
+
+        if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
+            anyhow::bail!(
+                "query still running after {} poll attempts ({} seconds)",
+                MAX_POLL_ATTEMPTS,
+                MAX_POLL_ATTEMPTS as u64 * POLL_INTERVAL.as_secs()
+            );
+        }
+    }
+
+    if let Some(code) = &response.code {
+        if code != SNOWFLAKE_SUCCESS_CODE {
+            anyhow::bail!(
+                "snowflake error: code={} message={}",
+                code,
+                response.message.as_deref().unwrap_or("<no message>")
+            );
+        }
+    }
+
+    Ok(response)
+}
+
+async fn fetch_existing_deployments(
+    http_client: &Arc<dyn HttpClient>,
+    base_url: &str,
+    token: &str,
+    role: &Option<String>,
+) -> Result<HashMap<String, ExistingDeployment>> {
+    let statement = format!(
+        r#"
+SELECT
+    event_properties:model_version_id::string AS model_version_id,
+    event_properties:experiment_name::string AS experiment_name,
+    event_properties:environment::string AS environment
+FROM events
+WHERE event_type = '{EDIT_PREDICTION_DEPLOYMENT_EVENT}'
+"#
+    );
+
+    let request = json!({
+        "statement": statement,
+        "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
+        "database": "EVENTS",
+        "schema": "PUBLIC",
+        "warehouse": "DBT",
+        "role": role,
+    });
+
+    let response = run_sql_with_polling(http_client.clone(), base_url, token, &request).await?;
+
+    let col_names = ["model_version_id", "experiment_name", "environment"];
+    let column_indices =
+        pull_examples::get_column_indices(&response.result_set_meta_data, &col_names);
+
+    let mut existing = HashMap::new();
+
+    for data_row in &response.data {
+        let get_string = |name: &str| -> Option<String> {
+            let &index = column_indices.get(name)?;
+            match data_row.get(index) {
+                Some(JsonValue::String(s)) => Some(s.clone()),
+                _ => None,
+            }
+        };
+
+        let Some(model_version_id) = get_string("model_version_id") else {
+            continue;
+        };
+        let experiment_name = get_string("experiment_name").unwrap_or_default();
+        let environment = get_string("environment").unwrap_or_default();
+
+        existing.insert(
+            model_version_id,
+            ExistingDeployment {
+                experiment_name,
+                environment,
+            },
+        );
+    }
+
+    Ok(existing)
+}
+
+async fn insert_deployment(
+    http_client: &Arc<dyn HttpClient>,
+    base_url: &str,
+    token: &str,
+    role: &Option<String>,
+    record: &DeploymentRecord,
+) -> Result<()> {
+    let event_properties = json!({
+        "model_id": record.model_id,
+        "model_version_id": record.model_version_id,
+        "experiment_name": record.experiment_name,
+        "environment": record.environment,
+        "status": record.status,
+        "created_at": record.created_at,
+    });
+
+    let event_properties_str =
+        serde_json::to_string(&event_properties).context("failed to serialize event_properties")?;
+
+    let statement = r#"
+INSERT INTO events (event_type, event_properties, device_id, time)
+VALUES (?, PARSE_JSON(?), 'ep-cli', CURRENT_TIMESTAMP())
+"#;
+
+    let bindings = json!({
+        "1": { "type": "TEXT", "value": EDIT_PREDICTION_DEPLOYMENT_EVENT },
+        "2": { "type": "TEXT", "value": event_properties_str }
+    });
+
+    let request = json!({
+        "statement": statement,
+        "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
+        "database": "EVENTS",
+        "schema": "PUBLIC",
+        "warehouse": "DBT",
+        "role": role,
+        "bindings": bindings
+    });
+
+    run_sql_with_polling(http_client.clone(), base_url, token, &request).await?;
+    Ok(())
+}
+
+async fn update_deployment(
+    http_client: &Arc<dyn HttpClient>,
+    base_url: &str,
+    token: &str,
+    role: &Option<String>,
+    record: &DeploymentRecord,
+) -> Result<()> {
+    let statement = format!(
+        r#"
+UPDATE events
+SET
+    event_properties = OBJECT_INSERT(
+        OBJECT_INSERT(event_properties, 'environment', ?::VARIANT, true),
+        'experiment_name', ?::VARIANT, true
+    ),
+    time = CURRENT_TIMESTAMP()
+WHERE event_type = '{EDIT_PREDICTION_DEPLOYMENT_EVENT}'
+    AND event_properties:model_version_id::string = ?
+"#
+    );
+
+    let bindings = json!({
+        "1": { "type": "TEXT", "value": record.environment },
+        "2": { "type": "TEXT", "value": record.experiment_name },
+        "3": { "type": "TEXT", "value": record.model_version_id }
+    });
+
+    let request = json!({
+        "statement": statement,
+        "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
+        "database": "EVENTS",
+        "schema": "PUBLIC",
+        "warehouse": "DBT",
+        "role": role,
+        "bindings": bindings
+    });
+
+    run_sql_with_polling(http_client.clone(), base_url, token, &request).await?;
+    Ok(())
+}
+
+fn display_deployments(existing: &HashMap<String, ExistingDeployment>) {
+    let col_names = ["version_id", "experiment", "environment"];
+
+    let mut col_widths: Vec<usize> = col_names.iter().map(|n| n.len()).collect();
+    let mut rows: Vec<[String; 3]> = Vec::new();
+
+    for (version_id, deployment) in existing {
+        let row = [
+            version_id.clone(),
+            deployment.experiment_name.clone(),
+            deployment.environment.clone(),
+        ];
+        for (i, val) in row.iter().enumerate() {
+            col_widths[i] = col_widths[i].max(val.len());
+        }
+        rows.push(row);
+    }
+
+    rows.sort_by(|a, b| a[2].cmp(&b[2]).then_with(|| a[1].cmp(&b[1])));
+
+    let print_row = |values: &[&str]| {
+        for (i, val) in values.iter().enumerate() {
+            if i > 0 {
+                eprint!("  ");
+            }
+            eprint!("{:width$}", val, width = col_widths[i]);
+        }
+        eprintln!();
+    };
+
+    eprintln!();
+    print_row(&col_names);
+
+    let separators: Vec<String> = col_widths.iter().map(|w| "─".repeat(*w)).collect();
+    let separator_refs: Vec<&str> = separators.iter().map(|s| s.as_str()).collect();
+    print_row(&separator_refs);
+
+    for row in &rows {
+        let refs: Vec<&str> = row.iter().map(|s| s.as_str()).collect();
+        print_row(&refs);
+    }
+}
+
+pub async fn run_sync_deployments(
+    http_client: Arc<dyn HttpClient>,
+    model_name: Option<String>,
+) -> Result<()> {
+    let baseten_api_key = std::env::var("BASETEN_API_KEY")
+        .context("missing required environment variable BASETEN_API_KEY")?;
+    let snowflake_token = std::env::var("EP_SNOWFLAKE_API_KEY")
+        .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
+    let snowflake_base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
+        "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
+    )?;
+    let snowflake_role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
+
+    let model_name = model_name.unwrap_or_else(|| DEFAULT_BASETEN_MODEL_NAME.to_string());
+
+    let models = fetch_baseten_models(&http_client, &baseten_api_key).await?;
+
+    let model = models
+        .iter()
+        .find(|m| m.name == model_name)
+        .with_context(|| {
+            let available: Vec<&str> = models.iter().map(|m| m.name.as_str()).collect();
+            format!(
+                "model '{}' not found on baseten. Available: {:?}",
+                model_name, available
+            )
+        })?;
+
+    eprintln!("Fetching existing deployments from Snowflake...");
+    let mut existing = fetch_existing_deployments(
+        &http_client,
+        &snowflake_base_url,
+        &snowflake_token,
+        &snowflake_role,
+    )
+    .await
+    .context("failed to fetch existing deployments from Snowflake")?;
+
+    eprintln!(
+        "Found {} existing deployment(s) in Snowflake.",
+        existing.len()
+    );
+
+    let baseten_deployments = fetch_baseten_deployments(&http_client, &baseten_api_key, &model.id)
+        .await
+        .with_context(|| format!("failed to fetch deployments for model '{}'", model.name))?;
+
+    let records = collect_deployment_records(&model.id, &baseten_deployments);
+
+    if records.is_empty() {
+        eprintln!("No deployments found on Baseten.");
+        return Ok(());
+    }
+
+    eprintln!(
+        "Found {} deployment(s) on Baseten for model '{}'.",
+        records.len(),
+        model.name
+    );
+
+    let mut inserts = Vec::new();
+    let mut updates = Vec::new();
+    let mut unchanged = 0;
+
+    for record in &records {
+        match existing.get(&record.model_version_id) {
+            Some(existing_deployment) => {
+                let environment_changed = existing_deployment.environment != record.environment;
+                let experiment_changed =
+                    existing_deployment.experiment_name != record.experiment_name;
+
+                if environment_changed || experiment_changed {
+                    updates.push(record);
+                } else {
+                    unchanged += 1;
+                }
+            }
+            None => {
+                inserts.push(record);
+            }
+        }
+    }
+
+    eprintln!(
+        "Diff: {} insert(s), {} update(s), {} unchanged",
+        inserts.len(),
+        updates.len(),
+        unchanged,
+    );
+
+    for (i, record) in inserts.iter().enumerate() {
+        eprintln!(
+            "  INSERT [{}/{}] {} -> {} (version_id={})",
+            i + 1,
+            inserts.len(),
+            record.experiment_name,
+            record.environment,
+            record.model_version_id,
+        );
+        insert_deployment(
+            &http_client,
+            &snowflake_base_url,
+            &snowflake_token,
+            &snowflake_role,
+            record,
+        )
+        .await
+        .with_context(|| {
+            format!(
+                "failed to insert deployment '{}' (model_version_id={})",
+                record.experiment_name, record.model_version_id
+            )
+        })?;
+
+        existing.insert(
+            record.model_version_id.clone(),
+            ExistingDeployment {
+                experiment_name: record.experiment_name.clone(),
+                environment: record.environment.clone(),
+            },
+        );
+    }
+
+    for (i, record) in updates.iter().enumerate() {
+        let existing_deployment = existing
+            .get(&record.model_version_id)
+            .context("update record missing from existing map")?;
+        eprintln!(
+            "  UPDATE [{}/{}] version_id={}: environment '{}' -> '{}', experiment '{}' -> '{}'",
+            i + 1,
+            updates.len(),
+            record.model_version_id,
+            existing_deployment.environment,
+            record.environment,
+            existing_deployment.experiment_name,
+            record.experiment_name,
+        );
+        update_deployment(
+            &http_client,
+            &snowflake_base_url,
+            &snowflake_token,
+            &snowflake_role,
+            record,
+        )
+        .await
+        .with_context(|| {
+            format!(
+                "failed to update deployment '{}' (model_version_id={})",
+                record.experiment_name, record.model_version_id
+            )
+        })?;
+
+        existing.insert(
+            record.model_version_id.clone(),
+            ExistingDeployment {
+                experiment_name: record.experiment_name.clone(),
+                environment: record.environment.clone(),
+            },
+        );
+    }
+
+    if inserts.is_empty() && updates.is_empty() {
+        eprintln!("All deployments up to date, no writes needed.");
+    }
+
+    display_deployments(&existing);
+
+    Ok(())
+}