- class model.modules.EHRAuditDataModule(yaml_config_path, vocab, batch_size=1, n_positions=1024, reset_cache=False, debug=False)¶
- Parameters:
yaml_config_path (str) – Configuration file for the data module.
vocab (EHRVocab) – Vocabulary for tokenization.
batch_size – Size of batches for the dataloader/collation.
n_positions – Context length of the model.
reset_cache – If the data model has changed, the cached sequences will need to be reset and this flag set.
debug – Whether to run in debug mode (single-threaded).
- prepare_data()¶
Prepares the data for the model. Loads the datasets from the audit logs specifically to tokenize. Will load the datasets in parallel and cache them (unless debug mode is enabled).
If self.reset_cache is set, the cached sequences will be reset. This is useful if the data has changed. :return:
- setup(stage=None)¶
Sets up the data for the model and loads them fully into memory. Splits the datasets into training, validation, and testing sets. :param stage: :return:
- test_dataloader()¶
- train_dataloader()¶
- val_dataloader()¶
- class model.modules.EHRAuditPretraining(model)¶
PyTorch Lightning module for the pretraining task on the EHR audit data.
- Parameters:
model – The model to be trained.
- configure_optimizers()¶
Configure the optimizer for the model. Uses SophiaG. :return:
- forward(input_ids, labels, should_break=False)¶
Forward pass for the model. :param input_ids: :param labels: :param should_break: :return:
- predict_step(batch, batch_idx, dataloader_idx=0)¶
- test_step(batch, batch_idx)¶
- training_step(batch, batch_idx)¶
- validation_step(batch, batch_idx)¶
- model.modules.collate_fn(batch, n_positions=1024)¶
Collate function for loading batches from the dataloader for the model.
- Parameters:
batch – The batch of data to be collated.
n_positions – The context length of the model.
- Returns:
Collated input_ids and labels.
- model.modules.worker_fn(worker_id, seed=0)¶