CodeWalk

Scikit-learn自定义Transformer与Estimator的完整实现

作者:我是大山 · 2026-05-30 12:55

请说明如何在sklearn中自定义Transformer和Estimator,并使其兼容Pipeline和GridSearchCV。实现一个自定义Transformer(如自定义文本特征提取器),以及一个自定义Estimator(如包装统计模型的分类器)。需要实现哪些必须的方法(fit/transform/predict/score/get_params/set_params)?

回答

我是大山

自定义Transformer:继承BaseEstimatorTransformerMixin

from sklearn.base import BaseEstimator, TransformerMixin

class TextFeatureExtractor(BaseEstimator, TransformerMixin):
    def __init__(self, max_words=100):
        self.max_words = max_words
        
    def fit(self, X, y=None):
        # X是文本列表,学习词汇表
        from collections import Counter
        word_counts = Counter()
        for text in X:
            word_counts.update(text.split())
        self.vocab_ = [w for w, _ in word_counts.most_common(self.max_words)]
        return self
        
    def transform(self, X):
        # 转换为词频矩阵
        import numpy as np
        result = np.zeros((len(X), len(self.vocab_)))
        for i, text in enumerate(X):
            for w in text.split():
                if w in self.vocab_:
                    j = self.vocab_.index(w)
                    result[i, j] += 1
        return result

自定义Estimator:继承BaseEstimatorClassifierMixin/RegressorMixin

from sklearn.base import BaseEstimator, ClassifierMixin

class MyClassifier(BaseEstimator, ClassifierMixin):
    def __init__(self, alpha=1.0):
        self.alpha = alpha
        
    def fit(self, X, y):
        # 学习逻辑回归参数
        self.coef_ = ...  # 学习得到的参数
        self.classes_ = np.unique(y)
        return self
        
    def predict(self, X):
        return self.classes_[np.argmax(self.decision_function(X), axis=1)]
        
    def predict_proba(self, X):
        # 实现概率预测
        ...
        return proba

必需方法

  • get_params():由BaseEstimator自动根据__init__签名生成
  • set_params():同上
  • 重要__init__参数必须作为关键字参数且不包含*args**kwargs
  • 所有学习到的属性加下划线后缀(如self.vocab_