Speed-up 9x7 IDWD by ~30% with OPJ_NUM_THREADS=2
authorEven Rouault <even.rouault@spatialys.com>
Thu, 21 May 2020 14:54:48 +0000 (16:54 +0200)
committerEven Rouault <even.rouault@spatialys.com>
Thu, 21 May 2020 15:21:55 +0000 (17:21 +0200)
"bench_dwt -I" time goes from 2.2s to 1.5s

src/lib/openjp2/dwt.c

index 84d5aaf5376fa7bfcbe4840747b0180e09c9724b..8790626e5401333ebb245a37255f9f9795446ec7 100644 (file)
@@ -2643,12 +2643,99 @@ static void opj_v8dwt_decode(opj_v8dwt_t* OPJ_RESTRICT dwt)
 #endif
 }
 
+typedef struct {
+    opj_v8dwt_t h;
+    OPJ_UINT32 rw;
+    OPJ_UINT32 w;
+    OPJ_FLOAT32 * OPJ_RESTRICT aj;
+    OPJ_UINT32 nb_rows;
+} opj_dwt97_decode_h_job_t;
+
+static void opj_dwt97_decode_h_func(void* user_data, opj_tls_t* tls)
+{
+    OPJ_UINT32 j;
+    opj_dwt97_decode_h_job_t* job;
+    OPJ_FLOAT32 * OPJ_RESTRICT aj;
+    OPJ_UINT32 w;
+    (void)tls;
+
+    job = (opj_dwt97_decode_h_job_t*)user_data;
+    w = job->w;
+
+    assert((job->nb_rows % NB_ELTS_V8) == 0);
+
+    aj = job->aj;
+    for (j = 0; j + NB_ELTS_V8 <= job->nb_rows; j += NB_ELTS_V8) {
+        OPJ_UINT32 k;
+        opj_v8dwt_interleave_h(&job->h, aj, job->w, NB_ELTS_V8);
+        opj_v8dwt_decode(&job->h);
+
+        /* To be adapted if NB_ELTS_V8 changes */
+        for (k = 0; k < job->rw; k++) {
+            aj[k      ] = job->h.wavelet[k].f[0];
+            aj[k + (OPJ_SIZE_T)w  ] = job->h.wavelet[k].f[1];
+            aj[k + (OPJ_SIZE_T)w * 2] = job->h.wavelet[k].f[2];
+            aj[k + (OPJ_SIZE_T)w * 3] = job->h.wavelet[k].f[3];
+        }
+        for (k = 0; k < job->rw; k++) {
+            aj[k + (OPJ_SIZE_T)w * 4] = job->h.wavelet[k].f[4];
+            aj[k + (OPJ_SIZE_T)w * 5] = job->h.wavelet[k].f[5];
+            aj[k + (OPJ_SIZE_T)w * 6] = job->h.wavelet[k].f[6];
+            aj[k + (OPJ_SIZE_T)w * 7] = job->h.wavelet[k].f[7];
+        }
+
+        aj += w * NB_ELTS_V8;
+    }
+
+    opj_aligned_free(job->h.wavelet);
+    opj_free(job);
+}
+
+
+typedef struct {
+    opj_v8dwt_t v;
+    OPJ_UINT32 rh;
+    OPJ_UINT32 w;
+    OPJ_FLOAT32 * OPJ_RESTRICT aj;
+    OPJ_UINT32 nb_columns;
+} opj_dwt97_decode_v_job_t;
+
+static void opj_dwt97_decode_v_func(void* user_data, opj_tls_t* tls)
+{
+    OPJ_UINT32 j;
+    opj_dwt97_decode_v_job_t* job;
+    OPJ_FLOAT32 * OPJ_RESTRICT aj;
+    (void)tls;
+
+    job = (opj_dwt97_decode_v_job_t*)user_data;
+
+    assert((job->nb_columns % NB_ELTS_V8) == 0);
+
+    aj = job->aj;
+    for (j = 0; j + NB_ELTS_V8 <= job->nb_columns; j += NB_ELTS_V8) {
+        OPJ_UINT32 k;
+
+        opj_v8dwt_interleave_v(&job->v, aj, job->w, NB_ELTS_V8);
+        opj_v8dwt_decode(&job->v);
+
+        for (k = 0; k < job->rh; ++k) {
+            memcpy(&aj[k * (OPJ_SIZE_T)job->w], &job->v.wavelet[k],
+                   NB_ELTS_V8 * sizeof(OPJ_FLOAT32));
+        }
+        aj += NB_ELTS_V8;
+    }
+
+    opj_aligned_free(job->v.wavelet);
+    opj_free(job);
+}
+
 
 /* <summary>                             */
 /* Inverse 9-7 wavelet transform in 2-D. */
 /* </summary>                            */
 static
