1use crate::{
2 assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
3};
4use anyhow::{anyhow, Result};
5use editor::{Editor, EditorElement, EditorStyle};
6use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
7use gpui::{AnyView, AppContext, FontStyle, FontWeight, Task, TextStyle, View, WhiteSpace};
8use open_ai::{stream_completion, Request, RequestMessage, Role as OpenAiRole};
9use settings::Settings;
10use std::{env, sync::Arc};
11use theme::ThemeSettings;
12use ui::prelude::*;
13use util::{http::HttpClient, ResultExt};
14
15pub struct OpenAiCompletionProvider {
16 api_key: Option<String>,
17 api_url: String,
18 default_model: OpenAiModel,
19 http_client: Arc<dyn HttpClient>,
20 settings_version: usize,
21}
22
23impl OpenAiCompletionProvider {
24 pub fn new(
25 default_model: OpenAiModel,
26 api_url: String,
27 http_client: Arc<dyn HttpClient>,
28 settings_version: usize,
29 ) -> Self {
30 Self {
31 api_key: None,
32 api_url,
33 default_model,
34 http_client,
35 settings_version,
36 }
37 }
38
39 pub fn update(&mut self, default_model: OpenAiModel, api_url: String, settings_version: usize) {
40 self.default_model = default_model;
41 self.api_url = api_url;
42 self.settings_version = settings_version;
43 }
44
45 pub fn settings_version(&self) -> usize {
46 self.settings_version
47 }
48
49 pub fn is_authenticated(&self) -> bool {
50 self.api_key.is_some()
51 }
52
53 pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
54 if self.is_authenticated() {
55 Task::ready(Ok(()))
56 } else {
57 let api_url = self.api_url.clone();
58 cx.spawn(|mut cx| async move {
59 let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
60 api_key
61 } else {
62 let (_, api_key) = cx
63 .update(|cx| cx.read_credentials(&api_url))?
64 .await?
65 .ok_or_else(|| anyhow!("credentials not found"))?;
66 String::from_utf8(api_key)?
67 };
68 cx.update_global::<CompletionProvider, _>(|provider, _cx| {
69 if let CompletionProvider::OpenAi(provider) = provider {
70 provider.api_key = Some(api_key);
71 }
72 })
73 })
74 }
75 }
76
77 pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
78 let delete_credentials = cx.delete_credentials(&self.api_url);
79 cx.spawn(|mut cx| async move {
80 delete_credentials.await.log_err();
81 cx.update_global::<CompletionProvider, _>(|provider, _cx| {
82 if let CompletionProvider::OpenAi(provider) = provider {
83 provider.api_key = None;
84 }
85 })
86 })
87 }
88
89 pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
90 cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
91 .into()
92 }
93
94 pub fn default_model(&self) -> OpenAiModel {
95 self.default_model.clone()
96 }
97
98 pub fn count_tokens(
99 &self,
100 request: LanguageModelRequest,
101 cx: &AppContext,
102 ) -> BoxFuture<'static, Result<usize>> {
103 count_open_ai_tokens(request, cx.background_executor())
104 }
105
106 pub fn complete(
107 &self,
108 request: LanguageModelRequest,
109 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
110 let request = self.to_open_ai_request(request);
111
112 let http_client = self.http_client.clone();
113 let api_key = self.api_key.clone();
114 let api_url = self.api_url.clone();
115 async move {
116 let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
117 let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
118 let response = request.await?;
119 let stream = response
120 .filter_map(|response| async move {
121 match response {
122 Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
123 Err(error) => Some(Err(error)),
124 }
125 })
126 .boxed();
127 Ok(stream)
128 }
129 .boxed()
130 }
131
132 fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
133 let model = match request.model {
134 LanguageModel::ZedDotDev(_) => self.default_model(),
135 LanguageModel::OpenAi(model) => model,
136 };
137
138 Request {
139 model,
140 messages: request
141 .messages
142 .into_iter()
143 .map(|msg| RequestMessage {
144 role: msg.role.into(),
145 content: msg.content,
146 })
147 .collect(),
148 stream: true,
149 stop: request.stop,
150 temperature: request.temperature,
151 }
152 }
153}
154
155pub fn count_open_ai_tokens(
156 request: LanguageModelRequest,
157 background_executor: &gpui::BackgroundExecutor,
158) -> BoxFuture<'static, Result<usize>> {
159 background_executor
160 .spawn(async move {
161 let messages = request
162 .messages
163 .into_iter()
164 .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
165 role: match message.role {
166 Role::User => "user".into(),
167 Role::Assistant => "assistant".into(),
168 Role::System => "system".into(),
169 },
170 content: Some(message.content),
171 name: None,
172 function_call: None,
173 })
174 .collect::<Vec<_>>();
175
176 tiktoken_rs::num_tokens_from_messages(request.model.id(), &messages)
177 })
178 .boxed()
179}
180
181impl From<Role> for open_ai::Role {
182 fn from(val: Role) -> Self {
183 match val {
184 Role::User => OpenAiRole::User,
185 Role::Assistant => OpenAiRole::Assistant,
186 Role::System => OpenAiRole::System,
187 }
188 }
189}
190
191struct AuthenticationPrompt {
192 api_key: View<Editor>,
193 api_url: String,
194}
195
196impl AuthenticationPrompt {
197 fn new(api_url: String, cx: &mut WindowContext) -> Self {
198 Self {
199 api_key: cx.new_view(|cx| {
200 let mut editor = Editor::single_line(cx);
201 editor.set_placeholder_text(
202 "sk-000000000000000000000000000000000000000000000000",
203 cx,
204 );
205 editor
206 }),
207 api_url,
208 }
209 }
210
211 fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
212 let api_key = self.api_key.read(cx).text(cx);
213 if api_key.is_empty() {
214 return;
215 }
216
217 let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes());
218 cx.spawn(|_, mut cx| async move {
219 write_credentials.await?;
220 cx.update_global::<CompletionProvider, _>(|provider, _cx| {
221 if let CompletionProvider::OpenAi(provider) = provider {
222 provider.api_key = Some(api_key);
223 }
224 })
225 })
226 .detach_and_log_err(cx);
227 }
228
229 fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
230 let settings = ThemeSettings::get_global(cx);
231 let text_style = TextStyle {
232 color: cx.theme().colors().text,
233 font_family: settings.ui_font.family.clone(),
234 font_features: settings.ui_font.features,
235 font_size: rems(0.875).into(),
236 font_weight: FontWeight::NORMAL,
237 font_style: FontStyle::Normal,
238 line_height: relative(1.3),
239 background_color: None,
240 underline: None,
241 strikethrough: None,
242 white_space: WhiteSpace::Normal,
243 };
244 EditorElement::new(
245 &self.api_key,
246 EditorStyle {
247 background: cx.theme().colors().editor_background,
248 local_player: cx.theme().players().local(),
249 text: text_style,
250 ..Default::default()
251 },
252 )
253 }
254}
255
256impl Render for AuthenticationPrompt {
257 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
258 const INSTRUCTIONS: [&str; 6] = [
259 "To use the assistant panel or inline assistant, you need to add your OpenAI API key.",
260 " - You can create an API key at: platform.openai.com/api-keys",
261 " - Make sure your OpenAI account has credits",
262 " - Having a subscription for another service like GitHub Copilot won't work.",
263 "",
264 "Paste your OpenAI API key below and hit enter to use the assistant:",
265 ];
266
267 v_flex()
268 .p_4()
269 .size_full()
270 .on_action(cx.listener(Self::save_api_key))
271 .children(
272 INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
273 )
274 .child(
275 h_flex()
276 .w_full()
277 .my_2()
278 .px_2()
279 .py_1()
280 .bg(cx.theme().colors().editor_background)
281 .rounded_md()
282 .child(self.render_api_key_editor(cx)),
283 )
284 .child(
285 Label::new(
286 "You can also assign the OPENAI_API_KEY environment variable and restart Zed.",
287 )
288 .size(LabelSize::Small),
289 )
290 .child(
291 h_flex()
292 .gap_2()
293 .child(Label::new("Click on").size(LabelSize::Small))
294 .child(Icon::new(IconName::Ai).size(IconSize::XSmall))
295 .child(
296 Label::new("in the status bar to close this panel.").size(LabelSize::Small),
297 ),
298 )
299 .into_any()
300 }
301}