Loss Functions

dnamite.loss_fns.coxph_loss(y_hat, events, times)

Cox Proportional Hazards loss function for survival analysis.

Parameters:
  • y_hat (torch.Tensor) – Predicted log-risk scores. Shape: (N)

  • events (torch.Tensor) – Event indicators, where 1 indicates event occurred and 0 indicates censored. Shape: (N)

  • times (torch.Tensor) – Observed event times. Shape: (N)

Returns:

Computed CoxPH loss.

Return type:

torch.Tensor

dnamite.loss_fns.ipcw_rps_loss(cdf_preds, pcw_eval_times, pcw_obs_times, events, times, eval_times)

IPCW Loss function for survival analysis. Alternative to the Cox loss that does not assume proportional hazards. Requires CDF estimates at specified evaluation times. For mathematical details, see Equation 6 in https://arxiv.org/pdf/2411.05923?.

Parameters:
  • cdf_preds (torch.Tensor) – Predicted cumulative distribution function values at eval_times. Shape: (N, T) where N is batch size and T is len(eval_times).

  • pcw_eval_times (torch.Tensor) – Estimated probability of censoring at evaluation times. Shape: (T)

  • pcw_obs_times (torch.Tensor) – Estimated probability of censoring at observed times. Shape: (N)

  • events (torch.Tensor) – Event indicators, where 1 indicates event occurred and 0 indicates censored. Shape: (N)

  • times (torch.Tensor) – Observed event times. Shape: (N)

  • eval_times (torch.Tensor) – Evaluation times. Shape: (T)

Returns:

Computed IPCW loss.

Return type:

torch.Tensor