"""This file contains functions related to miscellaneous utilities."""
import os
import random
import numpy as np
import torch
[docs]
def set_seed(seed: int):
"""A function to set seed for reproducibility."""
os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
[docs]
def overwrite_default(user_config: dict, default_config: dict) -> dict:
"""The function recursively overwrites information from user_config
onto the default_config.
"""
for key in user_config.keys():
if key not in default_config.keys() or not isinstance(
user_config[key], dict):
default_config[key] = user_config[key]
else:
default_config[key] = overwrite_default(user_config[key],
default_config[key])
return default_config
[docs]
def build_object(module, config: dict):
"""A builder function to build an object from its config.
Args:
module: Module containing the class.
config: Contains the name of the class and params to initialize the object.
Returns: Object, updated config.
"""
name = config.get('name')
if not name:
raise ValueError('class name not provided!')
params = config.get('params', dict())
default_params = getattr(module, name).get_default_params()
params = overwrite_default(params, default_params)
final_config = dict(name=name, params=params)
return getattr(module, name)(**params), final_config