open_ai.rs

  1use anyhow::{anyhow, Result};
  2use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
  3use serde::{Deserialize, Serialize};
  4use std::convert::TryFrom;
  5use util::http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  6
  7#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
  8#[serde(rename_all = "lowercase")]
  9pub enum Role {
 10    User,
 11    Assistant,
 12    System,
 13}
 14
 15impl TryFrom<String> for Role {
 16    type Error = anyhow::Error;
 17
 18    fn try_from(value: String) -> Result<Self> {
 19        match value.as_str() {
 20            "user" => Ok(Self::User),
 21            "assistant" => Ok(Self::Assistant),
 22            "system" => Ok(Self::System),
 23            _ => Err(anyhow!("invalid role '{value}'")),
 24        }
 25    }
 26}
 27
 28impl From<Role> for String {
 29    fn from(val: Role) -> Self {
 30        match val {
 31            Role::User => "user".to_owned(),
 32            Role::Assistant => "assistant".to_owned(),
 33            Role::System => "system".to_owned(),
 34        }
 35    }
 36}
 37
 38#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 39#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 40pub enum Model {
 41    #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo-0613")]
 42    ThreePointFiveTurbo,
 43    #[serde(rename = "gpt-4", alias = "gpt-4-0613")]
 44    Four,
 45    #[serde(rename = "gpt-4-turbo-preview", alias = "gpt-4-1106-preview")]
 46    #[default]
 47    FourTurbo,
 48}
 49
 50impl Model {
 51    pub fn from_id(id: &str) -> Result<Self> {
 52        match id {
 53            "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
 54            "gpt-4" => Ok(Self::Four),
 55            "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
 56            _ => Err(anyhow!("invalid model id")),
 57        }
 58    }
 59
 60    pub fn id(&self) -> &'static str {
 61        match self {
 62            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
 63            Self::Four => "gpt-4",
 64            Self::FourTurbo => "gpt-4-turbo-preview",
 65        }
 66    }
 67
 68    pub fn display_name(&self) -> &'static str {
 69        match self {
 70            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
 71            Self::Four => "gpt-4",
 72            Self::FourTurbo => "gpt-4-turbo",
 73        }
 74    }
 75}
 76
 77#[derive(Debug, Serialize)]
 78pub struct Request {
 79    pub model: Model,
 80    pub messages: Vec<RequestMessage>,
 81    pub stream: bool,
 82    pub stop: Vec<String>,
 83    pub temperature: f32,
 84}
 85
 86#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 87pub struct RequestMessage {
 88    pub role: Role,
 89    pub content: String,
 90}
 91
 92#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 93pub struct ResponseMessage {
 94    pub role: Option<Role>,
 95    pub content: Option<String>,
 96}
 97
 98#[derive(Deserialize, Debug)]
 99pub struct Usage {
100    pub prompt_tokens: u32,
101    pub completion_tokens: u32,
102    pub total_tokens: u32,
103}
104
105#[derive(Deserialize, Debug)]
106pub struct ChoiceDelta {
107    pub index: u32,
108    pub delta: ResponseMessage,
109    pub finish_reason: Option<String>,
110}
111
112#[derive(Deserialize, Debug)]
113pub struct ResponseStreamEvent {
114    pub created: u32,
115    pub model: String,
116    pub choices: Vec<ChoiceDelta>,
117    pub usage: Option<Usage>,
118}
119
120pub async fn stream_completion(
121    client: &dyn HttpClient,
122    api_url: &str,
123    api_key: &str,
124    request: Request,
125) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
126    let uri = format!("{api_url}/chat/completions");
127    let request = HttpRequest::builder()
128        .method(Method::POST)
129        .uri(uri)
130        .header("Content-Type", "application/json")
131        .header("Authorization", format!("Bearer {}", api_key))
132        .body(AsyncBody::from(serde_json::to_string(&request)?))?;
133    let mut response = client.send(request).await?;
134    if response.status().is_success() {
135        let reader = BufReader::new(response.into_body());
136        Ok(reader
137            .lines()
138            .filter_map(|line| async move {
139                match line {
140                    Ok(line) => {
141                        let line = line.strip_prefix("data: ")?;
142                        if line == "[DONE]" {
143                            None
144                        } else {
145                            match serde_json::from_str(line) {
146                                Ok(response) => Some(Ok(response)),
147                                Err(error) => Some(Err(anyhow!(error))),
148                            }
149                        }
150                    }
151                    Err(error) => Some(Err(anyhow!(error))),
152                }
153            })
154            .boxed())
155    } else {
156        let mut body = String::new();
157        response.body_mut().read_to_string(&mut body).await?;
158
159        #[derive(Deserialize)]
160        struct OpenAiResponse {
161            error: OpenAiError,
162        }
163
164        #[derive(Deserialize)]
165        struct OpenAiError {
166            message: String,
167        }
168
169        match serde_json::from_str::<OpenAiResponse>(&body) {
170            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
171                "Failed to connect to OpenAI API: {}",
172                response.error.message,
173            )),
174
175            _ => Err(anyhow!(
176                "Failed to connect to OpenAI API: {} {}",
177                response.status(),
178                body,
179            )),
180        }
181    }
182}