1use crate::LanguageModelCompletionProvider;
2use crate::{CompletionProvider, LanguageModel, LanguageModelRequest};
3use anyhow::Result;
4use futures::StreamExt as _;
5use futures::{future::BoxFuture, stream::BoxStream, FutureExt};
6use gpui::{AnyView, AppContext, Task};
7use http::HttpClient;
8use language_model::Role;
9use ollama::Model as OllamaModel;
10use ollama::{
11 get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
12};
13use std::sync::Arc;
14use std::time::Duration;
15use ui::{prelude::*, ButtonLike, ElevationIndex};
16
17const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
18const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
19
20pub struct OllamaCompletionProvider {
21 api_url: String,
22 model: OllamaModel,
23 http_client: Arc<dyn HttpClient>,
24 low_speed_timeout: Option<Duration>,
25 settings_version: usize,
26 available_models: Vec<OllamaModel>,
27}
28
29impl LanguageModelCompletionProvider for OllamaCompletionProvider {
30 fn available_models(&self) -> Vec<LanguageModel> {
31 self.available_models
32 .iter()
33 .map(|m| LanguageModel::Ollama(m.clone()))
34 .collect()
35 }
36
37 fn settings_version(&self) -> usize {
38 self.settings_version
39 }
40
41 fn is_authenticated(&self) -> bool {
42 !self.available_models.is_empty()
43 }
44
45 fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
46 if self.is_authenticated() {
47 Task::ready(Ok(()))
48 } else {
49 self.fetch_models(cx)
50 }
51 }
52
53 fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
54 let fetch_models = Box::new(move |cx: &mut WindowContext| {
55 cx.update_global::<CompletionProvider, _>(|provider, cx| {
56 provider
57 .update_current_as::<_, OllamaCompletionProvider>(|provider| {
58 provider.fetch_models(cx)
59 })
60 .unwrap_or_else(|| Task::ready(Ok(())))
61 })
62 });
63
64 cx.new_view(|cx| DownloadOllamaMessage::new(fetch_models, cx))
65 .into()
66 }
67
68 fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
69 self.fetch_models(cx)
70 }
71
72 fn model(&self) -> LanguageModel {
73 LanguageModel::Ollama(self.model.clone())
74 }
75
76 fn count_tokens(
77 &self,
78 request: LanguageModelRequest,
79 _cx: &AppContext,
80 ) -> BoxFuture<'static, Result<usize>> {
81 // There is no endpoint for this _yet_ in Ollama
82 // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
83 let token_count = request
84 .messages
85 .iter()
86 .map(|msg| msg.content.chars().count())
87 .sum::<usize>()
88 / 4;
89
90 async move { Ok(token_count) }.boxed()
91 }
92
93 fn stream_completion(
94 &self,
95 request: LanguageModelRequest,
96 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
97 let request = self.to_ollama_request(request);
98
99 let http_client = self.http_client.clone();
100 let api_url = self.api_url.clone();
101 let low_speed_timeout = self.low_speed_timeout;
102 async move {
103 let request =
104 stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout);
105 let response = request.await?;
106 let stream = response
107 .filter_map(|response| async move {
108 match response {
109 Ok(delta) => {
110 let content = match delta.message {
111 ChatMessage::User { content } => content,
112 ChatMessage::Assistant { content } => content,
113 ChatMessage::System { content } => content,
114 };
115 Some(Ok(content))
116 }
117 Err(error) => Some(Err(error)),
118 }
119 })
120 .boxed();
121 Ok(stream)
122 }
123 .boxed()
124 }
125
126 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
127 self
128 }
129}
130
131impl OllamaCompletionProvider {
132 pub fn new(
133 model: OllamaModel,
134 api_url: String,
135 http_client: Arc<dyn HttpClient>,
136 low_speed_timeout: Option<Duration>,
137 settings_version: usize,
138 cx: &AppContext,
139 ) -> Self {
140 cx.spawn({
141 let api_url = api_url.clone();
142 let client = http_client.clone();
143 let model = model.name.clone();
144
145 |_| async move {
146 if model.is_empty() {
147 return Ok(());
148 }
149 preload_model(client.as_ref(), &api_url, &model).await
150 }
151 })
152 .detach_and_log_err(cx);
153
154 Self {
155 api_url,
156 model,
157 http_client,
158 low_speed_timeout,
159 settings_version,
160 available_models: Default::default(),
161 }
162 }
163
164 pub fn update(
165 &mut self,
166 model: OllamaModel,
167 api_url: String,
168 low_speed_timeout: Option<Duration>,
169 settings_version: usize,
170 cx: &AppContext,
171 ) {
172 cx.spawn({
173 let api_url = api_url.clone();
174 let client = self.http_client.clone();
175 let model = model.name.clone();
176
177 |_| async move { preload_model(client.as_ref(), &api_url, &model).await }
178 })
179 .detach_and_log_err(cx);
180
181 if model.name.is_empty() {
182 self.select_first_available_model()
183 } else {
184 self.model = model;
185 }
186
187 self.api_url = api_url;
188 self.low_speed_timeout = low_speed_timeout;
189 self.settings_version = settings_version;
190 }
191
192 pub fn select_first_available_model(&mut self) {
193 if let Some(model) = self.available_models.first() {
194 self.model = model.clone();
195 }
196 }
197
198 pub fn fetch_models(&self, cx: &AppContext) -> Task<Result<()>> {
199 let http_client = self.http_client.clone();
200 let api_url = self.api_url.clone();
201
202 // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
203 cx.spawn(|mut cx| async move {
204 let models = get_models(http_client.as_ref(), &api_url, None).await?;
205
206 let mut models: Vec<OllamaModel> = models
207 .into_iter()
208 // Since there is no metadata from the Ollama API
209 // indicating which models are embedding models,
210 // simply filter out models with "-embed" in their name
211 .filter(|model| !model.name.contains("-embed"))
212 .map(|model| OllamaModel::new(&model.name))
213 .collect();
214
215 models.sort_by(|a, b| a.name.cmp(&b.name));
216
217 cx.update_global::<CompletionProvider, _>(|provider, _cx| {
218 provider.update_current_as::<_, OllamaCompletionProvider>(|provider| {
219 provider.available_models = models;
220
221 if !provider.available_models.is_empty() && provider.model.name.is_empty() {
222 provider.select_first_available_model()
223 }
224 });
225 })
226 })
227 }
228
229 fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
230 let model = match request.model {
231 LanguageModel::Ollama(model) => model,
232 _ => self.model.clone(),
233 };
234
235 ChatRequest {
236 model: model.name,
237 messages: request
238 .messages
239 .into_iter()
240 .map(|msg| match msg.role {
241 Role::User => ChatMessage::User {
242 content: msg.content,
243 },
244 Role::Assistant => ChatMessage::Assistant {
245 content: msg.content,
246 },
247 Role::System => ChatMessage::System {
248 content: msg.content,
249 },
250 })
251 .collect(),
252 keep_alive: model.keep_alive.unwrap_or_default(),
253 stream: true,
254 options: Some(ChatOptions {
255 num_ctx: Some(model.max_tokens),
256 stop: Some(request.stop),
257 temperature: Some(request.temperature),
258 ..Default::default()
259 }),
260 }
261 }
262}
263
264struct DownloadOllamaMessage {
265 retry_connection: Box<dyn Fn(&mut WindowContext) -> Task<Result<()>>>,
266}
267
268impl DownloadOllamaMessage {
269 pub fn new(
270 retry_connection: Box<dyn Fn(&mut WindowContext) -> Task<Result<()>>>,
271 _cx: &mut ViewContext<Self>,
272 ) -> Self {
273 Self { retry_connection }
274 }
275
276 fn render_download_button(&self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
277 ButtonLike::new("download_ollama_button")
278 .style(ButtonStyle::Filled)
279 .size(ButtonSize::Large)
280 .layer(ElevationIndex::ModalSurface)
281 .child(Label::new("Get Ollama"))
282 .on_click(move |_, cx| cx.open_url(OLLAMA_DOWNLOAD_URL))
283 }
284
285 fn render_retry_button(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
286 ButtonLike::new("retry_ollama_models")
287 .style(ButtonStyle::Filled)
288 .size(ButtonSize::Large)
289 .layer(ElevationIndex::ModalSurface)
290 .child(Label::new("Retry"))
291 .on_click(cx.listener(move |this, _, cx| {
292 let connected = (this.retry_connection)(cx);
293
294 cx.spawn(|_this, _cx| async move {
295 connected.await?;
296 anyhow::Ok(())
297 })
298 .detach_and_log_err(cx)
299 }))
300 }
301
302 fn render_next_steps(&self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
303 v_flex()
304 .p_4()
305 .size_full()
306 .gap_2()
307 .child(
308 Label::new("Once Ollama is on your machine, make sure to download a model or two.")
309 .size(LabelSize::Large),
310 )
311 .child(
312 h_flex().w_full().p_4().justify_center().gap_2().child(
313 ButtonLike::new("view-models")
314 .style(ButtonStyle::Filled)
315 .size(ButtonSize::Large)
316 .layer(ElevationIndex::ModalSurface)
317 .child(Label::new("View Available Models"))
318 .on_click(move |_, cx| cx.open_url(OLLAMA_LIBRARY_URL)),
319 ),
320 )
321 }
322}
323
324impl Render for DownloadOllamaMessage {
325 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
326 v_flex()
327 .p_4()
328 .size_full()
329 .gap_2()
330 .child(Label::new("To use Ollama models via the assistant, Ollama must be running on your machine with at least one model downloaded.").size(LabelSize::Large))
331 .child(
332 h_flex()
333 .w_full()
334 .p_4()
335 .justify_center()
336 .gap_2()
337 .child(
338 self.render_download_button(cx)
339 )
340 .child(
341 self.render_retry_button(cx)
342 )
343 )
344 .child(self.render_next_steps(cx))
345 .into_any()
346 }
347}