Precision_proba DecisionTreeClassifier возвращает 0 или 1

Я пытаюсь использовать классифицированное дерево решений для идентификации двух классов (переименованных в 0 и 1) на основе определенных параметров. Я обучаю его с использованием набора данных, а затем запускаю его на «тестовом наборе данных». Когда я пытаюсь рассчитать вероятность для каждой точки данных в тестовом наборе данных, он возвращает только 0 или 1. Интересно, в чем проблема?

Вот пример кода:

clf=tree.DecisionTreeClassifier(random_state=0) trained=clf.fit(data,identifier) # training data where identifier is 0 or 1 predict=trained.predict(test_data) Результаты этого:

In [9]: predict

Out[9]: 
array([0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0,
       1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0,
       0, 0, 1, 0, 0, 0])

In [10]: trained.predict_proba(test_data)[:,1]

Out[10]: 
array([ 0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,  1.,  1.,  1.,
        0.,  0.,  1.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  1.,  1.,  0.,
        1.,  0.,  1.,  0.,  0.,  1.,  0.,  0.,  0.,  1.,  1.,  1.,  1.,
        0.,  1.,  1.,  0.,  0.,  1.,  0.,  0.,  0.,  1.,  0.,  0.,  0.])

Я хотел бы сгенерировать и ROC, который на данный момент просто возвращает 3 точки данных для FPR/TPR.

Вот полный набор данных: Идентификатор — это последний столбец «данных».

Данные поезда:

