Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor code to calculate records per shard using n_volumes and number of shards #328

Open
hvgazula opened this issue Apr 20, 2024 · 2 comments
Assignees

Comments

@hvgazula
Copy link
Contributor

first_shard = (
dataset.take(1)
.flat_map(
lambda x: tf.data.TFRecordDataset(x, compression_type=compression_type)
)
.map(map_func=parse_fn, num_parallel_calls=num_parallel_calls)
)
block_length = len([0 for _ in first_shard])

If the number of volumes in the shard is too large, this snippet of code can be time-consuming. Alternatives are

  • use a combination of n_volumes and number of files with file_pattern to calculate len(first_shard)
  • provide metadata (number of volumes in the shard) as well as total number of volumes in the dataset
@hvgazula hvgazula self-assigned this Apr 20, 2024
@hvgazula
Copy link
Contributor Author

@hvgazula
Copy link
Contributor Author

hvgazula commented May 11, 2024

Ideally, if the tfrecords are created using the API, with the aforementioned change, we can ensure the same number of records in every shard except the last one. Now, if n_volumes is not specified, it can be calculated using this function, which is num_records_first_shard * (num_shards - 1) + num_records_in_last_shard

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant