use rand_distr::Distribution;
#[cfg(feature = "serde-serialize")]
use serde::{Deserialize, Serialize};
use super::{
super::{
super::{
error::Never,
field::LinkMatrix,
lattice::{LatticeCyclic, LatticeElementToIndex, LatticeLink, LatticeLinkCanonical},
su3, Complex, Real,
},
state::{LatticeState, LatticeStateDefault},
},
delta_s_old_new_cmp, MonteCarlo,
};
#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde-serialize", derive(Serialize, Deserialize))]
pub struct MetropolisHastingsSweep<Rng: rand::Rng> {
number_of_update: usize,
spread: Real,
number_replace_last: usize,
prob_replace_mean: Real,
rng: Rng,
}
impl<Rng: rand::Rng> MetropolisHastingsSweep<Rng> {
getter!(
pub const,
rng,
Rng
);
pub fn new(number_of_update: usize, spread: Real, rng: Rng) -> Option<Self> {
if number_of_update == 0 || spread <= 0_f64 || spread >= 1_f64 {
return None;
}
Some(Self {
number_of_update,
spread,
number_replace_last: 0,
prob_replace_mean: 0_f64,
rng,
})
}
pub const fn prob_replace_mean(&self) -> Real {
self.prob_replace_mean
}
pub const fn number_replace_last(&self) -> usize {
self.number_replace_last
}
#[allow(clippy::missing_const_for_fn)] pub fn rng_owned(self) -> Rng {
self.rng
}
pub fn rng_mut(&mut self) -> &mut Rng {
&mut self.rng
}
#[inline]
fn delta_s<const D: usize>(
link_matrix: &LinkMatrix,
lattice: &LatticeCyclic<D>,
link: &LatticeLinkCanonical<D>,
new_link: &na::Matrix3<Complex>,
beta: Real,
) -> Real {
let old_matrix = link_matrix
.matrix(&LatticeLink::from(*link), lattice)
.unwrap();
delta_s_old_new_cmp(link_matrix, lattice, link, new_link, beta, &old_matrix)
}
#[inline]
fn potential_modif<const D: usize>(
&mut self,
state: &LatticeStateDefault<D>,
link: &LatticeLinkCanonical<D>,
) -> na::Matrix3<Complex> {
let index = link.to_index(state.lattice());
let old_link_m = state.link_matrix()[index];
let mut new_link = old_link_m;
for _ in 0..self.number_of_update {
let rand_m = su3::orthonormalize_matrix(&su3::random_su3_close_to_unity(
self.spread,
&mut self.rng,
));
new_link = rand_m * new_link;
}
new_link
}
#[inline]
fn next_element_default<const D: usize>(
&mut self,
mut state: LatticeStateDefault<D>,
) -> LatticeStateDefault<D> {
self.prob_replace_mean = 0_f64;
self.number_replace_last = 0;
let lattice = state.lattice().clone();
lattice.get_links().for_each(|link| {
let potential_modif = self.potential_modif(&state, &link);
let proba = (-Self::delta_s(
state.link_matrix(),
state.lattice(),
&link,
&potential_modif,
state.beta(),
))
.exp()
.min(1_f64)
.max(0_f64);
self.prob_replace_mean += proba;
let d = rand::distributions::Bernoulli::new(proba).unwrap();
if d.sample(&mut self.rng) {
self.number_replace_last += 1;
*state.link_mut(&link).unwrap() = potential_modif;
}
});
self.prob_replace_mean /= lattice.number_of_canonical_links_space() as f64;
state
}
}
impl<Rng: rand::Rng> AsRef<Rng> for MetropolisHastingsSweep<Rng> {
fn as_ref(&self) -> &Rng {
self.rng()
}
}
impl<Rng: rand::Rng> AsMut<Rng> for MetropolisHastingsSweep<Rng> {
fn as_mut(&mut self) -> &mut Rng {
self.rng_mut()
}
}
impl<Rng, const D: usize> MonteCarlo<LatticeStateDefault<D>, D> for MetropolisHastingsSweep<Rng>
where
Rng: rand::Rng,
{
type Error = Never;
#[inline]
fn next_element(
&mut self,
state: LatticeStateDefault<D>,
) -> Result<LatticeStateDefault<D>, Self::Error> {
Ok(self.next_element_default(state))
}
}
#[cfg(test)]
mod test {
use std::error::Error;
use rand::SeedableRng;
use super::*;
use crate::error::ImplementationError;
#[test]
fn as_ref_as_mut() -> Result<(), Box<dyn Error>> {
let rng = rand::rngs::StdRng::seed_from_u64(0);
let mut mh = MetropolisHastingsSweep::new(1, 0.1_f64, rng.clone())
.ok_or(ImplementationError::OptionWithUnexpectedNone)?;
assert_eq!(&rng, mh.as_ref());
let _: &mut rand::rngs::StdRng = mh.as_mut();
Ok(())
}
}