Source code for scalr.nn.dataloader.test_simple_metadataloader
'''This is a test file for simplemetadataloader.'''
import anndata
import numpy as np
import pandas as pd
from scalr.nn.dataloader import build_dataloader
from scalr.utils import generate_dummy_anndata
[docs]
def test_metadataloader():
# Generating dummy anndata.
adata = generate_dummy_anndata(n_samples=15, n_features=7)
# Generating mappings for anndata obs columns.
mappings = {}
for column_name in adata.obs.columns:
mappings[column_name] = {}
id2label = []
id2label += adata.obs[column_name].astype(
'category').cat.categories.tolist()
label2id = {id2label[i]: i for i in range(len(id2label))}
mappings[column_name]['id2label'] = id2label
mappings[column_name]['label2id'] = label2id
# Defining required parameters for metadataloader.
metadata_col = ['batch', 'env']
dataloader_config = {
'name': 'SimpleMetaDataLoader',
'params': {
'batch_size': 10,
'metadata_col': metadata_col
}
}
dataloader, _ = build_dataloader(dataloader_config=dataloader_config,
adata=adata,
target='celltype',
mappings=mappings)
# Comparing expecting features shape after using metadatloader.
for feature, _ in dataloader:
assert feature.shape[1] == len(
adata.var_names) + adata.obs[metadata_col].nunique().sum()
# Breaking, as checking only first batch is enough.
break