Forward DWT 9-7: major speed up by vectorizing vertical pass
authorEven Rouault <even.rouault@spatialys.com>
Fri, 22 May 2020 21:57:51 +0000 (23:57 +0200)
committerEven Rouault <even.rouault@spatialys.com>
Fri, 22 May 2020 23:01:05 +0000 (01:01 +0200)
`bench_dwt -I -encode` times goes from 8.6s to 2.1s

src/lib/openjp2/dwt.c

index c422917ce82a7876539cd69c5021d7e64e62703b..ee9eb5e63894051a5afc87d487964003f5c44047 100644 (file)
@@ -125,13 +125,6 @@ static void opj_dwt_deinterleave_h(const OPJ_INT32 * OPJ_RESTRICT a,
                                    OPJ_INT32 * OPJ_RESTRICT b,
                                    OPJ_INT32 dn,
                                    OPJ_INT32 sn, OPJ_INT32 cas);
-/**
-Forward lazy transform (vertical)
-*/
-static void opj_dwt_deinterleave_v(const OPJ_INT32 * OPJ_RESTRICT a,
-                                   OPJ_INT32 * OPJ_RESTRICT b,
-                                   OPJ_INT32 dn,
-                                   OPJ_INT32 sn, OPJ_UINT32 x, OPJ_INT32 cas);
 
 /**
 Forward 9-7 wavelet transform in 1-D
@@ -252,35 +245,6 @@ static void opj_dwt_deinterleave_h(const OPJ_INT32 * OPJ_RESTRICT a,
     }
 }
 
-/* <summary>                             */
-/* Forward lazy transform (vertical).    */
-/* </summary>                            */
-static void opj_dwt_deinterleave_v(const OPJ_INT32 * OPJ_RESTRICT a,
-                                   OPJ_INT32 * OPJ_RESTRICT b,
-                                   OPJ_INT32 dn,
-                                   OPJ_INT32 sn, OPJ_UINT32 x, OPJ_INT32 cas)
-{
-    OPJ_INT32 i = sn;
-    OPJ_INT32 * OPJ_RESTRICT l_dest = b;
-    const OPJ_INT32 * OPJ_RESTRICT l_src = a + cas;
-
-    while (i--) {
-        *l_dest = *l_src;
-        l_dest += x;
-        l_src += 2;
-    } /* b[i*x]=a[2*i+cas]; */
-
-    l_dest = b + (OPJ_SIZE_T)sn * (OPJ_SIZE_T)x;
-    l_src = a + 1 - cas;
-
-    i = dn;
-    while (i--) {
-        *l_dest = *l_src;
-        l_dest += x;
-        l_src += 2;
-    } /*b[(sn+i)*x]=a[(2*i+1-cas)];*/
-}
-
 #ifdef STANDARD_SLOW_VERSION
 /* <summary>                             */
 /* Inverse lazy transform (horizontal).  */
@@ -989,36 +953,85 @@ static void opj_idwt53_v(const opj_dwt_t *dwt,
 #endif
 }
 
