Skip to content

Commit

Permalink
MCH LRU Eviction Policy (pytorch#1598)
Browse files Browse the repository at this point in the history
Summary:

LRU eviction policy with user-variable decay exponent (e.g. decay_exponent=1 is LRU with linear distance).

Reviewed By: dstaay-fb

Differential Revision: D52100910
  • Loading branch information
gsethi523 authored and facebook-github-bot committed Jan 3, 2024
1 parent 3499bcc commit 06ac6f2
Showing 1 changed file with 135 additions and 0 deletions.
135 changes: 135 additions & 0 deletions torchrec/modules/mc_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,141 @@ def update_metadata_and_generate_eviction_scores(
return evicted_indices, selected_new_indices


class LRU_EvictionPolicy(MCHEvictionPolicy):
def __init__(
self,
decay_exponent: float = 1.0,
threshold_filtering_func: Optional[
Callable[[torch.Tensor], Tuple[torch.Tensor, Union[float, torch.Tensor]]]
] = None, # experimental
) -> None:
super().__init__(
metadata_info=[
MCHEvictionPolicyMetadataInfo(
metadata_name="last_access_iter",
is_mch_metadata=True,
is_history_metadata=True,
),
],
threshold_filtering_func=threshold_filtering_func,
)
self._decay_exponent = decay_exponent

@property
def metadata_info(self) -> List[MCHEvictionPolicyMetadataInfo]:
return self._metadata_info

def record_history_metadata(
self,
current_iter: int,
incoming_ids: torch.Tensor,
history_metadata: Dict[str, torch.Tensor],
) -> None:
history_last_access_iter = history_metadata["last_access_iter"]
history_last_access_iter[:] = current_iter

def coalesce_history_metadata(
self,
current_iter: int,
history_metadata: Dict[str, torch.Tensor],
unique_ids_counts: torch.Tensor,
unique_inverse_mapping: torch.Tensor,
additional_ids: Optional[torch.Tensor] = None,
threshold_mask: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
coalesced_history_metadata: Dict[str, torch.Tensor] = {}
history_last_access_iter = history_metadata["last_access_iter"]
if additional_ids is not None:
history_last_access_iter = torch.cat(
[
history_last_access_iter,
torch.full_like(additional_ids, current_iter),
]
)
coalesced_history_metadata["last_access_iter"] = torch.zeros_like(
unique_ids_counts
).scatter_reduce_(
0,
unique_inverse_mapping,
history_last_access_iter,
reduce="amax",
include_self=False,
)
if threshold_mask is not None:
coalesced_history_metadata["last_access_iter"] = coalesced_history_metadata[
"last_access_iter"
][threshold_mask]
return coalesced_history_metadata

def update_metadata_and_generate_eviction_scores(
self,
current_iter: int,
mch_size: int,
coalesced_history_argsort_mapping: torch.Tensor,
coalesced_history_sorted_unique_ids_counts: torch.Tensor,
coalesced_history_mch_matching_elements_mask: torch.Tensor,
coalesced_history_mch_matching_indices: torch.Tensor,
mch_metadata: Dict[str, torch.Tensor],
coalesced_history_metadata: Dict[str, torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
mch_last_access_iter = mch_metadata["last_access_iter"]

# sort coalesced history metadata
coalesced_history_metadata["last_access_iter"].copy_(
coalesced_history_metadata["last_access_iter"][
coalesced_history_argsort_mapping
]
)
coalesced_history_sorted_uniq_ids_last_access_iter = coalesced_history_metadata[
"last_access_iter"
]

# update metadata for matching ids
mch_last_access_iter[
coalesced_history_mch_matching_indices
] = coalesced_history_sorted_uniq_ids_last_access_iter[
coalesced_history_mch_matching_elements_mask
]

# incoming non-matching ids
new_sorted_uniq_ids_last_access = (
coalesced_history_sorted_uniq_ids_last_access_iter[
~coalesced_history_mch_matching_elements_mask
]
)

# TODO: find cleaner way to avoid last element of zch
mch_last_access_iter[mch_size - 1] = current_iter
merged_access_iter = torch.cat(
[
mch_last_access_iter,
new_sorted_uniq_ids_last_access,
]
)
# lower scores are evicted first.
merged_eviction_scores = torch.neg(
torch.pow(
current_iter - merged_access_iter + 1,
self._decay_exponent,
)
)

# calculate evicted and replacement indices
(
evicted_indices,
selected_new_indices,
) = self._compute_selected_eviction_and_replacement_indices(
mch_size,
merged_eviction_scores,
)

mch_last_access_iter[evicted_indices] = new_sorted_uniq_ids_last_access[
selected_new_indices
]

return evicted_indices, selected_new_indices


class DistanceLFU_EvictionPolicy(MCHEvictionPolicy):
def __init__(
self,
Expand Down

0 comments on commit 06ac6f2

Please sign in to comment.