diff --git a/csdr/chain/__init__.py b/csdr/chain/__init__.py index 3f0a3b24..6be0ce88 100644 --- a/csdr/chain/__init__.py +++ b/csdr/chain/__init__.py @@ -1,14 +1,11 @@ from pycsdr.modules import Buffer -import logging -logger = logging.getLogger(__name__) - class Chain: def __init__(self, *workers): self.input = None self.output = None - self.workers = workers + self.workers = list(workers) for i in range(1, len(self.workers)): self._connect(self.workers[i - 1], self.workers[i]) @@ -55,6 +52,35 @@ class Chain: else: return self.input.getOutputFormat() + def replace(self, index, newWorker): + if index >= len(self.workers): + raise IndexError("Index {} does not exist".format(index)) + + self.workers[index].stop() + self.workers[index] = newWorker + + if index == 0: + newWorker.setInput(self.input) + else: + previousWorker = self.workers[index - 1] + if isinstance(previousWorker, Chain): + newWorker.setInput(previousWorker.getOutput()) + else: + buffer = Buffer(previousWorker.getOutputFormat()) + previousWorker.setOutput(buffer) + newWorker.setInput(buffer) + + if index < len(self.workers) - 1: + nextWorker = self.workers[index + 1] + if isinstance(newWorker, Chain): + nextWorker.setInput(newWorker.getOutput()) + else: + buffer = Buffer(newWorker.getOutputFormat()) + newWorker.setOutput(buffer) + nextWorker.setInput(buffer) + else: + newWorker.setOutput(self.output) + def pump(self, write): output = self.getOutput() diff --git a/csdr/chain/am.py b/csdr/chain/am.py index 9f0640fe..550032bc 100644 --- a/csdr/chain/am.py +++ b/csdr/chain/am.py @@ -18,6 +18,4 @@ class Am(Demodulator): super().__init__(*workers) def setLastDecimation(self, decimation: Chain): - # TODO: build api to replace workers - # TODO: replace placeholder - pass + self.replace(2, decimation) diff --git a/csdr/chain/fft.py b/csdr/chain/fft.py index 68d03796..9a5f541f 100644 --- a/csdr/chain/fft.py +++ b/csdr/chain/fft.py @@ -1,45 +1,25 @@ from csdr.chain import Chain from pycsdr.modules import Fft, LogPower, LogAveragePower, FftSwap, FftAdpcm -import logging -logger = logging.getLogger(__name__) - class FftAverager(Chain): def __init__(self, fft_size, fft_averages): self.fftSize = fft_size - self.fftAverages = None - self.worker = None - self.input = None - self.output = None - self.setFftAverages(fft_averages) - workers = [self.worker] + self.fftAverages = fft_averages + workers = [self._getWorker()] super().__init__(*workers) def setFftAverages(self, fft_averages): if self.fftAverages == fft_averages: return - if fft_averages == 0 and self.fftAverages != 0: - if self.worker is not None: - self.worker.stop() - self.worker = LogPower(add_db=70) - if self.output is not None: - self.worker.setOutput(self.output) - if self.input is not None: - self.worker.setInput(self.input) - elif fft_averages != 0: - if self.fftAverages == 0 or self.worker is None: - if self.worker is not None: - self.worker.stop() - self.worker = LogAveragePower(add_db=-70, fft_size=self.fftSize, avg_number=fft_averages) - if self.output is not None: - self.worker.setOutput(self.output) - if self.input is not None: - self.worker.setInput(self.input) - else: - self.worker.setAvgNumber(avg_number=fft_averages) - self.workers = [self.worker] self.fftAverages = fft_averages + self.replace(0, self._getWorker()) + + def _getWorker(self): + if self.fftAverages == 0: + return LogPower(add_db=-70) + else: + return LogAveragePower(add_db=-70, fft_size=self.fftSize, avg_number=self.fftAverages) class FftChain(Chain): diff --git a/csdr/chain/fm.py b/csdr/chain/fm.py index ebc42ac7..62e535e5 100644 --- a/csdr/chain/fm.py +++ b/csdr/chain/fm.py @@ -17,6 +17,4 @@ class Fm(Demodulator): super().__init__(*workers) def setLastDecimation(self, decimation: Chain): - # TODO: build api to replace workers - # TODO: replace placeholder - pass + self.replace(2, decimation)