Contenido principal

loss

Classification loss adjusted by fairness threshold

Since R2023a

    Description

    L = loss(thresholder,Tbl) computes the classification loss (specified by thresholder.LossFun) by using the fairnessThresholder object thresholder and the table data Tbl.

    example

    L = loss(thresholder,X,attribute,Y) computes the classification loss (specified by thresholder.LossFun) by using the fairnessThresholder object thresholder, the matrix data X, the sensitive attribute specified by attribute, and the true class labels Y.

    Examples

    collapse all

    Train a tree ensemble for binary classification, and compute the disparate impact for each group in the sensitive attribute. To reduce the disparate impact value of the nonreference group, adjust the score threshold for classifying observations.

    Load the data census1994, which contains the data set adultdata and the test data set adulttest. The data sets consist of demographic information from the US Census Bureau that can be used to predict whether an individual makes over $50,000 per year. Preview the first few rows of adultdata.

    load census1994
    head(adultdata)
        age       workClass          fnlwgt      education    education_num       marital_status           occupation        relationship     race      sex      capital_gain    capital_loss    hours_per_week    native_country    salary
        ___    ________________    __________    _________    _____________    _____________________    _________________    _____________    _____    ______    ____________    ____________    ______________    ______________    ______
    
        39     State-gov                77516    Bachelors         13          Never-married            Adm-clerical         Not-in-family    White    Male          2174             0                40          United-States     <=50K 
        50     Self-emp-not-inc         83311    Bachelors         13          Married-civ-spouse       Exec-managerial      Husband          White    Male             0             0                13          United-States     <=50K 
        38     Private             2.1565e+05    HS-grad            9          Divorced                 Handlers-cleaners    Not-in-family    White    Male             0             0                40          United-States     <=50K 
        53     Private             2.3472e+05    11th               7          Married-civ-spouse       Handlers-cleaners    Husband          Black    Male             0             0                40          United-States     <=50K 
        28     Private             3.3841e+05    Bachelors         13          Married-civ-spouse       Prof-specialty       Wife             Black    Female           0             0                40          Cuba              <=50K 
        37     Private             2.8458e+05    Masters           14          Married-civ-spouse       Exec-managerial      Wife             White    Female           0             0                40          United-States     <=50K 
        49     Private             1.6019e+05    9th                5          Married-spouse-absent    Other-service        Not-in-family    Black    Female           0             0                16          Jamaica           <=50K 
        52     Self-emp-not-inc    2.0964e+05    HS-grad            9          Married-civ-spouse       Exec-managerial      Husband          White    Male             0             0                45          United-States     >50K  
    

    Each row contains the demographic information for one adult. The information includes sensitive attributes, such as age, marital_status, relationship, race, and sex. The third column flnwgt contains observation weights, and the last column salary shows whether a person has a salary less than or equal to $50,000 per year (<=50K) or greater than $50,000 per year (>50K).

    Remove observations with missing values.

    adultdata = rmmissing(adultdata);
    adulttest = rmmissing(adulttest);

    Partition adultdata into training and validation sets. Use 60% of the observations for the training set trainingData and 40% of the observations for the validation set validationData.

    rng("default") % For reproducibility
    c = cvpartition(adultdata.salary,"Holdout",0.4);
    trainingIdx = training(c);
    validationIdx = test(c);
    trainingData = adultdata(trainingIdx,:);
    validationData = adultdata(validationIdx,:);

    Train a boosted ensemble of trees using the training data set trainingData. Specify the response variable, predictor variables, and observation weights by using the variable names in the adultdata table. Use random undersampling boosting as the ensemble aggregation method.

    predictors = ["capital_gain","capital_loss","education", ...
        "education_num","hours_per_week","occupation","workClass"];
    Mdl = fitcensemble(trainingData,"salary", ...
        PredictorNames=predictors, ...
        Weights="fnlwgt",Method="RUSBoost");

    Predict salary values for the observations in the test data set adulttest, and calculate the classification error.

    labels = predict(Mdl,adulttest);
    L = loss(Mdl,adulttest)
    L = 
    0.2080
    

    The model accurately predicts the salary categorization for approximately 80% of the test set observations.

    Compute fairness metrics with respect to the sensitive attribute sex by using the test set model predictions. In particular, find the disparate impact for each group in sex. Use the report and plot object functions of fairnessMetrics to display the results.

    metricsResults = fairnessMetrics(adulttest,"salary", ...
        SensitiveAttributeNames="sex",Predictions=labels, ...
        ModelNames="Ensemble",Weights="fnlwgt");
    metricsResults.PositiveClass
    ans = categorical
         >50K 
    
    
    metricsResults.ReferenceGroup
    ans = 
    'Male'
    
    report(metricsResults,BiasMetrics="DisparateImpact")
    ans=2×4 table
        ModelNames    SensitiveAttributeNames    Groups    DisparateImpact
        __________    _______________________    ______    _______________
    
         Ensemble               sex              Female        0.73792    
         Ensemble               sex              Male                1    
    
    
    plot(metricsResults,"DisparateImpact")

    Figure contains an axes object. The axes object with title Disparate Impact, xlabel Fairness Metric Value, ylabel sex contains 2 objects of type bar, constantline.

    For the nonreference group (Female), the disparate impact value is the proportion of predictions in the group with a positive class value (>50K) divided by the proportion of predictions in the reference group (Male) with a positive class value. Ideally, disparate impact values are close to 1.

    To try to improve the nonreference group disparate impact value, you can adjust model predictions by using the fairnessThresholder function. The function uses validation data to search for an optimal score threshold that maximizes accuracy while satisfying fairness bounds. For observations in the critical region below the optimal threshold, the function changes the labels so that the fairness constraints hold for the reference and nonreference groups. By default, the function tries to find a score threshold so that the disparate impact value for the nonreference group is in the range [0.8,1.25].

    fairnessMdl = fairnessThresholder(Mdl,validationData,"sex","salary")
    fairnessMdl = 
      fairnessThresholder with properties:
    
                   Learner: [1×1 classreg.learning.classif.CompactClassificationEnsemble]
        SensitiveAttribute: 'sex'
           ReferenceGroups: Male
              ResponseName: 'salary'
             PositiveClass: >50K
            ScoreThreshold: 1.6749
                BiasMetric: 'DisparateImpact'
           BiasMetricValue: 0.9702
           BiasMetricRange: [0.8000 1.2500]
            ValidationLoss: 0.2017
    
    

    fairnessMdl is a fairnessThresholder model object. Note that the predict function of the ensemble model Mdl returns scores that are not posterior probabilities. Scores are in the range (-,) instead, and the maximum score for each observation is greater than 0. For observations whose maximum scores are less than the new score threshold (fairnessMdl.ScoreThreshold), the predict function of the fairnessMdl object adjusts the prediction. If the observation is in the nonreference group, the function predicts the observation into the positive class. If the observation is in the reference group, the function predicts the observation into the negative class. These adjustments do not always result in a change in the predicted label.

    Adjust the test set predictions by using the new score threshold, and calculate the classification error.

    fairnessLabels = predict(fairnessMdl,adulttest);
    fairnessLoss = loss(fairnessMdl,adulttest)
    fairnessLoss = 
    0.2064
    

    The new classification error is similar to the original classification error.

    Compare the disparate impact values across the two sets of test predictions: the original predictions computed using Mdl and the adjusted predictions computed using fairnessMdl.

    newMetricsResults = fairnessMetrics(adulttest,"salary", ...
        SensitiveAttributeNames="sex",Predictions=[labels,fairnessLabels], ...
        ModelNames=["Original","Adjusted"],Weights="fnlwgt");
    newMetricsResults.PositiveClass
    ans = categorical
         >50K 
    
    
    newMetricsResults.ReferenceGroup
    ans = 
    'Male'
    
    report(newMetricsResults,BiasMetrics="DisparateImpact")
    ans=2×5 table
            Metrics        SensitiveAttributeNames    Groups    Original    Adjusted
        _______________    _______________________    ______    ________    ________
    
        DisparateImpact              sex              Female    0.73792      1.0048 
        DisparateImpact              sex              Male            1           1 
    
    
    plot(newMetricsResults,"di")

    Figure contains an axes object. The axes object with title Disparate Impact, xlabel Fairness Metric Value, ylabel sex contains 2 objects of type bar. These objects represent Original, Adjusted.

    The disparate impact value for the nonreference group (Female) is closer to 1 when you use the adjusted predictions.

    Input Arguments

    collapse all

    Fairness classification model, specified as a fairnessThresholder object. The ScoreThreshold property of the object must be nonempty.

    Data set, specified as a table. Each row of Tbl corresponds to one observation, and each column corresponds to one variable. If you use a table when creating the fairnessThresholder object, then you must use a table when using the loss function. The table must include all required predictor variables, the sensitive attribute, and the response variable. The table can include additional variables. Multicolumn variables and cell arrays other than cell arrays of character vectors are not allowed.

    Data Types: table

    Predictor data, specified as a numeric matrix. Each row of X corresponds to one observation, and each column corresponds to one predictor variable. If you use a matrix when creating the fairnessThresholder object, then you must use a matrix when using the loss function. X, attribute, and Y must have the same number of rows.

    Data Types: single | double

    Sensitive attribute, specified as a numeric column vector, logical column vector, character array, string array, cell array of character vectors, or categorical column vector.

    • X, Y, and attribute must have the same number of rows.

    • If attribute is a character array, then each row of the array must correspond to a group in the sensitive attribute.

    Data Types: single | double | logical | char | string | cell | categorical

    Class labels, specified as a numeric column vector, logical column vector, character array, string array, cell array of character vectors, or categorical column vector.

    • X, attribute, and Y must have the same number of rows.

    • If Y is a character array, then each row of the array must correspond to a class label.

    • The data type of Y must be the same as the data type of the response variable used to create thresholder.

    • If thresholder.Learner is a classification model object, then the distinct classes in Y must be a subset of the classes in thresholder.Learner.ClassNames.

    Data Types: single | double | logical | char | string | cell | categorical

    Output Arguments

    collapse all

    Classification loss, returned as a numeric scalar. The loss function computes the classification loss specified by the LossFun value used when creating thresholder. The function uses the data set predictions, adjusted using the thresholder.ScoreThreshold value. For more information, see Reject Option-Based Classification.

    Version History

    Introduced in R2023a