1use crate::{
2 assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
3};
4use anyhow::{anyhow, Result};
5use editor::{Editor, EditorElement, EditorStyle};
6use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
7use gpui::{AnyView, AppContext, FontStyle, FontWeight, Task, TextStyle, View, WhiteSpace};
8use open_ai::{stream_completion, Request, RequestMessage, Role as OpenAiRole};
9use settings::Settings;
10use std::{env, sync::Arc};
11use theme::ThemeSettings;
12use ui::prelude::*;
13use util::{http::HttpClient, ResultExt};
14
15pub struct OpenAiCompletionProvider {
16 api_key: Option<String>,
17 api_url: String,
18 default_model: OpenAiModel,
19 http_client: Arc<dyn HttpClient>,
20 settings_version: usize,
21}
22
23impl OpenAiCompletionProvider {
24 pub fn new(
25 default_model: OpenAiModel,
26 api_url: String,
27 http_client: Arc<dyn HttpClient>,
28 settings_version: usize,
29 ) -> Self {
30 Self {
31 api_key: None,
32 api_url,
33 default_model,
34 http_client,
35 settings_version,
36 }
37 }
38
39 pub fn update(&mut self, default_model: OpenAiModel, api_url: String, settings_version: usize) {
40 self.default_model = default_model;
41 self.api_url = api_url;
42 self.settings_version = settings_version;
43 }
44
45 pub fn settings_version(&self) -> usize {
46 self.settings_version
47 }
48
49 pub fn is_authenticated(&self) -> bool {
50 self.api_key.is_some()
51 }
52
53 pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
54 if self.is_authenticated() {
55 Task::ready(Ok(()))
56 } else {
57 let api_url = self.api_url.clone();
58 cx.spawn(|mut cx| async move {
59 let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
60 api_key
61 } else {
62 let (_, api_key) = cx
63 .update(|cx| cx.read_credentials(&api_url))?
64 .await?
65 .ok_or_else(|| anyhow!("credentials not found"))?;
66 String::from_utf8(api_key)?
67 };
68 cx.update_global::<CompletionProvider, _>(|provider, _cx| {
69 if let CompletionProvider::OpenAi(provider) = provider {
70 provider.api_key = Some(api_key);
71 }
72 })
73 })
74 }
75 }
76
77 pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
78 let delete_credentials = cx.delete_credentials(&self.api_url);
79 cx.spawn(|mut cx| async move {
80 delete_credentials.await.log_err();
81 cx.update_global::<CompletionProvider, _>(|provider, _cx| {
82 if let CompletionProvider::OpenAi(provider) = provider {
83 provider.api_key = None;
84 }
85 })
86 })
87 }
88
89 pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
90 cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
91 .into()
92 }
93
94 pub fn default_model(&self) -> OpenAiModel {
95 self.default_model.clone()
96 }
97
98 pub fn count_tokens(
99 &self,
100 request: LanguageModelRequest,
101 cx: &AppContext,
102 ) -> BoxFuture<'static, Result<usize>> {
103 count_open_ai_tokens(request, cx.background_executor())
104 }
105
106 pub fn complete(
107 &self,
108 request: LanguageModelRequest,
109 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
110 let request = self.to_open_ai_request(request);
111
112 let http_client = self.http_client.clone();
113 let api_key = self.api_key.clone();
114 let api_url = self.api_url.clone();
115 async move {
116 let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
117 let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
118 let response = request.await?;
119 let stream = response
120 .filter_map(|response| async move {
121 match response {
122 Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
123 Err(error) => Some(Err(error)),
124 }
125 })
126 .boxed();
127 Ok(stream)
128 }
129 .boxed()
130 }
131
132 fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
133 let model = match request.model {
134 LanguageModel::ZedDotDev(_) => self.default_model(),
135 LanguageModel::OpenAi(model) => model,
136 };
137
138 Request {
139 model,
140 messages: request
141 .messages
142 .into_iter()
143 .map(|msg| match msg.role {
144 Role::User => RequestMessage::User {
145 content: msg.content,
146 },
147 Role::Assistant => RequestMessage::Assistant {
148 content: Some(msg.content),
149 tool_calls: Vec::new(),
150 },
151 Role::System => RequestMessage::System {
152 content: msg.content,
153 },
154 })
155 .collect(),
156 stream: true,
157 stop: request.stop,
158 temperature: request.temperature,
159 tools: Vec::new(),
160 tool_choice: None,
161 }
162 }
163}
164
165pub fn count_open_ai_tokens(
166 request: LanguageModelRequest,
167 background_executor: &gpui::BackgroundExecutor,
168) -> BoxFuture<'static, Result<usize>> {
169 background_executor
170 .spawn(async move {
171 let messages = request
172 .messages
173 .into_iter()
174 .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
175 role: match message.role {
176 Role::User => "user".into(),
177 Role::Assistant => "assistant".into(),
178 Role::System => "system".into(),
179 },
180 content: Some(message.content),
181 name: None,
182 function_call: None,
183 })
184 .collect::<Vec<_>>();
185
186 tiktoken_rs::num_tokens_from_messages(request.model.id(), &messages)
187 })
188 .boxed()
189}
190
191impl From<Role> for open_ai::Role {
192 fn from(val: Role) -> Self {
193 match val {
194 Role::User => OpenAiRole::User,
195 Role::Assistant => OpenAiRole::Assistant,
196 Role::System => OpenAiRole::System,
197 }
198 }
199}
200
201struct AuthenticationPrompt {
202 api_key: View<Editor>,
203 api_url: String,
204}
205
206impl AuthenticationPrompt {
207 fn new(api_url: String, cx: &mut WindowContext) -> Self {
208 Self {
209 api_key: cx.new_view(|cx| {
210 let mut editor = Editor::single_line(cx);
211 editor.set_placeholder_text(
212 "sk-000000000000000000000000000000000000000000000000",
213 cx,
214 );
215 editor
216 }),
217 api_url,
218 }
219 }
220
221 fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
222 let api_key = self.api_key.read(cx).text(cx);
223 if api_key.is_empty() {
224 return;
225 }
226
227 let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes());
228 cx.spawn(|_, mut cx| async move {
229 write_credentials.await?;
230 cx.update_global::<CompletionProvider, _>(|provider, _cx| {
231 if let CompletionProvider::OpenAi(provider) = provider {
232 provider.api_key = Some(api_key);
233 }
234 })
235 })
236 .detach_and_log_err(cx);
237 }
238
239 fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
240 let settings = ThemeSettings::get_global(cx);
241 let text_style = TextStyle {
242 color: cx.theme().colors().text,
243 font_family: settings.ui_font.family.clone(),
244 font_features: settings.ui_font.features,
245 font_size: rems(0.875).into(),
246 font_weight: FontWeight::NORMAL,
247 font_style: FontStyle::Normal,
248 line_height: relative(1.3),
249 background_color: None,
250 underline: None,
251 strikethrough: None,
252 white_space: WhiteSpace::Normal,
253 };
254 EditorElement::new(
255 &self.api_key,
256 EditorStyle {
257 background: cx.theme().colors().editor_background,
258 local_player: cx.theme().players().local(),
259 text: text_style,
260 ..Default::default()
261 },
262 )
263 }
264}
265
266impl Render for AuthenticationPrompt {
267 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
268 const INSTRUCTIONS: [&str; 6] = [
269 "To use the assistant panel or inline assistant, you need to add your OpenAI API key.",
270 " - You can create an API key at: platform.openai.com/api-keys",
271 " - Make sure your OpenAI account has credits",
272 " - Having a subscription for another service like GitHub Copilot won't work.",
273 "",
274 "Paste your OpenAI API key below and hit enter to use the assistant:",
275 ];
276
277 v_flex()
278 .p_4()
279 .size_full()
280 .on_action(cx.listener(Self::save_api_key))
281 .children(
282 INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
283 )
284 .child(
285 h_flex()
286 .w_full()
287 .my_2()
288 .px_2()
289 .py_1()
290 .bg(cx.theme().colors().editor_background)
291 .rounded_md()
292 .child(self.render_api_key_editor(cx)),
293 )
294 .child(
295 Label::new(
296 "You can also assign the OPENAI_API_KEY environment variable and restart Zed.",
297 )
298 .size(LabelSize::Small),
299 )
300 .child(
301 h_flex()
302 .gap_2()
303 .child(Label::new("Click on").size(LabelSize::Small))
304 .child(Icon::new(IconName::Ai).size(IconSize::XSmall))
305 .child(
306 Label::new("in the status bar to close this panel.").size(LabelSize::Small),
307 ),
308 )
309 .into_any()
310 }
311}