@@ -66,12 +66,16 @@ use util::{ResultExt, paths::RemotePathBuf};
use wasm_host::llm_provider::ExtensionLanguageModelProvider;
use wasm_host::{
WasmExtension, WasmHost,
- wit::{LlmModelInfo, LlmProviderInfo, is_supported_wasm_api_version, wasm_api_version_range},
+ wit::{
+ LlmCacheConfiguration, LlmModelInfo, LlmProviderInfo, is_supported_wasm_api_version,
+ wasm_api_version_range,
+ },
};
struct LlmProviderWithModels {
provider_info: LlmProviderInfo,
models: Vec<LlmModelInfo>,
+ cache_configs: collections::HashMap<String, LlmCacheConfiguration>,
is_authenticated: bool,
icon_path: Option<SharedString>,
auth_config: Option<extension::LanguageModelAuthConfig>,
@@ -1635,6 +1639,32 @@ impl ExtensionStore {
}
};
+ // Query cache configurations for each model
+ let mut cache_configs = collections::HashMap::default();
+ for model in &models {
+ let cache_config_result = wasm_extension
+ .call({
+ let provider_id = provider_info.id.clone();
+ let model_id = model.id.clone();
+ |ext, store| {
+ async move {
+ ext.call_llm_cache_configuration(
+ store,
+ &provider_id,
+ &model_id,
+ )
+ .await
+ }
+ .boxed()
+ }
+ })
+ .await;
+
+ if let Ok(Ok(Some(config))) = cache_config_result {
+ cache_configs.insert(model.id.clone(), config);
+ }
+ }
+
// Query initial authentication state
let is_authenticated = wasm_extension
.call({
@@ -1677,6 +1707,7 @@ impl ExtensionStore {
llm_providers_with_models.push(LlmProviderWithModels {
provider_info,
models,
+ cache_configs,
is_authenticated,
icon_path,
auth_config,
@@ -1776,6 +1807,7 @@ impl ExtensionStore {
let wasm_ext = extension.as_ref().clone();
let pinfo = llm_provider.provider_info.clone();
let mods = llm_provider.models.clone();
+ let cache_cfgs = llm_provider.cache_configs.clone();
let auth = llm_provider.is_authenticated;
let icon = llm_provider.icon_path.clone();
let auth_config = llm_provider.auth_config.clone();
@@ -1784,7 +1816,7 @@ impl ExtensionStore {
provider_id.clone(),
Box::new(move |cx: &mut App| {
let provider = Arc::new(ExtensionLanguageModelProvider::new(
- wasm_ext, pinfo, mods, auth, icon, auth_config, cx,
+ wasm_ext, pinfo, mods, cache_cfgs, auth, icon, auth_config, cx,
));
language_model::LanguageModelRegistry::global(cx).update(
cx,
@@ -5,11 +5,12 @@ use crate::wasm_host::wit::LlmDeviceFlowPromptInfo;
use collections::HashSet;
use crate::wasm_host::wit::{
- LlmCompletionEvent, LlmCompletionRequest, LlmImageData, LlmMessageContent, LlmMessageRole,
- LlmModelInfo, LlmProviderInfo, LlmRequestMessage, LlmStopReason, LlmThinkingContent,
- LlmToolChoice, LlmToolDefinition, LlmToolInputFormat, LlmToolResult, LlmToolResultContent,
- LlmToolUse,
+ LlmCacheConfiguration, LlmCompletionEvent, LlmCompletionRequest, LlmImageData,
+ LlmMessageContent, LlmMessageRole, LlmModelInfo, LlmProviderInfo, LlmRequestMessage,
+ LlmStopReason, LlmThinkingContent, LlmToolChoice, LlmToolDefinition, LlmToolInputFormat,
+ LlmToolResult, LlmToolResultContent, LlmToolUse,
};
+use collections::HashMap;
use anyhow::{Result, anyhow};
use credentials_provider::CredentialsProvider;
use extension::{LanguageModelAuthConfig, OAuthConfig};
@@ -58,6 +59,8 @@ pub struct ExtensionLanguageModelProvider {
pub struct ExtensionLlmProviderState {
is_authenticated: bool,
available_models: Vec<LlmModelInfo>,
+ /// Cache configurations for each model, keyed by model ID.
+ cache_configs: HashMap<String, LlmCacheConfiguration>,
/// Set of env var names that are allowed to be read for this provider.
allowed_env_vars: HashSet<String>,
/// If authenticated via env var, which one was used.
@@ -71,6 +74,7 @@ impl ExtensionLanguageModelProvider {
extension: WasmExtension,
provider_info: LlmProviderInfo,
models: Vec<LlmModelInfo>,
+ cache_configs: HashMap<String, LlmCacheConfiguration>,
is_authenticated: bool,
icon_path: Option<SharedString>,
auth_config: Option<LanguageModelAuthConfig>,
@@ -118,6 +122,7 @@ impl ExtensionLanguageModelProvider {
let state = cx.new(|_| ExtensionLlmProviderState {
is_authenticated,
available_models: models,
+ cache_configs,
allowed_env_vars,
env_var_name_used,
});
@@ -139,6 +144,30 @@ impl ExtensionLanguageModelProvider {
fn credential_key(&self) -> String {
format!("extension-llm-{}", self.provider_id_string())
}
+
+ fn create_model(
+ &self,
+ model_info: &LlmModelInfo,
+ cache_configs: &HashMap<String, LlmCacheConfiguration>,
+ ) -> Arc<dyn LanguageModel> {
+ let cache_config = cache_configs.get(&model_info.id).map(|config| {
+ LanguageModelCacheConfiguration {
+ max_cache_anchors: config.max_cache_anchors as usize,
+ should_speculate: false,
+ min_total_token: config.min_total_token_count,
+ }
+ });
+
+ Arc::new(ExtensionLanguageModel {
+ extension: self.extension.clone(),
+ model_info: model_info.clone(),
+ provider_id: self.id(),
+ provider_name: self.name(),
+ provider_info: self.provider_info.clone(),
+ request_limiter: RateLimiter::new(4),
+ cache_config,
+ })
+ }
}
impl LanguageModelProvider for ExtensionLanguageModelProvider {
@@ -165,16 +194,7 @@ impl LanguageModelProvider for ExtensionLanguageModelProvider {
.iter()
.find(|m| m.is_default)
.or_else(|| state.available_models.first())
- .map(|model_info| {
- Arc::new(ExtensionLanguageModel {
- extension: self.extension.clone(),
- model_info: model_info.clone(),
- provider_id: self.id(),
- provider_name: self.name(),
- provider_info: self.provider_info.clone(),
- request_limiter: RateLimiter::new(4),
- }) as Arc<dyn LanguageModel>
- })
+ .map(|model_info| self.create_model(model_info, &state.cache_configs))
}
fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
@@ -183,16 +203,7 @@ impl LanguageModelProvider for ExtensionLanguageModelProvider {
.available_models
.iter()
.find(|m| m.is_default_fast)
- .map(|model_info| {
- Arc::new(ExtensionLanguageModel {
- extension: self.extension.clone(),
- model_info: model_info.clone(),
- provider_id: self.id(),
- provider_name: self.name(),
- provider_info: self.provider_info.clone(),
- request_limiter: RateLimiter::new(4),
- }) as Arc<dyn LanguageModel>
- })
+ .map(|model_info| self.create_model(model_info, &state.cache_configs))
}
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
@@ -200,16 +211,7 @@ impl LanguageModelProvider for ExtensionLanguageModelProvider {
state
.available_models
.iter()
- .map(|model_info| {
- Arc::new(ExtensionLanguageModel {
- extension: self.extension.clone(),
- model_info: model_info.clone(),
- provider_id: self.id(),
- provider_name: self.name(),
- provider_info: self.provider_info.clone(),
- request_limiter: RateLimiter::new(4),
- }) as Arc<dyn LanguageModel>
- })
+ .map(|model_info| self.create_model(model_info, &state.cache_configs))
.collect()
}
@@ -1595,6 +1597,7 @@ pub struct ExtensionLanguageModel {
provider_name: LanguageModelProviderName,
provider_info: LlmProviderInfo,
request_limiter: RateLimiter,
+ cache_config: Option<LanguageModelCacheConfiguration>,
}
impl LanguageModel for ExtensionLanguageModel {
@@ -1615,7 +1618,7 @@ impl LanguageModel for ExtensionLanguageModel {
}
fn telemetry_id(&self) -> String {
- format!("extension-{}", self.model_info.id)
+ format!("{}/{}", self.provider_info.id, self.model_info.id)
}
fn supports_images(&self) -> bool {
@@ -1795,8 +1798,7 @@ impl LanguageModel for ExtensionLanguageModel {
}
fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
- // Extensions can implement this via llm_cache_configuration
- None
+ self.cache_config.clone()
}
}