首页 > Windows开发 > 详细

Tensorflow object_detection API源码——mobileDet

时间:2021-09-05 23:26:54      阅读:27      评论:0      收藏:0      [点我收藏+]

接上https://www.cnblogs.com/shines87/p/15140744.html

一、代码整理

项目执行命令:

object_detection>python model_main.py   --logtostderr   --model_dir=../image/   --pipeline_config_path=../image/ssdlite_mobiledet_cpu_320x320_coco_sync_4x4.config

说明:关于代码中很多未执行的判断语句我都删了和精简了,警告语句tf.logging也删了!

1、model_main.py  该文件主要四部分

import tensorflow.compat.v1 as tf
from absl import flags
from object_detection import model_lib

# ①===== 命令行参数设定
flags.DEFINE_string(
    model_dir, None, Path to output model directory 
                       where event and checkpoint files will be written.)
flags.DEFINE_string(pipeline_config_path, None, Path to pipeline config )
FLAGS = flags.FLAGS


def main(unused_argv):
    flags.mark_flag_as_required(model_dir)
    flags.mark_flag_as_required(pipeline_config_path)

    # ②===== 模型部分: 由配置信息初始化,返回4个fn和一个estimator
    train_and_eval_dict = model_lib.create_estimator_and_inputs(  
        run_config=tf.estimator.RunConfig(model_dir=FLAGS.model_dir),   # "../image"
        pipeline_config_path=FLAGS.pipeline_config_path,  # .config
        train_steps=FLAGS.num_train_steps,  # None
        sample_1_of_n_eval_examples=FLAGS.sample_1_of_n_eval_examples,  # 1
        sample_1_of_n_eval_on_train_examples=FLAGS.sample_1_of_n_eval_on_train_examples)    # 5

    estimator = train_and_eval_dict[estimator]
    train_input_fn = train_and_eval_dict[train_input_fn]
    eval_input_fns = train_and_eval_dict[eval_input_fns]
    eval_on_train_input_fn = train_and_eval_dict[eval_on_train_input_fn]
    predict_input_fn = train_and_eval_dict[predict_input_fn]
    train_steps = train_and_eval_dict[train_steps]  # 4000

    # ③===== 数据部分: 由上一步返回的4个fn创建
    train_spec, eval_specs = model_lib.create_train_and_eval_specs(
        train_input_fn,
        eval_input_fns,
        eval_on_train_input_fn,
        predict_input_fn,
        train_steps,
        eval_on_train_data=False)

    # Currently only a single Eval Spec is allowed.
    # ④===== 训练
    tf.estimator.train_and_evaluate(estimator, train_spec, eval_specs[0])


if __name__ == __main__:
    tf.app.run()

 以上四部分主要的时两个函数 model_lib.create_estimator_and_inputs和model_lib.create_train_and_eval_specs,先看第一个

二、model_lib.create_estimator_and_inputs函数

 依赖了configs_util中3个函数, inputs中3个函数,还有同文件下的create_model_fn函数。

