Shot Features and Filters
A common task is to add a new feature to a shot, this document will walk you through that process.
Adding a feature
Adding the feature to feature extractions
The first step is to add the feature as an @property
or @functools.cached_property
on the FeatureExtraction
class in railbird/shot_parsing/feature_detection/feature_extraction.py. The logic implemented in this method is obviously unique to your situation but other than that this process should be identical for any addtional features.
@functools.cached_property
def my_new_feature(self) -> bool:
return feature is not None
Then you add the name of your new property to the method FeatureExtraction.eagerly_evaluate()
(assuming that it makes sense to eagerly evaluate your new feature; it almost always does).
class FeatureExtraction:
... # Omitted for brevity
def eagerly_evaluate(self):
if self.cue_and_object_ids is None:
self._add_warning("Could not find cue and object balls")
return
if self.exceptions:
return
try:
self.object_ball
self.intention_info
... # Omitted for brevity
self.is_direct
self.my_new_feature # We add our new feature to this list
except Exception as e:
self.exceptions.append(e)
logger.warn(
"Error eagerly evaluating feature_extraction for shot",
exc_info=e,
indices=(self.shot.start_index, self.shot.end_index),
)
Adding the feature to the ShotModel
The next step is to ensure this feature is written to the database when we write a shot. We do this by altering the _apply_feature_extraction
method on the ShotModel in railbird/datatypes/models/shot.py.
class ShotModel(Base):
__tablename__ = "shot"
... # Omitted for brevity
def _apply_feature_extraction(
self,
feature_extraction: shot_parsing.FeatureExtraction,
frame_offset: int = 0,
):
# Now we just add it to the model as usual
# For the case of a feature directly on the shot we do something like:
self.my_new_feature = feature_extraction.my_new_feature
# But we can also implement it on an attached model like:
self.my_new_feature_model = self.my_new_feature_model or MyNewFeatureModel()
self.my_new_feature_model.my_new_feature = (
feature_extraction.my_new_feature
)
This part of the tutorial assumes you've handled the business of adding the field to the relevant model, and created the relevant changes in the database. You can read more about how to do that here.
The critical thing to note is that we make sure to map that value using info=_qb()
in the parameters to the column when we define it in the model. This allows query builder to resolve the field.
So for example, if we have MyNewFeatureModel
we define my_new_feature
like so:
class MyNewFeatureModel(Base):
__tablename__ = "my_new_feature"
shot_id: Mapped[int] = mapped_column(
sa.BIGINT,
sa.ForeignKey("shot.id", onupdate="CASCADE", ondelete="CASCADE"),
primary_key=True,
nullable=False,
)
my_new_feature: Mapped[float] = mapped_column(
sa.DECIMAL(precision=7, scale=3),
index=True,
info=_qb(), # THIS IS THE CRITICAL LINE. YOU MUST ADD THIS SO QUERY BUILDER CAN RESOLVE IT.
)
Adding the feature (and optionally, a filter) to the GQL
The final step is to add the feature to the gql. That mostly follows this primer on the gql. Just add the field to the relevant gql type(s), run just gql
, commit in the gql, commit in the railbird main repo. The major difference here is if we want to add a filter.
In the case of a filter we begin by adding the filter to the FilterInput
type in railbird/datatypes/gql/filter.py. It is important that we add a filter of the correct type for the new field we want to filter over.
For example, to filter over a numerical value that may vary in some range, we add a field of type RangeFilter:
@strawberry.input
class RangeFilter:
less_than: Optional[float] = None
greater_than_equal_to: Optional[float] = None
greater_than: Optional[float] = None
include_on_none: bool = False
less_than_inclusive: bool = False
greater_than_inclusive: bool = True
@strawberry.input
class ValueFilterString:
equals: Optional[str] = None
@strawberry.input
class ValueFilterBool:
equals: Optional[bool] = None
@strawberry.input(one_of=True)
class FilterInput:
and_filters: List["FilterInput"] | None = strawberry.UNSET
or_filters: List["FilterInput"] | None = strawberry.UNSET
not_filter: Union["FilterInput", None] = strawberry.UNSET
...
# If the value varies over some range
my_new_feature: RangeFilter | None = strawberry.UNSET
# If we wanted to match some subset of a field with an enumerated type
my_new_feature: List[MyNewFilterEnum] | None = strawberry.UNSET
# If we wanted to match one or more specific values like an id
my_new_feature_id: List[int] | None = strawberry.UNSET
# Or a boolean
my_new_feature: ValueFilterBool | None = strawberry.UNSET
Adding testing for your feature
In theory, this should "just" "work"; to be "certain" of that you should add some tests.
Chances are your new feature touches the ShotModel or one of its associated models, so be sure to extend the mock in tests/conftest.py so that your field is handled. For example:
@FixtureManager.db_fixture_factory
def shot_factory(video_factory, user_factory, default_serialized_shot):
def _build_shot(session, *args, **params):
...
shot = models.ShotModel(
...
# If the feature is directly on the shot
my_new_feature = params.pop("my_new_feature", None)
# Or if the feature is part of an associated model
my_new_feature=models.MyNewFeature(
my_new_feature=params.pop("my_new_feature", None)
),
**params,
)
...
Write a test to be sure that you can query the field in tests/test_query_builder.py. This test assumes you're using a range filter:
@pytest.mark.writes_db
@pytest.mark.asyncio
async def test_my_new_feature_filter(
api_server, video_factory, upload_factory, user_factory, shot_factory
):
# Make fake shots
shot_factory(video_id=video.id, my_new_feature=20)
shot_factory(video_id=video.id, my_new_feature=30)
shot_factory(video_id=video.id, my_new_feature=40)
await aggregation_test(
api_server,
{
"aggregations": [],
"filterInput":{
"myNewFeature": {"lessThan": 35, "greaterThanEqualTo": 10}
},
},
# This count means we expect 2 of our shots to shake out with this filter
[((), {"count": 2})],
target_metrics=(dsl.TargetMetricsGQL.count,),
)
await aggregation_test(api_server, aggregate_input, expected_data)
Finally, optionally, you can write an actual feature test if you have a video that has a shot that typifies your new feature in tests/test_features.py.
For example:
@pytest.mark.writes_db
@pytest.mark.runs_model
def test_my_new_feature(
get_feature_extractions,
sync_session,
user_factory,
video_factory,
identifier_path_factory,
):
path_to_video_file = identifier_path_factory(
# TODO: Ivan needs to fill in how to populate this thing...I cant remember right now...
"4f292947c860a427a60e8eb6703c86cd32c6eeb2b71fcc40839a2f98db67eda4",
"cue-features-after-object1.mp4",
)
feature_extraction = assert_one_and_get(
get_feature_extractions(path_to_video_file, in_sample_videos=False)
)
assert feature_extraction.my_new_feature is not None
assert feature_extraction.my_new_feature is <SOME EXPECTED VALUE>