major upload of (python) course material & solutions
This commit is contained in:
@@ -0,0 +1,101 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from ISLP import load_data
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.tree import DecisionTreeClassifier, plot_tree
|
||||
from sklearn.metrics import accuracy_score
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from sklearn.model_selection import cross_val_score
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
# Load and preprocess data
|
||||
Carseats = load_data('Carseats').dropna()
|
||||
|
||||
# Create qualitative variable "High" vs "Low" Sales
|
||||
Carseats['High'] = np.where(Carseats['Sales'] <= 8, 'No', 'Yes')
|
||||
Carseats['High'] = Carseats['High'].astype('category')
|
||||
|
||||
# Drop 'Sales' from predictors
|
||||
X = Carseats.drop(columns=['Sales', 'High'])
|
||||
X = pd.get_dummies(X, drop_first=True) # Convert categorical to dummy variables
|
||||
y = Carseats['High']
|
||||
|
||||
# Train/test split (200 obs each)
|
||||
np.random.seed(2)
|
||||
train_idx = np.random.choice(len(Carseats), size=200, replace=False)
|
||||
X_train = X.iloc[train_idx]
|
||||
X_test = X.drop(train_idx)
|
||||
y_train = y.iloc[train_idx]
|
||||
y_test = y.drop(train_idx)
|
||||
|
||||
# Fit classification tree
|
||||
tree_model = DecisionTreeClassifier(criterion='entropy', random_state=2)
|
||||
tree_model.fit(X_train, y_train)
|
||||
|
||||
# Summary
|
||||
print(f"Tree depth: {tree_model.get_depth()}, Terminal nodes: {tree_model.get_n_leaves()}")
|
||||
|
||||
# Plot tree
|
||||
plt.figure(figsize=(16, 8))
|
||||
plot_tree(tree_model, filled=True, feature_names=X.columns, class_names=tree_model.classes_, fontsize=8)
|
||||
plt.title("Classification Tree")
|
||||
plt.show()
|
||||
|
||||
# Test error rate
|
||||
y_pred = tree_model.predict(X_test)
|
||||
error_rate_test = np.mean(y_pred != y_test)
|
||||
print(f"Test Error (Unpruned Tree): {error_rate_test:.3f}")
|
||||
|
||||
# Cross-validation to find optimal pruning parameter using cost-complexity pruning
|
||||
path = tree_model.cost_complexity_pruning_path(X_train, y_train)
|
||||
ccp_alphas = path.ccp_alphas[:-1] # exclude the last (trivial) alpha
|
||||
cv_errors = []
|
||||
|
||||
for alpha in ccp_alphas:
|
||||
clf = DecisionTreeClassifier(random_state=2, ccp_alpha=alpha)
|
||||
scores = cross_val_score(clf, X_train, y_train, cv=5, scoring='accuracy')
|
||||
cv_errors.append(1 - scores.mean())
|
||||
|
||||
# Plot CV errors
|
||||
plt.figure(figsize=(8, 5))
|
||||
plt.plot(ccp_alphas, cv_errors, marker='o')
|
||||
plt.xlabel("ccp_alpha")
|
||||
plt.ylabel("Cross-Validated Classification Error")
|
||||
plt.title("CV Error vs. Tree Complexity")
|
||||
plt.show()
|
||||
|
||||
# Prune tree with optimal alpha (min CV error)
|
||||
optimal_alpha = ccp_alphas[np.argmin(cv_errors)]
|
||||
pruned_tree = DecisionTreeClassifier(random_state=2, ccp_alpha=optimal_alpha)
|
||||
pruned_tree.fit(X_train, y_train)
|
||||
|
||||
# Plot pruned tree
|
||||
plt.figure(figsize=(16, 8))
|
||||
plot_tree(pruned_tree, filled=True, feature_names=X.columns, class_names=pruned_tree.classes_, fontsize=8)
|
||||
plt.title("Pruned Classification Tree")
|
||||
plt.show()
|
||||
|
||||
# Test error of pruned tree
|
||||
y_pred_pruned = pruned_tree.predict(X_test)
|
||||
error_rate_pruned = np.mean(y_pred_pruned != y_test)
|
||||
print(f"Test Error (Pruned Tree): {error_rate_pruned:.3f}")
|
||||
|
||||
# Fit Random Forest
|
||||
rf_model = RandomForestClassifier(n_estimators=500, max_features=3, oob_score=True, random_state=2)
|
||||
rf_model.fit(X_train, y_train)
|
||||
|
||||
# OOB Error
|
||||
oob_error = 1 - rf_model.oob_score_ if rf_model.oob_score else "OOB not enabled"
|
||||
print(f"OOB Error Rate: {oob_error}")
|
||||
|
||||
# Test error of RF
|
||||
rf_pred = rf_model.predict(X_test)
|
||||
error_rate_rf = np.mean(rf_pred != y_test)
|
||||
print(f"Test Error (Random Forest): {error_rate_rf:.3f}")
|
||||
|
||||
# Feature importance
|
||||
importances = pd.Series(rf_model.feature_importances_, index=X.columns)
|
||||
importances.sort_values(ascending=True).plot(kind='barh', figsize=(10, 8), title="Variable Importance")
|
||||
plt.xlabel("Importance")
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
@@ -0,0 +1,293 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "67cd5699-6111-4576-9386-0fe46130f060",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Preliminary setup"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0ea9c10a-5919-467d-8aca-efa3f2bc05e3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"from ISLP import load_data\n",
|
||||
"from matplotlib.pyplot import subplots, show\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Load and preprocess data\n",
|
||||
"Hitters = load_data('Hitters').dropna()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ce3b15bc-bebb-48cb-b0ab-8754b5004796",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Task 1"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a277a01e-5932-4376-9771-ca735b510eab",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"1. Use the Hitters data and remove all rows that contain missing values. Create a new\n",
|
||||
"variable that is the log of Salary and provide histograms for Salary and Log(Salary).\n",
|
||||
"Interpret."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "bcc5d1a2-c5b8-401d-b854-dd0ff5837704",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5ce10e96-7257-4e74-b4dd-61eadc98090a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"2. Split the sample into a training dataset consisting of the first 200 observations and a\n",
|
||||
"test dataset containing the remaining observations."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e1c39b34-4e4e-42bb-a915-ff7d9edc2bb5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2cffb0ba-7e62-4cff-b79d-ef5e027a62ec",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"3. Fit a large, unpruned regression tree to predigt Log(Salary). Which features are used\n",
|
||||
"to construct the tree, which features are the most important and how many terminal\n",
|
||||
"nodes does the tree have? You might want to plot the tree for this exercise."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "425892e5-ba65-4be4-b103-5d1968973cf5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0c19dc38-6d3d-4d83-8e77-eab071883a1e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"4. Compute the mean squared prediction error for the test data."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "eb73ed7b-6730-4a98-b04e-0d12c0c7125d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "dbae3448-f484-4fe2-afd1-40a741b8ef9e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"5. Let’s try to improve predictions using k-fold CV. Set the seed to 2 and run 5-fold cross\n",
|
||||
"validation. Plot the mean squared cross validation error against the tree size and\n",
|
||||
"report the tree size and the pruning parameter α that minimize the mean squared\n",
|
||||
"cross validation error."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "31280859-0b4f-4b8d-9aeb-4e9c83bd008a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "37322a0e-a542-4b10-88e3-eb88d7b1f2ac",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"6. Use the pruning parameter from the previous task to prune the tree. Plot the tree and\n",
|
||||
"report the most important variables."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b8bf40b3-8cba-4335-92e2-686ba0a93185",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "67496351-580b-4e9f-9b17-2776f2c55843",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"7. Compute the test mean squared prediction error for pruned tree and compare to the\n",
|
||||
"results from Task 4."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c3104831-7607-4eab-a0a2-861adde2658d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "30021421-8807-4481-b28d-6ea23cb06b82",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"8. Use random forest to improve the predictions. Fit $500$ trees using $m = \\sqrt(p)$ (round to the nearest integer)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c907edbf-5755-4a5c-bd12-ea80a2358358",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4b014396-e91b-4f72-9b58-85fa80805eb0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"9. Do you think it was necessary to fit $500$ trees or would have fewer trees be sufficient? Determine the number of trees that provides the lowest OOB error."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "77cb58bd-6d3d-4b0d-ad5e-e18737501cb8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2cea0e71-cc51-4890-b776-e4f03d7af94d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"10. Compute the OOB estimate of the out-of-sample error and compare it to best pruned model from CV of Task 5. Interpret the outcomes."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6aafe1d3-b54c-4bca-9070-ea62ac27f885",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "992771aa-1fec-44d0-b3f5-e8525bd1ce79",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"11. Which are the most important variables used in the random forest?"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "85841a9e-4df5-4d14-ae2b-107002042fd8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bc5eee45-8c48-41dd-ba38-7f78c4bcd036",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"12. Let’s try to improve the random forest by trying out different values for $m$. Set up a grid for m going from $1$ to $p$. Write a loop that fits a random forest for each $m$. Explain which model you would choose."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0361acc5-041d-46b1-848d-eadea0ce717b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6f38e2e4-8242-46c6-9c49-69b7ee73be1e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"13. For the best model, compute the test errors and compare them to the best pruned model from Task 7."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d31d199b-116f-4585-8e4d-e40d4b6ff685",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d6f1407e-5ad1-4690-bf9e-ecc36c4a50e5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"14. What is the OOB error obtained from bagging (you can infer the answer from the previous task)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d7ed7a03-8520-4fba-b2ff-500979e92496",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.13.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
Binary file not shown.
File diff suppressed because one or more lines are too long
Binary file not shown.
Reference in New Issue
Block a user