Skip to content

LightGBM Wrapper

train -- Trains a wideboost model using the lightGBM backend.

Parameters:

Name Type Description Default
param list

Named parameter list. Uses lightGBM conventions. Requires two parameters in addition to the usual lightGBM parameters, 'btype' and 'extra_dims'.

required

Returns:

Type Description
wlgb

A wlgb object containing the lightGBM object with objective and evaluation objects.

Source code in wideboost/wrappers/wlgb.py
def train(param, train_set, num_boost_round=100, valid_sets=None,
        valid_names=None, fobj=None, feval=None, init_model=None,
        feature_name='auto', categorical_feature='auto',
        early_stopping_rounds=None, evals_result=None, verbose_eval=True,
        learning_rates=None, keep_training_booster=False, callbacks=None):
    """train -- Trains a wideboost model using the lightGBM backend.

    Args:
        param (list): Named parameter list. Uses lightGBM conventions. Requires 
            two parameters in addition to the usual lightGBM parameters,
            'btype' and 'extra_dims'. 

    Returns:
        wlgb: A wlgb object containing the lightGBM object with objective and evaluation objects.
    """
    params = param.copy()
    if not isinstance(fobj,lgb_objective):
        assert params['extra_dims'] >= 0
    else:
        print("Using custom objective. Removed extra_dims restriction.")

    # Overwrite needed params
    print("Overwriting param `num_class`")
    try:
        # TODO this is ugly
        nclass = params["num_class"]
    except:
        params['num_class'] = 1

    if isinstance(fobj,lgb_objective):
        print("Found custom wideboost-compatible objective. Using user specified objective.")
    else:
        fobj = get_objective(params)

    params['num_class'] = params['num_class'] + params['extra_dims']
    params.pop('extra_dims')

    print("Overwriting param `objective` while setting `fobj` in train.")
    params['objective'] = 'regression'

    try:
        feval = get_eval_metric(params,fobj)
        params.pop('metric')
        print("Moving param `metric` to an feval.")
    except:
        None

    lgbobject = lgb.train(params,train_set, num_boost_round=num_boost_round, valid_sets=valid_sets,
        valid_names=valid_names, fobj=fobj, feval=feval, init_model=init_model,
        feature_name='auto', categorical_feature='auto',
        early_stopping_rounds=early_stopping_rounds, evals_result=evals_result,
        verbose_eval=verbose_eval,learning_rates=learning_rates,
        keep_training_booster=keep_training_booster, callbacks=callbacks)

    return wlgb(lgbobject,fobj,feval)