1use anyhow::{anyhow, Result};
2use futures::{
3 future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
4 Stream, StreamExt,
5};
6use gpui::{AppContext, BackgroundExecutor};
7use isahc::{http::StatusCode, Request, RequestExt};
8use parking_lot::RwLock;
9use serde::{Deserialize, Serialize};
10use std::{
11 env,
12 fmt::{self, Display},
13 io,
14 sync::Arc,
15};
16use util::ResultExt;
17
18use crate::{
19 auth::{CredentialProvider, ProviderCredential},
20 completion::{CompletionProvider, CompletionRequest},
21 models::LanguageModel,
22};
23
24use crate::providers::open_ai::{OpenAiLanguageModel, OPEN_AI_API_URL};
25
26#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
27#[serde(rename_all = "lowercase")]
28pub enum Role {
29 User,
30 Assistant,
31 System,
32}
33
34impl Role {
35 pub fn cycle(&mut self) {
36 *self = match self {
37 Role::User => Role::Assistant,
38 Role::Assistant => Role::System,
39 Role::System => Role::User,
40 }
41 }
42}
43
44impl Display for Role {
45 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
46 match self {
47 Role::User => write!(f, "User"),
48 Role::Assistant => write!(f, "Assistant"),
49 Role::System => write!(f, "System"),
50 }
51 }
52}
53
54#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
55pub struct RequestMessage {
56 pub role: Role,
57 pub content: String,
58}
59
60#[derive(Debug, Default, Serialize)]
61pub struct OpenAiRequest {
62 pub model: String,
63 pub messages: Vec<RequestMessage>,
64 pub stream: bool,
65 pub stop: Vec<String>,
66 pub temperature: f32,
67}
68
69impl CompletionRequest for OpenAiRequest {
70 fn data(&self) -> serde_json::Result<String> {
71 serde_json::to_string(self)
72 }
73}
74
75#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
76pub struct ResponseMessage {
77 pub role: Option<Role>,
78 pub content: Option<String>,
79}
80
81#[derive(Deserialize, Debug)]
82pub struct OpenAiUsage {
83 pub prompt_tokens: u32,
84 pub completion_tokens: u32,
85 pub total_tokens: u32,
86}
87
88#[derive(Deserialize, Debug)]
89pub struct ChatChoiceDelta {
90 pub index: u32,
91 pub delta: ResponseMessage,
92 pub finish_reason: Option<String>,
93}
94
95#[derive(Deserialize, Debug)]
96pub struct OpenAiResponseStreamEvent {
97 pub id: Option<String>,
98 pub object: String,
99 pub created: u32,
100 pub model: String,
101 pub choices: Vec<ChatChoiceDelta>,
102 pub usage: Option<OpenAiUsage>,
103}
104
105pub async fn stream_completion(
106 api_url: String,
107 credential: ProviderCredential,
108 executor: BackgroundExecutor,
109 request: Box<dyn CompletionRequest>,
110) -> Result<impl Stream<Item = Result<OpenAiResponseStreamEvent>>> {
111 let api_key = match credential {
112 ProviderCredential::Credentials { api_key } => api_key,
113 _ => {
114 return Err(anyhow!("no credentials provider for completion"));
115 }
116 };
117
118 let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAiResponseStreamEvent>>();
119
120 let json_data = request.data()?;
121 let mut response = Request::post(format!("{api_url}/chat/completions"))
122 .header("Content-Type", "application/json")
123 .header("Authorization", format!("Bearer {}", api_key))
124 .body(json_data)?
125 .send_async()
126 .await?;
127
128 let status = response.status();
129 if status == StatusCode::OK {
130 executor
131 .spawn(async move {
132 let mut lines = BufReader::new(response.body_mut()).lines();
133
134 fn parse_line(
135 line: Result<String, io::Error>,
136 ) -> Result<Option<OpenAiResponseStreamEvent>> {
137 if let Some(data) = line?.strip_prefix("data: ") {
138 let event = serde_json::from_str(data)?;
139 Ok(Some(event))
140 } else {
141 Ok(None)
142 }
143 }
144
145 while let Some(line) = lines.next().await {
146 if let Some(event) = parse_line(line).transpose() {
147 let done = event.as_ref().map_or(false, |event| {
148 event
149 .choices
150 .last()
151 .map_or(false, |choice| choice.finish_reason.is_some())
152 });
153 if tx.unbounded_send(event).is_err() {
154 break;
155 }
156
157 if done {
158 break;
159 }
160 }
161 }
162
163 anyhow::Ok(())
164 })
165 .detach();
166
167 Ok(rx)
168 } else {
169 let mut body = String::new();
170 response.body_mut().read_to_string(&mut body).await?;
171
172 #[derive(Deserialize)]
173 struct OpenAiResponse {
174 error: OpenAiError,
175 }
176
177 #[derive(Deserialize)]
178 struct OpenAiError {
179 message: String,
180 }
181
182 match serde_json::from_str::<OpenAiResponse>(&body) {
183 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
184 "Failed to connect to OpenAI API: {}",
185 response.error.message,
186 )),
187
188 _ => Err(anyhow!(
189 "Failed to connect to OpenAI API: {} {}",
190 response.status(),
191 body,
192 )),
193 }
194 }
195}
196
197#[derive(Clone)]
198pub struct OpenAiCompletionProvider {
199 api_url: String,
200 model: OpenAiLanguageModel,
201 credential: Arc<RwLock<ProviderCredential>>,
202 executor: BackgroundExecutor,
203}
204
205impl OpenAiCompletionProvider {
206 pub async fn new(api_url: String, model_name: String, executor: BackgroundExecutor) -> Self {
207 let model = executor
208 .spawn(async move { OpenAiLanguageModel::load(&model_name) })
209 .await;
210 let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
211 Self {
212 api_url,
213 model,
214 credential,
215 executor,
216 }
217 }
218}
219
220impl CredentialProvider for OpenAiCompletionProvider {
221 fn has_credentials(&self) -> bool {
222 match *self.credential.read() {
223 ProviderCredential::Credentials { .. } => true,
224 _ => false,
225 }
226 }
227
228 fn retrieve_credentials(&self, cx: &mut AppContext) -> BoxFuture<ProviderCredential> {
229 let existing_credential = self.credential.read().clone();
230 let retrieved_credential = match existing_credential {
231 ProviderCredential::Credentials { .. } => {
232 return async move { existing_credential }.boxed()
233 }
234 _ => {
235 if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
236 async move { ProviderCredential::Credentials { api_key } }.boxed()
237 } else {
238 let credentials = cx.read_credentials(OPEN_AI_API_URL);
239 async move {
240 if let Some(Some((_, api_key))) = credentials.await.log_err() {
241 if let Some(api_key) = String::from_utf8(api_key).log_err() {
242 ProviderCredential::Credentials { api_key }
243 } else {
244 ProviderCredential::NoCredentials
245 }
246 } else {
247 ProviderCredential::NoCredentials
248 }
249 }
250 .boxed()
251 }
252 }
253 };
254
255 async move {
256 let retrieved_credential = retrieved_credential.await;
257 *self.credential.write() = retrieved_credential.clone();
258 retrieved_credential
259 }
260 .boxed()
261 }
262
263 fn save_credentials(
264 &self,
265 cx: &mut AppContext,
266 credential: ProviderCredential,
267 ) -> BoxFuture<()> {
268 *self.credential.write() = credential.clone();
269 let credential = credential.clone();
270 let write_credentials = match credential {
271 ProviderCredential::Credentials { api_key } => {
272 Some(cx.write_credentials(OPEN_AI_API_URL, "Bearer", api_key.as_bytes()))
273 }
274 _ => None,
275 };
276
277 async move {
278 if let Some(write_credentials) = write_credentials {
279 write_credentials.await.log_err();
280 }
281 }
282 .boxed()
283 }
284
285 fn delete_credentials(&self, cx: &mut AppContext) -> BoxFuture<()> {
286 *self.credential.write() = ProviderCredential::NoCredentials;
287 let delete_credentials = cx.delete_credentials(OPEN_AI_API_URL);
288 async move {
289 delete_credentials.await.log_err();
290 }
291 .boxed()
292 }
293}
294
295impl CompletionProvider for OpenAiCompletionProvider {
296 fn base_model(&self) -> Box<dyn LanguageModel> {
297 let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
298 model
299 }
300 fn complete(
301 &self,
302 prompt: Box<dyn CompletionRequest>,
303 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
304 // Currently the CompletionRequest for OpenAI, includes a 'model' parameter
305 // This means that the model is determined by the CompletionRequest and not the CompletionProvider,
306 // which is currently model based, due to the language model.
307 // At some point in the future we should rectify this.
308 let credential = self.credential.read().clone();
309 let api_url = self.api_url.clone();
310 let request = stream_completion(api_url, credential, self.executor.clone(), prompt);
311 async move {
312 let response = request.await?;
313 let stream = response
314 .filter_map(|response| async move {
315 match response {
316 Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
317 Err(error) => Some(Err(error)),
318 }
319 })
320 .boxed();
321 Ok(stream)
322 }
323 .boxed()
324 }
325 fn box_clone(&self) -> Box<dyn CompletionProvider> {
326 Box::new((*self).clone())
327 }
328}