Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions src/torchcodec/_core/FFMPEGCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,70 @@ int64_t computeSafeDuration(
}
}

std::optional<double> getRotationFromStream(const AVStream* avStream) {
// av_stream_get_side_data() was deprecated in FFmpeg 6.0, but its replacement
// (av_packet_side_data_get() + codecpar->coded_side_data) is only available
// from FFmpeg 6.1. We need some #pragma magic to silence the deprecation
// warning which our compile chain would otherwise treat as an error.
if (avStream == nullptr) {
return std::nullopt;
}

const int32_t* displayMatrix = nullptr;

// FFmpeg >= 6.1: Use codecpar->coded_side_data
#if LIBAVCODEC_VERSION_INT >= AV_VERSION_INT(60, 31, 100)
const AVPacketSideData* sideData = av_packet_side_data_get(
avStream->codecpar->coded_side_data,
avStream->codecpar->nb_coded_side_data,
AV_PKT_DATA_DISPLAYMATRIX);
if (sideData != nullptr) {
displayMatrix = reinterpret_cast<const int32_t*>(sideData->data);
}
#elif LIBAVFORMAT_VERSION_MAJOR >= 60 // FFmpeg 6.0
// FFmpeg 6.0: Use av_stream_get_side_data (deprecated but still available)
// Suppress deprecation warning for this specific call
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
size_t sideDataSize = 0;
const uint8_t* sideData = av_stream_get_side_data(
avStream, AV_PKT_DATA_DISPLAYMATRIX, &sideDataSize);
#pragma GCC diagnostic pop
if (sideData != nullptr) {
displayMatrix = reinterpret_cast<const int32_t*>(sideData);
}
#else
// FFmpeg < 6: Use av_stream_get_side_data.
// The size parameter type changed from int* (FFmpeg 4) to size_t* (FFmpeg 5)
#if LIBAVFORMAT_VERSION_MAJOR >= 59 // FFmpeg 5
size_t sideDataSize = 0;
#else // FFmpeg 4
int sideDataSize = 0;
#endif
const uint8_t* sideData = av_stream_get_side_data(
avStream, AV_PKT_DATA_DISPLAYMATRIX, &sideDataSize);
if (sideData != nullptr) {
displayMatrix = reinterpret_cast<const int32_t*>(sideData);
}
#endif

if (displayMatrix == nullptr) {
return std::nullopt;
}

// av_display_rotation_get returns the rotation angle in degrees needed to
// rotate the video counter-clockwise to make it upright.
// Returns NaN if the matrix is invalid.
double rotation = av_display_rotation_get(displayMatrix);

// Check for invalid matrix
if (std::isnan(rotation)) {
return std::nullopt;
}

return rotation;
}

