Skip to content

Commit

Permalink
fit log model
Browse files Browse the repository at this point in the history
  • Loading branch information
timoast committed Sep 29, 2024
1 parent 5520f69 commit 33a65f2
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 16 deletions.
9 changes: 8 additions & 1 deletion src/loess.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,14 @@ pub fn loess(
let n = x.len();
let mut y_predict = Vec::with_capacity(x_predict.len());

let k = ((span * n as f64).ceil() as usize).max(degree + 1); // Ensure enough points
// The value of k is typically determined by the span parameter
// In standard LOESS implementations, k is usually calculated as:
let k = (span * n as f64).round() as usize;

// However, we still need to ensure k is at least degree + 1 to avoid underdetermined systems
let k = k.max(degree + 1);

// println!("{}", k);

for &x0 in x_predict {
// Find distances from x0 to all x
Expand Down
84 changes: 69 additions & 15 deletions src/stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,19 +162,31 @@ pub fn compute_stats(input_file: &str, output_prefix: &str) -> Result<(), Box<dy
let mut row_indices = Vec::new();
let mut row_means = Vec::new();
let mut row_variances = Vec::new();
let mut row_means_log = Vec::new();
let mut row_variances_log = Vec::new();

for row_idx in 1..=n_rows {
if let Some(stats) = row_stats.get(&row_idx) {
row_indices.push(row_idx);
row_means.push(stats.mean());
row_variances.push(stats.variance());
let mean = stats.mean();
let variance = stats.variance();

// Log-transform mean and variance, handling zero values
let log_mean = if mean > 0.0 { mean.ln() } else { f64::NEG_INFINITY };
let log_variance = if variance > 0.0 { variance.ln() } else { f64::NEG_INFINITY };

row_means.push(mean);
row_variances.push(variance);
row_means_log.push(log_mean);
row_variances_log.push(log_variance);
}
}

// Compute residual variances with LOESS
println!("");
println!("Fitting Mean-Variance model");
let span = 0.2;
let residual_variances = compute_residual_variance(&row_means, &row_variances, span)?;
let span = 0.3;
let residual_variances = compute_residual_variance(&row_means_log, &row_variances_log, span)?;

// Update the Stats structs with residual variances
for (&row_idx, &res_var) in row_indices.iter().zip(residual_variances.iter()) {
Expand Down Expand Up @@ -302,23 +314,59 @@ fn compute_residual_variance(
variances: &Vec<f64>,
span: f64,
) -> Result<Vec<f64>, Box<dyn std::error::Error>> {

if means.len() != variances.len() {
return Err("Means and variances vectors must have the same length.".into());
}

// Filter out -inf values
let filtered: Vec<(f64, f64, usize)> = means.iter()
.zip(variances.iter())
.enumerate()
.filter(|(_, (&m, &v))| m.is_finite() && v.is_finite())
.map(|(i, (&m, &v))| (m, v, i))
.collect();

let mut filtered_means = Vec::new();
let mut filtered_variances = Vec::new();
let mut indices = Vec::new();
for (mean, variance, index) in filtered {
filtered_means.push(mean);
filtered_variances.push(variance);
indices.push(index);
}

// Check if we should use binning or fit the model to all values
let use_binning = false;

if !use_binning {
// If not using binning, directly apply LOESS to all filtered data
let x: Vec<f64> = filtered_means.clone();
let y: Vec<f64> = filtered_variances.clone();
let fitted = loess::loess(&x, &y, span, 2, &x);

// Create the result vector with the same length as the original input
let mut result = vec![f64::NAN; means.len()];
for (&idx, &fitted_value) in indices.iter().zip(fitted.iter()) {
result[idx] = fitted_value;
}

return Ok(result);
}

// Compute min and max of means
let min_mean = means.iter().cloned().fold(f64::INFINITY, f64::min);
let max_mean = means.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let min_mean = filtered_means.iter().cloned().fold(f64::INFINITY, f64::min);
let max_mean = filtered_means.iter().cloned().fold(f64::NEG_INFINITY, f64::max);

// Define the number of bins
let num_bins = 200;
let bin_width = (max_mean - min_mean) / num_bins as f64;

println!("{}", bin_width);

// Create bins
let mut bins: Vec<Vec<usize>> = vec![Vec::new(); num_bins];

for (idx, &mean_value) in means.iter().enumerate() {
for (idx, &mean_value) in filtered_means.iter().enumerate() {
let bin_idx = ((mean_value - min_mean) / bin_width).floor() as usize;
let bin_idx = bin_idx.min(num_bins - 1); // Ensure the index is within bounds
bins[bin_idx].push(idx);
Expand All @@ -342,18 +390,24 @@ fn compute_residual_variance(
}

// Create sampled mean and variance vectors
let sampled_means = sampled_indices.iter().map(|&i| means[i]).collect::<Vec<f64>>();
let sampled_variances = sampled_indices.iter().map(|&i| variances[i]).collect::<Vec<f64>>();
let sampled_means = sampled_indices.iter().map(|&i| filtered_means[i]).collect::<Vec<f64>>();
let sampled_variances = sampled_indices.iter().map(|&i| filtered_variances[i]).collect::<Vec<f64>>();

let degree = 2; // quadratic fitting
let var_fitted = loess::loess(&sampled_means, &sampled_variances, span, degree, &means);
// Compute residuals
let residuals: Vec<f64> = variances
let var_fitted = loess::loess(&sampled_means, &sampled_variances, span, degree, &filtered_means);

// Compute residuals for filtered values
let filtered_residuals: Vec<f64> = filtered_variances
.iter()
.zip(var_fitted.iter())
.map(|(&yi, &y_fit)| yi - y_fit)
.collect();
let residual_variance = residuals.iter().map(|&r| r * r).collect();


// Create full residual variance vector, setting 0 for -inf values
let mut residual_variance = vec![0.0; means.len()];
for ((&r, &idx), &var) in filtered_residuals.iter().zip(indices.iter()).zip(filtered_variances.iter()) {
residual_variance[idx] = if var > 0.0 { r * r / var } else { 0.0 };
}

Ok(residual_variance)
}

0 comments on commit 33a65f2

Please sign in to comment.