Skip to content

Commit

Permalink
Add Union cases to Python get_dtype function (#164)
Browse files Browse the repository at this point in the history
  • Loading branch information
naegelejd authored Jul 8, 2024
1 parent 863223d commit d10dbf4
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 41 deletions.
5 changes: 5 additions & 0 deletions .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@
// Python linting.
"python.analysis.typeCheckingMode": "strict",
"python.analysis.diagnosticMode": "workspace",
"python.analysis.include": [
"${workspaceFolder}/python",
"${workspaceFolder}/smoketest",
"${workspaceFolder}/tooling/internal/python/static_files"
],

"python.defaultInterpreterPath": "/opt/conda/envs/yardl/bin/python",
"python.terminal.activateEnvironment": false, // Disable the extension calling activate when the integrated terminal launches. We take care of this in ~/.bashrc.
Expand Down
8 changes: 6 additions & 2 deletions python/test_model/basic_types/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,15 @@ def _mk_get_dtype():
dtype_map.setdefault(GenericUnion2, lambda type_args: np.dtype(np.object_))
dtype_map.setdefault(GenericNullableUnion2, lambda type_args: np.dtype(np.object_))
dtype_map.setdefault(GenericNullableUnion2, lambda type_args: np.dtype(np.object_))
dtype_map.setdefault(RecordWithUnions, np.dtype([('null_or_int_or_string', np.dtype(np.object_)), ('date_or_datetime', np.dtype(np.object_)), ('null_or_fruits_or_days_of_week', np.dtype(np.object_))], align=True))
dtype_map.setdefault(Int32OrString, np.dtype(np.object_))
dtype_map.setdefault(Int32OrString.Int32, np.dtype(np.int32))
dtype_map.setdefault(Int32OrString.String, np.dtype(np.object_))
dtype_map.setdefault(TimeOrDatetime, np.dtype(np.object_))
dtype_map.setdefault(RecordWithUnions, np.dtype([('null_or_int_or_string', np.dtype(np.object_)), ('date_or_datetime', np.dtype(np.object_)), ('null_or_fruits_or_days_of_week', np.dtype(np.object_))], align=True))
dtype_map.setdefault(T0OrT1, np.dtype(np.object_))
dtype_map.setdefault(TimeOrDatetime.Time, np.dtype(np.timedelta64))
dtype_map.setdefault(TimeOrDatetime.Datetime, np.dtype(np.datetime64))
dtype_map.setdefault(GenericRecordWithComputedFields, lambda type_args: np.dtype([('f1', np.dtype(np.object_))], align=True))
dtype_map.setdefault(T0OrT1, np.dtype(np.object_))

return get_dtype

Expand Down
38 changes: 32 additions & 6 deletions python/test_model/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2027,11 +2027,15 @@ def _mk_get_dtype():
dtype_map.setdefault(basic_types.GenericUnion2, lambda type_args: np.dtype(np.object_))
dtype_map.setdefault(basic_types.GenericNullableUnion2, lambda type_args: np.dtype(np.object_))
dtype_map.setdefault(basic_types.GenericNullableUnion2, lambda type_args: np.dtype(np.object_))
dtype_map.setdefault(basic_types.RecordWithUnions, np.dtype([('null_or_int_or_string', np.dtype(np.object_)), ('date_or_datetime', np.dtype(np.object_)), ('null_or_fruits_or_days_of_week', np.dtype(np.object_))], align=True))
dtype_map.setdefault(basic_types.Int32OrString, np.dtype(np.object_))
dtype_map.setdefault(basic_types.Int32OrString.Int32, np.dtype(np.int32))
dtype_map.setdefault(basic_types.Int32OrString.String, np.dtype(np.object_))
dtype_map.setdefault(basic_types.TimeOrDatetime, np.dtype(np.object_))
dtype_map.setdefault(basic_types.RecordWithUnions, np.dtype([('null_or_int_or_string', np.dtype(np.object_)), ('date_or_datetime', np.dtype(np.object_)), ('null_or_fruits_or_days_of_week', np.dtype(np.object_))], align=True))
dtype_map.setdefault(basic_types.T0OrT1, np.dtype(np.object_))
dtype_map.setdefault(basic_types.TimeOrDatetime.Time, np.dtype(np.timedelta64))
dtype_map.setdefault(basic_types.TimeOrDatetime.Datetime, np.dtype(np.datetime64))
dtype_map.setdefault(basic_types.GenericRecordWithComputedFields, lambda type_args: np.dtype([('f1', np.dtype(np.object_))], align=True))
dtype_map.setdefault(basic_types.T0OrT1, np.dtype(np.object_))
dtype_map.setdefault(SmallBenchmarkRecord, np.dtype([('a', np.dtype(np.float64)), ('b', np.dtype(np.float32)), ('c', np.dtype(np.float32))], align=True))
dtype_map.setdefault(SimpleEncodingCounters, np.dtype([('e1', np.dtype([('has_value', np.dtype(np.bool_)), ('value', np.dtype(np.uint32))], align=True)), ('e2', np.dtype([('has_value', np.dtype(np.bool_)), ('value', np.dtype(np.uint32))], align=True)), ('slice', np.dtype([('has_value', np.dtype(np.bool_)), ('value', np.dtype(np.uint32))], align=True)), ('repetition', np.dtype([('has_value', np.dtype(np.bool_)), ('value', np.dtype(np.uint32))], align=True))], align=True))
dtype_map.setdefault(SimpleAcquisition, np.dtype([('flags', np.dtype(np.uint64)), ('idx', get_dtype(SimpleEncodingCounters)), ('data', np.dtype(np.object_)), ('trajectory', np.dtype(np.object_))], align=True))
Expand All @@ -2055,10 +2059,16 @@ def _mk_get_dtype():
dtype_map.setdefault(RecordWithDynamicNDArrays, np.dtype([('ints', np.dtype(np.object_)), ('simple_record_array', np.dtype(np.object_)), ('record_with_vlens_array', np.dtype(np.object_))], align=True))
dtype_map.setdefault(RecordWithFixedCollections, np.dtype([('fixed_vector', np.dtype(np.int32), (3,)), ('fixed_array', np.dtype(np.int32), (2, 3,))], align=True))
dtype_map.setdefault(RecordWithVlenCollections, np.dtype([('vector', np.dtype(np.object_)), ('array', np.dtype(np.object_))], align=True))
dtype_map.setdefault(RecordWithUnionsOfContainers, np.dtype([('map_or_scalar', np.dtype(np.object_)), ('vector_or_scalar', np.dtype(np.object_)), ('array_or_scalar', np.dtype(np.object_))], align=True))
dtype_map.setdefault(MapOrScalar, np.dtype(np.object_))
dtype_map.setdefault(MapOrScalar.Map, np.dtype(np.object_))
dtype_map.setdefault(MapOrScalar.Scalar, np.dtype(np.int32))
dtype_map.setdefault(VectorOrScalar, np.dtype(np.object_))
dtype_map.setdefault(VectorOrScalar.Vector, np.dtype(np.object_))
dtype_map.setdefault(VectorOrScalar.Scalar, np.dtype(np.int32))
dtype_map.setdefault(ArrayOrScalar, np.dtype(np.object_))
dtype_map.setdefault(RecordWithUnionsOfContainers, np.dtype([('map_or_scalar', np.dtype(np.object_)), ('vector_or_scalar', np.dtype(np.object_)), ('array_or_scalar', np.dtype(np.object_))], align=True))
dtype_map.setdefault(ArrayOrScalar.Array, np.dtype(np.object_))
dtype_map.setdefault(ArrayOrScalar.Scalar, np.dtype(np.int32))
dtype_map.setdefault(Fruits, get_dtype(basic_types.Fruits))
dtype_map.setdefault(UInt64Enum, np.dtype(np.uint64))
dtype_map.setdefault(Int64Enum, np.dtype(np.int64))
Expand All @@ -2083,8 +2093,8 @@ def _mk_get_dtype():
dtype_map.setdefault(RecordWithGenericVectorOfRecords, lambda type_args: np.dtype([('v', np.dtype(np.object_))], align=True))
dtype_map.setdefault(RecordWithOptionalGenericField, lambda type_args: np.dtype([('v', np.dtype([('has_value', np.dtype(np.bool_)), ('value', get_dtype(type_args[0]))], align=True))], align=True))
dtype_map.setdefault(RecordWithAliasedOptionalGenericField, lambda type_args: np.dtype([('v', np.dtype([('has_value', np.dtype(np.bool_)), ('value', get_dtype(type_args[0]))], align=True))], align=True))
dtype_map.setdefault(UOrV, np.dtype(np.object_))
dtype_map.setdefault(RecordWithOptionalGenericUnionField, lambda type_args: np.dtype([('v', np.dtype(np.object_))], align=True))
dtype_map.setdefault(UOrV, np.dtype(np.object_))
dtype_map.setdefault(RecordWithAliasedOptionalGenericUnionField, lambda type_args: np.dtype([('v', np.dtype(np.object_))], align=True))
dtype_map.setdefault(RecordWithGenericVectors, lambda type_args: np.dtype([('v', np.dtype(np.object_)), ('av', np.dtype(np.object_))], align=True))
dtype_map.setdefault(RecordWithGenericFixedVectors, lambda type_args: np.dtype([('fv', get_dtype(type_args[0]), (3,)), ('afv', get_dtype(type_args[0]), (3,))], align=True))
Expand All @@ -2094,14 +2104,20 @@ def _mk_get_dtype():
dtype_map.setdefault(RecordContainingNestedGenericRecords, np.dtype([('f1', get_dtype(types.GenericAlias(RecordWithOptionalGenericField, (str,)))), ('f1a', get_dtype(types.GenericAlias(RecordWithAliasedOptionalGenericField, (str,)))), ('f2', get_dtype(types.GenericAlias(RecordWithOptionalGenericUnionField, (str, yardl.Int32,)))), ('f2a', get_dtype(types.GenericAlias(RecordWithAliasedOptionalGenericUnionField, (str, yardl.Int32,)))), ('nested', get_dtype(types.GenericAlias(RecordContainingGenericRecords, (str, yardl.Int32,))))], align=True))
dtype_map.setdefault(AliasedIntOrSimpleRecord, np.dtype(np.object_))
dtype_map.setdefault(AliasedIntOrAliasedSimpleRecord, np.dtype(np.object_))
dtype_map.setdefault(AliasedNullableIntSimpleRecord, np.dtype(np.object_))
dtype_map.setdefault(typing.Optional[AliasedNullableIntSimpleRecord], np.dtype(np.object_))
dtype_map.setdefault(AliasedNullableIntSimpleRecord, np.dtype(np.object_))
dtype_map.setdefault(AliasedNullableIntSimpleRecord.Int32, np.dtype(np.int32))
dtype_map.setdefault(AliasedNullableIntSimpleRecord.SimpleRecord, get_dtype(SimpleRecord))
dtype_map.setdefault(RecordWithIntVectors, get_dtype(types.GenericAlias(RecordWithGenericVectors, (yardl.Int32,))))
dtype_map.setdefault(RecordWithFloatArrays, get_dtype(types.GenericAlias(RecordWithGenericArrays, (yardl.Float32,))))
dtype_map.setdefault(UnionOfContainerRecords, np.dtype(np.object_))
dtype_map.setdefault(RecordWithComputedFields, np.dtype([('array_field', np.dtype(np.object_)), ('array_field_map_dimensions', np.dtype(np.object_)), ('dynamic_array_field', np.dtype(np.object_)), ('fixed_array_field', np.dtype(np.int32), (3, 4,)), ('int_field', np.dtype(np.int32)), ('int8_field', np.dtype(np.int8)), ('uint8_field', np.dtype(np.uint8)), ('int16_field', np.dtype(np.int16)), ('uint16_field', np.dtype(np.uint16)), ('uint32_field', np.dtype(np.uint32)), ('int64_field', np.dtype(np.int64)), ('uint64_field', np.dtype(np.uint64)), ('size_field', np.dtype(np.uint64)), ('float32_field', np.dtype(np.float32)), ('float64_field', np.dtype(np.float64)), ('complexfloat32_field', np.dtype(np.complex64)), ('complexfloat64_field', np.dtype(np.complex128)), ('string_field', np.dtype(np.object_)), ('tuple_field', get_dtype(types.GenericAlias(tuples.Tuple, (yardl.Int32, yardl.Int32,)))), ('vector_field', np.dtype(np.object_)), ('vector_of_vectors_field', np.dtype(np.object_)), ('fixed_vector_field', np.dtype(np.int32), (3,)), ('fixed_vector_of_vectors_field', np.dtype(np.int32), (2,)), ('optional_named_array', np.dtype([('has_value', np.dtype(np.bool_)), ('value', np.dtype(np.object_))], align=True)), ('int_float_union', np.dtype(np.object_)), ('nullable_int_float_union', np.dtype(np.object_)), ('union_with_nested_generic_union', np.dtype(np.object_)), ('map_field', np.dtype(np.object_))], align=True))
dtype_map.setdefault(Int32OrFloat32, np.dtype(np.object_))
dtype_map.setdefault(Int32OrFloat32.Int32, np.dtype(np.int32))
dtype_map.setdefault(Int32OrFloat32.Float32, np.dtype(np.float32))
dtype_map.setdefault(IntOrGenericRecordWithComputedFields, np.dtype(np.object_))
dtype_map.setdefault(RecordWithComputedFields, np.dtype([('array_field', np.dtype(np.object_)), ('array_field_map_dimensions', np.dtype(np.object_)), ('dynamic_array_field', np.dtype(np.object_)), ('fixed_array_field', np.dtype(np.int32), (3, 4,)), ('int_field', np.dtype(np.int32)), ('int8_field', np.dtype(np.int8)), ('uint8_field', np.dtype(np.uint8)), ('int16_field', np.dtype(np.int16)), ('uint16_field', np.dtype(np.uint16)), ('uint32_field', np.dtype(np.uint32)), ('int64_field', np.dtype(np.int64)), ('uint64_field', np.dtype(np.uint64)), ('size_field', np.dtype(np.uint64)), ('float32_field', np.dtype(np.float32)), ('float64_field', np.dtype(np.float64)), ('complexfloat32_field', np.dtype(np.complex64)), ('complexfloat64_field', np.dtype(np.complex128)), ('string_field', np.dtype(np.object_)), ('tuple_field', get_dtype(types.GenericAlias(tuples.Tuple, (yardl.Int32, yardl.Int32,)))), ('vector_field', np.dtype(np.object_)), ('vector_of_vectors_field', np.dtype(np.object_)), ('fixed_vector_field', np.dtype(np.int32), (3,)), ('fixed_vector_of_vectors_field', np.dtype(np.int32), (2,)), ('optional_named_array', np.dtype([('has_value', np.dtype(np.bool_)), ('value', np.dtype(np.object_))], align=True)), ('int_float_union', np.dtype(np.object_)), ('nullable_int_float_union', np.dtype(np.object_)), ('union_with_nested_generic_union', np.dtype(np.object_)), ('map_field', np.dtype(np.object_))], align=True))
dtype_map.setdefault(IntOrGenericRecordWithComputedFields.Int, np.dtype(np.int32))
dtype_map.setdefault(IntOrGenericRecordWithComputedFields.GenericRecordWithComputedFields, get_dtype(types.GenericAlias(basic_types.GenericRecordWithComputedFields, (str, yardl.Float32,))))
dtype_map.setdefault(GenericUnionWithRepeatedTypeParameters, lambda type_args: np.dtype(np.object_))
dtype_map.setdefault(GenericUnion3, lambda type_args: np.dtype(np.object_))
dtype_map.setdefault(GenericUnion3Alternate, lambda type_args: np.dtype(np.object_))
Expand All @@ -2110,10 +2126,20 @@ def _mk_get_dtype():
dtype_map.setdefault(RecordWithKeywordFields, np.dtype([('int_', np.dtype(np.object_)), ('sizeof', np.dtype(np.object_)), ('if_', get_dtype(EnumWithKeywordSymbols))], align=True))
dtype_map.setdefault(RecordWithOptionalDate, np.dtype([('date_field', np.dtype([('has_value', np.dtype(np.bool_)), ('value', np.dtype(np.datetime64))], align=True))], align=True))
dtype_map.setdefault(AcquisitionOrImage, np.dtype(np.object_))
dtype_map.setdefault(AcquisitionOrImage.Acquisition, get_dtype(SimpleAcquisition))
dtype_map.setdefault(AcquisitionOrImage.Image, np.dtype(np.object_))
dtype_map.setdefault(StringOrInt32, np.dtype(np.object_))
dtype_map.setdefault(StringOrInt32.String, np.dtype(np.object_))
dtype_map.setdefault(StringOrInt32.Int32, np.dtype(np.int32))
dtype_map.setdefault(Int32OrSimpleRecord, np.dtype(np.object_))
dtype_map.setdefault(Int32OrSimpleRecord.Int32, np.dtype(np.int32))
dtype_map.setdefault(Int32OrSimpleRecord.SimpleRecord, get_dtype(SimpleRecord))
dtype_map.setdefault(Int32OrRecordWithVlens, np.dtype(np.object_))
dtype_map.setdefault(Int32OrRecordWithVlens.Int32, np.dtype(np.int32))
dtype_map.setdefault(Int32OrRecordWithVlens.RecordWithVlens, get_dtype(RecordWithVlens))
dtype_map.setdefault(ImageFloatOrImageDouble, np.dtype(np.object_))
dtype_map.setdefault(ImageFloatOrImageDouble.ImageFloat, np.dtype(np.object_))
dtype_map.setdefault(ImageFloatOrImageDouble.ImageDouble, np.dtype(np.object_))

return get_dtype

Expand Down
13 changes: 13 additions & 0 deletions python/tests/test_generated_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,18 @@ def test_get_dtype():
assert tm.get_dtype(typing.Union[tm.Int32, tm.Float32]) == np.object_

assert tm.get_dtype(tm.basic_types.Int32OrString) == np.object_
assert tm.get_dtype(tm.basic_types.Int32OrString.Int32) == np.int32
assert tm.get_dtype(tm.basic_types.Int32OrString.String) == np.object_

assert tm.get_dtype(tm.basic_types.TimeOrDatetime) == np.object_
assert tm.get_dtype(tm.basic_types.TimeOrDatetime.Time) == np.timedelta64
assert tm.get_dtype(tm.basic_types.TimeOrDatetime.Datetime) == np.datetime64

assert tm.get_dtype(tm.Int32OrSimpleRecord) == np.object_
assert tm.get_dtype(tm.Int32OrSimpleRecord.Int32) == np.int32
assert tm.get_dtype(tm.Int32OrSimpleRecord.SimpleRecord) == tm.get_dtype(
tm.SimpleRecord
)

assert tm.get_dtype(tm.AliasedOptional) == np.dtype(
[("has_value", "?"), ("value", np.int32)], align=True
Expand All @@ -216,6 +225,10 @@ def test_get_dtype():
)

assert tm.get_dtype(tm.AliasedNullableIntSimpleRecord) == np.object_
assert tm.get_dtype(tm.AliasedNullableIntSimpleRecord.Int32) == np.int32
assert tm.get_dtype(tm.AliasedNullableIntSimpleRecord.SimpleRecord) == tm.get_dtype(
tm.SimpleRecord
)
assert (
tm.get_dtype(typing.Optional[tm.AliasedNullableIntSimpleRecord]) == np.object_
)
Expand Down
Loading

0 comments on commit d10dbf4

Please sign in to comment.