1use anyhow::{Context as _, Result};
2use futures::AsyncReadExt as _;
3use gpui::{
4 App, SharedString,
5 http_client::{self, HttpClient},
6};
7use language::language_settings::OpenAiCompatibleEditPredictionSettings;
8use language_model::{LanguageModelProviderId, LanguageModelRegistry};
9use serde::{Deserialize, Serialize};
10use std::sync::Arc;
11
12#[derive(Debug, Serialize)]
13pub(crate) struct OllamaGenerateRequest {
14 model: String,
15 prompt: String,
16 raw: bool,
17 stream: bool,
18 #[serde(skip_serializing_if = "Option::is_none")]
19 options: Option<OllamaGenerateOptions>,
20}
21
22#[derive(Debug, Serialize)]
23pub(crate) struct OllamaGenerateOptions {
24 #[serde(skip_serializing_if = "Option::is_none")]
25 num_predict: Option<u32>,
26 #[serde(skip_serializing_if = "Option::is_none")]
27 temperature: Option<f32>,
28 #[serde(skip_serializing_if = "Option::is_none")]
29 stop: Option<Vec<String>>,
30}
31
32#[derive(Debug, Deserialize)]
33pub(crate) struct OllamaGenerateResponse {
34 pub created_at: String,
35 pub response: String,
36}
37
38const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("ollama");
39
40pub fn is_available(cx: &App) -> bool {
41 LanguageModelRegistry::read_global(cx)
42 .provider(&PROVIDER_ID)
43 .is_some_and(|provider| provider.is_authenticated(cx))
44}
45
46pub fn ensure_authenticated(cx: &mut App) {
47 if let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&PROVIDER_ID) {
48 provider.authenticate(cx).detach_and_log_err(cx);
49 }
50}
51
52pub fn fetch_models(cx: &mut App) -> Vec<SharedString> {
53 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&PROVIDER_ID) else {
54 return Vec::new();
55 };
56 provider.authenticate(cx).detach_and_log_err(cx);
57 let mut models: Vec<SharedString> = provider
58 .provided_models(cx)
59 .into_iter()
60 .map(|model| SharedString::from(model.id().0.to_string()))
61 .collect();
62 models.sort();
63 models
64}
65
66pub(crate) async fn make_request(
67 settings: OpenAiCompatibleEditPredictionSettings,
68 prompt: String,
69 stop_tokens: Vec<String>,
70 http_client: Arc<dyn HttpClient>,
71) -> Result<OllamaGenerateResponse> {
72 let request = OllamaGenerateRequest {
73 model: settings.model.clone(),
74 prompt,
75 raw: true,
76 stream: false,
77 options: Some(OllamaGenerateOptions {
78 num_predict: Some(settings.max_output_tokens),
79 temperature: Some(0.2),
80 stop: Some(stop_tokens),
81 }),
82 };
83
84 let request_body = serde_json::to_string(&request)?;
85 let http_request = http_client::Request::builder()
86 .method(http_client::Method::POST)
87 .uri(format!("{}/api/generate", settings.api_url))
88 .header("Content-Type", "application/json")
89 .body(http_client::AsyncBody::from(request_body))?;
90
91 let mut response = http_client.send(http_request).await?;
92 let status = response.status();
93
94 log::debug!("Ollama: Response status: {}", status);
95
96 if !status.is_success() {
97 let mut body = String::new();
98 response.body_mut().read_to_string(&mut body).await?;
99 return Err(anyhow::anyhow!("Ollama API error: {} - {}", status, body));
100 }
101
102 let mut body = String::new();
103 response.body_mut().read_to_string(&mut body).await?;
104
105 let ollama_response: OllamaGenerateResponse =
106 serde_json::from_str(&body).context("Failed to parse Ollama response")?;
107 Ok(ollama_response)
108}