ollama.rs

  1use crate::{
  2    assistant_settings::OllamaModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
  3};
  4use anyhow::Result;
  5use futures::StreamExt as _;
  6use futures::{future::BoxFuture, stream::BoxStream, FutureExt};
  7use gpui::{AnyView, AppContext, Task};
  8use http::HttpClient;
  9use ollama::{
 10    get_models, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest, Role as OllamaRole,
 11};
 12use std::sync::Arc;
 13use std::time::Duration;
 14use ui::{prelude::*, ButtonLike, ElevationIndex};
 15
 16const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
 17
 18pub struct OllamaCompletionProvider {
 19    api_url: String,
 20    model: OllamaModel,
 21    http_client: Arc<dyn HttpClient>,
 22    low_speed_timeout: Option<Duration>,
 23    settings_version: usize,
 24    available_models: Vec<OllamaModel>,
 25}
 26
 27impl OllamaCompletionProvider {
 28    pub fn new(
 29        model: OllamaModel,
 30        api_url: String,
 31        http_client: Arc<dyn HttpClient>,
 32        low_speed_timeout: Option<Duration>,
 33        settings_version: usize,
 34    ) -> Self {
 35        Self {
 36            api_url,
 37            model,
 38            http_client,
 39            low_speed_timeout,
 40            settings_version,
 41            available_models: Default::default(),
 42        }
 43    }
 44
 45    pub fn update(
 46        &mut self,
 47        model: OllamaModel,
 48        api_url: String,
 49        low_speed_timeout: Option<Duration>,
 50        settings_version: usize,
 51    ) {
 52        self.model = model;
 53        self.api_url = api_url;
 54        self.low_speed_timeout = low_speed_timeout;
 55        self.settings_version = settings_version;
 56    }
 57
 58    pub fn available_models(&self) -> impl Iterator<Item = &OllamaModel> {
 59        self.available_models.iter()
 60    }
 61
 62    pub fn settings_version(&self) -> usize {
 63        self.settings_version
 64    }
 65
 66    pub fn is_authenticated(&self) -> bool {
 67        !self.available_models.is_empty()
 68    }
 69
 70    pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
 71        if self.is_authenticated() {
 72            Task::ready(Ok(()))
 73        } else {
 74            self.fetch_models(cx)
 75        }
 76    }
 77
 78    pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
 79        self.fetch_models(cx)
 80    }
 81
 82    pub fn fetch_models(&self, cx: &AppContext) -> Task<Result<()>> {
 83        let http_client = self.http_client.clone();
 84        let api_url = self.api_url.clone();
 85
 86        // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
 87        cx.spawn(|mut cx| async move {
 88            let models = get_models(http_client.as_ref(), &api_url, None).await?;
 89
 90            let mut models: Vec<OllamaModel> = models
 91                .into_iter()
 92                // Since there is no metadata from the Ollama API
 93                // indicating which models are embedding models,
 94                // simply filter out models with "-embed" in their name
 95                .filter(|model| !model.name.contains("-embed"))
 96                .map(|model| OllamaModel::new(&model.name, &model.details.parameter_size))
 97                .collect();
 98
 99            models.sort_by(|a, b| a.name.cmp(&b.name));
100
101            cx.update_global::<CompletionProvider, _>(|provider, _cx| {
102                if let CompletionProvider::Ollama(provider) = provider {
103                    provider.available_models = models;
104                }
105            })
106        })
107    }
108
109    pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
110        cx.new_view(|cx| DownloadOllamaMessage::new(cx)).into()
111    }
112
113    pub fn model(&self) -> OllamaModel {
114        self.model.clone()
115    }
116
117    pub fn count_tokens(
118        &self,
119        request: LanguageModelRequest,
120        _cx: &AppContext,
121    ) -> BoxFuture<'static, Result<usize>> {
122        // There is no endpoint for this _yet_ in Ollama
123        // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
124        let token_count = request
125            .messages
126            .iter()
127            .map(|msg| msg.content.chars().count())
128            .sum::<usize>()
129            / 4;
130
131        async move { Ok(token_count) }.boxed()
132    }
133
134    pub fn complete(
135        &self,
136        request: LanguageModelRequest,
137    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
138        let request = self.to_ollama_request(request);
139
140        let http_client = self.http_client.clone();
141        let api_url = self.api_url.clone();
142        let low_speed_timeout = self.low_speed_timeout;
143        async move {
144            let request =
145                stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout);
146            let response = request.await?;
147            let stream = response
148                .filter_map(|response| async move {
149                    match response {
150                        Ok(delta) => {
151                            let content = match delta.message {
152                                ChatMessage::User { content } => content,
153                                ChatMessage::Assistant { content } => content,
154                                ChatMessage::System { content } => content,
155                            };
156                            Some(Ok(content))
157                        }
158                        Err(error) => Some(Err(error)),
159                    }
160                })
161                .boxed();
162            Ok(stream)
163        }
164        .boxed()
165    }
166
167    fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
168        let model = match request.model {
169            LanguageModel::Ollama(model) => model,
170            _ => self.model(),
171        };
172
173        ChatRequest {
174            model: model.name,
175            messages: request
176                .messages
177                .into_iter()
178                .map(|msg| match msg.role {
179                    Role::User => ChatMessage::User {
180                        content: msg.content,
181                    },
182                    Role::Assistant => ChatMessage::Assistant {
183                        content: msg.content,
184                    },
185                    Role::System => ChatMessage::System {
186                        content: msg.content,
187                    },
188                })
189                .collect(),
190            keep_alive: model.keep_alive,
191            stream: true,
192            options: Some(ChatOptions {
193                num_ctx: Some(model.max_tokens),
194                stop: Some(request.stop),
195                temperature: Some(request.temperature),
196                ..Default::default()
197            }),
198        }
199    }
200}
201
202impl From<Role> for ollama::Role {
203    fn from(val: Role) -> Self {
204        match val {
205            Role::User => OllamaRole::User,
206            Role::Assistant => OllamaRole::Assistant,
207            Role::System => OllamaRole::System,
208        }
209    }
210}
211
212struct DownloadOllamaMessage {}
213
214impl DownloadOllamaMessage {
215    pub fn new(_cx: &mut ViewContext<Self>) -> Self {
216        Self {}
217    }
218
219    fn render_download_button(&self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
220        ButtonLike::new("download_ollama_button")
221            .style(ButtonStyle::Filled)
222            .size(ButtonSize::Large)
223            .layer(ElevationIndex::ModalSurface)
224            .child(Label::new("Get Ollama"))
225            .on_click(move |_, cx| cx.open_url(OLLAMA_DOWNLOAD_URL))
226    }
227}
228
229impl Render for DownloadOllamaMessage {
230    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
231        v_flex()
232            .p_4()
233            .size_full()
234            .child(Label::new("To use Ollama models via the assistant, Ollama must be running on your machine.").size(LabelSize::Large))
235            .child(
236                h_flex()
237                    .w_full()
238                    .p_4()
239                    .justify_center()
240                    .child(
241                        self.render_download_button(cx)
242                    )
243            )
244            .into_any()
245    }
246}