正文
def input_fn():
...<code>...
return ({ 'SepalLength':[values], ..<etc>.., 'PetalWidth':[values] },
[IrisFlowerType])
返回值必须是一个双元素元组,其组织如下:
•
第一个元素必须是一个dict(命令),其中每个输入特征都是一个键,然后是训练批量的值列表。
•
第二个元素是训练批量的标签列表。
由于我们返回了一批输入特征和训练标签,所以这意味着返回语句中的所有列表将具有相同的长度。从技术上说,每当我们在这里提到“列表”时,实际上指的是一个1-d TensorFlow张量。
为了使得能够重用input_fn,我们将添加一些参数。从而使得我们能够用不同的设置构建输入函数。这些配置是很简单的:
file_path:
要读取的数据文件。
perform_shuffle:
记录顺序是否应该是随机的。
repeat_count:
迭代数据集中记录的次数。例如,如果我们指定1,则每个记录将被读取一次。如果我们指定None,则迭代将永远持续下去。
以下是使用Dataset API实现此函数的方法。我们将把它封装在一个“输入函数”中,它将与我们馈送评估器模型相适应。
def my_input_fn(file_path, perform_shuffle=False, repeat_count=1):
def decode_csv(line):
parsed_line = tf.decode_csv(line, [[0.], [0.], [0.], [0.], [0]])
label = parsed_line[-1:] # Last element is the label
del parsed_line[-1] # Delete last element
features = parsed_line # Everything (but last element) are the features
d = dict(zip(feature_names, features)), label return d
dataset = (tf.contrib.data.TextLineDataset(file_path) # Read text file
.skip(1) # Skip header row
.map(decode_csv)) # Transform each elem by applying decode_csv fn
if perform_shuffle:
# Randomizes input using a window of 256 elements (read into memory)
dataset = dataset.shuffle(buffer_size=256)
dataset = dataset.repeat(repeat_count) # Repeats dataset this # times
dataset = dataset.batch(32) # Batch size to use
iterator = dataset.make_one_shot_iterator()
batch_features, batch_labels = iterator.get_next()
return batch_features, batch_labels
请注意以下事项:
TextLineDataset:
当你使用其基于文件的数据集时,Dataset API将为你处理大量的内存管理。例如,你可以通过指定列表作为参数,读取比内存大得多的数据集文件或读入多个文件。
Shuffle(随机化):
读取buffer_size记录,然后shuffle(随机化)其顺序。
Map(映射):
将数据集中的每个元素调用decode_csv函数,作为参数(因为我们使用的是TextLineDataset,每个元素都将是一行CSV文本)。然后我们将decode_csv应用于每一行。
decode_csv: