# Implementation of torch Datasets based on the RIR databases


<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->

## Pytorch Datasets

> Define what information from my databases I will provide to the model

### Dataset with a fixed environment

- The Dataset fixes the microphones that are known. For example, I set 4
  microphones at specific locations to meassure the acoustics of an
  environment.
- In every iteration it returns a dictionary containing information from
  1 microphone (position, signal and time samples)

``` python
class DSRirFixedEnv(Dataset):
    """ 
    Dataset with a fixed environment
    In this version I let the targeted microphone (to be predicted) to be any of the micros in the data
    (including those labeled as environment)
    """
    def __init__(self, 
                 mic_database: DB_microphones, 
                 ids_env: List[int],
                 ):
        super().__init__()
        self.db = mic_database
        self.ids_env = ids_env        

        # Environment microphones
        self.env = {}
        self.env['signal'] = [self.db.get_mic(i) for i in ids_env]
        self.env['time'] = [self.db.get_time(i) for i in ids_env]
        self.env['position'] = [self.db.get_pos(i) for i in ids_env]
        # Change to torch tensors
        self.env['signal'] = torch.from_numpy(np.stack(self.env['signal']).astype(np.float32))
        self.env['time'] = torch.from_numpy(np.stack(self.env['time']).astype(np.float32))
        self.env['position'] = torch.from_numpy(np.stack(self.env['position']).astype(np.float32))

    def __len__(self):
        return self.db.n_mics
    
    def __getitem__(self, idx):
        """
        In this version the environment is fixed, so in the __getitem__ 
        we only return the target 
        """       
        
        return dict(signal=self.db.get_mic(idx),
                    time=self.db.get_time(idx), 
                    position=self.db.get_pos(idx))
 
    def get_env(self):
        """
        Return the environment
        """
        return self.env
    
    def __str__(self):

        return ( 
            f"Pytorch Dataset: {self.__class__.__name__}\n"
            f"With length: {self.__len__()} \n"
            f"Environment (mics ids): {self.ids_env}"
            f"\n"+
            self.db.__str__()
        )
```

``` python
ds_Zea = DSRirFixedEnv(mic_database=ZeaRIR("./data", dataname="Balder", signal_start=0, signal_size=128),
                      ids_env=[10, 30, 50])
print()
print(ds_Zea)
```

    Matched resources to download:
    - BalderRIR.mat
    Loading the resource ./data/ZeaRIR/raw/BalderRIR.mat ...

    Pytorch Dataset: DSRirFixedEnv
    With length: 100 
    Environment (mics ids): [10, 30, 50]
    Database: ZeaRIR
    Download: ['BalderRIR.mat']
    Load room: BalderRIR.mat
    Path to raw resource: ./data/ZeaRIR/raw/BalderRIR.mat
    Path to unpacked data folder: ./data/ZeaRIR/raw
    Sampling frequency: 11250 Hz
    Number of microphones: 100
    Number of total time samples: 3623
    Number of time samples selected: 128
    Number of sources: 1
    Signal start: 0
    Signal size: 128
    Source ID: 0

``` python
# Accesing an element
print(f"Length of dataset: {len(ds_Zea)}")
print("Position of Target (index 1). ")
print(f"using list indexing:   {ds_Zea[1]['position']} ") 
print(f"and using __getitem__: {ds_Zea.__getitem__(1)['position']} ")

# Print the environment
print()
print("Environment \nPositions:")
print(ds_Zea.get_env()['position'])
```

    Length of dataset: 100
    Position of Target (index 1). 
    using list indexing:   [0.03 0.   0.  ] 
    and using __getitem__: [0.03 0.   0.  ] 

    Environment 
    Positions:
    tensor([[0.3000, 0.0000, 0.0000],
            [0.9000, 0.0000, 0.0000],
            [1.5000, 0.0000, 0.0000]])

### Dataset with random environment

------------------------------------------------------------------------

<a
href="https://github.com/Ramon-PR/DataScience_exploration/blob/main/DataScience_exploration/datasets/mics_datasets.py#L23"
target="_blank" style="float:right; font-size:smaller">source</a>

### DS_random_pick

>  DS_random_pick (mic_database:DataScience_exploration.datasets.mics_databa
>                      ses.DB_microphones, n_ref_mics:int=4,
>                      max_combinations:int=1000)

\*An abstract class representing a :class:`Dataset`.

All datasets that represent a map from keys to data samples should
subclass it. All subclasses should overwrite :meth:`__getitem__`,
supporting fetching a data sample for a given key. Subclasses could also
optionally overwrite :meth:`__len__`, which is expected to return the
size of the dataset by many :class:`~torch.utils.data.Sampler`
implementations and the default options of
:class:`~torch.utils.data.DataLoader`. Subclasses could also optionally
implement :meth:`__getitems__`, for speedup batched samples loading.
This method accepts list of indices of samples of batch and returns list
of samples.

.. note:: :class:`~torch.utils.data.DataLoader` by default constructs an
index sampler that yields integral indices. To make it work with a
map-style dataset with non-integral indices/keys, a custom sampler must
be provided.\*

<table>
<colgroup>
<col style="width: 6%" />
<col style="width: 25%" />
<col style="width: 34%" />
<col style="width: 34%" />
</colgroup>
<thead>
<tr>
<th></th>
<th><strong>Type</strong></th>
<th><strong>Default</strong></th>
<th><strong>Details</strong></th>
</tr>
</thead>
<tbody>
<tr>
<td>mic_database</td>
<td>DB_microphones</td>
<td></td>
<td></td>
</tr>
<tr>
<td>n_ref_mics</td>
<td>int</td>
<td>4</td>
<td>number of mics I will pick as my environment to interpolate</td>
</tr>
<tr>
<td>max_combinations</td>
<td>int</td>
<td>1000</td>
<td>number of maximum combinations</td>
</tr>
</tbody>
</table>

``` python
ds_Mesh = DS_random_pick(mic_database=MeshRIR(root="./data", dataname="S1", signal_start=0, signal_size=128, source_id=0),
                         n_ref_mics=4,
                         max_combinations=20)


env, target = ds_Mesh[1]
env_p, target_p = ds_Mesh.__getitem__(1)

print()
# Accesing an element
print(f"Length of dataset: {len(ds_Mesh)}")
print("Position of Target (index 1). ")
print(f"using list indexing:   {target['position']} ") 
print(f"and using __getitem__: {target_p['position']} ")

# Print the environment
print()
print("Environment \nPositions:")
print(env['position'])
```

    Matched resources to download:
    - S1-M3969_npy.zip
    Unpacked folder ./data/MeshRIR/raw/S1-M3969_npy already exists. Skipping unpacking.

    Length of dataset: 20
    Position of Target (index 1). 
    using list indexing:   tensor([0.1500, 0.3000, 0.2000]) 
    and using __getitem__: tensor([0.5000, 0.1500, 0.0500]) 

    Environment 
    Positions:
    tensor([[-0.2000,  0.2500,  0.1500],
            [ 0.4500,  0.1500,  0.0000],
            [-0.0500, -0.4500,  0.2000],
            [-0.1000,  0.4000, -0.1500]])

## Pytorch lightning Datamodules

> The pytorch lightning Datamodule organizes the torch `Datasets` with
> the operations that will have to be performed during the stages “fit”
> and “test”. It also includes information about the `Dataloader` that
> will be used for the training.

``` python
def ensure_list(x):
    if isinstance(x, Dataset):
        return [x]
    elif isinstance(x, list):
        return x
    elif x is None:
        return []
    else:
        raise TypeError(f"Expected Dataset or list of Datasets, got {type(x)}")
    
class DM_PL_DataModule(L.LightningDataModule):
    def __init__(self, 
                 ls_datasets_train: List[torch.utils.data.Dataset] = [], 
                 ls_datasets_test: List[torch.utils.data.Dataset] = [],
                 batch_size: int = 64, num_workers: int = 0, 
                 ):
        super().__init__()
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.ls_datasets_train = ensure_list(ls_datasets_train) 
        self.ls_datasets_test = ensure_list(ls_datasets_train)

    def setup(self, stage):
        if stage == "fit":
            self.ds_train, self.ds_val = random_split( ConcatDataset(self.ls_datasets_train), 
                                                        [0.8, 0.2])

        # Assign test dataset for use in dataloader(s)
        if stage == "test":
            self.ds_test = ConcatDataset(self.ls_datasets_test)

    def train_dataloader(self):
        return DataLoader(self.ds_train, batch_size=self.batch_size, shuffle=True,
            num_workers=self.num_workers, pin_memory=False, collate_fn=None)

    def val_dataloader(self):
        return DataLoader(self.ds_val, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.ds_test, batch_size=self.batch_size)
```
