1use anyhow::{anyhow, Result};
2use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
3use serde::{Deserialize, Serialize};
4use std::convert::TryFrom;
5use util::http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
6
7#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
8#[serde(rename_all = "lowercase")]
9pub enum Role {
10 User,
11 Assistant,
12 System,
13}
14
15impl TryFrom<String> for Role {
16 type Error = anyhow::Error;
17
18 fn try_from(value: String) -> Result<Self> {
19 match value.as_str() {
20 "user" => Ok(Self::User),
21 "assistant" => Ok(Self::Assistant),
22 "system" => Ok(Self::System),
23 _ => Err(anyhow!("invalid role '{value}'")),
24 }
25 }
26}
27
28impl From<Role> for String {
29 fn from(val: Role) -> Self {
30 match val {
31 Role::User => "user".to_owned(),
32 Role::Assistant => "assistant".to_owned(),
33 Role::System => "system".to_owned(),
34 }
35 }
36}
37
38#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
39#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
40pub enum Model {
41 #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo-0613")]
42 ThreePointFiveTurbo,
43 #[serde(rename = "gpt-4", alias = "gpt-4-0613")]
44 Four,
45 #[serde(rename = "gpt-4-turbo-preview", alias = "gpt-4-1106-preview")]
46 #[default]
47 FourTurbo,
48}
49
50impl Model {
51 pub fn from_id(id: &str) -> Result<Self> {
52 match id {
53 "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
54 "gpt-4" => Ok(Self::Four),
55 "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
56 _ => Err(anyhow!("invalid model id")),
57 }
58 }
59
60 pub fn id(&self) -> &'static str {
61 match self {
62 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
63 Self::Four => "gpt-4",
64 Self::FourTurbo => "gpt-4-turbo-preview",
65 }
66 }
67
68 pub fn display_name(&self) -> &'static str {
69 match self {
70 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
71 Self::Four => "gpt-4",
72 Self::FourTurbo => "gpt-4-turbo",
73 }
74 }
75}
76
77#[derive(Debug, Serialize)]
78pub struct Request {
79 pub model: Model,
80 pub messages: Vec<RequestMessage>,
81 pub stream: bool,
82 pub stop: Vec<String>,
83 pub temperature: f32,
84}
85
86#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
87pub struct RequestMessage {
88 pub role: Role,
89 pub content: String,
90}
91
92#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
93pub struct ResponseMessage {
94 pub role: Option<Role>,
95 pub content: Option<String>,
96}
97
98#[derive(Deserialize, Debug)]
99pub struct Usage {
100 pub prompt_tokens: u32,
101 pub completion_tokens: u32,
102 pub total_tokens: u32,
103}
104
105#[derive(Deserialize, Debug)]
106pub struct ChoiceDelta {
107 pub index: u32,
108 pub delta: ResponseMessage,
109 pub finish_reason: Option<String>,
110}
111
112#[derive(Deserialize, Debug)]
113pub struct ResponseStreamEvent {
114 pub created: u32,
115 pub model: String,
116 pub choices: Vec<ChoiceDelta>,
117 pub usage: Option<Usage>,
118}
119
120pub async fn stream_completion(
121 client: &dyn HttpClient,
122 api_url: &str,
123 api_key: &str,
124 request: Request,
125) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
126 let uri = format!("{api_url}/chat/completions");
127 let request = HttpRequest::builder()
128 .method(Method::POST)
129 .uri(uri)
130 .header("Content-Type", "application/json")
131 .header("Authorization", format!("Bearer {}", api_key))
132 .body(AsyncBody::from(serde_json::to_string(&request)?))?;
133 let mut response = client.send(request).await?;
134 if response.status().is_success() {
135 let reader = BufReader::new(response.into_body());
136 Ok(reader
137 .lines()
138 .filter_map(|line| async move {
139 match line {
140 Ok(line) => {
141 let line = line.strip_prefix("data: ")?;
142 if line == "[DONE]" {
143 None
144 } else {
145 match serde_json::from_str(line) {
146 Ok(response) => Some(Ok(response)),
147 Err(error) => Some(Err(anyhow!(error))),
148 }
149 }
150 }
151 Err(error) => Some(Err(anyhow!(error))),
152 }
153 })
154 .boxed())
155 } else {
156 let mut body = String::new();
157 response.body_mut().read_to_string(&mut body).await?;
158
159 #[derive(Deserialize)]
160 struct OpenAiResponse {
161 error: OpenAiError,
162 }
163
164 #[derive(Deserialize)]
165 struct OpenAiError {
166 message: String,
167 }
168
169 match serde_json::from_str::<OpenAiResponse>(&body) {
170 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
171 "Failed to connect to OpenAI API: {}",
172 response.error.message,
173 )),
174
175 _ => Err(anyhow!(
176 "Failed to connect to OpenAI API: {} {}",
177 response.status(),
178 body,
179 )),
180 }
181 }
182}