Model
A Model
is the simplest way to build, train, and evaluate neural networks in etch
. The Model
type takes care of the underlying implementation details for Graph
, DataLoader
, and Optimiser
.
There are three types of Model
.
Sequential
: trains a computational graph to predict either continuous variables or classes, and allows more control over the layers of the network.Regressor
: trains a computational graph to predict continuous variables. For example, what will be the future price of a particular currency?Classifier
: trains a computational graph to predict classes. For example, is this a picture of a cat or a dog?
Construct a Model
Create a Model
by setting a flag in the constructor.
function main()
var model1 = Model("sequential");
var model2 = Model("regressor");
var model3 = Model("classifier");
endfunction
Add
Manually add the layers to a sequential
Model
.
The function add(x, y, z, a)
with four parameters requires you specify a flag type, the dimensions, and the activation type.
The function add(x, y, z)
with three parameters requires you specify a flag type and the dimensions.
function main()
var model = Model("sequential");
model.add("dense", 10u64, 10u64, "relu");
model.add("dense", 10u64, 10u64, "relu");
model.add("dense", 10u64, 1u64);
endfunction
Compile
Compile the Model
with the compile()
function.
The function compile(loss-function-flag, optimiser-flag)
takes two inputs. In the below example, we compiled with a mean squared error loss function and an Adam Optimiser
.
function main()
var model = Model("sequential");
model.add("dense", 10u64, 10u64, "relu");
model.add("dense", 10u64, 10u64, "relu");
model.add("dense", 10u64, 1u64);
model.compile("mse", "adam");
endfunction
Read in input data
The readCSV(System.Argv(1))
function allows you to read in data from csv files with the etch
compiler argument commands.
For example, run the below script with the following command. The first if-else block ensure the correct number of arguments with System.Argc()
.
./etch -- file1 file2 file3 file4
function main()
if (System.Argc() != 5)
print("Usage: SCRIPT_FILE -- PATH/TO/BOSTON_TRAIN_DATA.CSV PATH/TO/BOSTON_TRAIN_LABELS.CSV PATH/TO/BOSTON_TEST_DATA.CSV PATH/TO/BOSTON_TEST_LABELS.CSV ");
return;
endif
var data = readCSV(System.Argv(1));
var label = readCSV(System.Argv(2));
var test_data = readCSV(System.Argv(3));
var test_label = readCSV(System.Argv(4));
endfunction
Fit
With the Model
set up as above, you can now add the input data and run the training function fit(data, labels, batch-size)
.
function main()
var data = readCSV(System.Argv(1));
var label = readCSV(System.Argv(2));
var test_data = readCSV(System.Argv(3));
var test_label = readCSV(System.Argv(4));
var model = Model("sequential");
model.add("dense", 13u64, 10u64, "relu");
model.add("dense", 10u64, 10u64, "relu");
model.add("dense", 10u64, 1u64);
model.compile("mse", "adam");
var batch_size = 10u64;
model.fit(data, label, batch_size);
endfunction
Evaluate
Evaluate the prediction error with the evaluate()
function.
function main()
var data = readCSV(System.Argv(1));
var label = readCSV(System.Argv(2));
var test_data = readCSV(System.Argv(3));
var test_label = readCSV(System.Argv(4));
var model = Model("sequential");
model.add("dense", 13u64, 10u64, "relu");
model.add("dense", 10u64, 10u64, "relu");
model.add("dense", 10u64, 1u64);
model.compile("mse", "adam");
var batch_size = 10u64;
model.fit(data, label, batch_size);
var loss = model.evaluate();
printLn(loss);
endfunction
Predict
Finally, make predictions on the data
function main()
var data = readCSV(System.Argv(1));
var label = readCSV(System.Argv(2));
var test_data = readCSV(System.Argv(3));
var test_label = readCSV(System.Argv(4));
var model = Model("sequential");
model.add("dense", 13u64, 10u64, "relu");
model.add("dense", 10u64, 10u64, "relu");
model.add("dense", 10u64, 1u64);
model.compile("mse", "adam");
var batch_size = 10u64;
model.fit(data, label, batch_size);
var loss = model.evaluate();
printLn(loss);
var predictions = model.predict(test_data);
print(predictions.at(0u64, 0u64));
endfunction