ollama.rs

  1use crate::{
  2    assistant_settings::OllamaModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
  3};
  4use anyhow::Result;
  5use futures::StreamExt as _;
  6use futures::{future::BoxFuture, stream::BoxStream, FutureExt};
  7use gpui::{AnyView, AppContext, Task};
  8use http::HttpClient;
  9use ollama::{
 10    get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
 11    Role as OllamaRole,
 12};
 13use std::sync::Arc;
 14use std::time::Duration;
 15use ui::{prelude::*, ButtonLike, ElevationIndex};
 16
 17const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
 18const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
 19
 20pub struct OllamaCompletionProvider {
 21    api_url: String,
 22    model: OllamaModel,
 23    http_client: Arc<dyn HttpClient>,
 24    low_speed_timeout: Option<Duration>,
 25    settings_version: usize,
 26    available_models: Vec<OllamaModel>,
 27}
 28
 29impl OllamaCompletionProvider {
 30    pub fn new(
 31        model: OllamaModel,
 32        api_url: String,
 33        http_client: Arc<dyn HttpClient>,
 34        low_speed_timeout: Option<Duration>,
 35        settings_version: usize,
 36        cx: &AppContext,
 37    ) -> Self {
 38        cx.spawn({
 39            let api_url = api_url.clone();
 40            let client = http_client.clone();
 41            let model = model.name.clone();
 42
 43            |_| async move {
 44                if model.is_empty() {
 45                    return Ok(());
 46                }
 47                preload_model(client.as_ref(), &api_url, &model).await
 48            }
 49        })
 50        .detach_and_log_err(cx);
 51
 52        Self {
 53            api_url,
 54            model,
 55            http_client,
 56            low_speed_timeout,
 57            settings_version,
 58            available_models: Default::default(),
 59        }
 60    }
 61
 62    pub fn update(
 63        &mut self,
 64        model: OllamaModel,
 65        api_url: String,
 66        low_speed_timeout: Option<Duration>,
 67        settings_version: usize,
 68        cx: &AppContext,
 69    ) {
 70        cx.spawn({
 71            let api_url = api_url.clone();
 72            let client = self.http_client.clone();
 73            let model = model.name.clone();
 74
 75            |_| async move { preload_model(client.as_ref(), &api_url, &model).await }
 76        })
 77        .detach_and_log_err(cx);
 78
 79        if model.name.is_empty() {
 80            self.select_first_available_model()
 81        } else {
 82            self.model = model;
 83        }
 84
 85        self.api_url = api_url;
 86        self.low_speed_timeout = low_speed_timeout;
 87        self.settings_version = settings_version;
 88    }
 89
 90    pub fn available_models(&self) -> impl Iterator<Item = &OllamaModel> {
 91        self.available_models.iter()
 92    }
 93
 94    pub fn select_first_available_model(&mut self) {
 95        if let Some(model) = self.available_models.first() {
 96            self.model = model.clone();
 97        }
 98    }
 99
100    pub fn settings_version(&self) -> usize {
101        self.settings_version
102    }
103
104    pub fn is_authenticated(&self) -> bool {
105        !self.available_models.is_empty()
106    }
107
108    pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
109        if self.is_authenticated() {
110            Task::ready(Ok(()))
111        } else {
112            self.fetch_models(cx)
113        }
114    }
115
116    pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
117        self.fetch_models(cx)
118    }
119
120    pub fn fetch_models(&self, cx: &AppContext) -> Task<Result<()>> {
121        let http_client = self.http_client.clone();
122        let api_url = self.api_url.clone();
123
124        // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
125        cx.spawn(|mut cx| async move {
126            let models = get_models(http_client.as_ref(), &api_url, None).await?;
127
128            let mut models: Vec<OllamaModel> = models
129                .into_iter()
130                // Since there is no metadata from the Ollama API
131                // indicating which models are embedding models,
132                // simply filter out models with "-embed" in their name
133                .filter(|model| !model.name.contains("-embed"))
134                .map(|model| OllamaModel::new(&model.name))
135                .collect();
136
137            models.sort_by(|a, b| a.name.cmp(&b.name));
138
139            cx.update_global::<CompletionProvider, _>(|provider, _cx| {
140                if let CompletionProvider::Ollama(provider) = provider {
141                    provider.available_models = models;
142
143                    if !provider.available_models.is_empty() && provider.model.name.is_empty() {
144                        provider.select_first_available_model()
145                    }
146                }
147            })
148        })
149    }
150
151    pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
152        let fetch_models = Box::new(move |cx: &mut WindowContext| {
153            cx.update_global::<CompletionProvider, _>(|provider, cx| {
154                if let CompletionProvider::Ollama(provider) = provider {
155                    provider.fetch_models(cx)
156                } else {
157                    Task::ready(Ok(()))
158                }
159            })
160        });
161
162        cx.new_view(|cx| DownloadOllamaMessage::new(fetch_models, cx))
163            .into()
164    }
165
166    pub fn model(&self) -> OllamaModel {
167        self.model.clone()
168    }
169
170    pub fn count_tokens(
171        &self,
172        request: LanguageModelRequest,
173        _cx: &AppContext,
174    ) -> BoxFuture<'static, Result<usize>> {
175        // There is no endpoint for this _yet_ in Ollama
176        // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
177        let token_count = request
178            .messages
179            .iter()
180            .map(|msg| msg.content.chars().count())
181            .sum::<usize>()
182            / 4;
183
184        async move { Ok(token_count) }.boxed()
185    }
186
187    pub fn complete(
188        &self,
189        request: LanguageModelRequest,
190    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
191        let request = self.to_ollama_request(request);
192
193        let http_client = self.http_client.clone();
194        let api_url = self.api_url.clone();
195        let low_speed_timeout = self.low_speed_timeout;
196        async move {
197            let request =
198                stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout);
199            let response = request.await?;
200            let stream = response
201                .filter_map(|response| async move {
202                    match response {
203                        Ok(delta) => {
204                            let content = match delta.message {
205                                ChatMessage::User { content } => content,
206                                ChatMessage::Assistant { content } => content,
207                                ChatMessage::System { content } => content,
208                            };
209                            Some(Ok(content))
210                        }
211                        Err(error) => Some(Err(error)),
212                    }
213                })
214                .boxed();
215            Ok(stream)
216        }
217        .boxed()
218    }
219
220    fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
221        let model = match request.model {
222            LanguageModel::Ollama(model) => model,
223            _ => self.model(),
224        };
225
226        ChatRequest {
227            model: model.name,
228            messages: request
229                .messages
230                .into_iter()
231                .map(|msg| match msg.role {
232                    Role::User => ChatMessage::User {
233                        content: msg.content,
234                    },
235                    Role::Assistant => ChatMessage::Assistant {
236                        content: msg.content,
237                    },
238                    Role::System => ChatMessage::System {
239                        content: msg.content,
240                    },
241                })
242                .collect(),
243            keep_alive: model.keep_alive.unwrap_or_default(),
244            stream: true,
245            options: Some(ChatOptions {
246                num_ctx: Some(model.max_tokens),
247                stop: Some(request.stop),
248                temperature: Some(request.temperature),
249                ..Default::default()
250            }),
251        }
252    }
253}
254
255impl From<Role> for ollama::Role {
256    fn from(val: Role) -> Self {
257        match val {
258            Role::User => OllamaRole::User,
259            Role::Assistant => OllamaRole::Assistant,
260            Role::System => OllamaRole::System,
261        }
262    }
263}
264
265struct DownloadOllamaMessage {
266    retry_connection: Box<dyn Fn(&mut WindowContext) -> Task<Result<()>>>,
267}
268
269impl DownloadOllamaMessage {
270    pub fn new(
271        retry_connection: Box<dyn Fn(&mut WindowContext) -> Task<Result<()>>>,
272        _cx: &mut ViewContext<Self>,
273    ) -> Self {
274        Self { retry_connection }
275    }
276
277    fn render_download_button(&self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
278        ButtonLike::new("download_ollama_button")
279            .style(ButtonStyle::Filled)
280            .size(ButtonSize::Large)
281            .layer(ElevationIndex::ModalSurface)
282            .child(Label::new("Get Ollama"))
283            .on_click(move |_, cx| cx.open_url(OLLAMA_DOWNLOAD_URL))
284    }
285
286    fn render_retry_button(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
287        ButtonLike::new("retry_ollama_models")
288            .style(ButtonStyle::Filled)
289            .size(ButtonSize::Large)
290            .layer(ElevationIndex::ModalSurface)
291            .child(Label::new("Retry"))
292            .on_click(cx.listener(move |this, _, cx| {
293                let connected = (this.retry_connection)(cx);
294
295                cx.spawn(|_this, _cx| async move {
296                    connected.await?;
297                    anyhow::Ok(())
298                })
299                .detach_and_log_err(cx)
300            }))
301    }
302
303    fn render_next_steps(&self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
304        v_flex()
305            .p_4()
306            .size_full()
307            .gap_2()
308            .child(
309                Label::new("Once Ollama is on your machine, make sure to download a model or two.")
310                    .size(LabelSize::Large),
311            )
312            .child(
313                h_flex().w_full().p_4().justify_center().gap_2().child(
314                    ButtonLike::new("view-models")
315                        .style(ButtonStyle::Filled)
316                        .size(ButtonSize::Large)
317                        .layer(ElevationIndex::ModalSurface)
318                        .child(Label::new("View Available Models"))
319                        .on_click(move |_, cx| cx.open_url(OLLAMA_LIBRARY_URL)),
320                ),
321            )
322    }
323}
324
325impl Render for DownloadOllamaMessage {
326    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
327        v_flex()
328            .p_4()
329            .size_full()
330            .gap_2()
331            .child(Label::new("To use Ollama models via the assistant, Ollama must be running on your machine with at least one model downloaded.").size(LabelSize::Large))
332            .child(
333                h_flex()
334                    .w_full()
335                    .p_4()
336                    .justify_center()
337                    .gap_2()
338                    .child(
339                        self.render_download_button(cx)
340                    )
341                    .child(
342                        self.render_retry_button(cx)
343                    )
344            )
345            .child(self.render_next_steps(cx))
346            .into_any()
347    }
348}