1use crate::{
2 assistant_settings::AnthropicModel, CompletionProvider, LanguageModel, LanguageModelRequest,
3 Role,
4};
5use crate::{count_open_ai_tokens, LanguageModelRequestMessage};
6use anthropic::{stream_completion, Request, RequestMessage};
7use anyhow::{anyhow, Result};
8use editor::{Editor, EditorElement, EditorStyle};
9use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
10use gpui::{AnyView, AppContext, FontStyle, Task, TextStyle, View, WhiteSpace};
11use http::HttpClient;
12use settings::Settings;
13use std::time::Duration;
14use std::{env, sync::Arc};
15use strum::IntoEnumIterator;
16use theme::ThemeSettings;
17use ui::prelude::*;
18use util::ResultExt;
19
20pub struct AnthropicCompletionProvider {
21 api_key: Option<String>,
22 api_url: String,
23 model: AnthropicModel,
24 http_client: Arc<dyn HttpClient>,
25 low_speed_timeout: Option<Duration>,
26 settings_version: usize,
27}
28
29impl AnthropicCompletionProvider {
30 pub fn new(
31 model: AnthropicModel,
32 api_url: String,
33 http_client: Arc<dyn HttpClient>,
34 low_speed_timeout: Option<Duration>,
35 settings_version: usize,
36 ) -> Self {
37 Self {
38 api_key: None,
39 api_url,
40 model,
41 http_client,
42 low_speed_timeout,
43 settings_version,
44 }
45 }
46
47 pub fn update(
48 &mut self,
49 model: AnthropicModel,
50 api_url: String,
51 low_speed_timeout: Option<Duration>,
52 settings_version: usize,
53 ) {
54 self.model = model;
55 self.api_url = api_url;
56 self.low_speed_timeout = low_speed_timeout;
57 self.settings_version = settings_version;
58 }
59
60 pub fn available_models(&self) -> impl Iterator<Item = AnthropicModel> {
61 AnthropicModel::iter()
62 }
63
64 pub fn settings_version(&self) -> usize {
65 self.settings_version
66 }
67
68 pub fn is_authenticated(&self) -> bool {
69 self.api_key.is_some()
70 }
71
72 pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
73 if self.is_authenticated() {
74 Task::ready(Ok(()))
75 } else {
76 let api_url = self.api_url.clone();
77 cx.spawn(|mut cx| async move {
78 let api_key = if let Ok(api_key) = env::var("ANTHROPIC_API_KEY") {
79 api_key
80 } else {
81 let (_, api_key) = cx
82 .update(|cx| cx.read_credentials(&api_url))?
83 .await?
84 .ok_or_else(|| anyhow!("credentials not found"))?;
85 String::from_utf8(api_key)?
86 };
87 cx.update_global::<CompletionProvider, _>(|provider, _cx| {
88 if let CompletionProvider::Anthropic(provider) = provider {
89 provider.api_key = Some(api_key);
90 }
91 })
92 })
93 }
94 }
95
96 pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
97 let delete_credentials = cx.delete_credentials(&self.api_url);
98 cx.spawn(|mut cx| async move {
99 delete_credentials.await.log_err();
100 cx.update_global::<CompletionProvider, _>(|provider, _cx| {
101 if let CompletionProvider::Anthropic(provider) = provider {
102 provider.api_key = None;
103 }
104 })
105 })
106 }
107
108 pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
109 cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
110 .into()
111 }
112
113 pub fn model(&self) -> AnthropicModel {
114 self.model.clone()
115 }
116
117 pub fn count_tokens(
118 &self,
119 request: LanguageModelRequest,
120 cx: &AppContext,
121 ) -> BoxFuture<'static, Result<usize>> {
122 count_open_ai_tokens(request, cx.background_executor())
123 }
124
125 pub fn complete(
126 &self,
127 request: LanguageModelRequest,
128 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
129 let request = self.to_anthropic_request(request);
130
131 let http_client = self.http_client.clone();
132 let api_key = self.api_key.clone();
133 let api_url = self.api_url.clone();
134 let low_speed_timeout = self.low_speed_timeout;
135 async move {
136 let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
137 let request = stream_completion(
138 http_client.as_ref(),
139 &api_url,
140 &api_key,
141 request,
142 low_speed_timeout,
143 );
144 let response = request.await?;
145 let stream = response
146 .filter_map(|response| async move {
147 match response {
148 Ok(response) => match response {
149 anthropic::ResponseEvent::ContentBlockStart {
150 content_block, ..
151 } => match content_block {
152 anthropic::ContentBlock::Text { text } => Some(Ok(text)),
153 },
154 anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => {
155 match delta {
156 anthropic::TextDelta::TextDelta { text } => Some(Ok(text)),
157 }
158 }
159 _ => None,
160 },
161 Err(error) => Some(Err(error)),
162 }
163 })
164 .boxed();
165 Ok(stream)
166 }
167 .boxed()
168 }
169
170 fn to_anthropic_request(&self, mut request: LanguageModelRequest) -> Request {
171 preprocess_anthropic_request(&mut request);
172
173 let model = match request.model {
174 LanguageModel::Anthropic(model) => model,
175 _ => self.model(),
176 };
177
178 let mut system_message = String::new();
179 if request
180 .messages
181 .first()
182 .map_or(false, |message| message.role == Role::System)
183 {
184 system_message = request.messages.remove(0).content;
185 }
186
187 Request {
188 model,
189 messages: request
190 .messages
191 .iter()
192 .map(|msg| RequestMessage {
193 role: match msg.role {
194 Role::User => anthropic::Role::User,
195 Role::Assistant => anthropic::Role::Assistant,
196 Role::System => unreachable!("filtered out by preprocess_request"),
197 },
198 content: msg.content.clone(),
199 })
200 .collect(),
201 stream: true,
202 system: system_message,
203 max_tokens: 4092,
204 }
205 }
206}
207
208pub fn preprocess_anthropic_request(request: &mut LanguageModelRequest) {
209 let mut new_messages: Vec<LanguageModelRequestMessage> = Vec::new();
210 let mut system_message = String::new();
211
212 for message in request.messages.drain(..) {
213 if message.content.is_empty() {
214 continue;
215 }
216
217 match message.role {
218 Role::User | Role::Assistant => {
219 if let Some(last_message) = new_messages.last_mut() {
220 if last_message.role == message.role {
221 last_message.content.push_str("\n\n");
222 last_message.content.push_str(&message.content);
223 continue;
224 }
225 }
226
227 new_messages.push(message);
228 }
229 Role::System => {
230 if !system_message.is_empty() {
231 system_message.push_str("\n\n");
232 }
233 system_message.push_str(&message.content);
234 }
235 }
236 }
237
238 if !system_message.is_empty() {
239 new_messages.insert(
240 0,
241 LanguageModelRequestMessage {
242 role: Role::System,
243 content: system_message,
244 },
245 );
246 }
247
248 request.messages = new_messages;
249}
250
251struct AuthenticationPrompt {
252 api_key: View<Editor>,
253 api_url: String,
254}
255
256impl AuthenticationPrompt {
257 fn new(api_url: String, cx: &mut WindowContext) -> Self {
258 Self {
259 api_key: cx.new_view(|cx| {
260 let mut editor = Editor::single_line(cx);
261 editor.set_placeholder_text(
262 "sk-000000000000000000000000000000000000000000000000",
263 cx,
264 );
265 editor
266 }),
267 api_url,
268 }
269 }
270
271 fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
272 let api_key = self.api_key.read(cx).text(cx);
273 if api_key.is_empty() {
274 return;
275 }
276
277 let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes());
278 cx.spawn(|_, mut cx| async move {
279 write_credentials.await?;
280 cx.update_global::<CompletionProvider, _>(|provider, _cx| {
281 if let CompletionProvider::Anthropic(provider) = provider {
282 provider.api_key = Some(api_key);
283 }
284 })
285 })
286 .detach_and_log_err(cx);
287 }
288
289 fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
290 let settings = ThemeSettings::get_global(cx);
291 let text_style = TextStyle {
292 color: cx.theme().colors().text,
293 font_family: settings.ui_font.family.clone(),
294 font_features: settings.ui_font.features.clone(),
295 font_size: rems(0.875).into(),
296 font_weight: settings.ui_font.weight,
297 font_style: FontStyle::Normal,
298 line_height: relative(1.3),
299 background_color: None,
300 underline: None,
301 strikethrough: None,
302 white_space: WhiteSpace::Normal,
303 };
304 EditorElement::new(
305 &self.api_key,
306 EditorStyle {
307 background: cx.theme().colors().editor_background,
308 local_player: cx.theme().players().local(),
309 text: text_style,
310 ..Default::default()
311 },
312 )
313 }
314}
315
316impl Render for AuthenticationPrompt {
317 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
318 const INSTRUCTIONS: [&str; 4] = [
319 "To use the assistant panel or inline assistant, you need to add your Anthropic API key.",
320 "You can create an API key at: https://console.anthropic.com/settings/keys",
321 "",
322 "Paste your Anthropic API key below and hit enter to use the assistant:",
323 ];
324
325 v_flex()
326 .p_4()
327 .size_full()
328 .on_action(cx.listener(Self::save_api_key))
329 .children(
330 INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
331 )
332 .child(
333 h_flex()
334 .w_full()
335 .my_2()
336 .px_2()
337 .py_1()
338 .bg(cx.theme().colors().editor_background)
339 .rounded_md()
340 .child(self.render_api_key_editor(cx)),
341 )
342 .child(
343 Label::new(
344 "You can also assign the ANTHROPIC_API_KEY environment variable and restart Zed.",
345 )
346 .size(LabelSize::Small),
347 )
348 .child(
349 h_flex()
350 .gap_2()
351 .child(Label::new("Click on").size(LabelSize::Small))
352 .child(Icon::new(IconName::ZedAssistant).size(IconSize::XSmall))
353 .child(
354 Label::new("in the status bar to close this panel.").size(LabelSize::Small),
355 ),
356 )
357 .into_any()
358 }
359}