1use anyhow::{anyhow, Result};
2use futures::AsyncReadExt;
3use http_client::{http::HeaderMap, AsyncBody, HttpClient, Method, Request as HttpRequest};
4use serde::{Deserialize, Serialize};
5
6pub const FIREWORKS_API_URL: &str = "https://api.openai.com/v1";
7
8#[derive(Debug, Serialize, Deserialize)]
9pub struct CompletionRequest {
10 pub model: String,
11 pub prompt: String,
12 pub max_tokens: u32,
13 pub temperature: f32,
14 #[serde(default, skip_serializing_if = "Option::is_none")]
15 pub prediction: Option<Prediction>,
16 #[serde(default, skip_serializing_if = "Option::is_none")]
17 pub rewrite_speculation: Option<bool>,
18}
19
20#[derive(Clone, Deserialize, Serialize, Debug)]
21#[serde(tag = "type", rename_all = "snake_case")]
22pub enum Prediction {
23 Content { content: String },
24}
25
26#[derive(Debug)]
27pub struct Response {
28 pub completion: CompletionResponse,
29 pub headers: Headers,
30}
31
32#[derive(Serialize, Deserialize, Debug)]
33pub struct CompletionResponse {
34 pub id: String,
35 pub object: String,
36 pub created: u64,
37 pub model: String,
38 pub choices: Vec<CompletionChoice>,
39 pub usage: Usage,
40}
41
42#[derive(Serialize, Deserialize, Debug)]
43pub struct CompletionChoice {
44 pub text: String,
45}
46
47#[derive(Serialize, Deserialize, Debug)]
48pub struct Usage {
49 pub prompt_tokens: u32,
50 pub completion_tokens: u32,
51 pub total_tokens: u32,
52}
53
54#[derive(Debug, Clone, Default, Serialize)]
55pub struct Headers {
56 pub server_processing_time: Option<f64>,
57 pub request_id: Option<String>,
58 pub prompt_tokens: Option<u32>,
59 pub speculation_generated_tokens: Option<u32>,
60 pub cached_prompt_tokens: Option<u32>,
61 pub backend_host: Option<String>,
62 pub num_concurrent_requests: Option<u32>,
63 pub deployment: Option<String>,
64 pub tokenizer_queue_duration: Option<f64>,
65 pub tokenizer_duration: Option<f64>,
66 pub prefill_queue_duration: Option<f64>,
67 pub prefill_duration: Option<f64>,
68 pub generation_queue_duration: Option<f64>,
69}
70
71impl Headers {
72 pub fn parse(headers: &HeaderMap) -> Self {
73 Headers {
74 request_id: headers
75 .get("x-request-id")
76 .and_then(|v| v.to_str().ok())
77 .map(String::from),
78 server_processing_time: headers
79 .get("fireworks-server-processing-time")
80 .and_then(|v| v.to_str().ok()?.parse().ok()),
81 prompt_tokens: headers
82 .get("fireworks-prompt-tokens")
83 .and_then(|v| v.to_str().ok()?.parse().ok()),
84 speculation_generated_tokens: headers
85 .get("fireworks-speculation-generated-tokens")
86 .and_then(|v| v.to_str().ok()?.parse().ok()),
87 cached_prompt_tokens: headers
88 .get("fireworks-cached-prompt-tokens")
89 .and_then(|v| v.to_str().ok()?.parse().ok()),
90 backend_host: headers
91 .get("fireworks-backend-host")
92 .and_then(|v| v.to_str().ok())
93 .map(String::from),
94 num_concurrent_requests: headers
95 .get("fireworks-num-concurrent-requests")
96 .and_then(|v| v.to_str().ok()?.parse().ok()),
97 deployment: headers
98 .get("fireworks-deployment")
99 .and_then(|v| v.to_str().ok())
100 .map(String::from),
101 tokenizer_queue_duration: headers
102 .get("fireworks-tokenizer-queue-duration")
103 .and_then(|v| v.to_str().ok()?.parse().ok()),
104 tokenizer_duration: headers
105 .get("fireworks-tokenizer-duration")
106 .and_then(|v| v.to_str().ok()?.parse().ok()),
107 prefill_queue_duration: headers
108 .get("fireworks-prefill-queue-duration")
109 .and_then(|v| v.to_str().ok()?.parse().ok()),
110 prefill_duration: headers
111 .get("fireworks-prefill-duration")
112 .and_then(|v| v.to_str().ok()?.parse().ok()),
113 generation_queue_duration: headers
114 .get("fireworks-generation-queue-duration")
115 .and_then(|v| v.to_str().ok()?.parse().ok()),
116 }
117 }
118}
119
120pub async fn complete(
121 client: &dyn HttpClient,
122 api_url: &str,
123 api_key: &str,
124 request: CompletionRequest,
125) -> Result<Response> {
126 let uri = format!("{api_url}/completions");
127 let request_builder = HttpRequest::builder()
128 .method(Method::POST)
129 .uri(uri)
130 .header("Content-Type", "application/json")
131 .header("Authorization", format!("Bearer {}", api_key));
132
133 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
134 let mut response = client.send(request).await?;
135
136 if response.status().is_success() {
137 let headers = Headers::parse(response.headers());
138
139 let mut body = String::new();
140 response.body_mut().read_to_string(&mut body).await?;
141
142 Ok(Response {
143 completion: serde_json::from_str(&body)?,
144 headers,
145 })
146 } else {
147 let mut body = String::new();
148 response.body_mut().read_to_string(&mut body).await?;
149
150 #[derive(Deserialize)]
151 struct FireworksResponse {
152 error: FireworksError,
153 }
154
155 #[derive(Deserialize)]
156 struct FireworksError {
157 message: String,
158 }
159
160 match serde_json::from_str::<FireworksResponse>(&body) {
161 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
162 "Failed to connect to Fireworks API: {}",
163 response.error.message,
164 )),
165
166 _ => Err(anyhow!(
167 "Failed to connect to Fireworks API: {} {}",
168 response.status(),
169 body,
170 )),
171 }
172 }
173}