anthropic.rs

  1use crate::{
  2    assistant_settings::AnthropicModel, CompletionProvider, LanguageModel, LanguageModelRequest,
  3    Role,
  4};
  5use crate::{count_open_ai_tokens, LanguageModelRequestMessage};
  6use anthropic::{stream_completion, Request, RequestMessage};
  7use anyhow::{anyhow, Result};
  8use editor::{Editor, EditorElement, EditorStyle};
  9use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
 10use gpui::{AnyView, AppContext, FontStyle, Task, TextStyle, View, WhiteSpace};
 11use http::HttpClient;
 12use settings::Settings;
 13use std::time::Duration;
 14use std::{env, sync::Arc};
 15use strum::IntoEnumIterator;
 16use theme::ThemeSettings;
 17use ui::prelude::*;
 18use util::ResultExt;
 19
 20pub struct AnthropicCompletionProvider {
 21    api_key: Option<String>,
 22    api_url: String,
 23    model: AnthropicModel,
 24    http_client: Arc<dyn HttpClient>,
 25    low_speed_timeout: Option<Duration>,
 26    settings_version: usize,
 27}
 28
 29impl AnthropicCompletionProvider {
 30    pub fn new(
 31        model: AnthropicModel,
 32        api_url: String,
 33        http_client: Arc<dyn HttpClient>,
 34        low_speed_timeout: Option<Duration>,
 35        settings_version: usize,
 36    ) -> Self {
 37        Self {
 38            api_key: None,
 39            api_url,
 40            model,
 41            http_client,
 42            low_speed_timeout,
 43            settings_version,
 44        }
 45    }
 46
 47    pub fn update(
 48        &mut self,
 49        model: AnthropicModel,
 50        api_url: String,
 51        low_speed_timeout: Option<Duration>,
 52        settings_version: usize,
 53    ) {
 54        self.model = model;
 55        self.api_url = api_url;
 56        self.low_speed_timeout = low_speed_timeout;
 57        self.settings_version = settings_version;
 58    }
 59
 60    pub fn available_models(&self) -> impl Iterator<Item = AnthropicModel> {
 61        AnthropicModel::iter()
 62    }
 63
 64    pub fn settings_version(&self) -> usize {
 65        self.settings_version
 66    }
 67
 68    pub fn is_authenticated(&self) -> bool {
 69        self.api_key.is_some()
 70    }
 71
 72    pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
 73        if self.is_authenticated() {
 74            Task::ready(Ok(()))
 75        } else {
 76            let api_url = self.api_url.clone();
 77            cx.spawn(|mut cx| async move {
 78                let api_key = if let Ok(api_key) = env::var("ANTHROPIC_API_KEY") {
 79                    api_key
 80                } else {
 81                    let (_, api_key) = cx
 82                        .update(|cx| cx.read_credentials(&api_url))?
 83                        .await?
 84                        .ok_or_else(|| anyhow!("credentials not found"))?;
 85                    String::from_utf8(api_key)?
 86                };
 87                cx.update_global::<CompletionProvider, _>(|provider, _cx| {
 88                    if let CompletionProvider::Anthropic(provider) = provider {
 89                        provider.api_key = Some(api_key);
 90                    }
 91                })
 92            })
 93        }
 94    }
 95
 96    pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
 97        let delete_credentials = cx.delete_credentials(&self.api_url);
 98        cx.spawn(|mut cx| async move {
 99            delete_credentials.await.log_err();
100            cx.update_global::<CompletionProvider, _>(|provider, _cx| {
101                if let CompletionProvider::Anthropic(provider) = provider {
102                    provider.api_key = None;
103                }
104            })
105        })
106    }
107
108    pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
109        cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
110            .into()
111    }
112
113    pub fn model(&self) -> AnthropicModel {
114        self.model.clone()
115    }
116
117    pub fn count_tokens(
118        &self,
119        request: LanguageModelRequest,
120        cx: &AppContext,
121    ) -> BoxFuture<'static, Result<usize>> {
122        count_open_ai_tokens(request, cx.background_executor())
123    }
124
125    pub fn complete(
126        &self,
127        request: LanguageModelRequest,
128    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
129        let request = self.to_anthropic_request(request);
130
131        let http_client = self.http_client.clone();
132        let api_key = self.api_key.clone();
133        let api_url = self.api_url.clone();
134        let low_speed_timeout = self.low_speed_timeout;
135        async move {
136            let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
137            let request = stream_completion(
138                http_client.as_ref(),
139                &api_url,
140                &api_key,
141                request,
142                low_speed_timeout,
143            );
144            let response = request.await?;
145            let stream = response
146                .filter_map(|response| async move {
147                    match response {
148                        Ok(response) => match response {
149                            anthropic::ResponseEvent::ContentBlockStart {
150                                content_block, ..
151                            } => match content_block {
152                                anthropic::ContentBlock::Text { text } => Some(Ok(text)),
153                            },
154                            anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => {
155                                match delta {
156                                    anthropic::TextDelta::TextDelta { text } => Some(Ok(text)),
157                                }
158                            }
159                            _ => None,
160                        },
161                        Err(error) => Some(Err(error)),
162                    }
163                })
164                .boxed();
165            Ok(stream)
166        }
167        .boxed()
168    }
169
170    fn to_anthropic_request(&self, mut request: LanguageModelRequest) -> Request {
171        preprocess_anthropic_request(&mut request);
172
173        let model = match request.model {
174            LanguageModel::Anthropic(model) => model,
175            _ => self.model(),
176        };
177
178        let mut system_message = String::new();
179        if request
180            .messages
181            .first()
182            .map_or(false, |message| message.role == Role::System)
183        {
184            system_message = request.messages.remove(0).content;
185        }
186
187        Request {
188            model,
189            messages: request
190                .messages
191                .iter()
192                .map(|msg| RequestMessage {
193                    role: match msg.role {
194                        Role::User => anthropic::Role::User,
195                        Role::Assistant => anthropic::Role::Assistant,
196                        Role::System => unreachable!("filtered out by preprocess_request"),
197                    },
198                    content: msg.content.clone(),
199                })
200                .collect(),
201            stream: true,
202            system: system_message,
203            max_tokens: 4092,
204        }
205    }
206}
207
208pub fn preprocess_anthropic_request(request: &mut LanguageModelRequest) {
209    let mut new_messages: Vec<LanguageModelRequestMessage> = Vec::new();
210    let mut system_message = String::new();
211
212    for message in request.messages.drain(..) {
213        if message.content.is_empty() {
214            continue;
215        }
216
217        match message.role {
218            Role::User | Role::Assistant => {
219                if let Some(last_message) = new_messages.last_mut() {
220                    if last_message.role == message.role {
221                        last_message.content.push_str("\n\n");
222                        last_message.content.push_str(&message.content);
223                        continue;
224                    }
225                }
226
227                new_messages.push(message);
228            }
229            Role::System => {
230                if !system_message.is_empty() {
231                    system_message.push_str("\n\n");
232                }
233                system_message.push_str(&message.content);
234            }
235        }
236    }
237
238    if !system_message.is_empty() {
239        new_messages.insert(
240            0,
241            LanguageModelRequestMessage {
242                role: Role::System,
243                content: system_message,
244            },
245        );
246    }
247
248    request.messages = new_messages;
249}
250
251struct AuthenticationPrompt {
252    api_key: View<Editor>,
253    api_url: String,
254}
255
256impl AuthenticationPrompt {
257    fn new(api_url: String, cx: &mut WindowContext) -> Self {
258        Self {
259            api_key: cx.new_view(|cx| {
260                let mut editor = Editor::single_line(cx);
261                editor.set_placeholder_text(
262                    "sk-000000000000000000000000000000000000000000000000",
263                    cx,
264                );
265                editor
266            }),
267            api_url,
268        }
269    }
270
271    fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
272        let api_key = self.api_key.read(cx).text(cx);
273        if api_key.is_empty() {
274            return;
275        }
276
277        let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes());
278        cx.spawn(|_, mut cx| async move {
279            write_credentials.await?;
280            cx.update_global::<CompletionProvider, _>(|provider, _cx| {
281                if let CompletionProvider::Anthropic(provider) = provider {
282                    provider.api_key = Some(api_key);
283                }
284            })
285        })
286        .detach_and_log_err(cx);
287    }
288
289    fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
290        let settings = ThemeSettings::get_global(cx);
291        let text_style = TextStyle {
292            color: cx.theme().colors().text,
293            font_family: settings.ui_font.family.clone(),
294            font_features: settings.ui_font.features.clone(),
295            font_size: rems(0.875).into(),
296            font_weight: settings.ui_font.weight,
297            font_style: FontStyle::Normal,
298            line_height: relative(1.3),
299            background_color: None,
300            underline: None,
301            strikethrough: None,
302            white_space: WhiteSpace::Normal,
303        };
304        EditorElement::new(
305            &self.api_key,
306            EditorStyle {
307                background: cx.theme().colors().editor_background,
308                local_player: cx.theme().players().local(),
309                text: text_style,
310                ..Default::default()
311            },
312        )
313    }
314}
315
316impl Render for AuthenticationPrompt {
317    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
318        const INSTRUCTIONS: [&str; 4] = [
319            "To use the assistant panel or inline assistant, you need to add your Anthropic API key.",
320            "You can create an API key at: https://console.anthropic.com/settings/keys",
321            "",
322            "Paste your Anthropic API key below and hit enter to use the assistant:",
323        ];
324
325        v_flex()
326            .p_4()
327            .size_full()
328            .on_action(cx.listener(Self::save_api_key))
329            .children(
330                INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
331            )
332            .child(
333                h_flex()
334                    .w_full()
335                    .my_2()
336                    .px_2()
337                    .py_1()
338                    .bg(cx.theme().colors().editor_background)
339                    .rounded_md()
340                    .child(self.render_api_key_editor(cx)),
341            )
342            .child(
343                Label::new(
344                    "You can also assign the ANTHROPIC_API_KEY environment variable and restart Zed.",
345                )
346                .size(LabelSize::Small),
347            )
348            .child(
349                h_flex()
350                    .gap_2()
351                    .child(Label::new("Click on").size(LabelSize::Small))
352                    .child(Icon::new(IconName::ZedAssistant).size(IconSize::XSmall))
353                    .child(
354                        Label::new("in the status bar to close this panel.").size(LabelSize::Small),
355                    ),
356            )
357            .into_any()
358    }
359}