Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reformat with 0.2.6 #2

Open
wants to merge 1 commit into
base: 0.2.5
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
The diff you're trying to view is too large. We only load the first 3000 changed files.
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,18 @@ trait KeyAuthentication {
Future {

val passedKey = accessKeyParamOpt.getOrElse {
Left(AuthenticationFailedRejection(
AuthenticationFailedRejection.CredentialsRejected, List()))
Left(
AuthenticationFailedRejection(
AuthenticationFailedRejection.CredentialsRejected,
List()))
}

if (passedKey.equals(ServerKey.get)) Right(ctx.request)
else
Left(AuthenticationFailedRejection(
AuthenticationFailedRejection.CredentialsRejected, List()))
Left(
AuthenticationFailedRejection(
AuthenticationFailedRejection.CredentialsRejected,
List()))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,10 @@ trait SSLConfiguration {
// provide implicit SSLEngine with some protocols
implicit def sslEngineProvider: ServerSSLEngineProvider = {
ServerSSLEngineProvider { engine =>
engine.setEnabledCipherSuites(Array("TLS_RSA_WITH_AES_256_CBC_SHA",
"TLS_ECDH_ECDSA_WITH_RC4_128_SHA",
"TLS_RSA_WITH_AES_128_CBC_SHA"))
engine.setEnabledCipherSuites(
Array("TLS_RSA_WITH_AES_256_CBC_SHA",
"TLS_ECDH_ECDSA_WITH_RC4_128_SHA",
"TLS_RSA_WITH_AES_128_CBC_SHA"))
engine.setEnabledProtocols(Array("TLSv1", "TLSv1.2", "TLSv1.1"))
engine
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ import scala.language.implicitConversions
* @group Engine
*/
class Engine[TD, EI, PD, Q, P, A](
val dataSourceClassMap: Map[
String, Class[_ <: BaseDataSource[TD, EI, Q, A]]],
val dataSourceClassMap: Map[String,
Class[_ <: BaseDataSource[TD, EI, Q, A]]],
val preparatorClassMap: Map[String, Class[_ <: BasePreparator[TD, PD]]],
val algorithmClassMap: Map[String, Class[_ <: BaseAlgorithm[PD, _, Q, P]]],
val servingClassMap: Map[String, Class[_ <: BaseServing[Q, P]]])
Expand Down Expand Up @@ -117,7 +117,8 @@ class Engine[TD, EI, PD, Q, P, A](
def this(dataSourceClass: Class[_ <: BaseDataSource[TD, EI, Q, A]],
preparatorClass: Class[_ <: BasePreparator[TD, PD]],
algorithmClassMap: _root_.java.util.Map[
String, Class[_ <: BaseAlgorithm[PD, _, Q, P]]],
String,
Class[_ <: BaseAlgorithm[PD, _, Q, P]]],
servingClass: Class[_ <: BaseServing[Q, P]]) = this(
Map("" -> dataSourceClass),
Map("" -> preparatorClass),
Expand Down Expand Up @@ -201,63 +202,62 @@ class Engine[TD, EI, PD, Q, P, A](
Doer(algorithmClassMap(algoName), algoParams)
}

val models =
if (persistedModels.exists(m => m.isInstanceOf[Unit.type])) {
// If any of persistedModels is Unit, we need to re-train the model.
logger.info("Some persisted models are Unit, need to re-train.")
val (dataSourceName, dataSourceParams) = engineParams.dataSourceParams
val dataSource = Doer(
dataSourceClassMap(dataSourceName), dataSourceParams)

val (preparatorName, preparatorParams) = engineParams.preparatorParams
val preparator = Doer(
preparatorClassMap(preparatorName), preparatorParams)

val td = dataSource.readTrainingBase(sc)
val pd = preparator.prepareBase(sc, td)

val models = algorithms.zip(persistedModels).map {
case (algo, m) =>
m match {
case Unit => algo.trainBase(sc, pd)
case _ => m
}
}
models
} else {
logger.info("Using persisted model")
persistedModels
val models = if (persistedModels.exists(m => m.isInstanceOf[Unit.type])) {
// If any of persistedModels is Unit, we need to re-train the model.
logger.info("Some persisted models are Unit, need to re-train.")
val (dataSourceName, dataSourceParams) = engineParams.dataSourceParams
val dataSource =
Doer(dataSourceClassMap(dataSourceName), dataSourceParams)

val (preparatorName, preparatorParams) = engineParams.preparatorParams
val preparator =
Doer(preparatorClassMap(preparatorName), preparatorParams)

val td = dataSource.readTrainingBase(sc)
val pd = preparator.prepareBase(sc, td)

val models = algorithms.zip(persistedModels).map {
case (algo, m) =>
m match {
case Unit => algo.trainBase(sc, pd)
case _ => m
}
}
models
} else {
logger.info("Using persisted model")
persistedModels
}

models.zip(algorithms).zip(algoParamsList).zipWithIndex.map {
case (((model, algo), (algoName, algoParams)), ax) => {
model match {
case modelManifest: PersistentModelManifest => {
logger.info("Custom-persisted model detected for algorithm " +
algo.getClass.getName)
SparkWorkflowUtils.getPersistentModel(
modelManifest,
Seq(engineInstanceId, ax, algoName).mkString("-"),
algoParams,
Some(sc),
getClass.getClassLoader)
}
case m => {
try {
logger.info(
s"Loaded model ${m.getClass.getName} for algorithm " +
s"${algo.getClass.getName}")
sc.stop
m
} catch {
case e: NullPointerException =>
logger.warn(
s"Null model detected for algorithm ${algo.getClass.getName}")
m
}
}
} // model match
}
model match {
case modelManifest: PersistentModelManifest => {
logger.info("Custom-persisted model detected for algorithm " +
algo.getClass.getName)
SparkWorkflowUtils.getPersistentModel(
modelManifest,
Seq(engineInstanceId, ax, algoName).mkString("-"),
algoParams,
Some(sc),
getClass.getClassLoader)
}
case m => {
try {
logger.info(
s"Loaded model ${m.getClass.getName} for algorithm " +
s"${algo.getClass.getName}")
sc.stop
m
} catch {
case e: NullPointerException =>
logger.warn(
s"Null model detected for algorithm ${algo.getClass.getName}")
m
}
}
} // model match
}
}
}

Expand Down Expand Up @@ -319,25 +319,26 @@ class Engine[TD, EI, PD, Q, P, A](

val algorithms = algoParamsList.map {
case (algoName, algoParams) => {
try {
Doer(algorithmClassMap(algoName), algoParams)
} catch {
case e: NoSuchElementException => {
if (algoName == "") {
logger.error(
"Empty algorithm name supplied but it could not " +
"match with any algorithm in the engine's definition. " +
"Existing algorithm name(s) are: " +
s"${algorithmClassMap.keys.mkString(", ")}. Aborting.")
} else {
logger.error(s"$algoName cannot be found in the engine's " +
"definition. Existing algorithm name(s) are: " +
s"${algorithmClassMap.keys.mkString(", ")}. Aborting.")
}
sys.exit(1)
}
try {
Doer(algorithmClassMap(algoName), algoParams)
} catch {
case e: NoSuchElementException => {
if (algoName == "") {
logger.error(
"Empty algorithm name supplied but it could not " +
"match with any algorithm in the engine's definition. " +
"Existing algorithm name(s) are: " +
s"${algorithmClassMap.keys.mkString(", ")}. Aborting.")
} else {
logger.error(
s"$algoName cannot be found in the engine's " +
"definition. Existing algorithm name(s) are: " +
s"${algorithmClassMap.keys.mkString(", ")}. Aborting.")
}
sys.exit(1)
}
}
}
}

val (servingName, servingParams) = engineParams.servingParams
Expand Down Expand Up @@ -419,8 +420,9 @@ class Engine[TD, EI, PD, Q, P, A](
val (name, params) =
read[(String, JValue)](engineInstance.dataSourceParams)
if (!dataSourceClassMap.contains(name)) {
logger.error(s"Unable to find datasource class with name '$name'" +
" defined in Engine.")
logger.error(
s"Unable to find datasource class with name '$name'" +
" defined in Engine.")
sys.exit(1)
}
val extractedParams = WorkflowUtils.extractParams(
Expand All @@ -435,8 +437,9 @@ class Engine[TD, EI, PD, Q, P, A](
val (name, params) =
read[(String, JValue)](engineInstance.preparatorParams)
if (!preparatorClassMap.contains(name)) {
logger.error(s"Unable to find preparator class with name '$name'" +
" defined in Engine.")
logger.error(
s"Unable to find preparator class with name '$name'" +
" defined in Engine.")
sys.exit(1)
}
val extractedParams = WorkflowUtils.extractParams(
Expand All @@ -461,8 +464,9 @@ class Engine[TD, EI, PD, Q, P, A](
val servingParamsWithName: (String, Params) = {
val (name, params) = read[(String, JValue)](engineInstance.servingParams)
if (!servingClassMap.contains(name)) {
logger.error(s"Unable to find serving class with name '$name'" +
" defined in Engine.")
logger.error(
s"Unable to find serving class with name '$name'" +
" defined in Engine.")
sys.exit(1)
}
val extractedParams = WorkflowUtils.extractParams(
Expand Down Expand Up @@ -633,22 +637,24 @@ object Engine {
case e: StorageClientException =>
logger.error(
s"Error occured reading from data source. (Reason: " +
e.getMessage + ") Please see the log for debugging details.",
e.getMessage + ") Please see the log for debugging details.",
e)
sys.exit(1)
}

if (!params.skipSanityCheck) {
td match {
case sanityCheckable: SanityCheck => {
logger.info(s"${td.getClass.getName} supports data sanity" +
logger.info(
s"${td.getClass.getName} supports data sanity" +
" check. Performing check.")
sanityCheckable.sanityCheck()
}
sanityCheckable.sanityCheck()
}
case _ => {
logger.info(s"${td.getClass.getName} does not support" +
logger.info(
s"${td.getClass.getName} does not support" +
" data sanity check. Skipping check.")
}
}
}
}

Expand All @@ -662,14 +668,16 @@ object Engine {
if (!params.skipSanityCheck) {
pd match {
case sanityCheckable: SanityCheck => {
logger.info(s"${pd.getClass.getName} supports data sanity" +
logger.info(
s"${pd.getClass.getName} supports data sanity" +
" check. Performing check.")
sanityCheckable.sanityCheck()
}
sanityCheckable.sanityCheck()
}
case _ => {
logger.info(s"${pd.getClass.getName} does not support" +
logger.info(
s"${pd.getClass.getName} does not support" +
" data sanity check. Skipping check.")
}
}
}
}

Expand All @@ -685,14 +693,16 @@ object Engine {
{
model match {
case sanityCheckable: SanityCheck => {
logger.info(s"${model.getClass.getName} supports data sanity" +
logger.info(
s"${model.getClass.getName} supports data sanity" +
" check. Performing check.")
sanityCheckable.sanityCheck()
}
sanityCheckable.sanityCheck()
}
case _ => {
logger.info(s"${model.getClass.getName} does not support" +
logger.info(
s"${model.getClass.getName} does not support" +
" data sanity check. Skipping check.")
}
}
}
}
}
Expand Down Expand Up @@ -777,8 +787,8 @@ object Engine {
algo.batchPredictBase(sc, model, qs)
val predicts: RDD[(QX, (AX, P))] = rawPredicts.map {
case (qx, p) => {
(qx, (ax, p))
}
(qx, (ax, p))
}
}
predicts
}
Expand All @@ -800,18 +810,18 @@ object Engine {

val servingQPAMap: Map[EX, RDD[(Q, P, A)]] = algoPredictsMap.map {
case (ex, psMap) => {
// The query passed to serving.serve is the original one, not
// supplemented.
val qasMap: RDD[(QX, (Q, A))] = evalQAsMap(ex)
val qpsaMap: RDD[(QX, Q, Seq[P], A)] = psMap.join(qasMap).map {
case (qx, t) => (qx, t._2._1, t._1, t._2._2)
}
// The query passed to serving.serve is the original one, not
// supplemented.
val qasMap: RDD[(QX, (Q, A))] = evalQAsMap(ex)
val qpsaMap: RDD[(QX, Q, Seq[P], A)] = psMap.join(qasMap).map {
case (qx, t) => (qx, t._2._1, t._1, t._2._2)
}

val qpaMap: RDD[(Q, P, A)] = qpsaMap.map {
case (qx, q, ps, a) => (q, serving.serveBase(q, ps), a)
}
(ex, qpaMap)
val qpaMap: RDD[(Q, P, A)] = qpsaMap.map {
case (qx, q, ps, a) => (q, serving.serveBase(q, ps), a)
}
(ex, qpaMap)
}
}

(0 until evalCount).map { ex =>
Expand Down
Loading