SwsFrameContext::SwsFrameContext(
int inputWidth,
int inputHeight,
Expand Down
6 changes: 6 additions & 0 deletions src/torchcodec/_core/FFMPEGCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#pragma once

#include <memory>
#include <optional>
#include <stdexcept>
#include <string>

Expand Down Expand Up @@ -280,6 +281,11 @@ int64_t computeSafeDuration(
const AVRational& frameRate,
const AVRational& timeBase);

// Extracts the rotation angle in degrees from the stream's display matrix
// side data. The display matrix is used to specify how the video should be
// rotated for correct display.
std::optional<double> getRotationFromStream(const AVStream* avStream);

AVFilterContext* createAVFilterContextWithOptions(
AVFilterGraph* filterGraph,
const AVFilter* buffer,
Expand Down
7 changes: 5 additions & 2 deletions src/torchcodec/_core/Metadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,12 @@ struct StreamMetadata {
std::optional<int64_t> numFramesFromContent;

// Video-only fields
std::optional<int> width;
std::optional<int> height;
// Post-rotation dimensions
std::optional<int> postRotationWidth;
std::optional<int> postRotationHeight;
std::optional<AVRational> sampleAspectRatio;
// Rotation angle in degrees from display matrix, in the range [-180, 180].
std::optional<double> rotation;

// Audio-only fields
std::optional<int64_t> sampleRate;
Expand Down
93 changes: 68 additions & 25 deletions src/torchcodec/_core/SingleStreamDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,20 @@ void SingleStreamDecoder::initializeDecoder() {
if (fps > 0) {
streamMetadata.averageFpsFromHeader = fps;
}
streamMetadata.width = avStream->codecpar->width;
streamMetadata.height = avStream->codecpar->height;
streamMetadata.rotation = getRotationFromStream(avStream);

// Report post-rotation dimensions: swap width/height for 90 or -90
// degree rotations so metadata matches what the decoder returns.
int width = avStream->codecpar->width;
int height = avStream->codecpar->height;
Rotation rotation = rotationFromDegrees(streamMetadata.rotation);
// 90° rotations swap dimensions
if (rotation == Rotation::CCW90 || rotation == Rotation::CW90) {
std::swap(width, height);
}
streamMetadata.postRotationWidth = width;
streamMetadata.postRotationHeight = height;

streamMetadata.sampleAspectRatio =
avStream->codecpar->sample_aspect_ratio;
containerMetadata_.numVideoStreams++;
Expand Down Expand Up @@ -540,19 +552,46 @@ void SingleStreamDecoder::addVideoStream(
activeStreamIndex_, customFrameMappings.value());
}

metadataDims_ =
FrameDims(streamMetadata.height.value(), streamMetadata.width.value());
FrameDims currInputDims = metadataDims_;
// Set preRotationDims_ for the active stream. These are the raw encoded
// dimensions from FFmpeg, used as a fallback for tensor pre-allocation when
// no resize/rotation transforms are applied.
preRotationDims_ = FrameDims(
streamInfo.stream->codecpar->height, streamInfo.stream->codecpar->width);

FrameDims currInputDims = preRotationDims_;

// If there's rotation, prepend a RotationTransform to handle it in the
// filter graph. This way user transforms (resize, crop) operate in
// post-rotation coordinate space, preserving x/y coordinates for crops.
//
// It is critical to apply the rotation *before* any user-supplied
// transforms. By design, we want:
// A: VideoDecoder(..., transforms=tv_transforms)[i]
// to be equivalent to:
// B: tv_transforms(VideoDecoder(...)[i])
// In B, rotation is applied before transforms, so A must behave the same.
//
// TODO: benchmark the performance of doing this additional filtergraph
// transform
Rotation rotation = rotationFromDegrees(streamMetadata.rotation);
if (rotation != Rotation::NONE) {
auto rotationTransform =
std::make_unique<RotationTransform>(rotation, currInputDims);
currInputDims = rotationTransform->getOutputFrameDims().value();
resizedOutputDims_ = currInputDims;
transforms_.push_back(std::move(rotationTransform));
}

// Note that we are claiming ownership of the transform objects passed in to
// us.
// Validate and add user transforms
for (auto& transform : transforms) {
TORCH_CHECK(transform != nullptr, "Transforms should never be nullptr!");
transform->validate(currInputDims);
if (transform->getOutputFrameDims().has_value()) {
resizedOutputDims_ = transform->getOutputFrameDims().value();
currInputDims = resizedOutputDims_.value();
}
transform->validate(currInputDims);
currInputDims = resizedOutputDims_.value_or(metadataDims_);

// Note that we are claiming ownership of the transform objects passed in to
// us.
transforms_.push_back(std::unique_ptr<Transform>(transform));
}

Expand Down Expand Up @@ -679,9 +718,7 @@ FrameBatchOutput SingleStreamDecoder::getFramesAtIndices(
const auto& streamInfo = streamInfos_[activeStreamIndex_];
const auto& videoStreamOptions = streamInfo.videoStreamOptions;
FrameBatchOutput frameBatchOutput(
frameIndices.numel(),
resizedOutputDims_.value_or(metadataDims_),
videoStreamOptions.device);
frameIndices.numel(), getOutputDims(), videoStreamOptions.device);

auto previousIndexInVideo = -1;
for (int64_t f = 0; f < frameIndices.numel(); ++f) {
Expand Down Expand Up @@ -738,9 +775,7 @@ FrameBatchOutput SingleStreamDecoder::getFramesInRange(
int64_t numOutputFrames = std::ceil((stop - start) / double(step));
const auto& videoStreamOptions = streamInfo.videoStreamOptions;
FrameBatchOutput frameBatchOutput(
numOutputFrames,
resizedOutputDims_.value_or(metadataDims_),
videoStreamOptions.device);
numOutputFrames, getOutputDims(), videoStreamOptions.device);

for (int64_t i = start, f = 0; i < stop; i += step, ++f) {
FrameOutput frameOutput =
Expand Down Expand Up @@ -873,9 +908,7 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange(
// below. Hence, we need this special case below.
if (startSeconds == stopSeconds) {
FrameBatchOutput frameBatchOutput(
0,
resizedOutputDims_.value_or(metadataDims_),
videoStreamOptions.device);
0, getOutputDims(), videoStreamOptions.device);
frameBatchOutput.data = maybePermuteHWC2CHW(frameBatchOutput.data);
return frameBatchOutput;
}
Expand Down Expand Up @@ -918,9 +951,7 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange(
int64_t numOutputFrames = static_cast<int64_t>(std::round(product));

FrameBatchOutput frameBatchOutput(
numOutputFrames,
resizedOutputDims_.value_or(metadataDims_),
videoStreamOptions.device);
numOutputFrames, getOutputDims(), videoStreamOptions.device);

// Decode frames, reusing already-decoded frames for duplicates
int64_t lastDecodedSourceIndex = -1;
Expand Down Expand Up @@ -961,9 +992,7 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange(
int64_t numFrames = stopFrameIndex - startFrameIndex;

FrameBatchOutput frameBatchOutput(
numFrames,
resizedOutputDims_.value_or(metadataDims_),
videoStreamOptions.device);
numFrames, getOutputDims(), videoStreamOptions.device);
for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) {
FrameOutput frameOutput =
getFrameAtIndexInternal(i, frameBatchOutput.data[f]);
Expand Down Expand Up @@ -1513,6 +1542,20 @@ int64_t SingleStreamDecoder::getPts(int64_t frameIndex) {
}
}

FrameDims SingleStreamDecoder::getOutputDims() const {
const auto& streamMetadata =
containerMetadata_.allStreamMetadata[activeStreamIndex_];
Rotation rotation = rotationFromDegrees(streamMetadata.rotation);
// If there is a rotation, then resizedOutputDims_ is necessarily non-null
// (the rotation transform would have set it).
if (rotation != Rotation::NONE) {
TORCH_CHECK(
resizedOutputDims_.has_value(),
"Internal error: rotation is applied but resizedOutputDims_ is not set");
}
return resizedOutputDims_.value_or(preRotationDims_);
}

// --------------------------------------------------------------------------
// STREAM AND METADATA APIS
// --------------------------------------------------------------------------
Expand Down
27 changes: 20 additions & 7 deletions src/torchcodec/_core/SingleStreamDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,16 @@ class SingleStreamDecoder {

int64_t getPts(int64_t frameIndex);

// Returns the output frame dimensions for video frames.
// If resizedOutputDims_ is set (via resize, crop, or rotation transforms),
// returns that. Otherwise, returns preRotationDims_.
//
// Note: if resizedOutputDims_ is null, there is no rotation (the
// rotation transform would have set it), so preRotationDims_ ==
// postRotationDims_. This makes it safe to use preRotationDims_ as the
// fallback.
FrameDims getOutputDims() const;

// --------------------------------------------------------------------------
// STREAM AND METADATA APIS
// --------------------------------------------------------------------------
Expand Down Expand Up @@ -362,18 +372,21 @@ class SingleStreamDecoder {
// resizedOutputDims_. If resizedOutputDims_ has no value, that means there
// are no transforms that change the output frame dimensions.
//
// The priority order for output frame dimension is:
// The priority order for output frame dimensions is:
//
// 1. resizedOutputDims_; the resize requested by the user always takes
// priority.
// 2. The dimemnsions of the actual decoded AVFrame. This can change
// 1. resizedOutputDims_; the resize requested by the user (or rotation)
// always takes priority.
// 2. The dimensions of the actual decoded AVFrame. This can change
// per-decoded frame, and is unknown in SingleStreamDecoder. Only the
// DeviceInterface learns it immediately after decoding a raw frame but
// before the color transformation.
// 3. metdataDims_; the dimensions we learned from the metadata.
// before the color conversion.
// 3. preRotationDims_; the raw encoded dimensions from FFmpeg metadata
// (before any rotation is applied). Used as fallback for tensor
// allocation when resizedOutputDims_ is not set, which only happens
// when no rotation is needed, so preRotationDims_ is the correct value.
std::vector<std::unique_ptr<Transform>> transforms_;
std::optional<FrameDims> resizedOutputDims_;
FrameDims metadataDims_;
FrameDims preRotationDims_;

// Whether or not we have already scanned all streams to update the metadata.
bool scannedAllStreams_ = false;
Expand Down
56 changes: 56 additions & 0 deletions src/torchcodec/_core/Transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,60 @@ void CropTransform::validate(const FrameDims& inputDims) const {
}
}

Rotation rotationFromDegrees(std::optional<double> degrees) {
if (!degrees.has_value()) {
return Rotation::NONE;
}
// Round to nearest multiple of 90 degrees
int rounded = static_cast<int>(std::round(*degrees / 90.0)) * 90;
switch (rounded) {
case 0:
return Rotation::NONE;
case 90:
return Rotation::CCW90;
case -90:
return Rotation::CW90;
case 180:
case -180:
return Rotation::ROTATE180;
default:
TORCH_CHECK(
false,
"Unexpected rotation value: ",
*degrees,
". Expected range is [-180, 180].");
}
}

RotationTransform::RotationTransform(
Rotation rotation,
const FrameDims& inputDims)
: rotation_(rotation) {
// 90° rotations swap dimensions
if (rotation_ == Rotation::CCW90 || rotation_ == Rotation::CW90) {
outputDims_ = FrameDims(inputDims.width, inputDims.height);
} else {
outputDims_ = inputDims;
}
}

std::string RotationTransform::getFilterGraphCpu() const {
switch (rotation_) {
case Rotation::NONE:
return "";
case Rotation::CCW90:
return "transpose=cclock";
case Rotation::CW90:
return "transpose=clock";
case Rotation::ROTATE180:
return "hflip,vflip";
default:
return "";
}
}

std::optional<FrameDims> RotationTransform::getOutputFrameDims() const {
return outputDims_;
}

} // namespace facebook::torchcodec
Loading
Loading