tnk2908 commited on
Commit
91d5a5e
·
1 Parent(s): 399153d

Implement auto-register seed-scheme factory

Browse files
seed_scheme_factory.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, Callable
2
+
3
+ import torch
4
+
5
+
6
+ class SeedSchemeFactory:
7
+ registry = {}
8
+
9
+ @classmethod
10
+ def register(cls, name: str):
11
+ """
12
+ Register the hash scheme by name. Hash scheme must be callable.
13
+
14
+ Args:
15
+ name: name of seed scheme.
16
+ """
17
+
18
+ def wrapper(wrapped_class):
19
+ if name in cls.registry:
20
+ print(f"Override {name} in SeedSchemeFactory")
21
+ cls.registry[name] = wrapped_class
22
+ return wrapped_class
23
+
24
+ return wrapper
25
+
26
+ @classmethod
27
+ def get_instance(cls, name: str, *args, **kwargs):
28
+ """
29
+ Get the hash scheme by name.
30
+
31
+ Args:
32
+ name: name of seed scheme.
33
+ """
34
+ if name in cls.registry:
35
+ return cls.registry[name](*args, **kwargs)
36
+ else:
37
+ return None
38
+
39
+
40
+ class SeedScheme():
41
+ def __call__(self, input_ids: torch.Tensor) -> int:
42
+ return 0
43
+
44
+ from seed_schemes import *
seed_schemes/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inspect import isclass
2
+ from pkgutil import iter_modules
3
+ from importlib import import_module
4
+ import os
5
+
6
+ pkg_dir = os.path.dirname(__file__)
7
+ for _, module_name, _ in iter_modules([pkg_dir]):
8
+ module = import_module(f"{__name__}.{module_name}")
9
+
10
+ for attr_name in dir(module):
11
+ attr = getattr(module, attr_name)
12
+ if isclass(attr):
13
+ globals()[attr_name] = attr
seed_schemes/dummy_hash.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from seed_scheme_factory import SeedSchemeFactory, SeedScheme
4
+
5
+ @SeedSchemeFactory.register("dummy_hash")
6
+ class DummyHash(SeedScheme):
7
+ def __init__(self, *args, **kwargs):
8
+ pass
9
+
10
+ def __call__(self, input_ids: torch.Tensor):
11
+ return int(input_ids[-1].item())
12
+
seed_schemes/sha_hash.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+ import torch
3
+ import io
4
+
5
+ from cryptography.hazmat.primitives import hashes
6
+
7
+ from seed_scheme_factory import SeedSchemeFactory, SeedScheme
8
+
9
+
10
+ @SeedSchemeFactory.register("sha_left_hash")
11
+ class SHALeftHash(SeedScheme):
12
+ def __init__(self, private_key: Union[int, None] = None, *args, **kwargs):
13
+ self.private_key = (
14
+ private_key.to_bytes(8, "big") if private_key is not None else None
15
+ )
16
+
17
+ def __call__(self, input_ids: torch.Tensor):
18
+ buff = io.BytesIO()
19
+ if self.private_key is not None:
20
+ buff.write(self.private_key)
21
+ for input_id in input_ids:
22
+ buff.write(int(input_id.item()).to_bytes(8, "big"))
23
+ buff.seek(0)
24
+ input_ids_bytes = buff.read()
25
+
26
+ digest = hashes.Hash(hashes.SHA224())
27
+ digest.update(input_ids_bytes)
28
+ hashed_value = digest.finalize()
29
+
30
+ # Only take the first 8 bytes because seed in torch rng only accept int64 seed
31
+ seed = int.from_bytes(hashed_value[:8], byteorder="big")
32
+
33
+ return seed