Spectral_Index,W1-W2,W2-W3,HR0.3-100,HR50-2,Gamma,Class
1.4304664,0.61,2.18,0.3819051,0.99992716,1.93,0
1.6969398,0.54,1.93,0.66479063,0.9999814,2.11,0
2.233997,1.02,3.18,0.55532146,0.9999979,2.07,0
2.230639,0.77,2.34,0.0012237767,1.0,1.81,0
1.7325432,0.71,2.27,0.34395835,1.0,1.9,0
1.8728518,0.8,2.14,0.4255796,1.0,1.96,0
1.9818852,0.7,2.18,-0.08978904,1.0,1.66,0
2.3864453,0.95,2.51,0.109010585,0.98401743,1.81,0
2.5911317,0.94,2.49,0.60381645,0.99991965,2.03,0
1.9564596,0.81,2.29,0.3843,0.9999495,2.08,0
2.1506176,0.93,2.62,0.28551856,0.9999999,1.91,0
1.9069784,0.62,1.76,0.041608978,1.0,1.86,0
1.6216202,0.77,2.11,-0.14271076,1.0,1.7,0
2.276335,0.68,2.14,0.40399882,1.0,2.06,0
2.2430172,1.0,2.94,0.61844856,1.0,2.12,0
1.0226197,0.66,2.07,-0.14886126,1.0,1.84,0
2.2564504,1.06,2.77,0.6974536,0.99844635,2.16,0
2.2819016,0.88,2.37,0.30696234,0.999996,1.86,0
1.4881139,0.7,2.09,0.40853307,1.0,1.82,0
2.4640048,0.9,2.39,0.35103577,1.0,2.02,0
2.656071,0.72,2.29,0.21568911,0.9999046,2.11,0
1.7204628,0.62,2.01,0.19794853,1.0,1.8,0
1.9134961,0.86,2.27,0.37281907,1.0,1.94,0
1.3061943,0.67,2.01,0.3463318,0.99999976,1.86,0
1.8845558,0.64,2.01,0.12364135,0.9999834,1.84,0
2.4409518,1.12,3.31,0.7502838,1.0,2.17,0
1.9501582,0.85,2.34,0.29961613,0.9999974,1.92,0
2.1314192,1.03,2.62,0.69623667,1.0,2.28,0
1.7345899,0.69,2.61,0.38524705,0.99999887,2.09,0
1.7095753,0.75,2.08,0.21696341,0.9999987,1.95,0
1.9115254,0.83,2.17,-0.046689913,1.0,1.85,0
1.565369,0.67,2.01,-0.04827315,0.9999915,1.79,0
2.2971635,0.59,2.1,0.35741857,1.0,2.0,0
3.042759,1.06,2.94,0.70878696,0.9999844,2.15,0
2.340724,0.96,2.74,0.42822766,0.99999416,1.97,0
1.8552977,0.74,2.09,0.07262661,1.0,1.69,0
2.0324602,0.66,2.05,-0.07643526,0.9999982,1.83,0
1.8508979,0.67,1.96,0.054557554,0.99997455,1.75,0
2.7983437,0.96,2.58,0.8554537,0.9999992,2.2,0
2.1728642,1.09,3.05,0.61488354,1.0,2.04,0
3.113785,0.66,1.85,0.48011553,0.99995273,1.95,0
3.0665417,0.78,2.19,0.27814054,1.0,1.86,0
2.0060341,0.83,2.39,0.20785762,0.9999502,1.85,0
2.1786506,0.57,2.0,0.33096096,1.0,1.91,0
1.823961,0.72,1.96,-0.103285044,1.0,1.6,0
1.612012,0.68,2.15,-0.3136376,0.65517294,1.52,0
2.1615896,0.87,2.4,0.47535577,1.0,2.04,0
2.3053634,1.06,2.92,0.67040676,0.9991328,2.15,0
1.7525402,0.73,2.12,0.25563625,0.9999979,1.92,0
2.7306526,0.91,2.35,0.68943393,-0.4308276,2.1,0
2.2549937,1.07,2.91,0.6077795,0.9999626,2.04,0
2.0924683,0.69,2.04,-0.068183094,0.3497915,1.77,0
2.210627,0.84,2.09,0.6309954,0.99999976,1.99,0
2.4609168,0.67,2.08,0.29552716,0.99964327,1.96,0
2.5169518,0.84,2.45,0.35437247,0.9999745,1.92,0
2.1841373,0.9,2.51,0.5617463,1.0,2.15,0
3.0673068,0.8,2.22,0.17641401,1.0,1.9,0
2.6202004,0.97,2.47,0.36663872,1.0,2.03,0
1.9694642,0.95,2.54,0.33140072,0.99998665,2.04,0
1.8766946,0.84,2.32,-0.024992371,0.99999803,1.94,0
2.9352057,1.2,2.96,0.6385377,0.9951195,2.18,0
1.4075257,0.86,2.27,0.046303034,0.9999998,1.81,0
1.8769667,0.6,2.0,0.08842805,0.15410244,1.83,0
1.2585826,0.71,1.96,0.005930161,0.78259146,1.72,0
2.2046561,0.9,2.37,0.62021697,1.0,2.07,0
1.0217602,0.49,1.89,-0.26944694,0.9999997,1.66,0
2.1021683,1.05,2.78,0.5306551,1.0,2.14,0
2.4789429,0.94,2.52,0.34224525,0.9999965,2.01,0
2.1449182,0.8,2.32,0.37609425,0.9997282,2.25,0
2.7071185,0.83,2.36,0.75363404,1.0,2.31,0
1.8445525,1.04,2.76,0.6075378,0.88632137,2.14,0
1.6024263,1.09,2.63,0.64461184,1.0,2.18,0
2.0292685,0.53,2.15,0.090091705,1.0,1.92,0
2.0858748,0.71,1.86,0.14351326,0.9999994,1.88,0
2.1292083,0.81,2.31,0.33257455,1.0,1.95,0
1.6344122,0.84,2.38,0.6371139,0.9999998,2.11,0
1.7532507,0.75,2.04,0.16182575,1.0,1.78,0
2.2479355,0.97,2.72,0.41953298,1.0,2.04,0
2.5790315,1.07,2.96,0.7216893,0.9999953,2.11,0
3.0039942,1.03,2.44,0.8042694,0.9998856,2.25,1
3.7599833,1.16,3.23,0.9095345,0.66683024,2.39,1
2.8912013,1.05,2.67,0.85215354,0.9967052,2.27,1
3.8784094,1.11,3.18,0.6971026,1.0,2.19,1
2.1862392,1.13,2.7,0.65855825,1.0,2.28,1
2.7684402,1.16,2.79,0.9261603,-0.9540385,2.35,1
1.7551649,0.56,2.18,0.23092282,1.0,1.98,1
2.804592,1.13,2.98,0.84827685,1.0,2.3,1
1.9874831,1.0,2.98,0.87599415,1.0,2.21,1
2.5059428,1.16,2.79,0.97649753,0.9997586,2.42,1
2.812127,1.12,3.11,0.87392867,1.0,2.21,1
2.9445121,1.06,3.17,0.8849491,1.0,2.41,1
2.7388847,1.11,2.78,0.84986275,0.96669436,2.32,1
2.1416433,1.1,3.61,0.7671358,0.9999998,2.29,1
2.3661094,1.05,3.16,0.73194104,0.99990827,2.14,1
2.761189,1.09,2.81,0.7681978,-0.99955946,2.23,1
2.6658804,1.02,3.36,0.8036201,0.98403203,2.28,1
2.720667,0.99,2.78,0.97055733,0.9781505,2.48,1
2.6812658,0.98,3.05,0.73290765,1.0,2.09,1
1.4784714,0.62,1.97,0.418,1.0,2.02,0
1.7488811,0.7,2.05,0.418,0.99999624,2.02,0

