1use anyhow::{anyhow, Result};
2use futures::{
3 future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
4 Stream, StreamExt,
5};
6use gpui::executor::Background;
7use isahc::{http::StatusCode, Request, RequestExt};
8use serde::{Deserialize, Serialize};
9use std::{
10 fmt::{self, Display},
11 io,
12 sync::Arc,
13};
14
15use crate::{
16 completion::{CompletionProvider, CompletionRequest},
17 models::LanguageModel,
18};
19
20use super::OpenAILanguageModel;
21
22pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
23
24#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
25#[serde(rename_all = "lowercase")]
26pub enum Role {
27 User,
28 Assistant,
29 System,
30}
31
32impl Role {
33 pub fn cycle(&mut self) {
34 *self = match self {
35 Role::User => Role::Assistant,
36 Role::Assistant => Role::System,
37 Role::System => Role::User,
38 }
39 }
40}
41
42impl Display for Role {
43 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
44 match self {
45 Role::User => write!(f, "User"),
46 Role::Assistant => write!(f, "Assistant"),
47 Role::System => write!(f, "System"),
48 }
49 }
50}
51
52#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
53pub struct RequestMessage {
54 pub role: Role,
55 pub content: String,
56}
57
58#[derive(Debug, Default, Serialize)]
59pub struct OpenAIRequest {
60 pub model: String,
61 pub messages: Vec<RequestMessage>,
62 pub stream: bool,
63 pub stop: Vec<String>,
64 pub temperature: f32,
65}
66
67impl CompletionRequest for OpenAIRequest {
68 fn data(&self) -> serde_json::Result<String> {
69 serde_json::to_string(self)
70 }
71}
72
73#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
74pub struct ResponseMessage {
75 pub role: Option<Role>,
76 pub content: Option<String>,
77}
78
79#[derive(Deserialize, Debug)]
80pub struct OpenAIUsage {
81 pub prompt_tokens: u32,
82 pub completion_tokens: u32,
83 pub total_tokens: u32,
84}
85
86#[derive(Deserialize, Debug)]
87pub struct ChatChoiceDelta {
88 pub index: u32,
89 pub delta: ResponseMessage,
90 pub finish_reason: Option<String>,
91}
92
93#[derive(Deserialize, Debug)]
94pub struct OpenAIResponseStreamEvent {
95 pub id: Option<String>,
96 pub object: String,
97 pub created: u32,
98 pub model: String,
99 pub choices: Vec<ChatChoiceDelta>,
100 pub usage: Option<OpenAIUsage>,
101}
102
103pub async fn stream_completion(
104 api_key: String,
105 executor: Arc<Background>,
106 request: Box<dyn CompletionRequest>,
107) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
108 let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
109
110 let json_data = request.data()?;
111 let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
112 .header("Content-Type", "application/json")
113 .header("Authorization", format!("Bearer {}", api_key))
114 .body(json_data)?
115 .send_async()
116 .await?;
117
118 let status = response.status();
119 if status == StatusCode::OK {
120 executor
121 .spawn(async move {
122 let mut lines = BufReader::new(response.body_mut()).lines();
123
124 fn parse_line(
125 line: Result<String, io::Error>,
126 ) -> Result<Option<OpenAIResponseStreamEvent>> {
127 if let Some(data) = line?.strip_prefix("data: ") {
128 let event = serde_json::from_str(&data)?;
129 Ok(Some(event))
130 } else {
131 Ok(None)
132 }
133 }
134
135 while let Some(line) = lines.next().await {
136 if let Some(event) = parse_line(line).transpose() {
137 let done = event.as_ref().map_or(false, |event| {
138 event
139 .choices
140 .last()
141 .map_or(false, |choice| choice.finish_reason.is_some())
142 });
143 if tx.unbounded_send(event).is_err() {
144 break;
145 }
146
147 if done {
148 break;
149 }
150 }
151 }
152
153 anyhow::Ok(())
154 })
155 .detach();
156
157 Ok(rx)
158 } else {
159 let mut body = String::new();
160 response.body_mut().read_to_string(&mut body).await?;
161
162 #[derive(Deserialize)]
163 struct OpenAIResponse {
164 error: OpenAIError,
165 }
166
167 #[derive(Deserialize)]
168 struct OpenAIError {
169 message: String,
170 }
171
172 match serde_json::from_str::<OpenAIResponse>(&body) {
173 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
174 "Failed to connect to OpenAI API: {}",
175 response.error.message,
176 )),
177
178 _ => Err(anyhow!(
179 "Failed to connect to OpenAI API: {} {}",
180 response.status(),
181 body,
182 )),
183 }
184 }
185}
186
187pub struct OpenAICompletionProvider {
188 model: OpenAILanguageModel,
189 api_key: String,
190 executor: Arc<Background>,
191}
192
193impl OpenAICompletionProvider {
194 pub fn new(model_name: &str, api_key: String, executor: Arc<Background>) -> Self {
195 let model = OpenAILanguageModel::load(model_name);
196 Self {
197 model,
198 api_key,
199 executor,
200 }
201 }
202}
203
204impl CompletionProvider for OpenAICompletionProvider {
205 fn base_model(&self) -> Box<dyn LanguageModel> {
206 let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
207 model
208 }
209 fn complete(
210 &self,
211 prompt: Box<dyn CompletionRequest>,
212 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
213 let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
214 async move {
215 let response = request.await?;
216 let stream = response
217 .filter_map(|response| async move {
218 match response {
219 Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
220 Err(error) => Some(Err(error)),
221 }
222 })
223 .boxed();
224 Ok(stream)
225 }
226 .boxed()
227 }
228}