LightGBM Wrapper
train -- Trains a wideboost model using the lightGBM backend.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
param |
list |
Named parameter list. Uses lightGBM conventions. Requires two parameters in addition to the usual lightGBM parameters, 'btype' and 'extra_dims'. |
required |
Returns:
| Type | Description |
|---|---|
wlgb |
A wlgb object containing the lightGBM object with objective and evaluation objects. |
Source code in wideboost/wrappers/wlgb.py
def train(param, train_set, num_boost_round=100, valid_sets=None,
valid_names=None, fobj=None, feval=None, init_model=None,
feature_name='auto', categorical_feature='auto',
early_stopping_rounds=None, evals_result=None, verbose_eval=True,
learning_rates=None, keep_training_booster=False, callbacks=None):
"""train -- Trains a wideboost model using the lightGBM backend.
Args:
param (list): Named parameter list. Uses lightGBM conventions. Requires
two parameters in addition to the usual lightGBM parameters,
'btype' and 'extra_dims'.
Returns:
wlgb: A wlgb object containing the lightGBM object with objective and evaluation objects.
"""
params = param.copy()
if not isinstance(fobj,lgb_objective):
assert params['extra_dims'] >= 0
else:
print("Using custom objective. Removed extra_dims restriction.")
# Overwrite needed params
print("Overwriting param `num_class`")
try:
# TODO this is ugly
nclass = params["num_class"]
except:
params['num_class'] = 1
if isinstance(fobj,lgb_objective):
print("Found custom wideboost-compatible objective. Using user specified objective.")
else:
fobj = get_objective(params)
params['num_class'] = params['num_class'] + params['extra_dims']
params.pop('extra_dims')
print("Overwriting param `objective` while setting `fobj` in train.")
params['objective'] = 'regression'
try:
feval = get_eval_metric(params,fobj)
params.pop('metric')
print("Moving param `metric` to an feval.")
except:
None
lgbobject = lgb.train(params,train_set, num_boost_round=num_boost_round, valid_sets=valid_sets,
valid_names=valid_names, fobj=fobj, feval=feval, init_model=init_model,
feature_name='auto', categorical_feature='auto',
early_stopping_rounds=early_stopping_rounds, evals_result=evals_result,
verbose_eval=verbose_eval,learning_rates=learning_rates,
keep_training_booster=keep_training_booster, callbacks=callbacks)
return wlgb(lgbobject,fobj,feval)