+#if 0
 static void opj_dwt_encode_step1(OPJ_FLOAT32* fw,
-                                 OPJ_UINT32 start,
                                  OPJ_UINT32 end,
                                  const OPJ_FLOAT32 c)
 {
-    OPJ_UINT32 i;
-    for (i = start; i < end; ++i) {
-        fw[i * 2] *= c;
+    OPJ_UINT32 i = 0;
+    for (; i < end; ++i) {
+        fw[0] *= c;
+        fw += 2;
+    }
+}
+#else
+static void opj_dwt_encode_step1_combined(OPJ_FLOAT32* fw,
+        OPJ_UINT32 iters_c1,
+        OPJ_UINT32 iters_c2,
+        const OPJ_FLOAT32 c1,
+        const OPJ_FLOAT32 c2)
+{
+    OPJ_UINT32 i = 0;
+    const OPJ_UINT32 iters_common =  opj_uint_min(iters_c1, iters_c2);
+    assert((((OPJ_SIZE_T)fw) & 0xf) == 0);
+    assert(opj_int_abs((OPJ_INT32)iters_c1 - (OPJ_INT32)iters_c2) <= 1);
+    for (; i + 3 < iters_common; i += 4) {
+#ifdef __SSE__
+        const __m128 vcst = _mm_set_ps(c2, c1, c2, c1);
+        *(__m128*)fw = _mm_mul_ps(*(__m128*)fw, vcst);
+        *(__m128*)(fw + 4) = _mm_mul_ps(*(__m128*)(fw + 4), vcst);
+#else
+        fw[0] *= c1;
+        fw[1] *= c2;
+        fw[2] *= c1;
+        fw[3] *= c2;
+        fw[4] *= c1;
+        fw[5] *= c2;
+        fw[6] *= c1;
+        fw[7] *= c2;
+#endif
+        fw += 8;
+    }
+    for (; i < iters_common; i++) {
+        fw[0] *= c1;
+        fw[1] *= c2;
+        fw += 2;
+    }
+    if (i < iters_c1) {
+        fw[0] *= c1;
+    } else if (i < iters_c2) {
+        fw[1] *= c2;
     }
 }
+
+#endif
+
 static void opj_dwt_encode_step2(OPJ_FLOAT32* fl, OPJ_FLOAT32* fw,
-                                 OPJ_UINT32 start,
                                  OPJ_UINT32 end,
                                  OPJ_UINT32 m,
                                  OPJ_FLOAT32 c)
 {
     OPJ_UINT32 i;
     OPJ_UINT32 imax = opj_uint_min(end, m);
-    if (start > 0) {
-        fw += 2 * start;
-        fl = fw - 2;
-    }
-    for (i = start; i < imax; ++i) {
+    if (imax > 0) {
         fw[-1] += (fl[0] + fw[0]) * c;
-        fl = fw;
         fw += 2;
+        i = 1;
+        for (; i + 3 < imax; i += 4) {
+            fw[-1] += (fw[-2] + fw[0]) * c;
+            fw[1] += (fw[0] + fw[2]) * c;
+            fw[3] += (fw[2] + fw[4]) * c;
+            fw[5] += (fw[4] + fw[6]) * c;
+            fw += 8;
+        }
+        for (; i < imax; ++i) {
+            fw[-1] += (fw[-2] + fw[0]) * c;
+            fw += 2;
+        }
     }
     if (m < end) {
         assert(m + 1 == end);
-        fw[-1] += (2 * fl[0]) * c;
+        fw[-1] += (2 * fw[-2]) * c;
     }
 }
 
@@ -1027,39 +1040,50 @@ static void opj_dwt_encode_1_real(void *aIn, OPJ_INT32 dn, OPJ_INT32 sn,
 {
     OPJ_FLOAT32* w = (OPJ_FLOAT32*)aIn;
     OPJ_INT32 a, b;
+    assert(dn + sn > 1);
     if (cas == 0) {
-        if (!((dn > 0) || (sn > 1))) {
-            return;
-        }
         a = 0;
         b = 1;
     } else {
-        if (!((sn > 0) || (dn > 1))) {
-            return;
-        }
         a = 1;
         b = 0;
     }
     opj_dwt_encode_step2(w + a, w + b + 1,
-                         0, (OPJ_UINT32)dn,
+                         (OPJ_UINT32)dn,
                          (OPJ_UINT32)opj_int_min(dn, sn - b),
                          opj_dwt_alpha);
     opj_dwt_encode_step2(w + b, w + a + 1,
-                         0, (OPJ_UINT32)sn,
+                         (OPJ_UINT32)sn,
                          (OPJ_UINT32)opj_int_min(sn, dn - a),
                          opj_dwt_beta);
     opj_dwt_encode_step2(w + a, w + b + 1,
-                         0, (OPJ_UINT32)dn,
+                         (OPJ_UINT32)dn,
                          (OPJ_UINT32)opj_int_min(dn, sn - b),
                          opj_dwt_gamma);
     opj_dwt_encode_step2(w + b, w + a + 1,
-                         0, (OPJ_UINT32)sn,
+                         (OPJ_UINT32)sn,
                          (OPJ_UINT32)opj_int_min(sn, dn - a),
                          opj_dwt_delta);
-    opj_dwt_encode_step1(w + b, 0, (OPJ_UINT32)dn,
+#if 0
+    opj_dwt_encode_step1(w + b, (OPJ_UINT32)dn,
                          opj_K);
-    opj_dwt_encode_step1(w + a, 0, (OPJ_UINT32)sn,
+    opj_dwt_encode_step1(w + a, (OPJ_UINT32)sn,
                          opj_invK);
+#else
+    if (a == 0) {
+        opj_dwt_encode_step1_combined(w,
+                                      (OPJ_UINT32)sn,
+                                      (OPJ_UINT32)dn,
+                                      opj_invK,
+                                      opj_K);
+    } else {
+        opj_dwt_encode_step1_combined(w,
+                                      (OPJ_UINT32)dn,
+                                      (OPJ_UINT32)sn,
+                                      opj_K,
+                                      opj_invK);
+    }
+#endif
 }
 
 static void opj_dwt_encode_stepsize(OPJ_INT32 stepsize, OPJ_INT32 numbps,
@@ -1143,6 +1167,9 @@ void opj_dwt_encode_and_deinterleave_h_one_row_real(void* rowIn,
     OPJ_FLOAT32* OPJ_RESTRICT tmp = (OPJ_FLOAT32*)tmpIn;
     const OPJ_INT32 sn = (OPJ_INT32)((width + (even ? 1 : 0)) >> 1);
     const OPJ_INT32 dn = (OPJ_INT32)(width - (OPJ_UINT32)sn);
+    if (width == 1) {
+        return;
+    }
     memcpy(tmp, row, width * sizeof(OPJ_FLOAT32));
     opj_dwt_encode_1_real(tmp, dn, sn, even ? 0 : 1);
     opj_dwt_deinterleave_h((OPJ_INT32 * OPJ_RESTRICT)tmp,
@@ -1258,29 +1285,49 @@ static INLINE void opj_dwt_deinterleave_v_cols(
     OPJ_INT32 cas,
     OPJ_UINT32 cols)
 {
+    OPJ_INT32 k;
     OPJ_INT32 i = sn;
     OPJ_INT32 * OPJ_RESTRICT l_dest = dst;
     const OPJ_INT32 * OPJ_RESTRICT l_src = src + cas * NB_ELTS_V8;
     OPJ_UINT32 c;
 
-    while (i--) {
-        for (c = 0; c < cols; c++) {
-            l_dest[c] = l_src[c];
+    for (k = 0; k < 2; k++) {
+        while (i--) {
+            if (cols == NB_ELTS_V8) {
+                memcpy(l_dest, l_src, NB_ELTS_V8 * sizeof(OPJ_INT32));
+            } else {
+                c = 0;
+                switch (cols) {
+                case 7:
+                    l_dest[c] = l_src[c];
+                    c++; /* fallthru */
+                case 6:
+                    l_dest[c] = l_src[c];
+                    c++; /* fallthru */
+                case 5:
+                    l_dest[c] = l_src[c];
+                    c++; /* fallthru */
+                case 4:
+                    l_dest[c] = l_src[c];
+                    c++; /* fallthru */
+                case 3:
+                    l_dest[c] = l_src[c];
+                    c++; /* fallthru */
+                case 2:
+                    l_dest[c] = l_src[c];
+                    c++; /* fallthru */
+                default:
+                    l_dest[c] = l_src[c];
+                    break;
+                }
+            }
+            l_dest += stride_width;
+            l_src += 2 * NB_ELTS_V8;
         }
-        l_dest += stride_width;
-        l_src += 2 * NB_ELTS_V8;
-    }
 
-    l_dest = dst + (OPJ_SIZE_T)sn * (OPJ_SIZE_T)stride_width;
-    l_src = src + (1 - cas) * NB_ELTS_V8;
-
-    i = dn;
-    while (i--) {
-        for (c = 0; c < cols; c++) {
-            l_dest[c] = l_src[c];
-        }
-        l_dest += stride_width;
-        l_src += 2 * NB_ELTS_V8;
+        l_dest = dst + (OPJ_SIZE_T)sn * (OPJ_SIZE_T)stride_width;
+        l_src = src + (1 - cas) * NB_ELTS_V8;
+        i = dn;
     }
 }
 
@@ -1517,6 +1564,84 @@ static void opj_dwt_encode_and_deinterleave_v(
     }
 }
 
+static void opj_v8dwt_encode_step1(OPJ_FLOAT32* fw,
+                                   OPJ_UINT32 end,
+                                   const OPJ_FLOAT32 cst)
+{
+    OPJ_UINT32 i;
+#ifdef __SSE__
+    __m128* vw = (__m128*) fw;
+    const __m128 vcst = _mm_set1_ps(cst);
+    for (i = 0; i < end; ++i) {
+        vw[0] = _mm_mul_ps(vw[0], vcst);
+        vw[1] = _mm_mul_ps(vw[1], vcst);
+        vw += 2 * (NB_ELTS_V8 * sizeof(OPJ_FLOAT32) / sizeof(__m128));
+    }
+#else
+    OPJ_UINT32 c;
+    for (i = 0; i < end; ++i) {
+        for (c = 0; c < NB_ELTS_V8; c++) {
+            fw[i * 2 * NB_ELTS_V8 + c] *= cst;
+        }
+    }
+#endif
+}
+
+static void opj_v8dwt_encode_step2(OPJ_FLOAT32* fl, OPJ_FLOAT32* fw,
+                                   OPJ_UINT32 end,
+                                   OPJ_UINT32 m,
+                                   OPJ_FLOAT32 cst)
+{
+    OPJ_UINT32 i;
+    OPJ_UINT32 imax = opj_uint_min(end, m);
+#ifdef __SSE__
+    __m128* vw = (__m128*) fw;
+    __m128 vcst = _mm_set1_ps(cst);
+    if (imax > 0) {
+        __m128* vl = (__m128*) fl;
+        vw[-2] = _mm_add_ps(vw[-2], _mm_mul_ps(_mm_add_ps(vl[0], vw[0]), vcst));
+        vw[-1] = _mm_add_ps(vw[-1], _mm_mul_ps(_mm_add_ps(vl[1], vw[1]), vcst));
+        vw += 2 * (NB_ELTS_V8 * sizeof(OPJ_FLOAT32) / sizeof(__m128));
+        i = 1;
+
+        for (; i < imax; ++i) {
+            vw[-2] = _mm_add_ps(vw[-2], _mm_mul_ps(_mm_add_ps(vw[-4], vw[0]), vcst));
+            vw[-1] = _mm_add_ps(vw[-1], _mm_mul_ps(_mm_add_ps(vw[-3], vw[1]), vcst));
+            vw += 2 * (NB_ELTS_V8 * sizeof(OPJ_FLOAT32) / sizeof(__m128));
+        }
+    }
+    if (m < end) {
+        assert(m + 1 == end);
+        vcst = _mm_add_ps(vcst, vcst);
+        vw[-2] = _mm_add_ps(vw[-2], _mm_mul_ps(vw[-4], vcst));
+        vw[-1] = _mm_add_ps(vw[-1], _mm_mul_ps(vw[-3], vcst));
+    }
+#else
+    OPJ_INT32 c;
+    if (imax > 0) {
+        for (c = 0; c < NB_ELTS_V8; c++) {
+            fw[-1 * NB_ELTS_V8 + c] += (fl[0 * NB_ELTS_V8 + c] + fw[0 * NB_ELTS_V8 + c]) *
+                                       cst;
+        }
+        fw += 2 * NB_ELTS_V8;
+        i = 1;
+        for (; i < imax; ++i) {
+            for (c = 0; c < NB_ELTS_V8; c++) {
+                fw[-1 * NB_ELTS_V8 + c] += (fw[-2 * NB_ELTS_V8 + c] + fw[0 * NB_ELTS_V8 + c]) *
+                                           cst;
+            }
+            fw += 2 * NB_ELTS_V8;
+        }
+    }
+    if (m < end) {
+        assert(m + 1 == end);
+        for (c = 0; c < NB_ELTS_V8; c++) {
+            fw[-1 * NB_ELTS_V8 + c] += (2 * fw[-2 * NB_ELTS_V8 + c]) * cst;
+        }
+    }
+#endif
+}
+
 /* Forward 9-7 transform, for the vertical pass, processing cols columns */
 /* where cols <= NB_ELTS_V8 */
 static void opj_dwt_encode_and_deinterleave_v_real(
@@ -1529,20 +1654,59 @@ static void opj_dwt_encode_and_deinterleave_v_real(
 {
     OPJ_FLOAT32* OPJ_RESTRICT array = (OPJ_FLOAT32 * OPJ_RESTRICT)arrayIn;
     OPJ_FLOAT32* OPJ_RESTRICT tmp = (OPJ_FLOAT32 * OPJ_RESTRICT)tmpIn;
-    OPJ_UINT32 c;
     const OPJ_INT32 sn = (OPJ_INT32)((height + (even ? 1 : 0)) >> 1);
     const OPJ_INT32 dn = (OPJ_INT32)(height - (OPJ_UINT32)sn);
-    for (c = 0; c < cols; c++) {
-        OPJ_UINT32 k;
-        for (k = 0; k < height; ++k) {
-            tmp[k] = array[c + k * stride_width];
-        }
+    OPJ_INT32 a, b;
 
-        opj_dwt_encode_1_real(tmp, dn, sn, even ? 0 : 1);
+    if (height == 1) {
+        return;
+    }
 
-        opj_dwt_deinterleave_v((OPJ_INT32*)tmpIn,
-                               ((OPJ_INT32*)(arrayIn)) + c,
-                               dn, sn, stride_width, even ? 0 : 1);
+    opj_dwt_fetch_cols_vertical_pass(arrayIn, tmpIn, height, stride_width, cols);
+
+    if (even) {
+        a = 0;
+        b = 1;
+    } else {
+        a = 1;
+        b = 0;
+    }
+    opj_v8dwt_encode_step2(tmp + a * NB_ELTS_V8,
+                           tmp + (b + 1) * NB_ELTS_V8,
+                           (OPJ_UINT32)dn,
+                           (OPJ_UINT32)opj_int_min(dn, sn - b),
+                           opj_dwt_alpha);
+    opj_v8dwt_encode_step2(tmp + b * NB_ELTS_V8,
+                           tmp + (a + 1) * NB_ELTS_V8,
+                           (OPJ_UINT32)sn,
+                           (OPJ_UINT32)opj_int_min(sn, dn - a),
+                           opj_dwt_beta);
+    opj_v8dwt_encode_step2(tmp + a * NB_ELTS_V8,
+                           tmp + (b + 1) * NB_ELTS_V8,
+                           (OPJ_UINT32)dn,
+                           (OPJ_UINT32)opj_int_min(dn, sn - b),
+                           opj_dwt_gamma);
+    opj_v8dwt_encode_step2(tmp + b * NB_ELTS_V8,
+                           tmp + (a + 1) * NB_ELTS_V8,
+                           (OPJ_UINT32)sn,
+                           (OPJ_UINT32)opj_int_min(sn, dn - a),
+                           opj_dwt_delta);
+    opj_v8dwt_encode_step1(tmp + b * NB_ELTS_V8, (OPJ_UINT32)dn,
+                           opj_K);
+    opj_v8dwt_encode_step1(tmp + a * NB_ELTS_V8, (OPJ_UINT32)sn,
+                           opj_invK);
+
+
+    if (cols == NB_ELTS_V8) {
+        opj_dwt_deinterleave_v_cols((OPJ_INT32*)tmp,
+                                    (OPJ_INT32*)array,
+                                    (OPJ_INT32)dn, (OPJ_INT32)sn,
+                                    stride_width, even ? 0 : 1, NB_ELTS_V8);
+    } else {
+        opj_dwt_deinterleave_v_cols((OPJ_INT32*)tmp,
+                                    (OPJ_INT32*)array,
+                                    (OPJ_INT32)dn, (OPJ_INT32)sn,
+                                    stride_width, even ? 0 : 1, cols);
     }
 }