open_ai.rs

  1mod supported_countries;
  2
  3use anyhow::{anyhow, Context, Result};
  4use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
  5use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  6use isahc::config::Configurable;
  7use serde::{Deserialize, Serialize};
  8use serde_json::Value;
  9use std::{convert::TryFrom, future::Future, pin::Pin, time::Duration};
 10use strum::EnumIter;
 11
 12pub use supported_countries::*;
 13
 14pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
 15
 16fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
 17    opt.as_ref().map_or(true, |v| v.as_ref().is_empty())
 18}
 19
 20#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 21#[serde(rename_all = "lowercase")]
 22pub enum Role {
 23    User,
 24    Assistant,
 25    System,
 26    Tool,
 27}
 28
 29impl TryFrom<String> for Role {
 30    type Error = anyhow::Error;
 31
 32    fn try_from(value: String) -> Result<Self> {
 33        match value.as_str() {
 34            "user" => Ok(Self::User),
 35            "assistant" => Ok(Self::Assistant),
 36            "system" => Ok(Self::System),
 37            "tool" => Ok(Self::Tool),
 38            _ => Err(anyhow!("invalid role '{value}'")),
 39        }
 40    }
 41}
 42
 43impl From<Role> for String {
 44    fn from(val: Role) -> Self {
 45        match val {
 46            Role::User => "user".to_owned(),
 47            Role::Assistant => "assistant".to_owned(),
 48            Role::System => "system".to_owned(),
 49            Role::Tool => "tool".to_owned(),
 50        }
 51    }
 52}
 53
 54#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 55#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 56pub enum Model {
 57    #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo-0613")]
 58    ThreePointFiveTurbo,
 59    #[serde(rename = "gpt-4", alias = "gpt-4-0613")]
 60    Four,
 61    #[serde(rename = "gpt-4-turbo-preview", alias = "gpt-4-1106-preview")]
 62    FourTurbo,
 63    #[serde(rename = "gpt-4o", alias = "gpt-4o-2024-05-13")]
 64    #[default]
 65    FourOmni,
 66    #[serde(rename = "gpt-4o-mini", alias = "gpt-4o-mini-2024-07-18")]
 67    FourOmniMini,
 68    #[serde(rename = "custom")]
 69    Custom {
 70        name: String,
 71        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
 72        display_name: Option<String>,
 73        max_tokens: usize,
 74        max_output_tokens: Option<u32>,
 75    },
 76}
 77
 78impl Model {
 79    pub fn from_id(id: &str) -> Result<Self> {
 80        match id {
 81            "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
 82            "gpt-4" => Ok(Self::Four),
 83            "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
 84            "gpt-4o" => Ok(Self::FourOmni),
 85            "gpt-4o-mini" => Ok(Self::FourOmniMini),
 86            _ => Err(anyhow!("invalid model id")),
 87        }
 88    }
 89
 90    pub fn id(&self) -> &str {
 91        match self {
 92            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
 93            Self::Four => "gpt-4",
 94            Self::FourTurbo => "gpt-4-turbo-preview",
 95            Self::FourOmni => "gpt-4o",
 96            Self::FourOmniMini => "gpt-4o-mini",
 97            Self::Custom { name, .. } => name,
 98        }
 99    }
100
101    pub fn display_name(&self) -> &str {
102        match self {
103            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
104            Self::Four => "gpt-4",
105            Self::FourTurbo => "gpt-4-turbo",
106            Self::FourOmni => "gpt-4o",
107            Self::FourOmniMini => "gpt-4o-mini",
108            Self::Custom {
109                name, display_name, ..
110            } => display_name.as_ref().unwrap_or(name),
111        }
112    }
113
114    pub fn max_token_count(&self) -> usize {
115        match self {
116            Self::ThreePointFiveTurbo => 4096,
117            Self::Four => 8192,
118            Self::FourTurbo => 128000,
119            Self::FourOmni => 128000,
120            Self::FourOmniMini => 128000,
121            Self::Custom { max_tokens, .. } => *max_tokens,
122        }
123    }
124
125    pub fn max_output_tokens(&self) -> Option<u32> {
126        match self {
127            Self::Custom {
128                max_output_tokens, ..
129            } => *max_output_tokens,
130            _ => None,
131        }
132    }
133}
134
135#[derive(Debug, Serialize, Deserialize)]
136pub struct Request {
137    pub model: String,
138    pub messages: Vec<RequestMessage>,
139    pub stream: bool,
140    #[serde(default, skip_serializing_if = "Option::is_none")]
141    pub max_tokens: Option<u32>,
142    pub stop: Vec<String>,
143    pub temperature: f32,
144    #[serde(default, skip_serializing_if = "Option::is_none")]
145    pub tool_choice: Option<ToolChoice>,
146    #[serde(default, skip_serializing_if = "Vec::is_empty")]
147    pub tools: Vec<ToolDefinition>,
148}
149
150#[derive(Debug, Serialize, Deserialize)]
151#[serde(untagged)]
152pub enum ToolChoice {
153    Auto,
154    Required,
155    None,
156    Other(ToolDefinition),
157}
158
159#[derive(Clone, Deserialize, Serialize, Debug)]
160#[serde(tag = "type", rename_all = "snake_case")]
161pub enum ToolDefinition {
162    #[allow(dead_code)]
163    Function { function: FunctionDefinition },
164}
165
166#[derive(Clone, Debug, Serialize, Deserialize)]
167pub struct FunctionDefinition {
168    pub name: String,
169    pub description: Option<String>,
170    pub parameters: Option<Value>,
171}
172
173#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
174#[serde(tag = "role", rename_all = "lowercase")]
175pub enum RequestMessage {
176    Assistant {
177        content: Option<String>,
178        #[serde(default, skip_serializing_if = "Vec::is_empty")]
179        tool_calls: Vec<ToolCall>,
180    },
181    User {
182        content: String,
183    },
184    System {
185        content: String,
186    },
187    Tool {
188        content: String,
189        tool_call_id: String,
190    },
191}
192
193#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
194pub struct ToolCall {
195    pub id: String,
196    #[serde(flatten)]
197    pub content: ToolCallContent,
198}
199
200#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
201#[serde(tag = "type", rename_all = "lowercase")]
202pub enum ToolCallContent {
203    Function { function: FunctionContent },
204}
205
206#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
207pub struct FunctionContent {
208    pub name: String,
209    pub arguments: String,
210}
211
212#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
213pub struct ResponseMessageDelta {
214    pub role: Option<Role>,
215    pub content: Option<String>,
216    #[serde(default, skip_serializing_if = "is_none_or_empty")]
217    pub tool_calls: Option<Vec<ToolCallChunk>>,
218}
219
220#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
221pub struct ToolCallChunk {
222    pub index: usize,
223    pub id: Option<String>,
224
225    // There is also an optional `type` field that would determine if a
226    // function is there. Sometimes this streams in with the `function` before
227    // it streams in the `type`
228    pub function: Option<FunctionChunk>,
229}
230
231#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
232pub struct FunctionChunk {
233    pub name: Option<String>,
234    pub arguments: Option<String>,
235}
236
237#[derive(Serialize, Deserialize, Debug)]
238pub struct Usage {
239    pub prompt_tokens: u32,
240    pub completion_tokens: u32,
241    pub total_tokens: u32,
242}
243
244#[derive(Serialize, Deserialize, Debug)]
245pub struct ChoiceDelta {
246    pub index: u32,
247    pub delta: ResponseMessageDelta,
248    pub finish_reason: Option<String>,
249}
250
251#[derive(Serialize, Deserialize, Debug)]
252#[serde(untagged)]
253pub enum ResponseStreamResult {
254    Ok(ResponseStreamEvent),
255    Err { error: String },
256}
257
258#[derive(Serialize, Deserialize, Debug)]
259pub struct ResponseStreamEvent {
260    pub created: u32,
261    pub model: String,
262    pub choices: Vec<ChoiceDelta>,
263    pub usage: Option<Usage>,
264}
265
266pub async fn stream_completion(
267    client: &dyn HttpClient,
268    api_url: &str,
269    api_key: &str,
270    request: Request,
271    low_speed_timeout: Option<Duration>,
272) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
273    let uri = format!("{api_url}/chat/completions");
274    let mut request_builder = HttpRequest::builder()
275        .method(Method::POST)
276        .uri(uri)
277        .header("Content-Type", "application/json")
278        .header("Authorization", format!("Bearer {}", api_key));
279
280    if let Some(low_speed_timeout) = low_speed_timeout {
281        request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
282    };
283
284    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
285    let mut response = client.send(request).await?;
286    if response.status().is_success() {
287        let reader = BufReader::new(response.into_body());
288        Ok(reader
289            .lines()
290            .filter_map(|line| async move {
291                match line {
292                    Ok(line) => {
293                        let line = line.strip_prefix("data: ")?;
294                        if line == "[DONE]" {
295                            None
296                        } else {
297                            match serde_json::from_str(line) {
298                                Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
299                                Ok(ResponseStreamResult::Err { error }) => {
300                                    Some(Err(anyhow!(error)))
301                                }
302                                Err(error) => Some(Err(anyhow!(error))),
303                            }
304                        }
305                    }
306                    Err(error) => Some(Err(anyhow!(error))),
307                }
308            })
309            .boxed())
310    } else {
311        let mut body = String::new();
312        response.body_mut().read_to_string(&mut body).await?;
313
314        #[derive(Deserialize)]
315        struct OpenAiResponse {
316            error: OpenAiError,
317        }
318
319        #[derive(Deserialize)]
320        struct OpenAiError {
321            message: String,
322        }
323
324        match serde_json::from_str::<OpenAiResponse>(&body) {
325            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
326                "Failed to connect to OpenAI API: {}",
327                response.error.message,
328            )),
329
330            _ => Err(anyhow!(
331                "Failed to connect to OpenAI API: {} {}",
332                response.status(),
333                body,
334            )),
335        }
336    }
337}
338
339#[derive(Copy, Clone, Serialize, Deserialize)]
340pub enum OpenAiEmbeddingModel {
341    #[serde(rename = "text-embedding-3-small")]
342    TextEmbedding3Small,
343    #[serde(rename = "text-embedding-3-large")]
344    TextEmbedding3Large,
345}
346
347#[derive(Serialize)]
348struct OpenAiEmbeddingRequest<'a> {
349    model: OpenAiEmbeddingModel,
350    input: Vec<&'a str>,
351}
352
353#[derive(Deserialize)]
354pub struct OpenAiEmbeddingResponse {
355    pub data: Vec<OpenAiEmbedding>,
356}
357
358#[derive(Deserialize)]
359pub struct OpenAiEmbedding {
360    pub embedding: Vec<f32>,
361}
362
363pub fn embed<'a>(
364    client: &dyn HttpClient,
365    api_url: &str,
366    api_key: &str,
367    model: OpenAiEmbeddingModel,
368    texts: impl IntoIterator<Item = &'a str>,
369) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
370    let uri = format!("{api_url}/embeddings");
371
372    let request = OpenAiEmbeddingRequest {
373        model,
374        input: texts.into_iter().collect(),
375    };
376    let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
377    let request = HttpRequest::builder()
378        .method(Method::POST)
379        .uri(uri)
380        .header("Content-Type", "application/json")
381        .header("Authorization", format!("Bearer {}", api_key))
382        .body(body)
383        .map(|request| client.send(request));
384
385    async move {
386        let mut response = request?.await?;
387        let mut body = String::new();
388        response.body_mut().read_to_string(&mut body).await?;
389
390        if response.status().is_success() {
391            let response: OpenAiEmbeddingResponse =
392                serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
393            Ok(response)
394        } else {
395            Err(anyhow!(
396                "error during embedding, status: {:?}, body: {:?}",
397                response.status(),
398                body
399            ))
400        }
401    }
402}
403
404pub async fn extract_tool_args_from_events(
405    tool_name: String,
406    mut events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
407) -> Result<impl Send + Stream<Item = Result<String>>> {
408    let mut tool_use_index = None;
409    let mut first_chunk = None;
410    while let Some(event) = events.next().await {
411        let call = event?.choices.into_iter().find_map(|choice| {
412            choice.delta.tool_calls?.into_iter().find_map(|call| {
413                if call.function.as_ref()?.name.as_deref()? == tool_name {
414                    Some(call)
415                } else {
416                    None
417                }
418            })
419        });
420        if let Some(call) = call {
421            tool_use_index = Some(call.index);
422            first_chunk = call.function.and_then(|func| func.arguments);
423            break;
424        }
425    }
426
427    let Some(tool_use_index) = tool_use_index else {
428        return Err(anyhow!("tool not used"));
429    };
430
431    Ok(events.filter_map(move |event| {
432        let result = match event {
433            Err(error) => Some(Err(error)),
434            Ok(ResponseStreamEvent { choices, .. }) => choices.into_iter().find_map(|choice| {
435                choice.delta.tool_calls?.into_iter().find_map(|call| {
436                    if call.index == tool_use_index {
437                        let func = call.function?;
438                        let mut arguments = func.arguments?;
439                        if let Some(mut first_chunk) = first_chunk.take() {
440                            first_chunk.push_str(&arguments);
441                            arguments = first_chunk
442                        }
443                        Some(Ok(arguments))
444                    } else {
445                        None
446                    }
447                })
448            }),
449        };
450
451        async move { result }
452    }))
453}
454
455pub fn extract_text_from_events(
456    response: impl Stream<Item = Result<ResponseStreamEvent>>,
457) -> impl Stream<Item = Result<String>> {
458    response.filter_map(|response| async move {
459        match response {
460            Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
461            Err(error) => Some(Err(error)),
462        }
463    })
464}