Пропускать запрещенные комбинации параметров при использовании GridSearchCV

Я хочу выполнить жадный поиск во всем пространстве параметров моего классификатора опорных векторов, используя GridSearchCV. Однако некоторые комбинации параметров запрещены LinearSVC и выбросить исключение. В частности, существуют взаимоисключающие комбинации параметров dual, penalty и loss:

Например, такой код:

from sklearn import svm, datasets
from sklearn.model_selection import GridSearchCV

iris = datasets.load_iris()
parameters = {'dual':[True, False], 'penalty' : ['l1', 'l2'], \
              'loss': ['hinge', 'squared_hinge']}
svc = svm.LinearSVC()
clf = GridSearchCV(svc, parameters)
clf.fit(iris.data, iris.target)

Возвращает ValueError: Unsupported set of arguments: The combination of penalty='l2' and loss='hinge' are not supported when dual=False, Parameters: penalty='l2', loss='hinge', dual=False

Мой вопрос: можно ли заставить GridSearchCV пропускать комбинации параметров, которые модель запрещает? Если нет, есть ли простой способ создать пространство параметров, не нарушающее правила?


person crypdick    schedule 24.03.2017    source источник
comment
Это все равно было бы проблемой, но меньшей проблемой, если бы мы могли хотя бы подавить операторы FitFailedWarning в этом случае. Я сталкиваюсь с той же битвой, где я знаю, что некоторые комбинации незаконны, но логика (как объясняется ниже) для предотвращения этих комбинаций слишком уродлива.   -  person demongolem    schedule 05.06.2020


Ответы (2)


Я решил эту проблему, передав error_score=0.0 GridSearchCV:

error_score: «поднять» (по умолчанию) или числовой

Значение, присваиваемое баллу, если при подборе оценщика возникает ошибка. Если установлено «поднять», возникает ошибка. Если задано числовое значение, возникает FitFailedWarning. Этот параметр не влияет на этап переоборудования, который всегда вызывает ошибку.

ОБНОВЛЕНИЕ: новые версии sklearn распечатывают кучу ConvergenceWarning и FitFailedWarning. Мне было трудно подавить их с помощью contextlib.suppress, но вокруг этого есть хитрость с участием диспетчера контекста тестирования:

from sklearn import svm, datasets 
from sklearn.utils._testing import ignore_warnings 
from sklearn.exceptions import FitFailedWarning, ConvergenceWarning 
from sklearn.model_selection import GridSearchCV 

with ignore_warnings(category=[ConvergenceWarning, FitFailedWarning]): 
    iris = datasets.load_iris() 
    parameters = {'dual':[True, False], 'penalty' : ['l1', 'l2'], \ 
                 'loss': ['hinge', 'squared_hinge']} 
    svc = svm.LinearSVC() 
    clf = GridSearchCV(svc, parameters, error_score=0.0) 
    clf.fit(iris.data, iris.target)
person crypdick    schedule 24.03.2017
comment
Есть ли обходной путь, чтобы на самом деле избежать этих комбинаций (или любых других) до того, как они фактически выдадут какую-либо ошибку? - person GRoutar; 21.12.2018
comment
@Khabz, мой ответ был слишком большим, чтобы поместиться в комментариях, поэтому я разместил его как еще один ответ. - person crypdick; 21.12.2018
comment
@crypdick есть ли способ избежать появления FitFailedWarning в результате? - person Nihat; 30.07.2020
comment
@Nihat Я отредактировал свой ответ, чтобы отключить новые предупреждения - person crypdick; 31.07.2020

Если вы хотите полностью избежать изучения конкретных комбинаций (не дожидаясь появления ошибок), вам нужно построить сетку самостоятельно. GridSearchCV может принимать список словарей, в которых исследуются сетки, охватываемые каждым словарем в списке.

В данном случае условная логика была не так уж плоха, но для чего-то более сложного было бы утомительно:

from sklearn import svm, datasets
from sklearn.model_selection import GridSearchCV
from itertools import product

iris = datasets.load_iris()

duals = [True, False]
penaltys = ['l1', 'l2']
losses = ['hinge', 'squared_hinge']
all_params = list(product(duals, penaltys, losses))
filtered_params = [{'dual': [dual], 'penalty' : [penalty], 'loss': [loss]}
                   for dual, penalty, loss in all_params
                   if not (penalty == 'l1' and loss == 'hinge') 
                   and not ((penalty == 'l1' and loss == 'squared_hinge' and dual is True))
                  and not ((penalty == 'l2' and loss == 'hinge' and dual is False))]

svc = svm.LinearSVC()
clf = GridSearchCV(svc, filtered_params)
clf.fit(iris.data, iris.target)
person crypdick    schedule 21.12.2018
comment
Я ценю ваши усилия, но это кажется немного схематичным решением, которое привело бы к большому количеству подробностей для проблемы с большим количеством ограничений. - person GRoutar; 24.12.2018
comment
@Khabz согласился, этот код проклят! Если существует миллиард условных операторов, одна из возможностей состоит в том, чтобы программно создать список условных операторов в filtered_params, затем str.join(conditionals_list) и, наконец, eval() строку для понимания списка. - person crypdick; 26.12.2018