1 |
|
- |
|
|
1 |
+ |
// [[Rcpp::plugins(cpp11)]] |
|
2 |
+ |
// [[Rcpp::depends(RcppThread)]] |
2 |
3 |
|
// [[Rcpp::depends(RcppArmadillo)]] |
|
4 |
+ |
#include "RcppThread.h" |
3 |
5 |
|
#include <RcppArmadillo.h> |
4 |
6 |
|
#include <cmath> |
5 |
7 |
|
#define ARMA_64BIT_WORD |
6 |
8 |
|
using namespace Rcpp ; |
7 |
9 |
|
|
8 |
10 |
|
// [[Rcpp::export]] |
9 |
|
- |
NumericVector calc_sum_squares_latent(arma::sp_mat Y, NumericMatrix X, NumericMatrix W, NumericVector ybar) { |
10 |
|
- |
|
11 |
|
- |
int n_obs = Y.n_rows; // number of observations |
12 |
|
- |
int n_latent_vars = X.ncol(); // number of latent variables |
13 |
|
- |
int n_explicit_vars = W.ncol(); // number of explicit variables |
|
11 |
+ |
NumericVector calc_sum_squares_latent( |
|
12 |
+ |
arma::sp_mat Y, |
|
13 |
+ |
NumericMatrix X, |
|
14 |
+ |
NumericMatrix W, |
|
15 |
+ |
NumericVector ybar, |
|
16 |
+ |
int threads |
|
17 |
+ |
) { |
|
18 |
+ |
|
|
19 |
+ |
int n_obs = Y.n_cols; // number of observations |
14 |
20 |
|
NumericVector result(2); // final result |
15 |
21 |
|
double SSE = 0; // sum of squared errors across all documents |
16 |
22 |
|
double SST = 0; // total sum of squares across all documents |
17 |
23 |
|
|
18 |
24 |
|
|
19 |
25 |
|
// for each observations... |
20 |
|
- |
for(int d = 0; d < n_obs; d++){ |
|
26 |
+ |
RcppThread::parallelFor( |
|
27 |
+ |
0, |
|
28 |
+ |
n_obs, |
|
29 |
+ |
[&Y, |
|
30 |
+ |
&X, |
|
31 |
+ |
&W, |
|
32 |
+ |
&ybar, |
|
33 |
+ |
&SSE, |
|
34 |
+ |
&SST |
|
35 |
+ |
] (unsigned int d){ |
|
36 |
+ |
RcppThread::checkUserInterrupt(); |
21 |
37 |
|
|
22 |
|
- |
R_CheckUserInterrupt(); |
|
38 |
+ |
// Yhat = X %*% W. But doing it funny below to optimize calculation |
|
39 |
+ |
double sse = 0; |
|
40 |
+ |
double sst = 0; |
23 |
41 |
|
|
24 |
|
- |
// Yhat = X %*% W. But doing it funny below to optimize calculation |
25 |
|
- |
double sse = 0; |
26 |
|
- |
double sst = 0; |
|
42 |
+ |
for(int v = 0; v < W.ncol(); v++ ){ |
|
43 |
+ |
double Yhat = 0; |
27 |
44 |
|
|
28 |
|
- |
for(int v = 0; v < n_explicit_vars; v++ ){ |
29 |
|
- |
double Yhat = 0; |
|
45 |
+ |
for(int k = 0; k < X.ncol(); k++ ){ |
|
46 |
+ |
Yhat = Yhat + X(d , k ) * W(k , v ); |
|
47 |
+ |
} |
30 |
48 |
|
|
31 |
|
- |
for(int k = 0; k < n_latent_vars; k++ ){ |
32 |
|
- |
Yhat = Yhat + X(d , k ) * W(k , v ); |
33 |
|
- |
} |
|
49 |
+ |
sse = sse + ((Y(v, d) - Yhat) * (Y(v, d) - Yhat)); |
34 |
50 |
|
|
35 |
|
- |
sse = sse + ((Y(d , v) - Yhat) * (Y(d , v) - Yhat)); |
|
51 |
+ |
sst = sst + ((Y(v, d) - ybar[ v ]) * (Y(v, d) - ybar[ v ])); |
36 |
52 |
|
|
37 |
|
- |
sst = sst + ((Y(d , v) - ybar[ v ]) * (Y(d , v) - ybar[ v ])); |
|
53 |
+ |
} |
38 |
54 |
|
|
39 |
|
- |
} |
|
55 |
+ |
SSE = SSE + sse; |
40 |
56 |
|
|
41 |
|
- |
SSE = SSE + sse; |
|
57 |
+ |
SST = SST + sst; |
|
58 |
+ |
}, |
|
59 |
+ |
threads); |
42 |
60 |
|
|
43 |
|
- |
SST = SST + sst; |
44 |
|
- |
} |
45 |
61 |
|
|
46 |
62 |
|
result[ 0 ] = SSE; |
47 |
63 |
|
result[ 1 ] = SST; |