Skip to content

step

allennlp.tango.step

[SOURCE]


AllenNLP Tango is an experimental API and parts of it might change or disappear every time we release a new version.

T

T = TypeVar("T")

StepCache

class StepCache(Registrable)

This is a mapping from instances of Step to the results of that step.

path_for_step

class StepCache(Registrable):
 | ...
 | def path_for_step(self, step: "Step") -> Optional[Path]

Steps that can be restarted (like a training job that gets interrupted half-way through) must save their state somewhere. A StepCache can help by providing a suitable location in this method.

MemoryStepCache

@StepCache.register("memory")
class MemoryStepCache(StepCache):
 | def __init__(self)

This is a StepCache that stores results in memory. It is little more than a Python dictionary.

default_step_cache

default_step_cache = MemoryStepCache()

DirectoryStepCache

@StepCache.register("directory")
class DirectoryStepCache(StepCache):
 | def __init__(self, dir: Union[str, PathLike])

This is a StepCache that stores its results on disk, in the location given in dir.

Every cached step gets a directory under dir with that step's unique_id(). In that directory we store the results themselves in some format according to the step's FORMAT, and we also write a metadata.json file that stores some metadata. The presence of metadata.json signifies that the cache entry is complete and has been written successfully.

LRU_CACHE_MAX_SIZE

class DirectoryStepCache(StepCache):
 | ...
 | LRU_CACHE_MAX_SIZE = 8

path_for_step

class DirectoryStepCache(StepCache):
 | ...
 | def path_for_step(self, step: "Step") -> Path

Step

class Step(Registrable,  Generic[T]):
 | def __init__(
 |     self,
 |     step_name: Optional[str] = None,
 |     cache_results: Optional[bool] = None,
 |     step_format: Optional[Format] = None,
 |     only_if_needed: Optional[bool] = None,
 |     **kwargs
 | )

This class defines one step in your experiment. To write your own step, just derive from this class and overwrite the run() method. The run() method must have parameters with type hints.

Step.__init__() takes all the arguments we want to run the step with. They get passed to Step.run() (almost) as they are. If the arguments are other instances of Step, those will be replaced with the step's results before calling run(). Further, there are four special parameters:

  • step_name contains an optional human-readable name for the step. This name is used for error messages and the like, and has no consequence on the actual computation.
  • cache_results specifies whether the results of this step should be cached. If this is False, the step is recomputed every time it is needed. If this is not set at all, we cache if the step is marked as DETERMINISTIC, and we don't cache otherwise.
  • step_format gives you a way to override the step's default format (which is given in FORMAT).
  • only_if_needed specifies whether we can skip this step if no other step depends on it. The default for this setting is to set it for all steps that don't have an explicit name.

default_implementation

class Step(Registrable,  Generic[T]):
 | ...
 | default_implementation = "ref"

DETERMINISTIC

class Step(Registrable,  Generic[T]):
 | ...
 | DETERMINISTIC: bool = False

This describes whether this step can be relied upon to produce the same results every time when given the same inputs. If this is False, the step can't be cached, and neither can any step that depends on it.

CACHEABLE

class Step(Registrable,  Generic[T]):
 | ...
 | CACHEABLE: Optional[bool] = None

This provides a direct way to turn off caching. For example, a step that reads a HuggingFace dataset doesn't need to be cached, because HuggingFace datasets already have their own caching mechanism. But it's still a deterministic step, and all following steps are allowed to cache. If it is None, the step figures out by itself whether it should be cacheable or not.

VERSION

class Step(Registrable,  Generic[T]):
 | ...
 | VERSION: Optional[str] = None

This is optional, but recommended. Specifying a version gives you a way to tell AllenNLP that a step has changed during development, and should now be recomputed. This doesn't invalidate the old results, so when you revert your code, the old cache entries will stick around and be picked up.

FORMAT

class Step(Registrable,  Generic[T]):
 | ...
 | FORMAT: Format = DillFormat("gz")

This specifies the format the results of this step will be serialized in. See the documentation for Format for details.

from_params

class Step(Registrable,  Generic[T]):
 | ...
 | @classmethod
 | def from_params(
 |     cls: Type["Step"],
 |     params: Params,
 |     constructor_to_call: Callable[..., "Step"] = None,
 |     constructor_to_inspect: Union[Callable[..., "Step"], Callable[["Step"], None]] = None,
 |     existing_steps: Optional[Dict[str, "Step"]] = None,
 |     step_name: Optional[str] = None,
 |     **extras
 | ) -> "Step"

Why do we need a custom from_params? Step classes have a run() method that takes all the parameters necessary to perform the step. The init() method of the step takes those same parameters, but each of them could be wrapped in another Step instead of being supplied directly. from_params() doesn't know anything about these shenanigans, so we have to supply the necessary logic here.

run

class Step(Registrable,  Generic[T]):
 | ...
 | @abstractmethod
 | def run(self, **kwargs) -> T

This is the main method of a step. Overwrite this method to define your step's action.

work_dir

class Step(Registrable,  Generic[T]):
 | ...
 | def work_dir(self) -> Path

Returns a work directory that a step can use while its run() method runs.

This directory stays around across restarts. You cannot assume that it is empty when your step runs, but you can use it to store information that helps you restart a step if it got killed half-way through the last time it ran.

result

class Step(Registrable,  Generic[T]):
 | ...
 | def result(self, cache: Optional[StepCache] = None) -> T

Returns the result of this step. If the results are cached, it returns those. Otherwise it runs the step and returns the result from there.

ensure_result

class Step(Registrable,  Generic[T]):
 | ...
 | def ensure_result(self, cache: Optional[StepCache] = None) -> None

This makes sure that the result of this step is in the cache. It does not return the result.

det_hash_object

class Step(Registrable,  Generic[T]):
 | ...
 | def det_hash_object(self) -> Any

unique_id

class Step(Registrable,  Generic[T]):
 | ...
 | def unique_id(self) -> str

Returns the unique ID for this step.

Unique IDs are of the shape $class_name-$version-$hash, where the hash is the hash of the inputs for deterministic steps, and a random string of characters for non-deterministic ones.

dependencies

class Step(Registrable,  Generic[T]):
 | ...
 | def dependencies(self) -> Set["Step"]

Returns a set of steps that this step depends on.

Does not return recursive dependencies.

recursive_dependencies

class Step(Registrable,  Generic[T]):
 | ...
 | def recursive_dependencies(self) -> Set["Step"]

Returns a set of steps that this step depends on.

This returns recursive dependencies.

step_graph_from_params

def step_graph_from_params(
    params: Dict[str, Params]
) -> Dict[str, Step]

Given a mapping from strings to Params objects, this parses each Params object into a Step, and resolved dependencies between the steps. Returns a dictionary mapping step names to instances of Step.

tango_dry_run

def tango_dry_run(
    step_or_steps: Union[Step, Iterable[Step]],
    step_cache: Optional[StepCache]
) -> List[Tuple[Step, bool]]

Returns the list of steps that will be run, or read from cache, if you call a step's result() method.

Steps come out as tuples (step, read_from_cache), so you can see which steps will be read from cache, and which have to be run.