diff --git a/setup.cfg b/setup.cfg index 94f2368..59f03bb 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = rmqrcode -version = 0.3.0 +version = 0.3.1 author = Takahiro Tomita author_email = ttp8101@gmail.com description = An rMQR Code Generetor diff --git a/src/rmqrcode/rmqrcode.py b/src/rmqrcode/rmqrcode.py index 74dc670..8330a8a 100644 --- a/src/rmqrcode/rmqrcode.py +++ b/src/rmqrcode/rmqrcode.py @@ -82,23 +82,24 @@ def fit(data, ecc=ErrorCorrectionLevel.M, fit_strategy=FitStrategy.BALANCED): logger.debug("Select rMQR Code version") for version_name, qr_version in DataCapacities.items(): optimizer = qr_segments.SegmentOptimizer() - optimized_segments = optimizer.compute(data, version_name) - data_length = qr_segments.compute_length(optimized_segments, version_name) - - if data_length <= qr_version["number_of_data_bits"][ecc]: - width, height = qr_version["width"], qr_version["height"] - if width not in determined_width and height not in determined_height: - determined_width.add(width) - determined_height.add(height) - ok_versions.append( - { - "version": version_name, - "width": width, - "height": height, - "segments": optimized_segments, - } - ) - logger.debug(f"ok: {version_name}") + try: + optimized_segments = optimizer.compute(data, version_name, ecc) + except DataTooLongError: + continue + + width, height = qr_version["width"], qr_version["height"] + if width not in determined_width and height not in determined_height: + determined_width.add(width) + determined_height.add(height) + ok_versions.append( + { + "version": version_name, + "width": width, + "height": height, + "segments": optimized_segments, + } + ) + logger.debug(f"ok: {version_name}") if len(ok_versions) == 0: raise DataTooLongError("The data is too long.") @@ -128,7 +129,7 @@ def sort_key(x): def _optimized_segments(self, data): optimizer = qr_segments.SegmentOptimizer() - return optimizer.compute(data, self.version_name()) + return optimizer.compute(data, self.version_name(), self._error_correction_level) def __init__(self, version, ecc, with_quiet_zone=True, logger=None): self._logger = logger or rMQR._init_logger() diff --git a/src/rmqrcode/segments.py b/src/rmqrcode/segments.py index 1ca5b40..02aa436 100644 --- a/src/rmqrcode/segments.py +++ b/src/rmqrcode/segments.py @@ -1,5 +1,6 @@ from . import encoder from .errors import DataTooLongError +from .format.data_capacities import DataCapacities from .format.rmqr_versions import rMQRVersions encoders = [ @@ -47,12 +48,13 @@ def __init__(self): self.dp = [[[self.INF for n in range(3)] for mode in range(4)] for length in range(self.MAX_CHARACTER + 1)] self.parents = [[[-1 for n in range(3)] for mode in range(4)] for length in range(self.MAX_CHARACTER + 1)] - def compute(self, data, version): + def compute(self, data, version, ecc): """Computes the optimize segmentation for the given data. Args: data (str): The data to encode. version (str): The version name. + ecc (rmqrcode.ErrorCorrectionLevel): The error correction level. Returns: list: The list of segments. @@ -66,8 +68,11 @@ def compute(self, data, version): self.qr_version = rMQRVersions[version] self._compute_costs(data) - best_index = self._find_best(data) - path = self._reconstruct_path(best_index) + best = self._find_best(data) + if best["cost"] > DataCapacities[version]["number_of_data_bits"][ecc]: + raise DataTooLongError + + path = self._reconstruct_path(best["index"]) segments = self._compute_segments(path, data) return segments @@ -102,36 +107,67 @@ def _compute_costs(self, data): if not encoders[new_mode].is_valid_characters(data[n]): continue - encoder_class = encoders[new_mode] - character_count_indicator_length = self.qr_version["character_count_indicator_length"][ - encoder_class - ] if new_mode == mode: - # Keep the mode - if encoder_class == encoder.NumericEncoder: - new_length = (unfilled_length + 1) % 3 - cost = 4 if unfilled_length == 0 else 3 - elif encoder_class == encoder.AlphanumericEncoder: - new_length = (unfilled_length + 1) % 2 - cost = 6 if unfilled_length == 0 else 5 - elif encoder_class == encoder.ByteEncoder: - new_length = 0 - cost = 8 - elif encoder_class == encoder.KanjiEncoder: - new_length = 0 - cost = 13 + cost, new_length = self._compute_new_state_without_mode_changing( + data[n], new_mode, unfilled_length + ) else: - # Change the mode - if encoder_class in [encoder.NumericEncoder, encoder.AlphanumericEncoder]: - new_length = 1 - elif encoder_class in [encoder.ByteEncoder, encoder.KanjiEncoder]: - new_length = 0 - cost = encoders[new_mode].length(data[n], character_count_indicator_length) + cost, new_length = self._compute_new_state_with_mode_changing( + data[n], new_mode, unfilled_length + ) if self.dp[n][mode][unfilled_length] + cost < self.dp[n + 1][new_mode][new_length]: self.dp[n + 1][new_mode][new_length] = self.dp[n][mode][unfilled_length] + cost self.parents[n + 1][new_mode][new_length] = (n, mode, unfilled_length) + def _compute_new_state_without_mode_changing(self, character, new_mode, unfilled_length): + """Computes the new state values without mode changing. + + Args: + character (str): The current character. Assume this as one length string. + new_mode (int): The state of the new mode. + unfilled_length (int): The state of the current unfilled_length. + + Returns: + tuple: (cost, new_length). + + """ + encoder_class = encoders[new_mode] + if encoder_class == encoder.NumericEncoder: + new_length = (unfilled_length + 1) % 3 + cost = 4 if unfilled_length == 0 else 3 + elif encoder_class == encoder.AlphanumericEncoder: + new_length = (unfilled_length + 1) % 2 + cost = 6 if unfilled_length == 0 else 5 + elif encoder_class == encoder.ByteEncoder: + new_length = 0 + cost = 8 * len(character.encode("utf-8")) + elif encoder_class == encoder.KanjiEncoder: + new_length = 0 + cost = 13 + return (cost, new_length) + + def _compute_new_state_with_mode_changing(self, character, new_mode, unfilled_length): + """Computes the new state values with mode changing. + + Args: + character (str): The current character. Assume this as one length string. + new_mode (int): The state of the new mode. + unfilled_length (int): The state of the current unfilled_length. + + Returns: + tuple: (cost, new_length). + + """ + encoder_class = encoders[new_mode] + character_count_indicator_length = self.qr_version["character_count_indicator_length"][encoder_class] + if encoder_class in [encoder.NumericEncoder, encoder.AlphanumericEncoder]: + new_length = 1 + elif encoder_class in [encoder.ByteEncoder, encoder.KanjiEncoder]: + new_length = 0 + cost = encoder_class.length(character, character_count_indicator_length) + return (cost, new_length) + def _find_best(self, data): """Find the index which has the minimum costs. @@ -139,7 +175,8 @@ def _find_best(self, data): data (str): The data to encode. Returns: - tuple: The best index as tuple (n, mode, unfilled_length). + dict: The dict object includes "cost" and "index". The "cost" is the value of minimum cost. + The "index" is the index of the dp table as a tuple (n, mode, unfilled_length). """ best = self.INF @@ -149,7 +186,7 @@ def _find_best(self, data): if self.dp[len(data)][mode][unfilled_length] < best: best = self.dp[len(data)][mode][unfilled_length] best_index = (len(data), mode, unfilled_length) - return best_index + return {"cost": best, "index": best_index} def _reconstruct_path(self, best_index): """Reconstructs the path. diff --git a/tests/segments_test.py b/tests/segments_test.py index 073d230..552dd27 100644 --- a/tests/segments_test.py +++ b/tests/segments_test.py @@ -1,18 +1,59 @@ from rmqrcode.segments import SegmentOptimizer, compute_length -from rmqrcode import encoder +from rmqrcode import encoder, ErrorCorrectionLevel, DataTooLongError import pytest class TestSegments: - def test_can_optimize_segments(self): + def test_can_optimize_segments_numeric_and_byte(self): optimizer = SegmentOptimizer() - segments = optimizer.compute("123Abc", "R7x43") + segments = optimizer.compute("123Abc", "R7x43", ErrorCorrectionLevel.M) assert segments == [ {"data": "123", "encoder_class": encoder.NumericEncoder}, {"data": "Abc", "encoder_class": encoder.ByteEncoder}, ] + def test_can_optimize_segments_alphanumeric_and_kanji(self): + optimizer = SegmentOptimizer() + segments = optimizer.compute("17:30集合", "R7x59", ErrorCorrectionLevel.M) + assert segments == [ + {"data": "17:30", "encoder_class": encoder.AlphanumericEncoder}, + {"data": "集合", "encoder_class": encoder.KanjiEncoder}, + ] + + def test_can_optimize_segments_numeric_only(self): + optimizer = SegmentOptimizer() + segments = optimizer.compute("123456", "R7x59", ErrorCorrectionLevel.M) + assert segments == [ + {"data": "123456", "encoder_class": encoder.NumericEncoder}, + ] + + def test_can_optimize_segments_alphanumeric_only(self): + optimizer = SegmentOptimizer() + segments = optimizer.compute("HTTPS://", "R7x59", ErrorCorrectionLevel.M) + assert segments == [ + {"data": "HTTPS://", "encoder_class": encoder.AlphanumericEncoder}, + ] + + def test_can_optimize_segments_byte_only(self): + optimizer = SegmentOptimizer() + segments = optimizer.compute("1+zY!a:K", "R7x59", ErrorCorrectionLevel.M) + assert segments == [ + {"data": "1+zY!a:K", "encoder_class": encoder.ByteEncoder}, + ] + + def test_can_optimize_segments_kanji_only(self): + optimizer = SegmentOptimizer() + segments = optimizer.compute("漢字", "R7x59", ErrorCorrectionLevel.M) + assert segments == [ + {"data": "漢字", "encoder_class": encoder.KanjiEncoder}, + ] + + def test_optimize_segments_raises_data_too_long_error(self): + optimizer = SegmentOptimizer() + with pytest.raises(DataTooLongError) as e: + segments = optimizer.compute("a" * 12, "R7x59", ErrorCorrectionLevel.M) + def test_compute_length(self): optimizer = SegmentOptimizer() - segments = optimizer.compute("123Abc", "R7x43") + segments = optimizer.compute("123Abc", "R7x43", ErrorCorrectionLevel.M) assert compute_length(segments, "R7x43") is 47