anthropic.rs

  1use crate::count_open_ai_tokens;
  2use crate::{
  3    assistant_settings::AnthropicModel, CompletionProvider, LanguageModel, LanguageModelRequest,
  4    Role,
  5};
  6use anthropic::{stream_completion, Request, RequestMessage, Role as AnthropicRole};
  7use anyhow::{anyhow, Result};
  8use editor::{Editor, EditorElement, EditorStyle};
  9use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
 10use gpui::{AnyView, AppContext, FontStyle, Task, TextStyle, View, WhiteSpace};
 11use http::HttpClient;
 12use settings::Settings;
 13use std::time::Duration;
 14use std::{env, sync::Arc};
 15use strum::IntoEnumIterator;
 16use theme::ThemeSettings;
 17use ui::prelude::*;
 18use util::ResultExt;
 19
 20pub struct AnthropicCompletionProvider {
 21    api_key: Option<String>,
 22    api_url: String,
 23    model: AnthropicModel,
 24    http_client: Arc<dyn HttpClient>,
 25    low_speed_timeout: Option<Duration>,
 26    settings_version: usize,
 27}
 28
 29impl AnthropicCompletionProvider {
 30    pub fn new(
 31        model: AnthropicModel,
 32        api_url: String,
 33        http_client: Arc<dyn HttpClient>,
 34        low_speed_timeout: Option<Duration>,
 35        settings_version: usize,
 36    ) -> Self {
 37        Self {
 38            api_key: None,
 39            api_url,
 40            model,
 41            http_client,
 42            low_speed_timeout,
 43            settings_version,
 44        }
 45    }
 46
 47    pub fn update(
 48        &mut self,
 49        model: AnthropicModel,
 50        api_url: String,
 51        low_speed_timeout: Option<Duration>,
 52        settings_version: usize,
 53    ) {
 54        self.model = model;
 55        self.api_url = api_url;
 56        self.low_speed_timeout = low_speed_timeout;
 57        self.settings_version = settings_version;
 58    }
 59
 60    pub fn available_models(&self) -> impl Iterator<Item = AnthropicModel> {
 61        AnthropicModel::iter()
 62    }
 63
 64    pub fn settings_version(&self) -> usize {
 65        self.settings_version
 66    }
 67
 68    pub fn is_authenticated(&self) -> bool {
 69        self.api_key.is_some()
 70    }
 71
 72    pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
 73        if self.is_authenticated() {
 74            Task::ready(Ok(()))
 75        } else {
 76            let api_url = self.api_url.clone();
 77            cx.spawn(|mut cx| async move {
 78                let api_key = if let Ok(api_key) = env::var("ANTHROPIC_API_KEY") {
 79                    api_key
 80                } else {
 81                    let (_, api_key) = cx
 82                        .update(|cx| cx.read_credentials(&api_url))?
 83                        .await?
 84                        .ok_or_else(|| anyhow!("credentials not found"))?;
 85                    String::from_utf8(api_key)?
 86                };
 87                cx.update_global::<CompletionProvider, _>(|provider, _cx| {
 88                    if let CompletionProvider::Anthropic(provider) = provider {
 89                        provider.api_key = Some(api_key);
 90                    }
 91                })
 92            })
 93        }
 94    }
 95
 96    pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
 97        let delete_credentials = cx.delete_credentials(&self.api_url);
 98        cx.spawn(|mut cx| async move {
 99            delete_credentials.await.log_err();
100            cx.update_global::<CompletionProvider, _>(|provider, _cx| {
101                if let CompletionProvider::Anthropic(provider) = provider {
102                    provider.api_key = None;
103                }
104            })
105        })
106    }
107
108    pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
109        cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
110            .into()
111    }
112
113    pub fn model(&self) -> AnthropicModel {
114        self.model.clone()
115    }
116
117    pub fn count_tokens(
118        &self,
119        request: LanguageModelRequest,
120        cx: &AppContext,
121    ) -> BoxFuture<'static, Result<usize>> {
122        count_open_ai_tokens(request, cx.background_executor())
123    }
124
125    pub fn complete(
126        &self,
127        request: LanguageModelRequest,
128    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
129        let request = self.to_anthropic_request(request);
130
131        let http_client = self.http_client.clone();
132        let api_key = self.api_key.clone();
133        let api_url = self.api_url.clone();
134        let low_speed_timeout = self.low_speed_timeout;
135        async move {
136            let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
137            let request = stream_completion(
138                http_client.as_ref(),
139                &api_url,
140                &api_key,
141                request,
142                low_speed_timeout,
143            );
144            let response = request.await?;
145            let stream = response
146                .filter_map(|response| async move {
147                    match response {
148                        Ok(response) => match response {
149                            anthropic::ResponseEvent::ContentBlockStart {
150                                content_block, ..
151                            } => match content_block {
152                                anthropic::ContentBlock::Text { text } => Some(Ok(text)),
153                            },
154                            anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => {
155                                match delta {
156                                    anthropic::TextDelta::TextDelta { text } => Some(Ok(text)),
157                                }
158                            }
159                            _ => None,
160                        },
161                        Err(error) => Some(Err(error)),
162                    }
163                })
164                .boxed();
165            Ok(stream)
166        }
167        .boxed()
168    }
169
170    fn to_anthropic_request(&self, request: LanguageModelRequest) -> Request {
171        let model = match request.model {
172            LanguageModel::Anthropic(model) => model,
173            _ => self.model(),
174        };
175
176        let mut system_message = String::new();
177
178        let mut messages: Vec<RequestMessage> = Vec::new();
179        for message in request.messages {
180            if message.content.is_empty() {
181                continue;
182            }
183
184            match message.role {
185                Role::User | Role::Assistant => {
186                    let role = match message.role {
187                        Role::User => AnthropicRole::User,
188                        Role::Assistant => AnthropicRole::Assistant,
189                        _ => unreachable!(),
190                    };
191
192                    if let Some(last_message) = messages.last_mut() {
193                        if last_message.role == role {
194                            last_message.content.push_str("\n\n");
195                            last_message.content.push_str(&message.content);
196                            continue;
197                        }
198                    }
199
200                    messages.push(RequestMessage {
201                        role,
202                        content: message.content,
203                    });
204                }
205                Role::System => {
206                    if !system_message.is_empty() {
207                        system_message.push_str("\n\n");
208                    }
209                    system_message.push_str(&message.content);
210                }
211            }
212        }
213
214        Request {
215            model,
216            messages,
217            stream: true,
218            system: system_message,
219            max_tokens: 4092,
220        }
221    }
222}
223
224struct AuthenticationPrompt {
225    api_key: View<Editor>,
226    api_url: String,
227}
228
229impl AuthenticationPrompt {
230    fn new(api_url: String, cx: &mut WindowContext) -> Self {
231        Self {
232            api_key: cx.new_view(|cx| {
233                let mut editor = Editor::single_line(cx);
234                editor.set_placeholder_text(
235                    "sk-000000000000000000000000000000000000000000000000",
236                    cx,
237                );
238                editor
239            }),
240            api_url,
241        }
242    }
243
244    fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
245        let api_key = self.api_key.read(cx).text(cx);
246        if api_key.is_empty() {
247            return;
248        }
249
250        let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes());
251        cx.spawn(|_, mut cx| async move {
252            write_credentials.await?;
253            cx.update_global::<CompletionProvider, _>(|provider, _cx| {
254                if let CompletionProvider::Anthropic(provider) = provider {
255                    provider.api_key = Some(api_key);
256                }
257            })
258        })
259        .detach_and_log_err(cx);
260    }
261
262    fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
263        let settings = ThemeSettings::get_global(cx);
264        let text_style = TextStyle {
265            color: cx.theme().colors().text,
266            font_family: settings.ui_font.family.clone(),
267            font_features: settings.ui_font.features.clone(),
268            font_size: rems(0.875).into(),
269            font_weight: settings.ui_font.weight,
270            font_style: FontStyle::Normal,
271            line_height: relative(1.3),
272            background_color: None,
273            underline: None,
274            strikethrough: None,
275            white_space: WhiteSpace::Normal,
276        };
277        EditorElement::new(
278            &self.api_key,
279            EditorStyle {
280                background: cx.theme().colors().editor_background,
281                local_player: cx.theme().players().local(),
282                text: text_style,
283                ..Default::default()
284            },
285        )
286    }
287}
288
289impl Render for AuthenticationPrompt {
290    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
291        const INSTRUCTIONS: [&str; 4] = [
292            "To use the assistant panel or inline assistant, you need to add your Anthropic API key.",
293            "You can create an API key at: https://console.anthropic.com/settings/keys",
294            "",
295            "Paste your Anthropic API key below and hit enter to use the assistant:",
296        ];
297
298        v_flex()
299            .p_4()
300            .size_full()
301            .on_action(cx.listener(Self::save_api_key))
302            .children(
303                INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
304            )
305            .child(
306                h_flex()
307                    .w_full()
308                    .my_2()
309                    .px_2()
310                    .py_1()
311                    .bg(cx.theme().colors().editor_background)
312                    .rounded_md()
313                    .child(self.render_api_key_editor(cx)),
314            )
315            .child(
316                Label::new(
317                    "You can also assign the ANTHROPIC_API_KEY environment variable and restart Zed.",
318                )
319                .size(LabelSize::Small),
320            )
321            .child(
322                h_flex()
323                    .gap_2()
324                    .child(Label::new("Click on").size(LabelSize::Small))
325                    .child(Icon::new(IconName::Ai).size(IconSize::XSmall))
326                    .child(
327                        Label::new("in the status bar to close this panel.").size(LabelSize::Small),
328                    ),
329            )
330            .into_any()
331    }
332}