76 lines
1.8 KiB
Mathematica
76 lines
1.8 KiB
Mathematica
|
function final_result=objective_process(datapass,elechoose)
|
||
|
|
||
|
traindata=datapass{1};
|
||
|
trainclass=datapass{2};
|
||
|
testdata=datapass{3};
|
||
|
lengthdata=datapass{5};
|
||
|
encode_word=datapass{6};
|
||
|
flag=datapass{21};
|
||
|
networkiter=datapass{22};
|
||
|
datanum=find(elechoose==0);
|
||
|
traindata_matrix=cell2mat(traindata);
|
||
|
loc=find(ismember(traindata_matrix,datanum));
|
||
|
traindata_matrix(loc)=0;
|
||
|
[rr,cc]=size(traindata_matrix);
|
||
|
for kr=1:rr
|
||
|
traindata1{kr}=traindata_matrix(kr,:);
|
||
|
end
|
||
|
traindata=traindata1;
|
||
|
datain_size=1;
|
||
|
dim_data=50;
|
||
|
hidden_len=80;
|
||
|
total_word=encode_word.NumWords;
|
||
|
no_of_class=3;
|
||
|
network_layer_infor=[ ...
|
||
|
sequenceInputLayer(datain_size)
|
||
|
wordEmbeddingLayer(dim_data,total_word)
|
||
|
lstmLayer(hidden_len,'OutputMode','last')
|
||
|
fullyConnectedLayer(no_of_class)
|
||
|
softmaxLayer
|
||
|
classificationLayer];
|
||
|
|
||
|
if(flag==1)
|
||
|
train_opt=trainingOptions('adam', ...
|
||
|
'MiniBatchSize',16, ...
|
||
|
'GradientThreshold',2, ...
|
||
|
'Shuffle','every-epoch', ...
|
||
|
'Plots','training-progress', ...
|
||
|
'Verbose',false);
|
||
|
|
||
|
else
|
||
|
|
||
|
train_opt=trainingOptions('adam', ...
|
||
|
'MiniBatchSize',16, ...
|
||
|
'GradientThreshold',2, ...
|
||
|
'Shuffle','every-epoch', ...
|
||
|
'Plots','none', ...
|
||
|
'Verbose',false);
|
||
|
|
||
|
end
|
||
|
train_opt.MaxEpochs=networkiter;
|
||
|
net=trainNetwork(traindata,trainclass,network_layer_infor,train_opt);
|
||
|
resultout=predict(net,testdata);
|
||
|
[maxval,maxlc]=max(((round(resultout.'))));
|
||
|
ypred1=categorical(maxlc).';
|
||
|
sin=double(trainclass);
|
||
|
sout=double(ypred1);
|
||
|
tardata=[];
|
||
|
resdata=[];
|
||
|
for km=1:length(sin)
|
||
|
tardata=[tardata double(ismember([1;2;3],sin(km)))];
|
||
|
resdata=[resdata double(ismember([1;2;3],sout(km)))];
|
||
|
end
|
||
|
[~,confu_result]=confusion(tardata,resdata);
|
||
|
%% find accuracy
|
||
|
accuracy=(sum(diag(confu_result))/sum(confu_result(:)))*100;
|
||
|
final_result{1}=accuracy;
|
||
|
final_result{2}=confu_result;
|
||
|
final_result{3}=tardata;
|
||
|
final_result{4}=resdata;
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|