public class NaiveBayesModel extends java.lang.Object implements ClassificationModel, scala.Serializable, Saveable
param: labels list of labels param: pi log of class priors, whose dimension is C, number of labels param: theta log of class conditional probabilities, whose dimension is C-by-D, where D is number of features param: modelType The type of NB model to fit can be "multinomial" or "bernoulli"
Modifier and Type | Method and Description |
---|---|
protected java.lang.String |
formatVersion()
Current version of model save/load format.
|
double[] |
labels() |
static NaiveBayesModel |
load(SparkContext sc,
java.lang.String path) |
java.lang.String |
modelType() |
double[] |
pi() |
RDD<java.lang.Object> |
predict(RDD<Vector> testData)
Predict values for the given data set using the model trained.
|
double |
predict(Vector testData)
Predict values for a single data point using the model trained.
|
RDD<Vector> |
predictProbabilities(RDD<Vector> testData)
Predict values for the given data set using the model trained.
|
Vector |
predictProbabilities(Vector testData)
Predict posterior class probabilities for a single data point using the model trained.
|
void |
save(SparkContext sc,
java.lang.String path)
Save this model to the given path.
|
double[][] |
theta() |
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
predict
public static NaiveBayesModel load(SparkContext sc, java.lang.String path)
public double[] labels()
public double[] pi()
public double[][] theta()
public java.lang.String modelType()
public RDD<java.lang.Object> predict(RDD<Vector> testData)
ClassificationModel
predict
in interface ClassificationModel
testData
- RDD representing data points to be predictedpublic double predict(Vector testData)
ClassificationModel
predict
in interface ClassificationModel
testData
- array representing a single data pointpublic RDD<Vector> predictProbabilities(RDD<Vector> testData)
testData
- RDD representing data points to be predictedpublic Vector predictProbabilities(Vector testData)
testData
- array representing a single data pointpublic void save(SparkContext sc, java.lang.String path)
Saveable
This saves: - human-readable (JSON) model metadata to path/metadata/ - Parquet formatted data to path/data/
The model may be loaded using Loader.load
.
protected java.lang.String formatVersion()
Saveable
formatVersion
in interface Saveable