Skip to content

sgnts.transforms.correlate

AdaptiveCorrelate dataclass

Bases: Correlate

Adaptive Correlate filter performs a correlation over a time-dependent set of filters. When the filters are updated, the correlation is performed over both the existing filters and the new filters, then combined using a window function.

Notes

Update frequency: Only 2 sets of filters are supported at this time. This is equivalent to requiring that filters can only be updated once per stride. Attempting to pass more than one update per stride will raise an error. Update duration: The filter update is performed across the entire stride. There is not presently support for more time-domain control of start/stop times for the blending of filters.

Parameters:

Name Type Description Default
filter_sink_name str

str, the name of the sink pad to pull data from

'filters'
init_filters InitVar[Optional[EventBuffer]]

EventBuffer, the filters to correlate over, with a t0, effectively a slice (t0, t_max). This is passed as an EventBuffer with the following types:

ts: int, the start time of the filter update
te: int = TIME_MAX, the end time of the filter update (always set to
    max time for now)
data: Array, the filters to correlate over
None

Raises:

Type Description
ValueError

Raises a value error if more than one filter update is passed per stride

Source code in sgnts/transforms/correlate.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
@dataclass
class AdaptiveCorrelate(Correlate):
    """Adaptive Correlate filter performs a correlation over a time-dependent set of
    filters. When the filters are updated, the correlation is performed over both the
    existing filters and the new filters, then combined using a window function.

    Notes:
        Update frequency:
            Only 2 sets of filters are supported at this time. This is equivalent
            to requiring that filters can only be updated once per stride. Attempting
            to pass more than one update per stride will raise an error.
        Update duration:
            The filter update is performed across the entire stride. There is not
            presently support for more time-domain control of start/stop times for
            the blending of filters.

    Args:
        filter_sink_name:
            str, the name of the sink pad to pull data from
        init_filters:
            EventBuffer, the filters to correlate over, with a t0,
            effectively a slice (t0, t_max). This is passed as an EventBuffer
            with the following types:

                ts: int, the start time of the filter update
                te: int = TIME_MAX, the end time of the filter update (always set to
                    max time for now)
                data: Array, the filters to correlate over

    Raises:
        ValueError:
            Raises a value error if more than one filter update is passed per stride
    """

    filter_sink_name: str = "filters"
    init_filters: InitVar[Optional[EventBuffer]] = None

    def __post_init__(self, init_filters: Optional[EventBuffer]):
        """Setup the adaptive FIR filter"""
        # Setup empty deque for storing filters
        self.filter_deque: Deque = deque()

        # Check that filters are provided for initial condition
        assert (
            init_filters is not None
        ), "filters must be provided to create AdaptiveCorrelate"

        # Set the initial filters
        self.filter_deque.append(init_filters)

        # Argument validation
        self._validate_filters_pad()
        self._validate_init_data()

        # Call the parent's post init, this will setup all the appropriate pads
        super().__post_init__()

    def _validate_init_data(self):
        """Validate arguments given to the adaptive filter"""
        # Check that the filters attribute is not used
        assert self.filters is None, (
            "filters attribute is not used, in adaptive case, use init_filters "
            "instead"
        )

        # Check that the filters are properly formatted if given
        assert self.filters_cur is not None, "filters must be provided"
        if self.filters_cur is not None:
            assert isinstance(
                self.filters_cur, EventBuffer
            ), "filters must be an EventBuffer"
            assert self.filters_cur.te == TIME_MAX, "te must be TIME_MAX"

        # Set filters to the initial filters
        self.filters = self.filters_cur.data

    def _validate_filters_pad(self):
        """Validate the filter sink pad before initializing the filter"""
        # Make sure the filter sink name is not already in use
        assert (
            self.filter_sink_name not in self.sink_pad_names
        ), "Filter sink name already in use"

        # Check that if unaligned pads are specified, that the filter sink name MUST
        # be one of them, if not included then add
        if self.unaligned is not None:
            if self.filter_sink_name not in self.unaligned:
                self.unaligned = list(self.unaligned) + [self.filter_sink_name]
        else:
            self.unaligned = [self.filter_sink_name]

        # Add the filter sink name to the sink pad names
        self.sink_pad_names = list(self.sink_pad_names) + [self.filter_sink_name]

    @property
    def filters_cur(self) -> EventBuffer:
        """Get the current filters"""
        return self.filter_deque[0]

    @property
    def filters_new(self) -> Optional[EventBuffer]:
        """Get the new filters"""
        if len(self.filter_deque) > 1:
            return self.filter_deque[1]

        return None

    @property
    def is_adapting(self) -> bool:
        """Check if the adaptive filter is adapting"""
        return self.filters_new is not None

    def can_adapt(self, frame: TSFrame) -> bool:
        """Check if the buffer can be adapted"""
        if not self.is_adapting:
            return False

        if frame.is_gap:
            return False

        # The below check is unnecessary except for Mypy
        assert self.filters_new is not None  # already checked in first line
        # Check that the frame overlaps the new filter slice
        new_slice = self.filters_new.slice
        frame_slice = frame.slice

        overlap = new_slice & frame_slice
        return overlap.isfinite()

    @wraps(TSTransform.pull)  # type: ignore
    def pull(self, pad: SinkPad, frame: Frame) -> None:  # type: ignore
        # Pull the data from the sink pad
        super().pull(pad, frame)  # type: ignore

        # If the pad is the special filter sink pad, then update filter
        # metadata values
        if pad.name == self.snks[self.filter_sink_name].name:
            # Assume frame is an EventFrame with only 1 EventBuffer in
            # the "events" list
            buf = self.unaligned_data[pad].events["events"][0]

            # If the buffer is null, then short circuit
            if buf.data is None:
                return

            # Redundant check, but more generalizable?
            if len(self.filter_deque) > 1:
                raise ValueError("Only one filter update per stride is supported")

            # Check that the new filters have the same shape as the existing filters
            if (
                self.filters_cur is not None
                and not self.filters_cur.data.shape == buf.data.shape
            ):
                raise ValueError(
                    "New filters must have the same shape as existing filters"
                )

            # Set the new filters
            self.filter_deque.append(buf)

    @wraps(TSTransform.new)  # type: ignore
    def new(self, pad: SourcePad) -> TSFrame:  # type: ignore
        # Get a aligned buffer to see if overlaps with new filters
        frame = self.preparedframes[self.sink_pads[0]]

        if self.can_adapt(frame):
            # Call the parent's new method for each set of filters
            assert self.filters_cur is not None
            self.filters = self.filters_cur.data
            res_cur = super().new(pad)  # type: ignore  # not recognizing self

            # Change the state of filters
            assert self.filters_new is not None
            self.filters = self.filters_new.data
            res_new = super().new(pad)  # type: ignore  # not recognizing self

            # Combine data with window functions

            # remove the new filters to indicate adaptation is complete
            self.filter_deque.popleft()

            # Compute window functions. Window functions
            # will be piecewise functions for the corresponding
            # intersection of the filter slice and data slice
            # where the window function is 0.0 before the intersection
            # and 1.0 after the intersection, and cos^2 in between
            N = res_cur[0].data.shape[-1]
            win_new = (scipy.signal.windows.cosine(2 * N, sym=True) ** 2)[:N]
            win_cur = 1.0 - win_new

            data = win_cur * res_cur[0].data + win_new * res_new[0].data

        else:
            res_new = super().new(pad)  # type: ignore  # not recognizing self
            if res_new.is_gap:
                data = None
            else:
                data = res_new.buffers[0].data

        # Return the new frame
        assert data is None or isinstance(data, numpy.ndarray)  # assert for typing
        frame = TSFrame(
            buffers=[
                SeriesBuffer(
                    offset=res_new[0].offset,
                    data=data,
                    sample_rate=res_new.sample_rate,
                    shape=res_new.shape if data is None else data.shape,
                )
            ],
            EOS=res_new.EOS,
        )
        return frame

