ollama.rs

  1use crate::LanguageModelCompletionProvider;
  2use crate::{
  3    assistant_settings::OllamaModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
  4};
  5use anyhow::Result;
  6use futures::StreamExt as _;
  7use futures::{future::BoxFuture, stream::BoxStream, FutureExt};
  8use gpui::{AnyView, AppContext, Task};
  9use http::HttpClient;
 10use ollama::{
 11    get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
 12    Role as OllamaRole,
 13};
 14use std::sync::Arc;
 15use std::time::Duration;
 16use ui::{prelude::*, ButtonLike, ElevationIndex};
 17
 18const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
 19const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
 20
 21pub struct OllamaCompletionProvider {
 22    api_url: String,
 23    model: OllamaModel,
 24    http_client: Arc<dyn HttpClient>,
 25    low_speed_timeout: Option<Duration>,
 26    settings_version: usize,
 27    available_models: Vec<OllamaModel>,
 28}
 29
 30impl LanguageModelCompletionProvider for OllamaCompletionProvider {
 31    fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
 32        self.available_models
 33            .iter()
 34            .map(|m| LanguageModel::Ollama(m.clone()))
 35            .collect()
 36    }
 37
 38    fn settings_version(&self) -> usize {
 39        self.settings_version
 40    }
 41
 42    fn is_authenticated(&self) -> bool {
 43        !self.available_models.is_empty()
 44    }
 45
 46    fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
 47        if self.is_authenticated() {
 48            Task::ready(Ok(()))
 49        } else {
 50            self.fetch_models(cx)
 51        }
 52    }
 53
 54    fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
 55        let fetch_models = Box::new(move |cx: &mut WindowContext| {
 56            cx.update_global::<CompletionProvider, _>(|provider, cx| {
 57                provider
 58                    .update_current_as::<_, OllamaCompletionProvider>(|provider| {
 59                        provider.fetch_models(cx)
 60                    })
 61                    .unwrap_or_else(|| Task::ready(Ok(())))
 62            })
 63        });
 64
 65        cx.new_view(|cx| DownloadOllamaMessage::new(fetch_models, cx))
 66            .into()
 67    }
 68
 69    fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
 70        self.fetch_models(cx)
 71    }
 72
 73    fn model(&self) -> LanguageModel {
 74        LanguageModel::Ollama(self.model.clone())
 75    }
 76
 77    fn count_tokens(
 78        &self,
 79        request: LanguageModelRequest,
 80        _cx: &AppContext,
 81    ) -> BoxFuture<'static, Result<usize>> {
 82        // There is no endpoint for this _yet_ in Ollama
 83        // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
 84        let token_count = request
 85            .messages
 86            .iter()
 87            .map(|msg| msg.content.chars().count())
 88            .sum::<usize>()
 89            / 4;
 90
 91        async move { Ok(token_count) }.boxed()
 92    }
 93
 94    fn complete(
 95        &self,
 96        request: LanguageModelRequest,
 97    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
 98        let request = self.to_ollama_request(request);
 99
100        let http_client = self.http_client.clone();
101        let api_url = self.api_url.clone();
102        let low_speed_timeout = self.low_speed_timeout;
103        async move {
104            let request =
105                stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout);
106            let response = request.await?;
107            let stream = response
108                .filter_map(|response| async move {
109                    match response {
110                        Ok(delta) => {
111                            let content = match delta.message {
112                                ChatMessage::User { content } => content,
113                                ChatMessage::Assistant { content } => content,
114                                ChatMessage::System { content } => content,
115                            };
116                            Some(Ok(content))
117                        }
118                        Err(error) => Some(Err(error)),
119                    }
120                })
121                .boxed();
122            Ok(stream)
123        }
124        .boxed()
125    }
126
127    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
128        self
129    }
130}
131
132impl OllamaCompletionProvider {
133    pub fn new(
134        model: OllamaModel,
135        api_url: String,
136        http_client: Arc<dyn HttpClient>,
137        low_speed_timeout: Option<Duration>,
138        settings_version: usize,
139        cx: &AppContext,
140    ) -> Self {
141        cx.spawn({
142            let api_url = api_url.clone();
143            let client = http_client.clone();
144            let model = model.name.clone();
145
146            |_| async move {
147                if model.is_empty() {
148                    return Ok(());
149                }
150                preload_model(client.as_ref(), &api_url, &model).await
151            }
152        })
153        .detach_and_log_err(cx);
154
155        Self {
156            api_url,
157            model,
158            http_client,
159            low_speed_timeout,
160            settings_version,
161            available_models: Default::default(),
162        }
163    }
164
165    pub fn update(
166        &mut self,
167        model: OllamaModel,
168        api_url: String,
169        low_speed_timeout: Option<Duration>,
170        settings_version: usize,
171        cx: &AppContext,
172    ) {
173        cx.spawn({
174            let api_url = api_url.clone();
175            let client = self.http_client.clone();
176            let model = model.name.clone();
177
178            |_| async move { preload_model(client.as_ref(), &api_url, &model).await }
179        })
180        .detach_and_log_err(cx);
181
182        if model.name.is_empty() {
183            self.select_first_available_model()
184        } else {
185            self.model = model;
186        }
187
188        self.api_url = api_url;
189        self.low_speed_timeout = low_speed_timeout;
190        self.settings_version = settings_version;
191    }
192
193    pub fn select_first_available_model(&mut self) {
194        if let Some(model) = self.available_models.first() {
195            self.model = model.clone();
196        }
197    }
198
199    pub fn fetch_models(&self, cx: &AppContext) -> Task<Result<()>> {
200        let http_client = self.http_client.clone();
201        let api_url = self.api_url.clone();
202
203        // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
204        cx.spawn(|mut cx| async move {
205            let models = get_models(http_client.as_ref(), &api_url, None).await?;
206
207            let mut models: Vec<OllamaModel> = models
208                .into_iter()
209                // Since there is no metadata from the Ollama API
210                // indicating which models are embedding models,
211                // simply filter out models with "-embed" in their name
212                .filter(|model| !model.name.contains("-embed"))
213                .map(|model| OllamaModel::new(&model.name))
214                .collect();
215
216            models.sort_by(|a, b| a.name.cmp(&b.name));
217
218            cx.update_global::<CompletionProvider, _>(|provider, _cx| {
219                provider.update_current_as::<_, OllamaCompletionProvider>(|provider| {
220                    provider.available_models = models;
221
222                    if !provider.available_models.is_empty() && provider.model.name.is_empty() {
223                        provider.select_first_available_model()
224                    }
225                });
226            })
227        })
228    }
229
230    fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
231        let model = match request.model {
232            LanguageModel::Ollama(model) => model,
233            _ => self.model.clone(),
234        };
235
236        ChatRequest {
237            model: model.name,
238            messages: request
239                .messages
240                .into_iter()
241                .map(|msg| match msg.role {
242                    Role::User => ChatMessage::User {
243                        content: msg.content,
244                    },
245                    Role::Assistant => ChatMessage::Assistant {
246                        content: msg.content,
247                    },
248                    Role::System => ChatMessage::System {
249                        content: msg.content,
250                    },
251                })
252                .collect(),
253            keep_alive: model.keep_alive.unwrap_or_default(),
254            stream: true,
255            options: Some(ChatOptions {
256                num_ctx: Some(model.max_tokens),
257                stop: Some(request.stop),
258                temperature: Some(request.temperature),
259                ..Default::default()
260            }),
261        }
262    }
263}
264
265impl From<Role> for ollama::Role {
266    fn from(val: Role) -> Self {
267        match val {
268            Role::User => OllamaRole::User,
269            Role::Assistant => OllamaRole::Assistant,
270            Role::System => OllamaRole::System,
271        }
272    }
273}
274
275struct DownloadOllamaMessage {
276    retry_connection: Box<dyn Fn(&mut WindowContext) -> Task<Result<()>>>,
277}
278
279impl DownloadOllamaMessage {
280    pub fn new(
281        retry_connection: Box<dyn Fn(&mut WindowContext) -> Task<Result<()>>>,
282        _cx: &mut ViewContext<Self>,
283    ) -> Self {
284        Self { retry_connection }
285    }
286
287    fn render_download_button(&self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
288        ButtonLike::new("download_ollama_button")
289            .style(ButtonStyle::Filled)
290            .size(ButtonSize::Large)
291            .layer(ElevationIndex::ModalSurface)
292            .child(Label::new("Get Ollama"))
293            .on_click(move |_, cx| cx.open_url(OLLAMA_DOWNLOAD_URL))
294    }
295
296    fn render_retry_button(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
297        ButtonLike::new("retry_ollama_models")
298            .style(ButtonStyle::Filled)
299            .size(ButtonSize::Large)
300            .layer(ElevationIndex::ModalSurface)
301            .child(Label::new("Retry"))
302            .on_click(cx.listener(move |this, _, cx| {
303                let connected = (this.retry_connection)(cx);
304
305                cx.spawn(|_this, _cx| async move {
306                    connected.await?;
307                    anyhow::Ok(())
308                })
309                .detach_and_log_err(cx)
310            }))
311    }
312
313    fn render_next_steps(&self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
314        v_flex()
315            .p_4()
316            .size_full()
317            .gap_2()
318            .child(
319                Label::new("Once Ollama is on your machine, make sure to download a model or two.")
320                    .size(LabelSize::Large),
321            )
322            .child(
323                h_flex().w_full().p_4().justify_center().gap_2().child(
324                    ButtonLike::new("view-models")
325                        .style(ButtonStyle::Filled)
326                        .size(ButtonSize::Large)
327                        .layer(ElevationIndex::ModalSurface)
328                        .child(Label::new("View Available Models"))
329                        .on_click(move |_, cx| cx.open_url(OLLAMA_LIBRARY_URL)),
330                ),
331            )
332    }
333}
334
335impl Render for DownloadOllamaMessage {
336    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
337        v_flex()
338            .p_4()
339            .size_full()
340            .gap_2()
341            .child(Label::new("To use Ollama models via the assistant, Ollama must be running on your machine with at least one model downloaded.").size(LabelSize::Large))
342            .child(
343                h_flex()
344                    .w_full()
345                    .p_4()
346                    .justify_center()
347                    .gap_2()
348                    .child(
349                        self.render_download_button(cx)
350                    )
351                    .child(
352                        self.render_retry_button(cx)
353                    )
354            )
355            .child(self.render_next_steps(cx))
356            .into_any()
357    }
358}