Skip to content

Commit

Permalink
Ensure all parameters are aligned
Browse files Browse the repository at this point in the history
  • Loading branch information
cmikeh2 committed Nov 15, 2023
1 parent 901d807 commit 25e9cd5
Showing 1 changed file with 9 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@
from ..inference_utils import elem_size


def pad_to_aligned_offset(offset: int, alignment: int = 256) -> int:
"""
Pad the provided offset to a well-aligned value.
"""
return ((offset + alignment - 1) // alignment) * alignment


class TensorMetadata(DeepSpeedConfigModel):
"""
A class to represent a tensor specification.
Expand Down Expand Up @@ -149,15 +156,15 @@ def process_layer(layer_container: LayerContainer, l_name: str, cur_offset: int)
strides=param.stride(),
offset=cur_offset)

cur_offset += elem_size(param.dtype) * param.numel()
cur_offset += pad_to_aligned_offset(elem_size(param.dtype) * param.numel())

for t_name, tensor in param.aux_attrs.items():
param_metadata.aux_params[t_name] = TensorMetadata(dtype=str(tensor.dtype),
shape=tensor.shape,
strides=tensor.stride(),
offset=cur_offset)

cur_offset += elem_size(param.dtype) * param.numel()
cur_offset += pad_to_aligned_offset(elem_size(param.dtype) * param.numel())

layer_metadata.params[p_name] = param_metadata

Expand Down

0 comments on commit 25e9cd5

Please sign in to comment.