1use crate::{
2 assistant_settings::AnthropicModel, CompletionProvider, LanguageModel, LanguageModelRequest,
3 Role,
4};
5use crate::{count_open_ai_tokens, LanguageModelCompletionProvider, LanguageModelRequestMessage};
6use anthropic::{stream_completion, Request, RequestMessage};
7use anyhow::{anyhow, Result};
8use editor::{Editor, EditorElement, EditorStyle};
9use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
10use gpui::{AnyView, AppContext, FontStyle, Task, TextStyle, View, WhiteSpace};
11use http::HttpClient;
12use settings::Settings;
13use std::time::Duration;
14use std::{env, sync::Arc};
15use strum::IntoEnumIterator;
16use theme::ThemeSettings;
17use ui::prelude::*;
18use util::ResultExt;
19
20pub struct AnthropicCompletionProvider {
21 api_key: Option<String>,
22 api_url: String,
23 model: AnthropicModel,
24 http_client: Arc<dyn HttpClient>,
25 low_speed_timeout: Option<Duration>,
26 settings_version: usize,
27}
28
29impl LanguageModelCompletionProvider for AnthropicCompletionProvider {
30 fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
31 AnthropicModel::iter()
32 .map(LanguageModel::Anthropic)
33 .collect()
34 }
35
36 fn settings_version(&self) -> usize {
37 self.settings_version
38 }
39
40 fn is_authenticated(&self) -> bool {
41 self.api_key.is_some()
42 }
43
44 fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
45 if self.is_authenticated() {
46 Task::ready(Ok(()))
47 } else {
48 let api_url = self.api_url.clone();
49 cx.spawn(|mut cx| async move {
50 let api_key = if let Ok(api_key) = env::var("ANTHROPIC_API_KEY") {
51 api_key
52 } else {
53 let (_, api_key) = cx
54 .update(|cx| cx.read_credentials(&api_url))?
55 .await?
56 .ok_or_else(|| anyhow!("credentials not found"))?;
57 String::from_utf8(api_key)?
58 };
59 cx.update_global::<CompletionProvider, _>(|provider, _cx| {
60 provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
61 provider.api_key = Some(api_key);
62 });
63 })
64 })
65 }
66 }
67
68 fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
69 let delete_credentials = cx.delete_credentials(&self.api_url);
70 cx.spawn(|mut cx| async move {
71 delete_credentials.await.log_err();
72 cx.update_global::<CompletionProvider, _>(|provider, _cx| {
73 provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
74 provider.api_key = None;
75 });
76 })
77 })
78 }
79
80 fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
81 cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
82 .into()
83 }
84
85 fn model(&self) -> LanguageModel {
86 LanguageModel::Anthropic(self.model.clone())
87 }
88
89 fn count_tokens(
90 &self,
91 request: LanguageModelRequest,
92 cx: &AppContext,
93 ) -> BoxFuture<'static, Result<usize>> {
94 count_open_ai_tokens(request, cx.background_executor())
95 }
96
97 fn complete(
98 &self,
99 request: LanguageModelRequest,
100 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
101 let request = self.to_anthropic_request(request);
102
103 let http_client = self.http_client.clone();
104 let api_key = self.api_key.clone();
105 let api_url = self.api_url.clone();
106 let low_speed_timeout = self.low_speed_timeout;
107 async move {
108 let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
109 let request = stream_completion(
110 http_client.as_ref(),
111 &api_url,
112 &api_key,
113 request,
114 low_speed_timeout,
115 );
116 let response = request.await?;
117 let stream = response
118 .filter_map(|response| async move {
119 match response {
120 Ok(response) => match response {
121 anthropic::ResponseEvent::ContentBlockStart {
122 content_block, ..
123 } => match content_block {
124 anthropic::ContentBlock::Text { text } => Some(Ok(text)),
125 },
126 anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => {
127 match delta {
128 anthropic::TextDelta::TextDelta { text } => Some(Ok(text)),
129 }
130 }
131 _ => None,
132 },
133 Err(error) => Some(Err(error)),
134 }
135 })
136 .boxed();
137 Ok(stream)
138 }
139 .boxed()
140 }
141
142 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
143 self
144 }
145}
146
147impl AnthropicCompletionProvider {
148 pub fn new(
149 model: AnthropicModel,
150 api_url: String,
151 http_client: Arc<dyn HttpClient>,
152 low_speed_timeout: Option<Duration>,
153 settings_version: usize,
154 ) -> Self {
155 Self {
156 api_key: None,
157 api_url,
158 model,
159 http_client,
160 low_speed_timeout,
161 settings_version,
162 }
163 }
164
165 pub fn update(
166 &mut self,
167 model: AnthropicModel,
168 api_url: String,
169 low_speed_timeout: Option<Duration>,
170 settings_version: usize,
171 ) {
172 self.model = model;
173 self.api_url = api_url;
174 self.low_speed_timeout = low_speed_timeout;
175 self.settings_version = settings_version;
176 }
177
178 fn to_anthropic_request(&self, mut request: LanguageModelRequest) -> Request {
179 preprocess_anthropic_request(&mut request);
180
181 let model = match request.model {
182 LanguageModel::Anthropic(model) => model,
183 _ => self.model.clone(),
184 };
185
186 let mut system_message = String::new();
187 if request
188 .messages
189 .first()
190 .map_or(false, |message| message.role == Role::System)
191 {
192 system_message = request.messages.remove(0).content;
193 }
194
195 Request {
196 model,
197 messages: request
198 .messages
199 .iter()
200 .map(|msg| RequestMessage {
201 role: match msg.role {
202 Role::User => anthropic::Role::User,
203 Role::Assistant => anthropic::Role::Assistant,
204 Role::System => unreachable!("filtered out by preprocess_request"),
205 },
206 content: msg.content.clone(),
207 })
208 .collect(),
209 stream: true,
210 system: system_message,
211 max_tokens: 4092,
212 }
213 }
214}
215
216pub fn preprocess_anthropic_request(request: &mut LanguageModelRequest) {
217 let mut new_messages: Vec<LanguageModelRequestMessage> = Vec::new();
218 let mut system_message = String::new();
219
220 for message in request.messages.drain(..) {
221 if message.content.is_empty() {
222 continue;
223 }
224
225 match message.role {
226 Role::User | Role::Assistant => {
227 if let Some(last_message) = new_messages.last_mut() {
228 if last_message.role == message.role {
229 last_message.content.push_str("\n\n");
230 last_message.content.push_str(&message.content);
231 continue;
232 }
233 }
234
235 new_messages.push(message);
236 }
237 Role::System => {
238 if !system_message.is_empty() {
239 system_message.push_str("\n\n");
240 }
241 system_message.push_str(&message.content);
242 }
243 }
244 }
245
246 if !system_message.is_empty() {
247 new_messages.insert(
248 0,
249 LanguageModelRequestMessage {
250 role: Role::System,
251 content: system_message,
252 },
253 );
254 }
255
256 request.messages = new_messages;
257}
258
259struct AuthenticationPrompt {
260 api_key: View<Editor>,
261 api_url: String,
262}
263
264impl AuthenticationPrompt {
265 fn new(api_url: String, cx: &mut WindowContext) -> Self {
266 Self {
267 api_key: cx.new_view(|cx| {
268 let mut editor = Editor::single_line(cx);
269 editor.set_placeholder_text(
270 "sk-000000000000000000000000000000000000000000000000",
271 cx,
272 );
273 editor
274 }),
275 api_url,
276 }
277 }
278
279 fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
280 let api_key = self.api_key.read(cx).text(cx);
281 if api_key.is_empty() {
282 return;
283 }
284
285 let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes());
286 cx.spawn(|_, mut cx| async move {
287 write_credentials.await?;
288 cx.update_global::<CompletionProvider, _>(|provider, _cx| {
289 provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
290 provider.api_key = Some(api_key);
291 });
292 })
293 })
294 .detach_and_log_err(cx);
295 }
296
297 fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
298 let settings = ThemeSettings::get_global(cx);
299 let text_style = TextStyle {
300 color: cx.theme().colors().text,
301 font_family: settings.ui_font.family.clone(),
302 font_features: settings.ui_font.features.clone(),
303 font_size: rems(0.875).into(),
304 font_weight: settings.ui_font.weight,
305 font_style: FontStyle::Normal,
306 line_height: relative(1.3),
307 background_color: None,
308 underline: None,
309 strikethrough: None,
310 white_space: WhiteSpace::Normal,
311 };
312 EditorElement::new(
313 &self.api_key,
314 EditorStyle {
315 background: cx.theme().colors().editor_background,
316 local_player: cx.theme().players().local(),
317 text: text_style,
318 ..Default::default()
319 },
320 )
321 }
322}
323
324impl Render for AuthenticationPrompt {
325 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
326 const INSTRUCTIONS: [&str; 4] = [
327 "To use the assistant panel or inline assistant, you need to add your Anthropic API key.",
328 "You can create an API key at: https://console.anthropic.com/settings/keys",
329 "",
330 "Paste your Anthropic API key below and hit enter to use the assistant:",
331 ];
332
333 v_flex()
334 .p_4()
335 .size_full()
336 .on_action(cx.listener(Self::save_api_key))
337 .children(
338 INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
339 )
340 .child(
341 h_flex()
342 .w_full()
343 .my_2()
344 .px_2()
345 .py_1()
346 .bg(cx.theme().colors().editor_background)
347 .rounded_md()
348 .child(self.render_api_key_editor(cx)),
349 )
350 .child(
351 Label::new(
352 "You can also assign the ANTHROPIC_API_KEY environment variable and restart Zed.",
353 )
354 .size(LabelSize::Small),
355 )
356 .child(
357 h_flex()
358 .gap_2()
359 .child(Label::new("Click on").size(LabelSize::Small))
360 .child(Icon::new(IconName::ZedAssistant).size(IconSize::XSmall))
361 .child(
362 Label::new("in the status bar to close this panel.").size(LabelSize::Small),
363 ),
364 )
365 .into_any()
366 }
367}