## Copyright 2022 The IREE Authors # # Licensed under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception """Helpers to serialize/deserialize objects.""" from enum import Enum from typing import Any, Dict, Optional, Sequence, Tuple, Type, TypeVar, Union import dataclasses import typing # types.NoneType is only added after Python 3.10. NONE_TYPE = type(None) SERIALIZE_FUNC_NAME = "__serialize__" DESERIALIZE_FUNC_NAME = "__deserialize__" SUPPORTED_DICT_KEY_TYPES = {str, int, float, bool} SUPPORTED_PRIMITIVE_TYPES = {str, int, float, bool, NONE_TYPE} def serialize_and_pack( obj, root_obj_field_name="root_obj", keyed_obj_map_field_name="keyed_obj_map" ): """Converts and packs the object into a serializable object. Args: obj: object to be serialized. root_obj_field_name: field name of the top-level object in the return dict. keyed_obj_map_field_name: field name of the keyed object map in the return dict. Returns A serializable dict. """ if root_obj_field_name == keyed_obj_map_field_name: raise ValueError(f"root_obj and keyed_obj_map can't have the same field name.") keyed_obj_map = {} root_obj = _serialize(obj=obj, keyed_obj_map=keyed_obj_map) return {root_obj_field_name: root_obj, keyed_obj_map_field_name: keyed_obj_map} T = TypeVar("T") def unpack_and_deserialize( data, root_type: Type[T], root_obj_field_name="root_obj", keyed_obj_map_field_name="keyed_obj_map", ) -> T: """Unpacks and deserializes the data back to the typed object. Args: data: serialized data dict. root_type: top-level object type of the data. root_obj_field_name: field name of the top-level object in the dict. keyed_obj_map_field_name: field name of the keyed object map in the dict. Returns: A deserialized object. """ obj = _deserialize( data=data[root_obj_field_name], obj_type=root_type, keyed_obj_map=data[keyed_obj_map_field_name], ) return typing.cast(root_type, obj) def _serialize(obj, keyed_obj_map: Dict[str, Any]): """Converts the object into a serializable object. Args: obj: object to be serialized. keyed_obj_map: mutable container to store the keyed serializable object. Returns A serializable object. """ serialize_func = getattr(obj, SERIALIZE_FUNC_NAME, None) if serialize_func is not None: return serialize_func(keyed_obj_map) elif isinstance(obj, list): return [_serialize(value, keyed_obj_map) for value in obj] elif isinstance(obj, Enum): return obj.name elif isinstance(obj, dict): result_dict = {} for key, value in obj.items(): if type(key) not in SUPPORTED_DICT_KEY_TYPES: raise ValueError(f"Unsupported key {key} in the dict {obj}.") result_dict[key] = _serialize(value, keyed_obj_map) return result_dict elif type(obj) in SUPPORTED_PRIMITIVE_TYPES: return obj raise ValueError(f"Unsupported object: {obj}.") def _deserialize( data, obj_type: Type, keyed_obj_map: Dict[str, Any], obj_cache: Dict[str, Any] = {} ): """Deserializes the data back to the typed object. Args: data: serialized data. obj_type: type of the data. keyed_obj_map: container of the keyed serializable object. Returns: A deserialized object. """ deserialize_func = getattr(obj_type, DESERIALIZE_FUNC_NAME, None) if deserialize_func is not None: return deserialize_func(data, keyed_obj_map, obj_cache) elif typing.get_origin(obj_type) == list: (subtype,) = typing.get_args(obj_type) return [_deserialize(item, subtype, keyed_obj_map, obj_cache) for item in data] elif typing.get_origin(obj_type) == dict: _, value_type = typing.get_args(obj_type) return dict( (key, _deserialize(value, value_type, keyed_obj_map, obj_cache)) for key, value in data.items() ) elif typing.get_origin(obj_type) == Union: subtypes = typing.get_args(obj_type) if len(subtypes) != 2 or NONE_TYPE not in subtypes: raise ValueError(f"Unsupported union type: {obj_type}.") if data is None: return None subtype = subtypes[0] if subtypes[1] == NONE_TYPE else subtypes[1] return _deserialize(data, subtype, keyed_obj_map, obj_cache) elif issubclass(obj_type, Enum): for member in obj_type: if data == member.name: return member raise ValueError(f"Member {data} not found in the enum {obj_type}.") return data def serializable(cls=None, type_key: Optional[str] = None, id_field: str = "id"): """Decorator to make a dataclass serializable. Args: type_key: string defines the object type and indeicates that the class is a keyed object, which is unique per id and will only have one copy in the serialization per id. id_field: field name of the id field of a keyed object. Example: @serializable @dataclass class A(object): ... @serialzable(type_key="obj_b") @dataclass class B(object): id: str """ if type_key is not None and ":" in type_key: raise ValueError("':' is the reserved character in type_key.") def wrap(cls): if not dataclasses.is_dataclass(cls): raise ValueError(f"{cls} is not a dataclass.") fields = dataclasses.fields(cls) if type_key is not None and all(field.name != id_field for field in fields): raise ValueError(f'Id field "{id_field}" not found in the class {cls}.') def serialize(self, keyed_obj_map: Dict[str, Any]): if type_key is None: return _fields_to_dict(self, fields, keyed_obj_map) obj_id = getattr(self, id_field) obj_key = f"{type_key}:{obj_id}" if obj_key in keyed_obj_map: # If the value in the map is None, it means we have visited this object # before but not yet finished serializing it. This will only happen if # there is a circular reference. if keyed_obj_map[obj_key] is None: raise ValueError(f"Circular reference is not supported: {obj_key}.") return obj_id # Populate the keyed_obj_map with None first to detect circular reference. keyed_obj_map[obj_key] = None obj_dict = _fields_to_dict(self, fields, keyed_obj_map) keyed_obj_map[obj_key] = obj_dict return obj_id def deserialize(data, keyed_obj_map: Dict[str, Any], obj_cache: Dict[str, Any]): if type_key is None: field_value_map = _dict_to_fields( data, fields, keyed_obj_map, obj_cache ) return cls(**field_value_map) obj_id = data obj_key = f"{type_key}:{obj_id}" if obj_key in obj_cache: return obj_cache[obj_key] field_value_map = _dict_to_fields( keyed_obj_map[obj_key], fields, keyed_obj_map, obj_cache ) derialized_obj = cls(**field_value_map) obj_cache[obj_key] = derialized_obj return derialized_obj setattr(cls, SERIALIZE_FUNC_NAME, serialize) setattr(cls, DESERIALIZE_FUNC_NAME, deserialize) return cls # Trick to allow the decoration with `@serializable(...)`. In that case, # `serializable` is called without cls and should return a decorator. if cls is None: return wrap return wrap(cls) def _fields_to_dict( obj, fields: Sequence[dataclasses.Field], keyed_obj_map: Dict[str, Any] ) -> Dict[str, Any]: return dict( (field.name, _serialize(getattr(obj, field.name), keyed_obj_map)) for field in fields ) def _dict_to_fields( obj_dict, fields: Sequence[dataclasses.Field], keyed_obj_map: Dict[str, Any], obj_cache: Dict[str, Any], ) -> Dict[str, Any]: return dict( ( field.name, _deserialize(obj_dict[field.name], field.type, keyed_obj_map, obj_cache), ) for field in fields )