Skip to content

Commit

Permalink
prefix module qualified names with __module__ (pytorch#23630)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#23630

This is temporary, won't be needed with the new serialization format.
But for now, since the main module gets its name from the archive name,
we need this for safety, other wise something like
`torch.jit.save("torch.pt") will break things.

Test Plan: Imported from OSS

Reviewed By: jamesr66a

Differential Revision: D16592404

Pulled By: suo

fbshipit-source-id: b538dc3438a80ea7bca14d84591ecd63f4b1289f
  • Loading branch information
suo authored and facebook-github-bot committed Aug 1, 2019
1 parent 230f7f9 commit 0ce950d
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 2 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/core/qualified_name.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ struct QualifiedName {
cacheAccessors();
}

/* implicit */ QualifiedName(const std::vector<std::string>& atoms) {
explicit QualifiedName(std::vector<std::string> atoms) {
for (const auto& atom : atoms) {
TORCH_CHECK(!atom.empty(), "Atom cannot be empty");
TORCH_CHECK(
Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/jit/import.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,8 @@ script::Module ScriptModuleDeserializer::convertModule(
for (const auto& atom : atoms) {
moduleStack_.emplace_back(atom);
}
auto module = script::Module(moduleStack_, compilation_unit_);
auto module =
script::Module(c10::QualifiedName(moduleStack_), compilation_unit_);
for (int i = 0; i < module_def.submodules_size(); ++i) {
const torch::ModuleDef& sub_def = module_def.submodules(i);
auto submodule = convertModule(sub_def);
Expand Down
9 changes: 9 additions & 0 deletions torch/csrc/jit/script/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@ static ModulePtr create_module_object(
c10::QualifiedName class_name,
std::shared_ptr<CompilationUnit> cu,
bool shouldMangle = false) {
// XXX: This is a temporary hack so that module names cannot clash with
// builtins like `torch`. Delete this with the new serialization format.
std::vector<std::string> new_class_name{"__module__"};
new_class_name.insert(
new_class_name.end(),
class_name.atoms().begin(),
class_name.atoms().end());
class_name = c10::QualifiedName(std::move(new_class_name));

if (shouldMangle && cu->get_class(class_name) != nullptr) {
class_name = cu->mangle(class_name);
}
Expand Down

0 comments on commit 0ce950d

Please sign in to comment.