1use anyhow::{Context as _, Result};
2use cloud_llm_client::predict_edits_v3::{RawCompletionRequest, RawCompletionResponse};
3use futures::AsyncReadExt as _;
4use gpui::{App, AppContext as _, Entity, Global, SharedString, Task, http_client};
5use language::language_settings::{OpenAiCompatibleEditPredictionSettings, all_language_settings};
6use language_model::{ApiKeyState, EnvVar, env_var};
7use std::sync::Arc;
8use zed_credentials_provider::global as global_credentials_provider;
9
10pub fn open_ai_compatible_api_url(cx: &App) -> SharedString {
11 all_language_settings(None, cx)
12 .edit_predictions
13 .open_ai_compatible_api
14 .as_ref()
15 .map(|settings| settings.api_url.clone())
16 .unwrap_or_default()
17 .into()
18}
19
20pub const OPEN_AI_COMPATIBLE_CREDENTIALS_USERNAME: &str = "openai-compatible-api-token";
21pub static OPEN_AI_COMPATIBLE_TOKEN_ENV_VAR: std::sync::LazyLock<EnvVar> =
22 env_var!("ZED_OPEN_AI_COMPATIBLE_EDIT_PREDICTION_API_KEY");
23
24struct GlobalOpenAiCompatibleApiKey(Entity<ApiKeyState>);
25
26impl Global for GlobalOpenAiCompatibleApiKey {}
27
28pub fn open_ai_compatible_api_token(cx: &mut App) -> Entity<ApiKeyState> {
29 if let Some(global) = cx.try_global::<GlobalOpenAiCompatibleApiKey>() {
30 return global.0.clone();
31 }
32
33 let entity = cx.new(|cx| {
34 ApiKeyState::new(
35 open_ai_compatible_api_url(cx),
36 OPEN_AI_COMPATIBLE_TOKEN_ENV_VAR.clone(),
37 )
38 });
39 cx.set_global(GlobalOpenAiCompatibleApiKey(entity.clone()));
40 entity
41}
42
43pub fn load_open_ai_compatible_api_token(
44 cx: &mut App,
45) -> Task<Result<(), language_model::AuthenticateError>> {
46 let credentials_provider = global_credentials_provider(cx);
47 let api_url = open_ai_compatible_api_url(cx);
48 open_ai_compatible_api_token(cx).update(cx, |key_state, cx| {
49 key_state.load_if_needed(api_url, |s| s, credentials_provider, cx)
50 })
51}
52
53pub fn load_open_ai_compatible_api_key_if_needed(
54 provider: settings::EditPredictionProvider,
55 cx: &mut App,
56) -> Option<Arc<str>> {
57 if provider != settings::EditPredictionProvider::OpenAiCompatibleApi {
58 return None;
59 }
60 _ = load_open_ai_compatible_api_token(cx);
61 let url = open_ai_compatible_api_url(cx);
62 return open_ai_compatible_api_token(cx).read(cx).key(&url);
63}
64
65pub(crate) async fn send_custom_server_request(
66 provider: settings::EditPredictionProvider,
67 settings: &OpenAiCompatibleEditPredictionSettings,
68 prompt: String,
69 max_tokens: u32,
70 stop_tokens: Vec<String>,
71 api_key: Option<Arc<str>>,
72 http_client: &Arc<dyn http_client::HttpClient>,
73) -> Result<(String, String)> {
74 match provider {
75 settings::EditPredictionProvider::Ollama => {
76 let response = crate::ollama::make_request(
77 settings.clone(),
78 prompt,
79 stop_tokens,
80 http_client.clone(),
81 )
82 .await?;
83 Ok((response.response, response.created_at))
84 }
85 _ => {
86 let request = RawCompletionRequest {
87 model: settings.model.clone(),
88 prompt,
89 max_tokens: Some(max_tokens),
90 temperature: None,
91 stop: stop_tokens
92 .into_iter()
93 .map(std::borrow::Cow::Owned)
94 .collect(),
95 environment: None,
96 };
97
98 let request_body = serde_json::to_string(&request)?;
99 let mut http_request_builder = http_client::Request::builder()
100 .method(http_client::Method::POST)
101 .uri(settings.api_url.as_ref())
102 .header("Content-Type", "application/json");
103
104 if let Some(api_key) = api_key {
105 http_request_builder =
106 http_request_builder.header("Authorization", format!("Bearer {}", api_key));
107 }
108
109 let http_request =
110 http_request_builder.body(http_client::AsyncBody::from(request_body))?;
111
112 let mut response = http_client.send(http_request).await?;
113 let status = response.status();
114
115 if !status.is_success() {
116 let mut body = String::new();
117 response.body_mut().read_to_string(&mut body).await?;
118 anyhow::bail!("custom server error: {} - {}", status, body);
119 }
120
121 let mut body = String::new();
122 response.body_mut().read_to_string(&mut body).await?;
123
124 let parsed: RawCompletionResponse =
125 serde_json::from_str(&body).context("Failed to parse completion response")?;
126 let text = parsed
127 .choices
128 .into_iter()
129 .next()
130 .map(|choice| choice.text)
131 .unwrap_or_default();
132 Ok((text, parsed.id))
133 }
134 }
135}