1use anyhow::{anyhow, Result};
2use futures::{
3 future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
4 Stream, StreamExt,
5};
6use gpui::{executor::Background, AppContext};
7use isahc::{http::StatusCode, Request, RequestExt};
8use parking_lot::RwLock;
9use serde::{Deserialize, Serialize};
10use std::{
11 env,
12 fmt::{self, Display},
13 io,
14 sync::Arc,
15};
16use util::ResultExt;
17
18use crate::{
19 auth::{CredentialProvider, ProviderCredential},
20 completion::{CompletionProvider, CompletionRequest},
21 models::LanguageModel,
22};
23
24use crate::providers::open_ai::{OpenAILanguageModel, OPENAI_API_URL};
25
26#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
27#[serde(rename_all = "lowercase")]
28pub enum Role {
29 User,
30 Assistant,
31 System,
32}
33
34impl Role {
35 pub fn cycle(&mut self) {
36 *self = match self {
37 Role::User => Role::Assistant,
38 Role::Assistant => Role::System,
39 Role::System => Role::User,
40 }
41 }
42}
43
44impl Display for Role {
45 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
46 match self {
47 Role::User => write!(f, "User"),
48 Role::Assistant => write!(f, "Assistant"),
49 Role::System => write!(f, "System"),
50 }
51 }
52}
53
54#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
55pub struct RequestMessage {
56 pub role: Role,
57 pub content: String,
58}
59
60#[derive(Debug, Default, Serialize)]
61pub struct OpenAIRequest {
62 pub model: String,
63 pub messages: Vec<RequestMessage>,
64 pub stream: bool,
65 pub stop: Vec<String>,
66 pub temperature: f32,
67}
68
69impl CompletionRequest for OpenAIRequest {
70 fn data(&self) -> serde_json::Result<String> {
71 serde_json::to_string(self)
72 }
73}
74
75#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
76pub struct ResponseMessage {
77 pub role: Option<Role>,
78 pub content: Option<String>,
79}
80
81#[derive(Deserialize, Debug)]
82pub struct OpenAIUsage {
83 pub prompt_tokens: u32,
84 pub completion_tokens: u32,
85 pub total_tokens: u32,
86}
87
88#[derive(Deserialize, Debug)]
89pub struct ChatChoiceDelta {
90 pub index: u32,
91 pub delta: ResponseMessage,
92 pub finish_reason: Option<String>,
93}
94
95#[derive(Deserialize, Debug)]
96pub struct OpenAIResponseStreamEvent {
97 pub id: Option<String>,
98 pub object: String,
99 pub created: u32,
100 pub model: String,
101 pub choices: Vec<ChatChoiceDelta>,
102 pub usage: Option<OpenAIUsage>,
103}
104
105pub async fn stream_completion(
106 credential: ProviderCredential,
107 executor: Arc<Background>,
108 request: Box<dyn CompletionRequest>,
109) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
110 let api_key = match credential {
111 ProviderCredential::Credentials { api_key } => api_key,
112 _ => {
113 return Err(anyhow!("no credentials provider for completion"));
114 }
115 };
116
117 let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
118
119 let json_data = request.data()?;
120 let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
121 .header("Content-Type", "application/json")
122 .header("Authorization", format!("Bearer {}", api_key))
123 .body(json_data)?
124 .send_async()
125 .await?;
126
127 let status = response.status();
128 if status == StatusCode::OK {
129 executor
130 .spawn(async move {
131 let mut lines = BufReader::new(response.body_mut()).lines();
132
133 fn parse_line(
134 line: Result<String, io::Error>,
135 ) -> Result<Option<OpenAIResponseStreamEvent>> {
136 if let Some(data) = line?.strip_prefix("data: ") {
137 let event = serde_json::from_str(&data)?;
138 Ok(Some(event))
139 } else {
140 Ok(None)
141 }
142 }
143
144 while let Some(line) = lines.next().await {
145 if let Some(event) = parse_line(line).transpose() {
146 let done = event.as_ref().map_or(false, |event| {
147 event
148 .choices
149 .last()
150 .map_or(false, |choice| choice.finish_reason.is_some())
151 });
152 if tx.unbounded_send(event).is_err() {
153 break;
154 }
155
156 if done {
157 break;
158 }
159 }
160 }
161
162 anyhow::Ok(())
163 })
164 .detach();
165
166 Ok(rx)
167 } else {
168 let mut body = String::new();
169 response.body_mut().read_to_string(&mut body).await?;
170
171 #[derive(Deserialize)]
172 struct OpenAIResponse {
173 error: OpenAIError,
174 }
175
176 #[derive(Deserialize)]
177 struct OpenAIError {
178 message: String,
179 }
180
181 match serde_json::from_str::<OpenAIResponse>(&body) {
182 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
183 "Failed to connect to OpenAI API: {}",
184 response.error.message,
185 )),
186
187 _ => Err(anyhow!(
188 "Failed to connect to OpenAI API: {} {}",
189 response.status(),
190 body,
191 )),
192 }
193 }
194}
195
196#[derive(Clone)]
197pub struct OpenAICompletionProvider {
198 model: OpenAILanguageModel,
199 credential: Arc<RwLock<ProviderCredential>>,
200 executor: Arc<Background>,
201}
202
203impl OpenAICompletionProvider {
204 pub fn new(model_name: &str, executor: Arc<Background>) -> Self {
205 let model = OpenAILanguageModel::load(model_name);
206 let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
207 Self {
208 model,
209 credential,
210 executor,
211 }
212 }
213}
214
215impl CredentialProvider for OpenAICompletionProvider {
216 fn has_credentials(&self) -> bool {
217 match *self.credential.read() {
218 ProviderCredential::Credentials { .. } => true,
219 _ => false,
220 }
221 }
222 fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
223 let mut credential = self.credential.write();
224 match *credential {
225 ProviderCredential::Credentials { .. } => {
226 return credential.clone();
227 }
228 _ => {
229 if let Ok(api_key) = env::var("OPENAI_API_KEY") {
230 *credential = ProviderCredential::Credentials { api_key };
231 } else if let Some((_, api_key)) = cx
232 .platform()
233 .read_credentials(OPENAI_API_URL)
234 .log_err()
235 .flatten()
236 {
237 if let Some(api_key) = String::from_utf8(api_key).log_err() {
238 *credential = ProviderCredential::Credentials { api_key };
239 }
240 } else {
241 };
242 }
243 }
244
245 credential.clone()
246 }
247
248 fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {
249 match credential.clone() {
250 ProviderCredential::Credentials { api_key } => {
251 cx.platform()
252 .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
253 .log_err();
254 }
255 _ => {}
256 }
257
258 *self.credential.write() = credential;
259 }
260 fn delete_credentials(&self, cx: &AppContext) {
261 cx.platform().delete_credentials(OPENAI_API_URL).log_err();
262 *self.credential.write() = ProviderCredential::NoCredentials;
263 }
264}
265
266impl CompletionProvider for OpenAICompletionProvider {
267 fn base_model(&self) -> Box<dyn LanguageModel> {
268 let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
269 model
270 }
271 fn complete(
272 &self,
273 prompt: Box<dyn CompletionRequest>,
274 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
275 // Currently the CompletionRequest for OpenAI, includes a 'model' parameter
276 // This means that the model is determined by the CompletionRequest and not the CompletionProvider,
277 // which is currently model based, due to the langauge model.
278 // At some point in the future we should rectify this.
279 let credential = self.credential.read().clone();
280 let request = stream_completion(credential, self.executor.clone(), prompt);
281 async move {
282 let response = request.await?;
283 let stream = response
284 .filter_map(|response| async move {
285 match response {
286 Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
287 Err(error) => Some(Err(error)),
288 }
289 })
290 .boxed();
291 Ok(stream)
292 }
293 .boxed()
294 }
295 fn box_clone(&self) -> Box<dyn CompletionProvider> {
296 Box::new((*self).clone())
297 }
298}