Forward DWT: small code refactoring to allow future improvements for the vertical...
authorEven Rouault <even.rouault@spatialys.com>
Fri, 22 May 2020 13:58:47 +0000 (15:58 +0200)
committerEven Rouault <even.rouault@spatialys.com>
Fri, 22 May 2020 14:01:45 +0000 (16:01 +0200)
src/lib/openjp2/dwt.c

index 79be0f568d68cfc65cc05df73e59cee9b2310ba6..4f54c57ad441e3a96047a19d3a94a31b1f1a635a 100644 (file)
@@ -157,10 +157,18 @@ static OPJ_BOOL opj_dwt_decode_partial_tile(
     opj_tcd_tilecomp_t* tilec,
     OPJ_UINT32 numres);
 
+/* Forward transform, for the vertical pass, processing cols columns */
+/* where cols <= NB_ELTS_V8 */
 /* Where void* is a OPJ_INT32* for 5x3 and OPJ_FLOAT32* for 9x7 */
-typedef void (*opj_encode_one_row_fnptr_type)(void *, OPJ_INT32, OPJ_INT32,
-        OPJ_INT32);
+typedef void (*opj_encode_and_deinterleave_v_fnptr_type)(
+    void *array,
+    void *tmp,
+    OPJ_UINT32 height,
+    OPJ_BOOL even,
+    OPJ_UINT32 stride_width,
+    OPJ_UINT32 cols);
 
