Skip to content

Commit

Permalink
fix: added getModel validation with no inputs/outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
yawetse committed Aug 4, 2022
1 parent 45859c2 commit 411010b
Show file tree
Hide file tree
Showing 3 changed files with 315 additions and 9 deletions.
233 changes: 233 additions & 0 deletions src/__test__/mock_automl_data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1071,4 +1071,237 @@ export const autoMLdata = {
1,
4
]
};

export const autoMLdataSM = {
"outputs": 5,
"rowRange": [
1,
151
],
"colRange": [
1,
5
],
"data": [
[
"sepal_length_cm",
"sepal_width_cm",
"petal_length_cm",
"petal_width_cm",
"plant"
],
[
5.1,
3.5,
1.4,
0.2,
"Iris-setosa"
],
[
4.9,
3,
1.4,
0.2,
"Iris-setosa"
],
[
4.7,
3.2,
1.3,
0.2,
"Iris-setosa"
],
[
4.6,
3.1,
1.5,
0.2,
"Iris-setosa"
],
[
6,
2.7,
5.1,
1.6,
"Iris-versicolor"
],
[
5.4,
3,
4.5,
1.5,
"Iris-versicolor"
],
[
6,
3.4,
4.5,
1.6,
"Iris-versicolor"
],
[
6.7,
3.1,
4.7,
1.5,
"Iris-versicolor"
],
[
6.3,
2.3,
4.4,
1.3,
"Iris-versicolor"
],
[
6.3,
2.5,
5,
1.9,
"Iris-virginica"
],
[
6.5,
3,
5.2,
2,
"Iris-virginica"
],
[
6.2,
3.4,
5.4,
2.3,
"Iris-virginica"
],
[
5.9,
3,
5.1,
1.8,
"Iris-virginica"
]
],
"inputs": [
1,
4
]
};


