ollama.rs

  1use anyhow::{Context as _, Result};
  2use futures::AsyncReadExt as _;
  3use gpui::{
  4    App, SharedString,
  5    http_client::{self, HttpClient},
  6};
  7use language::language_settings::OpenAiCompatibleEditPredictionSettings;
  8use language_model::{LanguageModelProviderId, LanguageModelRegistry};
  9use serde::{Deserialize, Serialize};
 10use std::sync::Arc;
 11
 12#[derive(Debug, Serialize)]
 13pub(crate) struct OllamaGenerateRequest {
 14    model: String,
 15    prompt: String,
 16    raw: bool,
 17    stream: bool,
 18    #[serde(skip_serializing_if = "Option::is_none")]
 19    options: Option<OllamaGenerateOptions>,
 20}
 21
 22#[derive(Debug, Serialize)]
 23pub(crate) struct OllamaGenerateOptions {
 24    #[serde(skip_serializing_if = "Option::is_none")]
 25    num_predict: Option<u32>,
 26    #[serde(skip_serializing_if = "Option::is_none")]
 27    temperature: Option<f32>,
 28    #[serde(skip_serializing_if = "Option::is_none")]
 29    stop: Option<Vec<String>>,
 30}
 31
 32#[derive(Debug, Deserialize)]
 33pub(crate) struct OllamaGenerateResponse {
 34    pub created_at: String,
 35    pub response: String,
 36}
 37
 38const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("ollama");
 39
 40pub fn is_available(cx: &App) -> bool {
 41    LanguageModelRegistry::read_global(cx)
 42        .provider(&PROVIDER_ID)
 43        .is_some_and(|provider| provider.is_authenticated(cx))
 44}
 45
 46pub fn ensure_authenticated(cx: &mut App) {
 47    if let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&PROVIDER_ID) {
 48        provider.authenticate(cx).detach_and_log_err(cx);
 49    }
 50}
 51
 52pub fn fetch_models(cx: &mut App) -> Vec<SharedString> {
 53    let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&PROVIDER_ID) else {
 54        return Vec::new();
 55    };
 56    provider.authenticate(cx).detach_and_log_err(cx);
 57    let mut models: Vec<SharedString> = provider
 58        .provided_models(cx)
 59        .into_iter()
 60        .map(|model| SharedString::from(model.id().0.to_string()))
 61        .collect();
 62    models.sort();
 63    models
 64}
 65
 66pub(crate) async fn make_request(
 67    settings: OpenAiCompatibleEditPredictionSettings,
 68    prompt: String,
 69    stop_tokens: Vec<String>,
 70    http_client: Arc<dyn HttpClient>,
 71) -> Result<OllamaGenerateResponse> {
 72    let request = OllamaGenerateRequest {
 73        model: settings.model.clone(),
 74        prompt,
 75        raw: true,
 76        stream: false,
 77        options: Some(OllamaGenerateOptions {
 78            num_predict: Some(settings.max_output_tokens),
 79            temperature: Some(0.2),
 80            stop: Some(stop_tokens),
 81        }),
 82    };
 83
 84    let request_body = serde_json::to_string(&request)?;
 85    let http_request = http_client::Request::builder()
 86        .method(http_client::Method::POST)
 87        .uri(format!("{}/api/generate", settings.api_url))
 88        .header("Content-Type", "application/json")
 89        .body(http_client::AsyncBody::from(request_body))?;
 90
 91    let mut response = http_client.send(http_request).await?;
 92    let status = response.status();
 93
 94    log::debug!("Ollama: Response status: {}", status);
 95
 96    if !status.is_success() {
 97        let mut body = String::new();
 98        response.body_mut().read_to_string(&mut body).await?;
 99        return Err(anyhow::anyhow!("Ollama API error: {} - {}", status, body));
100    }
101
102    let mut body = String::new();
103    response.body_mut().read_to_string(&mut body).await?;
104
105    let ollama_response: OllamaGenerateResponse =
106        serde_json::from_str(&body).context("Failed to parse Ollama response")?;
107    Ok(ollama_response)
108}