diff --git a/src/main.rs b/src/main.rs index 106e4eb..1b480ce 100644 --- a/src/main.rs +++ b/src/main.rs @@ -36,11 +36,11 @@ enum Commands { .args(&["rows", "cols"]), ))] Subset { - /// Input MTX file - #[arg(short, long)] - input: String, + /// Input MTX file or directory containing matrix.mtx.gz, features.tsv.gz, and barcodes.tsv + #[arg(short = 'i', long = "input")] + input: Option, - /// Output MTX file + /// Output MTX file or directory name (if directory is used, matrix.mtx.gz, features.tsv.gz, and barcodes.tsv will be subsetted) #[arg(short, long)] output: String, @@ -72,7 +72,14 @@ fn main() -> Result<(), Box> { cols, no_reindex, } => { - subset::subset_matrix(&input, &output, rows, cols, no_reindex)?; + + subset::subset_matrix( + input, + &output, + rows, + cols, + no_reindex, + )?; } } diff --git a/src/subset.rs b/src/subset.rs index fd08d76..468b36f 100644 --- a/src/subset.rs +++ b/src/subset.rs @@ -2,19 +2,25 @@ use crate::io_utils; use std::collections::{HashMap, HashSet}; use std::error::Error; use std::io::{self, BufRead, Write}; +use std::path::Path; +use std::fs; pub fn subset_matrix( - input_file: &str, - output_file: &str, + input: Option, + output: &str, rows_file: Option, cols_file: Option, no_reindex: bool, ) -> Result<(), Box> { + // Check if both rows_file and cols_file are None if rows_file.is_none() && cols_file.is_none() { return Err("At least one of --rows or --cols must be specified.".into()); } + // check if input is a file or directory + let is_directory = input.as_ref().map_or(false, |path| Path::new(path).is_dir()); + // Read the indices to retain let rows_to_retain = if let Some(ref file_path) = rows_file { read_indices(file_path)? @@ -41,17 +47,80 @@ pub fn subset_matrix( }; let rows_set: HashSet = if !rows_to_retain.is_empty() { - rows_to_retain.into_iter().collect() + rows_to_retain.clone().into_iter().collect() } else { HashSet::new() // Empty set indicates all rows are retained }; let cols_set: HashSet = if !cols_to_retain.is_empty() { - cols_to_retain.into_iter().collect() + cols_to_retain.clone().into_iter().collect() } else { HashSet::new() // Empty set indicates all columns are retained }; + if is_directory { + fs::create_dir_all(output)?; + let dir_path = Path::new(input.as_ref().unwrap()); + let output_path = Path::new(output); + let matrix_file = dir_path.join("matrix.mtx.gz").to_string_lossy().into_owned(); + let output_matrix = output_path.join("matrix.mtx").to_string_lossy().into_owned(); + let features_file = dir_path.join("features.tsv.gz").to_string_lossy().into_owned(); + let barcodes_file = if dir_path.join("barcodes.tsv.gz").exists() { + dir_path.join("barcodes.tsv.gz").to_string_lossy().into_owned() + } else { + dir_path.join("barcodes.tsv").to_string_lossy().into_owned() + }; + + let output_features = output_path.join("features.tsv").to_string_lossy().into_owned(); + let output_barcodes = output_path.join("barcodes.tsv").to_string_lossy().into_owned(); + + subset_mtx_file(&matrix_file, &output_matrix, &row_mapping, &col_mapping, &rows_set, &cols_set)?; + subset_dimnames_file(&features_file, &output_features, &rows_to_retain, &row_mapping)?; + subset_dimnames_file(&barcodes_file, &output_barcodes, &cols_to_retain, &col_mapping)?; + + } else { + let matrix_file = input.as_ref().unwrap(); + subset_mtx_file(&matrix_file, output, &row_mapping, &col_mapping, &rows_set, &cols_set)?; + } + + Ok(()) +} + +fn read_indices(file_path: &str) -> Result, Box> { + let reader = io_utils::get_reader(file_path)?; + let mut indices = Vec::new(); + + for (line_number, line) in reader.lines().enumerate() { + let line = line?; + let trimmed_line = line.trim(); + + if trimmed_line.is_empty() { + continue; + } + + match trimmed_line.parse::() { + Ok(idx) => indices.push(idx), + Err(_) => { + eprintln!( + "Warning: Failed to parse index on line {} in '{}'", + line_number + 1, + file_path + ); + } + } + } + + Ok(indices) +} + +fn subset_mtx_file( + input_file: &str, + output_file: &str, + row_mapping: &Option>, + col_mapping: &Option>, + rows_set: &HashSet, + cols_set: &HashSet, +) -> Result<(), Box> { let temp_data_file = "temp_data.mtx"; // Open the input and temporary data files @@ -140,9 +209,11 @@ pub fn subset_matrix( // Now write the header to the output file let mut output_writer = io_utils::get_writer(output_file)?; - // Write the header lines - for line in &header_lines { - writeln!(output_writer, "{}", line.trim_end())?; + // Write the header lines excluding the last one + for (index, line) in header_lines.iter().enumerate() { + if index < header_lines.len() - 1 { + writeln!(output_writer, "{}", line.trim_end())?; + } } // Calculate new dimensions @@ -180,31 +251,32 @@ pub fn subset_matrix( Ok(()) } -fn read_indices(file_path: &str) -> Result, Box> { - let reader = io_utils::get_reader(file_path)?; - let mut indices = Vec::new(); - for (line_number, line) in reader.lines().enumerate() { - let line = line?; - let trimmed_line = line.trim(); +fn subset_dimnames_file( + input_file: &str, + output_file: &str, + indices_to_retain: &[usize], + mapping: &Option>, +) -> Result<(), Box> { + let reader = io_utils::get_reader(input_file)?; + let mut writer = io_utils::get_writer(output_file)?; - if trimmed_line.is_empty() { - continue; - } + let indices_set: HashSet = indices_to_retain.iter().cloned().collect(); - match trimmed_line.parse::() { - Ok(idx) => indices.push(idx), - Err(_) => { - eprintln!( - "Warning: Failed to parse index on line {} in '{}'", - line_number + 1, - file_path - ); - } + for (index, line) in reader.lines().enumerate() { + let line = line?; + if indices_set.is_empty() || indices_set.contains(&(index + 1)) { + let _ = if let Some(ref mapping) = mapping { + *mapping.get(&(index + 1)).unwrap_or(&(index + 1)) + } else { + index + 1 + }; + writeln!(writer, "{}", line)?; } } - Ok(indices) + writer.flush()?; + Ok(()) } fn create_index_mapping(indices: &[usize]) -> HashMap {