1/*
  2Copyright (c) 2023 Evan Wallace
  3
  4Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
  5
  6The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
  7
  8THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
  9*/
 10package com.madebyevan.thumbhash;
 11
 12public final class ThumbHash {
 13    /**
 14     * Encodes an RGBA image to a ThumbHash. RGB should not be premultiplied by A.
 15     *
 16     * @param w    The width of the input image. Must be ≤100px.
 17     * @param h    The height of the input image. Must be ≤100px.
 18     * @param rgba The pixels in the input image, row-by-row. Must have w*h*4 elements.
 19     * @return The ThumbHash as a byte array.
 20     */
 21    public static byte[] rgbaToThumbHash(int w, int h, byte[] rgba) {
 22        // Encoding an image larger than 100x100 is slow with no benefit
 23        if (w > 100 || h > 100) throw new IllegalArgumentException(w + "x" + h + " doesn't fit in 100x100");
 24
 25        // Determine the average color
 26        float avg_r = 0, avg_g = 0, avg_b = 0, avg_a = 0;
 27        for (int i = 0, j = 0; i < w * h; i++, j += 4) {
 28            float alpha = (rgba[j + 3] & 255) / 255.0f;
 29            avg_r += alpha / 255.0f * (rgba[j] & 255);
 30            avg_g += alpha / 255.0f * (rgba[j + 1] & 255);
 31            avg_b += alpha / 255.0f * (rgba[j + 2] & 255);
 32            avg_a += alpha;
 33        }
 34        if (avg_a > 0) {
 35            avg_r /= avg_a;
 36            avg_g /= avg_a;
 37            avg_b /= avg_a;
 38        }
 39
 40        boolean hasAlpha = avg_a < w * h;
 41        int l_limit = hasAlpha ? 5 : 7; // Use fewer luminance bits if there's alpha
 42        int lx = Math.max(1, Math.round((float) (l_limit * w) / (float) Math.max(w, h)));
 43        int ly = Math.max(1, Math.round((float) (l_limit * h) / (float) Math.max(w, h)));
 44        float[] l = new float[w * h]; // luminance
 45        float[] p = new float[w * h]; // yellow - blue
 46        float[] q = new float[w * h]; // red - green
 47        float[] a = new float[w * h]; // alpha
 48
 49        // Convert the image from RGBA to LPQA (composite atop the average color)
 50        for (int i = 0, j = 0; i < w * h; i++, j += 4) {
 51            float alpha = (rgba[j + 3] & 255) / 255.0f;
 52            float r = avg_r * (1.0f - alpha) + alpha / 255.0f * (rgba[j] & 255);
 53            float g = avg_g * (1.0f - alpha) + alpha / 255.0f * (rgba[j + 1] & 255);
 54            float b = avg_b * (1.0f - alpha) + alpha / 255.0f * (rgba[j + 2] & 255);
 55            l[i] = (r + g + b) / 3.0f;
 56            p[i] = (r + g) / 2.0f - b;
 57            q[i] = r - g;
 58            a[i] = alpha;
 59        }
 60
 61        // Encode using the DCT into DC (constant) and normalized AC (varying) terms
 62        Channel l_channel = new Channel(Math.max(3, lx), Math.max(3, ly)).encode(w, h, l);
 63        Channel p_channel = new Channel(3, 3).encode(w, h, p);
 64        Channel q_channel = new Channel(3, 3).encode(w, h, q);
 65        Channel a_channel = hasAlpha ? new Channel(5, 5).encode(w, h, a) : null;
 66
 67        // Write the constants
 68        boolean isLandscape = w > h;
 69        int header24 = Math.round(63.0f * l_channel.dc)
 70                | (Math.round(31.5f + 31.5f * p_channel.dc) << 6)
 71                | (Math.round(31.5f + 31.5f * q_channel.dc) << 12)
 72                | (Math.round(31.0f * l_channel.scale) << 18)
 73                | (hasAlpha ? 1 << 23 : 0);
 74        int header16 = (isLandscape ? ly : lx)
 75                | (Math.round(63.0f * p_channel.scale) << 3)
 76                | (Math.round(63.0f * q_channel.scale) << 9)
 77                | (isLandscape ? 1 << 15 : 0);
 78        int ac_start = hasAlpha ? 6 : 5;
 79        int ac_count = l_channel.ac.length + p_channel.ac.length + q_channel.ac.length
 80                + (hasAlpha ? a_channel.ac.length : 0);
 81        byte[] hash = new byte[ac_start + (ac_count + 1) / 2];
 82        hash[0] = (byte) header24;
 83        hash[1] = (byte) (header24 >> 8);
 84        hash[2] = (byte) (header24 >> 16);
 85        hash[3] = (byte) header16;
 86        hash[4] = (byte) (header16 >> 8);
 87        if (hasAlpha) hash[5] = (byte) (Math.round(15.0f * a_channel.dc)
 88                | (Math.round(15.0f * a_channel.scale) << 4));
 89
 90        // Write the varying factors
 91        int ac_index = 0;
 92        ac_index = l_channel.writeTo(hash, ac_start, ac_index);
 93        ac_index = p_channel.writeTo(hash, ac_start, ac_index);
 94        ac_index = q_channel.writeTo(hash, ac_start, ac_index);
 95        if (hasAlpha) a_channel.writeTo(hash, ac_start, ac_index);
 96        return hash;
 97    }
 98
 99    /**
100     * Decodes a ThumbHash to an RGBA image. RGB is not be premultiplied by A.
101     *
102     * @param hash The bytes of the ThumbHash.
103     * @return The width, height, and pixels of the rendered placeholder image.
104     */
105    public static Image thumbHashToRGBA(byte[] hash) {
106        // Read the constants
107        int header24 = (hash[0] & 255) | ((hash[1] & 255) << 8) | ((hash[2] & 255) << 16);
108        int header16 = (hash[3] & 255) | ((hash[4] & 255) << 8);
109        float l_dc = (float) (header24 & 63) / 63.0f;
110        float p_dc = (float) ((header24 >> 6) & 63) / 31.5f - 1.0f;
111        float q_dc = (float) ((header24 >> 12) & 63) / 31.5f - 1.0f;
112        float l_scale = (float) ((header24 >> 18) & 31) / 31.0f;
113        boolean hasAlpha = (header24 >> 23) != 0;
114        float p_scale = (float) ((header16 >> 3) & 63) / 63.0f;
115        float q_scale = (float) ((header16 >> 9) & 63) / 63.0f;
116        boolean isLandscape = (header16 >> 15) != 0;
117        int lx = Math.max(3, isLandscape ? hasAlpha ? 5 : 7 : header16 & 7);
118        int ly = Math.max(3, isLandscape ? header16 & 7 : hasAlpha ? 5 : 7);
119        float a_dc = hasAlpha ? (float) (hash[5] & 15) / 15.0f : 1.0f;
120        float a_scale = (float) ((hash[5] >> 4) & 15) / 15.0f;
121
122        // Read the varying factors (boost saturation by 1.25x to compensate for quantization)
123        int ac_start = hasAlpha ? 6 : 5;
124        int ac_index = 0;
125        Channel l_channel = new Channel(lx, ly);
126        Channel p_channel = new Channel(3, 3);
127        Channel q_channel = new Channel(3, 3);
128        Channel a_channel = null;
129        ac_index = l_channel.decode(hash, ac_start, ac_index, l_scale);
130        ac_index = p_channel.decode(hash, ac_start, ac_index, p_scale * 1.25f);
131        ac_index = q_channel.decode(hash, ac_start, ac_index, q_scale * 1.25f);
132        if (hasAlpha) {
133            a_channel = new Channel(5, 5);
134            a_channel.decode(hash, ac_start, ac_index, a_scale);
135        }
136        float[] l_ac = l_channel.ac;
137        float[] p_ac = p_channel.ac;
138        float[] q_ac = q_channel.ac;
139        float[] a_ac = hasAlpha ? a_channel.ac : null;
140
141        // Decode using the DCT into RGB
142        float ratio = thumbHashToApproximateAspectRatio(hash);
143        int w = Math.round(ratio > 1.0f ? 32.0f : 32.0f * ratio);
144        int h = Math.round(ratio > 1.0f ? 32.0f / ratio : 32.0f);
145        byte[] rgba = new byte[w * h * 4];
146        int cx_stop = Math.max(lx, hasAlpha ? 5 : 3);
147        int cy_stop = Math.max(ly, hasAlpha ? 5 : 3);
148        float[] fx = new float[cx_stop];
149        float[] fy = new float[cy_stop];
150        for (int y = 0, i = 0; y < h; y++) {
151            for (int x = 0; x < w; x++, i += 4) {
152                float l = l_dc, p = p_dc, q = q_dc, a = a_dc;
153
154                // Precompute the coefficients
155                for (int cx = 0; cx < cx_stop; cx++)
156                    fx[cx] = (float) Math.cos(Math.PI / w * (x + 0.5f) * cx);
157                for (int cy = 0; cy < cy_stop; cy++)
158                    fy[cy] = (float) Math.cos(Math.PI / h * (y + 0.5f) * cy);
159
160                // Decode L
161                for (int cy = 0, j = 0; cy < ly; cy++) {
162                    float fy2 = fy[cy] * 2.0f;
163                    for (int cx = cy > 0 ? 0 : 1; cx * ly < lx * (ly - cy); cx++, j++)
164                        l += l_ac[j] * fx[cx] * fy2;
165                }
166
167                // Decode P and Q
168                for (int cy = 0, j = 0; cy < 3; cy++) {
169                    float fy2 = fy[cy] * 2.0f;
170                    for (int cx = cy > 0 ? 0 : 1; cx < 3 - cy; cx++, j++) {
171                        float f = fx[cx] * fy2;
172                        p += p_ac[j] * f;
173                        q += q_ac[j] * f;
174                    }
175                }
176
177                // Decode A
178                if (hasAlpha)
179                    for (int cy = 0, j = 0; cy < 5; cy++) {
180                        float fy2 = fy[cy] * 2.0f;
181                        for (int cx = cy > 0 ? 0 : 1; cx < 5 - cy; cx++, j++)
182                            a += a_ac[j] * fx[cx] * fy2;
183                    }
184
185                // Convert to RGB
186                float b = l - 2.0f / 3.0f * p;
187                float r = (3.0f * l - b + q) / 2.0f;
188                float g = r - q;
189                rgba[i] = (byte) Math.max(0, Math.round(255.0f * Math.min(1, r)));
190                rgba[i + 1] = (byte) Math.max(0, Math.round(255.0f * Math.min(1, g)));
191                rgba[i + 2] = (byte) Math.max(0, Math.round(255.0f * Math.min(1, b)));
192                rgba[i + 3] = (byte) Math.max(0, Math.round(255.0f * Math.min(1, a)));
193            }
194        }
195        return new Image(w, h, rgba);
196    }
197
198    /**
199     * Extracts the average color from a ThumbHash. RGB is not be premultiplied by A.
200     *
201     * @param hash The bytes of the ThumbHash.
202     * @return The RGBA values for the average color. Each value ranges from 0 to 1.
203     */
204    public static RGBA thumbHashToAverageRGBA(byte[] hash) {
205        int header = (hash[0] & 255) | ((hash[1] & 255) << 8) | ((hash[2] & 255) << 16);
206        float l = (float) (header & 63) / 63.0f;
207        float p = (float) ((header >> 6) & 63) / 31.5f - 1.0f;
208        float q = (float) ((header >> 12) & 63) / 31.5f - 1.0f;
209        boolean hasAlpha = (header >> 23) != 0;
210        float a = hasAlpha ? (float) (hash[5] & 15) / 15.0f : 1.0f;
211        float b = l - 2.0f / 3.0f * p;
212        float r = (3.0f * l - b + q) / 2.0f;
213        float g = r - q;
214        return new RGBA(
215                Math.max(0, Math.min(1, r)),
216                Math.max(0, Math.min(1, g)),
217                Math.max(0, Math.min(1, b)),
218                a);
219    }
220
221    /**
222     * Extracts the approximate aspect ratio of the original image.
223     *
224     * @param hash The bytes of the ThumbHash.
225     * @return The approximate aspect ratio (i.e. width / height).
226     */
227    public static float thumbHashToApproximateAspectRatio(byte[] hash) {
228        byte header = hash[3];
229        boolean hasAlpha = (hash[2] & 0x80) != 0;
230        boolean isLandscape = (hash[4] & 0x80) != 0;
231        int lx = isLandscape ? hasAlpha ? 5 : 7 : header & 7;
232        int ly = isLandscape ? header & 7 : hasAlpha ? 5 : 7;
233        return (float) lx / (float) ly;
234    }
235
236    public static final class Image {
237        public int width;
238        public int height;
239        public byte[] rgba;
240
241        public Image(int width, int height, byte[] rgba) {
242            this.width = width;
243            this.height = height;
244            this.rgba = rgba;
245        }
246    }
247
248    public static final class RGBA {
249        public float r;
250        public float g;
251        public float b;
252        public float a;
253
254        public RGBA(float r, float g, float b, float a) {
255            this.r = r;
256            this.g = g;
257            this.b = b;
258            this.a = a;
259        }
260    }
261
262    private static final class Channel {
263        int nx;
264        int ny;
265        float dc;
266        float[] ac;
267        float scale;
268
269        Channel(int nx, int ny) {
270            this.nx = nx;
271            this.ny = ny;
272            int n = 0;
273            for (int cy = 0; cy < ny; cy++)
274                for (int cx = cy > 0 ? 0 : 1; cx * ny < nx * (ny - cy); cx++)
275                    n++;
276            ac = new float[n];
277        }
278
279        Channel encode(int w, int h, float[] channel) {
280            int n = 0;
281            float[] fx = new float[w];
282            for (int cy = 0; cy < ny; cy++) {
283                for (int cx = 0; cx * ny < nx * (ny - cy); cx++) {
284                    float f = 0;
285                    for (int x = 0; x < w; x++)
286                        fx[x] = (float) Math.cos(Math.PI / w * cx * (x + 0.5f));
287                    for (int y = 0; y < h; y++) {
288                        float fy = (float) Math.cos(Math.PI / h * cy * (y + 0.5f));
289                        for (int x = 0; x < w; x++)
290                            f += channel[x + y * w] * fx[x] * fy;
291                    }
292                    f /= w * h;
293                    if (cx > 0 || cy > 0) {
294                        ac[n++] = f;
295                        scale = Math.max(scale, Math.abs(f));
296                    } else {
297                        dc = f;
298                    }
299                }
300            }
301            if (scale > 0)
302                for (int i = 0; i < ac.length; i++)
303                    ac[i] = 0.5f + 0.5f / scale * ac[i];
304            return this;
305        }
306
307        int decode(byte[] hash, int start, int index, float scale) {
308            for (int i = 0; i < ac.length; i++) {
309                int data = hash[start + (index >> 1)] >> ((index & 1) << 2);
310                ac[i] = ((float) (data & 15) / 7.5f - 1.0f) * scale;
311                index++;
312            }
313            return index;
314        }
315
316        int writeTo(byte[] hash, int start, int index) {
317            for (float v : ac) {
318                hash[start + (index >> 1)] |= Math.round(15.0f * v) << ((index & 1) << 2);
319                index++;
320            }
321            return index;
322        }
323    }
324}