filters_cur property

Get the current filters

filters_new property

Get the new filters

is_adapting property

Check if the adaptive filter is adapting

__post_init__(init_filters)

Setup the adaptive FIR filter

Source code in sgnts/transforms/correlate.py
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
def __post_init__(self, init_filters: Optional[EventBuffer]):
    """Setup the adaptive FIR filter"""
    # Setup empty deque for storing filters
    self.filter_deque: Deque = deque()

    # Check that filters are provided for initial condition
    assert (
        init_filters is not None
    ), "filters must be provided to create AdaptiveCorrelate"

    # Set the initial filters
    self.filter_deque.append(init_filters)

    # Argument validation
    self._validate_filters_pad()
    self._validate_init_data()

    # Call the parent's post init, this will setup all the appropriate pads
    super().__post_init__()

_validate_filters_pad()

Validate the filter sink pad before initializing the filter

Source code in sgnts/transforms/correlate.py
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
def _validate_filters_pad(self):
    """Validate the filter sink pad before initializing the filter"""
    # Make sure the filter sink name is not already in use
    assert (
        self.filter_sink_name not in self.sink_pad_names
    ), "Filter sink name already in use"

    # Check that if unaligned pads are specified, that the filter sink name MUST
    # be one of them, if not included then add
    if self.unaligned is not None:
        if self.filter_sink_name not in self.unaligned:
            self.unaligned = list(self.unaligned) + [self.filter_sink_name]
    else:
        self.unaligned = [self.filter_sink_name]

    # Add the filter sink name to the sink pad names
    self.sink_pad_names = list(self.sink_pad_names) + [self.filter_sink_name]

_validate_init_data()

Validate arguments given to the adaptive filter

