{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Feature Selection\n", "\n", "This user guide details how dnamite can be used for feature selection / feature-sparse prediction." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Why Bother with Feature Selection?\n", "\n", "When training a black-box machine learning model, it is common practice to use all available features even for high-dimensional datasets, as modern ML models can easily handle many features. However, there are some settings where feature selection is very useful:\n", "\n", "1) **Interpretability**: When training a glass-box model, we need to care about both predictive performance as well as accurate and utility of explanations. While glass-box models often have good accurate on high-dimensional datasets, model explanations are much more likely to be impaired in such settings. In particular, when sets of correlated features are all used in the same dataset, additive models run into identifiability issues with how to spread contribution across the feature set. This tends to cause increased variance in feature importances and shape functions, reducing confidence in the model's interpretations.\n", "\n", "2) **Simplicity**: Models that use less features are inherently simpler models, which can help with both interpretability and generalizability. For example, consider a model used to predict cancer risk which uses hundreds of features. This model is harder to explain completely doctors and patients since even a glass-box model requires hundreds of shape function plots to describe completely. Also, if the model is deployed to a different medical network than is used during training, it's likely that several features are not available at the new medical network which increases missing values and thus decreases predictive performance.\n", "\n", "3) **Extreme High-Dimensionality**: In datasets that are extremely high-dimensional (e.g. more features than samples), feature selection can be beneficial even just for improving predictive accuracy along with the above reasons. One common example is genomic datasets where samples are expensive to collect but each sample can many thousands of features on the expression level of various genes. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Example\n", "\n", "To showcase the capabilities of DNAMite for feature selection, we'll use the [Ames housing dataset](https://www.openml.org/search?type=data&status=active&id=42165). This dataset involves predicting house prices in Ames, Iowa, similar to the California housing dataset. Different than the California housing data, though, the Ames housing dataset contains a larger number of features, some of which are categorical, along with missing values. We'll demonstrate how dnamite can be used to fit a low-dimensional interpretable model to this dataset." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dataset shape (1460, 79)\n" ] } ], "source": [ "import numpy as np \n", "import pandas as pd \n", "import matplotlib.pyplot as plt \n", "import seaborn as sns \n", "sns.set_theme()\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.datasets import fetch_openml\n", "import torch\n", "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n", "\n", "housing = fetch_openml(name=\"house_prices\", as_frame=True, parser='auto')\n", "data = housing[\"frame\"]\n", "X = data.drop([\"Id\", \"SalePrice\"], axis=1)\n", "y = data[\"SalePrice\"]\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=10)\n", "print(\"Dataset shape\", X.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We start by fitting a model to the complete dataset to serve as a performance benchmark. We set a few additional optional parameters to account for the fact that the Ames dataset has a smaller number of rows." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
DNAMiteRegressor(random_state=0)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
DNAMiteRegressor(random_state=0)
DNAMiteRegressor(random_state=0)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
DNAMiteRegressor(random_state=0)