Skip to content
Snippets Groups Projects
Commit 100cd400 authored by Alex Rubinsteyn's avatar Alex Rubinsteyn
Browse files

track std of scores along with their averages

parent f17269dd
No related branches found
No related tags found
No related merge requests found
......@@ -151,37 +151,55 @@ class ScoreSet(object):
if self.verbose:
print("--> %s:%s %0.4f" % (group, score_type, value))
def averages(self):
def score_types(self):
result = set([])
for (g, d) in sorted(self.groups.items()):
for score_type in sorted(d.keys()):
result.add(score_type)
return list(sorted(result))
def _reduce_scores(self, reduce_fn):
score_types = self.score_types()
return {
group:
OrderedDict([
(score_type, np.mean(scores))
for (score_type, scores)
in sorted(score_dict.items())
(score_type, reduce_fn(score_dict[score_type]))
for score_type
in score_types
])
for (group, score_dict)
in self.groups.items()
}
def score_types(self):
result = set([])
for (g, d) in sorted(self.groups.items()):
for score_type in sorted(d.keys()):
result.add(score_type)
return list(result)
def averages(self):
return self._reduce_scores(np.mean)
def stds(self):
return self._reduce_scores(np.std)
def to_csv(self, filename):
with open(filename, "w") as f:
score_types = scores.score_types()
header_list = ["name"] + list(score_types)
header_list = ["name"]
score_types = self.score_types()
for score_type in score_types:
header_list.append(score_type)
header_list.append(score_type + "_std")
header_line = ",".join(header_list) + "\n"
if self.verbose:
print(header_line)
f.write(header_line)
for name, score_type_dict in sorted(scores.averages().items()):
score_averages = self.averages()
score_stds = self.stds()
for name in sorted(score_averages.keys()):
line_elements = [name]
for _, value in sorted(score_type_dict.items()):
line_elements.append("%0.4f" % value)
for score_type in score_types:
line_elements.append(
"%0.4f" % score_averages[name][score_type])
line_elements.append(
"%0.4f" % score_stds[name][score_type])
line = ",".join(line_elements) + "\n"
if self.verbose:
print(line)
......@@ -204,7 +222,7 @@ if __name__ == "__main__":
"zeroFill": SimpleFill("zero"),
"MICE": MICE(
n_burn_in=5,
n_imputations=20,
n_imputations=25,
min_value=None if args.normalize_rows or args.normalize_columns else 0,
max_value=None if args.normalize_rows or args.normalize_columns else 1,
verbose=VERBOSE),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment