Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor mapnames2 #2501

Merged
merged 4 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions clang/lib/DPCT/AnalysisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "MigrationReport/Statics.h"
#include "RuleInfra/ExprAnalysis.h"
#include "RuleInfra/MapNames.h"
#include "RulesLang/MapNamesLang.h"
#include "RulesMathLib/MapNamesRandom.h"
#include "TextModification.h"
#include "Utility.h"
Expand Down Expand Up @@ -2725,7 +2726,8 @@ std::string CtTypeInfo::getFoldedArraySize(const ConstantArrayTypeLoc &TL) {
if (UETT->isArgumentType()) {
const auto *const RD =
UETT->getArgumentType().getCanonicalType()->getAsRecordDecl();
if (MapNames::SupportedVectorTypes.count(RD->getNameAsString()) == 0) {
if (MapNamesLang::SupportedVectorTypes.count(RD->getNameAsString()) ==
0) {
IsContainSizeOfUserDefinedType = true;
break;
}
Expand Down Expand Up @@ -4056,12 +4058,12 @@ void MemVarMap::merge(const MemVarMap &VarMap,
int MemVarMap::calculateExtraArgsSize() const {
int Size = 0;
if (hasStream())
Size += MapNames::KernelArgTypeSizeMap.at(KernelArgType::KAT_Stream);
Size += MapNamesLang::KernelArgTypeSizeMap.at(KernelArgType::KAT_Stream);

Size = Size + calculateExtraArgsSize(LocalVarMap) +
calculateExtraArgsSize(GlobalVarMap) +
calculateExtraArgsSize(ExternVarMap);
Size = Size + TextureMap.size() * MapNames::KernelArgTypeSizeMap.at(
Size = Size + TextureMap.size() * MapNamesLang::KernelArgTypeSizeMap.at(
KernelArgType::KAT_Texture);

return Size;
Expand Down Expand Up @@ -4256,7 +4258,7 @@ int MemVarMap::calculateExtraArgsSize(const MemVarInfoMap &Map) const {
int Size = 0;
for (auto &VarInfoPair : Map) {
auto D = VarInfoPair.second->getType()->getDimension();
Size += MapNames::getArrayTypeSize(D);
Size += MapNamesLang::getArrayTypeSize(D);
}
return Size;
}
Expand Down Expand Up @@ -5495,7 +5497,7 @@ KernelCallExpr::ArgInfo::ArgInfo(const ParmVarDecl *PVD,
PointerType = Arg->getType();
}
TypeString = DpctGlobalInfo::getReplacedTypeName(PointerType);
ArgSize = MapNames::KernelArgTypeSizeMap.at(KernelArgType::KAT_Default);
ArgSize = MapNamesLang::KernelArgTypeSizeMap.at(KernelArgType::KAT_Default);

// Currently, all the device RNG state structs are passed to kernel by
// pointer. So we check the pointee type, if it is in the type map, we
Expand All @@ -5513,11 +5515,13 @@ KernelCallExpr::ArgInfo::ArgInfo(const ParmVarDecl *PVD,
} else {
auto QT = Arg->getType();
QT = QT.getUnqualifiedType();
auto Iter = MapNames::VectorTypeMigratedTypeSizeMap.find(QT.getAsString());
if (Iter != MapNames::VectorTypeMigratedTypeSizeMap.end())
auto Iter =
MapNamesLang::VectorTypeMigratedTypeSizeMap.find(QT.getAsString());
if (Iter != MapNamesLang::VectorTypeMigratedTypeSizeMap.end())
ArgSize = Iter->second;
else
ArgSize = MapNames::KernelArgTypeSizeMap.at(KernelArgType::KAT_Default);
ArgSize =
MapNamesLang::KernelArgTypeSizeMap.at(KernelArgType::KAT_Default);
if (PVD) {
TypeString = DpctGlobalInfo::getReplacedTypeName(PVD->getType());
}
Expand Down Expand Up @@ -5584,7 +5588,7 @@ KernelCallExpr::ArgInfo::ArgInfo(std::shared_ptr<TextureObjectInfo> Obj,
}
ArgString = ArgStr;
IdString = ArgString + "_";
ArgSize = MapNames::KernelArgTypeSizeMap.at(KernelArgType::KAT_Texture);
ArgSize = MapNamesLang::KernelArgTypeSizeMap.at(KernelArgType::KAT_Texture);
}
const std::string &KernelCallExpr::ArgInfo::getArgString() const {
return ArgString;
Expand Down Expand Up @@ -5958,8 +5962,8 @@ void KernelCallExpr::buildUnionFindSet() {
}
}
void KernelCallExpr::addReplacements() {
if (TotalArgsSize >
MapNames::KernelArgTypeSizeMap.at(KernelArgType::KAT_MaxParameterSize))
if (TotalArgsSize > MapNamesLang::KernelArgTypeSizeMap.at(
KernelArgType::KAT_MaxParameterSize))
DiagnosticsUtils::report(getFilePath(), getOffset(),
Diagnostics::EXCEED_MAX_PARAMETER_SIZE, true,
false);
Expand Down
1 change: 1 addition & 0 deletions clang/lib/DPCT/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ add_clang_library(DPCT
RulesMathLib/MapNamesSolver.cpp
RulesMathLib/MapNamesBlas.cpp
RulesMathLib/MapNamesRandom.cpp
RulesLang/MapNamesLang.cpp
RulesDNN/MapNamesDNN.cpp
RulesLangLib/MapNamesLangLib.cpp
FileGenerator/GenFiles.cpp
Expand Down
2 changes: 2 additions & 0 deletions clang/lib/DPCT/DPCT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "RuleInfra/MemberExprRewriter.h"
#include "RuleInfra/TypeLocRewriters.h"
#include "RulesDNN/MapNamesDNN.h"
#include "RulesLang/MapNamesLang.h"
#include "RulesLangLib/MapNamesLangLib.h"
#include "RulesMathLib/MapNamesBlas.h"
#include "RulesMathLib/MapNamesRandom.h"
Expand Down Expand Up @@ -1146,6 +1147,7 @@ int runDPCT(int argc, const char **argv) {
ExplicitNamespace::EN_SYCL});
}
MapNames::setExplicitNamespaceMap(ExplicitNamespaces);
MapNamesLang::setExplicitNamespaceMap(ExplicitNamespaces);
MapNamesBlas::setExplicitNamespaceMap(ExplicitNamespaces);
MapNamesDNN::setExplicitNamespaceMap(ExplicitNamespaces);
MapNamesLangLib::setExplicitNamespaceMap(ExplicitNamespaces);
Expand Down
13 changes: 7 additions & 6 deletions clang/lib/DPCT/PreProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "AnalysisInfo.h"
#include "Diagnostics/Diagnostics.h"
#include "FileGenerator/GenFiles.h"
#include "RulesLang/MapNamesLang.h"
#include "RulesLangLib/MapNamesLangLib.h"
#include "TextModification.h"
#include "Utility.h"
Expand Down Expand Up @@ -219,17 +220,17 @@ void IncludesCallbacks::MacroDefined(const Token &MacroNameTok,
#endif
}

if (MapNames::AtomicFuncNamesMap.find(II->getName().str()) !=
MapNames::AtomicFuncNamesMap.end()) {
if (MapNamesLang::AtomicFuncNamesMap.find(II->getName().str()) !=
MapNamesLang::AtomicFuncNamesMap.end()) {
std::string HashStr =
getHashStrFromLoc(MI->getReplacementToken(0).getLocation());
DpctGlobalInfo::getInstance().insertAtomicInfo(
HashStr, MacroNameTok.getLocation(), II->getName().str());
} else if (MacroNameTok.getLocation().isValid() &&
MacroNameTok.getIdentifierInfo() &&
MapNames::VectorTypeMigratedTypeSizeMap.find(
MapNamesLang::VectorTypeMigratedTypeSizeMap.find(
MacroNameTok.getIdentifierInfo()->getName().str()) !=
MapNames::VectorTypeMigratedTypeSizeMap.end()) {
MapNamesLang::VectorTypeMigratedTypeSizeMap.end()) {
DiagnosticsUtils::report(
MacroNameTok.getLocation(), Diagnostics::MACRO_SAME_AS_SYCL_TYPE,
&TransformSet, false,
Expand Down Expand Up @@ -492,8 +493,8 @@ void IncludesCallbacks::MacroExpands(const Token &MacroNameTok,
#endif
}

auto Iter = MapNames::HostAllocSet.find(Name.str());
if (TKind == tok::identifier && Iter != MapNames::HostAllocSet.end()) {
auto Iter = MapNamesLang::HostAllocSet.find(Name.str());
if (TKind == tok::identifier && Iter != MapNamesLang::HostAllocSet.end()) {
if (MI->getNumTokens() == 1) {
auto ReplToken = MI->getReplacementToken(0);
if (ReplToken.getKind() == tok::numeric_constant) {
Expand Down
24 changes: 14 additions & 10 deletions clang/lib/DPCT/RuleInfra/ExprAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "RuleInfra/TypeLocRewriters.h"
#include "RulesDNN/DNNAPIMigration.h"
#include "RulesDNN/MapNamesDNN.h"
#include "RulesLang/MapNamesLang.h"
#include "RulesLang/RulesLang.h"
#include "RulesLangLib/CUBAPIMigration.h"
#include "RulesLangLib/MapNamesLangLib.h"
Expand Down Expand Up @@ -500,7 +501,7 @@ bool isMathFunction(std::string Name) {
}

bool isCGAPI(std::string Name) {
return MapNames::CooperativeGroupsAPISet.count(Name);
return MapNamesLang::CooperativeGroupsAPISet.count(Name);
}

void ExprAnalysis::analyzeExpr(const DeclRefExpr *DRE) {
Expand Down Expand Up @@ -603,7 +604,7 @@ void ExprAnalysis::analyzeExpr(const DeclRefExpr *DRE) {
addReplacement(DRE, Repl); \
} while (0)
REPLACE_ENUM(MapNamesBlas::BLASEnumsMap);
REPLACE_ENUM(MapNames::FunctionAttrMap);
REPLACE_ENUM(MapNamesLang::FunctionAttrMap);
REPLACE_ENUM(CuDNNTypeRule::CuDNNEnumNamesMap);
REPLACE_ENUM(MapNamesRandom::RandomEngineTypeMap);
REPLACE_ENUM(MapNamesRandom::RandomOrderingTypeMap);
Expand Down Expand Up @@ -772,10 +773,11 @@ void ExprAnalysis::analyzeExpr(const MemberExpr *ME) {
std::string FieldName = ME->getMemberDecl()->getName().str();
if (MapNames::replaceName(TextureRule::TextureMemberNames, FieldName)) {
addReplacement(ME->getMemberLoc(), buildString("get_", FieldName, "()"));
requestFeature(MapNames::ImageWrapperBaseToGetFeatureMap.at(FieldName));
requestFeature(
MapNamesLang::ImageWrapperBaseToGetFeatureMap.at(FieldName));
}
} else if (MapNames::SupportedVectorTypes.find(BaseType) !=
MapNames::SupportedVectorTypes.end()) {
} else if (MapNamesLang::SupportedVectorTypes.find(BaseType) !=
MapNamesLang::SupportedVectorTypes.end()) {
// Skip user-defined type.
if (isTypeInAnalysisScope(ME->getBase()->getType().getTypePtr()))
return;
Expand All @@ -799,9 +801,10 @@ void ExprAnalysis::analyzeExpr(const MemberExpr *ME) {
addReplacement(ME->getOperatorLoc(), ME->getEndLoc(), "");
} else {
std::string MemberName = ME->getMemberNameInfo().getAsString();
const auto &MArrayIdx = MapNames::MArrayMemberNamesMap.find(MemberName);
if (MapNames::VectorTypes2MArray.count(BaseType) &&
MArrayIdx != MapNames::MArrayMemberNamesMap.end()) {
const auto &MArrayIdx =
MapNamesLang::MArrayMemberNamesMap.find(MemberName);
if (MapNamesLang::VectorTypes2MArray.count(BaseType) &&
MArrayIdx != MapNamesLang::MArrayMemberNamesMap.end()) {
std::string RepStr = "";
if (isImplicit) {
RepStr = "(*this)";
Expand All @@ -810,7 +813,8 @@ void ExprAnalysis::analyzeExpr(const MemberExpr *ME) {
RepStr = ")";
}
addReplacement(Begin, ME->getEndLoc(), RepStr + MArrayIdx->second);
} else if (MapNames::replaceName(MapNames::MemberNamesMap, MemberName)) {
} else if (MapNames::replaceName(MapNamesLang::MemberNamesMap,
MemberName)) {
std::string RepStr = "";
const auto *MD = DpctGlobalInfo::findAncestor<CXXMethodDecl>(ME);
if (MD && MD->isVolatile()) {
Expand Down Expand Up @@ -1243,7 +1247,7 @@ void ExprAnalysis::analyzeDecltypeType(DecltypeTypeLoc TL) {
auto Name = getNestedNameSpecifierString(Qualifier);
auto Range = getDefinitionRange(SR.getBegin(), SR.getEnd());
Name.resize(Name.length() - 2); // Remove the "::".
if (MapNames::SupportedVectorTypes.count(Name)) {
if (MapNamesLang::SupportedVectorTypes.count(Name)) {
auto ReplacedStr =
MapNames::findReplacedName(MapNames::TypeNamesMap, Name);
if (Name.back() != '1') {
Expand Down
Loading
Loading