Skip to content

Commit

Permalink
Enhance code review for ML/DL/AI project
Browse files Browse the repository at this point in the history
Implement batch processing using PyTorch’s `DataLoader` and add docstrings for all public functions and classes in clustering modules.

* **Batch Processing**:
  - Add `EmbeddingDataset` class for custom dataset handling in `attention_clustering.py`, `cluster_manager.py`, `dynamic_cluster_manager.py`, and `dynamic_clusterer.py`.
  - Implement batch processing using `DataLoader` in `refine_embeddings` method of `HybridClusteringModule` in `attention_clustering.py`.
  - Implement batch processing using `DataLoader` in `fit_predict` method of `ClusterManager` in `cluster_manager.py`.
  - Implement batch processing using `DataLoader` in `fit_predict` method of `DynamicClusterManager` in `dynamic_cluster_manager.py`.
  - Implement batch processing using `DataLoader` in `select_best_algorithm` method of `DynamicClusterer` in `dynamic_clusterer.py`.

* **Multiprocessing**:
  - Add multiprocessing for preprocessing steps in `generate_explanations` method of `ClusterExplainer` in `cluster_explainer.py`.

* **Docstrings**:
  - Add docstrings for all public functions and classes in `attention_clustering.py`, `cluster_explainer.py`, `cluster_manager.py`, `clustering_utils.py`, `dynamic_cluster_manager.py`, and `dynamic_clusterer.py`.

---

For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/stochastic-sisyphus/synsearch?shareId=XXXX-XXXX-XXXX-XXXX).
  • Loading branch information
stochastic-sisyphus committed Dec 10, 2024
1 parent f656a6c commit 1d7ba48
Show file tree
Hide file tree
Showing 20 changed files with 1,284 additions and 180 deletions.
44 changes: 43 additions & 1 deletion src/clustering/attention_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
from typing import List, Dict, Optional
import numpy as np
from torch.utils.data import DataLoader, Dataset

class AttentionRefiner(nn.Module):
"""Refines embeddings using self-attention before clustering."""
Expand Down Expand Up @@ -29,9 +30,50 @@ def forward(self, embeddings: torch.Tensor) -> torch.Tensor:

return attn_output.squeeze(0)

class EmbeddingDataset(Dataset):
"""Custom Dataset for embeddings."""

def __init__(self, embeddings: np.ndarray):
self.embeddings = embeddings

def __len__(self):
return len(self.embeddings)

def __getitem__(self, idx):
return self.embeddings[idx]

class HybridClusteringModule:
"""Combines attention-refined embeddings with dynamic clustering."""

def __init__(self, embedding_dim: int, device: Optional[str] = None):
"""
Initialize the HybridClusteringModule with embedding dimension and device.
Args:
embedding_dim (int): Dimension of the embeddings.
device (Optional[str], optional): Device to use for computation. Defaults to None.
"""
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
self.attention_refiner = AttentionRefiner(embedding_dim).to(self.device)
self.attention_refiner = AttentionRefiner(embedding_dim).to(self.device)

def refine_embeddings(self, embeddings: np.ndarray, batch_size: int = 32) -> np.ndarray:
"""
Refine embeddings using self-attention in batches.
Args:
embeddings (np.ndarray): Array of embeddings.
batch_size (int, optional): Batch size for processing. Defaults to 32.
Returns:
np.ndarray: Refined embeddings.
"""
dataset = EmbeddingDataset(embeddings)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

refined_embeddings = []
for batch in dataloader:
batch = batch.to(self.device)
refined_batch = self.attention_refiner(batch)
refined_embeddings.append(refined_batch.cpu().numpy())

return np.concatenate(refined_embeddings, axis=0)
110 changes: 90 additions & 20 deletions src/clustering/cluster_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,18 @@
import spacy
from collections import Counter
import logging
from multiprocessing import Pool, cpu_count

class ClusterExplainer:
"""Explains cluster characteristics and key features."""