-OPJ_BOOL opj_dwt_decode_tile_97(opj_tcd_tilecomp_t* OPJ_RESTRICT tilec,
+OPJ_BOOL opj_dwt_decode_tile_97(opj_thread_pool_t* tp,
+                                opj_tcd_tilecomp_t* OPJ_RESTRICT tilec,
                                 OPJ_UINT32 numres)
 {
     opj_v8dwt_t h;
@@ -2666,6 +2753,7 @@ OPJ_BOOL opj_dwt_decode_tile_97(opj_tcd_tilecomp_t* OPJ_RESTRICT tilec,
                                 tilec->resolutions[tilec->minimum_num_resolutions - 1].x0);
 
     OPJ_SIZE_T l_data_size;
+    const int num_threads = opj_thread_pool_get_thread_count(tp);
 
     if (numres == 1) {
         return OPJ_TRUE;
@@ -2705,26 +2793,70 @@ OPJ_BOOL opj_dwt_decode_tile_97(opj_tcd_tilecomp_t* OPJ_RESTRICT tilec,
         h.win_l_x1 = (OPJ_UINT32)h.sn;
         h.win_h_x0 = 0;
         h.win_h_x1 = (OPJ_UINT32)h.dn;
-        for (j = 0; j + (NB_ELTS_V8 - 1) < rh; j += NB_ELTS_V8) {
-            OPJ_UINT32 k;
-            opj_v8dwt_interleave_h(&h, aj, w, rh - j);
-            opj_v8dwt_decode(&h);
 
-            /* To be adapted if NB_ELTS_V8 changes */
-            for (k = 0; k < rw; k++) {
-                aj[k      ] = h.wavelet[k].f[0];
-                aj[k + (OPJ_SIZE_T)w  ] = h.wavelet[k].f[1];
-                aj[k + (OPJ_SIZE_T)w * 2] = h.wavelet[k].f[2];
-                aj[k + (OPJ_SIZE_T)w * 3] = h.wavelet[k].f[3];
+        if (num_threads <= 1 || rh < 2 * NB_ELTS_V8) {
+            for (j = 0; j + (NB_ELTS_V8 - 1) < rh; j += NB_ELTS_V8) {
+                OPJ_UINT32 k;
+                opj_v8dwt_interleave_h(&h, aj, w, NB_ELTS_V8);
+                opj_v8dwt_decode(&h);
+
+                /* To be adapted if NB_ELTS_V8 changes */
+                for (k = 0; k < rw; k++) {
+                    aj[k      ] = h.wavelet[k].f[0];
+                    aj[k + (OPJ_SIZE_T)w  ] = h.wavelet[k].f[1];
+                    aj[k + (OPJ_SIZE_T)w * 2] = h.wavelet[k].f[2];
+                    aj[k + (OPJ_SIZE_T)w * 3] = h.wavelet[k].f[3];
+                }
+                for (k = 0; k < rw; k++) {
+                    aj[k + (OPJ_SIZE_T)w * 4] = h.wavelet[k].f[4];
+                    aj[k + (OPJ_SIZE_T)w * 5] = h.wavelet[k].f[5];
+                    aj[k + (OPJ_SIZE_T)w * 6] = h.wavelet[k].f[6];
+                    aj[k + (OPJ_SIZE_T)w * 7] = h.wavelet[k].f[7];
+                }
+
+                aj += w * NB_ELTS_V8;
             }
-            for (k = 0; k < rw; k++) {
-                aj[k + (OPJ_SIZE_T)w * 4] = h.wavelet[k].f[4];
-                aj[k + (OPJ_SIZE_T)w * 5] = h.wavelet[k].f[5];
-                aj[k + (OPJ_SIZE_T)w * 6] = h.wavelet[k].f[6];
-                aj[k + (OPJ_SIZE_T)w * 7] = h.wavelet[k].f[7];
+        } else {
+            OPJ_UINT32 num_jobs = (OPJ_UINT32)num_threads;
+            OPJ_UINT32 step_j;
+
+            if ((rh / NB_ELTS_V8) < num_jobs) {
+                num_jobs = rh / NB_ELTS_V8;
             }
+            step_j = ((rh / num_jobs) / NB_ELTS_V8) * NB_ELTS_V8;
+            for (j = 0; j < num_jobs; j++) {
+                opj_dwt97_decode_h_job_t* job;
 
-            aj += w * NB_ELTS_V8;
+                job = (opj_dwt97_decode_h_job_t*) opj_malloc(sizeof(opj_dwt97_decode_h_job_t));
+                if (!job) {
+                    opj_thread_pool_wait_completion(tp, 0);
+                    opj_aligned_free(h.wavelet);
+                    return OPJ_FALSE;
+                }
+                job->h.wavelet = (opj_v8_t*)opj_aligned_malloc(l_data_size * sizeof(opj_v8_t));
+                if (!job->h.wavelet) {
+                    opj_thread_pool_wait_completion(tp, 0);
+                    opj_free(job);
+                    opj_aligned_free(h.wavelet);
+                    return OPJ_FALSE;
+                }
+                job->h.dn = h.dn;
+                job->h.sn = h.sn;
+                job->h.cas = h.cas;
+                job->h.win_l_x0 = h.win_l_x0;
+                job->h.win_l_x1 = h.win_l_x1;
+                job->h.win_h_x0 = h.win_h_x0;
+                job->h.win_h_x1 = h.win_h_x1;
+                job->rw = rw;
+                job->w = w;
+                job->aj = aj;
+                job->nb_rows = (j + 1 == num_jobs) ? (rh & (OPJ_UINT32)~
+                                                      (NB_ELTS_V8 - 1)) - j * step_j : step_j;
+                aj += w * job->nb_rows;
+                opj_thread_pool_submit_job(tp, opj_dwt97_decode_h_func, job);
+            }
+            opj_thread_pool_wait_completion(tp, 0);
+            j = rh & (OPJ_UINT32)~(NB_ELTS_V8 - 1);
         }
 
         if (j < rh) {
@@ -2747,16 +2879,62 @@ OPJ_BOOL opj_dwt_decode_tile_97(opj_tcd_tilecomp_t* OPJ_RESTRICT tilec,
         v.win_h_x1 = (OPJ_UINT32)v.dn;
 
         aj = (OPJ_FLOAT32*) tilec->data;
-        for (j = rw; j > (NB_ELTS_V8 - 1); j -= NB_ELTS_V8) {
-            OPJ_UINT32 k;
+        if (num_threads <= 1 || rw < 2 * NB_ELTS_V8) {
+            for (j = rw; j > (NB_ELTS_V8 - 1); j -= NB_ELTS_V8) {
+                OPJ_UINT32 k;
 
-            opj_v8dwt_interleave_v(&v, aj, w, NB_ELTS_V8);
-            opj_v8dwt_decode(&v);
+                opj_v8dwt_interleave_v(&v, aj, w, NB_ELTS_V8);
+                opj_v8dwt_decode(&v);
 
-            for (k = 0; k < rh; ++k) {
-                memcpy(&aj[k * (OPJ_SIZE_T)w], &v.wavelet[k], NB_ELTS_V8 * sizeof(OPJ_FLOAT32));
+                for (k = 0; k < rh; ++k) {
+                    memcpy(&aj[k * (OPJ_SIZE_T)w], &v.wavelet[k], NB_ELTS_V8 * sizeof(OPJ_FLOAT32));
+                }
+                aj += NB_ELTS_V8;
+            }
+        } else {
+            /* "bench_dwt -I" shows that scaling is poor, likely due to RAM
+                transfer being the limiting factor. So limit the number of
+                threads.
+             */
+            OPJ_UINT32 num_jobs = opj_uint_max((OPJ_UINT32)num_threads / 2, 2U);
+            OPJ_UINT32 step_j;
+
+            if ((rw / NB_ELTS_V8) < num_jobs) {
+                num_jobs = rw / NB_ELTS_V8;
+            }
+            step_j = ((rw / num_jobs) / NB_ELTS_V8) * NB_ELTS_V8;
+            for (j = 0; j < num_jobs; j++) {
+                opj_dwt97_decode_v_job_t* job;
+
+                job = (opj_dwt97_decode_v_job_t*) opj_malloc(sizeof(opj_dwt97_decode_v_job_t));
+                if (!job) {
+                    opj_thread_pool_wait_completion(tp, 0);
+                    opj_aligned_free(h.wavelet);
+                    return OPJ_FALSE;
+                }
+                job->v.wavelet = (opj_v8_t*)opj_aligned_malloc(l_data_size * sizeof(opj_v8_t));
+                if (!job->v.wavelet) {
+                    opj_thread_pool_wait_completion(tp, 0);
+                    opj_free(job);
+                    opj_aligned_free(h.wavelet);
+                    return OPJ_FALSE;
+                }
+                job->v.dn = v.dn;
+                job->v.sn = v.sn;
+                job->v.cas = v.cas;
+                job->v.win_l_x0 = v.win_l_x0;
+                job->v.win_l_x1 = v.win_l_x1;
+                job->v.win_h_x0 = v.win_h_x0;
+                job->v.win_h_x1 = v.win_h_x1;
+                job->rh = rh;
+                job->w = w;
+                job->aj = aj;
+                job->nb_columns = (j + 1 == num_jobs) ? (rw & (OPJ_UINT32)~
+                                  (NB_ELTS_V8 - 1)) - j * step_j : step_j;
+                aj += job->nb_columns;
+                opj_thread_pool_submit_job(tp, opj_dwt97_decode_v_func, job);
             }
-            aj += NB_ELTS_V8;
+            opj_thread_pool_wait_completion(tp, 0);
         }
 
         if (rw & (NB_ELTS_V8 - 1)) {
@@ -3018,7 +3196,7 @@ OPJ_BOOL opj_dwt_decode_real(opj_tcd_t *p_tcd,
                              OPJ_UINT32 numres)
 {
     if (p_tcd->whole_tile_decoding) {
-        return opj_dwt_decode_tile_97(tilec, numres);
+        return opj_dwt_decode_tile_97(p_tcd->thread_pool, tilec, numres);
     } else {
         return opj_dwt_decode_partial_97(tilec, numres);
     }