export const autoMLdataTNE = {
"outputs": 5,
"rowRange": [
1,
151
],
"colRange": [
1,
5
],
"data": [
[
"sepal_length_cm",
"sepal_width_cm",
"petal_length_cm",
"petal_width_cm",
"plant"
],
[
5.1,
3.5,
1.4,
0.2,
"Iris-setosa"
],
[
4.9,
3,
1.4,
0.2,
"Iris-setosa"
],
[
4.7,
3.2,
1.3,
0.2,
"Iris-setosa"
],
[
4.6,
3.1,
1.5,
0.2,
""
],
[
6,
2.7,
5.1,
1.6,
"Iris-versicolor"
],
[
5.4,
3,
4.5,
1.5,
"Iris-versicolor"
],
[
6,
3.4,
4.5,
1.6,
"Iris-versicolor"
],
[
6.7,
3.1,
4.7,
1.5,
"Iris-versicolor"
],
[
6.3,
2.3,
4.4,
1.3,
""
],
[
6.3,
2.5,
5,
1.9,
"Iris-virginica"
],
[
6.5,
3,
5.2,
2,
"Iris-virginica"
],
[
6.2,
3.4,
5.4,
2.3,
"Iris-virginica"
],
[
5.9,
3,
5.1,
1.8,
""
]
],
"inputs": [
1,
4
]
};
83 changes: 77 additions & 6 deletions src/automl.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import * as JSONM from './index';
import { ModelTypes } from './model';
import { toBeWithinRange, } from './jest.test';
expect.extend({ toBeWithinRange });
import {autoMLdata} from './__test__/mock_automl_data'
import {autoMLdata, autoMLdataSM, autoMLdataTNE} from './__test__/mock_automl_data'
import { Data } from '@jsonstack/data/src/DataSet';
import { setBackend } from './tensorflow_singleton';
import * as tf from '@tensorflow/tfjs-node';
Expand Down Expand Up @@ -66,8 +66,44 @@ describe('AutoML Sheets Test',()=>{
});
})
describe('mock end to end example',()=>{

it('should run a basic test from spreadsheet data',async ()=>{
it('should run a basic test from spreadsheet with no prediction data',async ()=>{
const on_progress = ({
completion_percentage,
loss,
epoch,
logs,
status,
defaultLog,
}:TrainingProgressUpdate)=>{
if(status!=='training') console.log({status,defaultLog})
}
// const vectors = autoMLdata?.data.concat([]);
// const labels = vectors?.splice(0,1)[0] as string[];
// const dataset = JSONM.Data.DataSet.reverseColumnMatrix({labels,vectors});\
//@ts-ignore
const{vectors,labels,dataset}=getSpreadsheetDataset(autoMLdataSM?.data,{on_progress});
//@ts-ignore
const {columns,inputs,outputs} = JSONM.getInputsOutputsFromDataset({dataset,labels, on_progress});
const {trainingData,predictionData} = await splitTrainingPredictionData({
inputs,
outputs,
data: dataset,
});
try{
const SpreadsheetModel = await getModel({
type:'prediction',
inputs,
outputs,
dataset:trainingData,
//@ts-ignore
on_progress,
});
await SpreadsheetModel.trainModel();
} catch(e){
expect(e).toBeInstanceOf(RangeError)
}
},30000)
it('should run a basic test from spreadsheet with small prediction data',async ()=>{
const on_progress = ({
completion_percentage,
loss,
Expand All @@ -82,15 +118,15 @@ describe('AutoML Sheets Test',()=>{
// const labels = vectors?.splice(0,1)[0] as string[];
// const dataset = JSONM.Data.DataSet.reverseColumnMatrix({labels,vectors});\
//@ts-ignore
const{vectors,labels,dataset}=getSpreadsheetDataset(autoMLdata?.data,{on_progress});
const{vectors,labels,dataset}=getSpreadsheetDataset(autoMLdataTNE?.data,{on_progress});
//@ts-ignore
const {columns,inputs,outputs} = JSONM.getInputsOutputsFromDataset({dataset,labels, on_progress});
const {trainingData,predictionData} = await splitTrainingPredictionData({
inputs,
outputs,
data: dataset,
});
// console.log({trainingData,predictionData});
console.log({trainingData,predictionData})
const SpreadsheetModel = await getModel({
type:'prediction',
inputs,
Expand All @@ -100,8 +136,43 @@ describe('AutoML Sheets Test',()=>{
on_progress,
});
await SpreadsheetModel.trainModel();
},30000)
// describe('mock end to end example',()=>{
// it('should run a basic test from spreadsheet data',async ()=>{
// const on_progress = ({
// completion_percentage,
// loss,
// epoch,
// logs,
// status,
// defaultLog,
// }:TrainingProgressUpdate)=>{
// if(status!=='training') console.log({status,defaultLog})
// }
// // const vectors = autoMLdata?.data.concat([]);
// // const labels = vectors?.splice(0,1)[0] as string[];
// // const dataset = JSONM.Data.DataSet.reverseColumnMatrix({labels,vectors});\
// //@ts-ignore
// const{vectors,labels,dataset}=getSpreadsheetDataset(autoMLdata?.data,{on_progress});
// //@ts-ignore
// const {columns,inputs,outputs} = JSONM.getInputsOutputsFromDataset({dataset,labels, on_progress});
// const {trainingData,predictionData} = await splitTrainingPredictionData({
// inputs,
// outputs,
// data: dataset,
// });
// // console.log({trainingData,predictionData});
// const SpreadsheetModel = await getModel({
// type:'prediction',
// inputs,
// outputs,
// dataset:trainingData,
// //@ts-ignore
// on_progress,
// });
// await SpreadsheetModel.trainModel();


},30000)
// },30000)
})
});
8 changes: 5 additions & 3 deletions src/jsonm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ export async function getModelFromJSONM(jml?: JML): Promise<ModelX> {
const trainingData = Array.isArray(jml.dataset)
? jml.dataset
: await getDataSet(jml.dataset);
if(jml.outputs.length<1) throw new RangeError('Every model requires at least one output')
if(jml.inputs.length<1) throw new RangeError('Every model requires at least one input')

return new ModelX({
trainingData,
Expand Down Expand Up @@ -149,10 +151,10 @@ export function getModelOptions(jml?:JML,datum?:Datum){
}
})
const dataset = await getDataSet(options?.data);
const {trainingData, predictionData} = dataset.reduce((result,datum)=>{
const {trainingData, predictionData} = dataset.reduce((result,datum,idx)=>{
if(options?.outputs?.filter((output)=> isEmpty(datum[output])
).length) result.predictionData.push(datum);
else result.trainingData.push(datum);
).length) result.predictionData.push({...datum,__original_dataset_index: idx});
else result.trainingData.push({...datum,__original_dataset_index: idx});
return result;
},{trainingData:[],predictionData:[],})
return {trainingData,predictionData}
Expand Down

0 comments on commit 411010b

Please sign in to comment.