Loss Functions
- dnamite.loss_fns.coxph_loss(y_hat, events, times)
Cox Proportional Hazards loss function for survival analysis.
- Parameters:
y_hat (torch.Tensor) – Predicted log-risk scores. Shape: (N)
events (torch.Tensor) – Event indicators, where 1 indicates event occurred and 0 indicates censored. Shape: (N)
times (torch.Tensor) – Observed event times. Shape: (N)
- Returns:
Computed CoxPH loss.
- Return type:
torch.Tensor
- dnamite.loss_fns.ipcw_rps_loss(cdf_preds, pcw_eval_times, pcw_obs_times, events, times, eval_times)
IPCW Loss function for survival analysis. Alternative to the Cox loss that does not assume proportional hazards. Requires CDF estimates at specified evaluation times. For mathematical details, see Equation 6 in https://arxiv.org/pdf/2411.05923?.
- Parameters:
cdf_preds (torch.Tensor) – Predicted cumulative distribution function values at eval_times. Shape: (N, T) where N is batch size and T is len(eval_times).
pcw_eval_times (torch.Tensor) – Estimated probability of censoring at evaluation times. Shape: (T)
pcw_obs_times (torch.Tensor) – Estimated probability of censoring at observed times. Shape: (N)
events (torch.Tensor) – Event indicators, where 1 indicates event occurred and 0 indicates censored. Shape: (N)
times (torch.Tensor) – Observed event times. Shape: (N)
eval_times (torch.Tensor) – Evaluation times. Shape: (T)
- Returns:
Computed IPCW loss.
- Return type:
torch.Tensor