Skip to content

Commit

Permalink
Migrate some TSL code over to ABSL equivalents
Browse files Browse the repository at this point in the history
No functional change is intended.

PiperOrigin-RevId: 706010982
  • Loading branch information
majnemer authored and copybara-github committed Dec 13, 2024
1 parent d58ce1b commit 83c149d
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 230 deletions.
4 changes: 4 additions & 0 deletions tsl/platform/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ cc_library(
":stringpiece",
":stringprintf",
":types",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings",
"@double_conversion//:double-conversion",
],
)
Expand Down Expand Up @@ -1720,6 +1722,8 @@ tsl_cc_test(
":numbers",
":test",
":test_main",
":types",
"@com_google_absl//absl/strings",
],
)

Expand Down
215 changes: 14 additions & 201 deletions tsl/platform/numbers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,16 @@ limitations under the License.
#include <stdlib.h>

#include <algorithm>
#include <cinttypes>
#include <charconv>
#include <cmath>
#include <cstdint>
#include <locale>
#include <string>
#include <system_error> // NOLINT
#include <unordered_map>

#include "double-conversion/double-conversion.h"
#include "tsl/platform/str_util.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/macros.h"
#include "tsl/platform/stringprintf.h"
Expand Down Expand Up @@ -114,17 +116,6 @@ T locale_independent_strtonum(const char* str, const char** endptr) {
return result;
}

static inline const double_conversion::StringToDoubleConverter&
StringToFloatConverter() {
static const double_conversion::StringToDoubleConverter converter(
double_conversion::StringToDoubleConverter::ALLOW_LEADING_SPACES |
double_conversion::StringToDoubleConverter::ALLOW_HEX |
double_conversion::StringToDoubleConverter::ALLOW_TRAILING_SPACES |
double_conversion::StringToDoubleConverter::ALLOW_CASE_INSENSIBILITY,
0., 0., "inf", "nan");
return converter;
}

} // namespace

