본문 바로가기
CS/삽질로그

[디버깅] tensorflow assertion failed: [predictions must be >= 0] 에러 해결

by judy@ 2022. 10. 7.

딥러닝 모델 학습을 시도하였는데, 다음과 같은 assertion failed이 발생함.

 

에러메시지는 아래와 같음

Epoch 1/20
Traceback (most recent call last):
  File "lifelog_modeling.py", line 152, in <module>
    main()
  File "lifelog_modeling.py", line 129, in main
    history = model.fit(
  File "/home/***/anaconda3/envs/wsi-py3.8/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 108, in _method_wrapper
    return method(self, *args, **kwargs)
  File "/home/***/anaconda3/envs/wsi-py3.8/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 1098, in fit
    tmp_logs = train_function(iterator)
  File "/home/***/anaconda3/envs/wsi-py3.8/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 780, in __call__
    result = self._call(*args, **kwds)
  File "/home/***/anaconda3/envs/wsi-py3.8/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 840, in _call
    return self._stateless_fn(*args, **kwds)
  File "/home/***/anaconda3/envs/wsi-py3.8/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2829, in __call__
    return graph_function._filtered_call(args, kwargs)  # pylint: disable=protected-access
  File "/home/***/anaconda3/envs/wsi-py3.8/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 1843, in _filtered_call
    return self._call_flat(
  File "/home/***/anaconda3/envs/wsi-py3.8/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 1923, in _call_flat
    return self._build_call_outputs(self._inference_function.call(
  File "/home/***/anaconda3/envs/wsi-py3.8/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 545, in call
    outputs = execute.execute(
  File "/home/***/anaconda3/envs/wsi-py3.8/lib/python3.8/site-packages/tensorflow/python/eager/execute.py", line 59, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.InvalidArgumentError:  assertion failed: [predictions must be >= 0] [Condition x >= y did not hold element-wise:] [x (functional_1/dense/BiasAdd:0) = ] [[-0.00227704132][0.00556892389][0.0618832]...] [y (Cast_8/x:0) = ] [0]
         [[{{node assert_greater_equal/Assert/AssertGuard/else/_1/assert_greater_equal/Assert/AssertGuard/Assert}}]] [Op:__inference_train_function_3371]

Function call stack:
train_function

 

확인해보니, 회귀 예측을 목표로 하면서, 목표 변수(y)를 integer로 두어 발생한 문제. y 변수의 type을 변경하여 목표 변수가 클래스가 아닌 값을 가지도록 변경함.

 

# 기존 코드
y = y.astype('int8')

# 변경 코드
y = y.astype('float64')
반응형