sync_deployments.rs

  1use anyhow::{Context as _, Result};
  2use http_client::{AsyncBody, HttpClient, Method, Request};
  3use serde::Deserialize;
  4use serde_json::{Value as JsonValue, json};
  5use std::collections::HashMap;
  6use std::sync::Arc;
  7
  8use crate::pull_examples::{
  9    self, MAX_POLL_ATTEMPTS, POLL_INTERVAL, SNOWFLAKE_ASYNC_IN_PROGRESS_CODE,
 10    SNOWFLAKE_SUCCESS_CODE,
 11};
 12
 13const DEFAULT_BASETEN_MODEL_NAME: &str = "zeta-2";
 14const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120;
 15pub(crate) const EDIT_PREDICTION_DEPLOYMENT_EVENT: &str = "Edit Prediction Deployment";
 16
 17#[derive(Debug, Clone, Deserialize)]
 18struct BasetenModelsResponse {
 19    models: Vec<BasetenModel>,
 20}
 21
 22#[derive(Debug, Clone, Deserialize)]
 23struct BasetenModel {
 24    id: String,
 25    name: String,
 26}
 27
 28#[derive(Debug, Clone, Deserialize)]
 29struct BasetenDeploymentsResponse {
 30    deployments: Vec<BasetenDeployment>,
 31}
 32
 33#[derive(Debug, Clone, Deserialize)]
 34struct BasetenDeployment {
 35    id: String,
 36    name: String,
 37    #[serde(default)]
 38    status: Option<String>,
 39    #[serde(default)]
 40    created_at: Option<String>,
 41    #[serde(default)]
 42    environment: Option<String>,
 43}
 44
 45#[derive(Debug, Clone)]
 46struct DeploymentRecord {
 47    model_id: String,
 48    model_version_id: String,
 49    experiment_name: String,
 50    environment: String,
 51    status: String,
 52    created_at: String,
 53}
 54
 55#[derive(Debug, Clone)]
 56struct ExistingDeployment {
 57    experiment_name: String,
 58    environment: String,
 59}
 60
 61async fn fetch_baseten_models(
 62    http_client: &Arc<dyn HttpClient>,
 63    api_key: &str,
 64) -> Result<Vec<BasetenModel>> {
 65    let request = Request::builder()
 66        .method(Method::GET)
 67        .uri("https://api.baseten.co/v1/models")
 68        .header("Authorization", format!("Api-Key {api_key}"))
 69        .header("Accept", "application/json")
 70        .body(AsyncBody::empty())?;
 71
 72    let response = http_client
 73        .send(request)
 74        .await
 75        .context("failed to fetch baseten models")?;
 76
 77    let status = response.status();
 78    let body_bytes = {
 79        use futures::AsyncReadExt as _;
 80        let mut body = response.into_body();
 81        let mut bytes = Vec::new();
 82        body.read_to_end(&mut bytes)
 83            .await
 84            .context("failed to read baseten models response")?;
 85        bytes
 86    };
 87
 88    if !status.is_success() {
 89        let body_text = String::from_utf8_lossy(&body_bytes);
 90        anyhow::bail!("baseten models API http {}: {}", status.as_u16(), body_text);
 91    }
 92
 93    let parsed: BasetenModelsResponse =
 94        serde_json::from_slice(&body_bytes).context("failed to parse baseten models response")?;
 95    Ok(parsed.models)
 96}
 97
 98async fn fetch_baseten_deployments(
 99    http_client: &Arc<dyn HttpClient>,
100    api_key: &str,
101    model_id: &str,
102) -> Result<Vec<BasetenDeployment>> {
103    let url = format!("https://api.baseten.co/v1/models/{model_id}/deployments");
104    let request = Request::builder()
105        .method(Method::GET)
106        .uri(url.as_str())
107        .header("Authorization", format!("Api-Key {api_key}"))
108        .header("Accept", "application/json")
109        .body(AsyncBody::empty())?;
110
111    let response = http_client
112        .send(request)
113        .await
114        .context("failed to fetch baseten deployments")?;
115
116    let status = response.status();
117    let body_bytes = {
118        use futures::AsyncReadExt as _;
119        let mut body = response.into_body();
120        let mut bytes = Vec::new();
121        body.read_to_end(&mut bytes)
122            .await
123            .context("failed to read baseten deployments response")?;
124        bytes
125    };
126
127    if !status.is_success() {
128        let body_text = String::from_utf8_lossy(&body_bytes);
129        anyhow::bail!(
130            "baseten deployments API http {}: {}",
131            status.as_u16(),
132            body_text
133        );
134    }
135
136    let parsed: BasetenDeploymentsResponse =
137        serde_json::from_slice(&body_bytes).context("failed to parse deployments response")?;
138    Ok(parsed.deployments)
139}
140
141fn collect_deployment_records(
142    model_id: &str,
143    deployments: &[BasetenDeployment],
144) -> Vec<DeploymentRecord> {
145    deployments
146        .iter()
147        .map(|deployment| DeploymentRecord {
148            model_id: model_id.to_string(),
149            model_version_id: deployment.id.clone(),
150            experiment_name: deployment.name.clone(),
151            environment: deployment
152                .environment
153                .clone()
154                .unwrap_or_else(|| "none".to_string()),
155            status: deployment
156                .status
157                .clone()
158                .unwrap_or_else(|| "unknown".to_string()),
159            created_at: deployment
160                .created_at
161                .clone()
162                .unwrap_or_else(|| "unknown".to_string()),
163        })
164        .collect()
165}
166
167async fn run_sql_with_polling(
168    http_client: Arc<dyn HttpClient>,
169    base_url: &str,
170    token: &str,
171    request: &serde_json::Value,
172) -> Result<pull_examples::SnowflakeStatementResponse> {
173    let mut response =
174        pull_examples::run_sql(http_client.clone(), base_url, token, request).await?;
175
176    if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
177        let statement_handle = response
178            .statement_handle
179            .as_ref()
180            .context("async query response missing statementHandle")?
181            .clone();
182
183        for _attempt in 1..=MAX_POLL_ATTEMPTS {
184            std::thread::sleep(POLL_INTERVAL);
185
186            response = pull_examples::fetch_partition(
187                http_client.clone(),
188                base_url,
189                token,
190                &statement_handle,
191                0,
192            )
193            .await?;
194
195            if response.code.as_deref() != Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
196                break;
197            }
198        }
199
200        if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
201            anyhow::bail!(
202                "query still running after {} poll attempts ({} seconds)",
203                MAX_POLL_ATTEMPTS,
204                MAX_POLL_ATTEMPTS as u64 * POLL_INTERVAL.as_secs()
205            );
206        }
207    }
208
209    if let Some(code) = &response.code {
210        if code != SNOWFLAKE_SUCCESS_CODE {
211            anyhow::bail!(
212                "snowflake error: code={} message={}",
213                code,
214                response.message.as_deref().unwrap_or("<no message>")
215            );
216        }
217    }
218
219    Ok(response)
220}
221
222async fn fetch_existing_deployments(
223    http_client: &Arc<dyn HttpClient>,
224    base_url: &str,
225    token: &str,
226    role: &Option<String>,
227) -> Result<HashMap<String, ExistingDeployment>> {
228    let statement = format!(
229        r#"
230SELECT
231    event_properties:model_version_id::string AS model_version_id,
232    event_properties:experiment_name::string AS experiment_name,
233    event_properties:environment::string AS environment
234FROM events
235WHERE event_type = '{EDIT_PREDICTION_DEPLOYMENT_EVENT}'
236"#
237    );
238
239    let request = json!({
240        "statement": statement,
241        "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
242        "database": "EVENTS",
243        "schema": "PUBLIC",
244        "warehouse": "DBT",
245        "role": role,
246    });
247
248    let response = run_sql_with_polling(http_client.clone(), base_url, token, &request).await?;
249
250    let col_names = ["model_version_id", "experiment_name", "environment"];
251    let column_indices =
252        pull_examples::get_column_indices(&response.result_set_meta_data, &col_names);
253
254    let mut existing = HashMap::new();
255
256    for data_row in &response.data {
257        let get_string = |name: &str| -> Option<String> {
258            let &index = column_indices.get(name)?;
259            match data_row.get(index) {
260                Some(JsonValue::String(s)) => Some(s.clone()),
261                _ => None,
262            }
263        };
264
265        let Some(model_version_id) = get_string("model_version_id") else {
266            continue;
267        };
268        let experiment_name = get_string("experiment_name").unwrap_or_default();
269        let environment = get_string("environment").unwrap_or_default();
270
271        existing.insert(
272            model_version_id,
273            ExistingDeployment {
274                experiment_name,
275                environment,
276            },
277        );
278    }
279
280    Ok(existing)
281}
282
283async fn insert_deployment(
284    http_client: &Arc<dyn HttpClient>,
285    base_url: &str,
286    token: &str,
287    role: &Option<String>,
288    record: &DeploymentRecord,
289) -> Result<()> {
290    let event_properties = json!({
291        "model_id": record.model_id,
292        "model_version_id": record.model_version_id,
293        "experiment_name": record.experiment_name,
294        "environment": record.environment,
295        "status": record.status,
296        "created_at": record.created_at,
297    });
298
299    let event_properties_str =
300        serde_json::to_string(&event_properties).context("failed to serialize event_properties")?;
301
302    let statement = r#"
303INSERT INTO events (event_type, event_properties, device_id, time)
304VALUES (?, PARSE_JSON(?), 'ep-cli', CURRENT_TIMESTAMP())
305"#;
306
307    let bindings = json!({
308        "1": { "type": "TEXT", "value": EDIT_PREDICTION_DEPLOYMENT_EVENT },
309        "2": { "type": "TEXT", "value": event_properties_str }
310    });
311
312    let request = json!({
313        "statement": statement,
314        "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
315        "database": "EVENTS",
316        "schema": "PUBLIC",
317        "warehouse": "DBT",
318        "role": role,
319        "bindings": bindings
320    });
321
322    run_sql_with_polling(http_client.clone(), base_url, token, &request).await?;
323    Ok(())
324}
325
326async fn update_deployment(
327    http_client: &Arc<dyn HttpClient>,
328    base_url: &str,
329    token: &str,
330    role: &Option<String>,
331    record: &DeploymentRecord,
332) -> Result<()> {
333    let statement = format!(
334        r#"
335UPDATE events
336SET
337    event_properties = OBJECT_INSERT(
338        OBJECT_INSERT(event_properties, 'environment', ?::VARIANT, true),
339        'experiment_name', ?::VARIANT, true
340    ),
341    time = CURRENT_TIMESTAMP()
342WHERE event_type = '{EDIT_PREDICTION_DEPLOYMENT_EVENT}'
343    AND event_properties:model_version_id::string = ?
344"#
345    );
346
347    let bindings = json!({
348        "1": { "type": "TEXT", "value": record.environment },
349        "2": { "type": "TEXT", "value": record.experiment_name },
350        "3": { "type": "TEXT", "value": record.model_version_id }
351    });
352
353    let request = json!({
354        "statement": statement,
355        "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
356        "database": "EVENTS",
357        "schema": "PUBLIC",
358        "warehouse": "DBT",
359        "role": role,
360        "bindings": bindings
361    });
362
363    run_sql_with_polling(http_client.clone(), base_url, token, &request).await?;
364    Ok(())
365}
366
367fn display_deployments(existing: &HashMap<String, ExistingDeployment>) {
368    let col_names = ["version_id", "experiment", "environment"];
369
370    let mut col_widths: Vec<usize> = col_names.iter().map(|n| n.len()).collect();
371    let mut rows: Vec<[String; 3]> = Vec::new();
372
373    for (version_id, deployment) in existing {
374        let row = [
375            version_id.clone(),
376            deployment.experiment_name.clone(),
377            deployment.environment.clone(),
378        ];
379        for (i, val) in row.iter().enumerate() {
380            col_widths[i] = col_widths[i].max(val.len());
381        }
382        rows.push(row);
383    }
384
385    rows.sort_by(|a, b| a[2].cmp(&b[2]).then_with(|| a[1].cmp(&b[1])));
386
387    let print_row = |values: &[&str]| {
388        for (i, val) in values.iter().enumerate() {
389            if i > 0 {
390                eprint!("  ");
391            }
392            eprint!("{:width$}", val, width = col_widths[i]);
393        }
394        eprintln!();
395    };
396
397    eprintln!();
398    print_row(&col_names);
399
400    let separators: Vec<String> = col_widths.iter().map(|w| "".repeat(*w)).collect();
401    let separator_refs: Vec<&str> = separators.iter().map(|s| s.as_str()).collect();
402    print_row(&separator_refs);
403
404    for row in &rows {
405        let refs: Vec<&str> = row.iter().map(|s| s.as_str()).collect();
406        print_row(&refs);
407    }
408}
409
410pub async fn run_sync_deployments(
411    http_client: Arc<dyn HttpClient>,
412    model_name: Option<String>,
413) -> Result<()> {
414    let baseten_api_key = std::env::var("BASETEN_API_KEY")
415        .context("missing required environment variable BASETEN_API_KEY")?;
416    let snowflake_token = std::env::var("EP_SNOWFLAKE_API_KEY")
417        .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
418    let snowflake_base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
419        "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
420    )?;
421    let snowflake_role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
422
423    let model_name = model_name.unwrap_or_else(|| DEFAULT_BASETEN_MODEL_NAME.to_string());
424
425    let models = fetch_baseten_models(&http_client, &baseten_api_key).await?;
426
427    let model = models
428        .iter()
429        .find(|m| m.name == model_name)
430        .with_context(|| {
431            let available: Vec<&str> = models.iter().map(|m| m.name.as_str()).collect();
432            format!(
433                "model '{}' not found on baseten. Available: {:?}",
434                model_name, available
435            )
436        })?;
437
438    eprintln!("Fetching existing deployments from Snowflake...");
439    let mut existing = fetch_existing_deployments(
440        &http_client,
441        &snowflake_base_url,
442        &snowflake_token,
443        &snowflake_role,
444    )
445    .await
446    .context("failed to fetch existing deployments from Snowflake")?;
447
448    eprintln!(
449        "Found {} existing deployment(s) in Snowflake.",
450        existing.len()
451    );
452
453    let baseten_deployments = fetch_baseten_deployments(&http_client, &baseten_api_key, &model.id)
454        .await
455        .with_context(|| format!("failed to fetch deployments for model '{}'", model.name))?;
456
457    let records = collect_deployment_records(&model.id, &baseten_deployments);
458
459    if records.is_empty() {
460        eprintln!("No deployments found on Baseten.");
461        return Ok(());
462    }
463
464    eprintln!(
465        "Found {} deployment(s) on Baseten for model '{}'.",
466        records.len(),
467        model.name
468    );
469
470    let mut inserts = Vec::new();
471    let mut updates = Vec::new();
472    let mut unchanged = 0;
473
474    for record in &records {
475        match existing.get(&record.model_version_id) {
476            Some(existing_deployment) => {
477                let environment_changed = existing_deployment.environment != record.environment;
478                let experiment_changed =
479                    existing_deployment.experiment_name != record.experiment_name;
480
481                if environment_changed || experiment_changed {
482                    updates.push(record);
483                } else {
484                    unchanged += 1;
485                }
486            }
487            None => {
488                inserts.push(record);
489            }
490        }
491    }
492
493    eprintln!(
494        "Diff: {} insert(s), {} update(s), {} unchanged",
495        inserts.len(),
496        updates.len(),
497        unchanged,
498    );
499
500    for (i, record) in inserts.iter().enumerate() {
501        eprintln!(
502            "  INSERT [{}/{}] {} -> {} (version_id={})",
503            i + 1,
504            inserts.len(),
505            record.experiment_name,
506            record.environment,
507            record.model_version_id,
508        );
509        insert_deployment(
510            &http_client,
511            &snowflake_base_url,
512            &snowflake_token,
513            &snowflake_role,
514            record,
515        )
516        .await
517        .with_context(|| {
518            format!(
519                "failed to insert deployment '{}' (model_version_id={})",
520                record.experiment_name, record.model_version_id
521            )
522        })?;
523
524        existing.insert(
525            record.model_version_id.clone(),
526            ExistingDeployment {
527                experiment_name: record.experiment_name.clone(),
528                environment: record.environment.clone(),
529            },
530        );
531    }
532
533    for (i, record) in updates.iter().enumerate() {
534        let existing_deployment = existing
535            .get(&record.model_version_id)
536            .context("update record missing from existing map")?;
537        eprintln!(
538            "  UPDATE [{}/{}] version_id={}: environment '{}' -> '{}', experiment '{}' -> '{}'",
539            i + 1,
540            updates.len(),
541            record.model_version_id,
542            existing_deployment.environment,
543            record.environment,
544            existing_deployment.experiment_name,
545            record.experiment_name,
546        );
547        update_deployment(
548            &http_client,
549            &snowflake_base_url,
550            &snowflake_token,
551            &snowflake_role,
552            record,
553        )
554        .await
555        .with_context(|| {
556            format!(
557                "failed to update deployment '{}' (model_version_id={})",
558                record.experiment_name, record.model_version_id
559            )
560        })?;
561
562        existing.insert(
563            record.model_version_id.clone(),
564            ExistingDeployment {
565                experiment_name: record.experiment_name.clone(),
566                environment: record.environment.clone(),
567            },
568        );
569    }
570
571    if inserts.is_empty() && updates.is_empty() {
572        eprintln!("All deployments up to date, no writes needed.");
573    }
574
575    display_deployments(&existing);
576
577    Ok(())
578}