1use anyhow::{anyhow, Result};
2use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
3use gpui::{AnyView, AppContext, AsyncAppContext, ModelContext, Subscription, Task};
4use http_client::HttpClient;
5use ollama::{
6 get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
7};
8use settings::{Settings, SettingsStore};
9use std::{sync::Arc, time::Duration};
10use ui::{prelude::*, ButtonLike, ElevationIndex};
11
12use crate::{
13 settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
14 LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
15 LanguageModelProviderState, LanguageModelRequest, Role,
16};
17
18const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
19const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
20
21const PROVIDER_ID: &str = "ollama";
22const PROVIDER_NAME: &str = "Ollama";
23
24#[derive(Default, Debug, Clone, PartialEq)]
25pub struct OllamaSettings {
26 pub api_url: String,
27 pub low_speed_timeout: Option<Duration>,
28}
29
30pub struct OllamaLanguageModelProvider {
31 http_client: Arc<dyn HttpClient>,
32 state: gpui::Model<State>,
33}
34
35struct State {
36 http_client: Arc<dyn HttpClient>,
37 available_models: Vec<ollama::Model>,
38 _subscription: Subscription,
39}
40
41impl State {
42 fn fetch_models(&self, cx: &ModelContext<Self>) -> Task<Result<()>> {
43 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
44 let http_client = self.http_client.clone();
45 let api_url = settings.api_url.clone();
46
47 // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
48 cx.spawn(|this, mut cx| async move {
49 let models = get_models(http_client.as_ref(), &api_url, None).await?;
50
51 let mut models: Vec<ollama::Model> = models
52 .into_iter()
53 // Since there is no metadata from the Ollama API
54 // indicating which models are embedding models,
55 // simply filter out models with "-embed" in their name
56 .filter(|model| !model.name.contains("-embed"))
57 .map(|model| ollama::Model::new(&model.name))
58 .collect();
59
60 models.sort_by(|a, b| a.name.cmp(&b.name));
61
62 this.update(&mut cx, |this, cx| {
63 this.available_models = models;
64 cx.notify();
65 })
66 })
67 }
68}
69
70impl OllamaLanguageModelProvider {
71 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
72 let this = Self {
73 http_client: http_client.clone(),
74 state: cx.new_model(|cx| State {
75 http_client,
76 available_models: Default::default(),
77 _subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
78 this.fetch_models(cx).detach();
79 cx.notify();
80 }),
81 }),
82 };
83 this.fetch_models(cx).detach();
84 this
85 }
86
87 fn fetch_models(&self, cx: &AppContext) -> Task<Result<()>> {
88 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
89 let http_client = self.http_client.clone();
90 let api_url = settings.api_url.clone();
91
92 let state = self.state.clone();
93 // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
94 cx.spawn(|mut cx| async move {
95 let models = get_models(http_client.as_ref(), &api_url, None).await?;
96
97 let mut models: Vec<ollama::Model> = models
98 .into_iter()
99 // Since there is no metadata from the Ollama API
100 // indicating which models are embedding models,
101 // simply filter out models with "-embed" in their name
102 .filter(|model| !model.name.contains("-embed"))
103 .map(|model| ollama::Model::new(&model.name))
104 .collect();
105
106 models.sort_by(|a, b| a.name.cmp(&b.name));
107
108 state.update(&mut cx, |this, cx| {
109 this.available_models = models;
110 cx.notify();
111 })
112 })
113 }
114}
115
116impl LanguageModelProviderState for OllamaLanguageModelProvider {
117 fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
118 Some(cx.observe(&self.state, |_, _, cx| {
119 cx.notify();
120 }))
121 }
122}
123
124impl LanguageModelProvider for OllamaLanguageModelProvider {
125 fn id(&self) -> LanguageModelProviderId {
126 LanguageModelProviderId(PROVIDER_ID.into())
127 }
128
129 fn name(&self) -> LanguageModelProviderName {
130 LanguageModelProviderName(PROVIDER_NAME.into())
131 }
132
133 fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
134 self.state
135 .read(cx)
136 .available_models
137 .iter()
138 .map(|model| {
139 Arc::new(OllamaLanguageModel {
140 id: LanguageModelId::from(model.name.clone()),
141 model: model.clone(),
142 http_client: self.http_client.clone(),
143 }) as Arc<dyn LanguageModel>
144 })
145 .collect()
146 }
147
148 fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &AppContext) {
149 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
150 let http_client = self.http_client.clone();
151 let api_url = settings.api_url.clone();
152 let id = model.id().0.to_string();
153 cx.spawn(|_| async move { preload_model(http_client, &api_url, &id).await })
154 .detach_and_log_err(cx);
155 }
156
157 fn is_authenticated(&self, cx: &AppContext) -> bool {
158 !self.state.read(cx).available_models.is_empty()
159 }
160
161 fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
162 if self.is_authenticated(cx) {
163 Task::ready(Ok(()))
164 } else {
165 self.fetch_models(cx)
166 }
167 }
168
169 fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
170 let state = self.state.clone();
171 let fetch_models = Box::new(move |cx: &mut WindowContext| {
172 state.update(cx, |this, cx| this.fetch_models(cx))
173 });
174
175 cx.new_view(|cx| DownloadOllamaMessage::new(fetch_models, cx))
176 .into()
177 }
178
179 fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
180 self.fetch_models(cx)
181 }
182}
183
184pub struct OllamaLanguageModel {
185 id: LanguageModelId,
186 model: ollama::Model,
187 http_client: Arc<dyn HttpClient>,
188}
189
190impl OllamaLanguageModel {
191 fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
192 ChatRequest {
193 model: self.model.name.clone(),
194 messages: request
195 .messages
196 .into_iter()
197 .map(|msg| match msg.role {
198 Role::User => ChatMessage::User {
199 content: msg.content,
200 },
201 Role::Assistant => ChatMessage::Assistant {
202 content: msg.content,
203 },
204 Role::System => ChatMessage::System {
205 content: msg.content,
206 },
207 })
208 .collect(),
209 keep_alive: self.model.keep_alive.clone().unwrap_or_default(),
210 stream: true,
211 options: Some(ChatOptions {
212 num_ctx: Some(self.model.max_tokens),
213 stop: Some(request.stop),
214 temperature: Some(request.temperature),
215 ..Default::default()
216 }),
217 }
218 }
219}
220
221impl LanguageModel for OllamaLanguageModel {
222 fn id(&self) -> LanguageModelId {
223 self.id.clone()
224 }
225
226 fn name(&self) -> LanguageModelName {
227 LanguageModelName::from(self.model.display_name().to_string())
228 }
229
230 fn provider_id(&self) -> LanguageModelProviderId {
231 LanguageModelProviderId(PROVIDER_ID.into())
232 }
233
234 fn provider_name(&self) -> LanguageModelProviderName {
235 LanguageModelProviderName(PROVIDER_NAME.into())
236 }
237
238 fn max_token_count(&self) -> usize {
239 self.model.max_token_count()
240 }
241
242 fn telemetry_id(&self) -> String {
243 format!("ollama/{}", self.model.id())
244 }
245
246 fn count_tokens(
247 &self,
248 request: LanguageModelRequest,
249 _cx: &AppContext,
250 ) -> BoxFuture<'static, Result<usize>> {
251 // There is no endpoint for this _yet_ in Ollama
252 // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
253 let token_count = request
254 .messages
255 .iter()
256 .map(|msg| msg.content.chars().count())
257 .sum::<usize>()
258 / 4;
259
260 async move { Ok(token_count) }.boxed()
261 }
262
263 fn stream_completion(
264 &self,
265 request: LanguageModelRequest,
266 cx: &AsyncAppContext,
267 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
268 let request = self.to_ollama_request(request);
269
270 let http_client = self.http_client.clone();
271 let Ok((api_url, low_speed_timeout)) = cx.update(|cx| {
272 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
273 (settings.api_url.clone(), settings.low_speed_timeout)
274 }) else {
275 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
276 };
277
278 async move {
279 let request =
280 stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout);
281 let response = request.await?;
282 let stream = response
283 .filter_map(|response| async move {
284 match response {
285 Ok(delta) => {
286 let content = match delta.message {
287 ChatMessage::User { content } => content,
288 ChatMessage::Assistant { content } => content,
289 ChatMessage::System { content } => content,
290 };
291 Some(Ok(content))
292 }
293 Err(error) => Some(Err(error)),
294 }
295 })
296 .boxed();
297 Ok(stream)
298 }
299 .boxed()
300 }
301}
302
303struct DownloadOllamaMessage {
304 retry_connection: Box<dyn Fn(&mut WindowContext) -> Task<Result<()>>>,
305}
306
307impl DownloadOllamaMessage {
308 pub fn new(
309 retry_connection: Box<dyn Fn(&mut WindowContext) -> Task<Result<()>>>,
310 _cx: &mut ViewContext<Self>,
311 ) -> Self {
312 Self { retry_connection }
313 }
314
315 fn render_download_button(&self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
316 ButtonLike::new("download_ollama_button")
317 .style(ButtonStyle::Filled)
318 .size(ButtonSize::Large)
319 .layer(ElevationIndex::ModalSurface)
320 .child(Label::new("Get Ollama"))
321 .on_click(move |_, cx| cx.open_url(OLLAMA_DOWNLOAD_URL))
322 }
323
324 fn render_retry_button(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
325 ButtonLike::new("retry_ollama_models")
326 .style(ButtonStyle::Filled)
327 .size(ButtonSize::Large)
328 .layer(ElevationIndex::ModalSurface)
329 .child(Label::new("Retry"))
330 .on_click(cx.listener(move |this, _, cx| {
331 let connected = (this.retry_connection)(cx);
332
333 cx.spawn(|_this, _cx| async move {
334 connected.await?;
335 anyhow::Ok(())
336 })
337 .detach_and_log_err(cx)
338 }))
339 }
340
341 fn render_next_steps(&self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
342 v_flex()
343 .p_4()
344 .size_full()
345 .gap_2()
346 .child(
347 Label::new("Once Ollama is on your machine, make sure to download a model or two.")
348 .size(LabelSize::Large),
349 )
350 .child(
351 h_flex().w_full().p_4().justify_center().gap_2().child(
352 ButtonLike::new("view-models")
353 .style(ButtonStyle::Filled)
354 .size(ButtonSize::Large)
355 .layer(ElevationIndex::ModalSurface)
356 .child(Label::new("View Available Models"))
357 .on_click(move |_, cx| cx.open_url(OLLAMA_LIBRARY_URL)),
358 ),
359 )
360 }
361}
362
363impl Render for DownloadOllamaMessage {
364 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
365 v_flex()
366 .p_4()
367 .size_full()
368 .gap_2()
369 .child(Label::new("To use Ollama models via the assistant, Ollama must be running on your machine with at least one model downloaded.").size(LabelSize::Large))
370 .child(
371 h_flex()
372 .w_full()
373 .p_4()
374 .justify_center()
375 .gap_2()
376 .child(
377 self.render_download_button(cx)
378 )
379 .child(
380 self.render_retry_button(cx)
381 )
382 )
383 .child(self.render_next_steps(cx))
384 .into_any()
385 }
386}