Source code in sgnts/transforms/correlate.py
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
def _validate_init_data(self):
    """Validate arguments given to the adaptive filter"""
    # Check that the filters attribute is not used
    assert self.filters is None, (
        "filters attribute is not used, in adaptive case, use init_filters "
        "instead"
    )

    # Check that the filters are properly formatted if given
    assert self.filters_cur is not None, "filters must be provided"
    if self.filters_cur is not None:
        assert isinstance(
            self.filters_cur, EventBuffer
        ), "filters must be an EventBuffer"
        assert self.filters_cur.te == TIME_MAX, "te must be TIME_MAX"

    # Set filters to the initial filters
    self.filters = self.filters_cur.data

can_adapt(frame)

Check if the buffer can be adapted

Source code in sgnts/transforms/correlate.py
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
def can_adapt(self, frame: TSFrame) -> bool:
    """Check if the buffer can be adapted"""
    if not self.is_adapting:
        return False

    if frame.is_gap:
        return False

    # The below check is unnecessary except for Mypy
    assert self.filters_new is not None  # already checked in first line
    # Check that the frame overlaps the new filter slice
    new_slice = self.filters_new.slice
    frame_slice = frame.slice

    overlap = new_slice & frame_slice
    return overlap.isfinite()

Correlate dataclass

Bases: TSTransform

Correlates input data with filters

Parameters:

Name Type Description Default
filters Optional[Array]

Array, the filter to correlate over

None
Source code in sgnts/transforms/correlate.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
@dataclass
class Correlate(TSTransform):
    """Correlates input data with filters

    Args:
        filters:
            Array, the filter to correlate over
    """

    sample_rate: int = -1
    filters: Optional[Array] = None

    def __post_init__(self):
        # FIXME: read sample_rate from data
        assert self.filters is not None
        assert self.sample_rate != -1
        self.shape = self.filters.shape
        if self.adapter_config is None:
            self.adapter_config = AdapterConfig()
        self.adapter_config.overlap = (
            Offset.fromsamples(self.shape[-1] - 1, self.sample_rate),
            0,
        )
        self.adapter_config.pad_zeros_startup = False
        super().__post_init__()
        assert (
            len(self.aligned_sink_pads) == 1 and len(self.source_pads) == 1
        ), "only one sink_pad and one source_pad is allowed"

    def corr(self, data: Array) -> Array:
        """Correlate an array of data with an array of filters.

        Args:
            data:
                Array, the data to correlate with the filters

        Returns:
            Array, the result of the correlation
        """
        assert self.filters is not None
        if len(self.filters.shape) == 1:
            return scipy.signal.correlate(data, self.filters, mode="valid")

        # Skip the reshape for now
        os = []
        shape = self.shape
        self.filters = self.filters.reshape(-1, shape[-1])
        for j in range(self.shape[0]):
            os.append(scipy.signal.correlate(data, self.filters[j], mode="valid"))
        return numpy.vstack(os).reshape(shape[:-1] + (-1,))

    # FIXME: wraps are not playing well with mypy.  For now ignore and hope
    # that a future version of mypy will be able to handle this
    @wraps(TSTransform.new)
    def new(self, pad: SourcePad) -> TSFrame:  # type: ignore
        outbufs = []
        outoffsets = self.preparedoutoffsets[self.sink_pads[0]]
        frames = self.preparedframes[self.sink_pads[0]]
        for i, buf in enumerate(frames):
            assert buf.sample_rate == self.sample_rate
            if buf.is_gap:
                data = None
            else:
                # FIXME: Are there multi-channel correlation in numpy or scipy?
                # FIXME: consider multi-dimensional filters
                data = self.corr(buf.data)
            outoffset = outoffsets[i]
            outbufs.append(
                SeriesBuffer(
                    offset=outoffset["offset"],
                    sample_rate=buf.sample_rate,
                    data=data,
                    shape=(
                        self.shape[:-1]
                        + (Offset.tosamples(outoffset["noffset"], buf.sample_rate),)
                        if data is None
                        else data.shape
                    ),
                )
            )
        return TSFrame(buffers=outbufs, EOS=frames.EOS)

corr(data)

Correlate an array of data with an array of filters.

Parameters:

Name Type Description Default
data Array

Array, the data to correlate with the filters

required

Returns:

Type Description
Array

Array, the result of the correlation

Source code in sgnts/transforms/correlate.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def corr(self, data: Array) -> Array:
    """Correlate an array of data with an array of filters.

    Args:
        data:
            Array, the data to correlate with the filters

    Returns:
        Array, the result of the correlation
    """
    assert self.filters is not None
    if len(self.filters.shape) == 1:
        return scipy.signal.correlate(data, self.filters, mode="valid")

    # Skip the reshape for now
    os = []
    shape = self.shape
    self.filters = self.filters.reshape(-1, shape[-1])
    for j in range(self.shape[0]):
        os.append(scipy.signal.correlate(data, self.filters[j], mode="valid"))
    return numpy.vstack(os).reshape(shape[:-1] + (-1,))