Skip to content

feat: Implement optimal segmentation algorithm #24

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Aug 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ optional arguments:
Strategy how to determine rMQR Code size.
```

### Generate rMQR Code
### Generate rMQR Code in scripts
Alternatively, you can also use in python scripts:
```py
from rmqrcode import rMQR
Expand Down Expand Up @@ -126,6 +126,17 @@ The value for `encoder_class` is listed in the below table.
|Byte|ByteEncoder|Any|
|Kanji|KanjiEncoder|from 0x8140 to 0x9FFC, from 0xE040 to 0xEBBF in Shift JIS value|

### Optimal Segmentation
The `rMQR.fit` method mentioned above computes the optimal segmentation.
For example, the data "123Abc" is divided into the following two segments.

|Segment No.|Data|Encoding Mode|
|-|-|-|
|Segment1|123|Numeric|
|Segment2|Abc|Byte|

In the case of other segmentation like "123A bc", the length of the bit string after
encoding will be longer than the above optimal case.

## 🤝 Contributing
Any suggestions are welcome! If you are interesting in contributing, please read [CONTRIBUTING](https://github.com/OUDON/rmqrcode-python/blob/develop/CONTRIBUTING.md).
Expand Down
13 changes: 13 additions & 0 deletions src/rmqrcode/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
class DataTooLongError(ValueError):
"A class represents an error raised when the given data is too long."
pass


class IllegalVersionError(ValueError):
"A class represents an error raised when the given version name is illegal."
pass


class NoSegmentError(ValueError):
"A class represents an error raised when no segments are add"
pass
38 changes: 16 additions & 22 deletions src/rmqrcode/rmqrcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
import logging

from . import encoder
from . import segments as qr_segments
from .enums.color import Color
from .enums.fit_strategy import FitStrategy
from .errors import DataTooLongError, IllegalVersionError, NoSegmentError
from .format.alignment_pattern_coordinates import AlignmentPatternCoordinates
from .format.data_capacities import DataCapacities
from .format.error_correction_level import ErrorCorrectionLevel
Expand Down Expand Up @@ -77,14 +79,12 @@ def fit(data, ecc=ErrorCorrectionLevel.M, fit_strategy=FitStrategy.BALANCED):
determined_width = set()
determined_height = set()

# Fixed value currently
encoder_class = encoder.ByteEncoder

logger.debug("Select rMQR Code version")
for version_name, qr_version in DataCapacities.items():
data_length = encoder_class.length(
data, rMQRVersions[version_name]["character_count_indicator_length"][encoder_class]
)
optimizer = qr_segments.SegmentOptimizer()
optimized_segments = optimizer.compute(data, version_name)
data_length = qr_segments.compute_length(optimized_segments, version_name)

if data_length <= qr_version["number_of_data_bits"][ecc]:
width, height = qr_version["width"], qr_version["height"]
if width not in determined_width and height not in determined_height:
Expand All @@ -95,6 +95,7 @@ def fit(data, ecc=ErrorCorrectionLevel.M, fit_strategy=FitStrategy.BALANCED):
"version": version_name,
"width": width,
"height": height,
"segments": optimized_segments,
}
)
logger.debug(f"ok: {version_name}")
Expand All @@ -121,10 +122,14 @@ def sort_key(x):
logger.debug(f"selected: {selected}")

qr = rMQR(selected["version"], ecc)
qr.add_segment(data, encoder_class)
qr.add_segments(selected["segments"])
qr.make()
return qr

def _optimized_segments(self, data):
optimizer = qr_segments.SegmentOptimizer()
return optimizer.compute(data, self.version_name())

def __init__(self, version, ecc, with_quiet_zone=True, logger=None):
self._logger = logger or rMQR._init_logger()

Expand Down Expand Up @@ -155,6 +160,10 @@ def add_segment(self, data, encoder_class=encoder.ByteEncoder):
"""
self._segments.append({"data": data, "encoder_class": encoder_class})

def add_segments(self, segments):
for segment in segments:
self.add_segment(segment["data"], segment["encoder_class"])

def make(self):
"""Makes an rMQR Code for stored segments.

Expand Down Expand Up @@ -652,18 +661,3 @@ def validate_version(version_name):

"""
return version_name in rMQRVersions


class DataTooLongError(ValueError):
"A class represents an error raised when the given data is too long."
pass


class IllegalVersionError(ValueError):
"A class represents an error raised when the given version name is illegal."
pass


class NoSegmentError(ValueError):
"A class represents an error raised when no segments are add"
pass
200 changes: 200 additions & 0 deletions src/rmqrcode/segments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
from . import encoder
from .errors import DataTooLongError
from .format.rmqr_versions import rMQRVersions

encoders = [
encoder.NumericEncoder,
encoder.AlphanumericEncoder,
encoder.ByteEncoder,
encoder.KanjiEncoder,
]


def compute_length(segments, version_name):
"""Computes the sum of length of the segments.

Args:
segments (list): The list of segment.
version_name (str): The version name.

Returns:
int: The sum of the length of the segments.

"""
return sum(
map(
lambda s: s["encoder_class"].length(
s["data"], rMQRVersions[version_name]["character_count_indicator_length"][s["encoder_class"]]
),
segments,
)
)


class SegmentOptimizer:
"""A class for computing optimal segmentation of the given data by dynamic programming.

Attributes:
MAX_CHARACTER (int): The maximum characters of the given data.
INF (int): Large enough value. This is used as initial value of the dynamic programming table.