def __init__(self, config: Dict[str, Any]):
"""
Initialize the ClusterExplainer with configuration settings.
Args:
config (Dict[str, Any]): Configuration dictionary.
"""
self.config = config
self.logger = logging.getLogger(__name__)
self.nlp = spacy.load('en_core_web_sm')
Expand All @@ -22,7 +29,16 @@ def explain_clusters(
texts: List[str],
labels: np.ndarray
) -> Dict[str, Dict[str, Any]]:
"""Generate explanations for each cluster."""
"""
Generate explanations for each cluster.
Args:
texts (List[str]): List of texts.
labels (np.ndarray): Array of cluster labels.
Returns:
Dict[str, Dict[str, Any]]: Explanations for each cluster.
"""
try:
explanations = {}
unique_labels = np.unique(labels)
Expand All @@ -31,36 +47,74 @@ def explain_clusters(
tfidf_matrix = self.vectorizer.fit_transform(texts)
feature_names = self.vectorizer.get_feature_names_out()

for label in unique_labels:
if label == -1: # Skip noise cluster
continue

cluster_texts = [text for text, l in zip(texts, labels) if l == label]
cluster_indices = np.where(labels == label)[0]

explanations[str(label)] = {
'size': len(cluster_texts),
'key_terms': self._get_key_terms(
tfidf_matrix[cluster_indices],
feature_names
),
'entities': self._extract_entities(cluster_texts),
'summary_stats': self._calculate_summary_stats(cluster_texts)
}
with Pool(processes=cpu_count()) as pool:
results = pool.starmap(
self._process_cluster,
[(label, texts, labels, tfidf_matrix, feature_names) for label in unique_labels if label != -1]
)

for label, explanation in results:
explanations[str(label)] = explanation

return explanations

except Exception as e:
self.logger.error(f"Error generating explanations: {e}")
raise

def _process_cluster(
self,
label: int,
texts: List[str],
labels: np.ndarray,
tfidf_matrix: np.ndarray,
feature_names: np.ndarray
) -> (int, Dict[str, Any]):
"""
Process a single cluster to generate explanations.
Args:
label (int): Cluster label.
texts (List[str]): List of texts.
labels (np.ndarray): Array of cluster labels.
tfidf_matrix (np.ndarray): TF-IDF matrix.
feature_names (np.ndarray): Feature names from TF-IDF vectorizer.
Returns:
(int, Dict[str, Any]): Cluster label and its explanation.
"""
cluster_texts = [text for text, l in zip(texts, labels) if l == label]
cluster_indices = np.where(labels == label)[0]

explanation = {
'size': len(cluster_texts),
'key_terms': self._get_key_terms(
tfidf_matrix[cluster_indices],
feature_names
),
'entities': self._extract_entities(cluster_texts),
'summary_stats': self._calculate_summary_stats(cluster_texts)
}

return label, explanation

def _get_key_terms(
self,
cluster_tfidf: np.ndarray,
feature_names: np.ndarray,
top_n: int = 5
) -> List[Dict[str, float]]:
"""Extract key terms using TF-IDF scores."""
"""
Extract key terms using TF-IDF scores.
Args:
cluster_tfidf (np.ndarray): TF-IDF matrix for the cluster.
feature_names (np.ndarray): Feature names from TF-IDF vectorizer.
top_n (int, optional): Number of top terms to extract. Defaults to 5.
Returns:
List[Dict[str, float]]: List of key terms and their scores.
"""
avg_tfidf = np.asarray(cluster_tfidf.mean(axis=0)).ravel()
top_indices = avg_tfidf.argsort()[-top_n:][::-1]

Expand All @@ -70,7 +124,15 @@ def _get_key_terms(
]

def _extract_entities(self, texts: List[str]) -> Dict[str, List[str]]:
"""Extract named entities from cluster texts."""
"""
Extract named entities from cluster texts.
Args:
texts (List[str]): List of texts in the cluster.
Returns:
Dict[str, List[str]]: Most frequent named entities in the cluster.
"""
entities = {'ORG': [], 'PERSON': [], 'GPE': [], 'TOPIC': []}

for text in texts:
Expand All @@ -86,7 +148,15 @@ def _extract_entities(self, texts: List[str]) -> Dict[str, List[str]]:
}