namespace strings {
Expand Down Expand Up @@ -219,154 +210,6 @@ size_t DoubleToBuffer(double value, char* buffer) {
return snprintf_result;
}

namespace {
char SafeFirstChar(absl::string_view str) {
if (str.empty()) return '\0';
return str[0];
}
void SkipSpaces(absl::string_view* str) {
while (isspace(SafeFirstChar(*str))) str->remove_prefix(1);
}
} // namespace

bool safe_strto64(absl::string_view str, int64_t* value) {
SkipSpaces(&str);

int64_t vlimit = kint64max;
int sign = 1;
if (absl::ConsumePrefix(&str, "-")) {
sign = -1;
// Different limit for positive and negative integers.
vlimit = kint64min;
}

if (!isdigit(SafeFirstChar(str))) return false;

int64_t result = 0;
if (sign == 1) {
do {
int digit = SafeFirstChar(str) - '0';
if ((vlimit - digit) / 10 < result) {
return false;
}
result = result * 10 + digit;
str.remove_prefix(1);
} while (isdigit(SafeFirstChar(str)));
} else {
do {
int digit = SafeFirstChar(str) - '0';
if ((vlimit + digit) / 10 > result) {
return false;
}
result = result * 10 - digit;
str.remove_prefix(1);
} while (isdigit(SafeFirstChar(str)));
}

SkipSpaces(&str);
if (!str.empty()) return false;

*value = result;
return true;
}

bool safe_strtou64(absl::string_view str, uint64_t* value) {
SkipSpaces(&str);
if (!isdigit(SafeFirstChar(str))) return false;

uint64_t result = 0;
do {
int digit = SafeFirstChar(str) - '0';
if ((kuint64max - digit) / 10 < result) {
return false;
}
result = result * 10 + digit;
str.remove_prefix(1);
} while (isdigit(SafeFirstChar(str)));

SkipSpaces(&str);
if (!str.empty()) return false;

*value = result;
return true;
}

bool safe_strto32(absl::string_view str, int32_t* value) {
SkipSpaces(&str);

int64_t vmax = kint32max;
int sign = 1;
if (absl::ConsumePrefix(&str, "-")) {
sign = -1;
// Different max for positive and negative integers.
++vmax;
}

if (!isdigit(SafeFirstChar(str))) return false;

int64_t result = 0;
do {
result = result * 10 + SafeFirstChar(str) - '0';
if (result > vmax) {
return false;
}
str.remove_prefix(1);
} while (isdigit(SafeFirstChar(str)));

SkipSpaces(&str);

if (!str.empty()) return false;

*value = static_cast<int32_t>(result * sign);
return true;
}

bool safe_strtou32(absl::string_view str, uint32_t* value) {
SkipSpaces(&str);
if (!isdigit(SafeFirstChar(str))) return false;

int64_t result = 0;
do {
result = result * 10 + SafeFirstChar(str) - '0';
if (result > kuint32max) {
return false;
}
str.remove_prefix(1);
} while (isdigit(SafeFirstChar(str)));

SkipSpaces(&str);
if (!str.empty()) return false;

*value = static_cast<uint32_t>(result);
return true;
}

bool safe_strtof(absl::string_view str, float* value) {
int processed_characters_count = -1;
auto len = str.size();

// If string length exceeds buffer size or int max, fail.
if (len >= kFastToBufferSize) return false;
if (len > std::numeric_limits<int>::max()) return false;

*value = StringToFloatConverter().StringToFloat(
str.data(), static_cast<int>(len), &processed_characters_count);
return processed_characters_count > 0;
}

bool safe_strtod(absl::string_view str, double* value) {
int processed_characters_count = -1;
auto len = str.size();

// If string length exceeds buffer size or int max, fail.
if (len >= kFastToBufferSize) return false;
if (len > std::numeric_limits<int>::max()) return false;

*value = StringToFloatConverter().StringToDouble(
str.data(), static_cast<int>(len), &processed_characters_count);
return processed_characters_count > 0;
}

size_t FloatToBuffer(float value, char* buffer) {
// FLT_DIG is 6 for IEEE-754 floats, which are used on almost all
// platforms these days. Just in case some system exists where FLT_DIG
Expand Down Expand Up @@ -401,51 +244,21 @@ size_t FloatToBuffer(float value, char* buffer) {
}

std::string FpToString(Fprint fp) {
char buf[17];
snprintf(buf, sizeof(buf), "%016llx", static_cast<long long>(fp));
return std::string(buf);
return absl::StrCat(absl::Hex(fp, absl::kZeroPad16));
}

bool StringToFp(const std::string& s, Fprint* fp) {
char junk;
uint64_t result;
if (sscanf(s.c_str(), "%" SCNx64 "%c", &result, &junk) == 1) {
*fp = result;
return true;
} else {
bool HexStringToUint64(absl::string_view s, uint64_t* result) {
auto end_ptr = s.data() + s.size();
uint64_t parsed_result;
auto [ptr, ec] =
std::from_chars(s.data(), end_ptr, parsed_result, /*base=*/16);
if (ec != std::errc{}) {
return false;
}
}

absl::string_view Uint64ToHexString(uint64_t v, char* buf) {
static const char* hexdigits = "0123456789abcdef";
const int num_byte = 16;
buf[num_byte] = '\0';
for (int i = num_byte - 1; i >= 0; i--) {
buf[i] = hexdigits[v & 0xf];
v >>= 4;
}
return absl::string_view(buf, num_byte);
}

bool HexStringToUint64(const absl::string_view& s, uint64_t* result) {
uint64_t v = 0;
if (s.empty()) {
if (ptr != end_ptr) {
return false;
}
for (size_t i = 0; i < s.size(); i++) {
char c = s[i];
if (c >= '0' && c <= '9') {
v = (v << 4) + (c - '0');
} else if (c >= 'a' && c <= 'f') {
v = (v << 4) + 10 + (c - 'a');
} else if (c >= 'A' && c <= 'F') {
v = (v << 4) + 10 + (c - 'A');
} else {
return false;
}
}
*result = v;
*result = parsed_result;
return true;
}

Expand Down
53 changes: 32 additions & 21 deletions tsl/platform/numbers.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@ limitations under the License.
#ifndef TENSORFLOW_TSL_PLATFORM_NUMBERS_H_
#define TENSORFLOW_TSL_PLATFORM_NUMBERS_H_

#include <cstddef>
#include <cstdint>
#include <string>

#include "absl/base/macros.h"
#include "absl/strings/numbers.h"
#include "tsl/platform/stringpiece.h"
#include "tsl/platform/types.h"

Expand Down Expand Up @@ -46,7 +49,7 @@ namespace strings {
// Int64, UInt64, Int, Uint: 22 bytes
// Time: 30 bytes
// Use kFastToBufferSize rather than hardcoding constants.
static const int kFastToBufferSize = 32;
inline constexpr int kFastToBufferSize = 32;

// ----------------------------------------------------------------------
// FastInt32ToBufferLeft()
Expand Down Expand Up @@ -77,52 +80,60 @@ size_t FloatToBuffer(float value, char* buffer);
// Convert a 64-bit fingerprint value to an ASCII representation.
std::string FpToString(Fprint fp);

// Attempt to parse a fingerprint in the form encoded by FpToString. If
// successful, stores the fingerprint in *fp and returns true. Otherwise,
// returns false.
bool StringToFp(const std::string& s, Fprint* fp);

// Convert a 64-bit fingerprint value to an ASCII representation that
// is terminated by a '\0'.
// Buf must point to an array of at least kFastToBufferSize characters
absl::string_view Uint64ToHexString(uint64_t v, char* buf);

// Attempt to parse a uint64 in the form encoded by FastUint64ToHexString. If
// successful, stores the value in *v and returns true. Otherwise,
// returns false.
bool HexStringToUint64(const absl::string_view& s, uint64_t* result);
// Attempt to parse a `uint64_t` in the form encoded by
// `absl::StrCat(absl::Hex(*result))`. If successful, stores the value in
// `result` and returns true. Otherwise, returns false.
bool HexStringToUint64(absl::string_view s, uint64_t* result);

// Convert strings to 32bit integer values.
// Leading and trailing spaces are allowed.
// Return false with overflow or invalid input.
bool safe_strto32(absl::string_view str, int32_t* value);
ABSL_DEPRECATE_AND_INLINE()
inline bool safe_strto32(absl::string_view str, int32_t* value) {
return absl::SimpleAtoi(str, value);
}

// Convert strings to unsigned 32bit integer values.
// Leading and trailing spaces are allowed.
// Return false with overflow or invalid input.
bool safe_strtou32(absl::string_view str, uint32_t* value);
ABSL_DEPRECATE_AND_INLINE()
inline bool safe_strtou32(absl::string_view str, uint32_t* value) {
return absl::SimpleAtoi(str, value);
}

// Convert strings to 64bit integer values.
// Leading and trailing spaces are allowed.
// Return false with overflow or invalid input.
bool safe_strto64(absl::string_view str, int64_t* value);
ABSL_DEPRECATE_AND_INLINE()
inline bool safe_strto64(absl::string_view str, int64_t* value) {
return absl::SimpleAtoi(str, value);
}

// Convert strings to unsigned 64bit integer values.
// Leading and trailing spaces are allowed.
// Return false with overflow or invalid input.
bool safe_strtou64(absl::string_view str, uint64_t* value);
ABSL_DEPRECATE_AND_INLINE()
inline bool safe_strtou64(absl::string_view str, uint64_t* value) {
return absl::SimpleAtoi(str, value);
}

// Convert strings to floating point values.
// Leading and trailing spaces are allowed.
// Values may be rounded on over- and underflow.
// Returns false on invalid input or if `strlen(value) >= kFastToBufferSize`.
bool safe_strtof(absl::string_view str, float* value);
ABSL_DEPRECATE_AND_INLINE()
inline bool safe_strtof(absl::string_view str, float* value) {
return absl::SimpleAtof(str, value);
}

// Convert strings to double precision floating point values.
// Leading and trailing spaces are allowed.
// Values may be rounded on over- and underflow.
// Returns false on invalid input or if `strlen(value) >= kFastToBufferSize`.
bool safe_strtod(absl::string_view str, double* value);
ABSL_DEPRECATE_AND_INLINE()
inline bool safe_strtod(absl::string_view str, double* value) {
return absl::SimpleAtod(str, value);
}

inline bool ProtoParseNumeric(absl::string_view s, int32_t* value) {
return safe_strto32(s, value);
Expand Down
Loading

0 comments on commit 83c149d

Please sign in to comment.