Computing the Gaussian Wasserstein distance
조회 수: 17 (최근 30일)
이전 댓글 표시
Problem
I have to compute the Wasserstein distance between two bivariate Gaussian distributions with means , and covariances , .
According to equation 9 of this paper, in the Gaussian case the Wasserstein distance admits the following analytic expression
My problem is to implement this equation.
Tried solution
It is not clear to me what is intended as square root of a matrix, and probably this is the core of my problem. I suppose that in equation the square root of a matrix is its Cholesky factor.
According to this interpretation, I've written the following code to perform the Wasserstein distance given the parameters of the two Gaussian distributions
function [dd] = wass_dist(m1, Sigma1, m2, Sigma2)
sqrtSigma1 = chol(Sigma1);
sqrt_temp = chol(sqrtSigma1 * Sigma2 * sqrtSigma1);
ddm = (m1 - m2)' * (m1 - m2);
ddSigma = trace(Sigma1 + Sigma2 - 2 * sqrt_temp);
dd = ddm + ddSigma;
end
Firstly, such code gives problems because often the matrix sqrtSigma1 * Sigma2 * sqrtSigma1 is not positive definite. I suspect that this problem can be fixed in two manners: by transposing the first term, i.e. by considering sqrtSigma1' * Sigma2 * sqrtSigma1, or by transposing the third term, i.e. by considering sqrtSigma1 * Sigma2 * sqrtSigma1'. However in the aforermentioned paper, and in other papers as well, the given formula to compute the Wasserstein distance is always written in the form without transposition, meaning that does not contain any typo.
At this point I've tried to compute the Wasserstein distance of two identical Gaussian distributions according to the following modified function
function [dd] = wass_dist(m1, Sigma1, m2, Sigma2)
sqrtSigma1 = chol(Sigma1);
sqrt_temp = chol(sqrtSigma1 * Sigma2 * sqrtSigma1'); % third term transposed
ddm = (m1 - m2)' * (m1 - m2);
ddSigma = trace(Sigma1 + Sigma2 - 2 * sqrt_temp);
dd = ddm + ddSigma;
end
the output, as not expected, is not zero because ddSigma doesn't get a null value. More precisely, the input arguments that I've tried are
m1 = [500 500]', Sigma1 = 1.0e+04 * [1.6767 -0.3302; -0.3302 0.0826]
m2 = m1, Sigma2 = Sigma1
and the relative output is
dd = 6.6613
where ddm = 0 and ddSigma = 6.6613. This fact is a strong suggestion that there's something wrong in the code, maybe because the square root considered is not the Cholesky factor. I've also tried the modified version where is the first term to be transposed, and the result is even worse. With the previous inputs, the resul is
dd = ddSigma = 1.3002e+03
Question
Is it correct my code to compute equation ? If not, how can I fixed it?
댓글 수: 1
Ogul Can Yurdakul
2023년 5월 16일
Hey Matteo,
I believe using chol() is your problem. Cholesky decomposition of Σ gives you a matrix C such that , and the problem (I assume) is that this is a transposed square root. Using a symmetric squre root, meaning a matrix square root S such that might just solve your problem. I use the below code with no problems.
function [dist] = GW_dist(mu_1, cov_1, mu_2, cov_2)
dist = (mu_1 - mu_2).' * (mu_1 - mu_2);
dist = dist + trace(cov_1 + cov_2 - 2*(cov_1^0.5 * cov_2 * cov_1^0.5)^0.5);
dist = dist^0.5;
end
Hope it helps!
답변 (0개)
참고 항목
카테고리
Help Center 및 File Exchange에서 Particle & Nuclear Physics에 대해 자세히 알아보기
제품
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!