тестовые данные:

Spectral_Index,W1-W2,W2-W3,HR0.3-100,HR50-2,Gamma
1.6724254,0.95,2.58,0.92031854,1.0,2.15
2.552926,0.93,2.74,0.63588345,-0.30092865,2.18
2.5737462,0.86,2.22,0.43023747,1.0,2.08
2.1701677,0.62,2.19,0.6892167,1.0,2.15
3.6152358,0.96,2.58,0.67760235,0.99704355,2.06
3.6193092,0.82,2.34,0.4083981,0.9973078,2.04
2.0209844,1.02,2.86,0.8595182,-0.9979041,2.36
2.166221,1.07,3.0,0.7177616,-0.99961376,2.3
2.7933478,0.94,2.4,0.678935,1.0,2.12
2.2969048,0.86,2.29,0.18689133,1.0,1.96
3.1255674,1.15,2.77,0.9290483,0.6387009,2.28
2.3548958,1.01,2.46,0.75331503,-1.0,2.21
3.9791226,1.15,3.04,0.87006325,-0.99919724,2.43
2.3430493,0.85,2.42,0.81132597,-0.9999996,2.04
3.7431624,0.79,2.57,0.704,0.99952716,2.20784
3.1846259,1.14,2.85,0.9104803,0.99891067,2.3
3.1416001,0.73,2.26,0.5679769,1.0,1.98
2.670179,0.85,2.66,0.7376513,0.97939825,2.1
3.010911,0.79,2.38,0.21750104,0.21187924,1.82
1.4430648,0.9,2.38,0.7361963,0.999758,2.11
2.8149416,1.07,2.62,0.94750744,0.9967568,2.4
3.8395922,1.09,2.91,0.27485812,0.99887043,2.05
3.1686394,0.66,2.11,0.529385,1.0,1.9
3.190167,1.09,3.1,0.8501991,0.9507157,2.23
3.8597586,1.13,3.64,0.89043206,0.17880388,2.42
2.1516426,0.85,2.24,0.6673518,0.9985168,2.2
2.1318088,0.98,2.64,0.85542095,1.0,2.22
1.6740437,0.97,2.99,0.86632746,0.9983954,2.41
4.273427,1.01,2.71,0.8941501,0.64256436,2.47
2.284782,0.92,2.7,0.5820462,0.6981752,2.1
3.343603,1.06,2.84,0.6901738,0.83269715,2.13
5.766362,1.2,3.74,0.99009913,0.99998844,2.49
2.1547525,0.95,3.02,0.75229234,0.99604213,2.57
2.9853358,0.91,2.37,0.62881154,-0.98792726,2.06
2.8614197,0.82,2.15,0.75643075,1.0,2.19
3.6815813,1.14,3.24,0.8886577,-0.030438267,2.39
4.539201,1.17,2.83,0.93989134,0.23378997,2.55
3.35261,1.1,2.73,0.9184936,0.9998006,2.41
3.6697345,1.16,3.57,0.9515105,0.9999988,2.43
1.9781204,0.91,2.85,-0.06649571,0.9999991,1.7
2.6618617,1.1,3.24,0.8348949,-0.9834342,2.29
3.8140056,1.18,3.25,0.8766021,1.0,2.39
2.1926181,1.05,2.3,0.6880097,1.0,2.3
2.0248337,0.83,2.29,0.3604591,0.46159065,2.05
3.904931,1.13,2.46,0.9100119,1.0,2.32
1.9945884,0.94,2.5,0.4632657,0.9869119,2.05
3.3342967,1.1,3.04,0.51323855,-0.5262294,2.23
2.3138714,0.91,2.36,0.90414697,0.9999977,2.29
2.3118904,1.04,3.01,0.87289846,0.998577,2.29
2.246307,1.07,2.72,0.6147379,0.9999993,2.11
1.6369493,0.89,2.34,0.61421084,0.9997295,2.22
3.6198807,0.93,2.62,0.7463702,0.9994778,2.07

