use std::rc::Weak; use nalgebra::DMatrix; use gen_rs::{GLOBAL_RNG, Distribution, mvnormal, GenFn, Trace, GfDiff}; use super::model::PointedBuffer; use super::types_2d::{Point,Bounds}; pub struct DriftProposal { pub drift_cov: DMatrix } pub type DriftProposalArgs = (Weak>,()); impl GenFn for DriftProposal { fn simulate(&self, args: DriftProposalArgs) -> Trace { let prev_trace = args.0.upgrade().unwrap(); let mut choices = (None, prev_trace.data.1.clone()); GLOBAL_RNG.with_borrow_mut(|rng| { let new_latent = mvnormal.random(rng, (prev_trace.data.0.clone().unwrap(), self.drift_cov.clone())); choices.0 = Some(new_latent); }); let logp = mvnormal.logpdf(&choices.0.clone().unwrap(), (prev_trace.data.0.clone().unwrap(), self.drift_cov.clone())); Trace::new(args, choices, (), logp) } fn generate(&self, args: DriftProposalArgs, constraints: PointedBuffer) -> (Trace, f64) { let prev_trace = args.0.upgrade().unwrap(); let mut choices = (None, prev_trace.data.1.clone()); let new_latent: Point; let logp: f64; let mut weight = 0.; match constraints.0 { Some(latent_constraint) => { new_latent = latent_constraint; logp = mvnormal.logpdf(&new_latent, (prev_trace.data.0.clone().unwrap(), self.drift_cov.clone())); weight = logp; } None => { new_latent = GLOBAL_RNG.with_borrow_mut(|rng| { mvnormal.random(rng, (prev_trace.data.0.clone().unwrap(), self.drift_cov.clone())) }); logp = mvnormal.logpdf(&new_latent, (prev_trace.data.0.clone().unwrap(), self.drift_cov.clone())); } } choices.0 = Some(new_latent); (Trace::new(args, choices, (), logp), weight) } fn update(&self, _: Trace, _: DriftProposalArgs, _: GfDiff, _: PointedBuffer) -> (Trace, PointedBuffer, f64) { panic!("not implemented") } }