1use anyhow::{anyhow, bail, Result};
2use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
3use gpui::{AnyView, AppContext, AsyncAppContext, ModelContext, Subscription, Task};
4use http_client::HttpClient;
5use ollama::{
6 get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
7 ChatResponseDelta, OllamaToolCall,
8};
9use settings::{Settings, SettingsStore};
10use std::{sync::Arc, time::Duration};
11use ui::{prelude::*, ButtonLike, Indicator};
12use util::ResultExt;
13
14use crate::{
15 settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
16 LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
17 LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
18};
19
20const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
21const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
22const OLLAMA_SITE: &str = "https://ollama.com/";
23
24const PROVIDER_ID: &str = "ollama";
25const PROVIDER_NAME: &str = "Ollama";
26
27#[derive(Default, Debug, Clone, PartialEq)]
28pub struct OllamaSettings {
29 pub api_url: String,
30 pub low_speed_timeout: Option<Duration>,
31}
32
33pub struct OllamaLanguageModelProvider {
34 http_client: Arc<dyn HttpClient>,
35 state: gpui::Model<State>,
36}
37
38pub struct State {
39 http_client: Arc<dyn HttpClient>,
40 available_models: Vec<ollama::Model>,
41 _subscription: Subscription,
42}
43
44impl State {
45 fn is_authenticated(&self) -> bool {
46 !self.available_models.is_empty()
47 }
48
49 fn fetch_models(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
50 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
51 let http_client = self.http_client.clone();
52 let api_url = settings.api_url.clone();
53
54 // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
55 cx.spawn(|this, mut cx| async move {
56 let models = get_models(http_client.as_ref(), &api_url, None).await?;
57
58 let mut models: Vec<ollama::Model> = models
59 .into_iter()
60 // Since there is no metadata from the Ollama API
61 // indicating which models are embedding models,
62 // simply filter out models with "-embed" in their name
63 .filter(|model| !model.name.contains("-embed"))
64 .map(|model| ollama::Model::new(&model.name))
65 .collect();
66
67 models.sort_by(|a, b| a.name.cmp(&b.name));
68
69 this.update(&mut cx, |this, cx| {
70 this.available_models = models;
71 cx.notify();
72 })
73 })
74 }
75
76 fn authenticate(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
77 if self.is_authenticated() {
78 Task::ready(Ok(()))
79 } else {
80 self.fetch_models(cx)
81 }
82 }
83}
84
85impl OllamaLanguageModelProvider {
86 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
87 let this = Self {
88 http_client: http_client.clone(),
89 state: cx.new_model(|cx| State {
90 http_client,
91 available_models: Default::default(),
92 _subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
93 this.fetch_models(cx).detach();
94 cx.notify();
95 }),
96 }),
97 };
98 this.state
99 .update(cx, |state, cx| state.fetch_models(cx).detach());
100 this
101 }
102}
103
104impl LanguageModelProviderState for OllamaLanguageModelProvider {
105 type ObservableEntity = State;
106
107 fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
108 Some(self.state.clone())
109 }
110}
111
112impl LanguageModelProvider for OllamaLanguageModelProvider {
113 fn id(&self) -> LanguageModelProviderId {
114 LanguageModelProviderId(PROVIDER_ID.into())
115 }
116
117 fn name(&self) -> LanguageModelProviderName {
118 LanguageModelProviderName(PROVIDER_NAME.into())
119 }
120
121 fn icon(&self) -> IconName {
122 IconName::AiOllama
123 }
124
125 fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
126 self.state
127 .read(cx)
128 .available_models
129 .iter()
130 .map(|model| {
131 Arc::new(OllamaLanguageModel {
132 id: LanguageModelId::from(model.name.clone()),
133 model: model.clone(),
134 http_client: self.http_client.clone(),
135 request_limiter: RateLimiter::new(4),
136 }) as Arc<dyn LanguageModel>
137 })
138 .collect()
139 }
140
141 fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &AppContext) {
142 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
143 let http_client = self.http_client.clone();
144 let api_url = settings.api_url.clone();
145 let id = model.id().0.to_string();
146 cx.spawn(|_| async move { preload_model(http_client, &api_url, &id).await })
147 .detach_and_log_err(cx);
148 }
149
150 fn is_authenticated(&self, cx: &AppContext) -> bool {
151 self.state.read(cx).is_authenticated()
152 }
153
154 fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
155 self.state.update(cx, |state, cx| state.authenticate(cx))
156 }
157
158 fn configuration_view(&self, cx: &mut WindowContext) -> AnyView {
159 let state = self.state.clone();
160 cx.new_view(|cx| ConfigurationView::new(state, cx)).into()
161 }
162
163 fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
164 self.state.update(cx, |state, cx| state.fetch_models(cx))
165 }
166}
167
168pub struct OllamaLanguageModel {
169 id: LanguageModelId,
170 model: ollama::Model,
171 http_client: Arc<dyn HttpClient>,
172 request_limiter: RateLimiter,
173}
174
175impl OllamaLanguageModel {
176 fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
177 ChatRequest {
178 model: self.model.name.clone(),
179 messages: request
180 .messages
181 .into_iter()
182 .map(|msg| match msg.role {
183 Role::User => ChatMessage::User {
184 content: msg.string_contents(),
185 },
186 Role::Assistant => ChatMessage::Assistant {
187 content: msg.string_contents(),
188 tool_calls: None,
189 },
190 Role::System => ChatMessage::System {
191 content: msg.string_contents(),
192 },
193 })
194 .collect(),
195 keep_alive: self.model.keep_alive.clone().unwrap_or_default(),
196 stream: true,
197 options: Some(ChatOptions {
198 num_ctx: Some(self.model.max_tokens),
199 stop: Some(request.stop),
200 temperature: Some(request.temperature),
201 ..Default::default()
202 }),
203 tools: vec![],
204 }
205 }
206 fn request_completion(
207 &self,
208 request: ChatRequest,
209 cx: &AsyncAppContext,
210 ) -> BoxFuture<'static, Result<ChatResponseDelta>> {
211 let http_client = self.http_client.clone();
212
213 let Ok(api_url) = cx.update(|cx| {
214 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
215 settings.api_url.clone()
216 }) else {
217 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
218 };
219
220 async move { ollama::complete(http_client.as_ref(), &api_url, request).await }.boxed()
221 }
222}
223
224impl LanguageModel for OllamaLanguageModel {
225 fn id(&self) -> LanguageModelId {
226 self.id.clone()
227 }
228
229 fn name(&self) -> LanguageModelName {
230 LanguageModelName::from(self.model.display_name().to_string())
231 }
232
233 fn provider_id(&self) -> LanguageModelProviderId {
234 LanguageModelProviderId(PROVIDER_ID.into())
235 }
236
237 fn provider_name(&self) -> LanguageModelProviderName {
238 LanguageModelProviderName(PROVIDER_NAME.into())
239 }
240
241 fn telemetry_id(&self) -> String {
242 format!("ollama/{}", self.model.id())
243 }
244
245 fn max_token_count(&self) -> usize {
246 self.model.max_token_count()
247 }
248
249 fn count_tokens(
250 &self,
251 request: LanguageModelRequest,
252 _cx: &AppContext,
253 ) -> BoxFuture<'static, Result<usize>> {
254 // There is no endpoint for this _yet_ in Ollama
255 // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
256 let token_count = request
257 .messages
258 .iter()
259 .map(|msg| msg.string_contents().chars().count())
260 .sum::<usize>()
261 / 4;
262
263 async move { Ok(token_count) }.boxed()
264 }
265
266 fn stream_completion(
267 &self,
268 request: LanguageModelRequest,
269 cx: &AsyncAppContext,
270 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
271 let request = self.to_ollama_request(request);
272
273 let http_client = self.http_client.clone();
274 let Ok((api_url, low_speed_timeout)) = cx.update(|cx| {
275 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
276 (settings.api_url.clone(), settings.low_speed_timeout)
277 }) else {
278 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
279 };
280
281 let future = self.request_limiter.stream(async move {
282 let response =
283 stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout)
284 .await?;
285 let stream = response
286 .filter_map(|response| async move {
287 match response {
288 Ok(delta) => {
289 let content = match delta.message {
290 ChatMessage::User { content } => content,
291 ChatMessage::Assistant { content, .. } => content,
292 ChatMessage::System { content } => content,
293 };
294 Some(Ok(content))
295 }
296 Err(error) => Some(Err(error)),
297 }
298 })
299 .boxed();
300 Ok(stream)
301 });
302
303 async move { Ok(future.await?.boxed()) }.boxed()
304 }
305
306 fn use_any_tool(
307 &self,
308 request: LanguageModelRequest,
309 tool_name: String,
310 tool_description: String,
311 schema: serde_json::Value,
312 cx: &AsyncAppContext,
313 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
314 use ollama::{OllamaFunctionTool, OllamaTool};
315 let function = OllamaFunctionTool {
316 name: tool_name.clone(),
317 description: Some(tool_description),
318 parameters: Some(schema),
319 };
320 let tools = vec![OllamaTool::Function { function }];
321 let request = self.to_ollama_request(request).with_tools(tools);
322 let response = self.request_completion(request, cx);
323 self.request_limiter
324 .run(async move {
325 let response = response.await?;
326 let ChatMessage::Assistant { tool_calls, .. } = response.message else {
327 bail!("message does not have an assistant role");
328 };
329 if let Some(tool_calls) = tool_calls.filter(|calls| !calls.is_empty()) {
330 for call in tool_calls {
331 let OllamaToolCall::Function(function) = call;
332 if function.name == tool_name {
333 return Ok(futures::stream::once(async move {
334 Ok(function.arguments.to_string())
335 })
336 .boxed());
337 }
338 }
339 } else {
340 bail!("assistant message does not have any tool calls");
341 };
342
343 bail!("tool not used")
344 })
345 .boxed()
346 }
347}
348
349struct ConfigurationView {
350 state: gpui::Model<State>,
351 loading_models_task: Option<Task<()>>,
352}
353
354impl ConfigurationView {
355 pub fn new(state: gpui::Model<State>, cx: &mut ViewContext<Self>) -> Self {
356 let loading_models_task = Some(cx.spawn({
357 let state = state.clone();
358 |this, mut cx| async move {
359 if let Some(task) = state
360 .update(&mut cx, |state, cx| state.authenticate(cx))
361 .log_err()
362 {
363 task.await.log_err();
364 }
365 this.update(&mut cx, |this, cx| {
366 this.loading_models_task = None;
367 cx.notify();
368 })
369 .log_err();
370 }
371 }));
372
373 Self {
374 state,
375 loading_models_task,
376 }
377 }
378
379 fn retry_connection(&self, cx: &mut WindowContext) {
380 self.state
381 .update(cx, |state, cx| state.fetch_models(cx))
382 .detach_and_log_err(cx);
383 }
384}
385
386impl Render for ConfigurationView {
387 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
388 let is_authenticated = self.state.read(cx).is_authenticated();
389
390 let ollama_intro = "Get up and running with Llama 3.1, Mistral, Gemma 2, and other large language models with Ollama.";
391 let ollama_reqs =
392 "Ollama must be running with at least one model installed to use it in the assistant.";
393
394 let mut inline_code_bg = cx.theme().colors().editor_background;
395 inline_code_bg.fade_out(0.5);
396
397 if self.loading_models_task.is_some() {
398 div().child(Label::new("Loading models...")).into_any()
399 } else {
400 v_flex()
401 .size_full()
402 .gap_3()
403 .child(
404 v_flex()
405 .size_full()
406 .gap_2()
407 .p_1()
408 .child(Label::new(ollama_intro))
409 .child(Label::new(ollama_reqs))
410 .child(
411 h_flex()
412 .gap_0p5()
413 .child(Label::new("Once installed, try "))
414 .child(
415 div()
416 .bg(inline_code_bg)
417 .px_1p5()
418 .rounded_md()
419 .child(Label::new("ollama run llama3.1")),
420 ),
421 ),
422 )
423 .child(
424 h_flex()
425 .w_full()
426 .pt_2()
427 .justify_between()
428 .gap_2()
429 .child(
430 h_flex()
431 .w_full()
432 .gap_2()
433 .map(|this| {
434 if is_authenticated {
435 this.child(
436 Button::new("ollama-site", "Ollama")
437 .style(ButtonStyle::Subtle)
438 .icon(IconName::ExternalLink)
439 .icon_size(IconSize::XSmall)
440 .icon_color(Color::Muted)
441 .on_click(move |_, cx| cx.open_url(OLLAMA_SITE))
442 .into_any_element(),
443 )
444 } else {
445 this.child(
446 Button::new(
447 "download_ollama_button",
448 "Download Ollama",
449 )
450 .style(ButtonStyle::Subtle)
451 .icon(IconName::ExternalLink)
452 .icon_size(IconSize::XSmall)
453 .icon_color(Color::Muted)
454 .on_click(move |_, cx| cx.open_url(OLLAMA_DOWNLOAD_URL))
455 .into_any_element(),
456 )
457 }
458 })
459 .child(
460 Button::new("view-models", "All Models")
461 .style(ButtonStyle::Subtle)
462 .icon(IconName::ExternalLink)
463 .icon_size(IconSize::XSmall)
464 .icon_color(Color::Muted)
465 .on_click(move |_, cx| cx.open_url(OLLAMA_LIBRARY_URL)),
466 ),
467 )
468 .child(if is_authenticated {
469 // This is only a button to ensure the spacing is correct
470 // it should stay disabled
471 ButtonLike::new("connected")
472 .disabled(true)
473 // Since this won't ever be clickable, we can use the arrow cursor
474 .cursor_style(gpui::CursorStyle::Arrow)
475 .child(
476 h_flex()
477 .gap_2()
478 .child(Indicator::dot().color(Color::Success))
479 .child(Label::new("Connected"))
480 .into_any_element(),
481 )
482 .into_any_element()
483 } else {
484 Button::new("retry_ollama_models", "Connect")
485 .icon_position(IconPosition::Start)
486 .icon(IconName::ArrowCircle)
487 .on_click(cx.listener(move |this, _, cx| this.retry_connection(cx)))
488 .into_any_element()
489 }),
490 )
491 .into_any()
492 }
493 }
494}