anthropic.rs

   1use crate::AllLanguageModelSettings;
   2use crate::ui::InstructionListItem;
   3use anthropic::{
   4    AnthropicError, AnthropicModelMode, ContentDelta, Event, ResponseContent, ToolResultContent,
   5    ToolResultPart, Usage,
   6};
   7use anyhow::{Context as _, Result, anyhow};
   8use collections::{BTreeMap, HashMap};
   9use credentials_provider::CredentialsProvider;
  10use editor::{Editor, EditorElement, EditorStyle};
  11use futures::Stream;
  12use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
  13use gpui::{
  14    AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
  15};
  16use http_client::HttpClient;
  17use language_model::{
  18    AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
  19    LanguageModelCompletionError, LanguageModelId, LanguageModelName, LanguageModelProvider,
  20    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
  21    LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent, MessageContent,
  22    RateLimiter, Role,
  23};
  24use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason};
  25use schemars::JsonSchema;
  26use serde::{Deserialize, Serialize};
  27use settings::{Settings, SettingsStore};
  28use std::pin::Pin;
  29use std::str::FromStr;
  30use std::sync::Arc;
  31use strum::IntoEnumIterator;
  32use theme::ThemeSettings;
  33use ui::{Icon, IconName, List, Tooltip, prelude::*};
  34use util::ResultExt;
  35
  36const PROVIDER_ID: LanguageModelProviderId = language_model::ANTHROPIC_PROVIDER_ID;
  37const PROVIDER_NAME: LanguageModelProviderName = language_model::ANTHROPIC_PROVIDER_NAME;
  38
  39#[derive(Default, Clone, Debug, PartialEq)]
  40pub struct AnthropicSettings {
  41    pub api_url: String,
  42    /// Extend Zed's list of Anthropic models.
  43    pub available_models: Vec<AvailableModel>,
  44}
  45
  46#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  47pub struct AvailableModel {
  48    /// The model's name in the Anthropic API. e.g. claude-3-5-sonnet-latest, claude-3-opus-20240229, etc
  49    pub name: String,
  50    /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
  51    pub display_name: Option<String>,
  52    /// The model's context window size.
  53    pub max_tokens: u64,
  54    /// A model `name` to substitute when calling tools, in case the primary model doesn't support tool calling.
  55    pub tool_override: Option<String>,
  56    /// Configuration of Anthropic's caching API.
  57    pub cache_configuration: Option<LanguageModelCacheConfiguration>,
  58    pub max_output_tokens: Option<u64>,
  59    pub default_temperature: Option<f32>,
  60    #[serde(default)]
  61    pub extra_beta_headers: Vec<String>,
  62    /// The model's mode (e.g. thinking)
  63    pub mode: Option<ModelMode>,
  64}
  65
  66#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema)]
  67#[serde(tag = "type", rename_all = "lowercase")]
  68pub enum ModelMode {
  69    #[default]
  70    Default,
  71    Thinking {
  72        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
  73        budget_tokens: Option<u32>,
  74    },
  75}
  76
  77impl From<ModelMode> for AnthropicModelMode {
  78    fn from(value: ModelMode) -> Self {
  79        match value {
  80            ModelMode::Default => AnthropicModelMode::Default,
  81            ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
  82        }
  83    }
  84}
  85
  86impl From<AnthropicModelMode> for ModelMode {
  87    fn from(value: AnthropicModelMode) -> Self {
  88        match value {
  89            AnthropicModelMode::Default => ModelMode::Default,
  90            AnthropicModelMode::Thinking { budget_tokens } => ModelMode::Thinking { budget_tokens },
  91        }
  92    }
  93}
  94
  95pub struct AnthropicLanguageModelProvider {
  96    http_client: Arc<dyn HttpClient>,
  97    state: gpui::Entity<State>,
  98}
  99
 100const ANTHROPIC_API_KEY_VAR: &str = "ANTHROPIC_API_KEY";
 101
 102pub struct State {
 103    api_key: Option<String>,
 104    api_key_from_env: bool,
 105    _subscription: Subscription,
 106}
 107
 108impl State {
 109    fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 110        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 111        let api_url = AllLanguageModelSettings::get_global(cx)
 112            .anthropic
 113            .api_url
 114            .clone();
 115        cx.spawn(async move |this, cx| {
 116            credentials_provider
 117                .delete_credentials(&api_url, &cx)
 118                .await
 119                .ok();
 120            this.update(cx, |this, cx| {
 121                this.api_key = None;
 122                this.api_key_from_env = false;
 123                cx.notify();
 124            })
 125        })
 126    }
 127
 128    fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
 129        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 130        let api_url = AllLanguageModelSettings::get_global(cx)
 131            .anthropic
 132            .api_url
 133            .clone();
 134        cx.spawn(async move |this, cx| {
 135            credentials_provider
 136                .write_credentials(&api_url, "Bearer", api_key.as_bytes(), &cx)
 137                .await
 138                .ok();
 139
 140            this.update(cx, |this, cx| {
 141                this.api_key = Some(api_key);
 142                cx.notify();
 143            })
 144        })
 145    }
 146
 147    fn is_authenticated(&self) -> bool {
 148        self.api_key.is_some()
 149    }
 150
 151    fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
 152        if self.is_authenticated() {
 153            return Task::ready(Ok(()));
 154        }
 155
 156        let key = ApiKey::get(cx);
 157
 158        cx.spawn(async move |this, cx| {
 159            let key = key.await?;
 160
 161            this.update(cx, |this, cx| {
 162                this.api_key = Some(key.key);
 163                this.api_key_from_env = key.from_env;
 164                cx.notify();
 165            })?;
 166
 167            Ok(())
 168        })
 169    }
 170}
 171
 172pub struct ApiKey {
 173    pub key: String,
 174    pub from_env: bool,
 175}
 176
 177impl ApiKey {
 178    pub fn get(cx: &mut App) -> Task<Result<Self>> {
 179        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 180        let api_url = AllLanguageModelSettings::get_global(cx)
 181            .anthropic
 182            .api_url
 183            .clone();
 184
 185        if let Ok(key) = std::env::var(ANTHROPIC_API_KEY_VAR) {
 186            Task::ready(Ok(ApiKey {
 187                key,
 188                from_env: true,
 189            }))
 190        } else {
 191            cx.spawn(async move |cx| {
 192                let (_, api_key) = credentials_provider
 193                    .read_credentials(&api_url, &cx)
 194                    .await?
 195                    .ok_or(AuthenticateError::CredentialsNotFound)?;
 196
 197                Ok(ApiKey {
 198                    key: String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
 199                    from_env: false,
 200                })
 201            })
 202        }
 203    }
 204}
 205
 206impl AnthropicLanguageModelProvider {
 207    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
 208        let state = cx.new(|cx| State {
 209            api_key: None,
 210            api_key_from_env: false,
 211            _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 212                cx.notify();
 213            }),
 214        });
 215
 216        Self { http_client, state }
 217    }
 218
 219    fn create_language_model(&self, model: anthropic::Model) -> Arc<dyn LanguageModel> {
 220        Arc::new(AnthropicModel {
 221            id: LanguageModelId::from(model.id().to_string()),
 222            model,
 223            state: self.state.clone(),
 224            http_client: self.http_client.clone(),
 225            request_limiter: RateLimiter::new(4),
 226        })
 227    }
 228}
 229
 230impl LanguageModelProviderState for AnthropicLanguageModelProvider {
 231    type ObservableEntity = State;
 232
 233    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
 234        Some(self.state.clone())
 235    }
 236}
 237
 238impl LanguageModelProvider for AnthropicLanguageModelProvider {
 239    fn id(&self) -> LanguageModelProviderId {
 240        PROVIDER_ID
 241    }
 242
 243    fn name(&self) -> LanguageModelProviderName {
 244        PROVIDER_NAME
 245    }
 246
 247    fn icon(&self) -> IconName {
 248        IconName::AiAnthropic
 249    }
 250
 251    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 252        Some(self.create_language_model(anthropic::Model::default()))
 253    }
 254
 255    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 256        Some(self.create_language_model(anthropic::Model::default_fast()))
 257    }
 258
 259    fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 260        [
 261            anthropic::Model::ClaudeSonnet4,
 262            anthropic::Model::ClaudeSonnet4Thinking,
 263        ]
 264        .into_iter()
 265        .map(|model| self.create_language_model(model))
 266        .collect()
 267    }
 268
 269    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 270        let mut models = BTreeMap::default();
 271
 272        // Add base models from anthropic::Model::iter()
 273        for model in anthropic::Model::iter() {
 274            if !matches!(model, anthropic::Model::Custom { .. }) {
 275                models.insert(model.id().to_string(), model);
 276            }
 277        }
 278
 279        // Override with available models from settings
 280        for model in AllLanguageModelSettings::get_global(cx)
 281            .anthropic
 282            .available_models
 283            .iter()
 284        {
 285            models.insert(
 286                model.name.clone(),
 287                anthropic::Model::Custom {
 288                    name: model.name.clone(),
 289                    display_name: model.display_name.clone(),
 290                    max_tokens: model.max_tokens,
 291                    tool_override: model.tool_override.clone(),
 292                    cache_configuration: model.cache_configuration.as_ref().map(|config| {
 293                        anthropic::AnthropicModelCacheConfiguration {
 294                            max_cache_anchors: config.max_cache_anchors,
 295                            should_speculate: config.should_speculate,
 296                            min_total_token: config.min_total_token,
 297                        }
 298                    }),
 299                    max_output_tokens: model.max_output_tokens,
 300                    default_temperature: model.default_temperature,
 301                    extra_beta_headers: model.extra_beta_headers.clone(),
 302                    mode: model.mode.clone().unwrap_or_default().into(),
 303                },
 304            );
 305        }
 306
 307        models
 308            .into_values()
 309            .map(|model| self.create_language_model(model))
 310            .collect()
 311    }
 312
 313    fn is_authenticated(&self, cx: &App) -> bool {
 314        self.state.read(cx).is_authenticated()
 315    }
 316
 317    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 318        self.state.update(cx, |state, cx| state.authenticate(cx))
 319    }
 320
 321    fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
 322        cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
 323            .into()
 324    }
 325
 326    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
 327        self.state.update(cx, |state, cx| state.reset_api_key(cx))
 328    }
 329}
 330
 331pub struct AnthropicModel {
 332    id: LanguageModelId,
 333    model: anthropic::Model,
 334    state: gpui::Entity<State>,
 335    http_client: Arc<dyn HttpClient>,
 336    request_limiter: RateLimiter,
 337}
 338
 339pub fn count_anthropic_tokens(
 340    request: LanguageModelRequest,
 341    cx: &App,
 342) -> BoxFuture<'static, Result<u64>> {
 343    cx.background_spawn(async move {
 344        let messages = request.messages;
 345        let mut tokens_from_images = 0;
 346        let mut string_messages = Vec::with_capacity(messages.len());
 347
 348        for message in messages {
 349            use language_model::MessageContent;
 350
 351            let mut string_contents = String::new();
 352
 353            for content in message.content {
 354                match content {
 355                    MessageContent::Text(text) => {
 356                        string_contents.push_str(&text);
 357                    }
 358                    MessageContent::Thinking { .. } => {
 359                        // Thinking blocks are not included in the input token count.
 360                    }
 361                    MessageContent::RedactedThinking(_) => {
 362                        // Thinking blocks are not included in the input token count.
 363                    }
 364                    MessageContent::Image(image) => {
 365                        tokens_from_images += image.estimate_tokens();
 366                    }
 367                    MessageContent::ToolUse(_tool_use) => {
 368                        // TODO: Estimate token usage from tool uses.
 369                    }
 370                    MessageContent::ToolResult(tool_result) => match &tool_result.content {
 371                        LanguageModelToolResultContent::Text(text) => {
 372                            string_contents.push_str(text);
 373                        }
 374                        LanguageModelToolResultContent::Image(image) => {
 375                            tokens_from_images += image.estimate_tokens();
 376                        }
 377                    },
 378                }
 379            }
 380
 381            if !string_contents.is_empty() {
 382                string_messages.push(tiktoken_rs::ChatCompletionRequestMessage {
 383                    role: match message.role {
 384                        Role::User => "user".into(),
 385                        Role::Assistant => "assistant".into(),
 386                        Role::System => "system".into(),
 387                    },
 388                    content: Some(string_contents),
 389                    name: None,
 390                    function_call: None,
 391                });
 392            }
 393        }
 394
 395        // Tiktoken doesn't yet support these models, so we manually use the
 396        // same tokenizer as GPT-4.
 397        tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages)
 398            .map(|tokens| (tokens + tokens_from_images) as u64)
 399    })
 400    .boxed()
 401}
 402
 403impl AnthropicModel {
 404    fn stream_completion(
 405        &self,
 406        request: anthropic::Request,
 407        cx: &AsyncApp,
 408    ) -> BoxFuture<
 409        'static,
 410        Result<
 411            BoxStream<'static, Result<anthropic::Event, AnthropicError>>,
 412            LanguageModelCompletionError,
 413        >,
 414    > {
 415        let http_client = self.http_client.clone();
 416
 417        let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
 418            let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
 419            (state.api_key.clone(), settings.api_url.clone())
 420        }) else {
 421            return futures::future::ready(Err(anyhow!("App state dropped").into())).boxed();
 422        };
 423
 424        async move {
 425            let Some(api_key) = api_key else {
 426                return Err(LanguageModelCompletionError::NoApiKey {
 427                    provider: PROVIDER_NAME,
 428                });
 429            };
 430            let request =
 431                anthropic::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
 432            request.await.map_err(Into::into)
 433        }
 434        .boxed()
 435    }
 436}
 437
 438impl LanguageModel for AnthropicModel {
 439    fn id(&self) -> LanguageModelId {
 440        self.id.clone()
 441    }
 442
 443    fn name(&self) -> LanguageModelName {
 444        LanguageModelName::from(self.model.display_name().to_string())
 445    }
 446
 447    fn provider_id(&self) -> LanguageModelProviderId {
 448        PROVIDER_ID
 449    }
 450
 451    fn provider_name(&self) -> LanguageModelProviderName {
 452        PROVIDER_NAME
 453    }
 454
 455    fn supports_tools(&self) -> bool {
 456        true
 457    }
 458
 459    fn supports_images(&self) -> bool {
 460        true
 461    }
 462
 463    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
 464        match choice {
 465            LanguageModelToolChoice::Auto
 466            | LanguageModelToolChoice::Any
 467            | LanguageModelToolChoice::None => true,
 468        }
 469    }
 470
 471    fn telemetry_id(&self) -> String {
 472        format!("anthropic/{}", self.model.id())
 473    }
 474
 475    fn api_key(&self, cx: &App) -> Option<String> {
 476        self.state.read(cx).api_key.clone()
 477    }
 478
 479    fn max_token_count(&self) -> u64 {
 480        self.model.max_token_count()
 481    }
 482
 483    fn max_output_tokens(&self) -> Option<u64> {
 484        Some(self.model.max_output_tokens())
 485    }
 486
 487    fn count_tokens(
 488        &self,
 489        request: LanguageModelRequest,
 490        cx: &App,
 491    ) -> BoxFuture<'static, Result<u64>> {
 492        count_anthropic_tokens(request, cx)
 493    }
 494
 495    fn stream_completion(
 496        &self,
 497        request: LanguageModelRequest,
 498        cx: &AsyncApp,
 499    ) -> BoxFuture<
 500        'static,
 501        Result<
 502            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 503            LanguageModelCompletionError,
 504        >,
 505    > {
 506        let request = into_anthropic(
 507            request,
 508            self.model.request_id().into(),
 509            self.model.default_temperature(),
 510            self.model.max_output_tokens(),
 511            self.model.mode(),
 512        );
 513        let request = self.stream_completion(request, cx);
 514        let future = self.request_limiter.stream(async move {
 515            let response = request.await?;
 516            Ok(AnthropicEventMapper::new().map_stream(response))
 517        });
 518        async move { Ok(future.await?.boxed()) }.boxed()
 519    }
 520
 521    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
 522        self.model
 523            .cache_configuration()
 524            .map(|config| LanguageModelCacheConfiguration {
 525                max_cache_anchors: config.max_cache_anchors,
 526                should_speculate: config.should_speculate,
 527                min_total_token: config.min_total_token,
 528            })
 529    }
 530}
 531
 532pub fn into_anthropic(
 533    request: LanguageModelRequest,
 534    model: String,
 535    default_temperature: f32,
 536    max_output_tokens: u64,
 537    mode: AnthropicModelMode,
 538) -> anthropic::Request {
 539    let mut new_messages: Vec<anthropic::Message> = Vec::new();
 540    let mut system_message = String::new();
 541
 542    for message in request.messages {
 543        if message.contents_empty() {
 544            continue;
 545        }
 546
 547        match message.role {
 548            Role::User | Role::Assistant => {
 549                let mut anthropic_message_content: Vec<anthropic::RequestContent> = message
 550                    .content
 551                    .into_iter()
 552                    .filter_map(|content| match content {
 553                        MessageContent::Text(text) => {
 554                            let text = if text.chars().last().map_or(false, |c| c.is_whitespace()) {
 555                                text.trim_end().to_string()
 556                            } else {
 557                                text
 558                            };
 559                            if !text.is_empty() {
 560                                Some(anthropic::RequestContent::Text {
 561                                    text,
 562                                    cache_control: None,
 563                                })
 564                            } else {
 565                                None
 566                            }
 567                        }
 568                        MessageContent::Thinking {
 569                            text: thinking,
 570                            signature,
 571                        } => {
 572                            if !thinking.is_empty() {
 573                                Some(anthropic::RequestContent::Thinking {
 574                                    thinking,
 575                                    signature: signature.unwrap_or_default(),
 576                                    cache_control: None,
 577                                })
 578                            } else {
 579                                None
 580                            }
 581                        }
 582                        MessageContent::RedactedThinking(data) => {
 583                            if !data.is_empty() {
 584                                Some(anthropic::RequestContent::RedactedThinking { data })
 585                            } else {
 586                                None
 587                            }
 588                        }
 589                        MessageContent::Image(image) => Some(anthropic::RequestContent::Image {
 590                            source: anthropic::ImageSource {
 591                                source_type: "base64".to_string(),
 592                                media_type: "image/png".to_string(),
 593                                data: image.source.to_string(),
 594                            },
 595                            cache_control: None,
 596                        }),
 597                        MessageContent::ToolUse(tool_use) => {
 598                            Some(anthropic::RequestContent::ToolUse {
 599                                id: tool_use.id.to_string(),
 600                                name: tool_use.name.to_string(),
 601                                input: tool_use.input,
 602                                cache_control: None,
 603                            })
 604                        }
 605                        MessageContent::ToolResult(tool_result) => {
 606                            Some(anthropic::RequestContent::ToolResult {
 607                                tool_use_id: tool_result.tool_use_id.to_string(),
 608                                is_error: tool_result.is_error,
 609                                content: match tool_result.content {
 610                                    LanguageModelToolResultContent::Text(text) => {
 611                                        ToolResultContent::Plain(text.to_string())
 612                                    }
 613                                    LanguageModelToolResultContent::Image(image) => {
 614                                        ToolResultContent::Multipart(vec![ToolResultPart::Image {
 615                                            source: anthropic::ImageSource {
 616                                                source_type: "base64".to_string(),
 617                                                media_type: "image/png".to_string(),
 618                                                data: image.source.to_string(),
 619                                            },
 620                                        }])
 621                                    }
 622                                },
 623                                cache_control: None,
 624                            })
 625                        }
 626                    })
 627                    .collect();
 628                let anthropic_role = match message.role {
 629                    Role::User => anthropic::Role::User,
 630                    Role::Assistant => anthropic::Role::Assistant,
 631                    Role::System => unreachable!("System role should never occur here"),
 632                };
 633                if let Some(last_message) = new_messages.last_mut() {
 634                    if last_message.role == anthropic_role {
 635                        last_message.content.extend(anthropic_message_content);
 636                        continue;
 637                    }
 638                }
 639
 640                // Mark the last segment of the message as cached
 641                if message.cache {
 642                    let cache_control_value = Some(anthropic::CacheControl {
 643                        cache_type: anthropic::CacheControlType::Ephemeral,
 644                    });
 645                    for message_content in anthropic_message_content.iter_mut().rev() {
 646                        match message_content {
 647                            anthropic::RequestContent::RedactedThinking { .. } => {
 648                                // Caching is not possible, fallback to next message
 649                            }
 650                            anthropic::RequestContent::Text { cache_control, .. }
 651                            | anthropic::RequestContent::Thinking { cache_control, .. }
 652                            | anthropic::RequestContent::Image { cache_control, .. }
 653                            | anthropic::RequestContent::ToolUse { cache_control, .. }
 654                            | anthropic::RequestContent::ToolResult { cache_control, .. } => {
 655                                *cache_control = cache_control_value;
 656                                break;
 657                            }
 658                        }
 659                    }
 660                }
 661
 662                new_messages.push(anthropic::Message {
 663                    role: anthropic_role,
 664                    content: anthropic_message_content,
 665                });
 666            }
 667            Role::System => {
 668                if !system_message.is_empty() {
 669                    system_message.push_str("\n\n");
 670                }
 671                system_message.push_str(&message.string_contents());
 672            }
 673        }
 674    }
 675
 676    anthropic::Request {
 677        model,
 678        messages: new_messages,
 679        max_tokens: max_output_tokens,
 680        system: if system_message.is_empty() {
 681            None
 682        } else {
 683            Some(anthropic::StringOrContents::String(system_message))
 684        },
 685        thinking: if request.thinking_allowed
 686            && let AnthropicModelMode::Thinking { budget_tokens } = mode
 687        {
 688            Some(anthropic::Thinking::Enabled { budget_tokens })
 689        } else {
 690            None
 691        },
 692        tools: request
 693            .tools
 694            .into_iter()
 695            .map(|tool| anthropic::Tool {
 696                name: tool.name,
 697                description: tool.description,
 698                input_schema: tool.input_schema,
 699            })
 700            .collect(),
 701        tool_choice: request.tool_choice.map(|choice| match choice {
 702            LanguageModelToolChoice::Auto => anthropic::ToolChoice::Auto,
 703            LanguageModelToolChoice::Any => anthropic::ToolChoice::Any,
 704            LanguageModelToolChoice::None => anthropic::ToolChoice::None,
 705        }),
 706        metadata: None,
 707        stop_sequences: Vec::new(),
 708        temperature: request.temperature.or(Some(default_temperature)),
 709        top_k: None,
 710        top_p: None,
 711    }
 712}
 713
 714pub struct AnthropicEventMapper {
 715    tool_uses_by_index: HashMap<usize, RawToolUse>,
 716    usage: Usage,
 717    stop_reason: StopReason,
 718}
 719
 720impl AnthropicEventMapper {
 721    pub fn new() -> Self {
 722        Self {
 723            tool_uses_by_index: HashMap::default(),
 724            usage: Usage::default(),
 725            stop_reason: StopReason::EndTurn,
 726        }
 727    }
 728
 729    pub fn map_stream(
 730        mut self,
 731        events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
 732    ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 733    {
 734        events.flat_map(move |event| {
 735            futures::stream::iter(match event {
 736                Ok(event) => self.map_event(event),
 737                Err(error) => vec![Err(error.into())],
 738            })
 739        })
 740    }
 741
 742    pub fn map_event(
 743        &mut self,
 744        event: Event,
 745    ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
 746        match event {
 747            Event::ContentBlockStart {
 748                index,
 749                content_block,
 750            } => match content_block {
 751                ResponseContent::Text { text } => {
 752                    vec![Ok(LanguageModelCompletionEvent::Text(text))]
 753                }
 754                ResponseContent::Thinking { thinking } => {
 755                    vec![Ok(LanguageModelCompletionEvent::Thinking {
 756                        text: thinking,
 757                        signature: None,
 758                    })]
 759                }
 760                ResponseContent::RedactedThinking { data } => {
 761                    vec![Ok(LanguageModelCompletionEvent::RedactedThinking { data })]
 762                }
 763                ResponseContent::ToolUse { id, name, .. } => {
 764                    self.tool_uses_by_index.insert(
 765                        index,
 766                        RawToolUse {
 767                            id,
 768                            name,
 769                            input_json: String::new(),
 770                        },
 771                    );
 772                    Vec::new()
 773                }
 774            },
 775            Event::ContentBlockDelta { index, delta } => match delta {
 776                ContentDelta::TextDelta { text } => {
 777                    vec![Ok(LanguageModelCompletionEvent::Text(text))]
 778                }
 779                ContentDelta::ThinkingDelta { thinking } => {
 780                    vec![Ok(LanguageModelCompletionEvent::Thinking {
 781                        text: thinking,
 782                        signature: None,
 783                    })]
 784                }
 785                ContentDelta::SignatureDelta { signature } => {
 786                    vec![Ok(LanguageModelCompletionEvent::Thinking {
 787                        text: "".to_string(),
 788                        signature: Some(signature),
 789                    })]
 790                }
 791                ContentDelta::InputJsonDelta { partial_json } => {
 792                    if let Some(tool_use) = self.tool_uses_by_index.get_mut(&index) {
 793                        tool_use.input_json.push_str(&partial_json);
 794
 795                        // Try to convert invalid (incomplete) JSON into
 796                        // valid JSON that serde can accept, e.g. by closing
 797                        // unclosed delimiters. This way, we can update the
 798                        // UI with whatever has been streamed back so far.
 799                        if let Ok(input) = serde_json::Value::from_str(
 800                            &partial_json_fixer::fix_json(&tool_use.input_json),
 801                        ) {
 802                            return vec![Ok(LanguageModelCompletionEvent::ToolUse(
 803                                LanguageModelToolUse {
 804                                    id: tool_use.id.clone().into(),
 805                                    name: tool_use.name.clone().into(),
 806                                    is_input_complete: false,
 807                                    raw_input: tool_use.input_json.clone(),
 808                                    input,
 809                                },
 810                            ))];
 811                        }
 812                    }
 813                    return vec![];
 814                }
 815            },
 816            Event::ContentBlockStop { index } => {
 817                if let Some(tool_use) = self.tool_uses_by_index.remove(&index) {
 818                    let input_json = tool_use.input_json.trim();
 819                    let input_value = if input_json.is_empty() {
 820                        Ok(serde_json::Value::Object(serde_json::Map::default()))
 821                    } else {
 822                        serde_json::Value::from_str(input_json)
 823                    };
 824                    let event_result = match input_value {
 825                        Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
 826                            LanguageModelToolUse {
 827                                id: tool_use.id.into(),
 828                                name: tool_use.name.into(),
 829                                is_input_complete: true,
 830                                input,
 831                                raw_input: tool_use.input_json.clone(),
 832                            },
 833                        )),
 834                        Err(json_parse_err) => {
 835                            Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
 836                                id: tool_use.id.into(),
 837                                tool_name: tool_use.name.into(),
 838                                raw_input: input_json.into(),
 839                                json_parse_error: json_parse_err.to_string(),
 840                            })
 841                        }
 842                    };
 843
 844                    vec![event_result]
 845                } else {
 846                    Vec::new()
 847                }
 848            }
 849            Event::MessageStart { message } => {
 850                update_usage(&mut self.usage, &message.usage);
 851                vec![
 852                    Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage(
 853                        &self.usage,
 854                    ))),
 855                    Ok(LanguageModelCompletionEvent::StartMessage {
 856                        message_id: message.id,
 857                    }),
 858                ]
 859            }
 860            Event::MessageDelta { delta, usage } => {
 861                update_usage(&mut self.usage, &usage);
 862                if let Some(stop_reason) = delta.stop_reason.as_deref() {
 863                    self.stop_reason = match stop_reason {
 864                        "end_turn" => StopReason::EndTurn,
 865                        "max_tokens" => StopReason::MaxTokens,
 866                        "tool_use" => StopReason::ToolUse,
 867                        "refusal" => StopReason::Refusal,
 868                        _ => {
 869                            log::error!("Unexpected anthropic stop_reason: {stop_reason}");
 870                            StopReason::EndTurn
 871                        }
 872                    };
 873                }
 874                vec![Ok(LanguageModelCompletionEvent::UsageUpdate(
 875                    convert_usage(&self.usage),
 876                ))]
 877            }
 878            Event::MessageStop => {
 879                vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))]
 880            }
 881            Event::Error { error } => {
 882                vec![Err(error.into())]
 883            }
 884            _ => Vec::new(),
 885        }
 886    }
 887}
 888
 889struct RawToolUse {
 890    id: String,
 891    name: String,
 892    input_json: String,
 893}
 894
 895/// Updates usage data by preferring counts from `new`.
 896fn update_usage(usage: &mut Usage, new: &Usage) {
 897    if let Some(input_tokens) = new.input_tokens {
 898        usage.input_tokens = Some(input_tokens);
 899    }
 900    if let Some(output_tokens) = new.output_tokens {
 901        usage.output_tokens = Some(output_tokens);
 902    }
 903    if let Some(cache_creation_input_tokens) = new.cache_creation_input_tokens {
 904        usage.cache_creation_input_tokens = Some(cache_creation_input_tokens);
 905    }
 906    if let Some(cache_read_input_tokens) = new.cache_read_input_tokens {
 907        usage.cache_read_input_tokens = Some(cache_read_input_tokens);
 908    }
 909}
 910
 911fn convert_usage(usage: &Usage) -> language_model::TokenUsage {
 912    language_model::TokenUsage {
 913        input_tokens: usage.input_tokens.unwrap_or(0),
 914        output_tokens: usage.output_tokens.unwrap_or(0),
 915        cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
 916        cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
 917    }
 918}
 919
 920struct ConfigurationView {
 921    api_key_editor: Entity<Editor>,
 922    state: gpui::Entity<State>,
 923    load_credentials_task: Option<Task<()>>,
 924}
 925
 926impl ConfigurationView {
 927    const PLACEHOLDER_TEXT: &'static str = "sk-ant-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx";
 928
 929    fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
 930        cx.observe(&state, |_, _, cx| {
 931            cx.notify();
 932        })
 933        .detach();
 934
 935        let load_credentials_task = Some(cx.spawn({
 936            let state = state.clone();
 937            async move |this, cx| {
 938                if let Some(task) = state
 939                    .update(cx, |state, cx| state.authenticate(cx))
 940                    .log_err()
 941                {
 942                    // We don't log an error, because "not signed in" is also an error.
 943                    let _ = task.await;
 944                }
 945                this.update(cx, |this, cx| {
 946                    this.load_credentials_task = None;
 947                    cx.notify();
 948                })
 949                .log_err();
 950            }
 951        }));
 952
 953        Self {
 954            api_key_editor: cx.new(|cx| {
 955                let mut editor = Editor::single_line(window, cx);
 956                editor.set_placeholder_text(Self::PLACEHOLDER_TEXT, cx);
 957                editor
 958            }),
 959            state,
 960            load_credentials_task,
 961        }
 962    }
 963
 964    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
 965        let api_key = self.api_key_editor.read(cx).text(cx);
 966        if api_key.is_empty() {
 967            return;
 968        }
 969
 970        let state = self.state.clone();
 971        cx.spawn_in(window, async move |_, cx| {
 972            state
 973                .update(cx, |state, cx| state.set_api_key(api_key, cx))?
 974                .await
 975        })
 976        .detach_and_log_err(cx);
 977
 978        cx.notify();
 979    }
 980
 981    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 982        self.api_key_editor
 983            .update(cx, |editor, cx| editor.set_text("", window, cx));
 984
 985        let state = self.state.clone();
 986        cx.spawn_in(window, async move |_, cx| {
 987            state.update(cx, |state, cx| state.reset_api_key(cx))?.await
 988        })
 989        .detach_and_log_err(cx);
 990
 991        cx.notify();
 992    }
 993
 994    fn render_api_key_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
 995        let settings = ThemeSettings::get_global(cx);
 996        let text_style = TextStyle {
 997            color: cx.theme().colors().text,
 998            font_family: settings.ui_font.family.clone(),
 999            font_features: settings.ui_font.features.clone(),
1000            font_fallbacks: settings.ui_font.fallbacks.clone(),
1001            font_size: rems(0.875).into(),
1002            font_weight: settings.ui_font.weight,
1003            font_style: FontStyle::Normal,
1004            line_height: relative(1.3),
1005            white_space: WhiteSpace::Normal,
1006            ..Default::default()
1007        };
1008        EditorElement::new(
1009            &self.api_key_editor,
1010            EditorStyle {
1011                background: cx.theme().colors().editor_background,
1012                local_player: cx.theme().players().local(),
1013                text: text_style,
1014                ..Default::default()
1015            },
1016        )
1017    }
1018
1019    fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
1020        !self.state.read(cx).is_authenticated()
1021    }
1022}
1023
1024impl Render for ConfigurationView {
1025    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1026        let env_var_set = self.state.read(cx).api_key_from_env;
1027
1028        if self.load_credentials_task.is_some() {
1029            div().child(Label::new("Loading credentials...")).into_any()
1030        } else if self.should_render_editor(cx) {
1031            v_flex()
1032                .size_full()
1033                .on_action(cx.listener(Self::save_api_key))
1034                .child(Label::new("To use Zed's agent with Anthropic, you need to add an API key. Follow these steps:"))
1035                .child(
1036                    List::new()
1037                        .child(
1038                            InstructionListItem::new(
1039                                "Create one by visiting",
1040                                Some("Anthropic's settings"),
1041                                Some("https://console.anthropic.com/settings/keys")
1042                            )
1043                        )
1044                        .child(
1045                            InstructionListItem::text_only("Paste your API key below and hit enter to start using the assistant")
1046                        )
1047                )
1048                .child(
1049                    h_flex()
1050                        .w_full()
1051                        .my_2()
1052                        .px_2()
1053                        .py_1()
1054                        .bg(cx.theme().colors().editor_background)
1055                        .border_1()
1056                        .border_color(cx.theme().colors().border)
1057                        .rounded_sm()
1058                        .child(self.render_api_key_editor(cx)),
1059                )
1060                .child(
1061                    Label::new(
1062                        format!("You can also assign the {ANTHROPIC_API_KEY_VAR} environment variable and restart Zed."),
1063                    )
1064                    .size(LabelSize::Small)
1065                    .color(Color::Muted),
1066                )
1067                .into_any()
1068        } else {
1069            h_flex()
1070                .mt_1()
1071                .p_1()
1072                .justify_between()
1073                .rounded_md()
1074                .border_1()
1075                .border_color(cx.theme().colors().border)
1076                .bg(cx.theme().colors().background)
1077                .child(
1078                    h_flex()
1079                        .gap_1()
1080                        .child(Icon::new(IconName::Check).color(Color::Success))
1081                        .child(Label::new(if env_var_set {
1082                            format!("API key set in {ANTHROPIC_API_KEY_VAR} environment variable.")
1083                        } else {
1084                            "API key configured.".to_string()
1085                        })),
1086                )
1087                .child(
1088                    Button::new("reset-key", "Reset Key")
1089                        .label_size(LabelSize::Small)
1090                        .icon(Some(IconName::Trash))
1091                        .icon_size(IconSize::Small)
1092                        .icon_position(IconPosition::Start)
1093                        .disabled(env_var_set)
1094                        .when(env_var_set, |this| {
1095                            this.tooltip(Tooltip::text(format!("To reset your API key, unset the {ANTHROPIC_API_KEY_VAR} environment variable.")))
1096                        })
1097                        .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
1098                )
1099                .into_any()
1100        }
1101    }
1102}
1103
1104#[cfg(test)]
1105mod tests {
1106    use super::*;
1107    use anthropic::AnthropicModelMode;
1108    use language_model::{LanguageModelRequestMessage, MessageContent};
1109
1110    #[test]
1111    fn test_cache_control_only_on_last_segment() {
1112        let request = LanguageModelRequest {
1113            messages: vec![LanguageModelRequestMessage {
1114                role: Role::User,
1115                content: vec![
1116                    MessageContent::Text("Some prompt".to_string()),
1117                    MessageContent::Image(language_model::LanguageModelImage::empty()),
1118                    MessageContent::Image(language_model::LanguageModelImage::empty()),
1119                    MessageContent::Image(language_model::LanguageModelImage::empty()),
1120                    MessageContent::Image(language_model::LanguageModelImage::empty()),
1121                ],
1122                cache: true,
1123            }],
1124            thread_id: None,
1125            prompt_id: None,
1126            intent: None,
1127            mode: None,
1128            stop: vec![],
1129            temperature: None,
1130            tools: vec![],
1131            tool_choice: None,
1132            thinking_allowed: true,
1133        };
1134
1135        let anthropic_request = into_anthropic(
1136            request,
1137            "claude-3-5-sonnet".to_string(),
1138            0.7,
1139            4096,
1140            AnthropicModelMode::Default,
1141        );
1142
1143        assert_eq!(anthropic_request.messages.len(), 1);
1144
1145        let message = &anthropic_request.messages[0];
1146        assert_eq!(message.content.len(), 5);
1147
1148        assert!(matches!(
1149            message.content[0],
1150            anthropic::RequestContent::Text {
1151                cache_control: None,
1152                ..
1153            }
1154        ));
1155        for i in 1..3 {
1156            assert!(matches!(
1157                message.content[i],
1158                anthropic::RequestContent::Image {
1159                    cache_control: None,
1160                    ..
1161                }
1162            ));
1163        }
1164
1165        assert!(matches!(
1166            message.content[4],
1167            anthropic::RequestContent::Image {
1168                cache_control: Some(anthropic::CacheControl {
1169                    cache_type: anthropic::CacheControlType::Ephemeral,
1170                }),
1171                ..
1172            }
1173        ));
1174    }
1175}