+/* Where void* is a OPJ_INT32* for 5x3 and OPJ_FLOAT32* for 9x7 */
 typedef void (*opj_encode_and_deinterleave_h_one_row_fnptr_type)(
     void *row,
     void *tmp,
@@ -169,7 +177,7 @@ typedef void (*opj_encode_and_deinterleave_h_one_row_fnptr_type)(
 
 static OPJ_BOOL opj_dwt_encode_procedure(opj_thread_pool_t* tp,
         opj_tcd_tilecomp_t * tilec,
-        opj_encode_one_row_fnptr_type p_function,
+        opj_encode_and_deinterleave_v_fnptr_type p_encode_and_deinterleave_v,
         opj_encode_and_deinterleave_h_one_row_fnptr_type
         p_encode_and_deinterleave_h_one_row);
 
@@ -1226,7 +1234,7 @@ typedef struct {
     OPJ_INT32 * OPJ_RESTRICT tiledp;
     OPJ_UINT32 min_j;
     OPJ_UINT32 max_j;
-    opj_encode_one_row_fnptr_type p_function;
+    opj_encode_and_deinterleave_v_fnptr_type p_encode_and_deinterleave_v;
 } opj_dwt_encode_v_job_t;
 
 static void opj_dwt_encode_v_func(void* user_data, opj_tls_t* tls)
@@ -1236,29 +1244,90 @@ static void opj_dwt_encode_v_func(void* user_data, opj_tls_t* tls)
     (void)tls;
 
     job = (opj_dwt_encode_v_job_t*)user_data;
-    for (j = job->min_j; j < job->max_j; j++) {
-        OPJ_INT32* OPJ_RESTRICT aj = job->tiledp + j;
+    for (j = job->min_j; j + NB_ELTS_V8 - 1 < job->max_j; j += NB_ELTS_V8) {
+        (*job->p_encode_and_deinterleave_v)(job->tiledp + j,
+                                            job->v.mem,
+                                            job->rh,
+                                            job->v.cas == 0,
+                                            job->w,
+                                            NB_ELTS_V8);
+    }
+    if (j < job->max_j) {
+        (*job->p_encode_and_deinterleave_v)(job->tiledp + j,
+                                            job->v.mem,
+                                            job->rh,
+                                            job->v.cas == 0,
+                                            job->w,
+                                            job->max_j - j);
+    }
+
+    opj_aligned_free(job->v.mem);
+    opj_free(job);
+}
+
+/* Forward 5-3 transform, for the vertical pass, processing cols columns */
+/* where cols <= NB_ELTS_V8 */
+static void opj_dwt_encode_and_deinterleave_v(
+    void *arrayIn,
+    void *tmpIn,
+    OPJ_UINT32 height,
+    OPJ_BOOL even,
+    OPJ_UINT32 stride_width,
+    OPJ_UINT32 cols)
+{
+    OPJ_INT32* OPJ_RESTRICT array = (OPJ_INT32 * OPJ_RESTRICT)arrayIn;
+    OPJ_INT32* OPJ_RESTRICT tmp = (OPJ_INT32 * OPJ_RESTRICT)tmpIn;
+    OPJ_UINT32 c;
+    const OPJ_INT32 sn = (OPJ_INT32)((height + (even ? 1 : 0)) >> 1);
+    const OPJ_INT32 dn = (OPJ_INT32)(height - (OPJ_UINT32)sn);
+    for (c = 0; c < cols; c++) {
         OPJ_UINT32 k;
-        for (k = 0; k < job->rh; ++k) {
-            job->v.mem[k] = aj[k * job->w];
+        for (k = 0; k < height; ++k) {
+            tmp[k] = array[c + k * stride_width];
         }
 
-        (*job->p_function)(job->v.mem, job->v.dn, job->v.sn, job->v.cas);
+        opj_dwt_encode_1(tmp, dn, sn, even ? 0 : 1);
 
-        opj_dwt_deinterleave_v(job->v.mem, aj, job->v.dn, job->v.sn, job->w,
-                               job->v.cas);
+        opj_dwt_deinterleave_v(tmp, array + c, dn, sn, stride_width, even ? 0 : 1);
     }
+}
 
-    opj_aligned_free(job->v.mem);
-    opj_free(job);
+/* Forward 9-7 transform, for the vertical pass, processing cols columns */
+/* where cols <= NB_ELTS_V8 */
+static void opj_dwt_encode_and_deinterleave_v_real(
+    void *arrayIn,
+    void *tmpIn,
+    OPJ_UINT32 height,
+    OPJ_BOOL even,
+    OPJ_UINT32 stride_width,
+    OPJ_UINT32 cols)
+{
+    OPJ_FLOAT32* OPJ_RESTRICT array = (OPJ_FLOAT32 * OPJ_RESTRICT)arrayIn;
+    OPJ_FLOAT32* OPJ_RESTRICT tmp = (OPJ_FLOAT32 * OPJ_RESTRICT)tmpIn;
+    OPJ_UINT32 c;
+    const OPJ_INT32 sn = (OPJ_INT32)((height + (even ? 1 : 0)) >> 1);
+    const OPJ_INT32 dn = (OPJ_INT32)(height - (OPJ_UINT32)sn);
+    for (c = 0; c < cols; c++) {
+        OPJ_UINT32 k;
+        for (k = 0; k < height; ++k) {
+            tmp[k] = array[c + k * stride_width];
+        }
+
+        opj_dwt_encode_1_real(tmp, dn, sn, even ? 0 : 1);
+
+        opj_dwt_deinterleave_v((OPJ_INT32*)tmpIn,
+                               ((OPJ_INT32*)(arrayIn)) + c,
+                               dn, sn, stride_width, even ? 0 : 1);
+    }
 }
 
+
 /* <summary>                            */
 /* Forward 5-3 wavelet transform in 2-D. */
 /* </summary>                           */
 static INLINE OPJ_BOOL opj_dwt_encode_procedure(opj_thread_pool_t* tp,
         opj_tcd_tilecomp_t * tilec,
-        opj_encode_one_row_fnptr_type p_function,
+        opj_encode_and_deinterleave_v_fnptr_type p_encode_and_deinterleave_v,
         opj_encode_and_deinterleave_h_one_row_fnptr_type
         p_encode_and_deinterleave_h_one_row)
 {
@@ -1282,11 +1351,11 @@ static INLINE OPJ_BOOL opj_dwt_encode_procedure(opj_thread_pool_t* tp,
 
     l_data_size = opj_dwt_max_resolution(tilec->resolutions, tilec->numresolutions);
     /* overflow check */
-    if (l_data_size > (SIZE_MAX / sizeof(OPJ_INT32))) {
+    if (l_data_size > (SIZE_MAX / (NB_ELTS_V8 * sizeof(OPJ_INT32)))) {
         /* FIXME event manager error callback */
         return OPJ_FALSE;
     }
-    l_data_size *= sizeof(OPJ_INT32);
+    l_data_size *= NB_ELTS_V8 * sizeof(OPJ_INT32);
     bj = (OPJ_INT32*)opj_aligned_32_malloc(l_data_size);
     /* l_data_size is equal to 0 when numresolutions == 1 but bj is not used */
     /* in that case, so do not error out */
@@ -1319,17 +1388,22 @@ static INLINE OPJ_BOOL opj_dwt_encode_procedure(opj_thread_pool_t* tp,
         dn = (OPJ_INT32)(rh - rh1);
 
         /* Perform vertical pass */
-        if (num_threads <= 1 || rw <= 1) {
-            for (j = 0; j < rw; ++j) {
-                OPJ_INT32* OPJ_RESTRICT aj = tiledp + j;
-                OPJ_UINT32 k;
-                for (k = 0; k < rh; ++k) {
-                    bj[k] = aj[k * w];
-                }
-
-                (*p_function)(bj, dn, sn, cas_col);
-
-                opj_dwt_deinterleave_v(bj, aj, dn, sn, w, cas_col);
+        if (num_threads <= 1 || rw < 2 * NB_ELTS_V8) {
+            for (j = 0; j + NB_ELTS_V8 - 1 < rw; j += NB_ELTS_V8) {
+                p_encode_and_deinterleave_v(tiledp + j,
+                                            bj,
+                                            rh,
+                                            cas_col == 0,
+                                            w,
+                                            NB_ELTS_V8);
+            }
+            if (j < rw) {
+                p_encode_and_deinterleave_v(tiledp + j,
+                                            bj,
+                                            rh,
+                                            cas_col == 0,
+                                            w,
+                                            rw - j);
             }
         }  else {
             OPJ_UINT32 num_jobs = (OPJ_UINT32)num_threads;
@@ -1338,7 +1412,7 @@ static INLINE OPJ_BOOL opj_dwt_encode_procedure(opj_thread_pool_t* tp,
             if (rw < num_jobs) {
                 num_jobs = rw;
             }
-            step_j = (rw / num_jobs);
+            step_j = ((rw / num_jobs) / NB_ELTS_V8) * NB_ELTS_V8;
 
             for (j = 0; j < num_jobs; j++) {
                 opj_dwt_encode_v_job_t* job;
@@ -1363,11 +1437,8 @@ static INLINE OPJ_BOOL opj_dwt_encode_procedure(opj_thread_pool_t* tp,
                 job->w = w;
                 job->tiledp = tiledp;
                 job->min_j = j * step_j;
-                job->max_j = (j + 1U) * step_j; /* this can overflow */
-                if (j == (num_jobs - 1U)) {  /* this will take care of the overflow */
-                    job->max_j = rw;
-                }
-                job->p_function = p_function;
+                job->max_j = (j + 1 == num_jobs) ? rw : (j + 1) * step_j;
+                job->p_encode_and_deinterleave_v = p_encode_and_deinterleave_v;
                 opj_thread_pool_submit_job(tp, opj_dwt_encode_v_func, job);
             }
             opj_thread_pool_wait_completion(tp, 0);
@@ -1440,7 +1511,7 @@ OPJ_BOOL opj_dwt_encode(opj_tcd_t *p_tcd,
                         opj_tcd_tilecomp_t * tilec)
 {
     return opj_dwt_encode_procedure(p_tcd->thread_pool, tilec,
-                                    opj_dwt_encode_1,
+                                    opj_dwt_encode_and_deinterleave_v,
                                     opj_dwt_encode_and_deinterleave_h_one_row);
 }
 
@@ -1480,7 +1551,7 @@ OPJ_BOOL opj_dwt_encode_real(opj_tcd_t *p_tcd,
                              opj_tcd_tilecomp_t * tilec)
 {
     return opj_dwt_encode_procedure(p_tcd->thread_pool, tilec,
-                                    opj_dwt_encode_1_real,
+                                    opj_dwt_encode_and_deinterleave_v_real,
                                     opj_dwt_encode_and_deinterleave_h_one_row_real);
 }