person Phyast10    schedule 12.01.2018    source источник
comment
Это зависит от данных. Возможно, ваши данные таковы, что дерево решений может успешно разделить классы (возможно, дерево переоснащается), и, следовательно, дерево точно знает, что конкретный экземпляр принадлежит классу 1 с вероятностью 100%. Без просмотра образцов данных мы мало что можем сделать.   -  person Vivek Kumar    schedule 12.01.2018
comment
Я использую 100 баллов для обучения и 6 параметров. Как вы думаете, проблема в этом?   -  person Phyast10    schedule 12.01.2018
comment
Взгляните на martin-thoma.com/comparing-classifiers (там есть код). Попробуйте другие классификаторы, особенно логистическую регрессию и k ближайших соседей. Что для них результат?   -  person Martin Thoma    schedule 15.01.2018
comment
По данным, которые вы предоставили, дерево дает 100% точность на поезде. Так что это полностью соответствует данным.   -  person Vivek Kumar    schedule 15.01.2018


Ответы (1)


Проблем нет — дерево ведет себя именно так, как ожидалось.

Дерево решений вычисляет вероятность класса по количеству выборок каждого класса, попадающих в данный лист.

В документации говорится:

Значения по умолчанию для параметров, контролирующих размер деревьев (например, max_depth, min_samples_leaf и т. д.), приводят к полностью выращенным и необрезанным деревьям.

т.е. дерево растет до тех пор, пока оно полностью не будет соответствовать обучающим данным. Это означает, что все обучающие выборки в каждом листе относятся к одному и тому же классу, а тестовая выборка либо соответствует этому классу (p=1), либо нет (p=0).

Чтобы получить более точные оценки вероятности, вы можете ограничить min_samples_leaf так, чтобы в каждом листе было минимальное количество выборок, которые будут использоваться для вычисления вероятностей (с одной выборкой вы получите [0, 1] - например, с 10 выборками вы можете получить [0 , 0,1, 0,2, ..., 0,9, 1]). Вам придется поэкспериментировать с настройками, чтобы найти, какие числа лучше всего подходят для вас и ваших данных.

person kazemakase    schedule 15.01.2018
comment
Большое спасибо за разъяснение. Тогда проблема заключается в кривой ROC, которая в этом случае покажет только 3 точки из-за двух значений вероятности и еще одного в бесконечности. Как мне представить результаты этой классификации? - person Phyast10; 16.01.2018
comment
Я бы ограничил min_samples_leaf так, чтобы в каждом листе было минимальное количество выборок, которые будут использоваться для вычисления вероятностей (с одной выборкой вы получите [0, 1] - например, с 10 выборками вы можете получить [0, 0,1, 0,2, ..., 0,9, 1]). Вам придется поэкспериментировать с настройками, чтобы найти, какие числа лучше всего подходят для вас и ваших данных. - person kazemakase; 22.01.2018