Skip to content

Commit

Permalink
msggen: Add classes for MethodName and TypeName
Browse files Browse the repository at this point in the history
This is required for types and methods with names that need
post-processing (`bkpr-listincome`).
  • Loading branch information
cdecker committed Feb 1, 2024
1 parent 85b79bc commit 19af808
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 11 deletions.
19 changes: 12 additions & 7 deletions contrib/msggen/msggen/gen/grpc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# A grpc model
from msggen.model import ArrayField, Field, CompositeField, EnumField, PrimitiveField, Service
from msggen.model import ArrayField, Field, CompositeField, EnumField, PrimitiveField, Service, MethodName, TypeName
from msggen.gen import IGenerator
from typing import TextIO, List, Dict, Any
from textwrap import indent, dedent
Expand Down Expand Up @@ -60,9 +60,11 @@ def write(self, text: str, cleanup: bool = True) -> None:
else:
self.dest.write(text)

def field2number(self, message_name, field):
def field2number(self, message_name: TypeName, field):
m = self.meta['grpc-field-map']

message_name = message_name.name # TypeName is not JSON-serializable, use the unaltered name.

# Wrap each field mapping by the message_name, since otherwise
# requests and responses share the same number space (just
# cosmetic really, but why not do it?)
Expand Down Expand Up @@ -94,11 +96,14 @@ def enumerate_fields(self, message_name, fields):
for f in fields:
yield (self.field2number(message_name, f), f)

def enumvar2number(self, typename, variant):
def enumvar2number(self, typename: TypeName, variant):
"""Find an existing variant number of generate a new one.
If we don't have a variant number yet we'll just take the
largest one assigned so far and increment it by 1. """

typename = str(typename.name)

m = self.meta['grpc-enum-map']
variant = str(variant)
if typename not in m:
Expand Down Expand Up @@ -149,7 +154,7 @@ def generate_service(self, service: Service) -> None:
""")

for method in service.methods:
mname = method_name_overrides.get(method.name, method.name)
mname = MethodName(method_name_overrides.get(method.name, method.name))
self.write(
f" rpc {mname}({method.request.typename}) returns ({method.response.typename}) {{}}\n",
cleanup=False,
Expand Down Expand Up @@ -202,7 +207,7 @@ def generate_message(self, message: CompositeField):
typename = f.override(f.typename)
self.write(f"\t{opt}{typename} {f.normalized()} = {i};\n", False)

self.write(f"""}}
self.write("""}
""")

def generate(self, service: Service) -> None:
Expand Down Expand Up @@ -250,7 +255,7 @@ def generate_composite(self, prefix, field: CompositeField):
elif isinstance(f, CompositeField):
self.generate_composite(prefix, f)

pbname = self.to_camel_case(field.typename)
pbname = self.to_camel_case(str(field.typename))

# If any of the field accesses would result in a deprecated
# warning we mark the construction here to allow deprecated
Expand Down Expand Up @@ -421,7 +426,7 @@ def generate_composite(self, prefix, field: CompositeField) -> None:
has_deprecated = any([f.deprecated for f in field.fields])
deprecated = ",deprecated" if has_deprecated else ""

pbname = self.to_camel_case(field.typename)
pbname = self.to_camel_case(str(field.typename))
# And now we can convert the current field:
self.write(f"""\
#[allow(unused_variables{deprecated})]
Expand Down
38 changes: 34 additions & 4 deletions contrib/msggen/msggen/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,36 @@ def __str__(self):
return self.name


class TypeName:
def __init__(self, name: Optional[str]):
if name is None:
raise ValueError("empty typename")
self.name = name

def __str__(self) -> str:
"""Return the normalized typename."""
return (
self.name
.replace(' ', '_')
.replace('-', '')
.replace('/', '_')
)

def __repr__(self) -> str:
return f"Typename[raw={self.name}, str={self}"

def __iadd__(self, other):
self.name += str(other)
return self

def __lt__(self, other) -> bool:
return str(self.name) < str(other)


class MethodName(TypeName):
"""A class encapsulating the naming rules for methods. """


class Field:
def __init__(
self,
Expand Down Expand Up @@ -140,7 +170,7 @@ def __init__(self, name: str, request: Field, response: Field):
class CompositeField(Field):
def __init__(
self,
typename,
typename: TypeName,
fields,
path,
description,
Expand All @@ -159,7 +189,7 @@ def __init__(

@classmethod
def from_js(cls, js, path):
typename = path2type(path)
typename = TypeName(path2type(path))

properties = js.get("properties", {})
# Ok, let's flatten the conditional properties. We do this by
Expand Down Expand Up @@ -257,7 +287,7 @@ def normalized(self):


class EnumField(Field):
def __init__(self, typename, values, path, description, added, deprecated):
def __init__(self, typename: TypeName, values, path, description, added, deprecated):
Field.__init__(self, path, description, added=added, deprecated=deprecated)
self.typename = typename
self.values = values
Expand All @@ -266,7 +296,7 @@ def __init__(self, typename, values, path, description, added, deprecated):
@classmethod
def from_js(cls, js, path):
# Transform the path into something that is a valid TypeName
typename = path2type(path)
typename = TypeName(path2type(path))
return EnumField(
typename,
values=filter(lambda i: i is not None, js["enum"]),
Expand Down

0 comments on commit 19af808

Please sign in to comment.