1use anyhow::{Context as _, Result};
2use cloud_llm_client::predict_edits_v3::{RawCompletionRequest, RawCompletionResponse};
3use futures::AsyncReadExt as _;
4use gpui::{App, AppContext as _, Entity, Global, SharedString, Task, http_client};
5use language::language_settings::{OpenAiCompatibleEditPredictionSettings, all_language_settings};
6use language_model::{ApiKeyState, EnvVar, env_var};
7use std::sync::Arc;
8
9pub fn open_ai_compatible_api_url(cx: &App) -> SharedString {
10 all_language_settings(None, cx)
11 .edit_predictions
12 .open_ai_compatible_api
13 .as_ref()
14 .map(|settings| settings.api_url.clone())
15 .unwrap_or_default()
16 .into()
17}
18
19pub const OPEN_AI_COMPATIBLE_CREDENTIALS_USERNAME: &str = "openai-compatible-api-token";
20pub static OPEN_AI_COMPATIBLE_TOKEN_ENV_VAR: std::sync::LazyLock<EnvVar> =
21 env_var!("ZED_OPEN_AI_COMPATIBLE_EDIT_PREDICTION_API_KEY");
22
23struct GlobalOpenAiCompatibleApiKey(Entity<ApiKeyState>);
24
25impl Global for GlobalOpenAiCompatibleApiKey {}
26
27pub fn open_ai_compatible_api_token(cx: &mut App) -> Entity<ApiKeyState> {
28 if let Some(global) = cx.try_global::<GlobalOpenAiCompatibleApiKey>() {
29 return global.0.clone();
30 }
31
32 let entity = cx.new(|cx| {
33 ApiKeyState::new(
34 open_ai_compatible_api_url(cx),
35 OPEN_AI_COMPATIBLE_TOKEN_ENV_VAR.clone(),
36 )
37 });
38 cx.set_global(GlobalOpenAiCompatibleApiKey(entity.clone()));
39 entity
40}
41
42pub fn load_open_ai_compatible_api_token(
43 cx: &mut App,
44) -> Task<Result<(), language_model::AuthenticateError>> {
45 let credentials_provider = zed_credentials_provider::global(cx);
46 let api_url = open_ai_compatible_api_url(cx);
47 open_ai_compatible_api_token(cx).update(cx, |key_state, cx| {
48 key_state.load_if_needed(api_url, |s| s, credentials_provider, cx)
49 })
50}
51
52pub fn load_open_ai_compatible_api_key_if_needed(
53 provider: settings::EditPredictionProvider,
54 cx: &mut App,
55) -> Option<Arc<str>> {
56 if provider != settings::EditPredictionProvider::OpenAiCompatibleApi {
57 return None;
58 }
59 _ = load_open_ai_compatible_api_token(cx);
60 let url = open_ai_compatible_api_url(cx);
61 return open_ai_compatible_api_token(cx).read(cx).key(&url);
62}
63
64pub(crate) async fn send_custom_server_request(
65 provider: settings::EditPredictionProvider,
66 settings: &OpenAiCompatibleEditPredictionSettings,
67 prompt: String,
68 max_tokens: u32,
69 stop_tokens: Vec<String>,
70 api_key: Option<Arc<str>>,
71 http_client: &Arc<dyn http_client::HttpClient>,
72) -> Result<(String, String)> {
73 match provider {
74 settings::EditPredictionProvider::Ollama => {
75 let response = crate::ollama::make_request(
76 settings.clone(),
77 prompt,
78 stop_tokens,
79 http_client.clone(),
80 )
81 .await?;
82 Ok((response.response, response.created_at))
83 }
84 _ => {
85 let request = RawCompletionRequest {
86 model: settings.model.clone(),
87 prompt,
88 max_tokens: Some(max_tokens),
89 temperature: None,
90 stop: stop_tokens
91 .into_iter()
92 .map(std::borrow::Cow::Owned)
93 .collect(),
94 environment: None,
95 };
96
97 let request_body = serde_json::to_string(&request)?;
98 let mut http_request_builder = http_client::Request::builder()
99 .method(http_client::Method::POST)
100 .uri(settings.api_url.as_ref())
101 .header("Content-Type", "application/json");
102
103 if let Some(api_key) = api_key {
104 http_request_builder =
105 http_request_builder.header("Authorization", format!("Bearer {}", api_key));
106 }
107
108 let http_request =
109 http_request_builder.body(http_client::AsyncBody::from(request_body))?;
110
111 let mut response = http_client.send(http_request).await?;
112 let status = response.status();
113
114 if !status.is_success() {
115 let mut body = String::new();
116 response.body_mut().read_to_string(&mut body).await?;
117 anyhow::bail!("custom server error: {} - {}", status, body);
118 }
119
120 let mut body = String::new();
121 response.body_mut().read_to_string(&mut body).await?;
122
123 let parsed: RawCompletionResponse =
124 serde_json::from_str(&body).context("Failed to parse completion response")?;
125 let text = parsed
126 .choices
127 .into_iter()
128 .next()
129 .map(|choice| choice.text)
130 .unwrap_or_default();
131 Ok((text, parsed.id))
132 }
133 }
134}