1use anyhow::{Result, anyhow};
2use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
3use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
4use http_client::HttpClient;
5use language_model::{AuthenticateError, LanguageModelCompletionEvent};
6use language_model::{
7 LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
8 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
9 LanguageModelRequest, RateLimiter, Role,
10};
11use ollama::{
12 ChatMessage, ChatOptions, ChatRequest, KeepAlive, get_models, preload_model,
13 stream_chat_completion,
14};
15use schemars::JsonSchema;
16use serde::{Deserialize, Serialize};
17use settings::{Settings, SettingsStore};
18use std::{collections::BTreeMap, sync::Arc};
19use ui::{ButtonLike, Indicator, List, prelude::*};
20use util::ResultExt;
21
22use crate::AllLanguageModelSettings;
23use crate::ui::InstructionListItem;
24
25const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
26const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
27const OLLAMA_SITE: &str = "https://ollama.com/";
28
29const PROVIDER_ID: &str = "ollama";
30const PROVIDER_NAME: &str = "Ollama";
31
32#[derive(Default, Debug, Clone, PartialEq)]
33pub struct OllamaSettings {
34 pub api_url: String,
35 pub available_models: Vec<AvailableModel>,
36}
37
38#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
39pub struct AvailableModel {
40 /// The model name in the Ollama API (e.g. "llama3.2:latest")
41 pub name: String,
42 /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
43 pub display_name: Option<String>,
44 /// The Context Length parameter to the model (aka num_ctx or n_ctx)
45 pub max_tokens: usize,
46 /// The number of seconds to keep the connection open after the last request
47 pub keep_alive: Option<KeepAlive>,
48}
49
50pub struct OllamaLanguageModelProvider {
51 http_client: Arc<dyn HttpClient>,
52 state: gpui::Entity<State>,
53}
54
55pub struct State {
56 http_client: Arc<dyn HttpClient>,
57 available_models: Vec<ollama::Model>,
58 fetch_model_task: Option<Task<Result<()>>>,
59 _subscription: Subscription,
60}
61
62impl State {
63 fn is_authenticated(&self) -> bool {
64 !self.available_models.is_empty()
65 }
66
67 fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
68 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
69 let http_client = self.http_client.clone();
70 let api_url = settings.api_url.clone();
71
72 // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
73 cx.spawn(async move |this, cx| {
74 let models = get_models(http_client.as_ref(), &api_url, None).await?;
75
76 let mut models: Vec<ollama::Model> = models
77 .into_iter()
78 // Since there is no metadata from the Ollama API
79 // indicating which models are embedding models,
80 // simply filter out models with "-embed" in their name
81 .filter(|model| !model.name.contains("-embed"))
82 .map(|model| ollama::Model::new(&model.name, None, None))
83 .collect();
84
85 models.sort_by(|a, b| a.name.cmp(&b.name));
86
87 this.update(cx, |this, cx| {
88 this.available_models = models;
89 cx.notify();
90 })
91 })
92 }
93
94 fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
95 let task = self.fetch_models(cx);
96 self.fetch_model_task.replace(task);
97 }
98
99 fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
100 if self.is_authenticated() {
101 return Task::ready(Ok(()));
102 }
103
104 let fetch_models_task = self.fetch_models(cx);
105 cx.spawn(async move |_this, _cx| Ok(fetch_models_task.await?))
106 }
107}
108
109impl OllamaLanguageModelProvider {
110 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
111 let this = Self {
112 http_client: http_client.clone(),
113 state: cx.new(|cx| {
114 let subscription = cx.observe_global::<SettingsStore>({
115 let mut settings = AllLanguageModelSettings::get_global(cx).ollama.clone();
116 move |this: &mut State, cx| {
117 let new_settings = &AllLanguageModelSettings::get_global(cx).ollama;
118 if &settings != new_settings {
119 settings = new_settings.clone();
120 this.restart_fetch_models_task(cx);
121 cx.notify();
122 }
123 }
124 });
125
126 State {
127 http_client,
128 available_models: Default::default(),
129 fetch_model_task: None,
130 _subscription: subscription,
131 }
132 }),
133 };
134 this.state
135 .update(cx, |state, cx| state.restart_fetch_models_task(cx));
136 this
137 }
138}
139
140impl LanguageModelProviderState for OllamaLanguageModelProvider {
141 type ObservableEntity = State;
142
143 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
144 Some(self.state.clone())
145 }
146}
147
148impl LanguageModelProvider for OllamaLanguageModelProvider {
149 fn id(&self) -> LanguageModelProviderId {
150 LanguageModelProviderId(PROVIDER_ID.into())
151 }
152
153 fn name(&self) -> LanguageModelProviderName {
154 LanguageModelProviderName(PROVIDER_NAME.into())
155 }
156
157 fn icon(&self) -> IconName {
158 IconName::AiOllama
159 }
160
161 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
162 self.provided_models(cx).into_iter().next()
163 }
164
165 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
166 self.default_model(cx)
167 }
168
169 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
170 let mut models: BTreeMap<String, ollama::Model> = BTreeMap::default();
171
172 // Add models from the Ollama API
173 for model in self.state.read(cx).available_models.iter() {
174 models.insert(model.name.clone(), model.clone());
175 }
176
177 // Override with available models from settings
178 for model in AllLanguageModelSettings::get_global(cx)
179 .ollama
180 .available_models
181 .iter()
182 {
183 models.insert(
184 model.name.clone(),
185 ollama::Model {
186 name: model.name.clone(),
187 display_name: model.display_name.clone(),
188 max_tokens: model.max_tokens,
189 keep_alive: model.keep_alive.clone(),
190 },
191 );
192 }
193
194 models
195 .into_values()
196 .map(|model| {
197 Arc::new(OllamaLanguageModel {
198 id: LanguageModelId::from(model.name.clone()),
199 model: model.clone(),
200 http_client: self.http_client.clone(),
201 request_limiter: RateLimiter::new(4),
202 }) as Arc<dyn LanguageModel>
203 })
204 .collect()
205 }
206
207 fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &App) {
208 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
209 let http_client = self.http_client.clone();
210 let api_url = settings.api_url.clone();
211 let id = model.id().0.to_string();
212 cx.spawn(async move |_| preload_model(http_client, &api_url, &id).await)
213 .detach_and_log_err(cx);
214 }
215
216 fn is_authenticated(&self, cx: &App) -> bool {
217 self.state.read(cx).is_authenticated()
218 }
219
220 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
221 self.state.update(cx, |state, cx| state.authenticate(cx))
222 }
223
224 fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
225 let state = self.state.clone();
226 cx.new(|cx| ConfigurationView::new(state, window, cx))
227 .into()
228 }
229
230 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
231 self.state.update(cx, |state, cx| state.fetch_models(cx))
232 }
233}
234
235pub struct OllamaLanguageModel {
236 id: LanguageModelId,
237 model: ollama::Model,
238 http_client: Arc<dyn HttpClient>,
239 request_limiter: RateLimiter,
240}
241
242impl OllamaLanguageModel {
243 fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
244 ChatRequest {
245 model: self.model.name.clone(),
246 messages: request
247 .messages
248 .into_iter()
249 .map(|msg| match msg.role {
250 Role::User => ChatMessage::User {
251 content: msg.string_contents(),
252 },
253 Role::Assistant => ChatMessage::Assistant {
254 content: msg.string_contents(),
255 tool_calls: None,
256 },
257 Role::System => ChatMessage::System {
258 content: msg.string_contents(),
259 },
260 })
261 .collect(),
262 keep_alive: self.model.keep_alive.clone().unwrap_or_default(),
263 stream: true,
264 options: Some(ChatOptions {
265 num_ctx: Some(self.model.max_tokens),
266 stop: Some(request.stop),
267 temperature: request.temperature.or(Some(1.0)),
268 ..Default::default()
269 }),
270 tools: vec![],
271 }
272 }
273}
274
275impl LanguageModel for OllamaLanguageModel {
276 fn id(&self) -> LanguageModelId {
277 self.id.clone()
278 }
279
280 fn name(&self) -> LanguageModelName {
281 LanguageModelName::from(self.model.display_name().to_string())
282 }
283
284 fn provider_id(&self) -> LanguageModelProviderId {
285 LanguageModelProviderId(PROVIDER_ID.into())
286 }
287
288 fn provider_name(&self) -> LanguageModelProviderName {
289 LanguageModelProviderName(PROVIDER_NAME.into())
290 }
291
292 fn supports_tools(&self) -> bool {
293 false
294 }
295
296 fn telemetry_id(&self) -> String {
297 format!("ollama/{}", self.model.id())
298 }
299
300 fn max_token_count(&self) -> usize {
301 self.model.max_token_count()
302 }
303
304 fn count_tokens(
305 &self,
306 request: LanguageModelRequest,
307 _cx: &App,
308 ) -> BoxFuture<'static, Result<usize>> {
309 // There is no endpoint for this _yet_ in Ollama
310 // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
311 let token_count = request
312 .messages
313 .iter()
314 .map(|msg| msg.string_contents().chars().count())
315 .sum::<usize>()
316 / 4;
317
318 async move { Ok(token_count) }.boxed()
319 }
320
321 fn stream_completion(
322 &self,
323 request: LanguageModelRequest,
324 cx: &AsyncApp,
325 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
326 let request = self.to_ollama_request(request);
327
328 let http_client = self.http_client.clone();
329 let Ok(api_url) = cx.update(|cx| {
330 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
331 settings.api_url.clone()
332 }) else {
333 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
334 };
335
336 let future = self.request_limiter.stream(async move {
337 let response = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
338 let stream = response
339 .filter_map(|response| async move {
340 match response {
341 Ok(delta) => {
342 let content = match delta.message {
343 ChatMessage::User { content } => content,
344 ChatMessage::Assistant { content, .. } => content,
345 ChatMessage::System { content } => content,
346 };
347 Some(Ok(content))
348 }
349 Err(error) => Some(Err(error)),
350 }
351 })
352 .boxed();
353 Ok(stream)
354 });
355
356 async move {
357 Ok(future
358 .await?
359 .map(|result| result.map(LanguageModelCompletionEvent::Text))
360 .boxed())
361 }
362 .boxed()
363 }
364}
365
366struct ConfigurationView {
367 state: gpui::Entity<State>,
368 loading_models_task: Option<Task<()>>,
369}
370
371impl ConfigurationView {
372 pub fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
373 let loading_models_task = Some(cx.spawn_in(window, {
374 let state = state.clone();
375 async move |this, cx| {
376 if let Some(task) = state
377 .update(cx, |state, cx| state.authenticate(cx))
378 .log_err()
379 {
380 task.await.log_err();
381 }
382 this.update(cx, |this, cx| {
383 this.loading_models_task = None;
384 cx.notify();
385 })
386 .log_err();
387 }
388 }));
389
390 Self {
391 state,
392 loading_models_task,
393 }
394 }
395
396 fn retry_connection(&self, cx: &mut App) {
397 self.state
398 .update(cx, |state, cx| state.fetch_models(cx))
399 .detach_and_log_err(cx);
400 }
401}
402
403impl Render for ConfigurationView {
404 fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
405 let is_authenticated = self.state.read(cx).is_authenticated();
406
407 let ollama_intro =
408 "Get up & running with Llama 3.3, Mistral, Gemma 2, and other LLMs with Ollama.";
409
410 if self.loading_models_task.is_some() {
411 div().child(Label::new("Loading models...")).into_any()
412 } else {
413 v_flex()
414 .gap_2()
415 .child(
416 v_flex().gap_1().child(Label::new(ollama_intro)).child(
417 List::new()
418 .child(InstructionListItem::text_only("Ollama must be running with at least one model installed to use it in the assistant."))
419 .child(InstructionListItem::text_only(
420 "Once installed, try `ollama run llama3.2`",
421 )),
422 ),
423 )
424 .child(
425 h_flex()
426 .w_full()
427 .justify_between()
428 .gap_2()
429 .child(
430 h_flex()
431 .w_full()
432 .gap_2()
433 .map(|this| {
434 if is_authenticated {
435 this.child(
436 Button::new("ollama-site", "Ollama")
437 .style(ButtonStyle::Subtle)
438 .icon(IconName::ArrowUpRight)
439 .icon_size(IconSize::XSmall)
440 .icon_color(Color::Muted)
441 .on_click(move |_, _, cx| cx.open_url(OLLAMA_SITE))
442 .into_any_element(),
443 )
444 } else {
445 this.child(
446 Button::new(
447 "download_ollama_button",
448 "Download Ollama",
449 )
450 .style(ButtonStyle::Subtle)
451 .icon(IconName::ArrowUpRight)
452 .icon_size(IconSize::XSmall)
453 .icon_color(Color::Muted)
454 .on_click(move |_, _, cx| {
455 cx.open_url(OLLAMA_DOWNLOAD_URL)
456 })
457 .into_any_element(),
458 )
459 }
460 })
461 .child(
462 Button::new("view-models", "All Models")
463 .style(ButtonStyle::Subtle)
464 .icon(IconName::ArrowUpRight)
465 .icon_size(IconSize::XSmall)
466 .icon_color(Color::Muted)
467 .on_click(move |_, _, cx| cx.open_url(OLLAMA_LIBRARY_URL)),
468 ),
469 )
470 .map(|this| {
471 if is_authenticated {
472 this.child(
473 ButtonLike::new("connected")
474 .disabled(true)
475 .cursor_style(gpui::CursorStyle::Arrow)
476 .child(
477 h_flex()
478 .gap_2()
479 .child(Indicator::dot().color(Color::Success))
480 .child(Label::new("Connected"))
481 .into_any_element(),
482 ),
483 )
484 } else {
485 this.child(
486 Button::new("retry_ollama_models", "Connect")
487 .icon_position(IconPosition::Start)
488 .icon_size(IconSize::XSmall)
489 .icon(IconName::Play)
490 .on_click(cx.listener(move |this, _, _, cx| {
491 this.retry_connection(cx)
492 })),
493 )
494 }
495 })
496 )
497 .into_any()
498 }
499 }
500}