anthropic.rs

  1pub mod telemetry;
  2
  3use anthropic::{ANTHROPIC_API_URL, AnthropicError, AnthropicModelMode};
  4use anyhow::Result;
  5use collections::BTreeMap;
  6use credentials_provider::CredentialsProvider;
  7use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
  8use gpui::{AnyView, App, AsyncApp, Context, Entity, Task};
  9use http_client::HttpClient;
 10use language_model::{
 11    ANTHROPIC_PROVIDER_ID, ANTHROPIC_PROVIDER_NAME, ApiKeyState, AuthenticateError,
 12    ConfigurationViewTargetAgent, EnvVar, IconOrSvg, LanguageModel,
 13    LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent,
 14    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
 15    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
 16    LanguageModelToolChoice, RateLimiter, env_var,
 17};
 18use settings::{Settings, SettingsStore};
 19use std::sync::{Arc, LazyLock};
 20use strum::IntoEnumIterator;
 21use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
 22use ui_input::InputField;
 23use util::ResultExt;
 24
 25pub use anthropic::completion::{
 26    AnthropicEventMapper, count_anthropic_tokens_with_tiktoken, into_anthropic,
 27    into_anthropic_count_tokens_request,
 28};
 29pub use settings::AnthropicAvailableModel as AvailableModel;
 30
 31const PROVIDER_ID: LanguageModelProviderId = ANTHROPIC_PROVIDER_ID;
 32const PROVIDER_NAME: LanguageModelProviderName = ANTHROPIC_PROVIDER_NAME;
 33
 34#[derive(Default, Clone, Debug, PartialEq)]
 35pub struct AnthropicSettings {
 36    pub api_url: String,
 37    /// Extend Zed's list of Anthropic models.
 38    pub available_models: Vec<AvailableModel>,
 39}
 40
 41pub struct AnthropicLanguageModelProvider {
 42    http_client: Arc<dyn HttpClient>,
 43    state: Entity<State>,
 44}
 45
 46const API_KEY_ENV_VAR_NAME: &str = "ANTHROPIC_API_KEY";
 47static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
 48
 49pub struct State {
 50    api_key_state: ApiKeyState,
 51    credentials_provider: Arc<dyn CredentialsProvider>,
 52}
 53
 54impl State {
 55    fn is_authenticated(&self) -> bool {
 56        self.api_key_state.has_key()
 57    }
 58
 59    fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
 60        let credentials_provider = self.credentials_provider.clone();
 61        let api_url = AnthropicLanguageModelProvider::api_url(cx);
 62        self.api_key_state.store(
 63            api_url,
 64            api_key,
 65            |this| &mut this.api_key_state,
 66            credentials_provider,
 67            cx,
 68        )
 69    }
 70
 71    fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
 72        let credentials_provider = self.credentials_provider.clone();
 73        let api_url = AnthropicLanguageModelProvider::api_url(cx);
 74        self.api_key_state.load_if_needed(
 75            api_url,
 76            |this| &mut this.api_key_state,
 77            credentials_provider,
 78            cx,
 79        )
 80    }
 81}
 82
 83impl AnthropicLanguageModelProvider {
 84    pub fn new(
 85        http_client: Arc<dyn HttpClient>,
 86        credentials_provider: Arc<dyn CredentialsProvider>,
 87        cx: &mut App,
 88    ) -> Self {
 89        let state = cx.new(|cx| {
 90            cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
 91                let credentials_provider = this.credentials_provider.clone();
 92                let api_url = Self::api_url(cx);
 93                this.api_key_state.handle_url_change(
 94                    api_url,
 95                    |this| &mut this.api_key_state,
 96                    credentials_provider,
 97                    cx,
 98                );
 99                cx.notify();
100            })
101            .detach();
102            State {
103                api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
104                credentials_provider,
105            }
106        });
107
108        Self { http_client, state }
109    }
110
111    fn create_language_model(&self, model: anthropic::Model) -> Arc<dyn LanguageModel> {
112        Arc::new(AnthropicModel {
113            id: LanguageModelId::from(model.id().to_string()),
114            model,
115            state: self.state.clone(),
116            http_client: self.http_client.clone(),
117            request_limiter: RateLimiter::new(4),
118        })
119    }
120
121    fn settings(cx: &App) -> &AnthropicSettings {
122        &crate::AllLanguageModelSettings::get_global(cx).anthropic
123    }
124
125    fn api_url(cx: &App) -> SharedString {
126        let api_url = &Self::settings(cx).api_url;
127        if api_url.is_empty() {
128            ANTHROPIC_API_URL.into()
129        } else {
130            SharedString::new(api_url.as_str())
131        }
132    }
133}
134
135impl LanguageModelProviderState for AnthropicLanguageModelProvider {
136    type ObservableEntity = State;
137
138    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
139        Some(self.state.clone())
140    }
141}
142
143impl LanguageModelProvider for AnthropicLanguageModelProvider {
144    fn id(&self) -> LanguageModelProviderId {
145        PROVIDER_ID
146    }
147
148    fn name(&self) -> LanguageModelProviderName {
149        PROVIDER_NAME
150    }
151
152    fn icon(&self) -> IconOrSvg {
153        IconOrSvg::Icon(IconName::AiAnthropic)
154    }
155
156    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
157        Some(self.create_language_model(anthropic::Model::default()))
158    }
159
160    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
161        Some(self.create_language_model(anthropic::Model::default_fast()))
162    }
163
164    fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
165        [anthropic::Model::ClaudeSonnet4_6]
166            .into_iter()
167            .map(|model| self.create_language_model(model))
168            .collect()
169    }
170
171    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
172        let mut models = BTreeMap::default();
173
174        // Add base models from anthropic::Model::iter()
175        for model in anthropic::Model::iter() {
176            if !matches!(model, anthropic::Model::Custom { .. }) {
177                models.insert(model.id().to_string(), model);
178            }
179        }
180
181        // Override with available models from settings
182        for model in &AnthropicLanguageModelProvider::settings(cx).available_models {
183            models.insert(
184                model.name.clone(),
185                anthropic::Model::Custom {
186                    name: model.name.clone(),
187                    display_name: model.display_name.clone(),
188                    max_tokens: model.max_tokens,
189                    tool_override: model.tool_override.clone(),
190                    cache_configuration: model.cache_configuration.as_ref().map(|config| {
191                        anthropic::AnthropicModelCacheConfiguration {
192                            max_cache_anchors: config.max_cache_anchors,
193                            should_speculate: config.should_speculate,
194                            min_total_token: config.min_total_token,
195                        }
196                    }),
197                    max_output_tokens: model.max_output_tokens,
198                    default_temperature: model.default_temperature,
199                    extra_beta_headers: model.extra_beta_headers.clone(),
200                    mode: match model.mode.unwrap_or_default() {
201                        settings::ModelMode::Default => AnthropicModelMode::Default,
202                        settings::ModelMode::Thinking { budget_tokens } => {
203                            AnthropicModelMode::Thinking { budget_tokens }
204                        }
205                    },
206                },
207            );
208        }
209
210        models
211            .into_values()
212            .map(|model| self.create_language_model(model))
213            .collect()
214    }
215
216    fn is_authenticated(&self, cx: &App) -> bool {
217        self.state.read(cx).is_authenticated()
218    }
219
220    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
221        self.state.update(cx, |state, cx| state.authenticate(cx))
222    }
223
224    fn configuration_view(
225        &self,
226        target_agent: ConfigurationViewTargetAgent,
227        window: &mut Window,
228        cx: &mut App,
229    ) -> AnyView {
230        cx.new(|cx| ConfigurationView::new(self.state.clone(), target_agent, window, cx))
231            .into()
232    }
233
234    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
235        self.state
236            .update(cx, |state, cx| state.set_api_key(None, cx))
237    }
238}
239
240pub struct AnthropicModel {
241    id: LanguageModelId,
242    model: anthropic::Model,
243    state: Entity<State>,
244    http_client: Arc<dyn HttpClient>,
245    request_limiter: RateLimiter,
246}
247
248impl AnthropicModel {
249    fn stream_completion(
250        &self,
251        request: anthropic::Request,
252        cx: &AsyncApp,
253    ) -> BoxFuture<
254        'static,
255        Result<
256            BoxStream<'static, Result<anthropic::Event, AnthropicError>>,
257            LanguageModelCompletionError,
258        >,
259    > {
260        let http_client = self.http_client.clone();
261
262        let (api_key, api_url) = self.state.read_with(cx, |state, cx| {
263            let api_url = AnthropicLanguageModelProvider::api_url(cx);
264            (state.api_key_state.key(&api_url), api_url)
265        });
266
267        let beta_headers = self.model.beta_headers();
268
269        async move {
270            let Some(api_key) = api_key else {
271                return Err(LanguageModelCompletionError::NoApiKey {
272                    provider: PROVIDER_NAME,
273                });
274            };
275            let request = anthropic::stream_completion(
276                http_client.as_ref(),
277                &api_url,
278                &api_key,
279                request,
280                beta_headers,
281            );
282            request.await.map_err(Into::into)
283        }
284        .boxed()
285    }
286}
287
288impl LanguageModel for AnthropicModel {
289    fn id(&self) -> LanguageModelId {
290        self.id.clone()
291    }
292
293    fn name(&self) -> LanguageModelName {
294        LanguageModelName::from(self.model.display_name().to_string())
295    }
296
297    fn provider_id(&self) -> LanguageModelProviderId {
298        PROVIDER_ID
299    }
300
301    fn provider_name(&self) -> LanguageModelProviderName {
302        PROVIDER_NAME
303    }
304
305    fn supports_tools(&self) -> bool {
306        true
307    }
308
309    fn supports_images(&self) -> bool {
310        true
311    }
312
313    fn supports_streaming_tools(&self) -> bool {
314        true
315    }
316
317    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
318        match choice {
319            LanguageModelToolChoice::Auto
320            | LanguageModelToolChoice::Any
321            | LanguageModelToolChoice::None => true,
322        }
323    }
324
325    fn supports_thinking(&self) -> bool {
326        self.model.supports_thinking()
327    }
328
329    fn supported_effort_levels(&self) -> Vec<language_model::LanguageModelEffortLevel> {
330        if self.model.supports_adaptive_thinking() {
331            vec![
332                language_model::LanguageModelEffortLevel {
333                    name: "Low".into(),
334                    value: "low".into(),
335                    is_default: false,
336                },
337                language_model::LanguageModelEffortLevel {
338                    name: "Medium".into(),
339                    value: "medium".into(),
340                    is_default: false,
341                },
342                language_model::LanguageModelEffortLevel {
343                    name: "High".into(),
344                    value: "high".into(),
345                    is_default: true,
346                },
347                language_model::LanguageModelEffortLevel {
348                    name: "Max".into(),
349                    value: "max".into(),
350                    is_default: false,
351                },
352            ]
353        } else {
354            Vec::new()
355        }
356    }
357
358    fn telemetry_id(&self) -> String {
359        format!("anthropic/{}", self.model.id())
360    }
361
362    fn api_key(&self, cx: &App) -> Option<String> {
363        self.state.read_with(cx, |state, cx| {
364            let api_url = AnthropicLanguageModelProvider::api_url(cx);
365            state.api_key_state.key(&api_url).map(|key| key.to_string())
366        })
367    }
368
369    fn max_token_count(&self) -> u64 {
370        self.model.max_token_count()
371    }
372
373    fn max_output_tokens(&self) -> Option<u64> {
374        Some(self.model.max_output_tokens())
375    }
376
377    fn count_tokens(
378        &self,
379        request: LanguageModelRequest,
380        cx: &App,
381    ) -> BoxFuture<'static, Result<u64>> {
382        let http_client = self.http_client.clone();
383        let model_id = self.model.request_id().to_string();
384        let mode = self.model.mode();
385
386        let (api_key, api_url) = self.state.read_with(cx, |state, cx| {
387            let api_url = AnthropicLanguageModelProvider::api_url(cx);
388            (
389                state.api_key_state.key(&api_url).map(|k| k.to_string()),
390                api_url.to_string(),
391            )
392        });
393
394        let background = cx.background_executor().clone();
395        async move {
396            // If no API key, fall back to tiktoken estimation
397            let Some(api_key) = api_key else {
398                return background
399                    .spawn(async move { count_anthropic_tokens_with_tiktoken(request) })
400                    .await;
401            };
402
403            let count_request =
404                into_anthropic_count_tokens_request(request.clone(), model_id, mode);
405
406            match anthropic::count_tokens(http_client.as_ref(), &api_url, &api_key, count_request)
407                .await
408            {
409                Ok(response) => Ok(response.input_tokens),
410                Err(err) => {
411                    log::error!(
412                        "Anthropic count_tokens API failed, falling back to tiktoken: {err:?}"
413                    );
414                    background
415                        .spawn(async move { count_anthropic_tokens_with_tiktoken(request) })
416                        .await
417                }
418            }
419        }
420        .boxed()
421    }
422
423    fn stream_completion(
424        &self,
425        request: LanguageModelRequest,
426        cx: &AsyncApp,
427    ) -> BoxFuture<
428        'static,
429        Result<
430            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
431            LanguageModelCompletionError,
432        >,
433    > {
434        let request = into_anthropic(
435            request,
436            self.model.request_id().into(),
437            self.model.default_temperature(),
438            self.model.max_output_tokens(),
439            self.model.mode(),
440        );
441        let request = self.stream_completion(request, cx);
442        let future = self.request_limiter.stream(async move {
443            let response = request.await?;
444            Ok(AnthropicEventMapper::new().map_stream(response))
445        });
446        async move { Ok(future.await?.boxed()) }.boxed()
447    }
448
449    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
450        self.model
451            .cache_configuration()
452            .map(|config| LanguageModelCacheConfiguration {
453                max_cache_anchors: config.max_cache_anchors,
454                should_speculate: config.should_speculate,
455                min_total_token: config.min_total_token,
456            })
457    }
458}
459
460struct ConfigurationView {
461    api_key_editor: Entity<InputField>,
462    state: Entity<State>,
463    load_credentials_task: Option<Task<()>>,
464    target_agent: ConfigurationViewTargetAgent,
465}
466
467impl ConfigurationView {
468    const PLACEHOLDER_TEXT: &'static str = "sk-ant-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx";
469
470    fn new(
471        state: Entity<State>,
472        target_agent: ConfigurationViewTargetAgent,
473        window: &mut Window,
474        cx: &mut Context<Self>,
475    ) -> Self {
476        cx.observe(&state, |_, _, cx| {
477            cx.notify();
478        })
479        .detach();
480
481        let load_credentials_task = Some(cx.spawn({
482            let state = state.clone();
483            async move |this, cx| {
484                let task = state.update(cx, |state, cx| state.authenticate(cx));
485                // We don't log an error, because "not signed in" is also an error.
486                let _ = task.await;
487                this.update(cx, |this, cx| {
488                    this.load_credentials_task = None;
489                    cx.notify();
490                })
491                .log_err();
492            }
493        }));
494
495        Self {
496            api_key_editor: cx.new(|cx| InputField::new(window, cx, Self::PLACEHOLDER_TEXT)),
497            state,
498            load_credentials_task,
499            target_agent,
500        }
501    }
502
503    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
504        let api_key = self.api_key_editor.read(cx).text(cx);
505        if api_key.is_empty() {
506            return;
507        }
508
509        // url changes can cause the editor to be displayed again
510        self.api_key_editor
511            .update(cx, |editor, cx| editor.set_text("", window, cx));
512
513        let state = self.state.clone();
514        cx.spawn_in(window, async move |_, cx| {
515            state
516                .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))
517                .await
518        })
519        .detach_and_log_err(cx);
520    }
521
522    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
523        self.api_key_editor
524            .update(cx, |editor, cx| editor.set_text("", window, cx));
525
526        let state = self.state.clone();
527        cx.spawn_in(window, async move |_, cx| {
528            state
529                .update(cx, |state, cx| state.set_api_key(None, cx))
530                .await
531        })
532        .detach_and_log_err(cx);
533    }
534
535    fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
536        !self.state.read(cx).is_authenticated()
537    }
538}
539
540impl Render for ConfigurationView {
541    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
542        let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
543        let configured_card_label = if env_var_set {
544            format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
545        } else {
546            let api_url = AnthropicLanguageModelProvider::api_url(cx);
547            if api_url == ANTHROPIC_API_URL {
548                "API key configured".to_string()
549            } else {
550                format!("API key configured for {}", api_url)
551            }
552        };
553
554        if self.load_credentials_task.is_some() {
555            div()
556                .child(Label::new("Loading credentials..."))
557                .into_any_element()
558        } else if self.should_render_editor(cx) {
559            v_flex()
560                .size_full()
561                .on_action(cx.listener(Self::save_api_key))
562                .child(Label::new(format!("To use {}, you need to add an API key. Follow these steps:", match &self.target_agent {
563                    ConfigurationViewTargetAgent::ZedAgent => "Zed's agent with Anthropic".into(),
564                    ConfigurationViewTargetAgent::Other(agent) => agent.clone(),
565                })))
566                .child(
567                    List::new()
568                        .child(
569                            ListBulletItem::new("")
570                                .child(Label::new("Create one by visiting"))
571                                .child(ButtonLink::new("Anthropic's settings", "https://console.anthropic.com/settings/keys"))
572                        )
573                        .child(
574                            ListBulletItem::new("Paste your API key below and hit enter to start using the agent")
575                        )
576                )
577                .child(self.api_key_editor.clone())
578                .child(
579                    Label::new(
580                        format!("You can also set the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."),
581                    )
582                    .size(LabelSize::Small)
583                    .color(Color::Muted)
584                    .mt_0p5(),
585                )
586                .into_any_element()
587        } else {
588            ConfiguredApiCard::new(configured_card_label)
589                .disabled(env_var_set)
590                .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)))
591                .when(env_var_set, |this| {
592                    this.tooltip_label(format!(
593                    "To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable."
594                ))
595                })
596                .into_any_element()
597        }
598    }
599}