mirror of
				https://github.com/ytdl-org/youtube-dl.git
				synced 2025-10-29 09:26:20 -07:00 
			
		
		
		
	[utils] Align traverse_obj() with yt-dlp
Thanks Grub4k for these:
* traverse `Iterable`s, from https://github.com/yt-dlp/yt-dlp/pull/6902, etc
* traverse `set` key for transformations/filters, `re.Match` group names, from
  776995bc10, etc
* traverse `re.Match`es, from https://github.com/yt-dlp/yt-dlp/pull/5174
* always return list when branching, from https://github.com/yt-dlp/yt-dlp/pull/5170
			
			
This commit is contained in:
		| @@ -20,7 +20,7 @@ import xml.etree.ElementTree | |||||||
| from youtube_dl.utils import ( | from youtube_dl.utils import ( | ||||||
|     age_restricted, |     age_restricted, | ||||||
|     args_to_str, |     args_to_str, | ||||||
|     encode_base_n, |     base_url, | ||||||
|     caesar, |     caesar, | ||||||
|     clean_html, |     clean_html, | ||||||
|     clean_podcast_url, |     clean_podcast_url, | ||||||
| @@ -29,10 +29,12 @@ from youtube_dl.utils import ( | |||||||
|     detect_exe_version, |     detect_exe_version, | ||||||
|     determine_ext, |     determine_ext, | ||||||
|     dict_get, |     dict_get, | ||||||
|  |     encode_base_n, | ||||||
|     encode_compat_str, |     encode_compat_str, | ||||||
|     encodeFilename, |     encodeFilename, | ||||||
|     escape_rfc3986, |     escape_rfc3986, | ||||||
|     escape_url, |     escape_url, | ||||||
|  |     expand_path, | ||||||
|     extract_attributes, |     extract_attributes, | ||||||
|     ExtractorError, |     ExtractorError, | ||||||
|     find_xpath_attr, |     find_xpath_attr, | ||||||
| @@ -51,6 +53,7 @@ from youtube_dl.utils import ( | |||||||
|     js_to_json, |     js_to_json, | ||||||
|     LazyList, |     LazyList, | ||||||
|     limit_length, |     limit_length, | ||||||
|  |     lowercase_escape, | ||||||
|     merge_dicts, |     merge_dicts, | ||||||
|     mimetype2ext, |     mimetype2ext, | ||||||
|     month_by_name, |     month_by_name, | ||||||
| @@ -66,17 +69,16 @@ from youtube_dl.utils import ( | |||||||
|     parse_resolution, |     parse_resolution, | ||||||
|     parse_bitrate, |     parse_bitrate, | ||||||
|     pkcs1pad, |     pkcs1pad, | ||||||
|     read_batch_urls, |  | ||||||
|     sanitize_filename, |  | ||||||
|     sanitize_path, |  | ||||||
|     sanitize_url, |  | ||||||
|     expand_path, |  | ||||||
|     prepend_extension, |     prepend_extension, | ||||||
|     replace_extension, |     read_batch_urls, | ||||||
|     remove_start, |     remove_start, | ||||||
|     remove_end, |     remove_end, | ||||||
|     remove_quotes, |     remove_quotes, | ||||||
|  |     replace_extension, | ||||||
|     rot47, |     rot47, | ||||||
|  |     sanitize_filename, | ||||||
|  |     sanitize_path, | ||||||
|  |     sanitize_url, | ||||||
|     shell_quote, |     shell_quote, | ||||||
|     smuggle_url, |     smuggle_url, | ||||||
|     str_or_none, |     str_or_none, | ||||||
| @@ -93,10 +95,8 @@ from youtube_dl.utils import ( | |||||||
|     unified_timestamp, |     unified_timestamp, | ||||||
|     unsmuggle_url, |     unsmuggle_url, | ||||||
|     uppercase_escape, |     uppercase_escape, | ||||||
|     lowercase_escape, |  | ||||||
|     url_basename, |     url_basename, | ||||||
|     url_or_none, |     url_or_none, | ||||||
|     base_url, |  | ||||||
|     urljoin, |     urljoin, | ||||||
|     urlencode_postdata, |     urlencode_postdata, | ||||||
|     urshift, |     urshift, | ||||||
| @@ -1586,6 +1586,11 @@ Line 1 | |||||||
|             'dict': {}, |             'dict': {}, | ||||||
|         } |         } | ||||||
|  |  | ||||||
|  |         # define a pukka Iterable | ||||||
|  |         def iter_range(stop): | ||||||
|  |             for from_ in range(stop): | ||||||
|  |                 yield from_ | ||||||
|  |  | ||||||
|         # Test base functionality |         # Test base functionality | ||||||
|         self.assertEqual(traverse_obj(_TEST_DATA, ('str',)), 'str', |         self.assertEqual(traverse_obj(_TEST_DATA, ('str',)), 'str', | ||||||
|                          msg='allow tuple path') |                          msg='allow tuple path') | ||||||
| @@ -1602,13 +1607,13 @@ Line 1 | |||||||
|         # Test Ellipsis behavior |         # Test Ellipsis behavior | ||||||
|         self.assertCountEqual(traverse_obj(_TEST_DATA, Ellipsis), |         self.assertCountEqual(traverse_obj(_TEST_DATA, Ellipsis), | ||||||
|                               (item for item in _TEST_DATA.values() if item not in (None, {})), |                               (item for item in _TEST_DATA.values() if item not in (None, {})), | ||||||
|                               msg='`...` should give all non discarded values') |                               msg='`...` should give all non-discarded values') | ||||||
|         self.assertCountEqual(traverse_obj(_TEST_DATA, ('urls', 0, Ellipsis)), _TEST_DATA['urls'][0].values(), |         self.assertCountEqual(traverse_obj(_TEST_DATA, ('urls', 0, Ellipsis)), _TEST_DATA['urls'][0].values(), | ||||||
|                               msg='`...` selection for dicts should select all values') |                               msg='`...` selection for dicts should select all values') | ||||||
|         self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, Ellipsis, 'url')), |         self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, Ellipsis, 'url')), | ||||||
|                          ['https://www.example.com/0', 'https://www.example.com/1'], |                          ['https://www.example.com/0', 'https://www.example.com/1'], | ||||||
|                          msg='nested `...` queries should work') |                          msg='nested `...` queries should work') | ||||||
|         self.assertCountEqual(traverse_obj(_TEST_DATA, (Ellipsis, Ellipsis, 'index')), range(4), |         self.assertCountEqual(traverse_obj(_TEST_DATA, (Ellipsis, Ellipsis, 'index')), iter_range(4), | ||||||
|                               msg='`...` query result should be flattened') |                               msg='`...` query result should be flattened') | ||||||
|         self.assertEqual(traverse_obj(iter(range(4)), Ellipsis), list(range(4)), |         self.assertEqual(traverse_obj(iter(range(4)), Ellipsis), list(range(4)), | ||||||
|                          msg='`...` should accept iterables') |                          msg='`...` should accept iterables') | ||||||
| @@ -1618,7 +1623,7 @@ Line 1 | |||||||
|                          [_TEST_DATA['urls']], |                          [_TEST_DATA['urls']], | ||||||
|                          msg='function as query key should perform a filter based on (key, value)') |                          msg='function as query key should perform a filter based on (key, value)') | ||||||
|         self.assertCountEqual(traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], str)), set(('str',)), |         self.assertCountEqual(traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], str)), set(('str',)), | ||||||
|                               msg='exceptions in the query function should be catched') |                               msg='exceptions in the query function should be caught') | ||||||
|         self.assertEqual(traverse_obj(iter(range(4)), lambda _, x: x % 2 == 0), [0, 2], |         self.assertEqual(traverse_obj(iter(range(4)), lambda _, x: x % 2 == 0), [0, 2], | ||||||
|                          msg='function key should accept iterables') |                          msg='function key should accept iterables') | ||||||
|         if __debug__: |         if __debug__: | ||||||
| @@ -1706,7 +1711,7 @@ Line 1 | |||||||
|         self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}), {}, |         self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}), {}, | ||||||
|                          msg='remove empty values when dict key') |                          msg='remove empty values when dict key') | ||||||
|         self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}, default=Ellipsis), {0: Ellipsis}, |         self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}, default=Ellipsis), {0: Ellipsis}, | ||||||
|                          msg='use `default` when dict key and `default`') |                          msg='use `default` when dict key and a default') | ||||||
|         self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 'fail'}}), {}, |         self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 'fail'}}), {}, | ||||||
|                          msg='remove empty values when nested dict key fails') |                          msg='remove empty values when nested dict key fails') | ||||||
|         self.assertEqual(traverse_obj(None, {0: 'fail'}), {}, |         self.assertEqual(traverse_obj(None, {0: 'fail'}), {}, | ||||||
| @@ -1768,7 +1773,7 @@ Line 1 | |||||||
|         self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=str), |         self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=str), | ||||||
|                          'str', msg='accept matching `expected_type` type') |                          'str', msg='accept matching `expected_type` type') | ||||||
|         self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=int), |         self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=int), | ||||||
|                          None, msg='reject non matching `expected_type` type') |                          None, msg='reject non-matching `expected_type` type') | ||||||
|         self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'int', expected_type=lambda x: str(x)), |         self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'int', expected_type=lambda x: str(x)), | ||||||
|                          '0', msg='transform type using type function') |                          '0', msg='transform type using type function') | ||||||
|         self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=lambda _: 1 / 0), |         self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=lambda _: 1 / 0), | ||||||
| @@ -1780,7 +1785,7 @@ Line 1 | |||||||
|         self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2, 2: 'None'}, expected_type=str_or_none), |         self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2, 2: 'None'}, expected_type=str_or_none), | ||||||
|                          {0: '100', 1: '1.2'}, msg='function as expected_type should transform dict values') |                          {0: '100', 1: '1.2'}, msg='function as expected_type should transform dict values') | ||||||
|         self.assertEqual(traverse_obj(_TEST_DATA, ({0: 1.2}, 0, set((int_or_none,))), expected_type=int), |         self.assertEqual(traverse_obj(_TEST_DATA, ({0: 1.2}, 0, set((int_or_none,))), expected_type=int), | ||||||
|                          1, msg='expected_type should not filter non final dict values') |                          1, msg='expected_type should not filter non-final dict values') | ||||||
|         self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 100, 1: 'str'}}, expected_type=int), |         self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 100, 1: 'str'}}, expected_type=int), | ||||||
|                          {0: {0: 100}}, msg='expected_type should transform deep dict values') |                          {0: {0: 100}}, msg='expected_type should transform deep dict values') | ||||||
|         self.assertEqual(traverse_obj(_TEST_DATA, [({0: '...'}, {0: '...'})], expected_type=type(Ellipsis)), |         self.assertEqual(traverse_obj(_TEST_DATA, [({0: '...'}, {0: '...'})], expected_type=type(Ellipsis)), | ||||||
| @@ -1838,7 +1843,7 @@ Line 1 | |||||||
|         self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', slice(0, None, 2)), |         self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', slice(0, None, 2)), | ||||||
|                                       _traverse_string=True), 'sr', |                                       _traverse_string=True), 'sr', | ||||||
|                          msg='`slice` should result in string if `traverse_string`') |                          msg='`slice` should result in string if `traverse_string`') | ||||||
|         self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', lambda i, v: i or v == "s"), |         self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', lambda i, v: i or v == 's'), | ||||||
|                                       _traverse_string=True), 'str', |                                       _traverse_string=True), 'str', | ||||||
|                          msg='function should result in string if `traverse_string`') |                          msg='function should result in string if `traverse_string`') | ||||||
|         self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', (0, 2)), |         self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', (0, 2)), | ||||||
|   | |||||||
| @@ -4268,13 +4268,8 @@ def variadic(x, allowed_types=NO_DEFAULT): | |||||||
|  |  | ||||||
|  |  | ||||||
| def dict_get(d, key_or_keys, default=None, skip_false_values=True): | def dict_get(d, key_or_keys, default=None, skip_false_values=True): | ||||||
|     if isinstance(key_or_keys, (list, tuple)): |     exp = (lambda x: x or None) if skip_false_values else IDENTITY | ||||||
|         for key in key_or_keys: |     return traverse_obj(d, *variadic(key_or_keys), expected_type=exp, default=default) | ||||||
|             if key not in d or d[key] is None or skip_false_values and not d[key]: |  | ||||||
|                 continue |  | ||||||
|             return d[key] |  | ||||||
|         return default |  | ||||||
|     return d.get(key_or_keys, default) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def try_call(*funcs, **kwargs): | def try_call(*funcs, **kwargs): | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user