Что такое TensorFlow Lite?

TensorFlow Lite обеспечивает встроенный (это означает, что он работает на самом мобильном устройстве) TensorFlow для мобильных устройств. Анонсированный в 2017 году программный стек TFLite разработан специально для разработки мобильных приложений.

TensorFlow Lite доступен на Android и iOS через C++ API и Java-оболочку для разработчиков Android. На устройствах, которые его поддерживают, библиотека также может использовать преимущества Android Neural Networks API для аппаратного ускорения.

В этой статье я покажу вам, как создать Android-приложение для обнаружения птиц с помощью TensorFlow Lite.

Для начала вам понадобится обученная модель .tflite.

Скачать можно здесь.

После загрузки назовите его как BirdsModel.tflite.

Добавьте TensorFlow Lite в приложение для Android.

  1. Запустите новый Java-проект Android Studio и добавьте следующие зависимости в файл build.gradle вашего приложения.
buildFeatures{
    mlModelBinding true
}
implementation 'org.tensorflow:tensorflow-lite-support:0.1.0'
implementation 'org.tensorflow:tensorflow-lite-metadata:0.1.0'

2. Добавьте TensorFlow Lite в приложение для Android.

Щелкните правой кнопкой мыши имя пакета или выберите «Файл», затем «Создать» -> «Другое» -> «Модель TensorFlow Lite». Выберите расположение модели, куда вы загрузили ранее (BirdsModel.tflite), и нажмите «Готово».

После добавления модели в ваш проект вы увидите сообщение о том, что модель была успешно импортирована.

Обработка изображения и показ результата

Ниже приведены простые шаги для реализации модели классификации птиц.

  1. Создайте переменную модели TensorFlow Lite и инициализируйте ее.
BirdsModel model = BirdsModel.newInstance(context);

2. Преобразование входного изображения в изображение тензорного потока.

TensorImage image = TensorImage.fromBitmap(bitmap);

3. Обработайте изображение и отсортируйте результат вывода в порядке убывания.

И выберите изображение с наибольшей вероятностью.

BirdsModel.Outputs outputs = model.process(image);
    List<Category> probability = outputs.getProbabilityAsCategoryList();

    int index = 0;
    float max = probability.get(0).getScore();
    for(int i=0;i<probability.size();i++){
    if(max<probability.get(i).getScore()){
        max=probability.get(i).getScore();
        index=i;
    }
}

4. Отобразите результат в TextView

Category output = probability.get(index);
tvResult.setText(output.getLabel());

Исходный код вышеуказанного проекта:

Файл макета: activity_main.xml

Этот макет содержит следующие представления:

  • Представление изображения для отображения нашей захваченной фотографии птицы
  • Кнопка для захвата изображения
  • Текстовый вид для отображения результата
<?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout
    xmlns:android="http://schemas.android.com/apk/res/android"
    xmlns:app="http://schemas.android.com/apk/res-auto"
    xmlns:tools="http://schemas.android.com/tools"
    android:layout_width="match_parent"
    android:background="#272343"
    android:layout_height="match_parent"
    tools:context=".MainActivity">

    <ImageView
        android:id="@+id/iv_add_image"
        android:layout_width="300dp"
        android:layout_height="450dp"
        android:layout_marginTop="20dp"
        android:src="@drawable/add_image_icon"
        app:layout_constraintEnd_toEndOf="parent"
        app:layout_constraintStart_toStartOf="parent"
        app:layout_constraintTop_toTopOf="parent"
        android:contentDescription="@string/image" />

    <androidx.appcompat.widget.AppCompatButton
        android:id="@+id/bt_load_image"
        android:layout_width="match_parent"
        android:layout_height="wrap_content"
        android:layout_marginTop="20dp"
        android:text="@string/load_image2"
        android:background="@drawable/design"
        app:layout_constraintEnd_toEndOf="parent"
        app:layout_constraintHorizontal_bias="0.0"
        app:layout_constraintStart_toStartOf="parent"
        app:layout_constraintTop_toBottomOf="@+id/iv_add_image" />

    <TextView
        android:id="@+id/result"
        android:layout_width="wrap_content"
        android:layout_height="wrap_content"
        android:layout_marginTop="36dp"
        android:text=""
        android:textStyle="bold"
        android:textSize="18sp"
        android:fontFamily="sans-serif-condensed"
        android:textColor="#ffffff"
        android:textAlignment="center"
        app:layout_constraintEnd_toEndOf="parent"
        app:layout_constraintHorizontal_bias="0.498"
        app:layout_constraintStart_toStartOf="parent"
        app:layout_constraintTop_toBottomOf="@+id/bt_load_image" />

    <TextView
        android:id="@+id/textView2"
        android:layout_width="wrap_content"
        android:layout_height="wrap_content"
        android:layout_marginBottom="48dp"
        android:textSize="19sp"
        android:textColor="#ffffff"
        android:textAlignment="center"
        android:text="@string/click_the_result_to_more_about_the_bird"
        android:visibility="invisible"
        app:layout_constraintBottom_toBottomOf="parent"
        app:layout_constraintEnd_toEndOf="parent"
        app:layout_constraintStart_toStartOf="parent" />



