1use anyhow::{anyhow, Result};
2use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
3use serde::{Deserialize, Serialize};
4use std::convert::TryFrom;
5use util::http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
6
7#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
8#[serde(rename_all = "lowercase")]
9pub enum Role {
10 User,
11 Assistant,
12 System,
13}
14
15impl TryFrom<String> for Role {
16 type Error = anyhow::Error;
17
18 fn try_from(value: String) -> Result<Self> {
19 match value.as_str() {
20 "user" => Ok(Self::User),
21 "assistant" => Ok(Self::Assistant),
22 "system" => Ok(Self::System),
23 _ => Err(anyhow!("invalid role '{value}'")),
24 }
25 }
26}
27
28impl From<Role> for String {
29 fn from(val: Role) -> Self {
30 match val {
31 Role::User => "user".to_owned(),
32 Role::Assistant => "assistant".to_owned(),
33 Role::System => "system".to_owned(),
34 }
35 }
36}
37
38#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
39#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
40pub enum Model {
41 #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo-0613")]
42 ThreePointFiveTurbo,
43 #[serde(rename = "gpt-4", alias = "gpt-4-0613")]
44 Four,
45 #[serde(rename = "gpt-4-turbo-preview", alias = "gpt-4-1106-preview")]
46 #[default]
47 FourTurbo,
48}
49
50impl Model {
51 pub fn from_id(id: &str) -> Result<Self> {
52 match id {
53 "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
54 "gpt-4" => Ok(Self::Four),
55 "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
56 _ => Err(anyhow!("invalid model id")),
57 }
58 }
59
60 pub fn id(&self) -> &'static str {
61 match self {
62 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
63 Self::Four => "gpt-4",
64 Self::FourTurbo => "gpt-4-turbo-preview",
65 }
66 }
67
68 pub fn display_name(&self) -> &'static str {
69 match self {
70 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
71 Self::Four => "gpt-4",
72 Self::FourTurbo => "gpt-4-turbo",
73 }
74 }
75
76 pub fn max_token_count(&self) -> usize {
77 match self {
78 Model::ThreePointFiveTurbo => 4096,
79 Model::Four => 8192,
80 Model::FourTurbo => 128000,
81 }
82 }
83}
84
85#[derive(Debug, Serialize)]
86pub struct Request {
87 pub model: Model,
88 pub messages: Vec<RequestMessage>,
89 pub stream: bool,
90 pub stop: Vec<String>,
91 pub temperature: f32,
92}
93
94#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
95pub struct RequestMessage {
96 pub role: Role,
97 pub content: String,
98}
99
100#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
101pub struct ResponseMessage {
102 pub role: Option<Role>,
103 pub content: Option<String>,
104}
105
106#[derive(Deserialize, Debug)]
107pub struct Usage {
108 pub prompt_tokens: u32,
109 pub completion_tokens: u32,
110 pub total_tokens: u32,
111}
112
113#[derive(Deserialize, Debug)]
114pub struct ChoiceDelta {
115 pub index: u32,
116 pub delta: ResponseMessage,
117 pub finish_reason: Option<String>,
118}
119
120#[derive(Deserialize, Debug)]
121pub struct ResponseStreamEvent {
122 pub created: u32,
123 pub model: String,
124 pub choices: Vec<ChoiceDelta>,
125 pub usage: Option<Usage>,
126}
127
128pub async fn stream_completion(
129 client: &dyn HttpClient,
130 api_url: &str,
131 api_key: &str,
132 request: Request,
133) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
134 let uri = format!("{api_url}/chat/completions");
135 let request = HttpRequest::builder()
136 .method(Method::POST)
137 .uri(uri)
138 .header("Content-Type", "application/json")
139 .header("Authorization", format!("Bearer {}", api_key))
140 .body(AsyncBody::from(serde_json::to_string(&request)?))?;
141 let mut response = client.send(request).await?;
142 if response.status().is_success() {
143 let reader = BufReader::new(response.into_body());
144 Ok(reader
145 .lines()
146 .filter_map(|line| async move {
147 match line {
148 Ok(line) => {
149 let line = line.strip_prefix("data: ")?;
150 if line == "[DONE]" {
151 None
152 } else {
153 match serde_json::from_str(line) {
154 Ok(response) => Some(Ok(response)),
155 Err(error) => Some(Err(anyhow!(error))),
156 }
157 }
158 }
159 Err(error) => Some(Err(anyhow!(error))),
160 }
161 })
162 .boxed())
163 } else {
164 let mut body = String::new();
165 response.body_mut().read_to_string(&mut body).await?;
166
167 #[derive(Deserialize)]
168 struct OpenAiResponse {
169 error: OpenAiError,
170 }
171
172 #[derive(Deserialize)]
173 struct OpenAiError {
174 message: String,
175 }
176
177 match serde_json::from_str::<OpenAiResponse>(&body) {
178 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
179 "Failed to connect to OpenAI API: {}",
180 response.error.message,
181 )),
182
183 _ => Err(anyhow!(
184 "Failed to connect to OpenAI API: {} {}",
185 response.status(),
186 body,
187 )),
188 }
189 }
190}