Skip to content

XGBoost Wrapper

train -- Trains a wideboost model using the XGBoost backend.

Parameters:

Name Type Description Default
param list

Named parameter list. Uses XGBoost conventions. Requires two parameters in addition to the usual XGBoost parameters, 'btype' and 'extra_dims'.

required

Returns:

Type Description
wxgb

A wxgb object containing the XGBoost object with objective and evaluation objects.

Source code in wideboost/wrappers/wxgb.py
def train(param,dtrain,num_boost_round=10,evals=(),obj=None,
          feval=None,maximize=False,early_stopping_rounds=None,evals_result=None,
          verbose_eval=True,xgb_model=None,callbacks=None):
    """train -- Trains a wideboost model using the XGBoost backend.

    Args:
        param (list): Named parameter list. Uses XGBoost conventions. Requires 
            two parameters in addition to the usual XGBoost parameters,
            'btype' and 'extra_dims'. 

    Returns:
        wxgb: A wxgb object containing the XGBoost object with objective and evaluation objects.
    """
    params = param.copy()
    if not isinstance(obj,xgb_objective):
        assert params['extra_dims'] >= 0
    else:
        print("Using custom objective. Removed extra_dims restriction.")

    # Overwrite needed params
    print("Overwriting param `num_class`")
    try:
        # TODO this is ugly
        nclass = params["num_class"]
    except:
        params['num_class'] = 1

    if isinstance(obj,xgb_objective):
        print("Found custom wideboost-compatible objective. Using user specified objective.")
    else:
        obj = get_objective(params)

    params['num_class'] = params['num_class'] + params['extra_dims']
    params.pop('extra_dims')

    print("Overwriting param `objective` while setting `obj` in train.")
    params['objective'] = 'reg:squarederror'

    try:
        feval = get_eval_metric(params,obj)
        params.pop('eval_metric')
        print("Moving param `eval_metric` to an feval.")
    except:
        None

    print("Setting param `disable_default_eval_metric` to 1.")
    params['disable_default_eval_metric'] = 1

    # TODO: base_score should be set depending on the objective chosen
    # TODO: Allow some items to be overwritten by user. This being one of them.
    params['base_score'] = 0.0

    xgbobject = xgb.train(params,dtrain,num_boost_round=num_boost_round,evals=evals,obj=obj,
          feval=feval,maximize=maximize,early_stopping_rounds=early_stopping_rounds,
          evals_result=evals_result,verbose_eval=verbose_eval,xgb_model=xgb_model,
          callbacks=callbacks)

    return wxgb(xgbobject,obj,feval)