</androidx.constraintlayout.widget.ConstraintLayout>

Файл Java: MainActivity.java

package com.aravind.birdsdetection;

import androidx.activity.result.ActivityResultCallback;
import androidx.activity.result.ActivityResultLauncher;
import androidx.activity.result.contract.ActivityResultContracts;
import androidx.appcompat.app.AppCompatActivity;


import android.content.Intent;
import android.graphics.Bitmap;
import android.net.Uri;
import android.os.Bundle;
import android.provider.MediaStore;
import android.view.View;
import android.widget.Button;
import android.widget.ImageView;
import android.widget.TextView;

import com.aravind.birdsdetection.ml.BirdsModel;
import org.tensorflow.lite.support.image.TensorImage;
import org.tensorflow.lite.support.label.Category;


import java.io.IOException;
import java.util.List;

public class MainActivity extends AppCompatActivity {

    Button btLoadImage;
    TextView tvResult,textView2;
    ImageView ivAddImage;
    ActivityResultLauncher<String> mgetContent;

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        ivAddImage = findViewById(R.id.iv_add_image);
        tvResult = findViewById(R.id.result);
        btLoadImage = findViewById(R.id.bt_load_image);
        textView2 = findViewById(R.id.textView2);


        mgetContent = registerForActivityResult(new ActivityResultContracts.GetContent(), new ActivityResultCallback<Uri>() {
            @Override
            public void onActivityResult(Uri result) {
                Bitmap imageBitmap = null;
                try {
                    imageBitmap = UriToBitmap(result);
                } catch (IOException e) {
                    e.printStackTrace();
                }


                // ivAddImage.setImageURI(result);
                ivAddImage.setImageBitmap(imageBitmap);
                outputGenerator(imageBitmap);
            }
        });

        btLoadImage.setOnClickListener(new View.OnClickListener() {
            @Override
            public void onClick(View view) {
                mgetContent.launch("image/*");
            }
        });

        tvResult.setOnClickListener(new View.OnClickListener() {

            @Override
            public void onClick(View view) {
                Intent intent = new Intent(Intent.ACTION_VIEW,Uri.parse("https://www.google.com/search?q=" +tvResult.getText().toString()));
                startActivity(intent);
            }
        });
    }

    private void outputGenerator(Bitmap imageBitmap) {
        try {
                BirdsModel model = BirdsModel.newInstance(MainActivity.this);

                // Creates inputs for reference.
                TensorImage image = TensorImage.fromBitmap(imageBitmap);

                // Runs model inference and gets result.
                BirdsModel.Outputs outputs = model.process(image);
                List<Category> probability = outputs.getProbabilityAsCategoryList();

                // Releases model resources if no longer used.

                int index = 0;
                float max = probability.get(0).getScore();
                for(int i=0;i<probability.size();i++){
                if(max<probability.get(i).getScore()){
                    max=probability.get(i).getScore();
                    index=i;
                }
            }

            Category output = probability.get(index);
            tvResult.setText(output.getLabel());
            textView2.setVisibility(View.VISIBLE);


            // Releases model resources if no longer used.
            model.close();
            }catch (IOException e) {
            e.printStackTrace();

        }
    }

    private Bitmap UriToBitmap(Uri result) throws IOException {
        return MediaStore.Images.Media.getBitmap(this.getContentResolver(), result);

    }
}

Это базовая классификация изображений с Tensorflow Lite. Вы можете попробовать разные наборы данных и комбинации моделей.

Если вам понравилась эта статья и вы хотите поговорить об этом подробнее и поделиться своими работами, давайте подключимся в LinkedIn.

Спасибо, что прочитали!