anthropic.rs

  1use crate::{
  2    settings::AllLanguageModelSettings, LanguageModel, LanguageModelCacheConfiguration,
  3    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
  4    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
  5};
  6use crate::{LanguageModelCompletionEvent, LanguageModelToolUse};
  7use anthropic::AnthropicError;
  8use anyhow::{anyhow, Context as _, Result};
  9use collections::BTreeMap;
 10use editor::{Editor, EditorElement, EditorStyle};
 11use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryStreamExt as _};
 12use gpui::{
 13    AnyView, AppContext, AsyncAppContext, FontStyle, ModelContext, Subscription, Task, TextStyle,
 14    View, WhiteSpace,
 15};
 16use http_client::HttpClient;
 17use schemars::JsonSchema;
 18use serde::{Deserialize, Serialize};
 19use settings::{Settings, SettingsStore};
 20use std::{sync::Arc, time::Duration};
 21use strum::IntoEnumIterator;
 22use theme::ThemeSettings;
 23use ui::{prelude::*, Icon, IconName, Tooltip};
 24use util::ResultExt;
 25
 26const PROVIDER_ID: &str = "anthropic";
 27const PROVIDER_NAME: &str = "Anthropic";
 28
 29#[derive(Default, Clone, Debug, PartialEq)]
 30pub struct AnthropicSettings {
 31    pub api_url: String,
 32    pub low_speed_timeout: Option<Duration>,
 33    /// Extend Zed's list of Anthropic models.
 34    pub available_models: Vec<AvailableModel>,
 35    pub needs_setting_migration: bool,
 36}
 37
 38#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 39pub struct AvailableModel {
 40    /// The model's name in the Anthropic API. e.g. claude-3-5-sonnet-20240620
 41    pub name: String,
 42    /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
 43    pub display_name: Option<String>,
 44    /// The model's context window size.
 45    pub max_tokens: usize,
 46    /// A model `name` to substitute when calling tools, in case the primary model doesn't support tool calling.
 47    pub tool_override: Option<String>,
 48    /// Configuration of Anthropic's caching API.
 49    pub cache_configuration: Option<LanguageModelCacheConfiguration>,
 50    pub max_output_tokens: Option<u32>,
 51}
 52
 53pub struct AnthropicLanguageModelProvider {
 54    http_client: Arc<dyn HttpClient>,
 55    state: gpui::Model<State>,
 56}
 57
 58const ANTHROPIC_API_KEY_VAR: &'static str = "ANTHROPIC_API_KEY";
 59
 60pub struct State {
 61    api_key: Option<String>,
 62    api_key_from_env: bool,
 63    _subscription: Subscription,
 64}
 65
 66impl State {
 67    fn reset_api_key(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
 68        let delete_credentials =
 69            cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).anthropic.api_url);
 70        cx.spawn(|this, mut cx| async move {
 71            delete_credentials.await.ok();
 72            this.update(&mut cx, |this, cx| {
 73                this.api_key = None;
 74                this.api_key_from_env = false;
 75                cx.notify();
 76            })
 77        })
 78    }
 79
 80    fn set_api_key(&mut self, api_key: String, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
 81        let write_credentials = cx.write_credentials(
 82            AllLanguageModelSettings::get_global(cx)
 83                .anthropic
 84                .api_url
 85                .as_str(),
 86            "Bearer",
 87            api_key.as_bytes(),
 88        );
 89        cx.spawn(|this, mut cx| async move {
 90            write_credentials.await?;
 91
 92            this.update(&mut cx, |this, cx| {
 93                this.api_key = Some(api_key);
 94                cx.notify();
 95            })
 96        })
 97    }
 98
 99    fn is_authenticated(&self) -> bool {
100        self.api_key.is_some()
101    }
102
103    fn authenticate(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
104        if self.is_authenticated() {
105            Task::ready(Ok(()))
106        } else {
107            let api_url = AllLanguageModelSettings::get_global(cx)
108                .anthropic
109                .api_url
110                .clone();
111
112            cx.spawn(|this, mut cx| async move {
113                let (api_key, from_env) = if let Ok(api_key) = std::env::var(ANTHROPIC_API_KEY_VAR)
114                {
115                    (api_key, true)
116                } else {
117                    let (_, api_key) = cx
118                        .update(|cx| cx.read_credentials(&api_url))?
119                        .await?
120                        .ok_or_else(|| anyhow!("credentials not found"))?;
121                    (String::from_utf8(api_key)?, false)
122                };
123
124                this.update(&mut cx, |this, cx| {
125                    this.api_key = Some(api_key);
126                    this.api_key_from_env = from_env;
127                    cx.notify();
128                })
129            })
130        }
131    }
132}
133
134impl AnthropicLanguageModelProvider {
135    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
136        let state = cx.new_model(|cx| State {
137            api_key: None,
138            api_key_from_env: false,
139            _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
140                cx.notify();
141            }),
142        });
143
144        Self { http_client, state }
145    }
146}
147
148impl LanguageModelProviderState for AnthropicLanguageModelProvider {
149    type ObservableEntity = State;
150
151    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
152        Some(self.state.clone())
153    }
154}
155
156impl LanguageModelProvider for AnthropicLanguageModelProvider {
157    fn id(&self) -> LanguageModelProviderId {
158        LanguageModelProviderId(PROVIDER_ID.into())
159    }
160
161    fn name(&self) -> LanguageModelProviderName {
162        LanguageModelProviderName(PROVIDER_NAME.into())
163    }
164
165    fn icon(&self) -> IconName {
166        IconName::AiAnthropic
167    }
168
169    fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
170        let mut models = BTreeMap::default();
171
172        // Add base models from anthropic::Model::iter()
173        for model in anthropic::Model::iter() {
174            if !matches!(model, anthropic::Model::Custom { .. }) {
175                models.insert(model.id().to_string(), model);
176            }
177        }
178
179        // Override with available models from settings
180        for model in AllLanguageModelSettings::get_global(cx)
181            .anthropic
182            .available_models
183            .iter()
184        {
185            models.insert(
186                model.name.clone(),
187                anthropic::Model::Custom {
188                    name: model.name.clone(),
189                    display_name: model.display_name.clone(),
190                    max_tokens: model.max_tokens,
191                    tool_override: model.tool_override.clone(),
192                    cache_configuration: model.cache_configuration.as_ref().map(|config| {
193                        anthropic::AnthropicModelCacheConfiguration {
194                            max_cache_anchors: config.max_cache_anchors,
195                            should_speculate: config.should_speculate,
196                            min_total_token: config.min_total_token,
197                        }
198                    }),
199                    max_output_tokens: model.max_output_tokens,
200                },
201            );
202        }
203
204        models
205            .into_values()
206            .map(|model| {
207                Arc::new(AnthropicModel {
208                    id: LanguageModelId::from(model.id().to_string()),
209                    model,
210                    state: self.state.clone(),
211                    http_client: self.http_client.clone(),
212                    request_limiter: RateLimiter::new(4),
213                }) as Arc<dyn LanguageModel>
214            })
215            .collect()
216    }
217
218    fn is_authenticated(&self, cx: &AppContext) -> bool {
219        self.state.read(cx).is_authenticated()
220    }
221
222    fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
223        self.state.update(cx, |state, cx| state.authenticate(cx))
224    }
225
226    fn configuration_view(&self, cx: &mut WindowContext) -> AnyView {
227        cx.new_view(|cx| ConfigurationView::new(self.state.clone(), cx))
228            .into()
229    }
230
231    fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
232        self.state.update(cx, |state, cx| state.reset_api_key(cx))
233    }
234}
235
236pub struct AnthropicModel {
237    id: LanguageModelId,
238    model: anthropic::Model,
239    state: gpui::Model<State>,
240    http_client: Arc<dyn HttpClient>,
241    request_limiter: RateLimiter,
242}
243
244pub fn count_anthropic_tokens(
245    request: LanguageModelRequest,
246    cx: &AppContext,
247) -> BoxFuture<'static, Result<usize>> {
248    cx.background_executor()
249        .spawn(async move {
250            let messages = request.messages;
251            let mut tokens_from_images = 0;
252            let mut string_messages = Vec::with_capacity(messages.len());
253
254            for message in messages {
255                use crate::MessageContent;
256
257                let mut string_contents = String::new();
258
259                for content in message.content {
260                    match content {
261                        MessageContent::Text(string) => {
262                            string_contents.push_str(&string);
263                        }
264                        MessageContent::Image(image) => {
265                            tokens_from_images += image.estimate_tokens();
266                        }
267                    }
268                }
269
270                if !string_contents.is_empty() {
271                    string_messages.push(tiktoken_rs::ChatCompletionRequestMessage {
272                        role: match message.role {
273                            Role::User => "user".into(),
274                            Role::Assistant => "assistant".into(),
275                            Role::System => "system".into(),
276                        },
277                        content: Some(string_contents),
278                        name: None,
279                        function_call: None,
280                    });
281                }
282            }
283
284            // Tiktoken doesn't yet support these models, so we manually use the
285            // same tokenizer as GPT-4.
286            tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages)
287                .map(|tokens| tokens + tokens_from_images)
288        })
289        .boxed()
290}
291
292impl AnthropicModel {
293    fn stream_completion(
294        &self,
295        request: anthropic::Request,
296        cx: &AsyncAppContext,
297    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<anthropic::Event, AnthropicError>>>>
298    {
299        let http_client = self.http_client.clone();
300
301        let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
302            let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
303            (
304                state.api_key.clone(),
305                settings.api_url.clone(),
306                settings.low_speed_timeout,
307            )
308        }) else {
309            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
310        };
311
312        async move {
313            let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
314            let request = anthropic::stream_completion(
315                http_client.as_ref(),
316                &api_url,
317                &api_key,
318                request,
319                low_speed_timeout,
320            );
321            request.await.context("failed to stream completion")
322        }
323        .boxed()
324    }
325}
326
327impl LanguageModel for AnthropicModel {
328    fn id(&self) -> LanguageModelId {
329        self.id.clone()
330    }
331
332    fn name(&self) -> LanguageModelName {
333        LanguageModelName::from(self.model.display_name().to_string())
334    }
335
336    fn provider_id(&self) -> LanguageModelProviderId {
337        LanguageModelProviderId(PROVIDER_ID.into())
338    }
339
340    fn provider_name(&self) -> LanguageModelProviderName {
341        LanguageModelProviderName(PROVIDER_NAME.into())
342    }
343
344    fn telemetry_id(&self) -> String {
345        format!("anthropic/{}", self.model.id())
346    }
347
348    fn max_token_count(&self) -> usize {
349        self.model.max_token_count()
350    }
351
352    fn max_output_tokens(&self) -> Option<u32> {
353        Some(self.model.max_output_tokens())
354    }
355
356    fn count_tokens(
357        &self,
358        request: LanguageModelRequest,
359        cx: &AppContext,
360    ) -> BoxFuture<'static, Result<usize>> {
361        count_anthropic_tokens(request, cx)
362    }
363
364    fn stream_completion(
365        &self,
366        request: LanguageModelRequest,
367        cx: &AsyncAppContext,
368    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
369        let request =
370            request.into_anthropic(self.model.id().into(), self.model.max_output_tokens());
371        let request = self.stream_completion(request, cx);
372        let future = self.request_limiter.stream(async move {
373            let response = request.await.map_err(|err| anyhow!(err))?;
374            Ok(anthropic::extract_content_from_events(response))
375        });
376        async move {
377            Ok(future
378                .await?
379                .map(|result| {
380                    result
381                        .map(|content| match content {
382                            anthropic::ResponseContent::Text { text } => {
383                                LanguageModelCompletionEvent::Text(text)
384                            }
385                            anthropic::ResponseContent::ToolUse { id, name, input } => {
386                                LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
387                                    id,
388                                    name,
389                                    input,
390                                })
391                            }
392                        })
393                        .map_err(|err| anyhow!(err))
394                })
395                .boxed())
396        }
397        .boxed()
398    }
399
400    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
401        self.model
402            .cache_configuration()
403            .map(|config| LanguageModelCacheConfiguration {
404                max_cache_anchors: config.max_cache_anchors,
405                should_speculate: config.should_speculate,
406                min_total_token: config.min_total_token,
407            })
408    }
409
410    fn use_any_tool(
411        &self,
412        request: LanguageModelRequest,
413        tool_name: String,
414        tool_description: String,
415        input_schema: serde_json::Value,
416        cx: &AsyncAppContext,
417    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
418        let mut request = request.into_anthropic(
419            self.model.tool_model_id().into(),
420            self.model.max_output_tokens(),
421        );
422        request.tool_choice = Some(anthropic::ToolChoice::Tool {
423            name: tool_name.clone(),
424        });
425        request.tools = vec![anthropic::Tool {
426            name: tool_name.clone(),
427            description: tool_description,
428            input_schema,
429        }];
430
431        let response = self.stream_completion(request, cx);
432        self.request_limiter
433            .run(async move {
434                let response = response.await?;
435                Ok(anthropic::extract_tool_args_from_events(
436                    tool_name,
437                    Box::pin(response.map_err(|e| anyhow!(e))),
438                )
439                .await?
440                .boxed())
441            })
442            .boxed()
443    }
444}
445
446struct ConfigurationView {
447    api_key_editor: View<Editor>,
448    state: gpui::Model<State>,
449    load_credentials_task: Option<Task<()>>,
450}
451
452impl ConfigurationView {
453    const PLACEHOLDER_TEXT: &'static str = "sk-ant-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx";
454
455    fn new(state: gpui::Model<State>, cx: &mut ViewContext<Self>) -> Self {
456        cx.observe(&state, |_, _, cx| {
457            cx.notify();
458        })
459        .detach();
460
461        let load_credentials_task = Some(cx.spawn({
462            let state = state.clone();
463            |this, mut cx| async move {
464                if let Some(task) = state
465                    .update(&mut cx, |state, cx| state.authenticate(cx))
466                    .log_err()
467                {
468                    // We don't log an error, because "not signed in" is also an error.
469                    let _ = task.await;
470                }
471                this.update(&mut cx, |this, cx| {
472                    this.load_credentials_task = None;
473                    cx.notify();
474                })
475                .log_err();
476            }
477        }));
478
479        Self {
480            api_key_editor: cx.new_view(|cx| {
481                let mut editor = Editor::single_line(cx);
482                editor.set_placeholder_text(Self::PLACEHOLDER_TEXT, cx);
483                editor
484            }),
485            state,
486            load_credentials_task,
487        }
488    }
489
490    fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
491        let api_key = self.api_key_editor.read(cx).text(cx);
492        if api_key.is_empty() {
493            return;
494        }
495
496        let state = self.state.clone();
497        cx.spawn(|_, mut cx| async move {
498            state
499                .update(&mut cx, |state, cx| state.set_api_key(api_key, cx))?
500                .await
501        })
502        .detach_and_log_err(cx);
503
504        cx.notify();
505    }
506
507    fn reset_api_key(&mut self, cx: &mut ViewContext<Self>) {
508        self.api_key_editor
509            .update(cx, |editor, cx| editor.set_text("", cx));
510
511        let state = self.state.clone();
512        cx.spawn(|_, mut cx| async move {
513            state
514                .update(&mut cx, |state, cx| state.reset_api_key(cx))?
515                .await
516        })
517        .detach_and_log_err(cx);
518
519        cx.notify();
520    }
521
522    fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
523        let settings = ThemeSettings::get_global(cx);
524        let text_style = TextStyle {
525            color: cx.theme().colors().text,
526            font_family: settings.ui_font.family.clone(),
527            font_features: settings.ui_font.features.clone(),
528            font_fallbacks: settings.ui_font.fallbacks.clone(),
529            font_size: rems(0.875).into(),
530            font_weight: settings.ui_font.weight,
531            font_style: FontStyle::Normal,
532            line_height: relative(1.3),
533            background_color: None,
534            underline: None,
535            strikethrough: None,
536            white_space: WhiteSpace::Normal,
537            truncate: None,
538        };
539        EditorElement::new(
540            &self.api_key_editor,
541            EditorStyle {
542                background: cx.theme().colors().editor_background,
543                local_player: cx.theme().players().local(),
544                text: text_style,
545                ..Default::default()
546            },
547        )
548    }
549
550    fn should_render_editor(&self, cx: &mut ViewContext<Self>) -> bool {
551        !self.state.read(cx).is_authenticated()
552    }
553}
554
555impl Render for ConfigurationView {
556    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
557        const ANTHROPIC_CONSOLE_URL: &str = "https://console.anthropic.com/settings/keys";
558        const INSTRUCTIONS: [&str; 4] = [
559            "To use the assistant panel or inline assistant, you need to add your Anthropic API key.",
560            "You can create an API key at:",
561            "",
562            "Paste your Anthropic API key below and hit enter to use the assistant:",
563        ];
564        let env_var_set = self.state.read(cx).api_key_from_env;
565
566        if self.load_credentials_task.is_some() {
567            div().child(Label::new("Loading credentials...")).into_any()
568        } else if self.should_render_editor(cx) {
569            v_flex()
570                .size_full()
571                .on_action(cx.listener(Self::save_api_key))
572                .child(Label::new(INSTRUCTIONS[0]))
573                .child(h_flex().child(Label::new(INSTRUCTIONS[1])).child(
574                    Button::new("anthropic_console", ANTHROPIC_CONSOLE_URL)
575                        .style(ButtonStyle::Subtle)
576                        .icon(IconName::ExternalLink)
577                        .icon_size(IconSize::XSmall)
578                        .icon_color(Color::Muted)
579                        .on_click(move |_, cx| cx.open_url(ANTHROPIC_CONSOLE_URL))
580                    )
581                )
582                .child(Label::new(INSTRUCTIONS[2]))
583                .child(Label::new(INSTRUCTIONS[3]))
584                .child(
585                    h_flex()
586                        .w_full()
587                        .my_2()
588                        .px_2()
589                        .py_1()
590                        .bg(cx.theme().colors().editor_background)
591                        .rounded_md()
592                        .child(self.render_api_key_editor(cx)),
593                )
594                .child(
595                    Label::new(
596                        "You can also assign the {ANTHROPIC_API_KEY_VAR} environment variable and restart Zed.",
597                    )
598                    .size(LabelSize::Small),
599                )
600                .into_any()
601        } else {
602            h_flex()
603                .size_full()
604                .justify_between()
605                .child(
606                    h_flex()
607                        .gap_1()
608                        .child(Icon::new(IconName::Check).color(Color::Success))
609                        .child(Label::new(if env_var_set {
610                            format!("API key set in {ANTHROPIC_API_KEY_VAR} environment variable.")
611                        } else {
612                            "API key configured.".to_string()
613                        })),
614                )
615                .child(
616                    Button::new("reset-key", "Reset key")
617                        .icon(Some(IconName::Trash))
618                        .icon_size(IconSize::Small)
619                        .icon_position(IconPosition::Start)
620                        .disabled(env_var_set)
621                        .when(env_var_set, |this| {
622                            this.tooltip(|cx| Tooltip::text(format!("To reset your API key, unset the {ANTHROPIC_API_KEY_VAR} environment variable."), cx))
623                        })
624                        .on_click(cx.listener(|this, _, cx| this.reset_api_key(cx))),
625                )
626                .into_any()
627        }
628    }
629}