1use crate::{
2 settings::AllLanguageModelSettings, LanguageModel, LanguageModelCacheConfiguration,
3 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
4 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
5};
6use crate::{LanguageModelCompletionEvent, LanguageModelToolUse};
7use anthropic::AnthropicError;
8use anyhow::{anyhow, Context as _, Result};
9use collections::BTreeMap;
10use editor::{Editor, EditorElement, EditorStyle};
11use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryStreamExt as _};
12use gpui::{
13 AnyView, AppContext, AsyncAppContext, FontStyle, ModelContext, Subscription, Task, TextStyle,
14 View, WhiteSpace,
15};
16use http_client::HttpClient;
17use schemars::JsonSchema;
18use serde::{Deserialize, Serialize};
19use settings::{Settings, SettingsStore};
20use std::{sync::Arc, time::Duration};
21use strum::IntoEnumIterator;
22use theme::ThemeSettings;
23use ui::{prelude::*, Icon, IconName, Tooltip};
24use util::ResultExt;
25
26const PROVIDER_ID: &str = "anthropic";
27const PROVIDER_NAME: &str = "Anthropic";
28
29#[derive(Default, Clone, Debug, PartialEq)]
30pub struct AnthropicSettings {
31 pub api_url: String,
32 pub low_speed_timeout: Option<Duration>,
33 /// Extend Zed's list of Anthropic models.
34 pub available_models: Vec<AvailableModel>,
35 pub needs_setting_migration: bool,
36}
37
38#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
39pub struct AvailableModel {
40 /// The model's name in the Anthropic API. e.g. claude-3-5-sonnet-20240620
41 pub name: String,
42 /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
43 pub display_name: Option<String>,
44 /// The model's context window size.
45 pub max_tokens: usize,
46 /// A model `name` to substitute when calling tools, in case the primary model doesn't support tool calling.
47 pub tool_override: Option<String>,
48 /// Configuration of Anthropic's caching API.
49 pub cache_configuration: Option<LanguageModelCacheConfiguration>,
50 pub max_output_tokens: Option<u32>,
51}
52
53pub struct AnthropicLanguageModelProvider {
54 http_client: Arc<dyn HttpClient>,
55 state: gpui::Model<State>,
56}
57
58const ANTHROPIC_API_KEY_VAR: &'static str = "ANTHROPIC_API_KEY";
59
60pub struct State {
61 api_key: Option<String>,
62 api_key_from_env: bool,
63 _subscription: Subscription,
64}
65
66impl State {
67 fn reset_api_key(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
68 let delete_credentials =
69 cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).anthropic.api_url);
70 cx.spawn(|this, mut cx| async move {
71 delete_credentials.await.ok();
72 this.update(&mut cx, |this, cx| {
73 this.api_key = None;
74 this.api_key_from_env = false;
75 cx.notify();
76 })
77 })
78 }
79
80 fn set_api_key(&mut self, api_key: String, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
81 let write_credentials = cx.write_credentials(
82 AllLanguageModelSettings::get_global(cx)
83 .anthropic
84 .api_url
85 .as_str(),
86 "Bearer",
87 api_key.as_bytes(),
88 );
89 cx.spawn(|this, mut cx| async move {
90 write_credentials.await?;
91
92 this.update(&mut cx, |this, cx| {
93 this.api_key = Some(api_key);
94 cx.notify();
95 })
96 })
97 }
98
99 fn is_authenticated(&self) -> bool {
100 self.api_key.is_some()
101 }
102
103 fn authenticate(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
104 if self.is_authenticated() {
105 Task::ready(Ok(()))
106 } else {
107 let api_url = AllLanguageModelSettings::get_global(cx)
108 .anthropic
109 .api_url
110 .clone();
111
112 cx.spawn(|this, mut cx| async move {
113 let (api_key, from_env) = if let Ok(api_key) = std::env::var(ANTHROPIC_API_KEY_VAR)
114 {
115 (api_key, true)
116 } else {
117 let (_, api_key) = cx
118 .update(|cx| cx.read_credentials(&api_url))?
119 .await?
120 .ok_or_else(|| anyhow!("credentials not found"))?;
121 (String::from_utf8(api_key)?, false)
122 };
123
124 this.update(&mut cx, |this, cx| {
125 this.api_key = Some(api_key);
126 this.api_key_from_env = from_env;
127 cx.notify();
128 })
129 })
130 }
131 }
132}
133
134impl AnthropicLanguageModelProvider {
135 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
136 let state = cx.new_model(|cx| State {
137 api_key: None,
138 api_key_from_env: false,
139 _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
140 cx.notify();
141 }),
142 });
143
144 Self { http_client, state }
145 }
146}
147
148impl LanguageModelProviderState for AnthropicLanguageModelProvider {
149 type ObservableEntity = State;
150
151 fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
152 Some(self.state.clone())
153 }
154}
155
156impl LanguageModelProvider for AnthropicLanguageModelProvider {
157 fn id(&self) -> LanguageModelProviderId {
158 LanguageModelProviderId(PROVIDER_ID.into())
159 }
160
161 fn name(&self) -> LanguageModelProviderName {
162 LanguageModelProviderName(PROVIDER_NAME.into())
163 }
164
165 fn icon(&self) -> IconName {
166 IconName::AiAnthropic
167 }
168
169 fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
170 let mut models = BTreeMap::default();
171
172 // Add base models from anthropic::Model::iter()
173 for model in anthropic::Model::iter() {
174 if !matches!(model, anthropic::Model::Custom { .. }) {
175 models.insert(model.id().to_string(), model);
176 }
177 }
178
179 // Override with available models from settings
180 for model in AllLanguageModelSettings::get_global(cx)
181 .anthropic
182 .available_models
183 .iter()
184 {
185 models.insert(
186 model.name.clone(),
187 anthropic::Model::Custom {
188 name: model.name.clone(),
189 display_name: model.display_name.clone(),
190 max_tokens: model.max_tokens,
191 tool_override: model.tool_override.clone(),
192 cache_configuration: model.cache_configuration.as_ref().map(|config| {
193 anthropic::AnthropicModelCacheConfiguration {
194 max_cache_anchors: config.max_cache_anchors,
195 should_speculate: config.should_speculate,
196 min_total_token: config.min_total_token,
197 }
198 }),
199 max_output_tokens: model.max_output_tokens,
200 },
201 );
202 }
203
204 models
205 .into_values()
206 .map(|model| {
207 Arc::new(AnthropicModel {
208 id: LanguageModelId::from(model.id().to_string()),
209 model,
210 state: self.state.clone(),
211 http_client: self.http_client.clone(),
212 request_limiter: RateLimiter::new(4),
213 }) as Arc<dyn LanguageModel>
214 })
215 .collect()
216 }
217
218 fn is_authenticated(&self, cx: &AppContext) -> bool {
219 self.state.read(cx).is_authenticated()
220 }
221
222 fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
223 self.state.update(cx, |state, cx| state.authenticate(cx))
224 }
225
226 fn configuration_view(&self, cx: &mut WindowContext) -> AnyView {
227 cx.new_view(|cx| ConfigurationView::new(self.state.clone(), cx))
228 .into()
229 }
230
231 fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
232 self.state.update(cx, |state, cx| state.reset_api_key(cx))
233 }
234}
235
236pub struct AnthropicModel {
237 id: LanguageModelId,
238 model: anthropic::Model,
239 state: gpui::Model<State>,
240 http_client: Arc<dyn HttpClient>,
241 request_limiter: RateLimiter,
242}
243
244pub fn count_anthropic_tokens(
245 request: LanguageModelRequest,
246 cx: &AppContext,
247) -> BoxFuture<'static, Result<usize>> {
248 cx.background_executor()
249 .spawn(async move {
250 let messages = request.messages;
251 let mut tokens_from_images = 0;
252 let mut string_messages = Vec::with_capacity(messages.len());
253
254 for message in messages {
255 use crate::MessageContent;
256
257 let mut string_contents = String::new();
258
259 for content in message.content {
260 match content {
261 MessageContent::Text(string) => {
262 string_contents.push_str(&string);
263 }
264 MessageContent::Image(image) => {
265 tokens_from_images += image.estimate_tokens();
266 }
267 }
268 }
269
270 if !string_contents.is_empty() {
271 string_messages.push(tiktoken_rs::ChatCompletionRequestMessage {
272 role: match message.role {
273 Role::User => "user".into(),
274 Role::Assistant => "assistant".into(),
275 Role::System => "system".into(),
276 },
277 content: Some(string_contents),
278 name: None,
279 function_call: None,
280 });
281 }
282 }
283
284 // Tiktoken doesn't yet support these models, so we manually use the
285 // same tokenizer as GPT-4.
286 tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages)
287 .map(|tokens| tokens + tokens_from_images)
288 })
289 .boxed()
290}
291
292impl AnthropicModel {
293 fn stream_completion(
294 &self,
295 request: anthropic::Request,
296 cx: &AsyncAppContext,
297 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<anthropic::Event, AnthropicError>>>>
298 {
299 let http_client = self.http_client.clone();
300
301 let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
302 let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
303 (
304 state.api_key.clone(),
305 settings.api_url.clone(),
306 settings.low_speed_timeout,
307 )
308 }) else {
309 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
310 };
311
312 async move {
313 let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
314 let request = anthropic::stream_completion(
315 http_client.as_ref(),
316 &api_url,
317 &api_key,
318 request,
319 low_speed_timeout,
320 );
321 request.await.context("failed to stream completion")
322 }
323 .boxed()
324 }
325}
326
327impl LanguageModel for AnthropicModel {
328 fn id(&self) -> LanguageModelId {
329 self.id.clone()
330 }
331
332 fn name(&self) -> LanguageModelName {
333 LanguageModelName::from(self.model.display_name().to_string())
334 }
335
336 fn provider_id(&self) -> LanguageModelProviderId {
337 LanguageModelProviderId(PROVIDER_ID.into())
338 }
339
340 fn provider_name(&self) -> LanguageModelProviderName {
341 LanguageModelProviderName(PROVIDER_NAME.into())
342 }
343
344 fn telemetry_id(&self) -> String {
345 format!("anthropic/{}", self.model.id())
346 }
347
348 fn max_token_count(&self) -> usize {
349 self.model.max_token_count()
350 }
351
352 fn max_output_tokens(&self) -> Option<u32> {
353 Some(self.model.max_output_tokens())
354 }
355
356 fn count_tokens(
357 &self,
358 request: LanguageModelRequest,
359 cx: &AppContext,
360 ) -> BoxFuture<'static, Result<usize>> {
361 count_anthropic_tokens(request, cx)
362 }
363
364 fn stream_completion(
365 &self,
366 request: LanguageModelRequest,
367 cx: &AsyncAppContext,
368 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
369 let request =
370 request.into_anthropic(self.model.id().into(), self.model.max_output_tokens());
371 let request = self.stream_completion(request, cx);
372 let future = self.request_limiter.stream(async move {
373 let response = request.await.map_err(|err| anyhow!(err))?;
374 Ok(anthropic::extract_content_from_events(response))
375 });
376 async move {
377 Ok(future
378 .await?
379 .map(|result| {
380 result
381 .map(|content| match content {
382 anthropic::ResponseContent::Text { text } => {
383 LanguageModelCompletionEvent::Text(text)
384 }
385 anthropic::ResponseContent::ToolUse { id, name, input } => {
386 LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
387 id,
388 name,
389 input,
390 })
391 }
392 })
393 .map_err(|err| anyhow!(err))
394 })
395 .boxed())
396 }
397 .boxed()
398 }
399
400 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
401 self.model
402 .cache_configuration()
403 .map(|config| LanguageModelCacheConfiguration {
404 max_cache_anchors: config.max_cache_anchors,
405 should_speculate: config.should_speculate,
406 min_total_token: config.min_total_token,
407 })
408 }
409
410 fn use_any_tool(
411 &self,
412 request: LanguageModelRequest,
413 tool_name: String,
414 tool_description: String,
415 input_schema: serde_json::Value,
416 cx: &AsyncAppContext,
417 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
418 let mut request = request.into_anthropic(
419 self.model.tool_model_id().into(),
420 self.model.max_output_tokens(),
421 );
422 request.tool_choice = Some(anthropic::ToolChoice::Tool {
423 name: tool_name.clone(),
424 });
425 request.tools = vec![anthropic::Tool {
426 name: tool_name.clone(),
427 description: tool_description,
428 input_schema,
429 }];
430
431 let response = self.stream_completion(request, cx);
432 self.request_limiter
433 .run(async move {
434 let response = response.await?;
435 Ok(anthropic::extract_tool_args_from_events(
436 tool_name,
437 Box::pin(response.map_err(|e| anyhow!(e))),
438 )
439 .await?
440 .boxed())
441 })
442 .boxed()
443 }
444}
445
446struct ConfigurationView {
447 api_key_editor: View<Editor>,
448 state: gpui::Model<State>,
449 load_credentials_task: Option<Task<()>>,
450}
451
452impl ConfigurationView {
453 const PLACEHOLDER_TEXT: &'static str = "sk-ant-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx";
454
455 fn new(state: gpui::Model<State>, cx: &mut ViewContext<Self>) -> Self {
456 cx.observe(&state, |_, _, cx| {
457 cx.notify();
458 })
459 .detach();
460
461 let load_credentials_task = Some(cx.spawn({
462 let state = state.clone();
463 |this, mut cx| async move {
464 if let Some(task) = state
465 .update(&mut cx, |state, cx| state.authenticate(cx))
466 .log_err()
467 {
468 // We don't log an error, because "not signed in" is also an error.
469 let _ = task.await;
470 }
471 this.update(&mut cx, |this, cx| {
472 this.load_credentials_task = None;
473 cx.notify();
474 })
475 .log_err();
476 }
477 }));
478
479 Self {
480 api_key_editor: cx.new_view(|cx| {
481 let mut editor = Editor::single_line(cx);
482 editor.set_placeholder_text(Self::PLACEHOLDER_TEXT, cx);
483 editor
484 }),
485 state,
486 load_credentials_task,
487 }
488 }
489
490 fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
491 let api_key = self.api_key_editor.read(cx).text(cx);
492 if api_key.is_empty() {
493 return;
494 }
495
496 let state = self.state.clone();
497 cx.spawn(|_, mut cx| async move {
498 state
499 .update(&mut cx, |state, cx| state.set_api_key(api_key, cx))?
500 .await
501 })
502 .detach_and_log_err(cx);
503
504 cx.notify();
505 }
506
507 fn reset_api_key(&mut self, cx: &mut ViewContext<Self>) {
508 self.api_key_editor
509 .update(cx, |editor, cx| editor.set_text("", cx));
510
511 let state = self.state.clone();
512 cx.spawn(|_, mut cx| async move {
513 state
514 .update(&mut cx, |state, cx| state.reset_api_key(cx))?
515 .await
516 })
517 .detach_and_log_err(cx);
518
519 cx.notify();
520 }
521
522 fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
523 let settings = ThemeSettings::get_global(cx);
524 let text_style = TextStyle {
525 color: cx.theme().colors().text,
526 font_family: settings.ui_font.family.clone(),
527 font_features: settings.ui_font.features.clone(),
528 font_fallbacks: settings.ui_font.fallbacks.clone(),
529 font_size: rems(0.875).into(),
530 font_weight: settings.ui_font.weight,
531 font_style: FontStyle::Normal,
532 line_height: relative(1.3),
533 background_color: None,
534 underline: None,
535 strikethrough: None,
536 white_space: WhiteSpace::Normal,
537 truncate: None,
538 };
539 EditorElement::new(
540 &self.api_key_editor,
541 EditorStyle {
542 background: cx.theme().colors().editor_background,
543 local_player: cx.theme().players().local(),
544 text: text_style,
545 ..Default::default()
546 },
547 )
548 }
549
550 fn should_render_editor(&self, cx: &mut ViewContext<Self>) -> bool {
551 !self.state.read(cx).is_authenticated()
552 }
553}
554
555impl Render for ConfigurationView {
556 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
557 const ANTHROPIC_CONSOLE_URL: &str = "https://console.anthropic.com/settings/keys";
558 const INSTRUCTIONS: [&str; 4] = [
559 "To use the assistant panel or inline assistant, you need to add your Anthropic API key.",
560 "You can create an API key at:",
561 "",
562 "Paste your Anthropic API key below and hit enter to use the assistant:",
563 ];
564 let env_var_set = self.state.read(cx).api_key_from_env;
565
566 if self.load_credentials_task.is_some() {
567 div().child(Label::new("Loading credentials...")).into_any()
568 } else if self.should_render_editor(cx) {
569 v_flex()
570 .size_full()
571 .on_action(cx.listener(Self::save_api_key))
572 .child(Label::new(INSTRUCTIONS[0]))
573 .child(h_flex().child(Label::new(INSTRUCTIONS[1])).child(
574 Button::new("anthropic_console", ANTHROPIC_CONSOLE_URL)
575 .style(ButtonStyle::Subtle)
576 .icon(IconName::ExternalLink)
577 .icon_size(IconSize::XSmall)
578 .icon_color(Color::Muted)
579 .on_click(move |_, cx| cx.open_url(ANTHROPIC_CONSOLE_URL))
580 )
581 )
582 .child(Label::new(INSTRUCTIONS[2]))
583 .child(Label::new(INSTRUCTIONS[3]))
584 .child(
585 h_flex()
586 .w_full()
587 .my_2()
588 .px_2()
589 .py_1()
590 .bg(cx.theme().colors().editor_background)
591 .rounded_md()
592 .child(self.render_api_key_editor(cx)),
593 )
594 .child(
595 Label::new(
596 "You can also assign the {ANTHROPIC_API_KEY_VAR} environment variable and restart Zed.",
597 )
598 .size(LabelSize::Small),
599 )
600 .into_any()
601 } else {
602 h_flex()
603 .size_full()
604 .justify_between()
605 .child(
606 h_flex()
607 .gap_1()
608 .child(Icon::new(IconName::Check).color(Color::Success))
609 .child(Label::new(if env_var_set {
610 format!("API key set in {ANTHROPIC_API_KEY_VAR} environment variable.")
611 } else {
612 "API key configured.".to_string()
613 })),
614 )
615 .child(
616 Button::new("reset-key", "Reset key")
617 .icon(Some(IconName::Trash))
618 .icon_size(IconSize::Small)
619 .icon_position(IconPosition::Start)
620 .disabled(env_var_set)
621 .when(env_var_set, |this| {
622 this.tooltip(|cx| Tooltip::text(format!("To reset your API key, unset the {ANTHROPIC_API_KEY_VAR} environment variable."), cx))
623 })
624 .on_click(cx.listener(|this, _, cx| this.reset_api_key(cx))),
625 )
626 .into_any()
627 }
628 }
629}