anthropic.rs

  1use crate::{
  2    assistant_settings::AnthropicModel, CompletionProvider, LanguageModel, LanguageModelRequest,
  3    Role,
  4};
  5use crate::{count_open_ai_tokens, LanguageModelCompletionProvider, LanguageModelRequestMessage};
  6use anthropic::{stream_completion, Request, RequestMessage};
  7use anyhow::{anyhow, Result};
  8use editor::{Editor, EditorElement, EditorStyle};
  9use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
 10use gpui::{AnyView, AppContext, FontStyle, Task, TextStyle, View, WhiteSpace};
 11use http::HttpClient;
 12use settings::Settings;
 13use std::time::Duration;
 14use std::{env, sync::Arc};
 15use strum::IntoEnumIterator;
 16use theme::ThemeSettings;
 17use ui::prelude::*;
 18use util::ResultExt;
 19
 20pub struct AnthropicCompletionProvider {
 21    api_key: Option<String>,
 22    api_url: String,
 23    model: AnthropicModel,
 24    http_client: Arc<dyn HttpClient>,
 25    low_speed_timeout: Option<Duration>,
 26    settings_version: usize,
 27}
 28
 29impl LanguageModelCompletionProvider for AnthropicCompletionProvider {
 30    fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
 31        AnthropicModel::iter()
 32            .map(LanguageModel::Anthropic)
 33            .collect()
 34    }
 35
 36    fn settings_version(&self) -> usize {
 37        self.settings_version
 38    }
 39
 40    fn is_authenticated(&self) -> bool {
 41        self.api_key.is_some()
 42    }
 43
 44    fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
 45        if self.is_authenticated() {
 46            Task::ready(Ok(()))
 47        } else {
 48            let api_url = self.api_url.clone();
 49            cx.spawn(|mut cx| async move {
 50                let api_key = if let Ok(api_key) = env::var("ANTHROPIC_API_KEY") {
 51                    api_key
 52                } else {
 53                    let (_, api_key) = cx
 54                        .update(|cx| cx.read_credentials(&api_url))?
 55                        .await?
 56                        .ok_or_else(|| anyhow!("credentials not found"))?;
 57                    String::from_utf8(api_key)?
 58                };
 59                cx.update_global::<CompletionProvider, _>(|provider, _cx| {
 60                    provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
 61                        provider.api_key = Some(api_key);
 62                    });
 63                })
 64            })
 65        }
 66    }
 67
 68    fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
 69        let delete_credentials = cx.delete_credentials(&self.api_url);
 70        cx.spawn(|mut cx| async move {
 71            delete_credentials.await.log_err();
 72            cx.update_global::<CompletionProvider, _>(|provider, _cx| {
 73                provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
 74                    provider.api_key = None;
 75                });
 76            })
 77        })
 78    }
 79
 80    fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
 81        cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
 82            .into()
 83    }
 84
 85    fn model(&self) -> LanguageModel {
 86        LanguageModel::Anthropic(self.model.clone())
 87    }
 88
 89    fn count_tokens(
 90        &self,
 91        request: LanguageModelRequest,
 92        cx: &AppContext,
 93    ) -> BoxFuture<'static, Result<usize>> {
 94        count_open_ai_tokens(request, cx.background_executor())
 95    }
 96
 97    fn complete(
 98        &self,
 99        request: LanguageModelRequest,
100    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
101        let request = self.to_anthropic_request(request);
102
103        let http_client = self.http_client.clone();
104        let api_key = self.api_key.clone();
105        let api_url = self.api_url.clone();
106        let low_speed_timeout = self.low_speed_timeout;
107        async move {
108            let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
109            let request = stream_completion(
110                http_client.as_ref(),
111                &api_url,
112                &api_key,
113                request,
114                low_speed_timeout,
115            );
116            let response = request.await?;
117            let stream = response
118                .filter_map(|response| async move {
119                    match response {
120                        Ok(response) => match response {
121                            anthropic::ResponseEvent::ContentBlockStart {
122                                content_block, ..
123                            } => match content_block {
124                                anthropic::ContentBlock::Text { text } => Some(Ok(text)),
125                            },
126                            anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => {
127                                match delta {
128                                    anthropic::TextDelta::TextDelta { text } => Some(Ok(text)),
129                                }
130                            }
131                            _ => None,
132                        },
133                        Err(error) => Some(Err(error)),
134                    }
135                })
136                .boxed();
137            Ok(stream)
138        }
139        .boxed()
140    }
141
142    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
143        self
144    }
145}
146
147impl AnthropicCompletionProvider {
148    pub fn new(
149        model: AnthropicModel,
150        api_url: String,
151        http_client: Arc<dyn HttpClient>,
152        low_speed_timeout: Option<Duration>,
153        settings_version: usize,
154    ) -> Self {
155        Self {
156            api_key: None,
157            api_url,
158            model,
159            http_client,
160            low_speed_timeout,
161            settings_version,
162        }
163    }
164
165    pub fn update(
166        &mut self,
167        model: AnthropicModel,
168        api_url: String,
169        low_speed_timeout: Option<Duration>,
170        settings_version: usize,
171    ) {
172        self.model = model;
173        self.api_url = api_url;
174        self.low_speed_timeout = low_speed_timeout;
175        self.settings_version = settings_version;
176    }
177
178    fn to_anthropic_request(&self, mut request: LanguageModelRequest) -> Request {
179        preprocess_anthropic_request(&mut request);
180
181        let model = match request.model {
182            LanguageModel::Anthropic(model) => model,
183            _ => self.model.clone(),
184        };
185
186        let mut system_message = String::new();
187        if request
188            .messages
189            .first()
190            .map_or(false, |message| message.role == Role::System)
191        {
192            system_message = request.messages.remove(0).content;
193        }
194
195        Request {
196            model,
197            messages: request
198                .messages
199                .iter()
200                .map(|msg| RequestMessage {
201                    role: match msg.role {
202                        Role::User => anthropic::Role::User,
203                        Role::Assistant => anthropic::Role::Assistant,
204                        Role::System => unreachable!("filtered out by preprocess_request"),
205                    },
206                    content: msg.content.clone(),
207                })
208                .collect(),
209            stream: true,
210            system: system_message,
211            max_tokens: 4092,
212        }
213    }
214}
215
216pub fn preprocess_anthropic_request(request: &mut LanguageModelRequest) {
217    let mut new_messages: Vec<LanguageModelRequestMessage> = Vec::new();
218    let mut system_message = String::new();
219
220    for message in request.messages.drain(..) {
221        if message.content.is_empty() {
222            continue;
223        }
224
225        match message.role {
226            Role::User | Role::Assistant => {
227                if let Some(last_message) = new_messages.last_mut() {
228                    if last_message.role == message.role {
229                        last_message.content.push_str("\n\n");
230                        last_message.content.push_str(&message.content);
231                        continue;
232                    }
233                }
234
235                new_messages.push(message);
236            }
237            Role::System => {
238                if !system_message.is_empty() {
239                    system_message.push_str("\n\n");
240                }
241                system_message.push_str(&message.content);
242            }
243        }
244    }
245
246    if !system_message.is_empty() {
247        new_messages.insert(
248            0,
249            LanguageModelRequestMessage {
250                role: Role::System,
251                content: system_message,
252            },
253        );
254    }
255
256    request.messages = new_messages;
257}
258
259struct AuthenticationPrompt {
260    api_key: View<Editor>,
261    api_url: String,
262}
263
264impl AuthenticationPrompt {
265    fn new(api_url: String, cx: &mut WindowContext) -> Self {
266        Self {
267            api_key: cx.new_view(|cx| {
268                let mut editor = Editor::single_line(cx);
269                editor.set_placeholder_text(
270                    "sk-000000000000000000000000000000000000000000000000",
271                    cx,
272                );
273                editor
274            }),
275            api_url,
276        }
277    }
278
279    fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
280        let api_key = self.api_key.read(cx).text(cx);
281        if api_key.is_empty() {
282            return;
283        }
284
285        let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes());
286        cx.spawn(|_, mut cx| async move {
287            write_credentials.await?;
288            cx.update_global::<CompletionProvider, _>(|provider, _cx| {
289                provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
290                    provider.api_key = Some(api_key);
291                });
292            })
293        })
294        .detach_and_log_err(cx);
295    }
296
297    fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
298        let settings = ThemeSettings::get_global(cx);
299        let text_style = TextStyle {
300            color: cx.theme().colors().text,
301            font_family: settings.ui_font.family.clone(),
302            font_features: settings.ui_font.features.clone(),
303            font_size: rems(0.875).into(),
304            font_weight: settings.ui_font.weight,
305            font_style: FontStyle::Normal,
306            line_height: relative(1.3),
307            background_color: None,
308            underline: None,
309            strikethrough: None,
310            white_space: WhiteSpace::Normal,
311        };
312        EditorElement::new(
313            &self.api_key,
314            EditorStyle {
315                background: cx.theme().colors().editor_background,
316                local_player: cx.theme().players().local(),
317                text: text_style,
318                ..Default::default()
319            },
320        )
321    }
322}
323
324impl Render for AuthenticationPrompt {
325    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
326        const INSTRUCTIONS: [&str; 4] = [
327            "To use the assistant panel or inline assistant, you need to add your Anthropic API key.",
328            "You can create an API key at: https://console.anthropic.com/settings/keys",
329            "",
330            "Paste your Anthropic API key below and hit enter to use the assistant:",
331        ];
332
333        v_flex()
334            .p_4()
335            .size_full()
336            .on_action(cx.listener(Self::save_api_key))
337            .children(
338                INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
339            )
340            .child(
341                h_flex()
342                    .w_full()
343                    .my_2()
344                    .px_2()
345                    .py_1()
346                    .bg(cx.theme().colors().editor_background)
347                    .rounded_md()
348                    .child(self.render_api_key_editor(cx)),
349            )
350            .child(
351                Label::new(
352                    "You can also assign the ANTHROPIC_API_KEY environment variable and restart Zed.",
353                )
354                .size(LabelSize::Small),
355            )
356            .child(
357                h_flex()
358                    .gap_2()
359                    .child(Label::new("Click on").size(LabelSize::Small))
360                    .child(Icon::new(IconName::ZedAssistant).size(IconSize::XSmall))
361                    .child(
362                        Label::new("in the status bar to close this panel.").size(LabelSize::Small),
363                    ),
364            )
365            .into_any()
366    }
367}