Replace remaining GenericVector by std::vector for src/training

Signed-off-by: Stefan Weil <sw@weilnetz.de>
This commit is contained in:
Stefan Weil 2021-03-18 08:00:19 +01:00
parent 4d8e9dc659
commit 7df1cb0bab

View File

@ -2,7 +2,6 @@
// File: ctc.cpp
// Description: Slightly improved standard CTC to compute the targets.
// Author: Ray Smith
// Created: Wed Jul 13 15:50:06 PDT 2016
//
// (C) Copyright 2016, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
@ -18,7 +17,6 @@
#include "ctc.h"
#include "genericvector.h"
#include "matrix.h"
#include "network.h"
#include "networkio.h"
@ -217,7 +215,7 @@ static int BestLabel(const GENERIC_2D_ARRAY<float> &outputs, int t) {
// to the network outputs.
float CTC::CalculateBiasFraction() {
// Compute output labels via basic decoding.
GenericVector<int> output_labels;
std::vector<int> output_labels;
for (int t = 0; t < num_timesteps_; ++t) {
int label = BestLabel(outputs_, t);
while (t + 1 < num_timesteps_ && BestLabel(outputs_, t + 1) == label)
@ -226,8 +224,8 @@ float CTC::CalculateBiasFraction() {
output_labels.push_back(label);
}
// Simple bag of labels error calculation.
GenericVector<int> truth_counts(num_classes_, 0);
GenericVector<int> output_counts(num_classes_, 0);
std::vector<int> truth_counts(num_classes_, 0);
std::vector<int> output_counts(num_classes_, 0);
for (int l = 0; l < num_labels_; ++l) {
++truth_counts[labels_[l]];
}
@ -353,10 +351,10 @@ void CTC::NormalizeSequence(GENERIC_2D_ARRAY<double> *probs) const {
void CTC::LabelsToClasses(const GENERIC_2D_ARRAY<double> &probs, NetworkIO *targets) const {
// For each timestep compute the max prob for each class over all
// instances of the class in the labels_.
GenericVector<double> class_probs;
std::vector<double> class_probs;
for (int t = 0; t < num_timesteps_; ++t) {
float *targets_t = targets->f(t);
class_probs.init_to_size(num_classes_, 0.0);
class_probs.resize(num_classes_);
for (int u = 0; u < num_labels_; ++u) {
double prob = probs(t, u);
// Note that although Graves specifies sum over all labels of the same