open_ai.rs

  1mod supported_countries;
  2
  3use anyhow::{anyhow, Context, Result};
  4use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
  5use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  6use isahc::config::Configurable;
  7use serde::{Deserialize, Serialize};
  8use serde_json::Value;
  9use std::{convert::TryFrom, future::Future, pin::Pin, time::Duration};
 10use strum::EnumIter;
 11
 12pub use supported_countries::*;
 13
 14pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
 15
 16fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
 17    opt.as_ref().map_or(true, |v| v.as_ref().is_empty())
 18}
 19
 20#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 21#[serde(rename_all = "lowercase")]
 22pub enum Role {
 23    User,
 24    Assistant,
 25    System,
 26    Tool,
 27}
 28
 29impl TryFrom<String> for Role {
 30    type Error = anyhow::Error;
 31
 32    fn try_from(value: String) -> Result<Self> {
 33        match value.as_str() {
 34            "user" => Ok(Self::User),
 35            "assistant" => Ok(Self::Assistant),
 36            "system" => Ok(Self::System),
 37            "tool" => Ok(Self::Tool),
 38            _ => Err(anyhow!("invalid role '{value}'")),
 39        }
 40    }
 41}
 42
 43impl From<Role> for String {
 44    fn from(val: Role) -> Self {
 45        match val {
 46            Role::User => "user".to_owned(),
 47            Role::Assistant => "assistant".to_owned(),
 48            Role::System => "system".to_owned(),
 49            Role::Tool => "tool".to_owned(),
 50        }
 51    }
 52}
 53
 54#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 55#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 56pub enum Model {
 57    #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo-0613")]
 58    ThreePointFiveTurbo,
 59    #[serde(rename = "gpt-4", alias = "gpt-4-0613")]
 60    Four,
 61    #[serde(rename = "gpt-4-turbo-preview", alias = "gpt-4-1106-preview")]
 62    FourTurbo,
 63    #[serde(rename = "gpt-4o", alias = "gpt-4o-2024-05-13")]
 64    #[default]
 65    FourOmni,
 66    #[serde(rename = "gpt-4o-mini", alias = "gpt-4o-mini-2024-07-18")]
 67    FourOmniMini,
 68    #[serde(rename = "custom")]
 69    Custom {
 70        name: String,
 71        max_tokens: usize,
 72        max_output_tokens: Option<u32>,
 73    },
 74}
 75
 76impl Model {
 77    pub fn from_id(id: &str) -> Result<Self> {
 78        match id {
 79            "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
 80            "gpt-4" => Ok(Self::Four),
 81            "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
 82            "gpt-4o" => Ok(Self::FourOmni),
 83            "gpt-4o-mini" => Ok(Self::FourOmniMini),
 84            _ => Err(anyhow!("invalid model id")),
 85        }
 86    }
 87
 88    pub fn id(&self) -> &str {
 89        match self {
 90            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
 91            Self::Four => "gpt-4",
 92            Self::FourTurbo => "gpt-4-turbo-preview",
 93            Self::FourOmni => "gpt-4o",
 94            Self::FourOmniMini => "gpt-4o-mini",
 95            Self::Custom { name, .. } => name,
 96        }
 97    }
 98
 99    pub fn display_name(&self) -> &str {
100        match self {
101            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
102            Self::Four => "gpt-4",
103            Self::FourTurbo => "gpt-4-turbo",
104            Self::FourOmni => "gpt-4o",
105            Self::FourOmniMini => "gpt-4o-mini",
106            Self::Custom { name, .. } => name,
107        }
108    }
109
110    pub fn max_token_count(&self) -> usize {
111        match self {
112            Self::ThreePointFiveTurbo => 4096,
113            Self::Four => 8192,
114            Self::FourTurbo => 128000,
115            Self::FourOmni => 128000,
116            Self::FourOmniMini => 128000,
117            Self::Custom { max_tokens, .. } => *max_tokens,
118        }
119    }
120
121    pub fn max_output_tokens(&self) -> Option<u32> {
122        match self {
123            Self::Custom {
124                max_output_tokens, ..
125            } => *max_output_tokens,
126            _ => None,
127        }
128    }
129}
130
131#[derive(Debug, Serialize, Deserialize)]
132pub struct Request {
133    pub model: String,
134    pub messages: Vec<RequestMessage>,
135    pub stream: bool,
136    #[serde(default, skip_serializing_if = "Option::is_none")]
137    pub max_tokens: Option<u32>,
138    pub stop: Vec<String>,
139    pub temperature: f32,
140    #[serde(default, skip_serializing_if = "Option::is_none")]
141    pub tool_choice: Option<ToolChoice>,
142    #[serde(default, skip_serializing_if = "Vec::is_empty")]
143    pub tools: Vec<ToolDefinition>,
144}
145
146#[derive(Debug, Serialize, Deserialize)]
147#[serde(untagged)]
148pub enum ToolChoice {
149    Auto,
150    Required,
151    None,
152    Other(ToolDefinition),
153}
154
155#[derive(Clone, Deserialize, Serialize, Debug)]
156#[serde(tag = "type", rename_all = "snake_case")]
157pub enum ToolDefinition {
158    #[allow(dead_code)]
159    Function { function: FunctionDefinition },
160}
161
162#[derive(Clone, Debug, Serialize, Deserialize)]
163pub struct FunctionDefinition {
164    pub name: String,
165    pub description: Option<String>,
166    pub parameters: Option<Value>,
167}
168
169#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
170#[serde(tag = "role", rename_all = "lowercase")]
171pub enum RequestMessage {
172    Assistant {
173        content: Option<String>,
174        #[serde(default, skip_serializing_if = "Vec::is_empty")]
175        tool_calls: Vec<ToolCall>,
176    },
177    User {
178        content: String,
179    },
180    System {
181        content: String,
182    },
183    Tool {
184        content: String,
185        tool_call_id: String,
186    },
187}
188
189#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
190pub struct ToolCall {
191    pub id: String,
192    #[serde(flatten)]
193    pub content: ToolCallContent,
194}
195
196#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
197#[serde(tag = "type", rename_all = "lowercase")]
198pub enum ToolCallContent {
199    Function { function: FunctionContent },
200}
201
202#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
203pub struct FunctionContent {
204    pub name: String,
205    pub arguments: String,
206}
207
208#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
209pub struct ResponseMessageDelta {
210    pub role: Option<Role>,
211    pub content: Option<String>,
212    #[serde(default, skip_serializing_if = "is_none_or_empty")]
213    pub tool_calls: Option<Vec<ToolCallChunk>>,
214}
215
216#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
217pub struct ToolCallChunk {
218    pub index: usize,
219    pub id: Option<String>,
220
221    // There is also an optional `type` field that would determine if a
222    // function is there. Sometimes this streams in with the `function` before
223    // it streams in the `type`
224    pub function: Option<FunctionChunk>,
225}
226
227#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
228pub struct FunctionChunk {
229    pub name: Option<String>,
230    pub arguments: Option<String>,
231}
232
233#[derive(Serialize, Deserialize, Debug)]
234pub struct Usage {
235    pub prompt_tokens: u32,
236    pub completion_tokens: u32,
237    pub total_tokens: u32,
238}
239
240#[derive(Serialize, Deserialize, Debug)]
241pub struct ChoiceDelta {
242    pub index: u32,
243    pub delta: ResponseMessageDelta,
244    pub finish_reason: Option<String>,
245}
246
247#[derive(Serialize, Deserialize, Debug)]
248#[serde(untagged)]
249pub enum ResponseStreamResult {
250    Ok(ResponseStreamEvent),
251    Err { error: String },
252}
253
254#[derive(Serialize, Deserialize, Debug)]
255pub struct ResponseStreamEvent {
256    pub created: u32,
257    pub model: String,
258    pub choices: Vec<ChoiceDelta>,
259    pub usage: Option<Usage>,
260}
261
262pub async fn stream_completion(
263    client: &dyn HttpClient,
264    api_url: &str,
265    api_key: &str,
266    request: Request,
267    low_speed_timeout: Option<Duration>,
268) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
269    let uri = format!("{api_url}/chat/completions");
270    let mut request_builder = HttpRequest::builder()
271        .method(Method::POST)
272        .uri(uri)
273        .header("Content-Type", "application/json")
274        .header("Authorization", format!("Bearer {}", api_key));
275
276    if let Some(low_speed_timeout) = low_speed_timeout {
277        request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
278    };
279
280    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
281    let mut response = client.send(request).await?;
282    if response.status().is_success() {
283        let reader = BufReader::new(response.into_body());
284        Ok(reader
285            .lines()
286            .filter_map(|line| async move {
287                match line {
288                    Ok(line) => {
289                        let line = line.strip_prefix("data: ")?;
290                        if line == "[DONE]" {
291                            None
292                        } else {
293                            match serde_json::from_str(line) {
294                                Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
295                                Ok(ResponseStreamResult::Err { error }) => {
296                                    Some(Err(anyhow!(error)))
297                                }
298                                Err(error) => Some(Err(anyhow!(error))),
299                            }
300                        }
301                    }
302                    Err(error) => Some(Err(anyhow!(error))),
303                }
304            })
305            .boxed())
306    } else {
307        let mut body = String::new();
308        response.body_mut().read_to_string(&mut body).await?;
309
310        #[derive(Deserialize)]
311        struct OpenAiResponse {
312            error: OpenAiError,
313        }
314
315        #[derive(Deserialize)]
316        struct OpenAiError {
317            message: String,
318        }
319
320        match serde_json::from_str::<OpenAiResponse>(&body) {
321            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
322                "Failed to connect to OpenAI API: {}",
323                response.error.message,
324            )),
325
326            _ => Err(anyhow!(
327                "Failed to connect to OpenAI API: {} {}",
328                response.status(),
329                body,
330            )),
331        }
332    }
333}
334
335#[derive(Copy, Clone, Serialize, Deserialize)]
336pub enum OpenAiEmbeddingModel {
337    #[serde(rename = "text-embedding-3-small")]
338    TextEmbedding3Small,
339    #[serde(rename = "text-embedding-3-large")]
340    TextEmbedding3Large,
341}
342
343#[derive(Serialize)]
344struct OpenAiEmbeddingRequest<'a> {
345    model: OpenAiEmbeddingModel,
346    input: Vec<&'a str>,
347}
348
349#[derive(Deserialize)]
350pub struct OpenAiEmbeddingResponse {
351    pub data: Vec<OpenAiEmbedding>,
352}
353
354#[derive(Deserialize)]
355pub struct OpenAiEmbedding {
356    pub embedding: Vec<f32>,
357}
358
359pub fn embed<'a>(
360    client: &dyn HttpClient,
361    api_url: &str,
362    api_key: &str,
363    model: OpenAiEmbeddingModel,
364    texts: impl IntoIterator<Item = &'a str>,
365) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
366    let uri = format!("{api_url}/embeddings");
367
368    let request = OpenAiEmbeddingRequest {
369        model,
370        input: texts.into_iter().collect(),
371    };
372    let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
373    let request = HttpRequest::builder()
374        .method(Method::POST)
375        .uri(uri)
376        .header("Content-Type", "application/json")
377        .header("Authorization", format!("Bearer {}", api_key))
378        .body(body)
379        .map(|request| client.send(request));
380
381    async move {
382        let mut response = request?.await?;
383        let mut body = String::new();
384        response.body_mut().read_to_string(&mut body).await?;
385
386        if response.status().is_success() {
387            let response: OpenAiEmbeddingResponse =
388                serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
389            Ok(response)
390        } else {
391            Err(anyhow!(
392                "error during embedding, status: {:?}, body: {:?}",
393                response.status(),
394                body
395            ))
396        }
397    }
398}
399
400pub async fn extract_tool_args_from_events(
401    tool_name: String,
402    mut events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
403) -> Result<impl Send + Stream<Item = Result<String>>> {
404    let mut tool_use_index = None;
405    let mut first_chunk = None;
406    while let Some(event) = events.next().await {
407        let call = event?.choices.into_iter().find_map(|choice| {
408            choice.delta.tool_calls?.into_iter().find_map(|call| {
409                if call.function.as_ref()?.name.as_deref()? == tool_name {
410                    Some(call)
411                } else {
412                    None
413                }
414            })
415        });
416        if let Some(call) = call {
417            tool_use_index = Some(call.index);
418            first_chunk = call.function.and_then(|func| func.arguments);
419            break;
420        }
421    }
422
423    let Some(tool_use_index) = tool_use_index else {
424        return Err(anyhow!("tool not used"));
425    };
426
427    Ok(events.filter_map(move |event| {
428        let result = match event {
429            Err(error) => Some(Err(error)),
430            Ok(ResponseStreamEvent { choices, .. }) => choices.into_iter().find_map(|choice| {
431                choice.delta.tool_calls?.into_iter().find_map(|call| {
432                    if call.index == tool_use_index {
433                        let func = call.function?;
434                        let mut arguments = func.arguments?;
435                        if let Some(mut first_chunk) = first_chunk.take() {
436                            first_chunk.push_str(&arguments);
437                            arguments = first_chunk
438                        }
439                        Some(Ok(arguments))
440                    } else {
441                        None
442                    }
443                })
444            }),
445        };
446
447        async move { result }
448    }))
449}
450
451pub fn extract_text_from_events(
452    response: impl Stream<Item = Result<ResponseStreamEvent>>,
453) -> impl Stream<Item = Result<String>> {
454    response.filter_map(|response| async move {
455        match response {
456            Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
457            Err(error) => Some(Err(error)),
458        }
459    })
460}