ollama.rs

  1use crate::{
  2    assistant_settings::OllamaModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
  3};
  4use anyhow::Result;
  5use futures::StreamExt as _;
  6use futures::{future::BoxFuture, stream::BoxStream, FutureExt};
  7use gpui::{AnyView, AppContext, Task};
  8use http::HttpClient;
  9use ollama::{
 10    get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
 11    Role as OllamaRole,
 12};
 13use std::sync::Arc;
 14use std::time::Duration;
 15use ui::{prelude::*, ButtonLike, ElevationIndex};
 16
 17const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
 18
 19pub struct OllamaCompletionProvider {
 20    api_url: String,
 21    model: OllamaModel,
 22    http_client: Arc<dyn HttpClient>,
 23    low_speed_timeout: Option<Duration>,
 24    settings_version: usize,
 25    available_models: Vec<OllamaModel>,
 26}
 27
 28impl OllamaCompletionProvider {
 29    pub fn new(
 30        model: OllamaModel,
 31        api_url: String,
 32        http_client: Arc<dyn HttpClient>,
 33        low_speed_timeout: Option<Duration>,
 34        settings_version: usize,
 35        cx: &AppContext,
 36    ) -> Self {
 37        cx.spawn({
 38            let api_url = api_url.clone();
 39            let client = http_client.clone();
 40            let model = model.name.clone();
 41
 42            |_| async move { preload_model(client.as_ref(), &api_url, &model).await }
 43        })
 44        .detach_and_log_err(cx);
 45
 46        Self {
 47            api_url,
 48            model,
 49            http_client,
 50            low_speed_timeout,
 51            settings_version,
 52            available_models: Default::default(),
 53        }
 54    }
 55
 56    pub fn update(
 57        &mut self,
 58        model: OllamaModel,
 59        api_url: String,
 60        low_speed_timeout: Option<Duration>,
 61        settings_version: usize,
 62        cx: &AppContext,
 63    ) {
 64        cx.spawn({
 65            let api_url = api_url.clone();
 66            let client = self.http_client.clone();
 67            let model = model.name.clone();
 68
 69            |_| async move { preload_model(client.as_ref(), &api_url, &model).await }
 70        })
 71        .detach_and_log_err(cx);
 72
 73        self.model = model;
 74        self.api_url = api_url;
 75        self.low_speed_timeout = low_speed_timeout;
 76        self.settings_version = settings_version;
 77    }
 78
 79    pub fn available_models(&self) -> impl Iterator<Item = &OllamaModel> {
 80        self.available_models.iter()
 81    }
 82
 83    pub fn settings_version(&self) -> usize {
 84        self.settings_version
 85    }
 86
 87    pub fn is_authenticated(&self) -> bool {
 88        !self.available_models.is_empty()
 89    }
 90
 91    pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
 92        if self.is_authenticated() {
 93            Task::ready(Ok(()))
 94        } else {
 95            self.fetch_models(cx)
 96        }
 97    }
 98
 99    pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
