anthropic.rs

  1pub mod telemetry;
  2
  3use anthropic::{ANTHROPIC_API_URL, AnthropicError, AnthropicModelMode};
  4use anyhow::Result;
  5use collections::BTreeMap;
  6use credentials_provider::CredentialsProvider;
  7use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
  8use gpui::{AnyView, App, AsyncApp, Context, Entity, Task};
  9use http_client::HttpClient;
 10use language_model::{
 11    ANTHROPIC_PROVIDER_ID, ANTHROPIC_PROVIDER_NAME, ApiKeyState, AuthenticateError,
 12    ConfigurationViewTargetAgent, EnvVar, IconOrSvg, LanguageModel,
 13    LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent,
 14    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
 15    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
 16    LanguageModelToolChoice, RateLimiter, env_var,
 17};
 18use settings::{Settings, SettingsStore};
 19use std::sync::{Arc, LazyLock};
 20use strum::IntoEnumIterator;
 21use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
 22use ui_input::InputField;
 23use util::ResultExt;
 24
 25pub use anthropic::completion::{
 26    AnthropicEventMapper, count_anthropic_tokens_with_tiktoken, into_anthropic,
 27    into_anthropic_count_tokens_request,
 28};
 29pub use settings::AnthropicAvailableModel as AvailableModel;
 30
 31const PROVIDER_ID: LanguageModelProviderId = ANTHROPIC_PROVIDER_ID;
 32const PROVIDER_NAME: LanguageModelProviderName = ANTHROPIC_PROVIDER_NAME;
 33
 34#[derive(Default, Clone, Debug, PartialEq)]
 35pub struct AnthropicSettings {
 36    pub api_url: String,
 37    /// Extend Zed's list of Anthropic models.
 38    pub available_models: Vec<AvailableModel>,
 39}
 40
 41pub struct AnthropicLanguageModelProvider {
 42    http_client: Arc<dyn HttpClient>,
 43    state: Entity<State>,
 44}
 45
 46const API_KEY_ENV_VAR_NAME: &str = "ANTHROPIC_API_KEY";
 47static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
 48
 49pub struct State {
 50    api_key_state: ApiKeyState,
 51    credentials_provider: Arc<dyn CredentialsProvider>,
 52}
 53
 54impl State {
 55    fn is_authenticated(&self) -> bool {
 56        self.api_key_state.has_key()
 57    }
 58
 59    fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
 60        let credentials_provider = self.credentials_provider.clone();
 61        let api_url = AnthropicLanguageModelProvider::api_url(cx);
 62        self.api_key_state.store(
 63            api_url,
 64            api_key,
 65            |this| &mut this.api_key_state,
 66            credentials_provider,
 67            cx,
 68        )
 69    }
 70
 71    fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
 72        let credentials_provider = self.credentials_provider.clone();
 73        let api_url = AnthropicLanguageModelProvider::api_url(cx);
 74        self.api_key_state.load_if_needed(
 75            api_url,
 76            |this| &mut this.api_key_state,
 77            credentials_provider,
 78            cx,
 79        )
 80    }
 81}
 82
 83impl AnthropicLanguageModelProvider {
 84    pub fn new(
 85        http_client: Arc<dyn HttpClient>,
 86        credentials_provider: Arc<dyn CredentialsProvider>,
 87        cx: &mut App,
 88    ) -> Self {
 89        let state = cx.new(|cx| {
 90            cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
 91                let credentials_provider = this.credentials_provider.clone();
 92                let api_url = Self::api_url(cx);
 93                this.api_key_state.handle_url_change(
 94                    api_url,
 95                    |this| &mut this.api_key_state,
 96                    credentials_provider,
 97                    cx,
 98                );
 99                cx.notify();
100            })
101            .detach();
102            State {
103                api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
104                credentials_provider,
105            }
106        });
107
108        Self { http_client, state }
109    }
110
111    fn create_language_model(&self, model: anthropic::Model) -> Arc<dyn LanguageModel> {
112        Arc::new(AnthropicModel {
113            id: LanguageModelId::from(model.id().to_string()),
114            model,
115            state: self.state.clone(),
116            http_client: self.http_client.clone(),
117            request_limiter: RateLimiter::new(4),
118        })
119    }
120
121    fn settings(cx: &App) -> &AnthropicSettings {
122        &crate::AllLanguageModelSettings::get_global(cx).anthropic
123    }
124
125    fn api_url(cx: &App) -> SharedString {
126        let api_url = &Self::settings(cx).api_url;
127        if api_url.is_empty() {
128            ANTHROPIC_API_URL.into()
129        } else {
130            SharedString::new(api_url.as_str())
131        }
132    }
133}
134
135impl LanguageModelProviderState for AnthropicLanguageModelProvider {
136    type ObservableEntity = State;
137
138    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
139        Some(self.state.clone())
140    }
141}
142
143impl LanguageModelProvider for AnthropicLanguageModelProvider {
144    fn id(&self) -> LanguageModelProviderId {
145        PROVIDER_ID
146    }
147
148    fn name(&self) -> LanguageModelProviderName {
149        PROVIDER_NAME
150    }
151
152    fn icon(&self) -> IconOrSvg {
153        IconOrSvg::Icon(IconName::AiAnthropic)
154    }
155
156    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
157        Some(self.create_language_model(anthropic::Model::default()))
158    }
159
160    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
161        Some(self.create_language_model(anthropic::Model::default_fast()))
162    }
163
164    fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
165        [anthropic::Model::ClaudeSonnet4_6]
166            .into_iter()
167            .map(|model| self.create_language_model(model))
168            .collect()
169    }
170
171    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
172        let mut models = BTreeMap::default();
173
174        // Add base models from anthropic::Model::iter()
175        for model in anthropic::Model::iter() {
176            if !matches!(model, anthropic::Model::Custom { .. }) {
177                models.insert(model.id().to_string(), model);
178            }
179        }
180
181        // Override with available models from settings
182        for model in &AnthropicLanguageModelProvider::settings(cx).available_models {
183            models.insert(
184                model.name.clone(),
185                anthropic::Model::Custom {
186                    name: model.name.clone(),
187                    display_name: model.display_name.clone(),
188                    max_tokens: model.max_tokens,
189                    tool_override: model.tool_override.clone(),
190                    cache_configuration: model.cache_configuration.as_ref().map(|config| {
191                        anthropic::AnthropicModelCacheConfiguration {
192                            max_cache_anchors: config.max_cache_anchors,
193                            should_speculate: config.should_speculate,
194                            min_total_token: config.min_total_token,
195                        }
196                    }),
197                    max_output_tokens: model.max_output_tokens,
198                    default_temperature: model.default_temperature,
199                    extra_beta_headers: model.extra_beta_headers.clone(),
200                    mode: match model.mode.unwrap_or_default() {
201                        settings::ModelMode::Default => AnthropicModelMode::Default,
202                        settings::ModelMode::Thinking { budget_tokens } => {
203                            AnthropicModelMode::Thinking { budget_tokens }
204                        }
205                    },
206                },
207            );
208        }
209
210        models
211            .into_values()
212            .map(|model| self.create_language_model(model))
213            .collect()
214    }
215
216    fn is_authenticated(&self, cx: &App) -> bool {
217        self.state.read(cx).is_authenticated()
218    }
219
220    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
221        self.state.update(cx, |state, cx| state.authenticate(cx))
222    }
223
224    fn configuration_view(
225        &self,
226        target_agent: ConfigurationViewTargetAgent,
227        window: &mut Window,
228        cx: &mut App,
229    ) -> AnyView {
230        cx.new(|cx| ConfigurationView::new(self.state.clone(), target_agent, window, cx))
231            .into()
232    }
233
234    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
235        self.state
236            .update(cx, |state, cx| state.set_api_key(None, cx))
237    }
238}
239
240pub struct AnthropicModel {
241    id: LanguageModelId,
242    model: anthropic::Model,
243    state: Entity<State>,
244    http_client: Arc<dyn HttpClient>,
245    request_limiter: RateLimiter,
246}
247
248impl AnthropicModel {
249    fn stream_completion(
250        &self,
251        request: anthropic::Request,
252        cx: &AsyncApp,
253    ) -> BoxFuture<
254        'static,
255        Result<
256            BoxStream<'static, Result<anthropic::Event, AnthropicError>>,
257            LanguageModelCompletionError,
258        >,
259    > {
260        let http_client = self.http_client.clone();
261
262        let (api_key, api_url) = self.state.read_with(cx, |state, cx| {
263            let api_url = AnthropicLanguageModelProvider::api_url(cx);
264            (state.api_key_state.key(&api_url), api_url)
265        });
266
267        let beta_headers = self.model.beta_headers();
268
269        async move {
270            let Some(api_key) = api_key else {
271                return Err(LanguageModelCompletionError::NoApiKey {
272                    provider: PROVIDER_NAME,
273                });
274            };
275            let request = anthropic::stream_completion(
276                http_client.as_ref(),
277                &api_url,
278                &api_key,
279                request,
280                beta_headers,
281            );
282            request.await.map_err(Into::into)
283        }
284        .boxed()
285    }
286}
287
288impl LanguageModel for AnthropicModel {
289    fn id(&self) -> LanguageModelId {
290        self.id.clone()
291    }
292
293    fn name(&self) -> LanguageModelName {
294        LanguageModelName::from(self.model.display_name().to_string())
295    }
296
297    fn provider_id(&self) -> LanguageModelProviderId {
298        PROVIDER_ID
299    }
300
301    fn provider_name(&self) -> LanguageModelProviderName {
302        PROVIDER_NAME
303    }
304
305    fn supports_tools(&self) -> bool {
306        true
307    }
308
309    fn supports_images(&self) -> bool {
310        true
311    }
312
313    fn supports_streaming_tools(&self) -> bool {
314        true
315    }
316
317    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
318        match choice {
319            LanguageModelToolChoice::Auto
320            | LanguageModelToolChoice::Any
321            | LanguageModelToolChoice::None => true,
322        }
323    }
324
325    fn supports_thinking(&self) -> bool {
326        self.model.supports_thinking()
327    }
328
329    fn supports_fast_mode(&self) -> bool {
330        self.model.supports_speed()
331    }
332
333    fn supported_effort_levels(&self) -> Vec<language_model::LanguageModelEffortLevel> {
334        if self.model.supports_adaptive_thinking() {
335            vec![
336                language_model::LanguageModelEffortLevel {
337                    name: "Low".into(),
338                    value: "low".into(),
339                    is_default: false,
340                },
341                language_model::LanguageModelEffortLevel {
342                    name: "Medium".into(),
343                    value: "medium".into(),
344                    is_default: false,
345                },
346                language_model::LanguageModelEffortLevel {
347                    name: "High".into(),
348                    value: "high".into(),
349                    is_default: true,
350                },
351                language_model::LanguageModelEffortLevel {
352                    name: "Max".into(),
353                    value: "max".into(),
354                    is_default: false,
355                },
356            ]
357        } else {
358            Vec::new()
359        }
360    }
361
362    fn telemetry_id(&self) -> String {
363        format!("anthropic/{}", self.model.id())
364    }
365
366    fn api_key(&self, cx: &App) -> Option<String> {
367        self.state.read_with(cx, |state, cx| {
368            let api_url = AnthropicLanguageModelProvider::api_url(cx);
369            state.api_key_state.key(&api_url).map(|key| key.to_string())
370        })
371    }
372
373    fn max_token_count(&self) -> u64 {
374        self.model.max_token_count()
375    }
376
377    fn max_output_tokens(&self) -> Option<u64> {
378        Some(self.model.max_output_tokens())
379    }
380
381    fn count_tokens(
382        &self,
383        request: LanguageModelRequest,
384        cx: &App,
385    ) -> BoxFuture<'static, Result<u64>> {
386        let http_client = self.http_client.clone();
387        let model_id = self.model.request_id().to_string();
388        let mode = self.model.mode();
389
390        let (api_key, api_url) = self.state.read_with(cx, |state, cx| {
391            let api_url = AnthropicLanguageModelProvider::api_url(cx);
392            (
393                state.api_key_state.key(&api_url).map(|k| k.to_string()),
394                api_url.to_string(),
395            )
396        });
397
398        let background = cx.background_executor().clone();
399        async move {
400            // If no API key, fall back to tiktoken estimation
401            let Some(api_key) = api_key else {
402                return background
403                    .spawn(async move { count_anthropic_tokens_with_tiktoken(request) })
404                    .await;
405            };
406
407            let count_request =
408                into_anthropic_count_tokens_request(request.clone(), model_id, mode);
409
410            match anthropic::count_tokens(http_client.as_ref(), &api_url, &api_key, count_request)
411                .await
412            {
413                Ok(response) => Ok(response.input_tokens),
414                Err(err) => {
415                    log::error!(
416                        "Anthropic count_tokens API failed, falling back to tiktoken: {err:?}"
417                    );
418                    background
419                        .spawn(async move { count_anthropic_tokens_with_tiktoken(request) })
420                        .await
421                }
422            }
423        }
424        .boxed()
425    }
426
427    fn stream_completion(
428        &self,
429        request: LanguageModelRequest,
430        cx: &AsyncApp,
431    ) -> BoxFuture<
432        'static,
433        Result<
434            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
435            LanguageModelCompletionError,
436        >,
437    > {
438        let mut request = into_anthropic(
439            request,
440            self.model.request_id().into(),
441            self.model.default_temperature(),
442            self.model.max_output_tokens(),
443            self.model.mode(),
444        );
445        if !self.model.supports_speed() {
446            request.speed = None;
447        }
448        let request = self.stream_completion(request, cx);
449        let future = self.request_limiter.stream(async move {
450            let response = request.await?;
451            Ok(AnthropicEventMapper::new().map_stream(response))
452        });
453        async move { Ok(future.await?.boxed()) }.boxed()
454    }
455
456    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
457        self.model
458            .cache_configuration()
459            .map(|config| LanguageModelCacheConfiguration {
460                max_cache_anchors: config.max_cache_anchors,
461                should_speculate: config.should_speculate,
462                min_total_token: config.min_total_token,
463            })
464    }
465}
466
467struct ConfigurationView {
468    api_key_editor: Entity<InputField>,
469    state: Entity<State>,
470    load_credentials_task: Option<Task<()>>,
471    target_agent: ConfigurationViewTargetAgent,
472}
473
474impl ConfigurationView {
475    const PLACEHOLDER_TEXT: &'static str = "sk-ant-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx";
476
477    fn new(
478        state: Entity<State>,
479        target_agent: ConfigurationViewTargetAgent,
480        window: &mut Window,
481        cx: &mut Context<Self>,
482    ) -> Self {
483        cx.observe(&state, |_, _, cx| {
484            cx.notify();
485        })
486        .detach();
487
488        let load_credentials_task = Some(cx.spawn({
489            let state = state.clone();
490            async move |this, cx| {
491                let task = state.update(cx, |state, cx| state.authenticate(cx));
492                // We don't log an error, because "not signed in" is also an error.
493                let _ = task.await;
494                this.update(cx, |this, cx| {
495                    this.load_credentials_task = None;
496                    cx.notify();
497                })
498                .log_err();
499            }
500        }));
501
502        Self {
503            api_key_editor: cx.new(|cx| InputField::new(window, cx, Self::PLACEHOLDER_TEXT)),
504            state,
505            load_credentials_task,
506            target_agent,
507        }
508    }
509
510    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
511        let api_key = self.api_key_editor.read(cx).text(cx);
512        if api_key.is_empty() {
513            return;
514        }
515
516        // url changes can cause the editor to be displayed again
517        self.api_key_editor
518            .update(cx, |editor, cx| editor.set_text("", window, cx));
519
520        let state = self.state.clone();
521        cx.spawn_in(window, async move |_, cx| {
522            state
523                .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))
524                .await
525        })
526        .detach_and_log_err(cx);
527    }
528
529    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
530        self.api_key_editor
531            .update(cx, |editor, cx| editor.set_text("", window, cx));
532
533        let state = self.state.clone();
534        cx.spawn_in(window, async move |_, cx| {
535            state
536                .update(cx, |state, cx| state.set_api_key(None, cx))
537                .await
538        })
539        .detach_and_log_err(cx);
540    }
541
542    fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
543        !self.state.read(cx).is_authenticated()
544    }
545}
546
547impl Render for ConfigurationView {
548    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
549        let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
550        let configured_card_label = if env_var_set {
551            format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
552        } else {
553            let api_url = AnthropicLanguageModelProvider::api_url(cx);
554            if api_url == ANTHROPIC_API_URL {
555                "API key configured".to_string()
556            } else {
557                format!("API key configured for {}", api_url)
558            }
559        };
560
561        if self.load_credentials_task.is_some() {
562            div()
563                .child(Label::new("Loading credentials..."))
564                .into_any_element()
565        } else if self.should_render_editor(cx) {
566            v_flex()
567                .size_full()
568                .on_action(cx.listener(Self::save_api_key))
569                .child(Label::new(format!("To use {}, you need to add an API key. Follow these steps:", match &self.target_agent {
570                    ConfigurationViewTargetAgent::ZedAgent => "Zed's agent with Anthropic".into(),
571                    ConfigurationViewTargetAgent::Other(agent) => agent.clone(),
572                })))
573                .child(
574                    List::new()
575                        .child(
576                            ListBulletItem::new("")
577                                .child(Label::new("Create one by visiting"))
578                                .child(ButtonLink::new("Anthropic's settings", "https://console.anthropic.com/settings/keys"))
579                        )
580                        .child(
581                            ListBulletItem::new("Paste your API key below and hit enter to start using the agent")
582                        )
583                )
584                .child(self.api_key_editor.clone())
585                .child(
586                    Label::new(
587                        format!("You can also set the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."),
588                    )
589                    .size(LabelSize::Small)
590                    .color(Color::Muted)
591                    .mt_0p5(),
592                )
593                .into_any_element()
594        } else {
595            ConfiguredApiCard::new(configured_card_label)
596                .disabled(env_var_set)
597                .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)))
598                .when(env_var_set, |this| {
599                    this.tooltip_label(format!(
600                    "To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable."
601                ))
602                })
603                .into_any_element()
604        }
605    }
606}