1use anyhow::{anyhow, Result};
2use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
3use gpui::{AnyView, AppContext, AsyncAppContext, ModelContext, Subscription, Task};
4use http_client::HttpClient;
5use ollama::{
6 get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
7};
8use settings::{Settings, SettingsStore};
9use std::{future, sync::Arc, time::Duration};
10use ui::{prelude::*, ButtonLike, ElevationIndex};
11
12use crate::{
13 settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
14 LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
15 LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
16};
17
18const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
19const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
20
21const PROVIDER_ID: &str = "ollama";
22const PROVIDER_NAME: &str = "Ollama";
23
24#[derive(Default, Debug, Clone, PartialEq)]
25pub struct OllamaSettings {
26 pub api_url: String,
27 pub low_speed_timeout: Option<Duration>,
28}
29
30pub struct OllamaLanguageModelProvider {
31 http_client: Arc<dyn HttpClient>,
32 state: gpui::Model<State>,
33}
34
35pub struct State {
36 http_client: Arc<dyn HttpClient>,
37 available_models: Vec<ollama::Model>,
38 _subscription: Subscription,
39}
40
41impl State {
42 fn fetch_models(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
43 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
44 let http_client = self.http_client.clone();
45 let api_url = settings.api_url.clone();
46
47 // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
48 cx.spawn(|this, mut cx| async move {
49 let models = get_models(http_client.as_ref(), &api_url, None).await?;
50
51 let mut models: Vec<ollama::Model> = models
52 .into_iter()
53 // Since there is no metadata from the Ollama API
54 // indicating which models are embedding models,
55 // simply filter out models with "-embed" in their name
56 .filter(|model| !model.name.contains("-embed"))
57 .map(|model| ollama::Model::new(&model.name))
58 .collect();
59
60 models.sort_by(|a, b| a.name.cmp(&b.name));
61
62 this.update(&mut cx, |this, cx| {
63 this.available_models = models;
64 cx.notify();
65 })
66 })
67 }
68}
69
70impl OllamaLanguageModelProvider {
71 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
72 let this = Self {
73 http_client: http_client.clone(),
74 state: cx.new_model(|cx| State {
75 http_client,
76 available_models: Default::default(),
77 _subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
78 this.fetch_models(cx).detach();
79 cx.notify();
80 }),
81 }),
82 };
83 this.state
84 .update(cx, |state, cx| state.fetch_models(cx).detach());
85 this
86 }
87}
88
89impl LanguageModelProviderState for OllamaLanguageModelProvider {
90 type ObservableEntity = State;
91
92 fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
93 Some(self.state.clone())
94 }
95}
96
97impl LanguageModelProvider for OllamaLanguageModelProvider {
98 fn id(&self) -> LanguageModelProviderId {
99 LanguageModelProviderId(PROVIDER_ID.into())
100 }
101
102 fn name(&self) -> LanguageModelProviderName {
103 LanguageModelProviderName(PROVIDER_NAME.into())
104 }
105
106 fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
107 self.state
108 .read(cx)
109 .available_models
110 .iter()
111 .map(|model| {
112 Arc::new(OllamaLanguageModel {
113 id: LanguageModelId::from(model.name.clone()),
114 model: model.clone(),
115 http_client: self.http_client.clone(),
116 request_limiter: RateLimiter::new(4),
117 }) as Arc<dyn LanguageModel>
118 })
119 .collect()
120 }
121
122 fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &AppContext) {
123 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
124 let http_client = self.http_client.clone();
125 let api_url = settings.api_url.clone();
126 let id = model.id().0.to_string();
127 cx.spawn(|_| async move { preload_model(http_client, &api_url, &id).await })
128 .detach_and_log_err(cx);
129 }
130
131 fn is_authenticated(&self, cx: &AppContext) -> bool {
132 !self.state.read(cx).available_models.is_empty()
133 }
134
135 fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
136 if self.is_authenticated(cx) {
137 Task::ready(Ok(()))
138 } else {
139 self.state.update(cx, |state, cx| state.fetch_models(cx))
140 }
141 }
142
143 fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
144 let state = self.state.clone();
145 let fetch_models = Box::new(move |cx: &mut WindowContext| {
146 state.update(cx, |this, cx| this.fetch_models(cx))
147 });
148
149 cx.new_view(|cx| DownloadOllamaMessage::new(fetch_models, cx))
150 .into()
151 }
152
153 fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
154 self.state.update(cx, |state, cx| state.fetch_models(cx))
155 }
156}
157
158pub struct OllamaLanguageModel {
159 id: LanguageModelId,
160 model: ollama::Model,
161 http_client: Arc<dyn HttpClient>,
162 request_limiter: RateLimiter,
163}
164
165impl OllamaLanguageModel {
166 fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
167 ChatRequest {
168 model: self.model.name.clone(),
169 messages: request
170 .messages
171 .into_iter()
172 .map(|msg| match msg.role {
173 Role::User => ChatMessage::User {
174 content: msg.content,
175 },
176 Role::Assistant => ChatMessage::Assistant {
177 content: msg.content,
178 },
179 Role::System => ChatMessage::System {
180 content: msg.content,
181 },
182 })
183 .collect(),
184 keep_alive: self.model.keep_alive.clone().unwrap_or_default(),
185 stream: true,
186 options: Some(ChatOptions {
187 num_ctx: Some(self.model.max_tokens),
188 stop: Some(request.stop),
189 temperature: Some(request.temperature),
190 ..Default::default()
191 }),
192 }
193 }
194}
195
196impl LanguageModel for OllamaLanguageModel {
197 fn id(&self) -> LanguageModelId {
198 self.id.clone()
199 }
200
201 fn name(&self) -> LanguageModelName {
202 LanguageModelName::from(self.model.display_name().to_string())
203 }
204
205 fn provider_id(&self) -> LanguageModelProviderId {
206 LanguageModelProviderId(PROVIDER_ID.into())
207 }
208
209 fn provider_name(&self) -> LanguageModelProviderName {
210 LanguageModelProviderName(PROVIDER_NAME.into())
211 }
212
213 fn telemetry_id(&self) -> String {
214 format!("ollama/{}", self.model.id())
215 }
216
217 fn max_token_count(&self) -> usize {
218 self.model.max_token_count()
219 }
220
221 fn count_tokens(
222 &self,
223 request: LanguageModelRequest,
224 _cx: &AppContext,
225 ) -> BoxFuture<'static, Result<usize>> {
226 // There is no endpoint for this _yet_ in Ollama
227 // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
228 let token_count = request
229 .messages
230 .iter()
231 .map(|msg| msg.content.chars().count())
232 .sum::<usize>()
233 / 4;
234
235 async move { Ok(token_count) }.boxed()
236 }
237
238 fn stream_completion(
239 &self,
240 request: LanguageModelRequest,
241 cx: &AsyncAppContext,
242 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
243 let request = self.to_ollama_request(request);
244
245 let http_client = self.http_client.clone();
246 let Ok((api_url, low_speed_timeout)) = cx.update(|cx| {
247 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
248 (settings.api_url.clone(), settings.low_speed_timeout)
249 }) else {
250 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
251 };
252
253 let future = self.request_limiter.stream(async move {
254 let response =
255 stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout)
256 .await?;
257 let stream = response
258 .filter_map(|response| async move {
259 match response {
260 Ok(delta) => {
261 let content = match delta.message {
262 ChatMessage::User { content } => content,
263 ChatMessage::Assistant { content } => content,
264 ChatMessage::System { content } => content,
265 };
266 Some(Ok(content))
267 }
268 Err(error) => Some(Err(error)),
269 }
270 })
271 .boxed();
272 Ok(stream)
273 });
274
275 async move { Ok(future.await?.boxed()) }.boxed()
276 }
277
278 fn use_any_tool(
279 &self,
280 _request: LanguageModelRequest,
281 _name: String,
282 _description: String,
283 _schema: serde_json::Value,
284 _cx: &AsyncAppContext,
285 ) -> BoxFuture<'static, Result<serde_json::Value>> {
286 future::ready(Err(anyhow!("not implemented"))).boxed()
287 }
288}
289
290struct DownloadOllamaMessage {
291 retry_connection: Box<dyn Fn(&mut WindowContext) -> Task<Result<()>>>,
292}
293
294impl DownloadOllamaMessage {
295 pub fn new(
296 retry_connection: Box<dyn Fn(&mut WindowContext) -> Task<Result<()>>>,
297 _cx: &mut ViewContext<Self>,
298 ) -> Self {
299 Self { retry_connection }
300 }
301
302 fn render_download_button(&self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
303 ButtonLike::new("download_ollama_button")
304 .style(ButtonStyle::Filled)
305 .size(ButtonSize::Large)
306 .layer(ElevationIndex::ModalSurface)
307 .child(Label::new("Get Ollama"))
308 .on_click(move |_, cx| cx.open_url(OLLAMA_DOWNLOAD_URL))
309 }
310
311 fn render_retry_button(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
312 ButtonLike::new("retry_ollama_models")
313 .style(ButtonStyle::Filled)
314 .size(ButtonSize::Large)
315 .layer(ElevationIndex::ModalSurface)
316 .child(Label::new("Retry"))
317 .on_click(cx.listener(move |this, _, cx| {
318 let connected = (this.retry_connection)(cx);
319
320 cx.spawn(|_this, _cx| async move {
321 connected.await?;
322 anyhow::Ok(())
323 })
324 .detach_and_log_err(cx)
325 }))
326 }
327
328 fn render_next_steps(&self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
329 v_flex()
330 .p_4()
331 .size_full()
332 .gap_2()
333 .child(
334 Label::new("Once Ollama is on your machine, make sure to download a model or two.")
335 .size(LabelSize::Large),
336 )
337 .child(
338 h_flex().w_full().p_4().justify_center().gap_2().child(
339 ButtonLike::new("view-models")
340 .style(ButtonStyle::Filled)
341 .size(ButtonSize::Large)
342 .layer(ElevationIndex::ModalSurface)
343 .child(Label::new("View Available Models"))
344 .on_click(move |_, cx| cx.open_url(OLLAMA_LIBRARY_URL)),
345 ),
346 )
347 }
348}
349
350impl Render for DownloadOllamaMessage {
351 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
352 v_flex()
353 .p_4()
354 .size_full()
355 .gap_2()
356 .child(Label::new("To use Ollama models via the assistant, Ollama must be running on your machine with at least one model downloaded.").size(LabelSize::Large))
357 .child(
358 h_flex()
359 .w_full()
360 .p_4()
361 .justify_center()
362 .gap_2()
363 .child(
364 self.render_download_button(cx)
365 )
366 .child(
367 self.render_retry_button(cx)
368 )
369 )
370 .child(self.render_next_steps(cx))
371 .into_any()
372 }
373}