Skip to content

Shot Features and Filters

A common task is to add a new feature to a shot, this document will walk you through that process.

Adding a feature

Adding the feature to feature extractions

The first step is to add the feature as an @property or @functools.cached_property on the FeatureExtraction class in railbird/shot_parsing/feature_detection/feature_extraction.py. The logic implemented in this method is obviously unique to your situation but other than that this process should be identical for any addtional features.

@functools.cached_property
def my_new_feature(self) -> bool:
    return feature is not None

Then you add the name of your new property to the method FeatureExtraction.eagerly_evaluate() (assuming that it makes sense to eagerly evaluate your new feature; it almost always does).

class FeatureExtraction:

    ... # Omitted for brevity

    def eagerly_evaluate(self):
        if self.cue_and_object_ids is None:
            self._add_warning("Could not find cue and object balls")
            return
        if self.exceptions:
            return
        try:
            self.object_ball
            self.intention_info
            ... # Omitted for brevity
            self.is_direct
            self.my_new_feature # We add our new feature to this list
        except Exception as e:
            self.exceptions.append(e)
            logger.warn(
                "Error eagerly evaluating feature_extraction for shot",
                exc_info=e,
                indices=(self.shot.start_index, self.shot.end_index),
            )

Adding the feature to the ShotModel

The next step is to ensure this feature is written to the database when we write a shot. We do this by altering the _apply_feature_extraction method on the ShotModel in railbird/datatypes/models/shot.py.

class ShotModel(Base):
    __tablename__ = "shot"

    ... # Omitted for brevity

    def _apply_feature_extraction(
        self,
        feature_extraction: shot_parsing.FeatureExtraction,
        frame_offset: int = 0,
    ):
        # Now we just add it to the model as usual

        # For the case of a feature directly on the shot we do something like:
        self.my_new_feature = feature_extraction.my_new_feature

        # But we can also implement it on an attached model like:
        self.my_new_feature_model = self.my_new_feature_model or MyNewFeatureModel()
        self.my_new_feature_model.my_new_feature = (
            feature_extraction.my_new_feature
        )

This part of the tutorial assumes you've handled the business of adding the field to the relevant model, and created the relevant changes in the database. You can read more about how to do that here.

The critical thing to note is that we make sure to map that value using info=_qb() in the parameters to the column when we define it in the model. This allows query builder to resolve the field.

So for example, if we have MyNewFeatureModel we define my_new_feature like so:

class MyNewFeatureModel(Base):

    __tablename__ = "my_new_feature"

    shot_id: Mapped[int] = mapped_column(
        sa.BIGINT,
        sa.ForeignKey("shot.id", onupdate="CASCADE", ondelete="CASCADE"),
        primary_key=True,
        nullable=False,
    )
    my_new_feature: Mapped[float] = mapped_column(
        sa.DECIMAL(precision=7, scale=3),
        index=True,
        info=_qb(), # THIS IS THE CRITICAL LINE. YOU MUST ADD THIS SO QUERY BUILDER CAN RESOLVE IT.
    )

Adding the feature (and optionally, a filter) to the GQL

The final step is to add the feature to the gql. That mostly follows this primer on the gql. Just add the field to the relevant gql type(s), run just gql, commit in the gql, commit in the railbird main repo. The major difference here is if we want to add a filter.

In the case of a filter we begin by adding the filter to the FilterInput type in railbird/datatypes/gql/filter.py. It is important that we add a filter of the correct type for the new field we want to filter over.

For example, to filter over a numerical value that may vary in some range, we add a field of type RangeFilter:

@strawberry.input
class RangeFilter:
    less_than: Optional[float] = None
    greater_than_equal_to: Optional[float] = None
    greater_than: Optional[float] = None
    include_on_none: bool = False
    less_than_inclusive: bool = False
    greater_than_inclusive: bool = True


@strawberry.input
class ValueFilterString:
    equals: Optional[str] = None


@strawberry.input
class ValueFilterBool:
    equals: Optional[bool] = None


@strawberry.input(one_of=True)
class FilterInput:
    and_filters: List["FilterInput"] | None = strawberry.UNSET
    or_filters: List["FilterInput"] | None = strawberry.UNSET
    not_filter: Union["FilterInput", None] = strawberry.UNSET

    ...

    # If the value varies over some range
    my_new_feature: RangeFilter | None = strawberry.UNSET
    # If we wanted to match some subset of a field with an enumerated type
    my_new_feature: List[MyNewFilterEnum] | None = strawberry.UNSET
    # If we wanted to match one or more specific values like an id
    my_new_feature_id: List[int] | None = strawberry.UNSET
    # Or a boolean
    my_new_feature: ValueFilterBool | None = strawberry.UNSET

Adding testing for your feature

In theory, this should "just" "work"; to be "certain" of that you should add some tests.

Chances are your new feature touches the ShotModel or one of its associated models, so be sure to extend the mock in tests/conftest.py so that your field is handled. For example:

@FixtureManager.db_fixture_factory
def shot_factory(video_factory, user_factory, default_serialized_shot):
    def _build_shot(session, *args, **params):
        ...
        shot = models.ShotModel(
            ...
            # If the feature is directly on the shot
            my_new_feature = params.pop("my_new_feature", None)
            # Or if the feature is part of an associated model
            my_new_feature=models.MyNewFeature(
                my_new_feature=params.pop("my_new_feature", None)
            ),
            **params,
        )
        ...

Write a test to be sure that you can query the field in tests/test_query_builder.py. This test assumes you're using a range filter:

@pytest.mark.writes_db
@pytest.mark.asyncio
async def test_my_new_feature_filter(
    api_server, video_factory, upload_factory, user_factory, shot_factory
):
    # Make fake shots
    shot_factory(video_id=video.id, my_new_feature=20)
    shot_factory(video_id=video.id, my_new_feature=30)
    shot_factory(video_id=video.id, my_new_feature=40)
    await aggregation_test(
        api_server,
        {
            "aggregations": [],
            "filterInput":{
                "myNewFeature": {"lessThan": 35, "greaterThanEqualTo": 10}
            },
        },
        # This count means we expect 2 of our shots to shake out with this filter
        [((), {"count": 2})],
        target_metrics=(dsl.TargetMetricsGQL.count,),
    )
    await aggregation_test(api_server, aggregate_input, expected_data)

Finally, optionally, you can write an actual feature test if you have a video that has a shot that typifies your new feature in tests/test_features.py.

For example:

@pytest.mark.writes_db
@pytest.mark.runs_model
def test_my_new_feature(
    get_feature_extractions,
    sync_session,
    user_factory,
    video_factory,
    identifier_path_factory,
):
    path_to_video_file = identifier_path_factory(
        # TODO: Ivan needs to fill in how to populate this thing...I cant remember right now...
        "4f292947c860a427a60e8eb6703c86cd32c6eeb2b71fcc40839a2f98db67eda4",
        "cue-features-after-object1.mp4",
    )
    feature_extraction = assert_one_and_get(
        get_feature_extractions(path_to_video_file, in_sample_videos=False)
    )
    assert feature_extraction.my_new_feature is not None
    assert feature_extraction.my_new_feature is <SOME EXPECTED VALUE>