- class model.model.EHRAuditGPT2(config, vocab)¶
GPT2 model for the EHR audit log dataset. This model inherits from the GPT2LMHeadModel class from the transformers library, and only includes the
TabularLoss
loss function as a major difference.- Parameters:
vocab (EHRVocab)
- forward(input_ids=None, labels=None, attention_mask=None, past=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, past_key_values=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, should_break=False, **kwargs)¶
Forward pass for the GPT2 model. Full documentation can be found in the
GPT2LMHeadModel
class.- Parameters:
input_ids – The input IDs for the model.
labels – The labels for the model. Necessary for generating cross-entropy loss.
attention_mask – The attention mask for the model if desired.
kwargs
- Returns:
Returns a
CausalLMOutputWithCrossAttentions
object with the predicted logits and loss.
- class model.model.EHRAuditLlama(config, vocab)¶
- Parameters:
vocab (EHRVocab)
- forward(input_ids=None, attention_mask=None, position_ids=None, past_key_values=None, inputs_embeds=None, labels=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, should_break=False)¶
- class model.model.EHRAuditRWKV(config, vocab)¶
- Parameters:
vocab (EHRVocab)
- forward(input_ids=None, attention_mask=None, inputs_embeds=None, state=None, labels=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, should_break=False, **kwargs)¶
- class model.model.EHRAuditTransformerXL(config, vocab)¶
- Parameters:
vocab (EHRVocab)
- forward(input_ids=None, mems=None, head_mask=None, inputs_embeds=None, output_attentions=None, output_hidden_states=None, return_dict=None, labels=None, training=False, should_break=False, **kwargs)¶
- class model.model.TabularLoss(config, vocab, smoothing=None, reduction='mean')¶
Loss function for the tabular transformer, based on Padhi et al. (2021).
Essentially a cross-entropy loss, but with the ability to handle multiple columns. The loss is computed for each column with respect to the number of vocab entries for each column. It is then averaged across the columns if reduction is set to mean or returned as a by-column positional list if set to none. Note that the first column is shifted by one since there is no context for the first column’s first entry.
- Parameters:
- forward(lm_logits, labels)¶
Forward pass for the loss function. Some optimizations have been made to make this faster.
- Parameters:
lm_logits – The logits from the model for the sequence.
labels – The labels for the sequence.
- Returns: