1use crate::LanguageModelCompletionProvider;
2use crate::{
3 assistant_settings::OllamaModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
4};
5use anyhow::Result;
6use futures::StreamExt as _;
7use futures::{future::BoxFuture, stream::BoxStream, FutureExt};
8use gpui::{AnyView, AppContext, Task};
9use http::HttpClient;
10use ollama::{
11 get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
12 Role as OllamaRole,
13};
14use std::sync::Arc;
15use std::time::Duration;
16use ui::{prelude::*, ButtonLike, ElevationIndex};
17
18const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
19const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
20
21pub struct OllamaCompletionProvider {
22 api_url: String,
23 model: OllamaModel,
24 http_client: Arc<dyn HttpClient>,
25 low_speed_timeout: Option<Duration>,
26 settings_version: usize,
27 available_models: Vec<OllamaModel>,
28}
29
30impl LanguageModelCompletionProvider for OllamaCompletionProvider {
31 fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
32 self.available_models
33 .iter()
34 .map(|m| LanguageModel::Ollama(m.clone()))
35 .collect()
36 }
37
38 fn settings_version(&self) -> usize {
39 self.settings_version
40 }
41
42 fn is_authenticated(&self) -> bool {
43 !self.available_models.is_empty()
44 }
45
46 fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
47 if self.is_authenticated() {
48 Task::ready(Ok(()))
49 } else {
50 self.fetch_models(cx)
51 }
52 }
53
54 fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
55 let fetch_models = Box::new(move |cx: &mut WindowContext| {
56 cx.update_global::<CompletionProvider, _>(|provider, cx| {
57 provider
58 .update_current_as::<_, OllamaCompletionProvider>(|provider| {
59 provider.fetch_models(cx)
60 })
61 .unwrap_or_else(|| Task::ready(Ok(())))
62 })
63 });
64
65 cx.new_view(|cx| DownloadOllamaMessage::new(fetch_models, cx))
66 .into()
67 }
68
69 fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
70 self.fetch_models(cx)
71 }
72
73 fn model(&self) -> LanguageModel {
74 LanguageModel::Ollama(self.model.clone())
75 }
76
77 fn count_tokens(
78 &self,
79 request: LanguageModelRequest,
80 _cx: &AppContext,
81 ) -> BoxFuture<'static, Result<usize>> {
82 // There is no endpoint for this _yet_ in Ollama
83 // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
84 let token_count = request
85 .messages
86 .iter()
87 .map(|msg| msg.content.chars().count())
88 .sum::<usize>()
89 / 4;
90
91 async move { Ok(token_count) }.boxed()
92 }
93
94 fn complete(
95 &self,
96 request: LanguageModelRequest,
97 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
98 let request = self.to_ollama_request(request);
99
100 let http_client = self.http_client.clone();
101 let api_url = self.api_url.clone();
102 let low_speed_timeout = self.low_speed_timeout;
103 async move {
104 let request =
105 stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout);
106 let response = request.await?;
107 let stream = response
108 .filter_map(|response| async move {
109 match response {
110 Ok(delta) => {
111 let content = match delta.message {
112 ChatMessage::User { content } => content,
113 ChatMessage::Assistant { content } => content,
114 ChatMessage::System { content } => content,
115 };
116 Some(Ok(content))
117 }
118 Err(error) => Some(Err(error)),
119 }
120 })
121 .boxed();
122 Ok(stream)
123 }
124 .boxed()
125 }
126
127 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
128 self
129 }
130}
131
132impl OllamaCompletionProvider {
133 pub fn new(
134 model: OllamaModel,
135 api_url: String,
136 http_client: Arc<dyn HttpClient>,
137 low_speed_timeout: Option<Duration>,
138 settings_version: usize,
139 cx: &AppContext,
140 ) -> Self {
141 cx.spawn({
142 let api_url = api_url.clone();
143 let client = http_client.clone();
144 let model = model.name.clone();
145
146 |_| async move {
147 if model.is_empty() {
148 return Ok(());
149 }
150 preload_model(client.as_ref(), &api_url, &model).await
151 }
152 })
153 .detach_and_log_err(cx);
154
155 Self {
156 api_url,
157 model,
158 http_client,
159 low_speed_timeout,
160 settings_version,
161 available_models: Default::default(),
162 }
163 }
164
165 pub fn update(
166 &mut self,
167 model: OllamaModel,
168 api_url: String,
169 low_speed_timeout: Option<Duration>,
170 settings_version: usize,
171 cx: &AppContext,
172 ) {
173 cx.spawn({
174 let api_url = api_url.clone();
175 let client = self.http_client.clone();
176 let model = model.name.clone();
177
178 |_| async move { preload_model(client.as_ref(), &api_url, &model).await }
179 })
180 .detach_and_log_err(cx);
181
182 if model.name.is_empty() {
183 self.select_first_available_model()
184 } else {
185 self.model = model;
186 }
187
188 self.api_url = api_url;
189 self.low_speed_timeout = low_speed_timeout;
190 self.settings_version = settings_version;
191 }
192
193 pub fn select_first_available_model(&mut self) {
194 if let Some(model) = self.available_models.first() {
195 self.model = model.clone();
196 }
197 }
198
199 pub fn fetch_models(&self, cx: &AppContext) -> Task<Result<()>> {
200 let http_client = self.http_client.clone();
201 let api_url = self.api_url.clone();
202
203 // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
204 cx.spawn(|mut cx| async move {
205 let models = get_models(http_client.as_ref(), &api_url, None).await?;
206
207 let mut models: Vec<OllamaModel> = models
208 .into_iter()
209 // Since there is no metadata from the Ollama API
210 // indicating which models are embedding models,
211 // simply filter out models with "-embed" in their name
212 .filter(|model| !model.name.contains("-embed"))
213 .map(|model| OllamaModel::new(&model.name))
214 .collect();
215
216 models.sort_by(|a, b| a.name.cmp(&b.name));
217
218 cx.update_global::<CompletionProvider, _>(|provider, _cx| {
219 provider.update_current_as::<_, OllamaCompletionProvider>(|provider| {
220 provider.available_models = models;
221
222 if !provider.available_models.is_empty() && provider.model.name.is_empty() {
223 provider.select_first_available_model()
224 }
225 });
226 })
227 })
228 }
229
230 fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
231 let model = match request.model {
232 LanguageModel::Ollama(model) => model,
233 _ => self.model.clone(),
234 };
235
236 ChatRequest {
237 model: model.name,
238 messages: request
239 .messages
240 .into_iter()
241 .map(|msg| match msg.role {
242 Role::User => ChatMessage::User {
243 content: msg.content,
244 },
245 Role::Assistant => ChatMessage::Assistant {
246 content: msg.content,
247 },
248 Role::System => ChatMessage::System {
249 content: msg.content,
250 },
251 })
252 .collect(),
253 keep_alive: model.keep_alive.unwrap_or_default(),
254 stream: true,
255 options: Some(ChatOptions {
256 num_ctx: Some(model.max_tokens),
257 stop: Some(request.stop),
258 temperature: Some(request.temperature),
259 ..Default::default()
260 }),
261 }
262 }
263}
264
265impl From<Role> for ollama::Role {
266 fn from(val: Role) -> Self {
267 match val {
268 Role::User => OllamaRole::User,
269 Role::Assistant => OllamaRole::Assistant,
270 Role::System => OllamaRole::System,
271 }
272 }
273}
274
275struct DownloadOllamaMessage {
276 retry_connection: Box<dyn Fn(&mut WindowContext) -> Task<Result<()>>>,
277}
278
279impl DownloadOllamaMessage {
280 pub fn new(
281 retry_connection: Box<dyn Fn(&mut WindowContext) -> Task<Result<()>>>,
282 _cx: &mut ViewContext<Self>,
283 ) -> Self {
284 Self { retry_connection }
285 }
286
287 fn render_download_button(&self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
288 ButtonLike::new("download_ollama_button")
289 .style(ButtonStyle::Filled)
290 .size(ButtonSize::Large)
291 .layer(ElevationIndex::ModalSurface)
292 .child(Label::new("Get Ollama"))
293 .on_click(move |_, cx| cx.open_url(OLLAMA_DOWNLOAD_URL))
294 }
295
296 fn render_retry_button(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
297 ButtonLike::new("retry_ollama_models")
298 .style(ButtonStyle::Filled)
299 .size(ButtonSize::Large)
300 .layer(ElevationIndex::ModalSurface)
301 .child(Label::new("Retry"))
302 .on_click(cx.listener(move |this, _, cx| {
303 let connected = (this.retry_connection)(cx);
304
305 cx.spawn(|_this, _cx| async move {
306 connected.await?;
307 anyhow::Ok(())
308 })
309 .detach_and_log_err(cx)
310 }))
311 }
312
313 fn render_next_steps(&self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
314 v_flex()
315 .p_4()
316 .size_full()
317 .gap_2()
318 .child(
319 Label::new("Once Ollama is on your machine, make sure to download a model or two.")
320 .size(LabelSize::Large),
321 )
322 .child(
323 h_flex().w_full().p_4().justify_center().gap_2().child(
324 ButtonLike::new("view-models")
325 .style(ButtonStyle::Filled)
326 .size(ButtonSize::Large)
327 .layer(ElevationIndex::ModalSurface)
328 .child(Label::new("View Available Models"))
329 .on_click(move |_, cx| cx.open_url(OLLAMA_LIBRARY_URL)),
330 ),
331 )
332 }
333}
334
335impl Render for DownloadOllamaMessage {
336 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
337 v_flex()
338 .p_4()
339 .size_full()
340 .gap_2()
341 .child(Label::new("To use Ollama models via the assistant, Ollama must be running on your machine with at least one model downloaded.").size(LabelSize::Large))
342 .child(
343 h_flex()
344 .w_full()
345 .p_4()
346 .justify_center()
347 .gap_2()
348 .child(
349 self.render_download_button(cx)
350 )
351 .child(
352 self.render_retry_button(cx)
353 )
354 )
355 .child(self.render_next_steps(cx))
356 .into_any()
357 }
358}