"""

MAX_CHARACTER = 360
INF = 100000

def __init__(self):
self.dp = [[[self.INF for n in range(3)] for mode in range(4)] for length in range(self.MAX_CHARACTER + 1)]
self.parents = [[[-1 for n in range(3)] for mode in range(4)] for length in range(self.MAX_CHARACTER + 1)]

def compute(self, data, version):
"""Computes the optimize segmentation for the given data.

Args:
data (str): The data to encode.
version (str): The version name.

Returns:
list: The list of segments.

Raises:
rmqrcode.DataTooLongError: If the data is too long to encode.

"""
if len(data) > self.MAX_CHARACTER:
raise DataTooLongError()

self.qr_version = rMQRVersions[version]
self._compute_costs(data)
best_index = self._find_best(data)
path = self._reconstruct_path(best_index)
segments = self._compute_segments(path, data)
return segments

def _compute_costs(self, data):
"""Computes costs by dynamic programming.

This method computes costs of the dynamic programming table. Define
dp[n][mode][unfilled_length] as the minimize bit length when encode only
the `n`-th leading characters which the last character is encoded in `mode`
and the remainder bits length is `unfilled_length`.

Args:
data (str): The data to encode.

Returns:
void

"""
for mode in range(len(encoders)):
encoder_class = encoders[mode]
character_count_indicator_length = self.qr_version["character_count_indicator_length"][encoder_class]
self.dp[0][mode][0] = encoder_class.length("", character_count_indicator_length)
self.parents[0][mode][0] = (0, 0, 0)

for n in range(0, len(data)):
for mode in range(4):
for unfilled_length in range(3):
if self.dp[n][mode][unfilled_length] == self.INF:
continue

for new_mode in range(4):
if not encoders[new_mode].is_valid_characters(data[n]):
continue

encoder_class = encoders[new_mode]
character_count_indicator_length = self.qr_version["character_count_indicator_length"][
encoder_class
]
if new_mode == mode:
# Keep the mode
if encoder_class == encoder.NumericEncoder:
new_length = (unfilled_length + 1) % 3
cost = 4 if unfilled_length == 0 else 3
elif encoder_class == encoder.AlphanumericEncoder:
new_length = (unfilled_length + 1) % 2
cost = 6 if unfilled_length == 0 else 5
elif encoder_class == encoder.ByteEncoder:
new_length = 0
cost = 8
elif encoder_class == encoder.KanjiEncoder:
new_length = 0
cost = 13
else:
# Change the mode
if encoder_class in [encoder.NumericEncoder, encoder.AlphanumericEncoder]:
new_length = 1
elif encoder_class in [encoder.ByteEncoder, encoder.KanjiEncoder]:
new_length = 0
cost = encoders[new_mode].length(data[n], character_count_indicator_length)

if self.dp[n][mode][unfilled_length] + cost < self.dp[n + 1][new_mode][new_length]:
self.dp[n + 1][new_mode][new_length] = self.dp[n][mode][unfilled_length] + cost
self.parents[n + 1][new_mode][new_length] = (n, mode, unfilled_length)

def _find_best(self, data):
"""Find the index which has the minimum costs.

Args:
data (str): The data to encode.

Returns:
tuple: The best index as tuple (n, mode, unfilled_length).

"""
best = self.INF
best_index = (-1, -1)
for mode in range(4):
for unfilled_length in range(3):
if self.dp[len(data)][mode][unfilled_length] < best:
best = self.dp[len(data)][mode][unfilled_length]
best_index = (len(data), mode, unfilled_length)
return best_index

def _reconstruct_path(self, best_index):
"""Reconstructs the path.

Args:
best_index: The best index computed by self._find_best().

Returns:
list: The path of minimum cost in the dynamic programming table.

"""
path = []
index = best_index
while index[0] != 0:
path.append(index)
index = self.parents[index[0]][index[1]][index[2]]
path.reverse()
return path

def _compute_segments(self, path, data):
"""Computes the segments.

This method computes the segments. The adjacent characters has same mode are merged.

Args:
path (list): The path computed by self._reconstruct_path().
data (str): The data to encode.

Returns:
list: The list of segments.

"""
segments = []
current_segment_data = ""
current_mode = -1
for p in path:
if current_mode == -1:
current_mode = p[1]
current_segment_data += data[p[0] - 1]
elif current_mode == p[1]:
current_segment_data += data[p[0] - 1]
else:
segments.append({"data": current_segment_data, "encoder_class": encoders[current_mode]})
current_segment_data = data[p[0] - 1]
current_mode = p[1]
if current_mode != -1:
segments.append({"data": current_segment_data, "encoder_class": encoders[current_mode]})
return segments
3 changes: 1 addition & 2 deletions tests/rmqrcode_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,7 @@ def test_raise_too_long_error_kanji_encoder(self):

def test_raise_too_long_error_fit(self):
with pytest.raises(DataTooLongError) as e:
s = "a".ljust(200, "a")
rMQR.fit(s)
rMQR.fit("a" * 200)

def test_raise_invalid_version_error(self):
with pytest.raises(IllegalVersionError) as e:
Expand Down
18 changes: 18 additions & 0 deletions tests/segments_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from rmqrcode.segments import SegmentOptimizer, compute_length
from rmqrcode import encoder
import pytest


class TestSegments:
def test_can_optimize_segments(self):
optimizer = SegmentOptimizer()
segments = optimizer.compute("123Abc", "R7x43")
assert segments == [
{"data": "123", "encoder_class": encoder.NumericEncoder},
{"data": "Abc", "encoder_class": encoder.ByteEncoder},
]

def test_compute_length(self):
optimizer = SegmentOptimizer()
segments = optimizer.compute("123Abc", "R7x43")
assert compute_length(segments, "R7x43") is 47