Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tensorflow/lite/kernels/op_macros.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ limitations under the License.
#include <cstdlib>
#define TFLITE_ABORT abort()
#else
#include <cstdlib>
inline void AbortImpl() {
MicroPrintf("HALTED");
while (1) {
}
abort();
}
#define TFLITE_ABORT AbortImpl();
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,10 @@ TfLiteStatus LoadMicroSpeechModelAndPerformInference(

tflite::MicroInterpreter interpreter(model, op_resolver, g_arena, kArenaSize);

MicroPrintf("%s: pre AllocateTensors", __func__);
TF_LITE_MICRO_EXPECT(interpreter.AllocateTensors() == kTfLiteOk);
TF_LITE_MICRO_CHECK_FAIL();
MicroPrintf("%s: post AllocateTensors", __func__);

MicroPrintf("MicroSpeech model arena size = %u",
interpreter.arena_used_bytes());
Expand All @@ -123,8 +125,10 @@ TfLiteStatus LoadMicroSpeechModelAndPerformInference(

std::copy_n(&features[0][0], kFeatureElementCount,
tflite::GetTensorData<int8_t>(input));
MicroPrintf("%s: pre Invoke", __func__);
TF_LITE_MICRO_EXPECT(interpreter.Invoke() == kTfLiteOk);
TF_LITE_MICRO_CHECK_FAIL();
MicroPrintf("%s: post Invoke", __func__);

// Dequantize output values
float category_predictions[kCategoryCount];
Expand Down Expand Up @@ -211,6 +215,7 @@ TfLiteStatus GenerateFeatures(const int16_t* audio_data,
feature_index++;
audio_data += kAudioSampleStrideCount;
remaining_samples -= kAudioSampleStrideCount;
MicroPrintf("Generated single feature %u", feature_index);
}

return kTfLiteOk;
Expand Down
44 changes: 34 additions & 10 deletions tensorflow/lite/micro/kernels/decode.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <utility>

#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/kernel_util.h"
Expand Down Expand Up @@ -49,6 +51,27 @@ TfLiteStatus SetOutputTensorData(TfLiteContext* context, const TfLiteNode* node,
return kTfLiteOk;
}

DecodeState* GetDecodeStateFromCustomRegistration(const TfLiteContext* context,
uint8_t type) {
const MicroContext* mc = GetMicroContext(context);
const MicroContext::CustomDecodeRegistration* registrations;
size_t registrations_count;
std::tie(registrations, registrations_count) =
mc->GetCustomDecodeRegistrations();
if (registrations == nullptr) {
return nullptr;
}

for (size_t i = 0; i < registrations_count; i++) {
auto& reg = registrations[i];
if (reg.type == type && reg.create_state != nullptr) {
return reg.create_state(context, mc->GetAlternateProfiler());
}
}

return nullptr;
}

TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const size_t num_inputs = NumInputs(node);
const size_t num_outputs = NumOutputs(node);
Expand Down Expand Up @@ -113,21 +136,22 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
dsp = DecodeState::CreateDecodeStateHuffman(
context, micro_context->GetAlternateProfiler());
break;
case DecodeState::kDcmTypeCustom:
MicroPrintf("Custom decode type not yet supported");
break;
default:
MicroPrintf("unsupported decode type %u",
DecodeState::Type(*ancillary));
uint32_t type = DecodeState::Type(*ancillary);
if (type >= DecodeState::kDcmTypeCustomFirst &&
type <= DecodeState::kDcmTypeCustomLast) {
dsp = GetDecodeStateFromCustomRegistration(context, type);
} else {
MicroPrintf("unsupported decode type %u", type);
}
break;
}

status = SetOutputTensorData(context, node, i / 2, output);
if (status != kTfLiteOk) {
break;
}

if (dsp != nullptr) {
status = SetOutputTensorData(context, node, i / 2, output);
if (status != kTfLiteOk) {
break;
}
status = dsp->Setup(*input, *ancillary, *output);
if (status != kTfLiteOk) {
break;
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/lite/micro/kernels/decode_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ class DecodeState {
static constexpr uint8_t kDcmTypeLUT = 0;
static constexpr uint8_t kDcmTypeHuffman = 1;
static constexpr uint8_t kDcmTypePrune = 2;
static constexpr uint8_t kDcmTypeCustom = 127;
static constexpr uint8_t kDcmTypeCustomFirst = 128;
static constexpr uint8_t kDcmTypeCustomLast = 255;

static constexpr size_t kDcmSizeInBytes = 16;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ TF_LITE_MICRO_TEST(DecodeHuffmanTable16BitsInt16Fail) {
tflite::testing::TestDecode<encodes.size() + ancillaries.size(),
outputs.size()>(
encodes, ancillaries, outputs, expected, tflite::Register_DECODE(),
nullptr, kTfLiteError);
nullptr, nullptr, kTfLiteError);
}

TF_LITE_MICRO_TEST(DecodeHuffmanTable32BitsInt8) {
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/lite/micro/kernels/decode_state_prune_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ TF_LITE_MICRO_TEST(DecodePruneQuantizedInvalidZeroPointInt16) {
tflite::testing::TestDecode<kEncodes.size() + kAncillaries.size(),
kOutputs.size()>(
kEncodes, kAncillaries, kOutputs, kExpected, tflite::Register_DECODE(),
nullptr, kTfLiteError);
nullptr, nullptr, kTfLiteError);
}

TF_LITE_MICRO_TESTS_END
129 changes: 129 additions & 0 deletions tensorflow/lite/micro/kernels/decode_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,76 @@ constexpr int kEncodedShapeLUT[] = {1, sizeof(kEncodedLUT)};
constexpr int8_t kExpectLUT0[] = {1, 2, 3, 4, 4, 3, 2, 1};
constexpr int16_t kExpectLUT1[] = {5, 6, 7, 8, 8, 7, 6, 5};

//
// Custom DECODE test data
//
constexpr int kDecodeTypeCustom = 200;

constexpr int8_t kAncillaryDataCustom[] = {0x42};

constexpr uint8_t kDcmCustom[tflite::DecodeState::kDcmSizeInBytes] = {
kDecodeTypeCustom, // type: custom
1, // DCM version: 1
};

// Align the tensor data the same as a Buffer in the TfLite schema
alignas(16) const uint8_t kEncodedCustom[] = {0x42, 0x43, 0x40, 0x46,
0x4A, 0x52, 0x62, 0x02};

// Tensor shapes as TfLiteIntArray
constexpr int kOutputShapeCustom[] = {1, 8};
constexpr int kEncodedShapeCustom[] = {1, sizeof(kEncodedCustom)};

constexpr int8_t kExpectCustom[] = {0x00, 0x01, 0x02, 0x04,
0x08, 0x10, 0x20, 0x40};

class DecodeStateCustom : public tflite::DecodeState {
public:
DecodeStateCustom() = delete;

DecodeStateCustom(const TfLiteContext* context,
tflite::MicroProfilerInterface* profiler)
: DecodeState(context, profiler) {}

virtual TfLiteStatus Setup(const TfLiteTensor& input,
const TfLiteTensor& ancillary,
const TfLiteTensor& output) override {
return kTfLiteOk;
}

virtual TfLiteStatus Decode(const TfLiteEvalTensor& input,
const TfLiteEvalTensor& ancillary,
const TfLiteEvalTensor& output) override {
const uint8_t* inp = tflite::micro::GetTensorData<uint8_t>(&input);
TF_LITE_ENSURE(const_cast<TfLiteContext*>(context_), inp != nullptr);
uint8_t* outp = tflite::micro::GetTensorData<uint8_t>(
const_cast<TfLiteEvalTensor*>(&output));
TF_LITE_ENSURE(const_cast<TfLiteContext*>(context_), outp != nullptr);
const uint8_t* vp = tflite::micro::GetTensorData<uint8_t>(&ancillary);
TF_LITE_ENSURE(const_cast<TfLiteContext*>(context_), vp != nullptr);
vp += kDcmSizeInBytes;

// simple XOR de-obfuscation
std::transform(inp, inp + input.dims->data[0], outp,
[vp](uint8_t i) { return i ^ *vp; });

return kTfLiteOk;
}

static DecodeState* CreateDecodeStateCustom(
const TfLiteContext* context, tflite::MicroProfilerInterface* profiler) {
alignas(4) static uint8_t buffer[sizeof(DecodeStateCustom)];
DecodeState* instance = new (buffer) DecodeStateCustom(context, profiler);
return instance;
}

protected:
virtual ~DecodeStateCustom() = default;

private:
TF_LITE_REMOVE_VIRTUAL_DELETE
};

} // namespace

TF_LITE_MICRO_TESTS_BEGIN
Expand Down Expand Up @@ -246,4 +316,63 @@ TF_LITE_MICRO_TEST(DecodeWithAltDecompressionMemory) {
encodes, ancillaries, outputs, expected, tflite::Register_DECODE(), &amr);
}

TF_LITE_MICRO_TEST(DecodeWithCustomRegistration) {
// Align the tensor data the same as a Buffer in the TfLite schema
alignas(16) int8_t output_data[std::size(kExpectCustom)] = {};
alignas(16) const AncillaryData<int8_t, std::size(kAncillaryDataCustom)>
kAncillaryData = {{kDcmCustom}, {kAncillaryDataCustom}};

constexpr int kAncillaryShapeCustom[] = {1, sizeof(kAncillaryData)};

const TfLiteIntArray* const encoded_dims =
tflite::testing::IntArrayFromInts(kEncodedShapeCustom);
static const TensorInDatum tid_encode = {
kEncodedCustom,
*encoded_dims,
};
static constexpr std::initializer_list<const TensorInDatum*> encodes = {
&tid_encode,
};

const TfLiteIntArray* const ancillary_dims =
tflite::testing::IntArrayFromInts(kAncillaryShapeCustom);
static const TensorInDatum tid_ancillary = {
&kAncillaryData,
*ancillary_dims,
};
static constexpr std::initializer_list<const TensorInDatum*> ancillaries = {
&tid_ancillary};

const TfLiteIntArray* const output_dims =
tflite::testing::IntArrayFromInts(kOutputShapeCustom);
constexpr int kOutputZeroPointsData[] = {0};
const TfLiteIntArray* const kOutputZeroPoints =
tflite::testing::IntArrayFromInts(kOutputZeroPointsData);
const TfLiteFloatArray kOutputScales = {kOutputZeroPoints->size};
static const TensorOutDatum tod = {
output_data, *output_dims, kTfLiteInt8, kOutputScales, *kOutputZeroPoints,
0, {},
};
static constexpr std::initializer_list<const TensorOutDatum*> outputs = {
&tod};

const std::initializer_list<const void*> expected = {kExpectCustom};

const std::initializer_list<tflite::MicroContext::CustomDecodeRegistration>
cdr = {
{
kDecodeTypeCustom,
0, // reserved
0, // reserved
0, // reserved
DecodeStateCustom::CreateDecodeStateCustom,
},
};

tflite::testing::TestDecode<encodes.size() + ancillaries.size(),
outputs.size()>(
encodes, ancillaries, outputs, expected, tflite::Register_DECODE(),
nullptr, &cdr);
}

TF_LITE_MICRO_TESTS_END
10 changes: 9 additions & 1 deletion tensorflow/lite/micro/kernels/decode_test_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ TfLiteStatus ExecuteDecodeTest(
TfLiteTensor* tensors, const TFLMRegistration& registration,
const std::initializer_list<const void*>& expected,
const std::initializer_list<MicroContext::AlternateMemoryRegion>* amr =
nullptr,
const std::initializer_list<MicroContext::CustomDecodeRegistration>* cdr =
nullptr) {
int kInputArrayData[kNumInputs + 1] = {kNumInputs};
for (size_t i = 0; i < kNumInputs; i++) {
Expand All @@ -105,6 +107,10 @@ TfLiteStatus ExecuteDecodeTest(
runner.GetFakeMicroContext()->SetDecompressionMemory(amr->begin(),
amr->size());
}
if (cdr != nullptr) {
runner.GetFakeMicroContext()->SetCustomDecodeRegistrations(cdr->begin(),
cdr->size());
}

if (runner.InitAndPrepare() != kTfLiteOk || runner.Invoke() != kTfLiteOk) {
return kTfLiteError;
Expand Down Expand Up @@ -150,6 +156,8 @@ void TestDecode(
const TFLMRegistration& registration,
const std::initializer_list<MicroContext::AlternateMemoryRegion>* amr =
nullptr,
const std::initializer_list<MicroContext::CustomDecodeRegistration>* cdr =
nullptr,
const TfLiteStatus expected_status = kTfLiteOk) {
TfLiteTensor tensors[kNumInputs + kNumOutputs] = {};

Expand Down Expand Up @@ -183,7 +191,7 @@ void TestDecode(
}

TfLiteStatus s = ExecuteDecodeTest<kNumInputs, kNumOutputs>(
tensors, registration, expected, amr);
tensors, registration, expected, amr, cdr);
TF_LITE_MICRO_EXPECT_EQ(s, expected_status);
}

Expand Down
2 changes: 2 additions & 0 deletions tensorflow/lite/micro/kernels/kernel_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/micro/micro_common.h"
#include "tensorflow/lite/micro/micro_context.h"
#include "tensorflow/lite/micro/micro_graph.h"

#ifdef USE_TFLM_COMPRESSION

Expand Down
10 changes: 10 additions & 0 deletions tensorflow/lite/micro/micro_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,4 +174,14 @@ void MicroContext::ResetDecompressionMemoryAllocations() {
std::fill_n(decompress_regions_allocations_, decompress_regions_size_, 0);
}

TfLiteStatus MicroContext::SetCustomDecodeRegistrations(
const CustomDecodeRegistration* registrations, size_t count) {
if (custom_decode_registrations_ != nullptr) {
return kTfLiteError;
}
custom_decode_registrations_ = registrations;
custom_decode_registrations_size_ = count;
return kTfLiteOk;
}

} // namespace tflite
Loading
Loading