#include #include #include #include #include "linear.h" #include "mex.h" #include "linear_model_matlab.h" #ifdef MX_API_VER #if MX_API_VER < 0x07030000 typedef int mwIndex; #endif #endif #define CMD_LEN 2048 #define Malloc(type,n) (type *)malloc((n)*sizeof(type)) #define INF HUGE_VAL void print_null(const char *s) {} void print_string_matlab(const char *s) {mexPrintf(s);} void exit_with_help() { mexPrintf( "Usage: model = train(training_label_vector, training_instance_matrix, 'liblinear_options', 'col');\n" "liblinear_options:\n" "-s type : set type of solver (default 1)\n" " for multi-class classification\n" " 0 -- L2-regularized logistic regression (primal)\n" " 1 -- L2-regularized L2-loss support vector classification (dual)\n" " 2 -- L2-regularized L2-loss support vector classification (primal)\n" " 3 -- L2-regularized L1-loss support vector classification (dual)\n" " 4 -- support vector classification by Crammer and Singer\n" " 5 -- L1-regularized L2-loss support vector classification\n" " 6 -- L1-regularized logistic regression\n" " 7 -- L2-regularized logistic regression (dual)\n" " for regression\n" " 11 -- L2-regularized L2-loss support vector regression (primal)\n" " 12 -- L2-regularized L2-loss support vector regression (dual)\n" " 13 -- L2-regularized L1-loss support vector regression (dual)\n" "-c cost : set the parameter C (default 1)\n" "-p epsilon : set the epsilon in loss function of SVR (default 0.1)\n" "-e epsilon : set tolerance of termination criterion\n" " -s 0 and 2\n" " |f'(w)|_2 <= eps*min(pos,neg)/l*|f'(w0)|_2,\n" " where f is the primal function and pos/neg are # of\n" " positive/negative data (default 0.01)\n" " -s 11\n" " |f'(w)|_2 <= eps*|f'(w0)|_2 (default 0.0001)\n" " -s 1, 3, 4 and 7\n" " Dual maximal violation <= eps; similar to libsvm (default 0.1)\n" " -s 5 and 6\n" " |f'(w)|_1 <= eps*min(pos,neg)/l*|f'(w0)|_1,\n" " where f is the primal function (default 0.01)\n" " -s 12 and 13\n" " |f'(alpha)|_1 <= eps |f'(alpha0)|,\n" " where f is the dual function (default 0.1)\n" "-B bias : if bias >= 0, instance x becomes [x; bias]; if < 0, no bias term added (default -1)\n" "-wi weight: weights adjust the parameter C of different classes (see README for details)\n" "-v n: n-fold cross validation mode\n" "-C : find parameters (C for -s 0, 2 and C, p for -s 11)\n" "-q : quiet mode (no outputs)\n" "col:\n" " if 'col' is setted, training_instance_matrix is parsed in column format, otherwise is in row format\n" ); } // liblinear arguments struct parameter param; // set by parse_command_line struct problem prob; // set by read_problem struct model *model_; struct feature_node *x_space; int flag_cross_validation; int flag_find_parameters; int flag_C_specified; int flag_p_specified; int flag_solver_specified; int col_format_flag; int nr_fold; double bias; void do_find_parameters(double *best_C, double *best_p, double *best_score) { double start_C, start_p; if (flag_C_specified) start_C = param.C; else start_C = -1.0; if (flag_p_specified) start_p = param.p; else start_p = -1.0; find_parameters(&prob, ¶m, nr_fold, start_C, start_p, best_C, best_p, best_score); if(param.solver_type == L2R_LR || param.solver_type == L2R_L2LOSS_SVC) mexPrintf("Best C = %g CV accuracy = %g%%\n", *best_C, 100.0**best_score); else if(param.solver_type == L2R_L2LOSS_SVR) mexPrintf("Best C = %g Best p = %g CV MSE = %g\n", *best_C, *best_p, *best_score); } double do_cross_validation() { int i; int total_correct = 0; double total_error = 0; double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0; double *target = Malloc(double, prob.l); double retval = 0.0; cross_validation(&prob,¶m,nr_fold,target); if(param.solver_type == L2R_L2LOSS_SVR || param.solver_type == L2R_L1LOSS_SVR_DUAL || param.solver_type == L2R_L2LOSS_SVR_DUAL) { for(i=0;i 2) { mxGetString(prhs[2], cmd, mxGetN(prhs[2]) + 1); if((argv[argc] = strtok(cmd, " ")) != NULL) while((argv[++argc] = strtok(NULL, " ")) != NULL) ; } // parse options for(i=1;i=argc && argv[i-1][1] != 'q' && argv[i-1][1] != 'C') // since options -q and -C have no parameter return 1; switch(argv[i-1][1]) { case 's': param.solver_type = atoi(argv[i]); flag_solver_specified = 1; break; case 'c': param.C = atof(argv[i]); flag_C_specified = 1; break; case 'p': param.p = atof(argv[i]); flag_p_specified = 1; break; case 'e': param.eps = atof(argv[i]); break; case 'B': bias = atof(argv[i]); break; case 'v': flag_cross_validation = 1; nr_fold = atoi(argv[i]); if(nr_fold < 2) { mexPrintf("n-fold cross validation: n must >= 2\n"); return 1; } break; case 'w': ++param.nr_weight; param.weight_label = (int *) realloc(param.weight_label,sizeof(int)*param.nr_weight); param.weight = (double *) realloc(param.weight,sizeof(double)*param.nr_weight); param.weight_label[param.nr_weight-1] = atoi(&argv[i-1][2]); param.weight[param.nr_weight-1] = atof(argv[i]); break; case 'q': print_func = &print_null; i--; break; case 'C': flag_find_parameters = 1; i--; break; default: mexPrintf("unknown option\n"); return 1; } } set_print_string_function(print_func); // default solver for parameter selection is L2R_L2LOSS_SVC if(flag_find_parameters) { if(!flag_cross_validation) nr_fold = 5; if(!flag_solver_specified) { mexPrintf("Solver not specified. Using -s 2\n"); param.solver_type = L2R_L2LOSS_SVC; } else if(param.solver_type != L2R_LR && param.solver_type != L2R_L2LOSS_SVC && param.solver_type != L2R_L2LOSS_SVR) { mexPrintf("Warm-start parameter search only available for -s 0, -s 2 and -s 11\n"); return 1; } } if(param.eps == INF) { switch(param.solver_type) { case L2R_LR: case L2R_L2LOSS_SVC: param.eps = 0.01; break; case L2R_L2LOSS_SVR: param.eps = 0.0001; break; case L2R_L2LOSS_SVC_DUAL: case L2R_L1LOSS_SVC_DUAL: case MCSVM_CS: case L2R_LR_DUAL: param.eps = 0.1; break; case L1R_L2LOSS_SVC: case L1R_LR: param.eps = 0.01; break; case L2R_L1LOSS_SVR_DUAL: case L2R_L2LOSS_SVR_DUAL: param.eps = 0.1; break; } } return 0; } static void fake_answer(int nlhs, mxArray *plhs[]) { int i; for(i=0;i=0) { x_space[j].index = (int) max_index+1; x_space[j].value = prob.bias; j++; } x_space[j++].index = -1; } if(prob.bias>=0) prob.n = (int) max_index+1; else prob.n = (int) max_index; return 0; } // Interface function of matlab // now assume prhs[0]: label prhs[1]: features void mexFunction( int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[] ) { const char *error_msg; // fix random seed to have same results for each run // (for cross validation) srand(1); if(nlhs > 1) { exit_with_help(); fake_answer(nlhs, plhs); return; } // Transform the input Matrix to libsvm format if(nrhs > 1 && nrhs < 5) { int err=0; if(!mxIsDouble(prhs[0]) || !mxIsDouble(prhs[1])) { mexPrintf("Error: label vector and instance matrix must be double\n"); fake_answer(nlhs, plhs); return; } if(mxIsSparse(prhs[0])) { mexPrintf("Error: label vector should not be in sparse format"); fake_answer(nlhs, plhs); return; } if(parse_command_line(nrhs, prhs, NULL)) { exit_with_help(); destroy_param(¶m); fake_answer(nlhs, plhs); return; } if(mxIsSparse(prhs[1])) err = read_problem_sparse(prhs[0], prhs[1]); else { mexPrintf("Training_instance_matrix must be sparse; " "use sparse(Training_instance_matrix) first\n"); destroy_param(¶m); fake_answer(nlhs, plhs); return; } // train's original code error_msg = check_parameter(&prob, ¶m); if(err || error_msg) { if (error_msg != NULL) mexPrintf("Error: %s\n", error_msg); destroy_param(¶m); free(prob.y); free(prob.x); free(x_space); fake_answer(nlhs, plhs); return; } if (flag_find_parameters) { double best_C, best_p, best_score, *ptr; do_find_parameters(&best_C, &best_p, &best_score); plhs[0] = mxCreateDoubleMatrix(3, 1, mxREAL); ptr = mxGetPr(plhs[0]); ptr[0] = best_C; ptr[1] = best_p; ptr[2] = best_score; } else if(flag_cross_validation) { double *ptr; plhs[0] = mxCreateDoubleMatrix(1, 1, mxREAL); ptr = mxGetPr(plhs[0]); ptr[0] = do_cross_validation(); } else { const char *error_msg; model_ = train(&prob, ¶m); error_msg = model_to_matlab_structure(plhs, model_); if(error_msg) mexPrintf("Error: can't convert libsvm model to matrix structure: %s\n", error_msg); free_and_destroy_model(&model_); } destroy_param(¶m); free(prob.y); free(prob.x); free(x_space); } else { exit_with_help(); fake_answer(nlhs, plhs); return; } }