import functools
import copy
from object_detection.utils import config_util
from object_detection import inputs
def
create_estimator_and_inputs(run_config, pipeline_config_path=None, # .config sample_1_of_n_eval_examples=1, # 1 sample_1_of_n_eval_on_train_examples=1, # 5 model_fn_creator=create_model_fn, **kwargs): configs = config_util.get_configs_from_pipeline_file(pipeline_config_path, config_override=None) # None kwargs.update({ train_steps: None, # because train_steps is None use_bfloat16: False # because configs[‘train_config‘].use_bfloat16 and use_tpu is False and False }) kwargs.update({sample_1_of_n_eval_examples: sample_1_of_n_eval_examples}) # because (sample_1_of_n_eval_examples=1) >= 1 kwargs.update({eval_num_epochs: 1}) # because override_eval_num_epochs is True tf.logging.warning(Forced number of epochs for all eval validations to be 1.) configs = config_util.merge_external_params_with_configs(configs, False, kwargs_dict=kwargs) # ③合并configs和kwargs hparms=False model_config = configs[model] # type is <class ‘object_detection.protos.model_pb2.DetectionModel‘> train_config = configs[train_config] train_input_config = configs[train_input_config] eval_config = configs[eval_config] eval_input_configs = configs[eval_input_configs] # 是一个len=1的列表 eval_on_train_input_config = copy.deepcopy(train_input_config) eval_on_train_input_config.sample_1_of_n_examples = (sample_1_of_n_eval_on_train_examples) eval_on_train_input_config.num_epochs = 1 # because override_eval_num_epochs=True # Create the input functions for TRAIN/EVAL/PREDICT. train_input_fn = input.train_input_fn(train_config=train_config, train_input_config=train_input_config, model_config=model_config) eval_input_fns = [] eval_input_fns[0] = input.create_eval_input_fn(eval_config, eval_input_configs[0], model_config) eval_on_train_input_fn = input.create_eval_input_fn(eval_config=eval_config, eval_input_config=eval_on_train_input_config, model_config=model_config) predict_input_fn = input.create_predict_input_fn(model_config=model_config, predict_input_config=eval_input_configs[0]) detection_model_fn = functools.partial(model_builder.build, model_config=model_config) model_fn = model_fn_creator(detection_model_fn, configs, False, False, False) # use_tpu = False hparms=False postprocess_on_cpu =False estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config) # Write the as-run pipeline config to disk. if run_config.is_chief and save_final_config: pipeline_config_final = config_util.create_pipeline_proto_from_configs(configs) config_util.save_pipeline_config(pipeline_config_final, estimator.model_dir) return dict( estimator=estimator, train_input_fn=train_input_fn, eval_input_fns=eval_input_fns, eval_on_train_input_fn=eval_on_train_input_fn, predict_input_fn=predict_input_fn, train_steps=train_config.num_steps) # 训练次数

 该函数:略去了一些参数:hparams=None, config_override=None, override_eval_num_epochs=True, save_final_config=False, postprocess_on_cpu=False,还有tpu有关的一些参数。

 ①首先由.config文件取出字典,  根据训练参数(比如sample_1_of_n_eval_examples表示一个eval example的采样频率)设置字典kwargs,然后将这两个字典合并处理。

  处理config文件时调用了3处config_util的库函数:载入、合并两个字典、存回。

 ②然后将config字典拆分成5部分,调用3个input库函数返回了4个fn函数;有两处用到了[0],是因为configs[‘eval_input_configs‘]返回一个长为1的列表。

  比较麻烦的是estimator,其中model_fn = model_fn_creator( model_builder.build(model_config), configs,……) ,关于functools.partial部分应该是这样解释的。

 ③返回的train_steps是设置在.config文件中的训练次数4000

 接下来看model_builder.build(model_config)函数  ---- 又跳转了一个.py文件

1、model_builder.build(model_config)函数

 ① build函数中 首先判断 model_config.WhichOneof(‘model‘)=="ssd" 跳转到同文件下的

  _build_ssd_model(getattr(model_config, ‘ssd‘), is_training, True)函数,这个is_training参数很奇怪 根本没用到这个参数!函数如下:

技术分享图片
from object_detection.builders import anchor_generator_builder
from object_detection.builders import box_coder_builder
from object_detection.builders import box_predictor_builder
from object_detection.builders import hyperparams_builder
from object_detection.builders import image_resizer_builder
from object_detection.builders import losses_builder
from object_detection.builders import matcher_builder
from object_detection.builders import post_processing_builder
from object_detection.builders import region_similarity_calculator_builder as sim_calc
from object_detection.core import target_assigner
from object_detection.meta_architectures import ssd_meta_arch
from object_detection.utils import ops


def _build_ssd_model(ssd_config, is_training, add_summaries):
    """Builds an SSD detection model based on the model config.

  Args:
    ssd_config: A ssd.proto object containing the config for the desired
      SSDMetaArch.
    is_training: True if this model is being built for training purposes.
    add_summaries: Whether to add tf summaries in the model.
  Returns:
    SSDMetaArch based on the config.

  Raises:
    ValueError: If ssd_config.type is not recognized (i.e. not registered in
      model_class_map).
  """
    num_classes = ssd_config.num_classes  # 5
    
    # Feature extractor   返回ssd_mobiledet_feature_extractor.SSDMobileDetFeatureExtractorBase
    feature_extractor = _build_ssd_feature_extractor(
        feature_extractor_config=ssd_config.feature_extractor,
        freeze_batchnorm=ssd_config.freeze_batchnorm,  # False
        is_training=is_training)

    box_coder = box_coder_builder.build(ssd_config.box_coder)
    non_max_suppression_fn, score_conversion_fn = post_processing_builder.build(ssd_config.post_processing)
    (classification_loss, localization_loss, classification_weight,
     localization_weight, hard_example_miner, random_example_sampler,
     expected_loss_weights_fn) = losses_builder.build(ssd_config.loss)
    kwargs = {}
    return ssd_meta_arch.SSDMetaArch(
        is_training=is_training,
        anchor_generator=anchor_generator_builder.build(ssd_config.anchor_generator),
        box_predictor=box_predictor_builder.build(
            hyperparams_builder.build, ssd_config.box_predictor, is_training,
            num_classes, ssd_config.add_background_class),
        box_coder=box_coder,
        feature_extractor=feature_extractor,
        encode_background_as_zeros=ssd_config.encode_background_as_zeros,
        image_resizer_fn=image_resizer_builder.build(ssd_config.image_resizer),
        non_max_suppression_fn=non_max_suppression_fn,
        score_conversion_fn=score_conversion_fn,
        classification_loss=classification_loss,
        localization_loss=localization_loss,
        classification_loss_weight=classification_weight,
        localization_loss_weight=localization_weight,
        normalize_loss_by_num_matches=ssd_config.normalize_loss_by_num_matches,
        hard_example_miner=hard_example_miner,
        target_assigner_instance=target_assigner.TargetAssigner(
            sim_calc.build(ssd_config.similarity_calculator),
            matcher_builder.build(ssd_config.matcher),
            box_coder,
            negative_class_weight=ssd_config.negative_class_weight),
        add_summaries=add_summaries,
        normalize_loc_loss_by_codesize=ssd_config.normalize_loc_loss_by_codesize,
        freeze_batchnorm=ssd_config.freeze_batchnorm,
        inplace_batchnorm_update=ssd_config.inplace_batchnorm_update,
        add_background_class=ssd_config.add_background_class,
        explicit_background_class=ssd_config.explicit_background_class,
        random_example_sampler=random_example_sampler,
        expected_loss_weights_fn=expected_loss_weights_fn,
        use_confidences_as_targets=ssd_config.use_confidences_as_targets,
        implicit_example_weight=ssd_config.implicit_example_weight,
        equalization_loss_config=ops.EqualizationLossConfig(
            weight=ssd_config.loss.equalization_loss.weight,
            exclude_prefixes=ssd_config.loss.equalization_loss.exclude_prefixes),
        return_raw_detections_during_predict=(
            ssd_config.return_raw_detections_during_predict),
        **kwargs)
_build_ssd_model函数

  函数中由ssd_config经过各种build函数 返回成了ssd_meta_arch.SSDMetaArch类的参数,其中比较费解的是

feature_extractor = _build_ssd_feature_extractor(    # 同文件下的函数
        feature_extractor_config=ssd_config.feature_extractor,
        freeze_batchnorm=ssd_config.freeze_batchnorm,  # False
        is_training=is_training)

 ② class SSDMetaArch(model.DetectionModel) 该类将传来的参数都成了self属性,该类还有十几个方法

2、

 

 

  

Tensorflow object_detection API源码——mobileDet

原文:https://www.cnblogs.com/shines87/p/15225582.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!