Skip to content

%matplotlib magic causes "ImportError: DLL load failed while importing" with jax and Windows 11 #15153

@ethan-tau

Description

@ethan-tau

Problem

Using %matplotlib magic prior to importing jax module causes a DLL failure on Windows 11

get_ipython().run_line_magic("matplotlib", "qt5")

import jax

Error

ImportError: DLL load failed while importing _jax: A dynamic link library (DLL) initialization routine failed.

Specific error line from stacktrace:
File ...\jaxlib\xla_client.py:28: from jaxlib import _jax as _xla

Full stacktrace
---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
Cell In[1], line 3
      1 get_ipython().run_line_magic("matplotlib", "qt5")  # noqa: F821
----> 3 import jax

File C:\...\Lib\site-packages\jax\__init__.py:25
     22 from jax.version import __version_info__ as __version_info__
     24 # Set Cloud TPU env vars if necessary before transitively loading C++ backend
---> 25 from jax._src.cloud_tpu_init import cloud_tpu_init as _cloud_tpu_init
     26 try:
     27   _cloud_tpu_init()

File C:\...\Lib\site-packages\jax\_src\cloud_tpu_init.py:20
     17 import re
     18 import warnings
---> 20 from jax._src import config
     21 from jax._src import hardware_utils
     23 running_in_cloud_tpu_vm: bool = False

File C:\...\Lib\site-packages\jax\_src\config.py:29
     26 import warnings
     28 from jax._src import deprecations
---> 29 from jax._src import logging_config
     30 from jax._src.lib import _jax
     31 from jax._src.lib import guard_lib

File C:\...\Lib\site-packages\jax\_src\logging_config.py:17
     15 import logging
     16 import sys
---> 17 from jax._src.lib import utils
     19 # Example log message:
     20 # DEBUG:2023-06-07 00:14:40,280:jax._src.xla_bridge:590: Initializing backend 'cpu'
     21 logging_formatter = logging.Formatter(
     22     "{levelname}:{asctime}:{name}:{lineno}: {message}", style='{')

File C:\...\Lib\site-packages\jax\_src\lib\__init__.py:89
     86 import jaxlib.cpu_feature_guard as cpu_feature_guard
     87 cpu_feature_guard.check_cpu_features()
---> 89 import jaxlib.xla_client as xla_client  # noqa: F401
     91 # Jaxlib code is split between the Jax and the XLA repositories.
     92 # Only for the internal usage of the JAX developers, we expose a version
     93 # number that can be used to perform changes without breaking the main
     94 # branch on the Jax github.
     95 jaxlib_extension_version: int = getattr(xla_client, '_version', 0)

File C:\...\Lib\site-packages\jaxlib\xla_client.py:28
     25 import threading
     26 from typing import Any, Protocol, Union
---> 28 from jaxlib import _jax as _xla
     30 # Note this module does *not* depend on any Python protocol buffers. The XLA
     31 # Python bindings are currently packaged both as part of jaxlib and as part
     32 # of TensorFlow. If we use protocol buffers here, then importing both jaxlib
   (...)     39 # Pylint has false positives for type annotations.
     40 # pylint: disable=invalid-sequence-index
     42 ifrt_programs = _xla.ifrt_programs

ImportError: DLL load failed while importing _jax: A dynamic link library (DLL) initialization routine failed.

Packages/environment:

OS: Windows 11
IPython: 9.7.0
matplotlib: 3.10.8
jax: 0.9.1

Other context

Without the matplotlib magic, the import works just fine:

import jax

as does this:

import matplotlib
import matplotlib.pyplot as plt
plt.plot(range(4))
import jax

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions