fireworks.rs

  1use anyhow::{anyhow, Result};
  2use futures::AsyncReadExt;
  3use http_client::{http::HeaderMap, AsyncBody, HttpClient, Method, Request as HttpRequest};
  4use serde::{Deserialize, Serialize};
  5
  6pub const FIREWORKS_API_URL: &str = "https://api.openai.com/v1";
  7
  8#[derive(Debug, Serialize, Deserialize)]
  9pub struct CompletionRequest {
 10    pub model: String,
 11    pub prompt: String,
 12    pub max_tokens: u32,
 13    pub temperature: f32,
 14    #[serde(default, skip_serializing_if = "Option::is_none")]
 15    pub prediction: Option<Prediction>,
 16    #[serde(default, skip_serializing_if = "Option::is_none")]
 17    pub rewrite_speculation: Option<bool>,
 18}
 19
 20#[derive(Clone, Deserialize, Serialize, Debug)]
 21#[serde(tag = "type", rename_all = "snake_case")]
 22pub enum Prediction {
 23    Content { content: String },
 24}
 25
 26#[derive(Debug)]
 27pub struct Response {
 28    pub completion: CompletionResponse,
 29    pub headers: Headers,
 30}
 31
 32#[derive(Serialize, Deserialize, Debug)]
 33pub struct CompletionResponse {
 34    pub id: String,
 35    pub object: String,
 36    pub created: u64,
 37    pub model: String,
 38    pub choices: Vec<CompletionChoice>,
 39    pub usage: Usage,
 40}
 41
 42#[derive(Serialize, Deserialize, Debug)]
 43pub struct CompletionChoice {
 44    pub text: String,
 45}
 46
 47#[derive(Serialize, Deserialize, Debug)]
 48pub struct Usage {
 49    pub prompt_tokens: u32,
 50    pub completion_tokens: u32,
 51    pub total_tokens: u32,
 52}
 53
 54#[derive(Debug, Clone, Default, Serialize)]
 55pub struct Headers {
 56    pub server_processing_time: Option<f64>,
 57    pub request_id: Option<String>,
 58    pub prompt_tokens: Option<u32>,
 59    pub speculation_generated_tokens: Option<u32>,
 60    pub cached_prompt_tokens: Option<u32>,
 61    pub backend_host: Option<String>,
 62    pub num_concurrent_requests: Option<u32>,
 63    pub deployment: Option<String>,
 64    pub tokenizer_queue_duration: Option<f64>,
 65    pub tokenizer_duration: Option<f64>,
 66    pub prefill_queue_duration: Option<f64>,
 67    pub prefill_duration: Option<f64>,
 68    pub generation_queue_duration: Option<f64>,
 69}
 70
 71impl Headers {
 72    pub fn parse(headers: &HeaderMap) -> Self {
 73        Headers {
 74            request_id: headers
 75                .get("x-request-id")
 76                .and_then(|v| v.to_str().ok())
 77                .map(String::from),
 78            server_processing_time: headers
 79                .get("fireworks-server-processing-time")
 80                .and_then(|v| v.to_str().ok()?.parse().ok()),
 81            prompt_tokens: headers
 82                .get("fireworks-prompt-tokens")
 83                .and_then(|v| v.to_str().ok()?.parse().ok()),
 84            speculation_generated_tokens: headers
 85                .get("fireworks-speculation-generated-tokens")
 86                .and_then(|v| v.to_str().ok()?.parse().ok()),
 87            cached_prompt_tokens: headers
 88                .get("fireworks-cached-prompt-tokens")
 89                .and_then(|v| v.to_str().ok()?.parse().ok()),
 90            backend_host: headers
 91                .get("fireworks-backend-host")
 92                .and_then(|v| v.to_str().ok())
 93                .map(String::from),
 94            num_concurrent_requests: headers
 95                .get("fireworks-num-concurrent-requests")
 96                .and_then(|v| v.to_str().ok()?.parse().ok()),
 97            deployment: headers
 98                .get("fireworks-deployment")
 99                .and_then(|v| v.to_str().ok())
100                .map(String::from),
101            tokenizer_queue_duration: headers
102                .get("fireworks-tokenizer-queue-duration")
103                .and_then(|v| v.to_str().ok()?.parse().ok()),
104            tokenizer_duration: headers
105                .get("fireworks-tokenizer-duration")
106                .and_then(|v| v.to_str().ok()?.parse().ok()),
107            prefill_queue_duration: headers
108                .get("fireworks-prefill-queue-duration")
109                .and_then(|v| v.to_str().ok()?.parse().ok()),
110            prefill_duration: headers
111                .get("fireworks-prefill-duration")
112                .and_then(|v| v.to_str().ok()?.parse().ok()),
113            generation_queue_duration: headers
114                .get("fireworks-generation-queue-duration")
115                .and_then(|v| v.to_str().ok()?.parse().ok()),
116        }
117    }
118}
119
120pub async fn complete(
121    client: &dyn HttpClient,
122    api_url: &str,
123    api_key: &str,
124    request: CompletionRequest,
125) -> Result<Response> {
126    let uri = format!("{api_url}/completions");
127    let request_builder = HttpRequest::builder()
128        .method(Method::POST)
129        .uri(uri)
130        .header("Content-Type", "application/json")
131        .header("Authorization", format!("Bearer {}", api_key));
132
133    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
134    let mut response = client.send(request).await?;
135
136    if response.status().is_success() {
137        let headers = Headers::parse(response.headers());
138
139        let mut body = String::new();
140        response.body_mut().read_to_string(&mut body).await?;
141
142        Ok(Response {
143            completion: serde_json::from_str(&body)?,
144            headers,
145        })
146    } else {
147        let mut body = String::new();
148        response.body_mut().read_to_string(&mut body).await?;
149
150        #[derive(Deserialize)]
151        struct FireworksResponse {
152            error: FireworksError,
153        }
154
155        #[derive(Deserialize)]
156        struct FireworksError {
157            message: String,
158        }
159
160        match serde_json::from_str::<FireworksResponse>(&body) {
161            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
162                "Failed to connect to Fireworks API: {}",
163                response.error.message,
164            )),
165
166            _ => Err(anyhow!(
167                "Failed to connect to Fireworks API: {} {}",
168                response.status(),
169                body,
170            )),
171        }
172    }
173}