1use crate::{
2 assistant_settings::OllamaModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
3};
4use anyhow::Result;
5use futures::StreamExt as _;
6use futures::{future::BoxFuture, stream::BoxStream, FutureExt};
7use gpui::{AnyView, AppContext, Task};
8use http::HttpClient;
9use ollama::{
10 get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
11 Role as OllamaRole,
12};
13use std::sync::Arc;
14use std::time::Duration;
15use ui::{prelude::*, ButtonLike, ElevationIndex};
16
17const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
18const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
19
20pub struct OllamaCompletionProvider {
21 api_url: String,
22 model: OllamaModel,
23 http_client: Arc<dyn HttpClient>,
24 low_speed_timeout: Option<Duration>,
25 settings_version: usize,
26 available_models: Vec<OllamaModel>,
27}
28
29impl OllamaCompletionProvider {
30 pub fn new(
31 model: OllamaModel,
32 api_url: String,
33 http_client: Arc<dyn HttpClient>,
34 low_speed_timeout: Option<Duration>,
35 settings_version: usize,
36 cx: &AppContext,
37 ) -> Self {
38 cx.spawn({
39 let api_url = api_url.clone();
40 let client = http_client.clone();
41 let model = model.name.clone();
42
43 |_| async move {
44 if model.is_empty() {
45 return Ok(());
46 }
47 preload_model(client.as_ref(), &api_url, &model).await
48 }
49 })
50 .detach_and_log_err(cx);
51
52 Self {
53 api_url,
54 model,
55 http_client,
56 low_speed_timeout,
57 settings_version,
58 available_models: Default::default(),
59 }
60 }
61
62 pub fn update(
63 &mut self,
64 model: OllamaModel,
65 api_url: String,
66 low_speed_timeout: Option<Duration>,
67 settings_version: usize,
68 cx: &AppContext,
69 ) {
70 cx.spawn({
71 let api_url = api_url.clone();
72 let client = self.http_client.clone();
73 let model = model.name.clone();
74
75 |_| async move { preload_model(client.as_ref(), &api_url, &model).await }
76 })
77 .detach_and_log_err(cx);
78
79 if model.name.is_empty() {
80 self.select_first_available_model()
81 } else {
82 self.model = model;
83 }
84
85 self.api_url = api_url;
86 self.low_speed_timeout = low_speed_timeout;
87 self.settings_version = settings_version;
88 }
89
90 pub fn available_models(&self) -> impl Iterator<Item = &OllamaModel> {
91 self.available_models.iter()
92 }
93
94 pub fn select_first_available_model(&mut self) {
95 if let Some(model) = self.available_models.first() {
96 self.model = model.clone();
97 }
98 }
99
100 pub fn settings_version(&self) -> usize {
101 self.settings_version
102 }
103
104 pub fn is_authenticated(&self) -> bool {
105 !self.available_models.is_empty()
106 }
107
108 pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
109 if self.is_authenticated() {
110 Task::ready(Ok(()))
111 } else {
112 self.fetch_models(cx)
113 }
114 }
115
116 pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
117 self.fetch_models(cx)
118 }
119
120 pub fn fetch_models(&self, cx: &AppContext) -> Task<Result<()>> {
121 let http_client = self.http_client.clone();
122 let api_url = self.api_url.clone();
123
124 // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
125 cx.spawn(|mut cx| async move {
126 let models = get_models(http_client.as_ref(), &api_url, None).await?;
127
128 let mut models: Vec<OllamaModel> = models
129 .into_iter()
130 // Since there is no metadata from the Ollama API
131 // indicating which models are embedding models,
132 // simply filter out models with "-embed" in their name
133 .filter(|model| !model.name.contains("-embed"))
134 .map(|model| OllamaModel::new(&model.name))
135 .collect();
136
137 models.sort_by(|a, b| a.name.cmp(&b.name));
138
139 cx.update_global::<CompletionProvider, _>(|provider, _cx| {
140 if let CompletionProvider::Ollama(provider) = provider {
141 provider.available_models = models;
142
143 if !provider.available_models.is_empty() && provider.model.name.is_empty() {
144 provider.select_first_available_model()
145 }
146 }
147 })
148 })
149 }
150
151 pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
152 let fetch_models = Box::new(move |cx: &mut WindowContext| {
153 cx.update_global::<CompletionProvider, _>(|provider, cx| {
154 if let CompletionProvider::Ollama(provider) = provider {
155 provider.fetch_models(cx)
156 } else {
157 Task::ready(Ok(()))
158 }
159 })
160 });
161
162 cx.new_view(|cx| DownloadOllamaMessage::new(fetch_models, cx))
163 .into()
164 }
165
166 pub fn model(&self) -> OllamaModel {
167 self.model.clone()
168 }
169
170 pub fn count_tokens(
171 &self,
172 request: LanguageModelRequest,
173 _cx: &AppContext,
174 ) -> BoxFuture<'static, Result<usize>> {
175 // There is no endpoint for this _yet_ in Ollama
176 // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
177 let token_count = request
178 .messages
179 .iter()
180 .map(|msg| msg.content.chars().count())
181 .sum::<usize>()
182 / 4;
183
184 async move { Ok(token_count) }.boxed()
185 }
186
187 pub fn complete(
188 &self,
189 request: LanguageModelRequest,
190 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
191 let request = self.to_ollama_request(request);
192
193 let http_client = self.http_client.clone();
194 let api_url = self.api_url.clone();
195 let low_speed_timeout = self.low_speed_timeout;
196 async move {
197 let request =
198 stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout);
199 let response = request.await?;
200 let stream = response
201 .filter_map(|response| async move {
202 match response {
203 Ok(delta) => {
204 let content = match delta.message {
205 ChatMessage::User { content } => content,
206 ChatMessage::Assistant { content } => content,
207 ChatMessage::System { content } => content,
208 };
209 Some(Ok(content))
210 }
211 Err(error) => Some(Err(error)),
212 }
213 })
214 .boxed();
215 Ok(stream)
216 }
217 .boxed()
218 }
219
220 fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
221 let model = match request.model {
222 LanguageModel::Ollama(model) => model,
223 _ => self.model(),
224 };
225
226 ChatRequest {
227 model: model.name,
228 messages: request
229 .messages
230 .into_iter()
231 .map(|msg| match msg.role {
232 Role::User => ChatMessage::User {
233 content: msg.content,
234 },
235 Role::Assistant => ChatMessage::Assistant {
236 content: msg.content,
237 },
238 Role::System => ChatMessage::System {
239 content: msg.content,
240 },
241 })
242 .collect(),
243 keep_alive: model.keep_alive.unwrap_or_default(),
244 stream: true,
245 options: Some(ChatOptions {
246 num_ctx: Some(model.max_tokens),
247 stop: Some(request.stop),
248 temperature: Some(request.temperature),
249 ..Default::default()
250 }),
251 }
252 }
253}
254
255impl From<Role> for ollama::Role {
256 fn from(val: Role) -> Self {
257 match val {
258 Role::User => OllamaRole::User,
259 Role::Assistant => OllamaRole::Assistant,
260 Role::System => OllamaRole::System,
261 }
262 }
263}
264
265struct DownloadOllamaMessage {
266 retry_connection: Box<dyn Fn(&mut WindowContext) -> Task<Result<()>>>,
267}
268
269impl DownloadOllamaMessage {
270 pub fn new(
271 retry_connection: Box<dyn Fn(&mut WindowContext) -> Task<Result<()>>>,
272 _cx: &mut ViewContext<Self>,
273 ) -> Self {
274 Self { retry_connection }
275 }
276
277 fn render_download_button(&self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
278 ButtonLike::new("download_ollama_button")
279 .style(ButtonStyle::Filled)
280 .size(ButtonSize::Large)
281 .layer(ElevationIndex::ModalSurface)
282 .child(Label::new("Get Ollama"))
283 .on_click(move |_, cx| cx.open_url(OLLAMA_DOWNLOAD_URL))
284 }
285
286 fn render_retry_button(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
287 ButtonLike::new("retry_ollama_models")
288 .style(ButtonStyle::Filled)
289 .size(ButtonSize::Large)
290 .layer(ElevationIndex::ModalSurface)
291 .child(Label::new("Retry"))
292 .on_click(cx.listener(move |this, _, cx| {
293 let connected = (this.retry_connection)(cx);
294
295 cx.spawn(|_this, _cx| async move {
296 connected.await?;
297 anyhow::Ok(())
298 })
299 .detach_and_log_err(cx)
300 }))
301 }
302
303 fn render_next_steps(&self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
304 v_flex()
305 .p_4()
306 .size_full()
307 .gap_2()
308 .child(
309 Label::new("Once Ollama is on your machine, make sure to download a model or two.")
310 .size(LabelSize::Large),
311 )
312 .child(
313 h_flex().w_full().p_4().justify_center().gap_2().child(
314 ButtonLike::new("view-models")
315 .style(ButtonStyle::Filled)
316 .size(ButtonSize::Large)
317 .layer(ElevationIndex::ModalSurface)
318 .child(Label::new("View Available Models"))
319 .on_click(move |_, cx| cx.open_url(OLLAMA_LIBRARY_URL)),
320 ),
321 )
322 }
323}
324
325impl Render for DownloadOllamaMessage {
326 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
327 v_flex()
328 .p_4()
329 .size_full()
330 .gap_2()
331 .child(Label::new("To use Ollama models via the assistant, Ollama must be running on your machine with at least one model downloaded.").size(LabelSize::Large))
332 .child(
333 h_flex()
334 .w_full()
335 .p_4()
336 .justify_center()
337 .gap_2()
338 .child(
339 self.render_download_button(cx)
340 )
341 .child(
342 self.render_retry_button(cx)
343 )
344 )
345 .child(self.render_next_steps(cx))
346 .into_any()
347 }
348}