"""
Finalfusion Vocabulary interface
"""
import abc
import struct
from typing import List, Optional, Dict, Tuple, BinaryIO, Iterable, Any, Union, Sequence, \
Iterator, Collection
from finalfusion.io import Chunk, _read_required_binary, _write_binary
[docs]class Vocab(Chunk, Collection[str]):
"""
Finalfusion vocabulary interface.
Vocabs provide at least a simple string to index mapping and index to
string mapping. Vocab is the base type of all vocabulary types.
"""
@property
@abc.abstractmethod
def words(self) -> List[str]:
"""
Get the list of known words
Returns
-------
words : List[str]
list of known words
"""
@property
@abc.abstractmethod
def word_index(self) -> Dict[str, int]:
"""
Get the index of known words
Returns
-------
dict : Dict[str, int]
index of known words
"""
@property
@abc.abstractmethod
def upper_bound(self) -> int:
"""
The exclusive upper bound of indices in this vocabulary.
Returns
-------
upper_bound : int
Exclusive upper bound of indices covered by the vocabulary.
"""
[docs] @abc.abstractmethod
def idx(self, item: str, default: Optional[Union[int, List[int]]] = None
) -> Optional[Union[int, List[int]]]:
"""
Lookup the given query item.
This lookup does not raise an exception if the vocab can't produce indices.
Parameters
----------
item : str
The query item.
default : Optional[Union[int, List[int]]]
Fall-back value to return if the vocab can't provide indices.
Returns
-------
index : Optional[Union[int, List[int]]]
* An integer if there is a single index for a known item.
* A list if the vocab can provide subword indices for a unknown item.
* The provided `default` item if the vocab can't provide indices.
"""
def __getitem__(self, item: str) -> Union[int, List[int]]:
return self.word_index[item]
def __contains__(self, item: Any) -> bool:
# usual case: checking whether a str is known
if isinstance(item, str):
return self.word_index.get(item) is not None
# e.g. allows checking whether one vocab is the superset of another
if hasattr(item, "__iter__"):
return all(w in self for w in item)
return False
def __iter__(self) -> Iterator[str]:
return iter(self.words)
def __len__(self) -> int:
return len(self.word_index)
def __eq__(self, other: Any) -> bool:
if not isinstance(other, type(self)):
return False
if self.words != other.words:
return False
if self.word_index != other.word_index:
return False
return True
def __repr__(self) -> str:
return f"{type(self).__name__}(n_words={len(self)}, upper_bound={self.upper_bound})"
def _write_words_binary(b_words: Iterable[bytes], file: BinaryIO):
"""
Helper method to write an iterable of bytes and their lengths.
"""
for word in b_words:
_write_binary(file, "<I", len(word))
file.write(word)
def _read_items(file: BinaryIO, length: int) -> List[str]:
"""
Helper method to read items from a vocabulary chunk.
Parameters
----------
file : BinaryIO
input file
length : int
number of items to read
Returns
-------
words : List[str]
The word list
"""
items = []
for _ in range(length):
item_length = _read_required_binary(file, "<I")[0]
word = file.read(item_length).decode("utf-8")
items.append(word)
return items
def _read_items_with_indices(file: BinaryIO,
length: int) -> Tuple[List[str], Dict[str, int]]:
"""
Helper method to read items from a vocabulary chunk.
Parameters
----------
file : BinaryIO
input file
length : int
number of items to read
Returns
-------
words : List[str]
The word list
"""
items = []
index = dict()
for _ in range(length):
item_length = _read_required_binary(file, "<I")[0]
item = file.read(item_length).decode("utf-8")
idx = _read_required_binary(file, "<Q")[0]
items.append(item)
index[item] = idx
return items, index
def _calculate_binary_list_size(items: List[str]):
size = sum(len(bytes(item, "utf-8")) for item in items)
size += struct.calcsize("<Q")
size += len(items) * struct.calcsize("<I")
return size
def _validate_items_and_create_index(items: Sequence[str]) -> Dict[str, int]:
index = dict((item, idx) for idx, item in enumerate(items))
n_unique_items = len(index)
assert len(items) == n_unique_items,\
f"Vocab items cannot be duplicated. List: {len(items)}, Unique: {n_unique_items}"
assert len(index) == len(items),\
f"Items and index need to have same length ({len(items)}, {len(index)})"
return index
__all__ = ['Vocab']