def _calculate_summary_stats(self, texts: List[str]) -> Dict[str, float]:
"""Calculate summary statistics for cluster texts."""
"""
Calculate summary statistics for cluster texts.
Args:
texts (List[str]): List of texts in the cluster.
Returns:
Dict[str, float]: Summary statistics for the cluster texts.
"""
lengths = [len(text.split()) for text in texts]
return {
'avg_length': float(np.mean(lengths)),
Expand Down
86 changes: 76 additions & 10 deletions src/clustering/cluster_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,39 @@
import torch
import multiprocessing
from joblib import parallel_backend, Parallel, delayed
from torch.utils.data import DataLoader, Dataset

class EmbeddingDataset(Dataset):
"""Custom Dataset for embeddings."""

def __init__(self, embeddings: np.ndarray):
self.embeddings = embeddings

def __len__(self):
return len(self.embeddings)

def __getitem__(self, idx):
return self.embeddings[idx]

class ClusterManager:
"""
Manages dynamic clustering operations with adaptive algorithm selection.
"""

def __init__(
self,
config: Dict,
device: Optional[str] = None,
n_jobs: Optional[int] = None
):
"""Initialize the cluster manager with parallel processing support"""
"""
Initialize the cluster manager with parallel processing support.
Args:
config (Dict): Configuration dictionary.
device (Optional[str], optional): Device to use for computation. Defaults to None.
n_jobs (Optional[int], optional): Number of CPU cores to use. Defaults to None.
"""
self.logger = logging.getLogger(__name__)
self.config = config

Expand All @@ -40,7 +64,9 @@ def __init__(
self._initialize_clusterer()

def _initialize_clusterer(self):
"""Initialize clustering algorithm with parallel processing support"""
"""
Initialize clustering algorithm with parallel processing support.
"""
params = self.config.get('clustering_params', {})

if self.method == 'hdbscan':
Expand All @@ -59,8 +85,17 @@ def _initialize_clusterer(self):
n_jobs=self.n_jobs
)

def fit_predict(self, embeddings: np.ndarray) -> Tuple[np.ndarray, Dict]:
"""Fit and predict clusters using parallel processing"""
def fit_predict(self, embeddings: np.ndarray, batch_size: int = 32) -> Tuple[np.ndarray, Dict]:
"""
Fit and predict clusters using parallel processing.
Args:
embeddings (np.ndarray): Array of embeddings.
batch_size (int, optional): Batch size for processing. Defaults to 32.
Returns:
Tuple[np.ndarray, Dict]: Cluster labels and metrics.
"""
self.logger.info(f"Starting clustering with {self.method} on {len(embeddings)} documents")

# Move embeddings to GPU if available and algorithm supports it
Expand All @@ -69,14 +104,30 @@ def fit_predict(self, embeddings: np.ndarray) -> Tuple[np.ndarray, Dict]:
self.labels_ = self._gpu_kmeans(embeddings_tensor)
else:
# Use parallel CPU processing
with parallel_backend('loky', n_jobs=self.n_jobs):
self.labels_ = self.clusterer.fit_predict(embeddings)
dataset = EmbeddingDataset(embeddings)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

all_labels = []
for batch in dataloader:
with parallel_backend('loky', n_jobs=self.n_jobs):
labels = self.clusterer.fit_predict(batch)
all_labels.append(labels)

self.labels_ = np.concatenate(all_labels)

metrics = self._calculate_metrics(embeddings)
return self.labels_, metrics

def _gpu_kmeans(self, embeddings_tensor: torch.Tensor) -> np.ndarray:
"""Perform K-means clustering on GPU"""
"""
Perform K-means clustering on GPU.
Args:
embeddings_tensor (torch.Tensor): Tensor of embeddings.
Returns:
np.ndarray: Cluster labels.
"""
from kmeans_pytorch import kmeans

cluster_ids_x, cluster_centers = kmeans(
Expand All @@ -89,7 +140,15 @@ def _gpu_kmeans(self, embeddings_tensor: torch.Tensor) -> np.ndarray:
return cluster_ids_x.cpu().numpy()

def _calculate_metrics(self, embeddings: np.ndarray) -> Dict:
"""Calculate clustering metrics in parallel"""
"""
Calculate clustering metrics in parallel.
Args:
embeddings (np.ndarray): Array of embeddings.
Returns:
Dict: Calculated metrics.
"""
metrics = {}

try:
Expand All @@ -114,7 +173,14 @@ def save_results(
metrics: Dict,
output_dir: Union[str, Path]
) -> None:
"""Save clustering results and metrics"""
"""
Save clustering results and metrics.
Args:
clusters (Dict[str, List[Dict]]): Cluster assignments.
metrics (Dict): Clustering metrics.
output_dir (Union[str, Path]): Directory to save results.
"""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

Expand All @@ -136,4 +202,4 @@ def save_results(
with open(clusters_file, 'w') as f:
json.dump(cluster_summary, f, indent=2)

self.logger.info(f"Saved clustering results to {output_dir}")
self.logger.info(f"Saved clustering results to {output_dir}")
Loading

0 comments on commit 1d7ba48

Please sign in to comment.