open_ai.rs

  1use anyhow::{anyhow, Result};
  2use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
  3use serde::{Deserialize, Serialize};
  4use std::convert::TryFrom;
  5use util::http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  6
  7#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
  8#[serde(rename_all = "lowercase")]
  9pub enum Role {
 10    User,
 11    Assistant,
 12    System,
 13}
 14
 15impl TryFrom<String> for Role {
 16    type Error = anyhow::Error;
 17
 18    fn try_from(value: String) -> Result<Self> {
 19        match value.as_str() {
 20            "user" => Ok(Self::User),
 21            "assistant" => Ok(Self::Assistant),
 22            "system" => Ok(Self::System),
 23            _ => Err(anyhow!("invalid role '{value}'")),
 24        }
 25    }
 26}
 27
 28impl From<Role> for String {
 29    fn from(val: Role) -> Self {
 30        match val {
 31            Role::User => "user".to_owned(),
 32            Role::Assistant => "assistant".to_owned(),
 33            Role::System => "system".to_owned(),
 34        }
 35    }
 36}
 37
 38#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 39#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 40pub enum Model {
 41    #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo-0613")]
 42    ThreePointFiveTurbo,
 43    #[serde(rename = "gpt-4", alias = "gpt-4-0613")]
 44    Four,
 45    #[serde(rename = "gpt-4-turbo-preview", alias = "gpt-4-1106-preview")]
 46    #[default]
 47    FourTurbo,
 48}
 49
 50impl Model {
 51    pub fn from_id(id: &str) -> Result<Self> {
 52        match id {
 53            "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
 54            "gpt-4" => Ok(Self::Four),
 55            "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
 56            _ => Err(anyhow!("invalid model id")),
 57        }
 58    }
 59
 60    pub fn id(&self) -> &'static str {
 61        match self {
 62            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
 63            Self::Four => "gpt-4",
 64            Self::FourTurbo => "gpt-4-turbo-preview",
 65        }
 66    }
 67
 68    pub fn display_name(&self) -> &'static str {
 69        match self {
 70            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
 71            Self::Four => "gpt-4",
 72            Self::FourTurbo => "gpt-4-turbo",
 73        }
 74    }
 75
 76    pub fn max_token_count(&self) -> usize {
 77        match self {
 78            Model::ThreePointFiveTurbo => 4096,
 79            Model::Four => 8192,
 80            Model::FourTurbo => 128000,
 81        }
 82    }
 83}
 84
 85#[derive(Debug, Serialize)]
 86pub struct Request {
 87    pub model: Model,
 88    pub messages: Vec<RequestMessage>,
 89    pub stream: bool,
 90    pub stop: Vec<String>,
 91    pub temperature: f32,
 92}
 93
 94#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 95pub struct RequestMessage {
 96    pub role: Role,
 97    pub content: String,
 98}
 99
100#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
101pub struct ResponseMessage {
102    pub role: Option<Role>,
103    pub content: Option<String>,
104}
105
106#[derive(Deserialize, Debug)]
107pub struct Usage {
108    pub prompt_tokens: u32,
109    pub completion_tokens: u32,
110    pub total_tokens: u32,
111}
112
113#[derive(Deserialize, Debug)]
114pub struct ChoiceDelta {
115    pub index: u32,
116    pub delta: ResponseMessage,
117    pub finish_reason: Option<String>,
118}
119
120#[derive(Deserialize, Debug)]
121pub struct ResponseStreamEvent {
122    pub created: u32,
123    pub model: String,
124    pub choices: Vec<ChoiceDelta>,
125    pub usage: Option<Usage>,
126}
127
128pub async fn stream_completion(
129    client: &dyn HttpClient,
130    api_url: &str,
131    api_key: &str,
132    request: Request,
133) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
134    let uri = format!("{api_url}/chat/completions");
135    let request = HttpRequest::builder()
136        .method(Method::POST)
137        .uri(uri)
138        .header("Content-Type", "application/json")
139        .header("Authorization", format!("Bearer {}", api_key))
140        .body(AsyncBody::from(serde_json::to_string(&request)?))?;
141    let mut response = client.send(request).await?;
142    if response.status().is_success() {
143        let reader = BufReader::new(response.into_body());
144        Ok(reader
145            .lines()
146            .filter_map(|line| async move {
147                match line {
148                    Ok(line) => {
149                        let line = line.strip_prefix("data: ")?;
150                        if line == "[DONE]" {
151                            None
152                        } else {
153                            match serde_json::from_str(line) {
154                                Ok(response) => Some(Ok(response)),
155                                Err(error) => Some(Err(anyhow!(error))),
156                            }
157                        }
158                    }
159                    Err(error) => Some(Err(anyhow!(error))),
160                }
161            })
162            .boxed())
163    } else {
164        let mut body = String::new();
165        response.body_mut().read_to_string(&mut body).await?;
166
167        #[derive(Deserialize)]
168        struct OpenAiResponse {
169            error: OpenAiError,
170        }
171
172        #[derive(Deserialize)]
173        struct OpenAiError {
174            message: String,
175        }
176
177        match serde_json::from_str::<OpenAiResponse>(&body) {
178            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
179                "Failed to connect to OpenAI API: {}",
180                response.error.message,
181            )),
182
183            _ => Err(anyhow!(
184                "Failed to connect to OpenAI API: {} {}",
185                response.status(),
186                body,
187            )),
188        }
189    }
190}