1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
//! Metropolis Hastings method
//!
//! # Example
//! see [`MetropolisHastingsSweep`]

use rand_distr::Distribution;
#[cfg(feature = "serde-serialize")]
use serde::{Deserialize, Serialize};

use super::{
    super::{
        super::{
            error::Never,
            field::LinkMatrix,
            lattice::{LatticeCyclic, LatticeElementToIndex, LatticeLink, LatticeLinkCanonical},
            su3, Complex, Real,
        },
        state::{LatticeState, LatticeStateDefault},
    },
    delta_s_old_new_cmp, MonteCarlo,
};

/// Metropolis Hastings method by doing a pass on all points
///
/// # Example
/// ```
/// # use std::error::Error;
/// #
/// # fn main() -> Result<(), Box<dyn Error>> {
/// use lattice_qcd_rs::error::ImplementationError;
/// use lattice_qcd_rs::simulation::{LatticeState, LatticeStateDefault, MetropolisHastingsSweep};
/// use rand::SeedableRng;
///
/// let rng = rand::rngs::StdRng::seed_from_u64(0); // change with your seed
/// let mut mh = MetropolisHastingsSweep::new(1, 0.1_f64, rng)
///     .ok_or(ImplementationError::OptionWithUnexpectedNone)?;
/// // Realistically you want more steps than 10
///
/// let mut state = LatticeStateDefault::<3>::new_cold(1_f64, 6_f64, 4)?;
/// for _ in 0..10 {
///     state = state.monte_carlo_step(&mut mh)?;
///     println!(
///         "mean probability of acceptance during last step {} and replaced {} links",
///         mh.prob_replace_mean(),
///         mh.number_replace_last()
///     );
///     // operation to track the progress or the evolution
/// }
/// // operation at the end of the simulation
/// #     Ok(())
/// # }
/// ```
#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde-serialize", derive(Serialize, Deserialize))]
pub struct MetropolisHastingsSweep<Rng: rand::Rng> {
    number_of_update: usize,
    spread: Real,
    number_replace_last: usize,
    prob_replace_mean: Real,
    rng: Rng,
}

impl<Rng: rand::Rng> MetropolisHastingsSweep<Rng> {
    getter!(
        /// Get a ref to the rng.
        pub const,
        rng,
        Rng
    );

    /// `spread` should be between 0 and 1 both not included and number_of_update should be greater
    /// than 0.
    ///
    /// `number_of_update` is the number of times a link matrix is randomly changed.
    /// `spread` is the spread factor for the random matrix change
    /// ( used in [`su3::random_su3_close_to_unity`]).
    pub fn new(number_of_update: usize, spread: Real, rng: Rng) -> Option<Self> {
        if number_of_update == 0 || spread <= 0_f64 || spread >= 1_f64 {
            return None;
        }
        Some(Self {
            number_of_update,
            spread,
            number_replace_last: 0,
            prob_replace_mean: 0_f64,
            rng,
        })
    }

    /// Get the mean of last probably of acceptance of the random change.
    pub const fn prob_replace_mean(&self) -> Real {
        self.prob_replace_mean
    }

    /// Number of accepted change during last sweep
    pub const fn number_replace_last(&self) -> usize {
        self.number_replace_last
    }

    /// Get the last probably of acceptance of the random change.
    #[allow(clippy::missing_const_for_fn)] // false positive
    pub fn rng_owned(self) -> Rng {
        self.rng
    }

    /// Get a mutable reference to the rng.
    pub fn rng_mut(&mut self) -> &mut Rng {
        &mut self.rng
    }

    #[inline]
    fn delta_s<const D: usize>(
        link_matrix: &LinkMatrix,
        lattice: &LatticeCyclic<D>,
        link: &LatticeLinkCanonical<D>,
        new_link: &na::Matrix3<Complex>,
        beta: Real,
    ) -> Real {
        let old_matrix = link_matrix
            .matrix(&LatticeLink::from(*link), lattice)
            .unwrap();
        delta_s_old_new_cmp(link_matrix, lattice, link, new_link, beta, &old_matrix)
    }

    #[inline]
    fn potential_modif<const D: usize>(
        &mut self,
        state: &LatticeStateDefault<D>,
        link: &LatticeLinkCanonical<D>,
    ) -> na::Matrix3<Complex> {
        let index = link.to_index(state.lattice());
        let old_link_m = state.link_matrix()[index];
        let mut new_link = old_link_m;
        for _ in 0..self.number_of_update {
            let rand_m = su3::orthonormalize_matrix(&su3::random_su3_close_to_unity(
                self.spread,
                &mut self.rng,
            ));
            new_link = rand_m * new_link;
        }

        new_link
    }

    #[inline]
    fn next_element_default<const D: usize>(
        &mut self,
        mut state: LatticeStateDefault<D>,
    ) -> LatticeStateDefault<D> {
        self.prob_replace_mean = 0_f64;
        self.number_replace_last = 0;
        let lattice = state.lattice().clone();
        lattice.get_links().for_each(|link| {
            let potential_modif = self.potential_modif(&state, &link);
            let proba = (-Self::delta_s(
                state.link_matrix(),
                state.lattice(),
                &link,
                &potential_modif,
                state.beta(),
            ))
            .exp()
            .min(1_f64)
            .max(0_f64);
            self.prob_replace_mean += proba;
            let d = rand::distributions::Bernoulli::new(proba).unwrap();
            if d.sample(&mut self.rng) {
                self.number_replace_last += 1;
                *state.link_mut(&link).unwrap() = potential_modif;
            }
        });
        self.prob_replace_mean /= lattice.number_of_canonical_links_space() as f64;
        state
    }
}

impl<Rng: rand::Rng> AsRef<Rng> for MetropolisHastingsSweep<Rng> {
    fn as_ref(&self) -> &Rng {
        self.rng()
    }
}

impl<Rng: rand::Rng> AsMut<Rng> for MetropolisHastingsSweep<Rng> {
    fn as_mut(&mut self) -> &mut Rng {
        self.rng_mut()
    }
}

impl<Rng, const D: usize> MonteCarlo<LatticeStateDefault<D>, D> for MetropolisHastingsSweep<Rng>
where
    Rng: rand::Rng,
{
    type Error = Never;

    #[inline]
    fn next_element(
        &mut self,
        state: LatticeStateDefault<D>,
    ) -> Result<LatticeStateDefault<D>, Self::Error> {
        Ok(self.next_element_default(state))
    }
}

#[cfg(test)]
mod test {
    use std::error::Error;

    use rand::SeedableRng;

    use super::*;
    use crate::error::ImplementationError;

    #[test]
    fn as_ref_as_mut() -> Result<(), Box<dyn Error>> {
        let rng = rand::rngs::StdRng::seed_from_u64(0);
        let mut mh = MetropolisHastingsSweep::new(1, 0.1_f64, rng.clone())
            .ok_or(ImplementationError::OptionWithUnexpectedNone)?;
        assert_eq!(&rng, mh.as_ref());

        let _: &mut rand::rngs::StdRng = mh.as_mut();

        Ok(())
    }
}