diff --git a/contrib/msggen/msggen/gen/grpc.py b/contrib/msggen/msggen/gen/grpc.py index 7e66afc68456..d8b860e27954 100644 --- a/contrib/msggen/msggen/gen/grpc.py +++ b/contrib/msggen/msggen/gen/grpc.py @@ -1,5 +1,5 @@ # A grpc model -from msggen.model import ArrayField, Field, CompositeField, EnumField, PrimitiveField, Service +from msggen.model import ArrayField, Field, CompositeField, EnumField, PrimitiveField, Service, MethodName, TypeName from msggen.gen import IGenerator from typing import TextIO, List, Dict, Any from textwrap import indent, dedent @@ -60,9 +60,11 @@ def write(self, text: str, cleanup: bool = True) -> None: else: self.dest.write(text) - def field2number(self, message_name, field): + def field2number(self, message_name: TypeName, field): m = self.meta['grpc-field-map'] + message_name = message_name.name # TypeName is not JSON-serializable, use the unaltered name. + # Wrap each field mapping by the message_name, since otherwise # requests and responses share the same number space (just # cosmetic really, but why not do it?) @@ -94,11 +96,14 @@ def enumerate_fields(self, message_name, fields): for f in fields: yield (self.field2number(message_name, f), f) - def enumvar2number(self, typename, variant): + def enumvar2number(self, typename: TypeName, variant): """Find an existing variant number of generate a new one. If we don't have a variant number yet we'll just take the largest one assigned so far and increment it by 1. """ + + typename = str(typename.name) + m = self.meta['grpc-enum-map'] variant = str(variant) if typename not in m: @@ -149,7 +154,7 @@ def generate_service(self, service: Service) -> None: """) for method in service.methods: - mname = method_name_overrides.get(method.name, method.name) + mname = MethodName(method_name_overrides.get(method.name, method.name)) self.write( f" rpc {mname}({method.request.typename}) returns ({method.response.typename}) {{}}\n", cleanup=False, @@ -202,7 +207,7 @@ def generate_message(self, message: CompositeField): typename = f.override(f.typename) self.write(f"\t{opt}{typename} {f.normalized()} = {i};\n", False) - self.write(f"""}} + self.write("""} """) def generate(self, service: Service) -> None: @@ -250,7 +255,7 @@ def generate_composite(self, prefix, field: CompositeField): elif isinstance(f, CompositeField): self.generate_composite(prefix, f) - pbname = self.to_camel_case(field.typename) + pbname = self.to_camel_case(str(field.typename)) # If any of the field accesses would result in a deprecated # warning we mark the construction here to allow deprecated @@ -421,7 +426,7 @@ def generate_composite(self, prefix, field: CompositeField) -> None: has_deprecated = any([f.deprecated for f in field.fields]) deprecated = ",deprecated" if has_deprecated else "" - pbname = self.to_camel_case(field.typename) + pbname = self.to_camel_case(str(field.typename)) # And now we can convert the current field: self.write(f"""\ #[allow(unused_variables{deprecated})] diff --git a/contrib/msggen/msggen/model.py b/contrib/msggen/msggen/model.py index f1a7813601d4..c706e71446c9 100644 --- a/contrib/msggen/msggen/model.py +++ b/contrib/msggen/msggen/model.py @@ -26,6 +26,36 @@ def __str__(self): return self.name +class TypeName: + def __init__(self, name: Optional[str]): + if name is None: + raise ValueError("empty typename") + self.name = name + + def __str__(self) -> str: + """Return the normalized typename.""" + return ( + self.name + .replace(' ', '_') + .replace('-', '') + .replace('/', '_') + ) + + def __repr__(self) -> str: + return f"Typename[raw={self.name}, str={self}" + + def __iadd__(self, other): + self.name += str(other) + return self + + def __lt__(self, other) -> bool: + return str(self.name) < str(other) + + +class MethodName(TypeName): + """A class encapsulating the naming rules for methods. """ + + class Field: def __init__( self, @@ -140,7 +170,7 @@ def __init__(self, name: str, request: Field, response: Field): class CompositeField(Field): def __init__( self, - typename, + typename: TypeName, fields, path, description, @@ -159,7 +189,7 @@ def __init__( @classmethod def from_js(cls, js, path): - typename = path2type(path) + typename = TypeName(path2type(path)) properties = js.get("properties", {}) # Ok, let's flatten the conditional properties. We do this by @@ -257,7 +287,7 @@ def normalized(self): class EnumField(Field): - def __init__(self, typename, values, path, description, added, deprecated): + def __init__(self, typename: TypeName, values, path, description, added, deprecated): Field.__init__(self, path, description, added=added, deprecated=deprecated) self.typename = typename self.values = values @@ -266,7 +296,7 @@ def __init__(self, typename, values, path, description, added, deprecated): @classmethod def from_js(cls, js, path): # Transform the path into something that is a valid TypeName - typename = path2type(path) + typename = TypeName(path2type(path)) return EnumField( typename, values=filter(lambda i: i is not None, js["enum"]),