Skip to content
Snippets Groups Projects
Commit 8be19316 authored by Paul McCarthy's avatar Paul McCarthy :mountain_bicyclist:
Browse files

MNT: Modify saveLabelFile to support saving classification probabilities

parent d45b3ce9
No related branches found
No related tags found
No related merge requests found
......@@ -343,33 +343,40 @@ def saveLabelFile(allLabels,
filename,
dirname=None,
listBad=True,
signalLabels=None):
signalLabels=None,
probabilities=None):
"""Saves the given classification labels to the specified file. The
classifications are saved in the format described in the
:func:`loadLabelFile` method.
:arg allLabels: A list of lists, one list for each component, where
each list contains the labels for the corresponding
component.
:arg allLabels: A list of lists, one list for each component, where
each list contains the labels for the corresponding
component.
:arg filename: Name of the file to which the labels should be saved.
:arg filename: Name of the file to which the labels should be saved.
:arg dirname: If provided, is output as the first line of the file.
Intended to be a relative path to the MELODIC analysis
directory with which this label file is associated. If
not provided, a ``'.'`` is output as the first line.
:arg dirname: If provided, is output as the first line of the file.
Intended to be a relative path to the MELODIC analysis
directory with which this label file is associated. If
not provided, a ``'.'`` is output as the first line.
:arg listBad: If ``True`` (the default), the last line of the file
will contain a comma separated list of components which
are deemed 'noisy' (see :func:`isNoisyComponent`).
:arg listBad: If ``True`` (the default), the last line of the file
will contain a comma separated list of components which
are deemed 'noisy' (see :func:`isNoisyComponent`).
:arg signalLabels: Labels which should be deemed 'signal' - see the
:func:`isNoisyComponent` function.
:arg signalLabels: Labels which should be deemed 'signal' - see the
:func:`isNoisyComponent` function.
:arg probabilities: Classification probabilities. If provided, the
probability for each component is saved to the file.
"""
lines = []
noisyComps = []
if probabilities is not None and len(probabilities) != len(allLabels):
raise ValueError('len(probabilities) != len(allLabels)')
# The first line - the melodic directory name
if dirname is None:
dirname = '.'
......@@ -387,6 +394,9 @@ def saveLabelFile(allLabels,
labels = [l.replace(',', '_') for l in labels]
tokens = [str(comp)] + labels + [str(noise)]
if probabilities is not None:
tokens.append(f'{probabilities[i]:0.6f}')
lines.append(', '.join(tokens))
if noise:
......@@ -422,4 +432,3 @@ class InvalidLabelFileError(Exception):
"""Exception raised by the :func:`loadLabelFile` function when an attempt
is made to load an invalid label file.
"""
pass
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment