1use anyhow::{anyhow, Result};
2use async_trait::async_trait;
3use futures::{
4 future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
5 Stream, StreamExt,
6};
7use gpui2::{AppContext, Executor};
8use isahc::{http::StatusCode, Request, RequestExt};
9use parking_lot::RwLock;
10use serde::{Deserialize, Serialize};
11use std::{
12 env,
13 fmt::{self, Display},
14 io,
15 sync::Arc,
16};
17use util::ResultExt;
18
19use crate::{
20 auth::{CredentialProvider, ProviderCredential},
21 completion::{CompletionProvider, CompletionRequest},
22 models::LanguageModel,
23};
24
25use crate::providers::open_ai::{OpenAILanguageModel, OPENAI_API_URL};
26
27#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
28#[serde(rename_all = "lowercase")]
29pub enum Role {
30 User,
31 Assistant,
32 System,
33}
34
35impl Role {
36 pub fn cycle(&mut self) {
37 *self = match self {
38 Role::User => Role::Assistant,
39 Role::Assistant => Role::System,
40 Role::System => Role::User,
41 }
42 }
43}
44
45impl Display for Role {
46 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
47 match self {
48 Role::User => write!(f, "User"),
49 Role::Assistant => write!(f, "Assistant"),
50 Role::System => write!(f, "System"),
51 }
52 }
53}
54
55#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
56pub struct RequestMessage {
57 pub role: Role,
58 pub content: String,
59}
60
61#[derive(Debug, Default, Serialize)]
62pub struct OpenAIRequest {
63 pub model: String,
64 pub messages: Vec<RequestMessage>,
65 pub stream: bool,
66 pub stop: Vec<String>,
67 pub temperature: f32,
68}
69
70impl CompletionRequest for OpenAIRequest {
71 fn data(&self) -> serde_json::Result<String> {
72 serde_json::to_string(self)
73 }
74}
75
76#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
77pub struct ResponseMessage {
78 pub role: Option<Role>,
79 pub content: Option<String>,
80}
81
82#[derive(Deserialize, Debug)]
83pub struct OpenAIUsage {
84 pub prompt_tokens: u32,
85 pub completion_tokens: u32,
86 pub total_tokens: u32,
87}
88
89#[derive(Deserialize, Debug)]
90pub struct ChatChoiceDelta {
91 pub index: u32,
92 pub delta: ResponseMessage,
93 pub finish_reason: Option<String>,
94}
95
96#[derive(Deserialize, Debug)]
97pub struct OpenAIResponseStreamEvent {
98 pub id: Option<String>,
99 pub object: String,
100 pub created: u32,
101 pub model: String,
102 pub choices: Vec<ChatChoiceDelta>,
103 pub usage: Option<OpenAIUsage>,
104}
105
106pub async fn stream_completion(
107 credential: ProviderCredential,
108 executor: Arc<Executor>,
109 request: Box<dyn CompletionRequest>,
110) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
111 let api_key = match credential {
112 ProviderCredential::Credentials { api_key } => api_key,
113 _ => {
114 return Err(anyhow!("no credentials provider for completion"));
115 }
116 };
117
118 let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
119
120 let json_data = request.data()?;
121 let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
122 .header("Content-Type", "application/json")
123 .header("Authorization", format!("Bearer {}", api_key))
124 .body(json_data)?
125 .send_async()
126 .await?;
127
128 let status = response.status();
129 if status == StatusCode::OK {
130 executor
131 .spawn(async move {
132 let mut lines = BufReader::new(response.body_mut()).lines();
133
134 fn parse_line(
135 line: Result<String, io::Error>,
136 ) -> Result<Option<OpenAIResponseStreamEvent>> {
137 if let Some(data) = line?.strip_prefix("data: ") {
138 let event = serde_json::from_str(&data)?;
139 Ok(Some(event))
140 } else {
141 Ok(None)
142 }
143 }
144
145 while let Some(line) = lines.next().await {
146 if let Some(event) = parse_line(line).transpose() {
147 let done = event.as_ref().map_or(false, |event| {
148 event
149 .choices
150 .last()
151 .map_or(false, |choice| choice.finish_reason.is_some())
152 });
153 if tx.unbounded_send(event).is_err() {
154 break;
155 }
156
157 if done {
158 break;
159 }
160 }
161 }
162
163 anyhow::Ok(())
164 })
165 .detach();
166
167 Ok(rx)
168 } else {
169 let mut body = String::new();
170 response.body_mut().read_to_string(&mut body).await?;
171
172 #[derive(Deserialize)]
173 struct OpenAIResponse {
174 error: OpenAIError,
175 }
176
177 #[derive(Deserialize)]
178 struct OpenAIError {
179 message: String,
180 }
181
182 match serde_json::from_str::<OpenAIResponse>(&body) {
183 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
184 "Failed to connect to OpenAI API: {}",
185 response.error.message,
186 )),
187
188 _ => Err(anyhow!(
189 "Failed to connect to OpenAI API: {} {}",
190 response.status(),
191 body,
192 )),
193 }
194 }
195}
196
197#[derive(Clone)]
198pub struct OpenAICompletionProvider {
199 model: OpenAILanguageModel,
200 credential: Arc<RwLock<ProviderCredential>>,
201 executor: Arc<Executor>,
202}
203
204impl OpenAICompletionProvider {
205 pub fn new(model_name: &str, executor: Arc<Executor>) -> Self {
206 let model = OpenAILanguageModel::load(model_name);
207 let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
208 Self {
209 model,
210 credential,
211 executor,
212 }
213 }
214}
215
216#[async_trait]
217impl CredentialProvider for OpenAICompletionProvider {
218 fn has_credentials(&self) -> bool {
219 match *self.credential.read() {
220 ProviderCredential::Credentials { .. } => true,
221 _ => false,
222 }
223 }
224 async fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential {
225 let existing_credential = self.credential.read().clone();
226
227 let retrieved_credential = cx
228 .run_on_main(move |cx| match existing_credential {
229 ProviderCredential::Credentials { .. } => {
230 return existing_credential.clone();
231 }
232 _ => {
233 if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
234 return ProviderCredential::Credentials { api_key };
235 }
236
237 if let Some(Some((_, api_key))) = cx.read_credentials(OPENAI_API_URL).log_err()
238 {
239 if let Some(api_key) = String::from_utf8(api_key).log_err() {
240 return ProviderCredential::Credentials { api_key };
241 } else {
242 return ProviderCredential::NoCredentials;
243 }
244 } else {
245 return ProviderCredential::NoCredentials;
246 }
247 }
248 })
249 .await;
250
251 *self.credential.write() = retrieved_credential.clone();
252 retrieved_credential
253 }
254
255 async fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential) {
256 *self.credential.write() = credential.clone();
257 let credential = credential.clone();
258 cx.run_on_main(move |cx| match credential {
259 ProviderCredential::Credentials { api_key } => {
260 cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
261 .log_err();
262 }
263 _ => {}
264 })
265 .await;
266 }
267 async fn delete_credentials(&self, cx: &mut AppContext) {
268 cx.run_on_main(move |cx| cx.delete_credentials(OPENAI_API_URL).log_err())
269 .await;
270 *self.credential.write() = ProviderCredential::NoCredentials;
271 }
272}
273
274impl CompletionProvider for OpenAICompletionProvider {
275 fn base_model(&self) -> Box<dyn LanguageModel> {
276 let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
277 model
278 }
279 fn complete(
280 &self,
281 prompt: Box<dyn CompletionRequest>,
282 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
283 // Currently the CompletionRequest for OpenAI, includes a 'model' parameter
284 // This means that the model is determined by the CompletionRequest and not the CompletionProvider,
285 // which is currently model based, due to the langauge model.
286 // At some point in the future we should rectify this.
287 let credential = self.credential.read().clone();
288 let request = stream_completion(credential, self.executor.clone(), prompt);
289 async move {
290 let response = request.await?;
291 let stream = response
292 .filter_map(|response| async move {
293 match response {
294 Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
295 Err(error) => Some(Err(error)),
296 }
297 })
298 .boxed();
299 Ok(stream)
300 }
301 .boxed()
302 }
303 fn box_clone(&self) -> Box<dyn CompletionProvider> {
304 Box::new((*self).clone())
305 }
306}