1use anyhow::{anyhow, bail, Result};
2use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
3use gpui::{AnyView, AppContext, AsyncAppContext, ModelContext, Subscription, Task};
4use http_client::HttpClient;
5use ollama::{
6 get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
7 ChatResponseDelta, OllamaToolCall,
8};
9use serde_json::Value;
10use settings::{Settings, SettingsStore};
11use std::{sync::Arc, time::Duration};
12use ui::{prelude::*, ButtonLike, Indicator};
13use util::ResultExt;
14
15use crate::{
16 settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
17 LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
18 LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
19};
20
21const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
22const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
23const OLLAMA_SITE: &str = "https://ollama.com/";
24
25const PROVIDER_ID: &str = "ollama";
26const PROVIDER_NAME: &str = "Ollama";
27
28#[derive(Default, Debug, Clone, PartialEq)]
29pub struct OllamaSettings {
30 pub api_url: String,
31 pub low_speed_timeout: Option<Duration>,
32}
33
34pub struct OllamaLanguageModelProvider {
35 http_client: Arc<dyn HttpClient>,
36 state: gpui::Model<State>,
37}
38
39pub struct State {
40 http_client: Arc<dyn HttpClient>,
41 available_models: Vec<ollama::Model>,
42 _subscription: Subscription,
43}
44
45impl State {
46 fn is_authenticated(&self) -> bool {
47 !self.available_models.is_empty()
48 }
49
50 fn fetch_models(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
51 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
52 let http_client = self.http_client.clone();
53 let api_url = settings.api_url.clone();
54
55 // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
56 cx.spawn(|this, mut cx| async move {
57 let models = get_models(http_client.as_ref(), &api_url, None).await?;
58
59 let mut models: Vec<ollama::Model> = models
60 .into_iter()
61 // Since there is no metadata from the Ollama API
62 // indicating which models are embedding models,
63 // simply filter out models with "-embed" in their name
64 .filter(|model| !model.name.contains("-embed"))
65 .map(|model| ollama::Model::new(&model.name))
66 .collect();
67
68 models.sort_by(|a, b| a.name.cmp(&b.name));
69
70 this.update(&mut cx, |this, cx| {
71 this.available_models = models;
72 cx.notify();
73 })
74 })
75 }
76
77 fn authenticate(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
78 if self.is_authenticated() {
79 Task::ready(Ok(()))
80 } else {
81 self.fetch_models(cx)
82 }
83 }
84}
85
86impl OllamaLanguageModelProvider {
87 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
88 let this = Self {
89 http_client: http_client.clone(),
90 state: cx.new_model(|cx| State {
91 http_client,
92 available_models: Default::default(),
93 _subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
94 this.fetch_models(cx).detach();
95 cx.notify();
96 }),
97 }),
98 };
99 this.state
100 .update(cx, |state, cx| state.fetch_models(cx).detach());
101 this
102 }
103}
104
105impl LanguageModelProviderState for OllamaLanguageModelProvider {
106 type ObservableEntity = State;
107
108 fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
109 Some(self.state.clone())
110 }
111}
112
113impl LanguageModelProvider for OllamaLanguageModelProvider {
114 fn id(&self) -> LanguageModelProviderId {
115 LanguageModelProviderId(PROVIDER_ID.into())
116 }
117
118 fn name(&self) -> LanguageModelProviderName {
119 LanguageModelProviderName(PROVIDER_NAME.into())
120 }
121
122 fn icon(&self) -> IconName {
123 IconName::AiOllama
124 }
125
126 fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
127 self.state
128 .read(cx)
129 .available_models
130 .iter()
131 .map(|model| {
132 Arc::new(OllamaLanguageModel {
133 id: LanguageModelId::from(model.name.clone()),
134 model: model.clone(),
135 http_client: self.http_client.clone(),
136 request_limiter: RateLimiter::new(4),
137 }) as Arc<dyn LanguageModel>
138 })
139 .collect()
140 }
141
142 fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &AppContext) {
143 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
144 let http_client = self.http_client.clone();
145 let api_url = settings.api_url.clone();
146 let id = model.id().0.to_string();
147 cx.spawn(|_| async move { preload_model(http_client, &api_url, &id).await })
148 .detach_and_log_err(cx);
149 }
150
151 fn is_authenticated(&self, cx: &AppContext) -> bool {
152 self.state.read(cx).is_authenticated()
153 }
154
155 fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
156 self.state.update(cx, |state, cx| state.authenticate(cx))
157 }
158
159 fn configuration_view(&self, cx: &mut WindowContext) -> AnyView {
160 let state = self.state.clone();
161 cx.new_view(|cx| ConfigurationView::new(state, cx)).into()
162 }
163
164 fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
165 self.state.update(cx, |state, cx| state.fetch_models(cx))
166 }
167}
168
169pub struct OllamaLanguageModel {
170 id: LanguageModelId,
171 model: ollama::Model,
172 http_client: Arc<dyn HttpClient>,
173 request_limiter: RateLimiter,
174}
175
176impl OllamaLanguageModel {
177 fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
178 ChatRequest {
179 model: self.model.name.clone(),
180 messages: request
181 .messages
182 .into_iter()
183 .map(|msg| match msg.role {
184 Role::User => ChatMessage::User {
185 content: msg.content,
186 },
187 Role::Assistant => ChatMessage::Assistant {
188 content: msg.content,
189 tool_calls: None,
190 },
191 Role::System => ChatMessage::System {
192 content: msg.content,
193 },
194 })
195 .collect(),
196 keep_alive: self.model.keep_alive.clone().unwrap_or_default(),
197 stream: true,
198 options: Some(ChatOptions {
199 num_ctx: Some(self.model.max_tokens),
200 stop: Some(request.stop),
201 temperature: Some(request.temperature),
202 ..Default::default()
203 }),
204 tools: vec![],
205 }
206 }
207 fn request_completion(
208 &self,
209 request: ChatRequest,
210 cx: &AsyncAppContext,
211 ) -> BoxFuture<'static, Result<ChatResponseDelta>> {
212 let http_client = self.http_client.clone();
213
214 let Ok(api_url) = cx.update(|cx| {
215 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
216 settings.api_url.clone()
217 }) else {
218 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
219 };
220
221 async move { ollama::complete(http_client.as_ref(), &api_url, request).await }.boxed()
222 }
223}
224
225impl LanguageModel for OllamaLanguageModel {
226 fn id(&self) -> LanguageModelId {
227 self.id.clone()
228 }
229
230 fn name(&self) -> LanguageModelName {
231 LanguageModelName::from(self.model.display_name().to_string())
232 }
233
234 fn provider_id(&self) -> LanguageModelProviderId {
235 LanguageModelProviderId(PROVIDER_ID.into())
236 }
237
238 fn provider_name(&self) -> LanguageModelProviderName {
239 LanguageModelProviderName(PROVIDER_NAME.into())
240 }
241
242 fn telemetry_id(&self) -> String {
243 format!("ollama/{}", self.model.id())
244 }
245
246 fn max_token_count(&self) -> usize {
247 self.model.max_token_count()
248 }
249
250 fn count_tokens(
251 &self,
252 request: LanguageModelRequest,
253 _cx: &AppContext,
254 ) -> BoxFuture<'static, Result<usize>> {
255 // There is no endpoint for this _yet_ in Ollama
256 // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
257 let token_count = request
258 .messages
259 .iter()
260 .map(|msg| msg.content.chars().count())
261 .sum::<usize>()
262 / 4;
263
264 async move { Ok(token_count) }.boxed()
265 }
266
267 fn stream_completion(
268 &self,
269 request: LanguageModelRequest,
270 cx: &AsyncAppContext,
271 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
272 let request = self.to_ollama_request(request);
273
274 let http_client = self.http_client.clone();
275 let Ok((api_url, low_speed_timeout)) = cx.update(|cx| {
276 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
277 (settings.api_url.clone(), settings.low_speed_timeout)
278 }) else {
279 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
280 };
281
282 let future = self.request_limiter.stream(async move {
283 let response =
284 stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout)
285 .await?;
286 let stream = response
287 .filter_map(|response| async move {
288 match response {
289 Ok(delta) => {
290 let content = match delta.message {
291 ChatMessage::User { content } => content,
292 ChatMessage::Assistant { content, .. } => content,
293 ChatMessage::System { content } => content,
294 };
295 Some(Ok(content))
296 }
297 Err(error) => Some(Err(error)),
298 }
299 })
300 .boxed();
301 Ok(stream)
302 });
303
304 async move { Ok(future.await?.boxed()) }.boxed()
305 }
306
307 fn use_any_tool(
308 &self,
309 request: LanguageModelRequest,
310 tool_name: String,
311 tool_description: String,
312 schema: serde_json::Value,
313 cx: &AsyncAppContext,
314 ) -> BoxFuture<'static, Result<serde_json::Value>> {
315 use ollama::{OllamaFunctionTool, OllamaTool};
316 let function = OllamaFunctionTool {
317 name: tool_name.clone(),
318 description: Some(tool_description),
319 parameters: Some(schema),
320 };
321 let tools = vec![OllamaTool::Function { function }];
322 let request = self.to_ollama_request(request).with_tools(tools);
323 let response = self.request_completion(request, cx);
324 self.request_limiter
325 .run(async move {
326 let response = response.await?;
327 let ChatMessage::Assistant {
328 tool_calls,
329 content,
330 } = response.message
331 else {
332 bail!("message does not have an assistant role");
333 };
334 if let Some(tool_calls) = tool_calls.filter(|calls| !calls.is_empty()) {
335 for call in tool_calls {
336 let OllamaToolCall::Function(function) = call;
337 if function.name == tool_name {
338 return Ok(function.arguments);
339 }
340 }
341 } else if let Ok(args) = serde_json::from_str::<Value>(&content) {
342 // Parse content as arguments.
343 return Ok(args);
344 } else {
345 bail!("assistant message does not have any tool calls");
346 };
347
348 bail!("tool not used")
349 })
350 .boxed()
351 }
352}
353
354struct ConfigurationView {
355 state: gpui::Model<State>,
356 loading_models_task: Option<Task<()>>,
357}
358
359impl ConfigurationView {
360 pub fn new(state: gpui::Model<State>, cx: &mut ViewContext<Self>) -> Self {
361 let loading_models_task = Some(cx.spawn({
362 let state = state.clone();
363 |this, mut cx| async move {
364 if let Some(task) = state
365 .update(&mut cx, |state, cx| state.authenticate(cx))
366 .log_err()
367 {
368 task.await.log_err();
369 }
370 this.update(&mut cx, |this, cx| {
371 this.loading_models_task = None;
372 cx.notify();
373 })
374 .log_err();
375 }
376 }));
377
378 Self {
379 state,
380 loading_models_task,
381 }
382 }
383
384 fn retry_connection(&self, cx: &mut WindowContext) {
385 self.state
386 .update(cx, |state, cx| state.fetch_models(cx))
387 .detach_and_log_err(cx);
388 }
389}
390
391impl Render for ConfigurationView {
392 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
393 let is_authenticated = self.state.read(cx).is_authenticated();
394
395 let ollama_intro = "Get up and running with Llama 3.1, Mistral, Gemma 2, and other large language models with Ollama.";
396 let ollama_reqs =
397 "Ollama must be running with at least one model installed to use it in the assistant.";
398
399 let mut inline_code_bg = cx.theme().colors().editor_background;
400 inline_code_bg.fade_out(0.5);
401
402 if self.loading_models_task.is_some() {
403 div().child(Label::new("Loading models...")).into_any()
404 } else {
405 v_flex()
406 .size_full()
407 .gap_3()
408 .child(
409 v_flex()
410 .size_full()
411 .gap_2()
412 .p_1()
413 .child(Label::new(ollama_intro))
414 .child(Label::new(ollama_reqs))
415 .child(
416 h_flex()
417 .gap_0p5()
418 .child(Label::new("Once installed, try "))
419 .child(
420 div()
421 .bg(inline_code_bg)
422 .px_1p5()
423 .rounded_md()
424 .child(Label::new("ollama run llama3.1")),
425 ),
426 ),
427 )
428 .child(
429 h_flex()
430 .w_full()
431 .pt_2()
432 .justify_between()
433 .gap_2()
434 .child(
435 h_flex()
436 .w_full()
437 .gap_2()
438 .map(|this| {
439 if is_authenticated {
440 this.child(
441 Button::new("ollama-site", "Ollama")
442 .style(ButtonStyle::Subtle)
443 .icon(IconName::ExternalLink)
444 .icon_size(IconSize::XSmall)
445 .icon_color(Color::Muted)
446 .on_click(move |_, cx| cx.open_url(OLLAMA_SITE))
447 .into_any_element(),
448 )
449 } else {
450 this.child(
451 Button::new(
452 "download_ollama_button",
453 "Download Ollama",
454 )
455 .style(ButtonStyle::Subtle)
456 .icon(IconName::ExternalLink)
457 .icon_size(IconSize::XSmall)
458 .icon_color(Color::Muted)
459 .on_click(move |_, cx| cx.open_url(OLLAMA_DOWNLOAD_URL))
460 .into_any_element(),
461 )
462 }
463 })
464 .child(
465 Button::new("view-models", "All Models")
466 .style(ButtonStyle::Subtle)
467 .icon(IconName::ExternalLink)
468 .icon_size(IconSize::XSmall)
469 .icon_color(Color::Muted)
470 .on_click(move |_, cx| cx.open_url(OLLAMA_LIBRARY_URL)),
471 ),
472 )
473 .child(if is_authenticated {
474 // This is only a button to ensure the spacing is correct
475 // it should stay disabled
476 ButtonLike::new("connected")
477 .disabled(true)
478 // Since this won't ever be clickable, we can use the arrow cursor
479 .cursor_style(gpui::CursorStyle::Arrow)
480 .child(
481 h_flex()
482 .gap_2()
483 .child(Indicator::dot().color(Color::Success))
484 .child(Label::new("Connected"))
485 .into_any_element(),
486 )
487 .into_any_element()
488 } else {
489 Button::new("retry_ollama_models", "Connect")
490 .icon_position(IconPosition::Start)
491 .icon(IconName::ArrowCircle)
492 .on_click(cx.listener(move |this, _, cx| this.retry_connection(cx)))
493 .into_any_element()
494 }),
495 )
496 .into_any()
497 }
498 }
499}