Spaces:
Sleeping
Sleeping
Implement auto-register seed-scheme factory
Browse files- seed_scheme_factory.py +44 -0
- seed_schemes/__init__.py +13 -0
- seed_schemes/dummy_hash.py +12 -0
- seed_schemes/sha_hash.py +33 -0
seed_scheme_factory.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Union, Callable
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class SeedSchemeFactory:
|
| 7 |
+
registry = {}
|
| 8 |
+
|
| 9 |
+
@classmethod
|
| 10 |
+
def register(cls, name: str):
|
| 11 |
+
"""
|
| 12 |
+
Register the hash scheme by name. Hash scheme must be callable.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
name: name of seed scheme.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def wrapper(wrapped_class):
|
| 19 |
+
if name in cls.registry:
|
| 20 |
+
print(f"Override {name} in SeedSchemeFactory")
|
| 21 |
+
cls.registry[name] = wrapped_class
|
| 22 |
+
return wrapped_class
|
| 23 |
+
|
| 24 |
+
return wrapper
|
| 25 |
+
|
| 26 |
+
@classmethod
|
| 27 |
+
def get_instance(cls, name: str, *args, **kwargs):
|
| 28 |
+
"""
|
| 29 |
+
Get the hash scheme by name.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
name: name of seed scheme.
|
| 33 |
+
"""
|
| 34 |
+
if name in cls.registry:
|
| 35 |
+
return cls.registry[name](*args, **kwargs)
|
| 36 |
+
else:
|
| 37 |
+
return None
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class SeedScheme():
|
| 41 |
+
def __call__(self, input_ids: torch.Tensor) -> int:
|
| 42 |
+
return 0
|
| 43 |
+
|
| 44 |
+
from seed_schemes import *
|
seed_schemes/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from inspect import isclass
|
| 2 |
+
from pkgutil import iter_modules
|
| 3 |
+
from importlib import import_module
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
pkg_dir = os.path.dirname(__file__)
|
| 7 |
+
for _, module_name, _ in iter_modules([pkg_dir]):
|
| 8 |
+
module = import_module(f"{__name__}.{module_name}")
|
| 9 |
+
|
| 10 |
+
for attr_name in dir(module):
|
| 11 |
+
attr = getattr(module, attr_name)
|
| 12 |
+
if isclass(attr):
|
| 13 |
+
globals()[attr_name] = attr
|
seed_schemes/dummy_hash.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from seed_scheme_factory import SeedSchemeFactory, SeedScheme
|
| 4 |
+
|
| 5 |
+
@SeedSchemeFactory.register("dummy_hash")
|
| 6 |
+
class DummyHash(SeedScheme):
|
| 7 |
+
def __init__(self, *args, **kwargs):
|
| 8 |
+
pass
|
| 9 |
+
|
| 10 |
+
def __call__(self, input_ids: torch.Tensor):
|
| 11 |
+
return int(input_ids[-1].item())
|
| 12 |
+
|
seed_schemes/sha_hash.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Union
|
| 2 |
+
import torch
|
| 3 |
+
import io
|
| 4 |
+
|
| 5 |
+
from cryptography.hazmat.primitives import hashes
|
| 6 |
+
|
| 7 |
+
from seed_scheme_factory import SeedSchemeFactory, SeedScheme
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@SeedSchemeFactory.register("sha_left_hash")
|
| 11 |
+
class SHALeftHash(SeedScheme):
|
| 12 |
+
def __init__(self, private_key: Union[int, None] = None, *args, **kwargs):
|
| 13 |
+
self.private_key = (
|
| 14 |
+
private_key.to_bytes(8, "big") if private_key is not None else None
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
def __call__(self, input_ids: torch.Tensor):
|
| 18 |
+
buff = io.BytesIO()
|
| 19 |
+
if self.private_key is not None:
|
| 20 |
+
buff.write(self.private_key)
|
| 21 |
+
for input_id in input_ids:
|
| 22 |
+
buff.write(int(input_id.item()).to_bytes(8, "big"))
|
| 23 |
+
buff.seek(0)
|
| 24 |
+
input_ids_bytes = buff.read()
|
| 25 |
+
|
| 26 |
+
digest = hashes.Hash(hashes.SHA224())
|
| 27 |
+
digest.update(input_ids_bytes)
|
| 28 |
+
hashed_value = digest.finalize()
|
| 29 |
+
|
| 30 |
+
# Only take the first 8 bytes because seed in torch rng only accept int64 seed
|
| 31 |
+
seed = int.from_bytes(hashed_value[:8], byteorder="big")
|
| 32 |
+
|
| 33 |
+
return seed
|