100        self.fetch_models(cx)
101    }
102
103    pub fn fetch_models(&self, cx: &AppContext) -> Task<Result<()>> {
104        let http_client = self.http_client.clone();
105        let api_url = self.api_url.clone();
106
107        // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
108        cx.spawn(|mut cx| async move {
109            let models = get_models(http_client.as_ref(), &api_url, None).await?;
110
111            let mut models: Vec<OllamaModel> = models
112                .into_iter()
113                // Since there is no metadata from the Ollama API
114                // indicating which models are embedding models,
115                // simply filter out models with "-embed" in their name
116                .filter(|model| !model.name.contains("-embed"))
117                .map(|model| OllamaModel::new(&model.name))
118                .collect();
119
120            models.sort_by(|a, b| a.name.cmp(&b.name));
121
122            cx.update_global::<CompletionProvider, _>(|provider, _cx| {
123                if let CompletionProvider::Ollama(provider) = provider {
124                    provider.available_models = models;
125                }
126            })
127        })
128    }
129
130    pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
131        cx.new_view(|cx| DownloadOllamaMessage::new(cx)).into()
132    }
133
134    pub fn model(&self) -> OllamaModel {
135        self.model.clone()
136    }
137
138    pub fn count_tokens(
139        &self,
140        request: LanguageModelRequest,
141        _cx: &AppContext,
142    ) -> BoxFuture<'static, Result<usize>> {
143        // There is no endpoint for this _yet_ in Ollama
144        // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
145        let token_count = request
146            .messages
147            .iter()
148            .map(|msg| msg.content.chars().count())
149            .sum::<usize>()
150            / 4;
151
152        async move { Ok(token_count) }.boxed()
153    }
154
155    pub fn complete(
156        &self,
157        request: LanguageModelRequest,
158    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
159        let request = self.to_ollama_request(request);
160
161        let http_client = self.http_client.clone();
162        let api_url = self.api_url.clone();
163        let low_speed_timeout = self.low_speed_timeout;
164        async move {
165            let request =
166                stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout);
167            let response = request.await?;
168            let stream = response
169                .filter_map(|response| async move {
170                    match response {
171                        Ok(delta) => {
172                            let content = match delta.message {
173                                ChatMessage::User { content } => content,
174                                ChatMessage::Assistant { content } => content,
175                                ChatMessage::System { content } => content,
176                            };
177                            Some(Ok(content))
178                        }
179                        Err(error) => Some(Err(error)),
180                    }
181                })
182                .boxed();
183            Ok(stream)
184        }
185        .boxed()
186    }
187
188    fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
189        let model = match request.model {
190            LanguageModel::Ollama(model) => model,
191            _ => self.model(),
192        };
193
194        ChatRequest {
195            model: model.name,
196            messages: request
197                .messages
198                .into_iter()
199                .map(|msg| match msg.role {
200                    Role::User => ChatMessage::User {
201                        content: msg.content,
202                    },
203                    Role::Assistant => ChatMessage::Assistant {
204                        content: msg.content,
205                    },
206                    Role::System => ChatMessage::System {
207                        content: msg.content,
208                    },
209                })
210                .collect(),
211            keep_alive: model.keep_alive,
212            stream: true,
213            options: Some(ChatOptions {
214                num_ctx: Some(model.max_tokens),
215                stop: Some(request.stop),
216                temperature: Some(request.temperature),
217                ..Default::default()
218            }),
219        }
220    }
221}
222
223impl From<Role> for ollama::Role {
224    fn from(val: Role) -> Self {
225        match val {
226            Role::User => OllamaRole::User,
227            Role::Assistant => OllamaRole::Assistant,
228            Role::System => OllamaRole::System,
229        }
230    }
231}
232
233struct DownloadOllamaMessage {}
234
235impl DownloadOllamaMessage {
236    pub fn new(_cx: &mut ViewContext<Self>) -> Self {
237        Self {}
238    }
239
240    fn render_download_button(&self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
241        ButtonLike::new("download_ollama_button")
242            .style(ButtonStyle::Filled)
243            .size(ButtonSize::Large)
244            .layer(ElevationIndex::ModalSurface)
245            .child(Label::new("Get Ollama"))
246            .on_click(move |_, cx| cx.open_url(OLLAMA_DOWNLOAD_URL))
247    }
248}
249
250impl Render for DownloadOllamaMessage {
251    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
252        v_flex()
253            .p_4()
254            .size_full()
255            .child(Label::new("To use Ollama models via the assistant, Ollama must be running on your machine.").size(LabelSize::Large))
256            .child(
257                h_flex()
258                    .w_full()
259                    .p_4()
260                    .justify_center()
261                    .child(
262                        self.render_download_button(cx)
263                    )
264            )
265            .into_any()
266    }
267}