1use std::{
2 env,
3 fmt::{self, Display},
4 io,
5 sync::Arc,
6};
7
8use anyhow::{anyhow, Result};
9use futures::{
10 future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
11 Stream, StreamExt,
12};
13use gpui::{AppContext, BackgroundExecutor};
14use isahc::{http::StatusCode, Request, RequestExt};
15use parking_lot::RwLock;
16use schemars::JsonSchema;
17use serde::{Deserialize, Serialize};
18use util::ResultExt;
19
20use crate::providers::open_ai::{OpenAiLanguageModel, OPEN_AI_API_URL};
21use crate::{
22 auth::{CredentialProvider, ProviderCredential},
23 completion::{CompletionProvider, CompletionRequest},
24 models::LanguageModel,
25};
26
27#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
28#[serde(rename_all = "lowercase")]
29pub enum Role {
30 User,
31 Assistant,
32 System,
33}
34
35impl Role {
36 pub fn cycle(&mut self) {
37 *self = match self {
38 Role::User => Role::Assistant,
39 Role::Assistant => Role::System,
40 Role::System => Role::User,
41 }
42 }
43}
44
45impl Display for Role {
46 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
47 match self {
48 Role::User => write!(f, "User"),
49 Role::Assistant => write!(f, "Assistant"),
50 Role::System => write!(f, "System"),
51 }
52 }
53}
54
55#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
56pub struct RequestMessage {
57 pub role: Role,
58 pub content: String,
59}
60
61#[derive(Debug, Default, Serialize)]
62pub struct OpenAiRequest {
63 pub model: String,
64 pub messages: Vec<RequestMessage>,
65 pub stream: bool,
66 pub stop: Vec<String>,
67 pub temperature: f32,
68}
69
70impl CompletionRequest for OpenAiRequest {
71 fn data(&self) -> serde_json::Result<String> {
72 serde_json::to_string(self)
73 }
74}
75
76#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
77pub struct ResponseMessage {
78 pub role: Option<Role>,
79 pub content: Option<String>,
80}
81
82#[derive(Deserialize, Debug)]
83pub struct OpenAiUsage {
84 pub prompt_tokens: u32,
85 pub completion_tokens: u32,
86 pub total_tokens: u32,
87}
88
89#[derive(Deserialize, Debug)]
90pub struct ChatChoiceDelta {
91 pub index: u32,
92 pub delta: ResponseMessage,
93 pub finish_reason: Option<String>,
94}
95
96#[derive(Deserialize, Debug)]
97pub struct OpenAiResponseStreamEvent {
98 pub id: Option<String>,
99 pub object: String,
100 pub created: u32,
101 pub model: String,
102 pub choices: Vec<ChatChoiceDelta>,
103 pub usage: Option<OpenAiUsage>,
104}
105
106async fn stream_completion(
107 api_url: String,
108 kind: OpenAiCompletionProviderKind,
109 credential: ProviderCredential,
110 executor: BackgroundExecutor,
111 request: Box<dyn CompletionRequest>,
112) -> Result<impl Stream<Item = Result<OpenAiResponseStreamEvent>>> {
113 let api_key = match credential {
114 ProviderCredential::Credentials { api_key } => api_key,
115 _ => {
116 return Err(anyhow!("no credentials provider for completion"));
117 }
118 };
119
120 let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAiResponseStreamEvent>>();
121
122 let (auth_header_name, auth_header_value) = kind.auth_header(api_key);
123 let json_data = request.data()?;
124 let mut response = Request::post(kind.completions_endpoint_url(&api_url))
125 .header("Content-Type", "application/json")
126 .header(auth_header_name, auth_header_value)
127 .body(json_data)?
128 .send_async()
129 .await?;
130
131 let status = response.status();
132 if status == StatusCode::OK {
133 executor
134 .spawn(async move {
135 let mut lines = BufReader::new(response.body_mut()).lines();
136
137 fn parse_line(
138 line: Result<String, io::Error>,
139 ) -> Result<Option<OpenAiResponseStreamEvent>> {
140 if let Some(data) = line?.strip_prefix("data: ") {
141 let event = serde_json::from_str(data)?;
142 Ok(Some(event))
143 } else {
144 Ok(None)
145 }
146 }
147
148 while let Some(line) = lines.next().await {
149 if let Some(event) = parse_line(line).transpose() {
150 let done = event.as_ref().map_or(false, |event| {
151 event
152 .choices
153 .last()
154 .map_or(false, |choice| choice.finish_reason.is_some())
155 });
156 if tx.unbounded_send(event).is_err() {
157 break;
158 }
159
160 if done {
161 break;
162 }
163 }
164 }
165
166 anyhow::Ok(())
167 })
168 .detach();
169
170 Ok(rx)
171 } else {
172 let mut body = String::new();
173 response.body_mut().read_to_string(&mut body).await?;
174
175 #[derive(Deserialize)]
176 struct OpenAiResponse {
177 error: OpenAiError,
178 }
179
180 #[derive(Deserialize)]
181 struct OpenAiError {
182 message: String,
183 }
184
185 match serde_json::from_str::<OpenAiResponse>(&body) {
186 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
187 "Failed to connect to OpenAI API: {}",
188 response.error.message,
189 )),
190
191 _ => Err(anyhow!(
192 "Failed to connect to OpenAI API: {} {}",
193 response.status(),
194 body,
195 )),
196 }
197 }
198}
199
200#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema)]
201pub enum AzureOpenAiApiVersion {
202 /// Retiring April 2, 2024.
203 #[serde(rename = "2023-03-15-preview")]
204 V2023_03_15Preview,
205 #[serde(rename = "2023-05-15")]
206 V2023_05_15,
207 /// Retiring April 2, 2024.
208 #[serde(rename = "2023-06-01-preview")]
209 V2023_06_01Preview,
210 /// Retiring April 2, 2024.
211 #[serde(rename = "2023-07-01-preview")]
212 V2023_07_01Preview,
213 /// Retiring April 2, 2024.
214 #[serde(rename = "2023-08-01-preview")]
215 V2023_08_01Preview,
216 /// Retiring April 2, 2024.
217 #[serde(rename = "2023-09-01-preview")]
218 V2023_09_01Preview,
219 #[serde(rename = "2023-12-01-preview")]
220 V2023_12_01Preview,
221 #[serde(rename = "2024-02-15-preview")]
222 V2024_02_15Preview,
223}
224
225impl fmt::Display for AzureOpenAiApiVersion {
226 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
227 write!(
228 f,
229 "{}",
230 match self {
231 Self::V2023_03_15Preview => "2023-03-15-preview",
232 Self::V2023_05_15 => "2023-05-15",
233 Self::V2023_06_01Preview => "2023-06-01-preview",
234 Self::V2023_07_01Preview => "2023-07-01-preview",
235 Self::V2023_08_01Preview => "2023-08-01-preview",
236 Self::V2023_09_01Preview => "2023-09-01-preview",
237 Self::V2023_12_01Preview => "2023-12-01-preview",
238 Self::V2024_02_15Preview => "2024-02-15-preview",
239 }
240 )
241 }
242}
243
244#[derive(Clone)]
245pub enum OpenAiCompletionProviderKind {
246 OpenAi,
247 AzureOpenAi {
248 deployment_id: String,
249 api_version: AzureOpenAiApiVersion,
250 },
251}
252
253impl OpenAiCompletionProviderKind {
254 /// Returns the chat completion endpoint URL for this [`OpenAiCompletionProviderKind`].
255 fn completions_endpoint_url(&self, api_url: &str) -> String {
256 match self {
257 Self::OpenAi => {
258 // https://platform.openai.com/docs/api-reference/chat/create
259 format!("{api_url}/chat/completions")
260 }
261 Self::AzureOpenAi {
262 deployment_id,
263 api_version,
264 } => {
265 // https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions
266 format!("{api_url}/openai/deployments/{deployment_id}/chat/completions?api-version={api_version}")
267 }
268 }
269 }
270
271 /// Returns the authentication header for this [`OpenAiCompletionProviderKind`].
272 fn auth_header(&self, api_key: String) -> (&'static str, String) {
273 match self {
274 Self::OpenAi => ("Authorization", format!("Bearer {api_key}")),
275 Self::AzureOpenAi { .. } => ("Api-Key", api_key),
276 }
277 }
278}
279
280#[derive(Clone)]
281pub struct OpenAiCompletionProvider {
282 api_url: String,
283 kind: OpenAiCompletionProviderKind,
284 model: OpenAiLanguageModel,
285 credential: Arc<RwLock<ProviderCredential>>,
286 executor: BackgroundExecutor,
287}
288
289impl OpenAiCompletionProvider {
290 pub async fn new(
291 api_url: String,
292 kind: OpenAiCompletionProviderKind,
293 model_name: String,
294 executor: BackgroundExecutor,
295 ) -> Self {
296 let model = executor
297 .spawn(async move { OpenAiLanguageModel::load(&model_name) })
298 .await;
299 let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
300 Self {
301 api_url,
302 kind,
303 model,
304 credential,
305 executor,
306 }
307 }
308}
309
310impl CredentialProvider for OpenAiCompletionProvider {
311 fn has_credentials(&self) -> bool {
312 match *self.credential.read() {
313 ProviderCredential::Credentials { .. } => true,
314 _ => false,
315 }
316 }
317
318 fn retrieve_credentials(&self, cx: &mut AppContext) -> BoxFuture<ProviderCredential> {
319 let existing_credential = self.credential.read().clone();
320 let retrieved_credential = match existing_credential {
321 ProviderCredential::Credentials { .. } => {
322 return async move { existing_credential }.boxed()
323 }
324 _ => {
325 if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
326 async move { ProviderCredential::Credentials { api_key } }.boxed()
327 } else {
328 let credentials = cx.read_credentials(OPEN_AI_API_URL);
329 async move {
330 if let Some(Some((_, api_key))) = credentials.await.log_err() {
331 if let Some(api_key) = String::from_utf8(api_key).log_err() {
332 ProviderCredential::Credentials { api_key }
333 } else {
334 ProviderCredential::NoCredentials
335 }
336 } else {
337 ProviderCredential::NoCredentials
338 }
339 }
340 .boxed()
341 }
342 }
343 };
344
345 async move {
346 let retrieved_credential = retrieved_credential.await;
347 *self.credential.write() = retrieved_credential.clone();
348 retrieved_credential
349 }
350 .boxed()
351 }
352
353 fn save_credentials(
354 &self,
355 cx: &mut AppContext,
356 credential: ProviderCredential,
357 ) -> BoxFuture<()> {
358 *self.credential.write() = credential.clone();
359 let credential = credential.clone();
360 let write_credentials = match credential {
361 ProviderCredential::Credentials { api_key } => {
362 Some(cx.write_credentials(OPEN_AI_API_URL, "Bearer", api_key.as_bytes()))
363 }
364 _ => None,
365 };
366
367 async move {
368 if let Some(write_credentials) = write_credentials {
369 write_credentials.await.log_err();
370 }
371 }
372 .boxed()
373 }
374
375 fn delete_credentials(&self, cx: &mut AppContext) -> BoxFuture<()> {
376 *self.credential.write() = ProviderCredential::NoCredentials;
377 let delete_credentials = cx.delete_credentials(OPEN_AI_API_URL);
378 async move {
379 delete_credentials.await.log_err();
380 }
381 .boxed()
382 }
383}
384
385impl CompletionProvider for OpenAiCompletionProvider {
386 fn base_model(&self) -> Box<dyn LanguageModel> {
387 let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
388 model
389 }
390
391 fn complete(
392 &self,
393 prompt: Box<dyn CompletionRequest>,
394 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
395 // Currently the CompletionRequest for OpenAI, includes a 'model' parameter
396 // This means that the model is determined by the CompletionRequest and not the CompletionProvider,
397 // which is currently model based, due to the language model.
398 // At some point in the future we should rectify this.
399 let credential = self.credential.read().clone();
400 let api_url = self.api_url.clone();
401 let kind = self.kind.clone();
402 let request = stream_completion(api_url, kind, credential, self.executor.clone(), prompt);
403 async move {
404 let response = request.await?;
405 let stream = response
406 .filter_map(|response| async move {
407 match response {
408 Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
409 Err(error) => Some(Err(error)),
410 }
411 })
412 .boxed();
413 Ok(stream)
414 }
415 .boxed()
416 }
417
418 fn box_clone(&self) -> Box<dyn CompletionProvider> {
419 Box::new((*self).clone())
420 }
421}