目的:說明如何進行決策樹超參數調校
資料:Pima Indian Diabetes
說明:在機器學習的模型中,除了模型參數要估計之外, 也有一些無需估計(大多有預設值),但也會影響模型結果與預測效力的參數, 這些參數稱為超參數(hyperparameter)。 由於模型估計及預測效力會受超參數數值的影響, 因此通常會建議設定適合的超參數,使得模型有較好的預測效力。 當我們嘗試設定以不同的超參數建模,這種過程稱為模型超參數調校(tuning)。
參考: https://mlr3book.mlr-org.com/optimization.html
以下進行超參數調校過程說明。
首先匯入R語言機器學習套件mlr3verse。
library("mlr3verse")
## Loading required package: mlr3
匯入繪圖套件ggplot2。
library("ggplot2")
匯入Pima Indian Diabetes資料集(資料已封裝在任務物件中)。
task = tsk("pima")
print(task)
## <TaskClassif:pima> (768 x 9)
## * Target: diabetes
## * Properties: twoclass
## * Features (8):
## - dbl (8): age, glucose, insulin, mass, pedigree, pregnant, pressure,
## triceps
取得TaskClassif的任務物件,內含資料集pima,目標二元變數設定為diabetes。 若要進一步分析資料集,以data()方法取得資料集。
pima = task$data()
利用dim()與str()觀察資料筆數與型態,資料集內共有768筆資料,9個變數(欄位)。 其中,反應變數(被預測的變數)為二元變數diabetes (只有pos或neg兩類), 每個人會依其是否有糖尿病被分為兩類。以head()方法查看前6筆資料。 此外,也可利用skimr套件查看各變數統計資訊。
dim(pima)
## [1] 768 9
str(pima)
## Classes 'data.table' and 'data.frame': 768 obs. of 9 variables:
## $ diabetes: Factor w/ 2 levels "pos","neg": 1 2 1 2 1 2 1 2 1 1 ...
## $ age : num 50 31 32 21 33 30 26 29 53 54 ...
## $ glucose : num 148 85 183 89 137 116 78 115 197 125 ...
## $ insulin : num NA NA NA 94 168 NA 88 NA 543 NA ...
## $ mass : num 33.6 26.6 23.3 28.1 43.1 25.6 31 35.3 30.5 NA ...
## $ pedigree: num 0.627 0.351 0.672 0.167 2.288 ...
## $ pregnant: num 6 1 8 1 0 5 3 10 2 8 ...
## $ pressure: num 72 66 64 66 40 74 50 NA 70 96 ...
## $ triceps : num 35 29 NA 23 35 NA 32 NA 45 NA ...
## - attr(*, ".internal.selfref")=<externalptr>
head(pima)
## diabetes age glucose insulin mass pedigree pregnant pressure triceps
## 1: pos 50 148 NA 33.6 0.627 6 72 35
## 2: neg 31 85 NA 26.6 0.351 1 66 29
## 3: pos 32 183 NA 23.3 0.672 8 64 NA
## 4: neg 21 89 94 28.1 0.167 1 66 23
## 5: pos 33 137 168 43.1 2.288 0 40 35
## 6: neg 30 116 NA 25.6 0.201 5 74 NA
在9個變數中,有1個變數資料型態為factor,8個為numeric。 其中diabetes、age、pedigree、pregnant四個變數沒有遺失值, 除了pressure變數外,其他數值型態的變數為右偏分配。
skimr::skim(pima)
Name | pima |
Number of rows | 768 |
Number of columns | 9 |
Key | NULL |
_______________________ | |
Column type frequency: | |
factor | 1 |
numeric | 8 |
________________________ | |
Group variables | None |
Variable type: factor
skim_variable | n_missing | complete_rate | ordered | n_unique | top_counts |
---|---|---|---|---|---|
diabetes | 0 | 1 | FALSE | 2 | neg: 500, pos: 268 |
Variable type: numeric
skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
---|---|---|---|---|---|---|---|---|---|---|
age | 0 | 1.00 | 33.24 | 11.76 | 21.00 | 24.00 | 29.00 | 41.00 | 81.00 | ▇▃▁▁▁ |
glucose | 5 | 0.99 | 121.69 | 30.54 | 44.00 | 99.00 | 117.00 | 141.00 | 199.00 | ▁▇▇▃▂ |
insulin | 374 | 0.51 | 155.55 | 118.78 | 14.00 | 76.25 | 125.00 | 190.00 | 846.00 | ▇▂▁▁▁ |
mass | 11 | 0.99 | 32.46 | 6.92 | 18.20 | 27.50 | 32.30 | 36.60 | 67.10 | ▅▇▃▁▁ |
pedigree | 0 | 1.00 | 0.47 | 0.33 | 0.08 | 0.24 | 0.37 | 0.63 | 2.42 | ▇▃▁▁▁ |
pregnant | 0 | 1.00 | 3.85 | 3.37 | 0.00 | 1.00 | 3.00 | 6.00 | 17.00 | ▇▃▂▁▁ |
pressure | 35 | 0.95 | 72.41 | 12.38 | 24.00 | 64.00 | 72.00 | 80.00 | 122.00 | ▁▃▇▂▁ |
triceps | 227 | 0.70 | 29.15 | 10.48 | 7.00 | 22.00 | 29.00 | 36.00 | 99.00 | ▆▇▁▁▁ |
以下說明如何對決策樹模型為的兩個超參數:cp (complexity)、 minsplit (termination criterion)進行調校。
在匯入資料的過程中已取得任務物件task,此物件記錄任務使用的資料pima, 目標二元變數diagetes,及分類的任務類別TaskClassif。
print(task)
## <TaskClassif:pima> (768 x 9)
## * Target: diabetes
## * Properties: twoclass
## * Features (8):
## - dbl (8): age, glucose, insulin, mass, pedigree, pregnant, pressure,
## triceps
首先,以lrn()方法建立learner物件,利用rpart套件建置決策樹模型。 以learner物件的param_set屬性檢視所有超參數及其預設值
learner = lrn("classif.rpart")
learner$param_set
## <ParamSet>
## id class lower upper nlevels default value
## 1: cp ParamDbl 0 1 Inf 0.01
## 2: keep_model ParamLgl NA NA 2 FALSE
## 3: maxcompete ParamInt 0 Inf Inf 4
## 4: maxdepth ParamInt 1 30 30 30
## 5: maxsurrogate ParamInt 0 Inf Inf 5
## 6: minbucket ParamInt 1 Inf Inf <NoDefault[3]>
## 7: minsplit ParamInt 1 Inf Inf 20
## 8: surrogatestyle ParamInt 0 1 2 0
## 9: usesurrogate ParamInt 0 2 3 2
## 10: xval ParamInt 0 Inf Inf 10 0
手動設定需要調校之參數值的上下界:利用ps()方法設定cp與minsplit的上下界, 並將結果指派到ParamSet物件search_space
search_space = ps(
cp = p_dbl(lower = 0.001, upper = 0.1),
minsplit = p_int(lower = 1, upper = 10)
)
search_space
## <ParamSet>
## id class lower upper nlevels default value
## 1: cp ParamDbl 0.001 0.1 Inf <NoDefault[3]>
## 2: minsplit ParamInt 1.000 10.0 10 <NoDefault[3]>
指定評估模型效力的方式,分別為重複抽樣策略及有效性衡量。
1. 重複抽樣策略:以holdout的方式分割樣本,所以只分割一次, 產生一組訓練樣本與測試樣本。預設訓練樣本數佔比為2/3,或直接指定ratio=2/3。
set.seed(1234) # 使得每次執行樣本分割的結果一致
hout = rsmp("holdout")
hout = rsmp("holdout", ratio=2/3)
print(hout)
## <ResamplingHoldout> with 1 iterations
## * Instantiated: FALSE
## * Parameters: ratio=0.6667
也可以使用其他的重複抽樣策略,例如subsampling的方式。
resampling = rsmp("subsampling", repeats = 10, ratio = 2/3)
2. 有效性衡量:以分類錯誤率衡量模型有效性。
measure = msr("classif.ce")
設定調校過程的中止條件。可以選擇以下幾個Terminator。
1. 在達到給定的時間後中止 (TerminatorClockTime)
2. 在達到給定的迭代次數後中止 (TerminatorEvals)
3. 在達到特定的模型效力後中止 (TerminatorPerfReached)
4. 在無法提升調校效能後中止 (TerminatorStagnation)
5. 上述任何條件的組合 (TerminatorCombo)
以下以達到給定的迭代次數的中止條件為例,設定模型評估的迭代次數為20次。
匯入套件並建立迭代中止物件。
library("mlr3tuning")
## Loading required package: paradox
evals20 = trm("evals", n_evals = 20)
建立調校準則物件(single criterion tuning instance), 封裝已建立的task、learner、hout、measure、search_space、evals20等物件。
instance = TuningInstanceSingleCrit$new(
task = task,
learner = learner,
resampling = hout,
measure = measure,
search_space = search_space,
terminator = evals20
)
instance
## <TuningInstanceSingleCrit>
## * State: Not optimized
## * Objective: <ObjectiveTuning:classif.rpart_on_pima>
## * Search Space:
## <ParamSet>
## id class lower upper nlevels default value
## 1: cp ParamDbl 0.001 0.1 Inf <NoDefault[3]>
## 2: minsplit ParamInt 1.000 10.0 10 <NoDefault[3]>
## * Terminator: <TerminatorEvals>
## * Terminated: FALSE
## * Archive:
## <ArchiveTuning>
## Null data.table (0 rows and 0 cols)
此外,在開始調校之前,必須指定使用何種最佳化演算法(optimization algorithm)。 Tuner類別提供的數個演算法,可由mlr3tuning取得,包括:
1. Grid Search (TunerGridSearch)
2. Random Search (TunerRandomSearch) (Bergstra and Bengio 2012)
3. Generalized Simulated Annealing (TunerGenSA)
4. Non-Linear Optimization (TunerNLoptr)
以下以grid search為例。
tuner = tnr("grid_search", resolution = 5, batch_size = 1)
設定resolution為5,因此每個超參數在上下界中間會有設定5個等距的點(值),因為有兩個超參數要調校,因此總共配置25個參數值的配對組合。
本次迭代次數設定為20(見evals20物件的設定),所以只隨機評估20組配對值。
若迭代次數設定為25次以上,則所有25個配對都會被評估。
利用tuner物件的optimize()方法進行調校,引數為調校準則物件。 機制如下:當每次建模時,決策樹模型的兩個超參數(cp和minsplit)會被隨機設定成25組參數配對組合之一。當中止條件成立(本例為完成20組配對評估)或所有設定值(25組,建立25個決策樹模型)全部被評估過,則完成調校。
當執行以下指令時,
tuner$optimize(instance)
調校程序如下:
1. tuner每次評估一組(或一組以上)的超參數設定值,batch_size可控制每次的評估組數。
2. 依調校準則物件的設定建模。
3. 若達到中止條件(此例為評估20次)則結束建模,否則回到1.。
4. 決定一組使得模型預測效力最佳的超參數。
5. 最後超參數及模型評估值會被記錄在instance物件的result_learner_param_vals及result_y中。
tuner$optimize(instance)
## INFO [11:29:57.782] [bbotk] Starting to optimize 2 parameter(s) with '<OptimizerGridSearch>' and '<TerminatorEvals> [n_evals=20, k=0]'
## INFO [11:29:57.808] [bbotk] Evaluating 1 configuration(s)
## INFO [11:29:57.832] [mlr3] Running benchmark with 1 resampling iterations
## INFO [11:29:57.874] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [11:29:57.903] [mlr3] Finished benchmark
## INFO [11:29:57.958] [bbotk] Result of batch 1:
## INFO [11:29:57.960] [bbotk] cp minsplit classif.ce runtime_learners
## INFO [11:29:57.960] [bbotk] 0.001 8 0.3085938 0.01
## INFO [11:29:57.960] [bbotk] uhash
## INFO [11:29:57.960] [bbotk] 1640e1a3-94e4-41ad-bd8c-d25662e40637
## INFO [11:29:57.961] [bbotk] Evaluating 1 configuration(s)
## INFO [11:29:57.976] [mlr3] Running benchmark with 1 resampling iterations
## INFO [11:29:57.982] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [11:29:57.995] [mlr3] Finished benchmark
## INFO [11:29:58.043] [bbotk] Result of batch 2:
## INFO [11:29:58.044] [bbotk] cp minsplit classif.ce runtime_learners
## INFO [11:29:58.044] [bbotk] 0.001 3 0.296875 0.02
## INFO [11:29:58.044] [bbotk] uhash
## INFO [11:29:58.044] [bbotk] 2778f922-8e7a-410c-a45a-cf019b09a79f
## INFO [11:29:58.045] [bbotk] Evaluating 1 configuration(s)
## INFO [11:29:58.060] [mlr3] Running benchmark with 1 resampling iterations
## INFO [11:29:58.066] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [11:29:58.078] [mlr3] Finished benchmark
## INFO [11:29:58.126] [bbotk] Result of batch 3:
## INFO [11:29:58.128] [bbotk] cp minsplit classif.ce runtime_learners uhash
## INFO [11:29:58.128] [bbotk] 0.1 3 0.2539062 0 eaf3e580-64f2-4591-b5f3-a32e17373fbc
## INFO [11:29:58.129] [bbotk] Evaluating 1 configuration(s)
## INFO [11:29:58.152] [mlr3] Running benchmark with 1 resampling iterations
## INFO [11:29:58.156] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [11:29:58.168] [mlr3] Finished benchmark
## INFO [11:29:58.221] [bbotk] Result of batch 4:
## INFO [11:29:58.223] [bbotk] cp minsplit classif.ce runtime_learners
## INFO [11:29:58.223] [bbotk] 0.02575 3 0.2539062 0
## INFO [11:29:58.223] [bbotk] uhash
## INFO [11:29:58.223] [bbotk] 9e34687d-70b4-4a46-8ff6-531764a2d9c2
## INFO [11:29:58.224] [bbotk] Evaluating 1 configuration(s)
## INFO [11:29:58.238] [mlr3] Running benchmark with 1 resampling iterations
## INFO [11:29:58.243] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [11:29:58.255] [mlr3] Finished benchmark
## INFO [11:29:58.309] [bbotk] Result of batch 5:
## INFO [11:29:58.310] [bbotk] cp minsplit classif.ce runtime_learners
## INFO [11:29:58.310] [bbotk] 0.001 5 0.265625 0.01
## INFO [11:29:58.310] [bbotk] uhash
## INFO [11:29:58.310] [bbotk] fd305853-449d-485f-b289-eb08f7fc3112
## INFO [11:29:58.312] [bbotk] Evaluating 1 configuration(s)
## INFO [11:29:58.325] [mlr3] Running benchmark with 1 resampling iterations
## INFO [11:29:58.330] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [11:29:58.341] [mlr3] Finished benchmark
## INFO [11:29:58.392] [bbotk] Result of batch 6:
## INFO [11:29:58.393] [bbotk] cp minsplit classif.ce runtime_learners
## INFO [11:29:58.393] [bbotk] 0.0505 10 0.2539062 0
## INFO [11:29:58.393] [bbotk] uhash
## INFO [11:29:58.393] [bbotk] b9f10bc6-d76f-4533-882d-af29438748b8
## INFO [11:29:58.394] [bbotk] Evaluating 1 configuration(s)
## INFO [11:29:58.408] [mlr3] Running benchmark with 1 resampling iterations
## INFO [11:29:58.412] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [11:29:58.424] [mlr3] Finished benchmark
## INFO [11:29:58.480] [bbotk] Result of batch 7:
## INFO [11:29:58.481] [bbotk] cp minsplit classif.ce runtime_learners
## INFO [11:29:58.481] [bbotk] 0.001 10 0.296875 0.01
## INFO [11:29:58.481] [bbotk] uhash
## INFO [11:29:58.481] [bbotk] 71f7bb78-0b09-4d38-ba64-3d186250cc90
## INFO [11:29:58.482] [bbotk] Evaluating 1 configuration(s)
## INFO [11:29:58.496] [mlr3] Running benchmark with 1 resampling iterations
## INFO [11:29:58.501] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [11:29:58.514] [mlr3] Finished benchmark
## INFO [11:29:58.562] [bbotk] Result of batch 8:
## INFO [11:29:58.563] [bbotk] cp minsplit classif.ce runtime_learners uhash
## INFO [11:29:58.563] [bbotk] 0.1 8 0.2539062 0 1c4d38c3-219f-46d4-8536-1aedba2983b7
## INFO [11:29:58.564] [bbotk] Evaluating 1 configuration(s)
## INFO [11:29:58.581] [mlr3] Running benchmark with 1 resampling iterations
## INFO [11:29:58.586] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [11:29:58.598] [mlr3] Finished benchmark
## INFO [11:29:58.645] [bbotk] Result of batch 9:
## INFO [11:29:58.646] [bbotk] cp minsplit classif.ce runtime_learners
## INFO [11:29:58.646] [bbotk] 0.02575 1 0.2539062 0.01
## INFO [11:29:58.646] [bbotk] uhash
## INFO [11:29:58.646] [bbotk] dc9eac76-531a-4e22-a012-cd82264d11e8
## INFO [11:29:58.647] [bbotk] Evaluating 1 configuration(s)
## INFO [11:29:58.664] [mlr3] Running benchmark with 1 resampling iterations
## INFO [11:29:58.669] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [11:29:58.681] [mlr3] Finished benchmark
## INFO [11:29:58.732] [bbotk] Result of batch 10:
## INFO [11:29:58.733] [bbotk] cp minsplit classif.ce runtime_learners
## INFO [11:29:58.733] [bbotk] 0.0505 8 0.2539062 0.01
## INFO [11:29:58.733] [bbotk] uhash
## INFO [11:29:58.733] [bbotk] ac024f47-4be7-4cd0-8ba7-0f5be33f4d95
## INFO [11:29:58.734] [bbotk] Evaluating 1 configuration(s)
## INFO [11:29:58.751] [mlr3] Running benchmark with 1 resampling iterations
## INFO [11:29:58.756] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [11:29:58.768] [mlr3] Finished benchmark
## INFO [11:29:58.822] [bbotk] Result of batch 11:
## INFO [11:29:58.823] [bbotk] cp minsplit classif.ce runtime_learners
## INFO [11:29:58.823] [bbotk] 0.02575 10 0.2539062 0
## INFO [11:29:58.823] [bbotk] uhash
## INFO [11:29:58.823] [bbotk] 588bd2ca-3ae2-40f5-9f22-83d1f3582897
## INFO [11:29:58.824] [bbotk] Evaluating 1 configuration(s)
## INFO [11:29:58.839] [mlr3] Running benchmark with 1 resampling iterations
## INFO [11:29:58.843] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [11:29:58.855] [mlr3] Finished benchmark
## INFO [11:29:58.906] [bbotk] Result of batch 12:
## INFO [11:29:58.908] [bbotk] cp minsplit classif.ce runtime_learners
## INFO [11:29:58.908] [bbotk] 0.07525 10 0.2539062 0
## INFO [11:29:58.908] [bbotk] uhash
## INFO [11:29:58.908] [bbotk] ec4289d1-9680-4fbf-bbbe-a49afa4ab671
## INFO [11:29:58.909] [bbotk] Evaluating 1 configuration(s)
## INFO [11:29:58.923] [mlr3] Running benchmark with 1 resampling iterations
## INFO [11:29:58.927] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [11:29:58.938] [mlr3] Finished benchmark
## INFO [11:29:59.028] [bbotk] Result of batch 13:
## INFO [11:29:59.030] [bbotk] cp minsplit classif.ce runtime_learners
## INFO [11:29:59.030] [bbotk] 0.07525 3 0.2539062 0
## INFO [11:29:59.030] [bbotk] uhash
## INFO [11:29:59.030] [bbotk] 674a6bd1-a766-4436-ae89-693818e6599d
## INFO [11:29:59.032] [bbotk] Evaluating 1 configuration(s)
## INFO [11:29:59.047] [mlr3] Running benchmark with 1 resampling iterations
## INFO [11:29:59.051] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [11:29:59.062] [mlr3] Finished benchmark
## INFO [11:29:59.108] [bbotk] Result of batch 14:
## INFO [11:29:59.109] [bbotk] cp minsplit classif.ce runtime_learners
## INFO [11:29:59.109] [bbotk] 0.07525 8 0.2539062 0
## INFO [11:29:59.109] [bbotk] uhash
## INFO [11:29:59.109] [bbotk] dac74012-f5ad-461b-802b-4d853e1d97fb
## INFO [11:29:59.110] [bbotk] Evaluating 1 configuration(s)
## INFO [11:29:59.126] [mlr3] Running benchmark with 1 resampling iterations
## INFO [11:29:59.130] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [11:29:59.141] [mlr3] Finished benchmark
## INFO [11:29:59.189] [bbotk] Result of batch 15:
## INFO [11:29:59.190] [bbotk] cp minsplit classif.ce runtime_learners uhash
## INFO [11:29:59.190] [bbotk] 0.1 1 0.2539062 0 67720b97-080c-4266-a858-35197a4642be
## INFO [11:29:59.191] [bbotk] Evaluating 1 configuration(s)
## INFO [11:29:59.205] [mlr3] Running benchmark with 1 resampling iterations
## INFO [11:29:59.209] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [11:29:59.220] [mlr3] Finished benchmark
## INFO [11:29:59.274] [bbotk] Result of batch 16:
## INFO [11:29:59.275] [bbotk] cp minsplit classif.ce runtime_learners
## INFO [11:29:59.275] [bbotk] 0.02575 5 0.2539062 0
## INFO [11:29:59.275] [bbotk] uhash
## INFO [11:29:59.275] [bbotk] fda6eda0-81f7-427c-b010-8cff139eedd3
## INFO [11:29:59.276] [bbotk] Evaluating 1 configuration(s)
## INFO [11:29:59.290] [mlr3] Running benchmark with 1 resampling iterations
## INFO [11:29:59.294] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [11:29:59.305] [mlr3] Finished benchmark
## INFO [11:29:59.353] [bbotk] Result of batch 17:
## INFO [11:29:59.355] [bbotk] cp minsplit classif.ce runtime_learners
## INFO [11:29:59.355] [bbotk] 0.02575 8 0.2539062 0.02
## INFO [11:29:59.355] [bbotk] uhash
## INFO [11:29:59.355] [bbotk] 29e3e76b-8a2e-4183-ad40-49d6a514b7b3
## INFO [11:29:59.356] [bbotk] Evaluating 1 configuration(s)
## INFO [11:29:59.369] [mlr3] Running benchmark with 1 resampling iterations
## INFO [11:29:59.374] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [11:29:59.384] [mlr3] Finished benchmark
## INFO [11:29:59.434] [bbotk] Result of batch 18:
## INFO [11:29:59.435] [bbotk] cp minsplit classif.ce runtime_learners uhash
## INFO [11:29:59.435] [bbotk] 0.1 5 0.2539062 0.02 b78beace-c168-4324-b83c-9fc804c02ff1
## INFO [11:29:59.436] [bbotk] Evaluating 1 configuration(s)
## INFO [11:29:59.450] [mlr3] Running benchmark with 1 resampling iterations
## INFO [11:29:59.454] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [11:29:59.464] [mlr3] Finished benchmark
## INFO [11:29:59.519] [bbotk] Result of batch 19:
## INFO [11:29:59.520] [bbotk] cp minsplit classif.ce runtime_learners
## INFO [11:29:59.520] [bbotk] 0.07525 5 0.2539062 0
## INFO [11:29:59.520] [bbotk] uhash
## INFO [11:29:59.520] [bbotk] 04cb96a8-a363-43dd-ab11-5afa5dd81a36
## INFO [11:29:59.522] [bbotk] Evaluating 1 configuration(s)
## INFO [11:29:59.534] [mlr3] Running benchmark with 1 resampling iterations
## INFO [11:29:59.539] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [11:29:59.549] [mlr3] Finished benchmark
## INFO [11:29:59.598] [bbotk] Result of batch 20:
## INFO [11:29:59.600] [bbotk] cp minsplit classif.ce runtime_learners
## INFO [11:29:59.600] [bbotk] 0.0505 1 0.2539062 0.02
## INFO [11:29:59.600] [bbotk] uhash
## INFO [11:29:59.600] [bbotk] 43c99936-7aee-47db-98cb-48b0bc0c8ff7
## INFO [11:29:59.604] [bbotk] Finished optimizing after 20 evaluation(s)
## INFO [11:29:59.605] [bbotk] Result:
## INFO [11:29:59.606] [bbotk] cp minsplit learner_param_vals x_domain classif.ce
## INFO [11:29:59.606] [bbotk] 0.1 3 <list[3]> <list[2]> 0.2539062
## cp minsplit learner_param_vals x_domain classif.ce
## 1: 0.1 3 <list[3]> <list[2]> 0.2539062
instance$result_learner_param_vals
## $xval
## [1] 0
##
## $cp
## [1] 0.1
##
## $minsplit
## [1] 3
instance$result_y
## classif.ce
## 0.2539062
利用以下指令查看在所有參數配對下的模型評估結果。
as.data.table(instance$archive)
## cp minsplit classif.ce x_domain_cp x_domain_minsplit runtime_learners
## 1: 0.00100 8 0.3085938 0.00100 8 0.01
## 2: 0.00100 3 0.2968750 0.00100 3 0.02
## 3: 0.10000 3 0.2539062 0.10000 3 0.00
## 4: 0.02575 3 0.2539062 0.02575 3 0.00
## 5: 0.00100 5 0.2656250 0.00100 5 0.01
## 6: 0.05050 10 0.2539062 0.05050 10 0.00
## 7: 0.00100 10 0.2968750 0.00100 10 0.01
## 8: 0.10000 8 0.2539062 0.10000 8 0.00
## 9: 0.02575 1 0.2539062 0.02575 1 0.01
## 10: 0.05050 8 0.2539062 0.05050 8 0.01
## 11: 0.02575 10 0.2539062 0.02575 10 0.00
## 12: 0.07525 10 0.2539062 0.07525 10 0.00
## 13: 0.07525 3 0.2539062 0.07525 3 0.00
## 14: 0.07525 8 0.2539062 0.07525 8 0.00
## 15: 0.10000 1 0.2539062 0.10000 1 0.00
## 16: 0.02575 5 0.2539062 0.02575 5 0.00
## 17: 0.02575 8 0.2539062 0.02575 8 0.02
## 18: 0.10000 5 0.2539062 0.10000 5 0.02
## 19: 0.07525 5 0.2539062 0.07525 5 0.00
## 20: 0.05050 1 0.2539062 0.05050 1 0.02
## timestamp batch_nr resample_result
## 1: 2021-10-12 11:29:57 1 <ResampleResult[20]>
## 2: 2021-10-12 11:29:58 2 <ResampleResult[20]>
## 3: 2021-10-12 11:29:58 3 <ResampleResult[20]>
## 4: 2021-10-12 11:29:58 4 <ResampleResult[20]>
## 5: 2021-10-12 11:29:58 5 <ResampleResult[20]>
## 6: 2021-10-12 11:29:58 6 <ResampleResult[20]>
## 7: 2021-10-12 11:29:58 7 <ResampleResult[20]>
## 8: 2021-10-12 11:29:58 8 <ResampleResult[20]>
## 9: 2021-10-12 11:29:58 9 <ResampleResult[20]>
## 10: 2021-10-12 11:29:58 10 <ResampleResult[20]>
## 11: 2021-10-12 11:29:58 11 <ResampleResult[20]>
## 12: 2021-10-12 11:29:58 12 <ResampleResult[20]>
## 13: 2021-10-12 11:29:59 13 <ResampleResult[20]>
## 14: 2021-10-12 11:29:59 14 <ResampleResult[20]>
## 15: 2021-10-12 11:29:59 15 <ResampleResult[20]>
## 16: 2021-10-12 11:29:59 16 <ResampleResult[20]>
## 17: 2021-10-12 11:29:59 17 <ResampleResult[20]>
## 18: 2021-10-12 11:29:59 18 <ResampleResult[20]>
## 19: 2021-10-12 11:29:59 19 <ResampleResult[20]>
## 20: 2021-10-12 11:29:59 20 <ResampleResult[20]>
因此,可以看出來grid search隨機評估了25組中的20組超參數配對(中止條件)。 也可以利用benchmark_result屬性查詢每一次重複抽樣的迭代資訊。
instance$archive$benchmark_result
## <BenchmarkResult> of 20 rows with 20 resampling runs
## nr task_id learner_id resampling_id iters warnings errors
## 1 pima classif.rpart holdout 1 0 0
## 2 pima classif.rpart holdout 1 0 0
## 3 pima classif.rpart holdout 1 0 0
## 4 pima classif.rpart holdout 1 0 0
## 5 pima classif.rpart holdout 1 0 0
## 6 pima classif.rpart holdout 1 0 0
## 7 pima classif.rpart holdout 1 0 0
## 8 pima classif.rpart holdout 1 0 0
## 9 pima classif.rpart holdout 1 0 0
## 10 pima classif.rpart holdout 1 0 0
## 11 pima classif.rpart holdout 1 0 0
## 12 pima classif.rpart holdout 1 0 0
## 13 pima classif.rpart holdout 1 0 0
## 14 pima classif.rpart holdout 1 0 0
## 15 pima classif.rpart holdout 1 0 0
## 16 pima classif.rpart holdout 1 0 0
## 17 pima classif.rpart holdout 1 0 0
## 18 pima classif.rpart holdout 1 0 0
## 19 pima classif.rpart holdout 1 0 0
## 20 pima classif.rpart holdout 1 0 0
可查詢20個模型的精確度(accuracy).
instance$archive$benchmark_result$score(msr("classif.acc"))
## uhash nr task task_id
## 1: 1640e1a3-94e4-41ad-bd8c-d25662e40637 1 <TaskClassif[47]> pima
## 2: 2778f922-8e7a-410c-a45a-cf019b09a79f 2 <TaskClassif[47]> pima
## 3: eaf3e580-64f2-4591-b5f3-a32e17373fbc 3 <TaskClassif[47]> pima
## 4: 9e34687d-70b4-4a46-8ff6-531764a2d9c2 4 <TaskClassif[47]> pima
## 5: fd305853-449d-485f-b289-eb08f7fc3112 5 <TaskClassif[47]> pima
## 6: b9f10bc6-d76f-4533-882d-af29438748b8 6 <TaskClassif[47]> pima
## 7: 71f7bb78-0b09-4d38-ba64-3d186250cc90 7 <TaskClassif[47]> pima
## 8: 1c4d38c3-219f-46d4-8536-1aedba2983b7 8 <TaskClassif[47]> pima
## 9: dc9eac76-531a-4e22-a012-cd82264d11e8 9 <TaskClassif[47]> pima
## 10: ac024f47-4be7-4cd0-8ba7-0f5be33f4d95 10 <TaskClassif[47]> pima
## 11: 588bd2ca-3ae2-40f5-9f22-83d1f3582897 11 <TaskClassif[47]> pima
## 12: ec4289d1-9680-4fbf-bbbe-a49afa4ab671 12 <TaskClassif[47]> pima
## 13: 674a6bd1-a766-4436-ae89-693818e6599d 13 <TaskClassif[47]> pima
## 14: dac74012-f5ad-461b-802b-4d853e1d97fb 14 <TaskClassif[47]> pima
## 15: 67720b97-080c-4266-a858-35197a4642be 15 <TaskClassif[47]> pima
## 16: fda6eda0-81f7-427c-b010-8cff139eedd3 16 <TaskClassif[47]> pima
## 17: 29e3e76b-8a2e-4183-ad40-49d6a514b7b3 17 <TaskClassif[47]> pima
## 18: b78beace-c168-4324-b83c-9fc804c02ff1 18 <TaskClassif[47]> pima
## 19: 04cb96a8-a363-43dd-ab11-5afa5dd81a36 19 <TaskClassif[47]> pima
## 20: 43c99936-7aee-47db-98cb-48b0bc0c8ff7 20 <TaskClassif[47]> pima
## learner learner_id resampling
## 1: <LearnerClassifRpart[36]> classif.rpart <ResamplingHoldout[19]>
## 2: <LearnerClassifRpart[36]> classif.rpart <ResamplingHoldout[19]>
## 3: <LearnerClassifRpart[36]> classif.rpart <ResamplingHoldout[19]>
## 4: <LearnerClassifRpart[36]> classif.rpart <ResamplingHoldout[19]>
## 5: <LearnerClassifRpart[36]> classif.rpart <ResamplingHoldout[19]>
## 6: <LearnerClassifRpart[36]> classif.rpart <ResamplingHoldout[19]>
## 7: <LearnerClassifRpart[36]> classif.rpart <ResamplingHoldout[19]>
## 8: <LearnerClassifRpart[36]> classif.rpart <ResamplingHoldout[19]>
## 9: <LearnerClassifRpart[36]> classif.rpart <ResamplingHoldout[19]>
## 10: <LearnerClassifRpart[36]> classif.rpart <ResamplingHoldout[19]>
## 11: <LearnerClassifRpart[36]> classif.rpart <ResamplingHoldout[19]>
## 12: <LearnerClassifRpart[36]> classif.rpart <ResamplingHoldout[19]>
## 13: <LearnerClassifRpart[36]> classif.rpart <ResamplingHoldout[19]>
## 14: <LearnerClassifRpart[36]> classif.rpart <ResamplingHoldout[19]>
## 15: <LearnerClassifRpart[36]> classif.rpart <ResamplingHoldout[19]>
## 16: <LearnerClassifRpart[36]> classif.rpart <ResamplingHoldout[19]>
## 17: <LearnerClassifRpart[36]> classif.rpart <ResamplingHoldout[19]>
## 18: <LearnerClassifRpart[36]> classif.rpart <ResamplingHoldout[19]>
## 19: <LearnerClassifRpart[36]> classif.rpart <ResamplingHoldout[19]>
## 20: <LearnerClassifRpart[36]> classif.rpart <ResamplingHoldout[19]>
## resampling_id iteration prediction classif.acc
## 1: holdout 1 <PredictionClassif[19]> 0.6914062
## 2: holdout 1 <PredictionClassif[19]> 0.7031250
## 3: holdout 1 <PredictionClassif[19]> 0.7460938
## 4: holdout 1 <PredictionClassif[19]> 0.7460938
## 5: holdout 1 <PredictionClassif[19]> 0.7343750
## 6: holdout 1 <PredictionClassif[19]> 0.7460938
## 7: holdout 1 <PredictionClassif[19]> 0.7031250
## 8: holdout 1 <PredictionClassif[19]> 0.7460938
## 9: holdout 1 <PredictionClassif[19]> 0.7460938
## 10: holdout 1 <PredictionClassif[19]> 0.7460938
## 11: holdout 1 <PredictionClassif[19]> 0.7460938
## 12: holdout 1 <PredictionClassif[19]> 0.7460938
## 13: holdout 1 <PredictionClassif[19]> 0.7460938
## 14: holdout 1 <PredictionClassif[19]> 0.7460938
## 15: holdout 1 <PredictionClassif[19]> 0.7460938
## 16: holdout 1 <PredictionClassif[19]> 0.7460938
## 17: holdout 1 <PredictionClassif[19]> 0.7460938
## 18: holdout 1 <PredictionClassif[19]> 0.7460938
## 19: holdout 1 <PredictionClassif[19]> 0.7460938
## 20: holdout 1 <PredictionClassif[19]> 0.7460938
將最佳的超參數數值設定回learner,並以全部樣本建模。
learner$param_set$values = instance$result_learner_param_vals
learner$param_set
## <ParamSet>
## id class lower upper nlevels default value
## 1: cp ParamDbl 0 1 Inf 0.01 0.1
## 2: keep_model ParamLgl NA NA 2 FALSE
## 3: maxcompete ParamInt 0 Inf Inf 4
## 4: maxdepth ParamInt 1 30 30 30
## 5: maxsurrogate ParamInt 0 Inf Inf 5
## 6: minbucket ParamInt 1 Inf Inf <NoDefault[3]>
## 7: minsplit ParamInt 1 Inf Inf 20 3
## 8: surrogatestyle ParamInt 0 1 2 0
## 9: usesurrogate ParamInt 0 2 3 2
## 10: xval ParamInt 0 Inf Inf 10 0
learner$train(task)
以ggplot2繪製變數重要性數值的長條圖。
importance = as.data.table(learner$importance(), keep.rownames = TRUE)
colnames(importance) = c("Feature", "Importance")
ggplot(importance, aes(x = reorder(Feature, Importance), y = Importance)) +
geom_col() + coord_flip() + xlab("")
可利用此建好的模型對一組新的資料進行預測。
此外,也可利用AutoTuner類別進行調校。與Tuner不同的是,AutoTuner繼承Learner, 所以其既是Tuner,也是Learner。
learner = lrn("classif.rpart")
search_space = ps(
cp = p_dbl(lower = 0.001, upper = 0.1),
minsplit = p_int(lower = 1, upper = 10)
)
terminator = trm("evals", n_evals = 10)
tuner = tnr("random_search")
at = AutoTuner$new(
learner = learner,
resampling = rsmp("holdout"),
measure = msr("classif.ce"),
search_space = search_space,
terminator = terminator,
tuner = tuner
)
at
## <AutoTuner:classif.rpart.tuned>
## * Model: -
## * Search Space:
## <ParamSet>
## id class lower upper nlevels default value
## 1: cp ParamDbl 0.001 0.1 Inf <NoDefault[3]>
## 2: minsplit ParamInt 1.000 10.0 10 <NoDefault[3]>
## * Packages: rpart
## * Predict Type: response
## * Feature Types: logical, integer, numeric, factor, ordered
## * Properties: importance, missings, multiclass, selected_features,
## twoclass, weights
因為at也是Learner,可直接套用train()與predict()方法。
at$train(task)
## INFO [11:30:00.287] [bbotk] Starting to optimize 2 parameter(s) with '<OptimizerRandomSearch>' and '<TerminatorEvals> [n_evals=10, k=0]'
## INFO [11:30:00.304] [bbotk] Evaluating 1 configuration(s)
## INFO [11:30:00.323] [mlr3] Running benchmark with 1 resampling iterations
## INFO [11:30:00.327] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [11:30:00.340] [mlr3] Finished benchmark
## INFO [11:30:00.399] [bbotk] Result of batch 1:
## INFO [11:30:00.400] [bbotk] cp minsplit classif.ce runtime_learners
## INFO [11:30:00.400] [bbotk] 0.02154074 2 0.2773438 0.02
## INFO [11:30:00.400] [bbotk] uhash
## INFO [11:30:00.400] [bbotk] c626fbcb-004e-4ab7-b78d-61dbe616910b
## INFO [11:30:00.403] [bbotk] Evaluating 1 configuration(s)
## INFO [11:30:00.417] [mlr3] Running benchmark with 1 resampling iterations
## INFO [11:30:00.422] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [11:30:00.433] [mlr3] Finished benchmark
## INFO [11:30:00.484] [bbotk] Result of batch 2:
## INFO [11:30:00.485] [bbotk] cp minsplit classif.ce runtime_learners
## INFO [11:30:00.485] [bbotk] 0.02171658 8 0.2773438 0
## INFO [11:30:00.485] [bbotk] uhash
## INFO [11:30:00.485] [bbotk] e7d75140-ff60-4451-823e-ac110d68e606
## INFO [11:30:00.489] [bbotk] Evaluating 1 configuration(s)
## INFO [11:30:00.502] [mlr3] Running benchmark with 1 resampling iterations
## INFO [11:30:00.507] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [11:30:00.517] [mlr3] Finished benchmark
## INFO [11:30:00.570] [bbotk] Result of batch 3:
## INFO [11:30:00.571] [bbotk] cp minsplit classif.ce runtime_learners
## INFO [11:30:00.571] [bbotk] 0.05513162 8 0.296875 0
## INFO [11:30:00.571] [bbotk] uhash
## INFO [11:30:00.571] [bbotk] 23729d21-dc48-4098-ad85-357cca7833a9
## INFO [11:30:00.574] [bbotk] Evaluating 1 configuration(s)
## INFO [11:30:00.588] [mlr3] Running benchmark with 1 resampling iterations
## INFO [11:30:00.593] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [11:30:00.604] [mlr3] Finished benchmark
## INFO [11:30:00.657] [bbotk] Result of batch 4:
## INFO [11:30:00.659] [bbotk] cp minsplit classif.ce runtime_learners
## INFO [11:30:00.659] [bbotk] 0.08936942 3 0.296875 0.01
## INFO [11:30:00.659] [bbotk] uhash
## INFO [11:30:00.659] [bbotk] aa2ab9ff-ecef-4e29-820d-1d118a2d38bd
## INFO [11:30:00.662] [bbotk] Evaluating 1 configuration(s)
## INFO [11:30:00.676] [mlr3] Running benchmark with 1 resampling iterations
## INFO [11:30:00.682] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [11:30:00.693] [mlr3] Finished benchmark
## INFO [11:30:00.745] [bbotk] Result of batch 5:
## INFO [11:30:00.746] [bbotk] cp minsplit classif.ce runtime_learners
## INFO [11:30:00.746] [bbotk] 0.01496239 10 0.296875 0.02
## INFO [11:30:00.746] [bbotk] uhash
## INFO [11:30:00.746] [bbotk] b96b107c-4bda-4600-94f7-65487251fdbe
## INFO [11:30:00.749] [bbotk] Evaluating 1 configuration(s)
## INFO [11:30:00.765] [mlr3] Running benchmark with 1 resampling iterations
## INFO [11:30:00.771] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [11:30:00.785] [mlr3] Finished benchmark
## INFO [11:30:00.835] [bbotk] Result of batch 6:
## INFO [11:30:00.837] [bbotk] cp minsplit classif.ce runtime_learners
## INFO [11:30:00.837] [bbotk] 0.081032 6 0.296875 0.01
## INFO [11:30:00.837] [bbotk] uhash
## INFO [11:30:00.837] [bbotk] a5bd810c-95f3-462e-b01b-f406e44b27d1
## INFO [11:30:00.840] [bbotk] Evaluating 1 configuration(s)
## INFO [11:30:00.854] [mlr3] Running benchmark with 1 resampling iterations
## INFO [11:30:00.858] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [11:30:00.869] [mlr3] Finished benchmark
## INFO [11:30:00.944] [bbotk] Result of batch 7:
## INFO [11:30:00.945] [bbotk] cp minsplit classif.ce runtime_learners
## INFO [11:30:00.945] [bbotk] 0.0865911 3 0.296875 0.02
## INFO [11:30:00.945] [bbotk] uhash
## INFO [11:30:00.945] [bbotk] f6176c32-ba95-4347-a0df-a465413d3057
## INFO [11:30:00.948] [bbotk] Evaluating 1 configuration(s)
## INFO [11:30:00.962] [mlr3] Running benchmark with 1 resampling iterations
## INFO [11:30:00.966] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [11:30:00.978] [mlr3] Finished benchmark
## INFO [11:30:01.026] [bbotk] Result of batch 8:
## INFO [11:30:01.028] [bbotk] cp minsplit classif.ce runtime_learners
## INFO [11:30:01.028] [bbotk] 0.08663804 2 0.296875 0.02
## INFO [11:30:01.028] [bbotk] uhash
## INFO [11:30:01.028] [bbotk] 69011fa2-943b-4ed1-92fa-ae3be22dcd3b
## INFO [11:30:01.031] [bbotk] Evaluating 1 configuration(s)
## INFO [11:30:01.044] [mlr3] Running benchmark with 1 resampling iterations
## INFO [11:30:01.048] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [11:30:01.061] [mlr3] Finished benchmark
## INFO [11:30:01.118] [bbotk] Result of batch 9:
## INFO [11:30:01.120] [bbotk] cp minsplit classif.ce runtime_learners
## INFO [11:30:01.120] [bbotk] 0.02164366 2 0.2773438 0
## INFO [11:30:01.120] [bbotk] uhash
## INFO [11:30:01.120] [bbotk] 52688f5e-7eba-400e-86dd-889429e5f6ed
## INFO [11:30:01.126] [bbotk] Evaluating 1 configuration(s)
## INFO [11:30:01.143] [mlr3] Running benchmark with 1 resampling iterations
## INFO [11:30:01.148] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [11:30:01.161] [mlr3] Finished benchmark
## INFO [11:30:01.230] [bbotk] Result of batch 10:
## INFO [11:30:01.232] [bbotk] cp minsplit classif.ce runtime_learners
## INFO [11:30:01.232] [bbotk] 0.09408432 10 0.296875 0.02
## INFO [11:30:01.232] [bbotk] uhash
## INFO [11:30:01.232] [bbotk] b0287967-9d28-4cc5-b506-46589fc5c903
## INFO [11:30:01.238] [bbotk] Finished optimizing after 10 evaluation(s)
## INFO [11:30:01.238] [bbotk] Result:
## INFO [11:30:01.240] [bbotk] cp minsplit learner_param_vals x_domain classif.ce
## INFO [11:30:01.240] [bbotk] 0.02154074 2 <list[3]> <list[2]> 0.2773438