zeta: Report Fireworks request data to Snowflake (#22973)

Thorsten Ball , Antonio Scandurra , and Conrad created

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Conrad <conrad@zed.dev>

Change summary

Cargo.lock                        |  12 ++
Cargo.toml                        |   2 
crates/collab/Cargo.toml          |   1 
crates/collab/src/llm.rs          |  31 +++++
crates/fireworks/Cargo.toml       |  19 +++
crates/fireworks/LICENSE-GPL      |   1 
crates/fireworks/src/fireworks.rs | 173 +++++++++++++++++++++++++++++++++
7 files changed, 236 insertions(+), 3 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -2666,6 +2666,7 @@ dependencies = [
  "envy",
  "extension",
  "file_finder",
+ "fireworks",
  "fs",
  "futures 0.3.31",
  "git",
@@ -4590,6 +4591,17 @@ dependencies = [
  "windows-sys 0.59.0",
 ]
 
+[[package]]
+name = "fireworks"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "futures 0.3.31",
+ "http_client",
+ "serde",
+ "serde_json",
+]
+
 [[package]]
 name = "fixedbitset"
 version = "0.4.2"

Cargo.toml 🔗

@@ -40,6 +40,7 @@ members = [
     "crates/feedback",
     "crates/file_finder",
     "crates/file_icons",
+    "crates/fireworks",
     "crates/fs",
     "crates/fsevent",
     "crates/fuzzy",
@@ -222,6 +223,7 @@ feature_flags = { path = "crates/feature_flags" }
 feedback = { path = "crates/feedback" }
 file_finder = { path = "crates/file_finder" }
 file_icons = { path = "crates/file_icons" }
+fireworks = { path = "crates/fireworks" }
 fs = { path = "crates/fs" }
 fsevent = { path = "crates/fsevent" }
 fuzzy = { path = "crates/fuzzy" }

crates/collab/Cargo.toml 🔗

@@ -34,6 +34,7 @@ collections.workspace = true
 dashmap.workspace = true
 derive_more.workspace = true
 envy = "0.4.2"
+fireworks.workspace = true
 futures.workspace = true
 google_ai.workspace = true
 hex.workspace = true

crates/collab/src/llm.rs 🔗

@@ -470,23 +470,48 @@ async fn predict_edits(
         .replace("<outline>", &outline_prefix)
         .replace("<events>", &params.input_events)
         .replace("<excerpt>", &params.input_excerpt);
-    let mut response = open_ai::complete_text(
+    let mut response = fireworks::complete(
         &state.http_client,
         api_url,
         api_key,
-        open_ai::CompletionRequest {
+        fireworks::CompletionRequest {
             model: model.to_string(),
             prompt: prompt.clone(),
             max_tokens: 2048,
             temperature: 0.,
-            prediction: Some(open_ai::Prediction::Content {
+            prediction: Some(fireworks::Prediction::Content {
                 content: params.input_excerpt,
             }),
             rewrite_speculation: Some(true),
         },
     )
     .await?;
+
+    state.executor.spawn_detached({
+        let kinesis_client = state.kinesis_client.clone();
+        let kinesis_stream = state.config.kinesis_stream.clone();
+        let headers = response.headers.clone();
+        let model = model.clone();
+
+        async move {
+            SnowflakeRow::new(
+                "Fireworks Completion Requested",
+                claims.metrics_id,
+                claims.is_staff,
+                claims.system_id.clone(),
+                json!({
+                    "model": model.to_string(),
+                    "headers": headers,
+                }),
+            )
+            .write(&kinesis_client, &kinesis_stream)
+            .await
+            .log_err();
+        }
+    });
+
     let choice = response
+        .completion
         .choices
         .pop()
         .context("no output from completion response")?;

crates/fireworks/Cargo.toml 🔗

@@ -0,0 +1,19 @@
+[package]
+name = "fireworks"
+version = "0.1.0"
+edition = "2021"
+publish = false
+license = "GPL-3.0-or-later"
+
+[lints]
+workspace = true
+
+[lib]
+path = "src/fireworks.rs"
+
+[dependencies]
+anyhow.workspace = true
+futures.workspace = true
+http_client.workspace = true
+serde.workspace = true
+serde_json.workspace = true

crates/fireworks/src/fireworks.rs 🔗

@@ -0,0 +1,173 @@
+use anyhow::{anyhow, Result};
+use futures::AsyncReadExt;
+use http_client::{http::HeaderMap, AsyncBody, HttpClient, Method, Request as HttpRequest};
+use serde::{Deserialize, Serialize};
+
+pub const FIREWORKS_API_URL: &str = "https://api.openai.com/v1";
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct CompletionRequest {
+    pub model: String,
+    pub prompt: String,
+    pub max_tokens: u32,
+    pub temperature: f32,
+    #[serde(default, skip_serializing_if = "Option::is_none")]
+    pub prediction: Option<Prediction>,
+    #[serde(default, skip_serializing_if = "Option::is_none")]
+    pub rewrite_speculation: Option<bool>,
+}
+
+#[derive(Clone, Deserialize, Serialize, Debug)]
+#[serde(tag = "type", rename_all = "snake_case")]
+pub enum Prediction {
+    Content { content: String },
+}
+
+#[derive(Debug)]
+pub struct Response {
+    pub completion: CompletionResponse,
+    pub headers: Headers,
+}
+
+#[derive(Serialize, Deserialize, Debug)]
+pub struct CompletionResponse {
+    pub id: String,
+    pub object: String,
+    pub created: u64,
+    pub model: String,
+    pub choices: Vec<CompletionChoice>,
+    pub usage: Usage,
+}
+
+#[derive(Serialize, Deserialize, Debug)]
+pub struct CompletionChoice {
+    pub text: String,
+}
+
+#[derive(Serialize, Deserialize, Debug)]
+pub struct Usage {
+    pub prompt_tokens: u32,
+    pub completion_tokens: u32,
+    pub total_tokens: u32,
+}
+
+#[derive(Debug, Clone, Default, Serialize)]
+pub struct Headers {
+    pub server_processing_time: Option<f64>,
+    pub request_id: Option<String>,
+    pub prompt_tokens: Option<u32>,
+    pub speculation_generated_tokens: Option<u32>,
+    pub cached_prompt_tokens: Option<u32>,
+    pub backend_host: Option<String>,
+    pub num_concurrent_requests: Option<u32>,
+    pub deployment: Option<String>,
+    pub tokenizer_queue_duration: Option<f64>,
+    pub tokenizer_duration: Option<f64>,
+    pub prefill_queue_duration: Option<f64>,
+    pub prefill_duration: Option<f64>,
+    pub generation_queue_duration: Option<f64>,
+}
+
+impl Headers {
+    pub fn parse(headers: &HeaderMap) -> Self {
+        Headers {
+            request_id: headers
+                .get("x-request-id")
+                .and_then(|v| v.to_str().ok())
+                .map(String::from),
+            server_processing_time: headers
+                .get("fireworks-server-processing-time")
+                .and_then(|v| v.to_str().ok()?.parse().ok()),
+            prompt_tokens: headers
+                .get("fireworks-prompt-tokens")
+                .and_then(|v| v.to_str().ok()?.parse().ok()),
+            speculation_generated_tokens: headers
+                .get("fireworks-speculation-generated-tokens")
+                .and_then(|v| v.to_str().ok()?.parse().ok()),
+            cached_prompt_tokens: headers
+                .get("fireworks-cached-prompt-tokens")
+                .and_then(|v| v.to_str().ok()?.parse().ok()),
+            backend_host: headers
+                .get("fireworks-backend-host")
+                .and_then(|v| v.to_str().ok())
+                .map(String::from),
+            num_concurrent_requests: headers
+                .get("fireworks-num-concurrent-requests")
+                .and_then(|v| v.to_str().ok()?.parse().ok()),
+            deployment: headers
+                .get("fireworks-deployment")
+                .and_then(|v| v.to_str().ok())
+                .map(String::from),
+            tokenizer_queue_duration: headers
+                .get("fireworks-tokenizer-queue-duration")
+                .and_then(|v| v.to_str().ok()?.parse().ok()),
+            tokenizer_duration: headers
+                .get("fireworks-tokenizer-duration")
+                .and_then(|v| v.to_str().ok()?.parse().ok()),
+            prefill_queue_duration: headers
+                .get("fireworks-prefill-queue-duration")
+                .and_then(|v| v.to_str().ok()?.parse().ok()),
+            prefill_duration: headers
+                .get("fireworks-prefill-duration")
+                .and_then(|v| v.to_str().ok()?.parse().ok()),
+            generation_queue_duration: headers
+                .get("fireworks-generation-queue-duration")
+                .and_then(|v| v.to_str().ok()?.parse().ok()),
+        }
+    }
+}
+
+pub async fn complete(
+    client: &dyn HttpClient,
+    api_url: &str,
+    api_key: &str,
+    request: CompletionRequest,
+) -> Result<Response> {
+    let uri = format!("{api_url}/completions");
+    let request_builder = HttpRequest::builder()
+        .method(Method::POST)
+        .uri(uri)
+        .header("Content-Type", "application/json")
+        .header("Authorization", format!("Bearer {}", api_key));
+
+    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
+    let mut response = client.send(request).await?;
+
+    if response.status().is_success() {
+        let headers = Headers::parse(response.headers());
+
+        let mut body = String::new();
+        response.body_mut().read_to_string(&mut body).await?;
+
+        Ok(Response {
+            completion: serde_json::from_str(&body)?,
+            headers,
+        })
+    } else {
+        let mut body = String::new();
+        response.body_mut().read_to_string(&mut body).await?;
+
+        #[derive(Deserialize)]
+        struct FireworksResponse {
+            error: FireworksError,
+        }
+
+        #[derive(Deserialize)]
+        struct FireworksError {
+            message: String,
+        }
+
+        match serde_json::from_str::<FireworksResponse>(&body) {
+            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
+                "Failed to connect to Fireworks API: {}",
+                response.error.message,
+            )),
+
+            _ => Err(anyhow!(
+                "Failed to connect to Fireworks API: {} {}",
+                response.status(),
+                body,
+            )),
+        }
+    }
+}