anthropic.rs

  1use crate::{
  2    settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
  3    LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
  4    LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
  5};
  6use anyhow::{anyhow, Context as _, Result};
  7use collections::BTreeMap;
  8use editor::{Editor, EditorElement, EditorStyle};
  9use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
 10use gpui::{
 11    AnyView, AppContext, AsyncAppContext, FontStyle, ModelContext, Subscription, Task, TextStyle,
 12    View, WhiteSpace,
 13};
 14use http_client::HttpClient;
 15use schemars::JsonSchema;
 16use serde::{Deserialize, Serialize};
 17use settings::{Settings, SettingsStore};
 18use std::{sync::Arc, time::Duration};
 19use strum::IntoEnumIterator;
 20use theme::ThemeSettings;
 21use ui::{prelude::*, Indicator};
 22use util::ResultExt;
 23
 24const PROVIDER_ID: &str = "anthropic";
 25const PROVIDER_NAME: &str = "Anthropic";
 26
 27#[derive(Default, Clone, Debug, PartialEq)]
 28pub struct AnthropicSettings {
 29    pub api_url: String,
 30    pub low_speed_timeout: Option<Duration>,
 31    pub available_models: Vec<AvailableModel>,
 32    pub needs_setting_migration: bool,
 33}
 34
 35#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 36pub struct AvailableModel {
 37    pub name: String,
 38    pub max_tokens: usize,
 39    pub tool_override: Option<String>,
 40}
 41
 42pub struct AnthropicLanguageModelProvider {
 43    http_client: Arc<dyn HttpClient>,
 44    state: gpui::Model<State>,
 45}
 46
 47pub struct State {
 48    api_key: Option<String>,
 49    _subscription: Subscription,
 50}
 51
 52impl State {
 53    fn reset_api_key(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
 54        let delete_credentials =
 55            cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).anthropic.api_url);
 56        cx.spawn(|this, mut cx| async move {
 57            delete_credentials.await.ok();
 58            this.update(&mut cx, |this, cx| {
 59                this.api_key = None;
 60                cx.notify();
 61            })
 62        })
 63    }
 64
 65    fn set_api_key(&mut self, api_key: String, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
 66        let write_credentials = cx.write_credentials(
 67            AllLanguageModelSettings::get_global(cx)
 68                .anthropic
 69                .api_url
 70                .as_str(),
 71            "Bearer",
 72            api_key.as_bytes(),
 73        );
 74        cx.spawn(|this, mut cx| async move {
 75            write_credentials.await?;
 76
 77            this.update(&mut cx, |this, cx| {
 78                this.api_key = Some(api_key);
 79                cx.notify();
 80            })
 81        })
 82    }
 83
 84    fn is_authenticated(&self) -> bool {
 85        self.api_key.is_some()
 86    }
 87
 88    fn authenticate(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
 89        if self.is_authenticated() {
 90            Task::ready(Ok(()))
 91        } else {
 92            let api_url = AllLanguageModelSettings::get_global(cx)
 93                .anthropic
 94                .api_url
 95                .clone();
 96
 97            cx.spawn(|this, mut cx| async move {
 98                let api_key = if let Ok(api_key) = std::env::var("ANTHROPIC_API_KEY") {
 99                    api_key
100                } else {
101                    let (_, api_key) = cx
102                        .update(|cx| cx.read_credentials(&api_url))?
103                        .await?
104                        .ok_or_else(|| anyhow!("credentials not found"))?;
105                    String::from_utf8(api_key)?
106                };
107
108                this.update(&mut cx, |this, cx| {
109                    this.api_key = Some(api_key);
110                    cx.notify();
111                })
112            })
113        }
114    }
115}
116
117impl AnthropicLanguageModelProvider {
118    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
119        let state = cx.new_model(|cx| State {
120            api_key: None,
121            _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
122                cx.notify();
123            }),
124        });
125
126        Self { http_client, state }
127    }
128}
129
130impl LanguageModelProviderState for AnthropicLanguageModelProvider {
131    type ObservableEntity = State;
132
133    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
134        Some(self.state.clone())
135    }
136}
137
138impl LanguageModelProvider for AnthropicLanguageModelProvider {
139    fn id(&self) -> LanguageModelProviderId {
140        LanguageModelProviderId(PROVIDER_ID.into())
141    }
142
143    fn name(&self) -> LanguageModelProviderName {
144        LanguageModelProviderName(PROVIDER_NAME.into())
145    }
146
147    fn icon(&self) -> IconName {
148        IconName::AiAnthropic
149    }
150
151    fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
152        let mut models = BTreeMap::default();
153
154        // Add base models from anthropic::Model::iter()
155        for model in anthropic::Model::iter() {
156            if !matches!(model, anthropic::Model::Custom { .. }) {
157                models.insert(model.id().to_string(), model);
158            }
159        }
160
161        // Override with available models from settings
162        for model in AllLanguageModelSettings::get_global(cx)
163            .anthropic
164            .available_models
165            .iter()
166        {
167            models.insert(
168                model.name.clone(),
169                anthropic::Model::Custom {
170                    name: model.name.clone(),
171                    max_tokens: model.max_tokens,
172                    tool_override: model.tool_override.clone(),
173                },
174            );
175        }
176
177        models
178            .into_values()
179            .map(|model| {
180                Arc::new(AnthropicModel {
181                    id: LanguageModelId::from(model.id().to_string()),
182                    model,
183                    state: self.state.clone(),
184                    http_client: self.http_client.clone(),
185                    request_limiter: RateLimiter::new(4),
186                }) as Arc<dyn LanguageModel>
187            })
188            .collect()
189    }
190
191    fn is_authenticated(&self, cx: &AppContext) -> bool {
192        self.state.read(cx).is_authenticated()
193    }
194
195    fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
196        self.state.update(cx, |state, cx| state.authenticate(cx))
197    }
198
199    fn configuration_view(&self, cx: &mut WindowContext) -> AnyView {
200        cx.new_view(|cx| ConfigurationView::new(self.state.clone(), cx))
201            .into()
202    }
203
204    fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
205        self.state.update(cx, |state, cx| state.reset_api_key(cx))
206    }
207}
208
209pub struct AnthropicModel {
210    id: LanguageModelId,
211    model: anthropic::Model,
212    state: gpui::Model<State>,
213    http_client: Arc<dyn HttpClient>,
214    request_limiter: RateLimiter,
215}
216
217pub fn count_anthropic_tokens(
218    request: LanguageModelRequest,
219    cx: &AppContext,
220) -> BoxFuture<'static, Result<usize>> {
221    cx.background_executor()
222        .spawn(async move {
223            let messages = request
224                .messages
225                .into_iter()
226                .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
227                    role: match message.role {
228                        Role::User => "user".into(),
229                        Role::Assistant => "assistant".into(),
230                        Role::System => "system".into(),
231                    },
232                    content: Some(message.content),
233                    name: None,
234                    function_call: None,
235                })
236                .collect::<Vec<_>>();
237
238            // Tiktoken doesn't yet support these models, so we manually use the
239            // same tokenizer as GPT-4.
240            tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
241        })
242        .boxed()
243}
244
245impl AnthropicModel {
246    fn request_completion(
247        &self,
248        request: anthropic::Request,
249        cx: &AsyncAppContext,
250    ) -> BoxFuture<'static, Result<anthropic::Response>> {
251        let http_client = self.http_client.clone();
252
253        let Ok((api_key, api_url)) = cx.read_model(&self.state, |state, cx| {
254            let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
255            (state.api_key.clone(), settings.api_url.clone())
256        }) else {
257            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
258        };
259
260        async move {
261            let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
262            anthropic::complete(http_client.as_ref(), &api_url, &api_key, request).await
263        }
264        .boxed()
265    }
266
267    fn stream_completion(
268        &self,
269        request: anthropic::Request,
270        cx: &AsyncAppContext,
271    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<anthropic::Event>>>> {
272        let http_client = self.http_client.clone();
273
274        let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
275            let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
276            (
277                state.api_key.clone(),
278                settings.api_url.clone(),
279                settings.low_speed_timeout,
280            )
281        }) else {
282            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
283        };
284
285        async move {
286            let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
287            let request = anthropic::stream_completion(
288                http_client.as_ref(),
289                &api_url,
290                &api_key,
291                request,
292                low_speed_timeout,
293            );
294            request.await
295        }
296        .boxed()
297    }
298}
299
300impl LanguageModel for AnthropicModel {
301    fn id(&self) -> LanguageModelId {
302        self.id.clone()
303    }
304
305    fn name(&self) -> LanguageModelName {
306        LanguageModelName::from(self.model.display_name().to_string())
307    }
308
309    fn provider_id(&self) -> LanguageModelProviderId {
310        LanguageModelProviderId(PROVIDER_ID.into())
311    }
312
313    fn provider_name(&self) -> LanguageModelProviderName {
314        LanguageModelProviderName(PROVIDER_NAME.into())
315    }
316
317    fn telemetry_id(&self) -> String {
318        format!("anthropic/{}", self.model.id())
319    }
320
321    fn max_token_count(&self) -> usize {
322        self.model.max_token_count()
323    }
324
325    fn count_tokens(
326        &self,
327        request: LanguageModelRequest,
328        cx: &AppContext,
329    ) -> BoxFuture<'static, Result<usize>> {
330        count_anthropic_tokens(request, cx)
331    }
332
333    fn stream_completion(
334        &self,
335        request: LanguageModelRequest,
336        cx: &AsyncAppContext,
337    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
338        let request = request.into_anthropic(self.model.id().into());
339        let request = self.stream_completion(request, cx);
340        let future = self.request_limiter.stream(async move {
341            let response = request.await?;
342            Ok(anthropic::extract_text_from_events(response))
343        });
344        async move { Ok(future.await?.boxed()) }.boxed()
345    }
346
347    fn use_any_tool(
348        &self,
349        request: LanguageModelRequest,
350        tool_name: String,
351        tool_description: String,
352        input_schema: serde_json::Value,
353        cx: &AsyncAppContext,
354    ) -> BoxFuture<'static, Result<serde_json::Value>> {
355        let mut request = request.into_anthropic(self.model.tool_model_id().into());
356        request.tool_choice = Some(anthropic::ToolChoice::Tool {
357            name: tool_name.clone(),
358        });
359        request.tools = vec![anthropic::Tool {
360            name: tool_name.clone(),
361            description: tool_description,
362            input_schema,
363        }];
364
365        let response = self.request_completion(request, cx);
366        self.request_limiter
367            .run(async move {
368                let response = response.await?;
369                response
370                    .content
371                    .into_iter()
372                    .find_map(|content| {
373                        if let anthropic::Content::ToolUse { name, input, .. } = content {
374                            if name == tool_name {
375                                Some(input)
376                            } else {
377                                None
378                            }
379                        } else {
380                            None
381                        }
382                    })
383                    .context("tool not used")
384            })
385            .boxed()
386    }
387}
388
389struct ConfigurationView {
390    api_key_editor: View<Editor>,
391    state: gpui::Model<State>,
392    load_credentials_task: Option<Task<()>>,
393}
394
395impl ConfigurationView {
396    const PLACEHOLDER_TEXT: &'static str = "sk-ant-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx";
397
398    fn new(state: gpui::Model<State>, cx: &mut ViewContext<Self>) -> Self {
399        cx.observe(&state, |_, _, cx| {
400            cx.notify();
401        })
402        .detach();
403
404        let load_credentials_task = Some(cx.spawn({
405            let state = state.clone();
406            |this, mut cx| async move {
407                if let Some(task) = state
408                    .update(&mut cx, |state, cx| state.authenticate(cx))
409                    .log_err()
410                {
411                    // We don't log an error, because "not signed in" is also an error.
412                    let _ = task.await;
413                }
414                this.update(&mut cx, |this, cx| {
415                    this.load_credentials_task = None;
416                    cx.notify();
417                })
418                .log_err();
419            }
420        }));
421
422        Self {
423            api_key_editor: cx.new_view(|cx| {
424                let mut editor = Editor::single_line(cx);
425                editor.set_placeholder_text(Self::PLACEHOLDER_TEXT, cx);
426                editor
427            }),
428            state,
429            load_credentials_task,
430        }
431    }
432
433    fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
434        let api_key = self.api_key_editor.read(cx).text(cx);
435        if api_key.is_empty() {
436            return;
437        }
438
439        let state = self.state.clone();
440        cx.spawn(|_, mut cx| async move {
441            state
442                .update(&mut cx, |state, cx| state.set_api_key(api_key, cx))?
443                .await
444        })
445        .detach_and_log_err(cx);
446
447        cx.notify();
448    }
449
450    fn reset_api_key(&mut self, cx: &mut ViewContext<Self>) {
451        self.api_key_editor
452            .update(cx, |editor, cx| editor.set_text("", cx));
453
454        let state = self.state.clone();
455        cx.spawn(|_, mut cx| async move {
456            state
457                .update(&mut cx, |state, cx| state.reset_api_key(cx))?
458                .await
459        })
460        .detach_and_log_err(cx);
461
462        cx.notify();
463    }
464
465    fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
466        let settings = ThemeSettings::get_global(cx);
467        let text_style = TextStyle {
468            color: cx.theme().colors().text,
469            font_family: settings.ui_font.family.clone(),
470            font_features: settings.ui_font.features.clone(),
471            font_fallbacks: settings.ui_font.fallbacks.clone(),
472            font_size: rems(0.875).into(),
473            font_weight: settings.ui_font.weight,
474            font_style: FontStyle::Normal,
475            line_height: relative(1.3),
476            background_color: None,
477            underline: None,
478            strikethrough: None,
479            white_space: WhiteSpace::Normal,
480        };
481        EditorElement::new(
482            &self.api_key_editor,
483            EditorStyle {
484                background: cx.theme().colors().editor_background,
485                local_player: cx.theme().players().local(),
486                text: text_style,
487                ..Default::default()
488            },
489        )
490    }
491
492    fn should_render_editor(&self, cx: &mut ViewContext<Self>) -> bool {
493        !self.state.read(cx).is_authenticated()
494    }
495}
496
497impl Render for ConfigurationView {
498    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
499        const INSTRUCTIONS: [&str; 4] = [
500            "To use the assistant panel or inline assistant, you need to add your Anthropic API key.",
501            "You can create an API key at: https://console.anthropic.com/settings/keys",
502            "",
503            "Paste your Anthropic API key below and hit enter to use the assistant:",
504        ];
505
506        if self.load_credentials_task.is_some() {
507            div().child(Label::new("Loading credentials...")).into_any()
508        } else if self.should_render_editor(cx) {
509            v_flex()
510                .size_full()
511                .on_action(cx.listener(Self::save_api_key))
512                .children(
513                    INSTRUCTIONS.map(|instruction| Label::new(instruction)),
514                )
515                .child(
516                    h_flex()
517                        .w_full()
518                        .my_2()
519                        .px_2()
520                        .py_1()
521                        .bg(cx.theme().colors().editor_background)
522                        .rounded_md()
523                        .child(self.render_api_key_editor(cx)),
524                )
525                .child(
526                    Label::new(
527                        "You can also assign the ANTHROPIC_API_KEY environment variable and restart Zed.",
528                    )
529                    .size(LabelSize::Small),
530                )
531                .into_any()
532        } else {
533            h_flex()
534                .size_full()
535                .justify_between()
536                .child(
537                    h_flex()
538                        .gap_2()
539                        .child(Indicator::dot().color(Color::Success))
540                        .child(Label::new("API key configured").size(LabelSize::Small)),
541                )
542                .child(
543                    Button::new("reset-key", "Reset key")
544                        .icon(Some(IconName::Trash))
545                        .icon_size(IconSize::Small)
546                        .icon_position(IconPosition::Start)
547                        .on_click(cx.listener(|this, _, cx| this.reset_api_key(cx))),
548                )
549                .into_any()
550        }
551    }
552}