1441b394c2cd9776c88f56e7c4c9c1eca1533cc1
[ardour.git] / libs / qm-dsp / dsp / segmentation / cluster_melt.c
1 /*
2  *  cluster.c
3  *  cluster_melt
4  *
5  *  Created by Mark Levy on 21/02/2006.
6  *  Copyright 2006 Centre for Digital Music, Queen Mary, University of London.
7
8     This program is free software; you can redistribute it and/or
9     modify it under the terms of the GNU General Public License as
10     published by the Free Software Foundation; either version 2 of the
11     License, or (at your option) any later version.  See the file
12     COPYING included with this distribution for more information.
13  *
14  */
15
16 #include <stdlib.h>
17
18 #include "cluster_melt.h"
19
20 #define DEFAULT_LAMBDA 0.02;
21 #define DEFAULT_LIMIT 20;
22
23 double kldist(double* a, double* b, int n) {
24         /* NB assume that all a[i], b[i] are non-negative
25         because a, b represent probability distributions */
26         double q, d;
27         int i;
28         
29         d = 0;
30         for (i = 0; i < n; i++)
31         {
32                 q = (a[i] + b[i]) / 2.0;
33                 if (q > 0)
34                 {
35                         if (a[i] > 0)
36                                 d += a[i] * log(a[i] / q);
37                         if (b[i] > 0)
38                                 d += b[i] * log(b[i] / q);
39                 }
40         }
41         return d;               
42 }       
43
44 void cluster_melt(double *h, int m, int n, double *Bsched, int t, int k, int l, int *c) {
45         double lambda, sum, beta, logsumexp, maxlp;
46         int i, j, a, b, b0, b1, limit, B, it, maxiter, maxiter0, maxiter1;
47         double** cl;    /* reference histograms for each cluster */
48         int** nc;       /* neighbour counts for each histogram */
49         double** lp;    /* soft assignment probs for each histogram */
50         int* oldc;      /* previous hard assignments (to check convergence) */
51         
52         /* NB h is passed as a 1d row major array */
53         
54         /* parameter values */
55         lambda = DEFAULT_LAMBDA;
56         if (l > 0)
57                 limit = l;
58         else
59                 limit = DEFAULT_LIMIT;          /* use default if no valid neighbourhood limit supplied */
60         B = 2 * limit + 1;
61         maxiter0 = 20;  /* number of iterations at initial temperature */
62         maxiter1 = 5;   /* number of iterations at subsequent temperatures */
63         
64         /* allocate memory */   
65         cl = (double**) malloc(k*sizeof(double*));
66         for (i= 0; i < k; i++)
67                 cl[i] = (double*) malloc(m*sizeof(double));
68         
69         nc = (int**) malloc(n*sizeof(int*));
70         for (i= 0; i < n; i++)
71                 nc[i] = (int*) malloc(k*sizeof(int));
72         
73         lp = (double**) malloc(n*sizeof(double*));
74         for (i= 0; i < n; i++)
75                 lp[i] = (double*) malloc(k*sizeof(double));
76         
77         oldc = (int*) malloc(n * sizeof(int));
78         
79         /* initialise */
80         for (i = 0; i < k; i++)
81         {
82                 sum = 0;
83                 for (j = 0; j < m; j++)
84                 {
85                         cl[i][j] = rand();      /* random initial reference histograms */
86                         sum += cl[i][j] * cl[i][j];
87                 }
88                 sum = sqrt(sum);
89                 for (j = 0; j < m; j++)
90                 {
91                         cl[i][j] /= sum;        /* normalise */
92                 }
93         }       
94         //print_array(cl, k, m);
95         
96         for (i = 0; i < n; i++)
97                 c[i] = 1;       /* initially assign all histograms to cluster 1 */
98         
99         for (a = 0; a < t; a++)
100         {
101                 beta = Bsched[a];
102                 
103                 if (a == 0)
104                         maxiter = maxiter0;
105                 else
106                         maxiter = maxiter1;
107                 
108                 for (it = 0; it < maxiter; it++)
109                 {
110                         //if (it == maxiter - 1)
111                         //      mexPrintf("hasn't converged after %d iterations\n", maxiter);
112                         
113                         for (i = 0; i < n; i++)
114                         {
115                                 /* save current hard assignments */
116                                 oldc[i] = c[i];
117                                 
118                                 /* calculate soft assignment logprobs for each cluster */
119                                 sum = 0;
120                                 for (j = 0; j < k; j++)
121                                 {
122                                         lp[i][ j] = -beta * kldist(cl[j], &h[i*m], m);
123                                         
124                                         /* update matching neighbour counts for this histogram, based on current hard assignments */
125                                         /* old version:
126                                         nc[i][j] = 0;   
127                                         if (i >= limit && i <= n - 1 - limit)
128                                         {
129                                                         for (b = i - limit; b <= i + limit; b++)
130                                                         {
131                                                                 if (c[b] == j+1)
132                                                                         nc[i][j]++;
133                                                         }
134                                                         nc[i][j] = B - nc[i][j];
135                                         }
136                                         */
137                                         b0 = i - limit;
138                                         if (b0 < 0)
139                                                 b0 = 0;
140                                         b1 = i + limit;
141                                         if (b1 >= n)
142                                                 b1 = n - 1;
143                                         nc[i][j] = b1 - b0 + 1;         /* = B except at edges */
144                                         for (b = b0; b <= b1; b++)
145                                                 if (c[b] == j+1)
146                                                         nc[i][j]--;
147                                         
148                                         sum += exp(lp[i][j]);
149                                 }
150                                 
151                                 /* normalise responsibilities and add duration logprior */
152                                 logsumexp = log(sum);
153                                 for (j = 0; j < k; j++)
154                                         lp[i][j] -= logsumexp + lambda * nc[i][j];                              
155                         }
156                         //print_array(lp, n, k);
157                         /*
158                         for (i = 0; i < n; i++)
159                         {
160                                  for (j = 0; j < k; j++)
161                                          mexPrintf("%d ", nc[i][j]);
162                                  mexPrintf("\n");
163                         } 
164                         */
165                         
166                         
167                         /* update the assignments now that we know the duration priors
168                         based on the current assignments */
169                         for (i = 0; i < n; i++)
170                         {
171                                 maxlp = lp[i][0];
172                                 c[i] = 1;
173                                 for (j = 1; j < k; j++)
174                                         if (lp[i][j] > maxlp)
175                                         {
176                                                 maxlp = lp[i][j];
177                                                 c[i] = j+1;
178                                         }
179                         }
180                                 
181                         /* break if assignments haven't changed */
182                         i = 0;
183                         while (i < n && oldc[i] == c[i])
184                                 i++;
185                         if (i == n)
186                                 break;
187                         
188                         /* update reference histograms now we know new responsibilities */
189                         for (j = 0; j < k; j++)
190                         {
191                                 for (b = 0; b < m; b++)
192                                 {
193                                         cl[j][b] = 0;
194                                         for (i = 0; i < n; i++)
195                                         {
196                                                 cl[j][b] += exp(lp[i][j]) * h[i*m+b];
197                                         }       
198                                 }
199                                 
200                                 sum = 0;                                
201                                 for (i = 0; i < n; i++)
202                                         sum += exp(lp[i][j]);
203                                 for (b = 0; b < m; b++)
204                                         cl[j][b] /= sum;        /* normalise */
205                         }       
206                         
207                         //print_array(cl, k, m);
208                         //mexPrintf("\n\n");
209                 }
210         }
211                 
212         /* free memory */
213         for (i = 0; i < k; i++)
214                 free(cl[i]);
215         free(cl);
216         for (i = 0; i < n; i++)
217                 free(nc[i]);
218         free(nc);
219         for (i = 0; i < n; i++)
220                 free(lp[i]);
221         free(lp);
222         free(oldc);     
223 }
224
225