class DSRirFixedEnv(Dataset):
"""
Dataset with a fixed environment
In this version I let the targeted microphone (to be predicted) to be any of the micros in the data
(including those labeled as environment)
"""
def __init__(self,
mic_database: DB_microphones,
ids_env: List[int],
):
super().__init__()
self.db = mic_database
self.ids_env = ids_env
# Environment microphones
self.env = {}
self.env['signal'] = [self.db.get_mic(i) for i in ids_env]
self.env['time'] = [self.db.get_time(i) for i in ids_env]
self.env['position'] = [self.db.get_pos(i) for i in ids_env]
# Change to torch tensors
self.env['signal'] = torch.from_numpy(np.stack(self.env['signal']).astype(np.float32))
self.env['time'] = torch.from_numpy(np.stack(self.env['time']).astype(np.float32))
self.env['position'] = torch.from_numpy(np.stack(self.env['position']).astype(np.float32))
def __len__(self):
return self.db.n_mics
def __getitem__(self, idx):
"""
In this version the environment is fixed, so in the __getitem__
we only return the target
"""
return dict(signal=self.db.get_mic(idx),
time=self.db.get_time(idx),
position=self.db.get_pos(idx))
def get_env(self):
"""
Return the environment
"""
return self.env
def __str__(self):
return (
f"Pytorch Dataset: {self.__class__.__name__}\n"
f"With length: {self.__len__()} \n"
f"Environment (mics ids): {self.ids_env}"
f"\n"+
self.db.__str__()
)Implementation of torch Datasets based on the RIR databases
Pytorch Datasets
Define what information from my databases I will provide to the model
Dataset with a fixed environment
- The Dataset fixes the microphones that are known. For example, I set 4 microphones at specific locations to meassure the acoustics of an environment.
- In every iteration it returns a dictionary containing information from 1 microphone (position, signal and time samples)
ds_Zea = DSRirFixedEnv(mic_database=ZeaRIR("./data", dataname="Balder", signal_start=0, signal_size=128),
ids_env=[10, 30, 50])
print()
print(ds_Zea)Matched resources to download:
- BalderRIR.mat
Loading the resource ./data/ZeaRIR/raw/BalderRIR.mat ...
Pytorch Dataset: DSRirFixedEnv
With length: 100
Environment (mics ids): [10, 30, 50]
Database: ZeaRIR
Download: ['BalderRIR.mat']
Load room: BalderRIR.mat
Path to raw resource: ./data/ZeaRIR/raw/BalderRIR.mat
Path to unpacked data folder: ./data/ZeaRIR/raw
Sampling frequency: 11250 Hz
Number of microphones: 100
Number of total time samples: 3623
Number of time samples selected: 128
Number of sources: 1
Signal start: 0
Signal size: 128
Source ID: 0
# Accesing an element
print(f"Length of dataset: {len(ds_Zea)}")
print("Position of Target (index 1). ")
print(f"using list indexing: {ds_Zea[1]['position']} ")
print(f"and using __getitem__: {ds_Zea.__getitem__(1)['position']} ")
# Print the environment
print()
print("Environment \nPositions:")
print(ds_Zea.get_env()['position'])Length of dataset: 100
Position of Target (index 1).
using list indexing: [0.03 0. 0. ]
and using __getitem__: [0.03 0. 0. ]
Environment
Positions:
tensor([[0.3000, 0.0000, 0.0000],
[0.9000, 0.0000, 0.0000],
[1.5000, 0.0000, 0.0000]])
Dataset with random environment
DS_random_pick
DS_random_pick (mic_database:DataScience_exploration.datasets.mics_databa ses.DB_microphones, n_ref_mics:int=4, max_combinations:int=1000)
*An abstract class representing a :class:Dataset.
All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:__getitem__, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:__len__, which is expected to return the size of the dataset by many :class:~torch.utils.data.Sampler implementations and the default options of :class:~torch.utils.data.DataLoader. Subclasses could also optionally implement :meth:__getitems__, for speedup batched samples loading. This method accepts list of indices of samples of batch and returns list of samples.
.. note:: :class:~torch.utils.data.DataLoader by default constructs an index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.*
| Type | Default | Details | |
|---|---|---|---|
| mic_database | DB_microphones | ||
| n_ref_mics | int | 4 | number of mics I will pick as my environment to interpolate |
| max_combinations | int | 1000 | number of maximum combinations |
ds_Mesh = DS_random_pick(mic_database=MeshRIR(root="./data", dataname="S1", signal_start=0, signal_size=128, source_id=0),
n_ref_mics=4,
max_combinations=20)
env, target = ds_Mesh[1]
env_p, target_p = ds_Mesh.__getitem__(1)
print()
# Accesing an element
print(f"Length of dataset: {len(ds_Mesh)}")
print("Position of Target (index 1). ")
print(f"using list indexing: {target['position']} ")
print(f"and using __getitem__: {target_p['position']} ")
# Print the environment
print()
print("Environment \nPositions:")
print(env['position'])Matched resources to download:
- S1-M3969_npy.zip
Unpacked folder ./data/MeshRIR/raw/S1-M3969_npy already exists. Skipping unpacking.
Length of dataset: 20
Position of Target (index 1).
using list indexing: tensor([0.1500, 0.3000, 0.2000])
and using __getitem__: tensor([0.5000, 0.1500, 0.0500])
Environment
Positions:
tensor([[-0.2000, 0.2500, 0.1500],
[ 0.4500, 0.1500, 0.0000],
[-0.0500, -0.4500, 0.2000],
[-0.1000, 0.4000, -0.1500]])
Pytorch lightning Datamodules
The pytorch lightning Datamodule organizes the torch
Datasetswith the operations that will have to be performed during the stages “fit” and “test”. It also includes information about theDataloaderthat will be used for the training.
def ensure_list(x):
if isinstance(x, Dataset):
return [x]
elif isinstance(x, list):
return x
elif x is None:
return []
else:
raise TypeError(f"Expected Dataset or list of Datasets, got {type(x)}")
class DM_PL_DataModule(L.LightningDataModule):
def __init__(self,
ls_datasets_train: List[torch.utils.data.Dataset] = [],
ls_datasets_test: List[torch.utils.data.Dataset] = [],
batch_size: int = 64, num_workers: int = 0,
):
super().__init__()
self.batch_size = batch_size
self.num_workers = num_workers
self.ls_datasets_train = ensure_list(ls_datasets_train)
self.ls_datasets_test = ensure_list(ls_datasets_train)
def setup(self, stage):
if stage == "fit":
self.ds_train, self.ds_val = random_split( ConcatDataset(self.ls_datasets_train),
[0.8, 0.2])
# Assign test dataset for use in dataloader(s)
if stage == "test":
self.ds_test = ConcatDataset(self.ls_datasets_test)
def train_dataloader(self):
return DataLoader(self.ds_train, batch_size=self.batch_size, shuffle=True,
num_workers=self.num_workers, pin_memory=False, collate_fn=None)
def val_dataloader(self):
return DataLoader(self.ds_val, batch_size=self.batch_size, num_workers=self.num_workers)
def test_dataloader(self):
return DataLoader(self.ds_test, batch_size=self.batch_size)