1use anyhow::{Context as _, Result};
2use http_client::{AsyncBody, HttpClient, Method, Request};
3use serde::Deserialize;
4use serde_json::{Value as JsonValue, json};
5use std::collections::HashMap;
6use std::sync::Arc;
7
8use crate::pull_examples::{
9 self, MAX_POLL_ATTEMPTS, POLL_INTERVAL, SNOWFLAKE_ASYNC_IN_PROGRESS_CODE,
10 SNOWFLAKE_SUCCESS_CODE,
11};
12
13const DEFAULT_BASETEN_MODEL_NAME: &str = "zeta-2";
14const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120;
15pub(crate) const EDIT_PREDICTION_DEPLOYMENT_EVENT: &str = "Edit Prediction Deployment";
16
17#[derive(Debug, Clone, Deserialize)]
18struct BasetenModelsResponse {
19 models: Vec<BasetenModel>,
20}
21
22#[derive(Debug, Clone, Deserialize)]
23struct BasetenModel {
24 id: String,
25 name: String,
26}
27
28#[derive(Debug, Clone, Deserialize)]
29struct BasetenDeploymentsResponse {
30 deployments: Vec<BasetenDeployment>,
31}
32
33#[derive(Debug, Clone, Deserialize)]
34struct BasetenDeployment {
35 id: String,
36 name: String,
37 #[serde(default)]
38 status: Option<String>,
39 #[serde(default)]
40 created_at: Option<String>,
41 #[serde(default)]
42 environment: Option<String>,
43}
44
45#[derive(Debug, Clone)]
46struct DeploymentRecord {
47 model_id: String,
48 model_version_id: String,
49 experiment_name: String,
50 environment: String,
51 status: String,
52 created_at: String,
53}
54
55#[derive(Debug, Clone)]
56struct ExistingDeployment {
57 experiment_name: String,
58 environment: String,
59}
60
61async fn fetch_baseten_models(
62 http_client: &Arc<dyn HttpClient>,
63 api_key: &str,
64) -> Result<Vec<BasetenModel>> {
65 let request = Request::builder()
66 .method(Method::GET)
67 .uri("https://api.baseten.co/v1/models")
68 .header("Authorization", format!("Api-Key {api_key}"))
69 .header("Accept", "application/json")
70 .body(AsyncBody::empty())?;
71
72 let response = http_client
73 .send(request)
74 .await
75 .context("failed to fetch baseten models")?;
76
77 let status = response.status();
78 let body_bytes = {
79 use futures::AsyncReadExt as _;
80 let mut body = response.into_body();
81 let mut bytes = Vec::new();
82 body.read_to_end(&mut bytes)
83 .await
84 .context("failed to read baseten models response")?;
85 bytes
86 };
87
88 if !status.is_success() {
89 let body_text = String::from_utf8_lossy(&body_bytes);
90 anyhow::bail!("baseten models API http {}: {}", status.as_u16(), body_text);
91 }
92
93 let parsed: BasetenModelsResponse =
94 serde_json::from_slice(&body_bytes).context("failed to parse baseten models response")?;
95 Ok(parsed.models)
96}
97
98async fn fetch_baseten_deployments(
99 http_client: &Arc<dyn HttpClient>,
100 api_key: &str,
101 model_id: &str,
102) -> Result<Vec<BasetenDeployment>> {
103 let url = format!("https://api.baseten.co/v1/models/{model_id}/deployments");
104 let request = Request::builder()
105 .method(Method::GET)
106 .uri(url.as_str())
107 .header("Authorization", format!("Api-Key {api_key}"))
108 .header("Accept", "application/json")
109 .body(AsyncBody::empty())?;
110
111 let response = http_client
112 .send(request)
113 .await
114 .context("failed to fetch baseten deployments")?;
115
116 let status = response.status();
117 let body_bytes = {
118 use futures::AsyncReadExt as _;
119 let mut body = response.into_body();
120 let mut bytes = Vec::new();
121 body.read_to_end(&mut bytes)
122 .await
123 .context("failed to read baseten deployments response")?;
124 bytes
125 };
126
127 if !status.is_success() {
128 let body_text = String::from_utf8_lossy(&body_bytes);
129 anyhow::bail!(
130 "baseten deployments API http {}: {}",
131 status.as_u16(),
132 body_text
133 );
134 }
135
136 let parsed: BasetenDeploymentsResponse =
137 serde_json::from_slice(&body_bytes).context("failed to parse deployments response")?;
138 Ok(parsed.deployments)
139}
140
141fn collect_deployment_records(
142 model_id: &str,
143 deployments: &[BasetenDeployment],
144) -> Vec<DeploymentRecord> {
145 deployments
146 .iter()
147 .map(|deployment| DeploymentRecord {
148 model_id: model_id.to_string(),
149 model_version_id: deployment.id.clone(),
150 experiment_name: deployment.name.clone(),
151 environment: deployment
152 .environment
153 .clone()
154 .unwrap_or_else(|| "none".to_string()),
155 status: deployment
156 .status
157 .clone()
158 .unwrap_or_else(|| "unknown".to_string()),
159 created_at: deployment
160 .created_at
161 .clone()
162 .unwrap_or_else(|| "unknown".to_string()),
163 })
164 .collect()
165}
166
167async fn run_sql_with_polling(
168 http_client: Arc<dyn HttpClient>,
169 base_url: &str,
170 token: &str,
171 request: &serde_json::Value,
172) -> Result<pull_examples::SnowflakeStatementResponse> {
173 let mut response =
174 pull_examples::run_sql(http_client.clone(), base_url, token, request).await?;
175
176 if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
177 let statement_handle = response
178 .statement_handle
179 .as_ref()
180 .context("async query response missing statementHandle")?
181 .clone();
182
183 for _attempt in 1..=MAX_POLL_ATTEMPTS {
184 std::thread::sleep(POLL_INTERVAL);
185
186 response = pull_examples::fetch_partition(
187 http_client.clone(),
188 base_url,
189 token,
190 &statement_handle,
191 0,
192 )
193 .await?;
194
195 if response.code.as_deref() != Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
196 break;
197 }
198 }
199
200 if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
201 anyhow::bail!(
202 "query still running after {} poll attempts ({} seconds)",
203 MAX_POLL_ATTEMPTS,
204 MAX_POLL_ATTEMPTS as u64 * POLL_INTERVAL.as_secs()
205 );
206 }
207 }
208
209 if let Some(code) = &response.code {
210 if code != SNOWFLAKE_SUCCESS_CODE {
211 anyhow::bail!(
212 "snowflake error: code={} message={}",
213 code,
214 response.message.as_deref().unwrap_or("<no message>")
215 );
216 }
217 }
218
219 Ok(response)
220}
221
222async fn fetch_existing_deployments(
223 http_client: &Arc<dyn HttpClient>,
224 base_url: &str,
225 token: &str,
226 role: &Option<String>,
227) -> Result<HashMap<String, ExistingDeployment>> {
228 let statement = format!(
229 r#"
230SELECT
231 event_properties:model_version_id::string AS model_version_id,
232 event_properties:experiment_name::string AS experiment_name,
233 event_properties:environment::string AS environment
234FROM events
235WHERE event_type = '{EDIT_PREDICTION_DEPLOYMENT_EVENT}'
236"#
237 );
238
239 let request = json!({
240 "statement": statement,
241 "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
242 "database": "EVENTS",
243 "schema": "PUBLIC",
244 "warehouse": "DBT",
245 "role": role,
246 });
247
248 let response = run_sql_with_polling(http_client.clone(), base_url, token, &request).await?;
249
250 let col_names = ["model_version_id", "experiment_name", "environment"];
251 let column_indices =
252 pull_examples::get_column_indices(&response.result_set_meta_data, &col_names);
253
254 let mut existing = HashMap::new();
255
256 for data_row in &response.data {
257 let get_string = |name: &str| -> Option<String> {
258 let &index = column_indices.get(name)?;
259 match data_row.get(index) {
260 Some(JsonValue::String(s)) => Some(s.clone()),
261 _ => None,
262 }
263 };
264
265 let Some(model_version_id) = get_string("model_version_id") else {
266 continue;
267 };
268 let experiment_name = get_string("experiment_name").unwrap_or_default();
269 let environment = get_string("environment").unwrap_or_default();
270
271 existing.insert(
272 model_version_id,
273 ExistingDeployment {
274 experiment_name,
275 environment,
276 },
277 );
278 }
279
280 Ok(existing)
281}
282
283async fn insert_deployment(
284 http_client: &Arc<dyn HttpClient>,
285 base_url: &str,
286 token: &str,
287 role: &Option<String>,
288 record: &DeploymentRecord,
289) -> Result<()> {
290 let event_properties = json!({
291 "model_id": record.model_id,
292 "model_version_id": record.model_version_id,
293 "experiment_name": record.experiment_name,
294 "environment": record.environment,
295 "status": record.status,
296 "created_at": record.created_at,
297 });
298
299 let event_properties_str =
300 serde_json::to_string(&event_properties).context("failed to serialize event_properties")?;
301
302 let statement = r#"
303INSERT INTO events (event_type, event_properties, device_id, time)
304VALUES (?, PARSE_JSON(?), 'ep-cli', CURRENT_TIMESTAMP())
305"#;
306
307 let bindings = json!({
308 "1": { "type": "TEXT", "value": EDIT_PREDICTION_DEPLOYMENT_EVENT },
309 "2": { "type": "TEXT", "value": event_properties_str }
310 });
311
312 let request = json!({
313 "statement": statement,
314 "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
315 "database": "EVENTS",
316 "schema": "PUBLIC",
317 "warehouse": "DBT",
318 "role": role,
319 "bindings": bindings
320 });
321
322 run_sql_with_polling(http_client.clone(), base_url, token, &request).await?;
323 Ok(())
324}
325
326async fn update_deployment(
327 http_client: &Arc<dyn HttpClient>,
328 base_url: &str,
329 token: &str,
330 role: &Option<String>,
331 record: &DeploymentRecord,
332) -> Result<()> {
333 let statement = format!(
334 r#"
335UPDATE events
336SET
337 event_properties = OBJECT_INSERT(
338 OBJECT_INSERT(event_properties, 'environment', ?::VARIANT, true),
339 'experiment_name', ?::VARIANT, true
340 ),
341 time = CURRENT_TIMESTAMP()
342WHERE event_type = '{EDIT_PREDICTION_DEPLOYMENT_EVENT}'
343 AND event_properties:model_version_id::string = ?
344"#
345 );
346
347 let bindings = json!({
348 "1": { "type": "TEXT", "value": record.environment },
349 "2": { "type": "TEXT", "value": record.experiment_name },
350 "3": { "type": "TEXT", "value": record.model_version_id }
351 });
352
353 let request = json!({
354 "statement": statement,
355 "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
356 "database": "EVENTS",
357 "schema": "PUBLIC",
358 "warehouse": "DBT",
359 "role": role,
360 "bindings": bindings
361 });
362
363 run_sql_with_polling(http_client.clone(), base_url, token, &request).await?;
364 Ok(())
365}
366
367fn display_deployments(existing: &HashMap<String, ExistingDeployment>) {
368 let col_names = ["version_id", "experiment", "environment"];
369
370 let mut col_widths: Vec<usize> = col_names.iter().map(|n| n.len()).collect();
371 let mut rows: Vec<[String; 3]> = Vec::new();
372
373 for (version_id, deployment) in existing {
374 let row = [
375 version_id.clone(),
376 deployment.experiment_name.clone(),
377 deployment.environment.clone(),
378 ];
379 for (i, val) in row.iter().enumerate() {
380 col_widths[i] = col_widths[i].max(val.len());
381 }
382 rows.push(row);
383 }
384
385 rows.sort_by(|a, b| a[2].cmp(&b[2]).then_with(|| a[1].cmp(&b[1])));
386
387 let print_row = |values: &[&str]| {
388 for (i, val) in values.iter().enumerate() {
389 if i > 0 {
390 eprint!(" ");
391 }
392 eprint!("{:width$}", val, width = col_widths[i]);
393 }
394 eprintln!();
395 };
396
397 eprintln!();
398 print_row(&col_names);
399
400 let separators: Vec<String> = col_widths.iter().map(|w| "─".repeat(*w)).collect();
401 let separator_refs: Vec<&str> = separators.iter().map(|s| s.as_str()).collect();
402 print_row(&separator_refs);
403
404 for row in &rows {
405 let refs: Vec<&str> = row.iter().map(|s| s.as_str()).collect();
406 print_row(&refs);
407 }
408}
409
410pub async fn run_sync_deployments(
411 http_client: Arc<dyn HttpClient>,
412 model_name: Option<String>,
413) -> Result<()> {
414 let baseten_api_key = std::env::var("BASETEN_API_KEY")
415 .context("missing required environment variable BASETEN_API_KEY")?;
416 let snowflake_token = std::env::var("EP_SNOWFLAKE_API_KEY")
417 .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
418 let snowflake_base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
419 "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
420 )?;
421 let snowflake_role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
422
423 let model_name = model_name.unwrap_or_else(|| DEFAULT_BASETEN_MODEL_NAME.to_string());
424
425 let models = fetch_baseten_models(&http_client, &baseten_api_key).await?;
426
427 let model = models
428 .iter()
429 .find(|m| m.name == model_name)
430 .with_context(|| {
431 let available: Vec<&str> = models.iter().map(|m| m.name.as_str()).collect();
432 format!(
433 "model '{}' not found on baseten. Available: {:?}",
434 model_name, available
435 )
436 })?;
437
438 eprintln!("Fetching existing deployments from Snowflake...");
439 let mut existing = fetch_existing_deployments(
440 &http_client,
441 &snowflake_base_url,
442 &snowflake_token,
443 &snowflake_role,
444 )
445 .await
446 .context("failed to fetch existing deployments from Snowflake")?;
447
448 eprintln!(
449 "Found {} existing deployment(s) in Snowflake.",
450 existing.len()
451 );
452
453 let baseten_deployments = fetch_baseten_deployments(&http_client, &baseten_api_key, &model.id)
454 .await
455 .with_context(|| format!("failed to fetch deployments for model '{}'", model.name))?;
456
457 let records = collect_deployment_records(&model.id, &baseten_deployments);
458
459 if records.is_empty() {
460 eprintln!("No deployments found on Baseten.");
461 return Ok(());
462 }
463
464 eprintln!(
465 "Found {} deployment(s) on Baseten for model '{}'.",
466 records.len(),
467 model.name
468 );
469
470 let mut inserts = Vec::new();
471 let mut updates = Vec::new();
472 let mut unchanged = 0;
473
474 for record in &records {
475 match existing.get(&record.model_version_id) {
476 Some(existing_deployment) => {
477 let environment_changed = existing_deployment.environment != record.environment;
478 let experiment_changed =
479 existing_deployment.experiment_name != record.experiment_name;
480
481 if environment_changed || experiment_changed {
482 updates.push(record);
483 } else {
484 unchanged += 1;
485 }
486 }
487 None => {
488 inserts.push(record);
489 }
490 }
491 }
492
493 eprintln!(
494 "Diff: {} insert(s), {} update(s), {} unchanged",
495 inserts.len(),
496 updates.len(),
497 unchanged,
498 );
499
500 for (i, record) in inserts.iter().enumerate() {
501 eprintln!(
502 " INSERT [{}/{}] {} -> {} (version_id={})",
503 i + 1,
504 inserts.len(),
505 record.experiment_name,
506 record.environment,
507 record.model_version_id,
508 );
509 insert_deployment(
510 &http_client,
511 &snowflake_base_url,
512 &snowflake_token,
513 &snowflake_role,
514 record,
515 )
516 .await
517 .with_context(|| {
518 format!(
519 "failed to insert deployment '{}' (model_version_id={})",
520 record.experiment_name, record.model_version_id
521 )
522 })?;
523
524 existing.insert(
525 record.model_version_id.clone(),
526 ExistingDeployment {
527 experiment_name: record.experiment_name.clone(),
528 environment: record.environment.clone(),
529 },
530 );
531 }
532
533 for (i, record) in updates.iter().enumerate() {
534 let existing_deployment = existing
535 .get(&record.model_version_id)
536 .context("update record missing from existing map")?;
537 eprintln!(
538 " UPDATE [{}/{}] version_id={}: environment '{}' -> '{}', experiment '{}' -> '{}'",
539 i + 1,
540 updates.len(),
541 record.model_version_id,
542 existing_deployment.environment,
543 record.environment,
544 existing_deployment.experiment_name,
545 record.experiment_name,
546 );
547 update_deployment(
548 &http_client,
549 &snowflake_base_url,
550 &snowflake_token,
551 &snowflake_role,
552 record,
553 )
554 .await
555 .with_context(|| {
556 format!(
557 "failed to update deployment '{}' (model_version_id={})",
558 record.experiment_name, record.model_version_id
559 )
560 })?;
561
562 existing.insert(
563 record.model_version_id.clone(),
564 ExistingDeployment {
565 experiment_name: record.experiment_name.clone(),
566 environment: record.environment.clone(),
567 },
568 );
569 }
570
571 if inserts.is_empty() && updates.is_empty() {
572 eprintln!("All deployments up to date, no writes needed.");
573 }
574
575 display_deployments(&existing);
576
577 Ok(())
578}